Invent random_normal() to provide normally-distributed random numbers.

There is already a version of this in contrib/tablefunc, but it
seems sufficiently widely useful to justify having it in core.

Paul Ramsey

Discussion: https://postgr.es/m/CACowWR0DqHAvOKUCNxTrASFkWsDLqKMd6WiXvVvaWg4pV1BMnQ@mail.gmail.com
This commit is contained in:
Tom Lane 2023-01-09 12:44:00 -05:00
parent 2673ebf49a
commit 38d81760c4
10 changed files with 164 additions and 29 deletions

View File

@ -1815,6 +1815,28 @@ repeat('Pg', 4) <returnvalue>PgPgPgPg</returnvalue>
</para></entry>
</row>
<row>
<entry role="func_table_entry"><para role="func_signature">
<indexterm>
<primary>random_normal</primary>
</indexterm>
<function>random_normal</function> (
<optional> <parameter>mean</parameter> <type>double precision</type>
<optional>, <parameter>stddev</parameter> <type>double precision</type> </optional></optional> )
<returnvalue>double precision</returnvalue>
</para>
<para>
Returns a random value from the normal distribution with the given
parameters; <parameter>mean</parameter> defaults to 0.0
and <parameter>stddev</parameter> defaults to 1.0
</para>
<para>
<literal>random_normal(0.0, 1.0)</literal>
<returnvalue>0.051285419</returnvalue>
</para></entry>
</row>
<row>
<entry role="func_table_entry"><para role="func_signature">
<indexterm>
@ -1824,7 +1846,8 @@ repeat('Pg', 4) <returnvalue>PgPgPgPg</returnvalue>
<returnvalue>void</returnvalue>
</para>
<para>
Sets the seed for subsequent <literal>random()</literal> calls;
Sets the seed for subsequent <literal>random()</literal> and
<literal>random_normal()</literal> calls;
argument must be between -1.0 and 1.0, inclusive
</para>
<para>
@ -1848,6 +1871,7 @@ repeat('Pg', 4) <returnvalue>PgPgPgPg</returnvalue>
Without any prior <function>setseed()</function> call in the same
session, the first <function>random()</function> call obtains a seed
from a platform-dependent source of random bits.
These remarks hold equally for <function>random_normal()</function>.
</para>
<para>

View File

@ -66,6 +66,13 @@ CREATE OR REPLACE FUNCTION bit_length(text)
IMMUTABLE PARALLEL SAFE STRICT COST 1
RETURN octet_length($1) * 8;
CREATE OR REPLACE FUNCTION
random_normal(mean float8 DEFAULT 0, stddev float8 DEFAULT 1)
RETURNS float8
LANGUAGE internal
VOLATILE PARALLEL RESTRICTED STRICT COST 1
AS 'drandom_normal';
CREATE OR REPLACE FUNCTION log(numeric)
RETURNS numeric
LANGUAGE sql

View File

@ -2743,13 +2743,11 @@ datanh(PG_FUNCTION_ARGS)
/*
* drandom - returns a random number
* initialize_drandom_seed - initialize drandom_seed if not yet done
*/
Datum
drandom(PG_FUNCTION_ARGS)
static void
initialize_drandom_seed(void)
{
float8 result;
/* Initialize random seed, if not done yet in this process */
if (unlikely(!drandom_seed_set))
{
@ -2769,6 +2767,17 @@ drandom(PG_FUNCTION_ARGS)
}
drandom_seed_set = true;
}
}
/*
* drandom - returns a random number
*/
Datum
drandom(PG_FUNCTION_ARGS)
{
float8 result;
initialize_drandom_seed();
/* pg_prng_double produces desired result range [0.0 - 1.0) */
result = pg_prng_double(&drandom_seed);
@ -2776,6 +2785,27 @@ drandom(PG_FUNCTION_ARGS)
PG_RETURN_FLOAT8(result);
}
/*
* drandom_normal - returns a random number from a normal distribution
*/
Datum
drandom_normal(PG_FUNCTION_ARGS)
{
float8 mean = PG_GETARG_FLOAT8(0);
float8 stddev = PG_GETARG_FLOAT8(1);
float8 result,
z;
initialize_drandom_seed();
/* Get random value from standard normal(mean = 0.0, stddev = 1.0) */
z = pg_prng_double_normal(&drandom_seed);
/* Transform the normal standard variable (z) */
/* using the target normal distribution parameters */
result = (stddev * z) + mean;
PG_RETURN_FLOAT8(result);
}
/*
* setseed - set seed for the random number generator

View File

@ -1136,8 +1136,8 @@ getGaussianRand(pg_prng_state *state, int64 min, int64 max,
Assert(parameter >= MIN_GAUSSIAN_PARAM);
/*
* Get user specified random number from this loop, with -parameter <
* stdev <= parameter
* Get normally-distributed random number in the range -parameter <= stdev
* < parameter.
*
* This loop is executed until the number is in the expected range.
*
@ -1149,25 +1149,7 @@ getGaussianRand(pg_prng_state *state, int64 min, int64 max,
*/
do
{
/*
* pg_prng_double generates [0, 1), but for the basic version of the
* Box-Muller transform the two uniformly distributed random numbers
* are expected to be in (0, 1] (see
* https://en.wikipedia.org/wiki/Box-Muller_transform)
*/
double rand1 = 1.0 - pg_prng_double(state);
double rand2 = 1.0 - pg_prng_double(state);
/* Box-Muller basic form transform */
double var_sqrt = sqrt(-2.0 * log(rand1));
stdev = var_sqrt * sin(2.0 * M_PI * rand2);
/*
* we may try with cos, but there may be a bias induced if the
* previous value fails the test. To be on the safe side, let us try
* over.
*/
stdev = pg_prng_double_normal(state);
}
while (stdev < -parameter || stdev >= parameter);

View File

@ -19,11 +19,17 @@
#include "c.h"
#include <math.h> /* for ldexp() */
#include <math.h>
#include "common/pg_prng.h"
#include "port/pg_bitutils.h"
/* X/Open (XSI) requires <math.h> to provide M_PI, but core POSIX does not */
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
/* process-wide state vector */
pg_prng_state pg_global_prng_state;
@ -235,6 +241,35 @@ pg_prng_double(pg_prng_state *state)
return ldexp((double) (v >> (64 - 52)), -52);
}
/*
* Select a random double from the normal distribution with
* mean = 0.0 and stddev = 1.0.
*
* To get a result from a different normal distribution use
* STDDEV * pg_prng_double_normal() + MEAN
*
* Uses https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
*/
double
pg_prng_double_normal(pg_prng_state *state)
{
double u1,
u2,
z0;
/*
* pg_prng_double generates [0, 1), but for the basic version of the
* Box-Muller transform the two uniformly distributed random numbers are
* expected to be in (0, 1]; in particular we'd better not compute log(0).
*/
u1 = 1.0 - pg_prng_double(state);
u2 = 1.0 - pg_prng_double(state);
/* Apply Box-Muller transform to get one normal-valued output */
z0 = sqrt(-2.0 * log(u1)) * sin(2.0 * M_PI * u2);
return z0;
}
/*
* Select a random boolean value.
*/

View File

@ -57,6 +57,6 @@
*/
/* yyyymmddN */
#define CATALOG_VERSION_NO 202301091
#define CATALOG_VERSION_NO 202301092
#endif

View File

@ -3359,6 +3359,10 @@
{ oid => '1598', descr => 'random value',
proname => 'random', provolatile => 'v', proparallel => 'r',
prorettype => 'float8', proargtypes => '', prosrc => 'drandom' },
{ oid => '8074', descr => 'random value from normal distribution',
proname => 'random_normal', provolatile => 'v', proparallel => 'r',
prorettype => 'float8', proargtypes => 'float8 float8',
prosrc => 'drandom_normal' },
{ oid => '1599', descr => 'set random seed',
proname => 'setseed', provolatile => 'v', proparallel => 'r',
prorettype => 'void', proargtypes => 'float8', prosrc => 'setseed' },

View File

@ -55,6 +55,7 @@ extern uint32 pg_prng_uint32(pg_prng_state *state);
extern int32 pg_prng_int32(pg_prng_state *state);
extern int32 pg_prng_int32p(pg_prng_state *state);
extern double pg_prng_double(pg_prng_state *state);
extern double pg_prng_double_normal(pg_prng_state *state);
extern bool pg_prng_bool(pg_prng_state *state);
#endif /* PG_PRNG_H */

View File

@ -51,3 +51,31 @@ SELECT AVG(random) FROM RANDOM_TBL
-----
(0 rows)
-- now test random_normal()
TRUNCATE random_tbl;
INSERT INTO random_tbl (random)
SELECT count(*)
FROM onek WHERE random_normal(0, 1) < 0;
INSERT INTO random_tbl (random)
SELECT count(*)
FROM onek WHERE random_normal(0) < 0;
INSERT INTO random_tbl (random)
SELECT count(*)
FROM onek WHERE random_normal() < 0;
INSERT INTO random_tbl (random)
SELECT count(*)
FROM onek WHERE random_normal(stddev => 1, mean => 0) < 0;
-- expect similar, but not identical values
SELECT random, count(random) FROM random_tbl
GROUP BY random HAVING count(random) > 3;
random | count
--------+-------
(0 rows)
-- approximately check expected distribution
SELECT AVG(random) FROM random_tbl
HAVING AVG(random) NOT BETWEEN 400 AND 600;
avg
-----
(0 rows)

View File

@ -42,3 +42,27 @@ SELECT random, count(random) FROM RANDOM_TBL
SELECT AVG(random) FROM RANDOM_TBL
HAVING AVG(random) NOT BETWEEN 80 AND 120;
-- now test random_normal()
TRUNCATE random_tbl;
INSERT INTO random_tbl (random)
SELECT count(*)
FROM onek WHERE random_normal(0, 1) < 0;
INSERT INTO random_tbl (random)
SELECT count(*)
FROM onek WHERE random_normal(0) < 0;
INSERT INTO random_tbl (random)
SELECT count(*)
FROM onek WHERE random_normal() < 0;
INSERT INTO random_tbl (random)
SELECT count(*)
FROM onek WHERE random_normal(stddev => 1, mean => 0) < 0;
-- expect similar, but not identical values
SELECT random, count(random) FROM random_tbl
GROUP BY random HAVING count(random) > 3;
-- approximately check expected distribution
SELECT AVG(random) FROM random_tbl
HAVING AVG(random) NOT BETWEEN 400 AND 600;