Mercurial > hg > octave-terminal
changeset 13255:dd3c5325039c
Use a hash map to store permutations in randperm's truncated Knuth shuffle
author | Jordi Gutiérrez Hermoso <jordigh@octave.org> |
---|---|
date | Thu, 29 Sep 2011 17:44:32 -0500 |
parents | e749d0b568c8 |
children | 41c2f4633a62 |
files | src/DLD-FUNCTIONS/rand.cc |
diffstat | 1 files changed, 43 insertions(+), 12 deletions(-) [+] |
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/rand.cc +++ b/src/DLD-FUNCTIONS/rand.cc @@ -26,7 +26,7 @@ #endif #include <ctime> - +#include <tr1/unordered_map> #include <string> #include "f77-fcn.h" @@ -1020,9 +1020,10 @@ @deftypefnx {Loadable Function} {} randperm (@var{n}, @var{m})\n\ Return a row vector containing a random permutation of @code{1:@var{n}}.\n\ If @var{m} is supplied, return @var{m} unique entries, sampled without\n\ -replacement from @code{1:@var{n}}. The complexity is O(N) in memory and \n\ -O(M) in time. The randomization is performed using rand(). All\n\ -permutations are equally likely.\n\ +replacement from @code{1:@var{n}}. The complexity is O(@var{n}) in\n\ +memory and O(@var{m}) in time, unless @var{m} < @var{n}/5, in which case\n\ +O(@var{m}) memory is used as well. The randomization is performed using\n\ +rand(). All permutations are equally likely.\n\ @seealso{perms}\n\ @end deftypefn") { @@ -1046,25 +1047,55 @@ if (m > n) error ("randperm: M must be less than or equal to N"); + // Quick and dirty heuristic to decide if we allocate or not the + // whole vector for tracking the truncated shuffle. + bool short_shuffle = m < n/5 && m < 1e5; + if (! error_state) { // Generate random numbers. NDArray r = octave_rand::nd_array (dim_vector (1, m)); - - Array<octave_idx_type> idx (dim_vector (1, n)); - double *rvec = r.fortran_vec (); + octave_idx_type idx_len = short_shuffle ? m : n; + Array<octave_idx_type> idx (dim_vector (1, idx_len)); octave_idx_type *ivec = idx.fortran_vec (); - for (octave_idx_type i = 0; i < n; i++) + for (octave_idx_type i = 0; i < idx_len; i++) ivec[i] = i; - // Perform the Knuth shuffle of the first m entries - for (octave_idx_type i = 0; i < m; i++) + if (short_shuffle) { - octave_idx_type k = i + gnulib::floor (rvec[i] * (n - i)); - std::swap (ivec[i], ivec[k]); + std::tr1::unordered_map<octave_idx_type, + octave_idx_type> map (m); + + // Perform the Knuth shuffle only keeping track of moved + // entries in the map + for (octave_idx_type i = 0; i < m; i++) + { + octave_idx_type k = i + + gnulib::floor (rvec[i] * (n - i)); + + if (map.find(k) == map.end()) + { + map[k] = ivec[i]; + ivec[i] = k; + } + else + std::swap (ivec[i], map[k]); + + } + } + else + { + + // Perform the Knuth shuffle of the first m entries + for (octave_idx_type i = 0; i < m; i++) + { + octave_idx_type k = i + + gnulib::floor (rvec[i] * (n - i)); + std::swap (ivec[i], ivec[k]); + } } // Convert to doubles, reusing r.