From 38d81760c4d7e22b95252e3545596602c9e38806 Mon Sep 17 00:00:00 2001 From: Tom Lane Date: Mon, 9 Jan 2023 12:44:00 -0500 Subject: [PATCH] Invent random_normal() to provide normally-distributed random numbers. There is already a version of this in contrib/tablefunc, but it seems sufficiently widely useful to justify having it in core. Paul Ramsey Discussion: https://postgr.es/m/CACowWR0DqHAvOKUCNxTrASFkWsDLqKMd6WiXvVvaWg4pV1BMnQ@mail.gmail.com --- doc/src/sgml/func.sgml | 26 ++++++++++++++- src/backend/catalog/system_functions.sql | 7 +++++ src/backend/utils/adt/float.c | 40 +++++++++++++++++++++--- src/bin/pgbench/pgbench.c | 24 ++------------ src/common/pg_prng.c | 37 +++++++++++++++++++++- src/include/catalog/catversion.h | 2 +- src/include/catalog/pg_proc.dat | 4 +++ src/include/common/pg_prng.h | 1 + src/test/regress/expected/random.out | 28 +++++++++++++++++ src/test/regress/sql/random.sql | 24 ++++++++++++++ 10 files changed, 164 insertions(+), 29 deletions(-) diff --git a/doc/src/sgml/func.sgml b/doc/src/sgml/func.sgml index 3bf8d021c3..b67dc26a35 100644 --- a/doc/src/sgml/func.sgml +++ b/doc/src/sgml/func.sgml @@ -1815,6 +1815,28 @@ repeat('Pg', 4) PgPgPgPg + + + + random_normal + + + random_normal ( + mean double precision + , stddev double precision ) + double precision + + + Returns a random value from the normal distribution with the given + parameters; mean defaults to 0.0 + and stddev defaults to 1.0 + + + random_normal(0.0, 1.0) + 0.051285419 + + + @@ -1824,7 +1846,8 @@ repeat('Pg', 4) PgPgPgPg void - Sets the seed for subsequent random() calls; + Sets the seed for subsequent random() and + random_normal() calls; argument must be between -1.0 and 1.0, inclusive @@ -1848,6 +1871,7 @@ repeat('Pg', 4) PgPgPgPg Without any prior setseed() call in the same session, the first random() call obtains a seed from a platform-dependent source of random bits. + These remarks hold equally for random_normal(). diff --git a/src/backend/catalog/system_functions.sql b/src/backend/catalog/system_functions.sql index f2470708e9..83ca893444 100644 --- a/src/backend/catalog/system_functions.sql +++ b/src/backend/catalog/system_functions.sql @@ -66,6 +66,13 @@ CREATE OR REPLACE FUNCTION bit_length(text) IMMUTABLE PARALLEL SAFE STRICT COST 1 RETURN octet_length($1) * 8; +CREATE OR REPLACE FUNCTION + random_normal(mean float8 DEFAULT 0, stddev float8 DEFAULT 1) + RETURNS float8 + LANGUAGE internal + VOLATILE PARALLEL RESTRICTED STRICT COST 1 +AS 'drandom_normal'; + CREATE OR REPLACE FUNCTION log(numeric) RETURNS numeric LANGUAGE sql diff --git a/src/backend/utils/adt/float.c b/src/backend/utils/adt/float.c index 56e349b888..d290b4ca67 100644 --- a/src/backend/utils/adt/float.c +++ b/src/backend/utils/adt/float.c @@ -2743,13 +2743,11 @@ datanh(PG_FUNCTION_ARGS) /* - * drandom - returns a random number + * initialize_drandom_seed - initialize drandom_seed if not yet done */ -Datum -drandom(PG_FUNCTION_ARGS) +static void +initialize_drandom_seed(void) { - float8 result; - /* Initialize random seed, if not done yet in this process */ if (unlikely(!drandom_seed_set)) { @@ -2769,6 +2767,17 @@ drandom(PG_FUNCTION_ARGS) } drandom_seed_set = true; } +} + +/* + * drandom - returns a random number + */ +Datum +drandom(PG_FUNCTION_ARGS) +{ + float8 result; + + initialize_drandom_seed(); /* pg_prng_double produces desired result range [0.0 - 1.0) */ result = pg_prng_double(&drandom_seed); @@ -2776,6 +2785,27 @@ drandom(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(result); } +/* + * drandom_normal - returns a random number from a normal distribution + */ +Datum +drandom_normal(PG_FUNCTION_ARGS) +{ + float8 mean = PG_GETARG_FLOAT8(0); + float8 stddev = PG_GETARG_FLOAT8(1); + float8 result, + z; + + initialize_drandom_seed(); + + /* Get random value from standard normal(mean = 0.0, stddev = 1.0) */ + z = pg_prng_double_normal(&drandom_seed); + /* Transform the normal standard variable (z) */ + /* using the target normal distribution parameters */ + result = (stddev * z) + mean; + + PG_RETURN_FLOAT8(result); +} /* * setseed - set seed for the random number generator diff --git a/src/bin/pgbench/pgbench.c b/src/bin/pgbench/pgbench.c index 18d9c94ebd..9c12ffaea9 100644 --- a/src/bin/pgbench/pgbench.c +++ b/src/bin/pgbench/pgbench.c @@ -1136,8 +1136,8 @@ getGaussianRand(pg_prng_state *state, int64 min, int64 max, Assert(parameter >= MIN_GAUSSIAN_PARAM); /* - * Get user specified random number from this loop, with -parameter < - * stdev <= parameter + * Get normally-distributed random number in the range -parameter <= stdev + * < parameter. * * This loop is executed until the number is in the expected range. * @@ -1149,25 +1149,7 @@ getGaussianRand(pg_prng_state *state, int64 min, int64 max, */ do { - /* - * pg_prng_double generates [0, 1), but for the basic version of the - * Box-Muller transform the two uniformly distributed random numbers - * are expected to be in (0, 1] (see - * https://en.wikipedia.org/wiki/Box-Muller_transform) - */ - double rand1 = 1.0 - pg_prng_double(state); - double rand2 = 1.0 - pg_prng_double(state); - - /* Box-Muller basic form transform */ - double var_sqrt = sqrt(-2.0 * log(rand1)); - - stdev = var_sqrt * sin(2.0 * M_PI * rand2); - - /* - * we may try with cos, but there may be a bias induced if the - * previous value fails the test. To be on the safe side, let us try - * over. - */ + stdev = pg_prng_double_normal(state); } while (stdev < -parameter || stdev >= parameter); diff --git a/src/common/pg_prng.c b/src/common/pg_prng.c index e58b471cff..c7bb92ede3 100644 --- a/src/common/pg_prng.c +++ b/src/common/pg_prng.c @@ -19,11 +19,17 @@ #include "c.h" -#include /* for ldexp() */ +#include #include "common/pg_prng.h" #include "port/pg_bitutils.h" +/* X/Open (XSI) requires to provide M_PI, but core POSIX does not */ +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + + /* process-wide state vector */ pg_prng_state pg_global_prng_state; @@ -235,6 +241,35 @@ pg_prng_double(pg_prng_state *state) return ldexp((double) (v >> (64 - 52)), -52); } +/* + * Select a random double from the normal distribution with + * mean = 0.0 and stddev = 1.0. + * + * To get a result from a different normal distribution use + * STDDEV * pg_prng_double_normal() + MEAN + * + * Uses https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform + */ +double +pg_prng_double_normal(pg_prng_state *state) +{ + double u1, + u2, + z0; + + /* + * pg_prng_double generates [0, 1), but for the basic version of the + * Box-Muller transform the two uniformly distributed random numbers are + * expected to be in (0, 1]; in particular we'd better not compute log(0). + */ + u1 = 1.0 - pg_prng_double(state); + u2 = 1.0 - pg_prng_double(state); + + /* Apply Box-Muller transform to get one normal-valued output */ + z0 = sqrt(-2.0 * log(u1)) * sin(2.0 * M_PI * u2); + return z0; +} + /* * Select a random boolean value. */ diff --git a/src/include/catalog/catversion.h b/src/include/catalog/catversion.h index e63ddeeb07..3a0ef3d874 100644 --- a/src/include/catalog/catversion.h +++ b/src/include/catalog/catversion.h @@ -57,6 +57,6 @@ */ /* yyyymmddN */ -#define CATALOG_VERSION_NO 202301091 +#define CATALOG_VERSION_NO 202301092 #endif diff --git a/src/include/catalog/pg_proc.dat b/src/include/catalog/pg_proc.dat index 7be9a50147..3810de7b22 100644 --- a/src/include/catalog/pg_proc.dat +++ b/src/include/catalog/pg_proc.dat @@ -3359,6 +3359,10 @@ { oid => '1598', descr => 'random value', proname => 'random', provolatile => 'v', proparallel => 'r', prorettype => 'float8', proargtypes => '', prosrc => 'drandom' }, +{ oid => '8074', descr => 'random value from normal distribution', + proname => 'random_normal', provolatile => 'v', proparallel => 'r', + prorettype => 'float8', proargtypes => 'float8 float8', + prosrc => 'drandom_normal' }, { oid => '1599', descr => 'set random seed', proname => 'setseed', provolatile => 'v', proparallel => 'r', prorettype => 'void', proargtypes => 'float8', prosrc => 'setseed' }, diff --git a/src/include/common/pg_prng.h b/src/include/common/pg_prng.h index 9e11e8fffd..b5c0b8d288 100644 --- a/src/include/common/pg_prng.h +++ b/src/include/common/pg_prng.h @@ -55,6 +55,7 @@ extern uint32 pg_prng_uint32(pg_prng_state *state); extern int32 pg_prng_int32(pg_prng_state *state); extern int32 pg_prng_int32p(pg_prng_state *state); extern double pg_prng_double(pg_prng_state *state); +extern double pg_prng_double_normal(pg_prng_state *state); extern bool pg_prng_bool(pg_prng_state *state); #endif /* PG_PRNG_H */ diff --git a/src/test/regress/expected/random.out b/src/test/regress/expected/random.out index a919b28d8d..30bd866138 100644 --- a/src/test/regress/expected/random.out +++ b/src/test/regress/expected/random.out @@ -51,3 +51,31 @@ SELECT AVG(random) FROM RANDOM_TBL ----- (0 rows) +-- now test random_normal() +TRUNCATE random_tbl; +INSERT INTO random_tbl (random) + SELECT count(*) + FROM onek WHERE random_normal(0, 1) < 0; +INSERT INTO random_tbl (random) + SELECT count(*) + FROM onek WHERE random_normal(0) < 0; +INSERT INTO random_tbl (random) + SELECT count(*) + FROM onek WHERE random_normal() < 0; +INSERT INTO random_tbl (random) + SELECT count(*) + FROM onek WHERE random_normal(stddev => 1, mean => 0) < 0; +-- expect similar, but not identical values +SELECT random, count(random) FROM random_tbl + GROUP BY random HAVING count(random) > 3; + random | count +--------+------- +(0 rows) + +-- approximately check expected distribution +SELECT AVG(random) FROM random_tbl + HAVING AVG(random) NOT BETWEEN 400 AND 600; + avg +----- +(0 rows) + diff --git a/src/test/regress/sql/random.sql b/src/test/regress/sql/random.sql index 8187b2c288..3104af46b7 100644 --- a/src/test/regress/sql/random.sql +++ b/src/test/regress/sql/random.sql @@ -42,3 +42,27 @@ SELECT random, count(random) FROM RANDOM_TBL SELECT AVG(random) FROM RANDOM_TBL HAVING AVG(random) NOT BETWEEN 80 AND 120; + +-- now test random_normal() + +TRUNCATE random_tbl; +INSERT INTO random_tbl (random) + SELECT count(*) + FROM onek WHERE random_normal(0, 1) < 0; +INSERT INTO random_tbl (random) + SELECT count(*) + FROM onek WHERE random_normal(0) < 0; +INSERT INTO random_tbl (random) + SELECT count(*) + FROM onek WHERE random_normal() < 0; +INSERT INTO random_tbl (random) + SELECT count(*) + FROM onek WHERE random_normal(stddev => 1, mean => 0) < 0; + +-- expect similar, but not identical values +SELECT random, count(random) FROM random_tbl + GROUP BY random HAVING count(random) > 3; + +-- approximately check expected distribution +SELECT AVG(random) FROM random_tbl + HAVING AVG(random) NOT BETWEEN 400 AND 600;