Re: [PATCH] random_normal function

Поиск
Список
Период
Сортировка
От Tom Lane
Тема Re: [PATCH] random_normal function
Дата
Msg-id 3983784.1673223640@sss.pgh.pa.us
обсуждение исходный текст
Ответ на Re: [PATCH] random_normal function  (Tom Lane <tgl@sss.pgh.pa.us>)
Ответы Re: [PATCH] random_normal function  (Tom Lane <tgl@sss.pgh.pa.us>)
Re: [PATCH] random_normal function  (Dean Rasheed <dean.a.rasheed@gmail.com>)
Список pgsql-hackers
I wrote:
> So the problem in this patch is that it's trying to include
> utils/float.h in a src/common file, where we have not included
> postgres.h.  Question is, why did you do that?

(Ah, for M_PI ... but our practice is just to duplicate that #define
where needed outside the backend.)

I spent some time reviewing this patch.  I'm on board with
inventing random_normal(): the definition seems solid and
the use-case for it seems reasonably well established.
I'm not necessarily against inventing similar functions for
other distributions, but this patch is not required to do so.
We can leave that discussion until somebody is motivated to
submit a patch for one.

On the other hand, I'm much less on board with inventing
random_string(): we don't have any clear field demand for it
and the appropriate definitional details are a lot less obvious
(for example, whether it needs to be based on pg_strong_random()
rather than the random() sequence).  I think we should leave that
out, and I have done so in the attached updated patch.

I noted several errors in the submitted patch.  It was creating
the function as PARALLEL SAFE which is just wrong, and the whole
business with checking PG_NARGS is useless because it will always
be 2.  (That's not how default arguments work.)

The business with checking against DBL_EPSILON seems wrong too.
All we need is to ensure that u1 > 0 so that log(u1) will not
choke; per spec, log() is defined for any positive input.  I see that
that seems to have been modeled on the C++ code in the Wikipedia
page, but I'm not sure that C++'s epsilon means the same thing, and
if it does then their example code is just wrong.  See the discussion
about "tails truncation" immediately above it: artificially
constraining the range of u1 just limits how much of the tail
of the distribution we can reproduce.  So that led me to doing
it the same way as in the existing Box-Muller code in pgbench,
which I then deleted per Fabien's advice.

BTW, the pgbench code was using sin() not cos(), which I duplicated
because using cos() causes the expected output of the pgbench tests
to change.  I'm not sure whether there was any hard reason to prefer
one or the other, and we can certainly change the expected output
if there's some reason to prefer cos().

I concur with not worrying about the Inf/NaN cases that Mark
pointed out.  It's not obvious that the results the proposed code
produces are wrong, and it's even less obvious that anyone will
ever care.

Also, I tried running the new random.sql regression cases over
and over, and found that the "not all duplicates" test fails about
one time in 100000 or so.  We could probably tolerate that given
that the random test is marked "ignore" in parallel_schedule, but
I thought it best to add one more iteration so we could knock the
odds down.  Also I changed the test iterations so they weren't
all invoking random_normal() in exactly the same way.

This version seems committable to me, barring objections.

            regards, tom lane

diff --git a/doc/src/sgml/func.sgml b/doc/src/sgml/func.sgml
index 3bf8d021c3..b67dc26a35 100644
--- a/doc/src/sgml/func.sgml
+++ b/doc/src/sgml/func.sgml
@@ -1815,6 +1815,28 @@ repeat('Pg', 4) <returnvalue>PgPgPgPg</returnvalue>
        </para></entry>
       </row>

+      <row>
+       <entry role="func_table_entry"><para role="func_signature">
+        <indexterm>
+         <primary>random_normal</primary>
+        </indexterm>
+
+         <function>random_normal</function> (
+         <optional> <parameter>mean</parameter> <type>double precision</type>
+         <optional>, <parameter>stddev</parameter> <type>double precision</type> </optional></optional> )
+         <returnvalue>double precision</returnvalue>
+       </para>
+       <para>
+        Returns a random value from the normal distribution with the given
+        parameters; <parameter>mean</parameter> defaults to 0.0
+        and <parameter>stddev</parameter> defaults to 1.0
+       </para>
+       <para>
+        <literal>random_normal(0.0, 1.0)</literal>
+        <returnvalue>0.051285419</returnvalue>
+       </para></entry>
+      </row>
+
       <row>
        <entry role="func_table_entry"><para role="func_signature">
         <indexterm>
@@ -1824,7 +1846,8 @@ repeat('Pg', 4) <returnvalue>PgPgPgPg</returnvalue>
         <returnvalue>void</returnvalue>
        </para>
        <para>
-        Sets the seed for subsequent <literal>random()</literal> calls;
+        Sets the seed for subsequent <literal>random()</literal> and
+        <literal>random_normal()</literal> calls;
         argument must be between -1.0 and 1.0, inclusive
        </para>
        <para>
@@ -1848,6 +1871,7 @@ repeat('Pg', 4) <returnvalue>PgPgPgPg</returnvalue>
    Without any prior <function>setseed()</function> call in the same
    session, the first <function>random()</function> call obtains a seed
    from a platform-dependent source of random bits.
+   These remarks hold equally for <function>random_normal()</function>.
   </para>

   <para>
diff --git a/src/backend/catalog/system_functions.sql b/src/backend/catalog/system_functions.sql
index f2470708e9..83ca893444 100644
--- a/src/backend/catalog/system_functions.sql
+++ b/src/backend/catalog/system_functions.sql
@@ -66,6 +66,13 @@ CREATE OR REPLACE FUNCTION bit_length(text)
  IMMUTABLE PARALLEL SAFE STRICT COST 1
 RETURN octet_length($1) * 8;

+CREATE OR REPLACE FUNCTION
+ random_normal(mean float8 DEFAULT 0, stddev float8 DEFAULT 1)
+ RETURNS float8
+ LANGUAGE internal
+ VOLATILE PARALLEL RESTRICTED STRICT COST 1
+AS 'drandom_normal';
+
 CREATE OR REPLACE FUNCTION log(numeric)
  RETURNS numeric
  LANGUAGE sql
diff --git a/src/backend/utils/adt/float.c b/src/backend/utils/adt/float.c
index 56e349b888..d290b4ca67 100644
--- a/src/backend/utils/adt/float.c
+++ b/src/backend/utils/adt/float.c
@@ -2743,13 +2743,11 @@ datanh(PG_FUNCTION_ARGS)


 /*
- *        drandom        - returns a random number
+ * initialize_drandom_seed - initialize drandom_seed if not yet done
  */
-Datum
-drandom(PG_FUNCTION_ARGS)
+static void
+initialize_drandom_seed(void)
 {
-    float8        result;
-
     /* Initialize random seed, if not done yet in this process */
     if (unlikely(!drandom_seed_set))
     {
@@ -2769,6 +2767,17 @@ drandom(PG_FUNCTION_ARGS)
         }
         drandom_seed_set = true;
     }
+}
+
+/*
+ *        drandom        - returns a random number
+ */
+Datum
+drandom(PG_FUNCTION_ARGS)
+{
+    float8        result;
+
+    initialize_drandom_seed();

     /* pg_prng_double produces desired result range [0.0 - 1.0) */
     result = pg_prng_double(&drandom_seed);
@@ -2776,6 +2785,27 @@ drandom(PG_FUNCTION_ARGS)
     PG_RETURN_FLOAT8(result);
 }

+/*
+ *        drandom_normal    - returns a random number from a normal distribution
+ */
+Datum
+drandom_normal(PG_FUNCTION_ARGS)
+{
+    float8        mean = PG_GETARG_FLOAT8(0);
+    float8        stddev = PG_GETARG_FLOAT8(1);
+    float8        result,
+                z;
+
+    initialize_drandom_seed();
+
+    /* Get random value from standard normal(mean = 0.0, stddev = 1.0) */
+    z = pg_prng_double_normal(&drandom_seed);
+    /* Transform the normal standard variable (z) */
+    /* using the target normal distribution parameters */
+    result = (stddev * z) + mean;
+
+    PG_RETURN_FLOAT8(result);
+}

 /*
  *        setseed        - set seed for the random number generator
diff --git a/src/bin/pgbench/pgbench.c b/src/bin/pgbench/pgbench.c
index 18d9c94ebd..9c12ffaea9 100644
--- a/src/bin/pgbench/pgbench.c
+++ b/src/bin/pgbench/pgbench.c
@@ -1136,8 +1136,8 @@ getGaussianRand(pg_prng_state *state, int64 min, int64 max,
     Assert(parameter >= MIN_GAUSSIAN_PARAM);

     /*
-     * Get user specified random number from this loop, with -parameter <
-     * stdev <= parameter
+     * Get normally-distributed random number in the range -parameter <= stdev
+     * < parameter.
      *
      * This loop is executed until the number is in the expected range.
      *
@@ -1149,25 +1149,7 @@ getGaussianRand(pg_prng_state *state, int64 min, int64 max,
      */
     do
     {
-        /*
-         * pg_prng_double generates [0, 1), but for the basic version of the
-         * Box-Muller transform the two uniformly distributed random numbers
-         * are expected to be in (0, 1] (see
-         * https://en.wikipedia.org/wiki/Box-Muller_transform)
-         */
-        double        rand1 = 1.0 - pg_prng_double(state);
-        double        rand2 = 1.0 - pg_prng_double(state);
-
-        /* Box-Muller basic form transform */
-        double        var_sqrt = sqrt(-2.0 * log(rand1));
-
-        stdev = var_sqrt * sin(2.0 * M_PI * rand2);
-
-        /*
-         * we may try with cos, but there may be a bias induced if the
-         * previous value fails the test. To be on the safe side, let us try
-         * over.
-         */
+        stdev = pg_prng_double_normal(state);
     }
     while (stdev < -parameter || stdev >= parameter);

diff --git a/src/common/pg_prng.c b/src/common/pg_prng.c
index e58b471cff..6e07d1c810 100644
--- a/src/common/pg_prng.c
+++ b/src/common/pg_prng.c
@@ -19,11 +19,17 @@

 #include "c.h"

-#include <math.h>                /* for ldexp() */
+#include <math.h>

 #include "common/pg_prng.h"
 #include "port/pg_bitutils.h"

+/* X/Open (XSI) requires <math.h> to provide M_PI, but core POSIX does not */
+#ifndef M_PI
+#define M_PI 3.14159265358979323846
+#endif
+
+
 /* process-wide state vector */
 pg_prng_state pg_global_prng_state;

@@ -235,6 +241,35 @@ pg_prng_double(pg_prng_state *state)
     return ldexp((double) (v >> (64 - 52)), -52);
 }

+/*
+ * Select a random double from the normal distribution with
+ * mean = 0.0 and stddev = 1.0.
+ *
+ * To get a result from a different normal distribution use
+ *   STDDEV * pg_prng_double_normal() + MEAN
+ *
+ * Uses https://en.wikipedia.org/wiki/Box–Muller_transform
+ */
+double
+pg_prng_double_normal(pg_prng_state *state)
+{
+    double        u1,
+                u2,
+                z0;
+
+    /*
+     * pg_prng_double generates [0, 1), but for the basic version of the
+     * Box-Muller transform the two uniformly distributed random numbers are
+     * expected to be in (0, 1]; in particular we'd better not compute log(0).
+     */
+    u1 = 1.0 - pg_prng_double(state);
+    u2 = 1.0 - pg_prng_double(state);
+
+    /* Apply Box-Muller transform to get one normal-valued output */
+    z0 = sqrt(-2.0 * log(u1)) * sin(2.0 * M_PI * u2);
+    return z0;
+}
+
 /*
  * Select a random boolean value.
  */
diff --git a/src/include/catalog/pg_proc.dat b/src/include/catalog/pg_proc.dat
index 7be9a50147..3810de7b22 100644
--- a/src/include/catalog/pg_proc.dat
+++ b/src/include/catalog/pg_proc.dat
@@ -3359,6 +3359,10 @@
 { oid => '1598', descr => 'random value',
   proname => 'random', provolatile => 'v', proparallel => 'r',
   prorettype => 'float8', proargtypes => '', prosrc => 'drandom' },
+{ oid => '8074', descr => 'random value from normal distribution',
+  proname => 'random_normal', provolatile => 'v', proparallel => 'r',
+  prorettype => 'float8', proargtypes => 'float8 float8',
+  prosrc => 'drandom_normal' },
 { oid => '1599', descr => 'set random seed',
   proname => 'setseed', provolatile => 'v', proparallel => 'r',
   prorettype => 'void', proargtypes => 'float8', prosrc => 'setseed' },
diff --git a/src/include/common/pg_prng.h b/src/include/common/pg_prng.h
index 9e11e8fffd..b5c0b8d288 100644
--- a/src/include/common/pg_prng.h
+++ b/src/include/common/pg_prng.h
@@ -55,6 +55,7 @@ extern uint32 pg_prng_uint32(pg_prng_state *state);
 extern int32 pg_prng_int32(pg_prng_state *state);
 extern int32 pg_prng_int32p(pg_prng_state *state);
 extern double pg_prng_double(pg_prng_state *state);
+extern double pg_prng_double_normal(pg_prng_state *state);
 extern bool pg_prng_bool(pg_prng_state *state);

 #endif                            /* PG_PRNG_H */
diff --git a/src/test/regress/expected/random.out b/src/test/regress/expected/random.out
index a919b28d8d..547b9c9b2b 100644
--- a/src/test/regress/expected/random.out
+++ b/src/test/regress/expected/random.out
@@ -51,3 +51,34 @@ SELECT AVG(random) FROM RANDOM_TBL
 -----
 (0 rows)

+-- now test random_normal()
+TRUNCATE random_tbl;
+INSERT INTO random_tbl (random)
+  SELECT count(*)
+  FROM onek WHERE random_normal(0, 1) < 0;
+INSERT INTO random_tbl (random)
+  SELECT count(*)
+  FROM onek WHERE random_normal(0) < 0;
+INSERT INTO random_tbl (random)
+  SELECT count(*)
+  FROM onek WHERE random_normal() < 0;
+INSERT INTO random_tbl (random)
+  SELECT count(*)
+  FROM onek WHERE random_normal(0, 10) < 0;
+INSERT INTO random_tbl (random)
+  SELECT count(*)
+  FROM onek WHERE random_normal(stddev => 1, mean => 0) < 0;
+-- expect similar, but not identical values
+SELECT random, count(random) FROM random_tbl
+  GROUP BY random HAVING count(random) > 4;
+ random | count
+--------+-------
+(0 rows)
+
+-- approximately check expected distribution
+SELECT AVG(random) FROM random_tbl
+  HAVING AVG(random) NOT BETWEEN 400 AND 600;
+ avg
+-----
+(0 rows)
+
diff --git a/src/test/regress/sql/random.sql b/src/test/regress/sql/random.sql
index 8187b2c288..56eb9b045c 100644
--- a/src/test/regress/sql/random.sql
+++ b/src/test/regress/sql/random.sql
@@ -42,3 +42,30 @@ SELECT random, count(random) FROM RANDOM_TBL

 SELECT AVG(random) FROM RANDOM_TBL
   HAVING AVG(random) NOT BETWEEN 80 AND 120;
+
+-- now test random_normal()
+
+TRUNCATE random_tbl;
+INSERT INTO random_tbl (random)
+  SELECT count(*)
+  FROM onek WHERE random_normal(0, 1) < 0;
+INSERT INTO random_tbl (random)
+  SELECT count(*)
+  FROM onek WHERE random_normal(0) < 0;
+INSERT INTO random_tbl (random)
+  SELECT count(*)
+  FROM onek WHERE random_normal() < 0;
+INSERT INTO random_tbl (random)
+  SELECT count(*)
+  FROM onek WHERE random_normal(0, 10) < 0;
+INSERT INTO random_tbl (random)
+  SELECT count(*)
+  FROM onek WHERE random_normal(stddev => 1, mean => 0) < 0;
+
+-- expect similar, but not identical values
+SELECT random, count(random) FROM random_tbl
+  GROUP BY random HAVING count(random) > 4;
+
+-- approximately check expected distribution
+SELECT AVG(random) FROM random_tbl
+  HAVING AVG(random) NOT BETWEEN 400 AND 600;

В списке pgsql-hackers по дате отправления:

Предыдущее
От: Andres Freund
Дата:
Сообщение: Re: Fixing a couple of buglets in how VACUUM sets visibility map bits
Следующее
От: Peter Geoghegan
Дата:
Сообщение: Re: Fixing a couple of buglets in how VACUUM sets visibility map bits