Mercurial > hg > octave-lyh
changeset 7533:ff52243af934
save state separately for each MT random number generator
author | John W. Eaton <jwe@octave.org> |
---|---|
date | Tue, 26 Feb 2008 05:28:59 -0500 |
parents | 493bb0de3199 |
children | ef755c763b62 |
files | liboctave/ChangeLog liboctave/oct-rand.cc liboctave/oct-rand.h liboctave/randmtzig.c src/ChangeLog src/DLD-FUNCTIONS/rand.cc |
diffstat | 6 files changed, 217 insertions(+), 72 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,5 +1,24 @@ 2008-02-26 John W. Eaton <jwe@octave.org> + * oct-rand.cc (rand_states): New static variable. + (initialize_rand_states, get_dist_id, get_internal_state, + set_internal_state, switch_to_generator, save_state): New functions. + (octave_rand::state): New arg to specify distribution. + Save state in rand_states instead of setting internal state. + Return named state. Use set_internal_state to generate proper + state vector from user supplied state. Save and restore current + state if specified and current distributions are different. + (octave_rand::distribution (void)): Use switch rather than if/else. + (octave_rand::distribution (const std::string&)): Likewise. + (octave_rand::uniform_distribution, + octave_rand::normal_distribution, + octave_rand::exponential_distribution, + octave_rand::poisson_distribution, + octave_rand::gamma_distribution): Call switch_to_generator. + (octave_rand::state, maybe_initialize): For new_generators, just + call initialize_rand_states if not already initialized. + (octave_rand::scalar, fill_rand): Save state after generating value. + * dMatrix.cc (Matrix::lssolve): Avoid another dgelsd lwork query bug. * CMatrix.cc (ComplexMatrix::lssolve): Likewise, for zgelsd
--- a/liboctave/oct-rand.cc +++ b/liboctave/oct-rand.cc @@ -23,6 +23,8 @@ #ifdef HAVE_CONFIG_H #include <config.h> #endif + +#include <map> #include <vector> #include "f77-fcn.h" @@ -53,6 +55,8 @@ static bool new_initialized = false; static bool use_old_generators = false; +std::map<int, ColumnVector> rand_states; + extern "C" { F77_RET_T @@ -126,6 +130,46 @@ old_initialized = true; } +static ColumnVector +get_internal_state (void) +{ + ColumnVector s (MT_N + 1); + + OCTAVE_LOCAL_BUFFER (uint32_t, tmp, MT_N + 1); + + oct_get_state (tmp); + + for (octave_idx_type i = 0; i <= MT_N; i++) + s.elem (i) = static_cast<double> (tmp [i]); + + return s; +} + +static inline void +save_state (void) +{ + rand_states[current_distribution] = get_internal_state ();; +} + +static void +initialize_rand_states (void) +{ + if (! new_initialized) + { + oct_init_by_entropy (); + + ColumnVector s = get_internal_state (); + + rand_states[uniform_dist] = s; + rand_states[normal_dist] = s; + rand_states[expon_dist] = s; + rand_states[poisson_dist] = s; + rand_states[gamma_dist] = s; + + new_initialized = true; + } +} + static inline void maybe_initialize (void) { @@ -137,10 +181,56 @@ else { if (! new_initialized) - { - oct_init_by_entropy (); - new_initialized = true; - } + initialize_rand_states (); + } +} + +static int +get_dist_id (const std::string& d) +{ + int retval; + + if (d == "uniform" || d == "rand") + retval = uniform_dist; + else if (d == "normal" || d == "randn") + retval = normal_dist; + else if (d == "exponential" || d == "rande") + retval = expon_dist; + else if (d == "poisson" || d == "randp") + retval = poisson_dist; + else if (d == "gamma" || d == "rangd") + retval = gamma_dist; + else + (*current_liboctave_error_handler) ("rand: invalid distribution"); + + return retval; +} + +static void +set_internal_state (const ColumnVector& s) +{ + octave_idx_type len = s.length (); + octave_idx_type n = len < MT_N + 1 ? len : MT_N + 1; + + OCTAVE_LOCAL_BUFFER (uint32_t, tmp, MT_N + 1); + + for (octave_idx_type i = 0; i < n; i++) + tmp[i] = static_cast<uint32_t> (s.elem(i)); + + if (len == MT_N + 1 && tmp[MT_N] <= MT_N && tmp[MT_N] > 0) + oct_set_state (tmp); + else + oct_init_by_array (tmp, len); +} + +static inline void +switch_to_generator (int dist) +{ + if (dist != current_distribution) + { + current_distribution = dist; + + set_internal_state (rand_states[dist]); } } @@ -172,6 +262,7 @@ octave_rand::seed (double s) { use_old_generators = true; + maybe_initialize (); int i0, i1; @@ -197,77 +288,104 @@ } ColumnVector -octave_rand::state (void) +octave_rand::state (const std::string& d) { - ColumnVector s (MT_N + 1); if (! new_initialized) - { - oct_init_by_entropy (); - new_initialized = true; - } + initialize_rand_states (); - OCTAVE_LOCAL_BUFFER (uint32_t, tmp, MT_N + 1); - oct_get_state (tmp); - for (octave_idx_type i = 0; i <= MT_N; i++) - s.elem (i) = static_cast<double>(tmp [i]); - return s; + return rand_states[d.empty () ? current_distribution : get_dist_id (d)]; } void -octave_rand::state (const ColumnVector &s) +octave_rand::state (const ColumnVector& s, const std::string& d) { use_old_generators = false; + maybe_initialize (); - octave_idx_type len = s.length(); - octave_idx_type n = len < MT_N + 1 ? len : MT_N + 1; - OCTAVE_LOCAL_BUFFER (uint32_t, tmp, MT_N + 1); - for (octave_idx_type i = 0; i < n; i++) - tmp[i] = static_cast<uint32_t> (s.elem(i)); + int old_dist = current_distribution; + + int new_dist = d.empty () ? current_distribution : get_dist_id (d); + + ColumnVector saved_state; - if (len == MT_N + 1 && tmp[MT_N] <= MT_N && tmp[MT_N] > 0) - oct_set_state (tmp); - else - oct_init_by_array (tmp, len); + if (old_dist != new_dist) + saved_state = get_internal_state (); + + set_internal_state (s); + + rand_states[new_dist] = get_internal_state (); + + if (old_dist != new_dist) + rand_states[old_dist] = saved_state; } std::string octave_rand::distribution (void) { + std::string retval; + maybe_initialize (); - if (current_distribution == uniform_dist) - return "uniform"; - else if (current_distribution == normal_dist) - return "normal"; - else if (current_distribution == expon_dist) - return "exponential"; - else if (current_distribution == poisson_dist) - return "poisson"; - else if (current_distribution == gamma_dist) - return "gamma"; - else + switch (current_distribution) { - abort (); - return ""; + case uniform_dist: + retval = "uniform"; + break; + + case normal_dist: + retval = "normal"; + break; + + case expon_dist: + retval = "exponential"; + break; + + case poisson_dist: + retval = "poisson"; + break; + + case gamma_dist: + retval = "gamma"; + break; + + default: + (*current_liboctave_error_handler) ("rand: invalid distribution"); + break; } + + return retval; } void octave_rand::distribution (const std::string& d) { - if (d == "uniform") - octave_rand::uniform_distribution (); - else if (d == "normal") - octave_rand::normal_distribution (); - else if (d == "exponential") - octave_rand::exponential_distribution (); - else if (d == "poisson") - octave_rand::poisson_distribution (); - else if (d == "gamma") - octave_rand::gamma_distribution (); - else - (*current_liboctave_error_handler) ("rand: invalid distribution"); + switch (get_dist_id (d)) + { + case uniform_dist: + octave_rand::uniform_distribution (); + break; + + case normal_dist: + octave_rand::normal_distribution (); + break; + + case expon_dist: + octave_rand::exponential_distribution (); + break; + + case poisson_dist: + octave_rand::poisson_distribution (); + break; + + case gamma_dist: + octave_rand::gamma_distribution (); + break; + + default: + (*current_liboctave_error_handler) ("rand: invalid distribution"); + break; + } } void @@ -275,7 +393,7 @@ { maybe_initialize (); - current_distribution = uniform_dist; + switch_to_generator (uniform_dist); F77_FUNC (setcgn, SETCGN) (uniform_dist); } @@ -285,7 +403,7 @@ { maybe_initialize (); - current_distribution = normal_dist; + switch_to_generator (normal_dist); F77_FUNC (setcgn, SETCGN) (normal_dist); } @@ -295,7 +413,7 @@ { maybe_initialize (); - current_distribution = expon_dist; + switch_to_generator (expon_dist); F77_FUNC (setcgn, SETCGN) (expon_dist); } @@ -305,7 +423,7 @@ { maybe_initialize (); - current_distribution = poisson_dist; + switch_to_generator (poisson_dist); F77_FUNC (setcgn, SETCGN) (poisson_dist); } @@ -315,7 +433,7 @@ { maybe_initialize (); - current_distribution = gamma_dist; + switch_to_generator (gamma_dist); F77_FUNC (setcgn, SETCGN) (gamma_dist); } @@ -363,7 +481,7 @@ break; default: - abort (); + (*current_liboctave_error_handler) ("rand: invalid distribution"); break; } } @@ -372,29 +490,31 @@ switch (current_distribution) { case uniform_dist: - retval = oct_randu(); + retval = oct_randu (); break; case normal_dist: - retval = oct_randn(); + retval = oct_randn (); break; case expon_dist: - retval = oct_rande(); + retval = oct_rande (); break; case poisson_dist: - retval = oct_randp(a); + retval = oct_randp (a); break; case gamma_dist: - retval = oct_randg(a); + retval = oct_randg (a); break; default: - abort (); + (*current_liboctave_error_handler) ("rand: invalid distribution"); break; } + + save_state (); } return retval; @@ -494,10 +614,12 @@ break; default: - abort (); + (*current_liboctave_error_handler) ("rand: invalid distribution"); break; } + save_state (); + return; }
--- a/liboctave/oct-rand.h +++ b/liboctave/oct-rand.h @@ -40,16 +40,17 @@ static void seed (double s); // Return the current state. - static ColumnVector state (void); + static ColumnVector state (const std::string& d = std::string ()); // Set the current state/ - static void state (const ColumnVector &s); + static void state (const ColumnVector &s, + const std::string& d = std::string ()); // Return the current distribution. static std::string distribution (void); // Set the current distribution. May be either "uniform" (the - // default) or "normal". + // default), "normal", "exponential", "poisson", or "gamma". static void distribution (const std::string& d); static void uniform_distribution (void);
--- a/liboctave/randmtzig.c +++ b/liboctave/randmtzig.c @@ -203,7 +203,7 @@ /* init_key is the array for initializing keys */ /* key_length is its length */ void -oct_init_by_array (uint32_t init_key[], int key_length) +oct_init_by_array (uint32_t *init_key, int key_length) { int i, j, k; oct_init_by_int (19650218UL); @@ -281,17 +281,17 @@ } void -oct_set_state (uint32_t save[]) +oct_set_state (uint32_t *save) { int i; - for (i=0; i < MT_N; i++) + for (i = 0; i < MT_N; i++) state[i] = save[i]; left = save[MT_N]; next = state + (MT_N - left + 1); } void -oct_get_state (uint32_t save[]) +oct_get_state (uint32_t *save) { int i; for (i = 0; i < MT_N; i++)
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,5 +1,8 @@ 2008-02-26 John W. Eaton <jwe@octave.org> + * DLD-FUNCTIONS/rand.cc (do_rand): Pass name of calling function + to octave_rand::state. + * variables.cc (bind_ans): Handle cs-lists recursively. * ov-cs-list.h, ov-cs-list.cc (octave_cs_list::print,
--- a/src/DLD-FUNCTIONS/rand.cc +++ b/src/DLD-FUNCTIONS/rand.cc @@ -113,7 +113,7 @@ } else if (s_arg == "state" || s_arg == "twister") { - retval = octave_rand::state (); + retval = octave_rand::state (fcn); } else if (s_arg == "uniform") { @@ -250,7 +250,7 @@ ColumnVector (args(idx+1).vector_value(false, true)); if (! error_state) - octave_rand::state (s); + octave_rand::state (s, fcn); } else error ("%s: unrecognized string argument", fcn);