Mercurial > hg > octave-lyh
changeset 9725:aea3a3a950e1
implement nth_element
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Wed, 14 Oct 2009 13:23:31 +0200 |
parents | f22bbc5d56e9 |
children | b7b89061bd0e |
files | liboctave/Array.cc liboctave/Array.h liboctave/ArrayN.h liboctave/ChangeLog liboctave/idx-vector.cc liboctave/idx-vector.h liboctave/oct-sort.cc liboctave/oct-sort.h src/ChangeLog src/data.cc src/ov.cc src/ov.h |
diffstat | 12 files changed, 359 insertions(+), 13 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/Array.cc +++ b/liboctave/Array.cc @@ -2004,7 +2004,7 @@ template <class T> Array<T> -Array<T>::sort (octave_idx_type dim, sortmode mode) const +Array<T>::sort (int dim, sortmode mode) const { if (dim < 0 || dim >= ndims ()) { @@ -2119,7 +2119,7 @@ template <class T> Array<T> -Array<T>::sort (Array<octave_idx_type> &sidx, octave_idx_type dim, +Array<T>::sort (Array<octave_idx_type> &sidx, int dim, sortmode mode) const { if (dim < 0 || dim >= ndims ()) @@ -2594,16 +2594,166 @@ return retval; } +template <class T> +Array<T> +Array<T>::nth_element (const idx_vector& n, int dim) const +{ + if (dim < 0 || dim >= ndims ()) + { + (*current_liboctave_error_handler) + ("nth_element: invalid dimension"); + return Array<T> (); + } + + dim_vector dv = dims (); + octave_idx_type ns = dv(dim); + + octave_idx_type nn = n.length (ns); + + dv(dim) = std::min (nn, ns); + dv.chop_trailing_singletons (); + + Array<T> m (dv); + + if (m.numel () == 0) + return m; + + sortmode mode = UNSORTED; + octave_idx_type lo = 0; + + switch (n.idx_class ()) + { + case idx_vector::class_scalar: + mode = ASCENDING; + lo = n(0); + break; + case idx_vector::class_range: + { + octave_idx_type inc = n.increment (); + if (inc == 1) + { + mode = ASCENDING; + lo = n(0); + } + else if (inc == -1) + { + mode = DESCENDING; + lo = ns - 1 - n(0); + } + } + default: + break; + } + + if (mode == UNSORTED) + { + (*current_liboctave_error_handler) + ("nth_element: n must be a scalar or a contiguous range"); + return Array<T> (); + } + + octave_idx_type up = lo + nn; + + if (lo < 0 || up > ns) + { + (*current_liboctave_error_handler) + ("nth_element: invalid element index"); + return Array<T> (); + } + + octave_idx_type iter = numel () / ns; + octave_idx_type stride = 1; + + for (int i = 0; i < dim; i++) + stride *= dv(i); + + T *v = m.fortran_vec (); + const T *ov = data (); + + OCTAVE_LOCAL_BUFFER (T, buf, ns); + + octave_sort<T> lsort; + lsort.set_compare (mode); + + for (octave_idx_type j = 0; j < iter; j++) + { + octave_idx_type kl = 0, ku = ns; + + if (stride == 1) + { + // copy without NaNs. + // FIXME: impact on integer types noticeable? + for (octave_idx_type i = 0; i < ns; i++) + { + T tmp = ov[i]; + if (sort_isnan<T> (tmp)) + buf[--ku] = tmp; + else + buf[kl++] = tmp; + } + + ov += ns; + } + else + { + octave_idx_type offset = j % stride; + // copy without NaNs. + // FIXME: impact on integer types noticeable? + for (octave_idx_type i = 0; i < ns; i++) + { + T tmp = ov[offset + i*stride]; + if (sort_isnan<T> (tmp)) + buf[--ku] = tmp; + else + buf[kl++] = tmp; + } + + if (offset == stride-1) + ov += ns*stride; + } + + if (ku == ns) + lsort.nth_element (buf, ns, lo, up); + else if (mode == ASCENDING) + lsort.nth_element (buf, ku, lo, std::min (ku, up)); + else + { + octave_idx_type nnan = ns - ku; + lsort.nth_element (buf, ku, std::max (lo - nnan, 0), + std::max (up - nnan, 0)); + std::rotate (buf, buf + ku, buf + ns); + } + + if (stride == 1) + { + for (octave_idx_type i = 0; i < nn; i++) + v[i] = buf[lo + i]; + + v += nn; + } + else + { + octave_idx_type offset = j % stride; + for (octave_idx_type i = 0; i < nn; i++) + v[offset + stride * i] = buf[lo + i]; + if (offset == stride-1) + v += nn*stride; + } + } + + return m; +} + #define INSTANTIATE_ARRAY_SORT(T) template class OCTAVE_API octave_sort<T>; #define NO_INSTANTIATE_ARRAY_SORT(T) \ \ template <> Array<T> \ -Array<T>::sort (octave_idx_type, sortmode) const { return *this; } \ +Array<T>::sort (int, sortmode) const { return *this; } \ \ template <> Array<T> \ -Array<T>::sort (Array<octave_idx_type> &sidx, octave_idx_type, sortmode) const \ +Array<T>::sort (Array<octave_idx_type> &sidx, int, sortmode) const \ { sidx = Array<octave_idx_type> (); return *this; } \ \ template <> sortmode \ @@ -2637,6 +2787,9 @@ template <> Array<octave_idx_type> \ Array<T>::find (octave_idx_type, bool) const\ { return Array<octave_idx_type> (); } \ + \ +template <> Array<T> \ +Array<T>::nth_element (const idx_vector&, int) const { return Array<T> (); } \ template <class T>
--- a/liboctave/Array.h +++ b/liboctave/Array.h @@ -595,8 +595,8 @@ // You should not use it anywhere else. void *mex_get_data (void) const { return const_cast<T *> (data ()); } - Array<T> sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const; - Array<T> sort (Array<octave_idx_type> &sidx, octave_idx_type dim = 0, + Array<T> sort (int dim = 0, sortmode mode = ASCENDING) const; + Array<T> sort (Array<octave_idx_type> &sidx, int dim = 0, sortmode mode = ASCENDING) const; // Ordering is auto-detected or can be specified. @@ -631,6 +631,10 @@ // specifies search from backward. Array<octave_idx_type> find (octave_idx_type n = -1, bool backward = false) const; + // Returns the n-th element in increasing order, using the same ordering as + // used for sort. n can either be a scalar index or a contiguous range. + Array<T> nth_element (const idx_vector& n, int dim = 0) const; + Array<T> diag (octave_idx_type k = 0) const; template <class U, class F>
--- a/liboctave/ArrayN.h +++ b/liboctave/ArrayN.h @@ -131,17 +131,20 @@ return ArrayN<T> (tmp, tmp.dims ()); } - ArrayN<T> sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const + ArrayN<T> sort (int dim = 0, sortmode mode = ASCENDING) const { - Array<T> tmp = Array<T>::sort (dim, mode); - return ArrayN<T> (tmp, tmp.dims ()); + return Array<T>::sort (dim, mode); } - ArrayN<T> sort (Array<octave_idx_type> &sidx, octave_idx_type dim = 0, + ArrayN<T> sort (Array<octave_idx_type> &sidx, int dim = 0, sortmode mode = ASCENDING) const { - Array<T> tmp = Array<T>::sort (sidx, dim, mode); - return ArrayN<T> (tmp, tmp.dims ()); + return Array<T>::sort (sidx, dim, mode); + } + + ArrayN<T> nth_element (const idx_vector& n, int dim = 0) const + { + return Array<T>::nth_element (n, dim); } ArrayN<T> diag (octave_idx_type k) const
--- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,3 +1,13 @@ +2009-10-14 Jaroslav Hajek <highegg@gmail.com> + + * oct-sort.cc (octave_sort<T>::nth_element): New overloaded method. + * oct-sort.h: Declare it. + * Array.cc (Array<T>::nth_element): New method. + * Array.h: Declare it. + (Array<T>::sort): Use int for dim argument. + * ArrayN.h (ArrayN<T>::nth_element): Wrap. + (ArrayN<T>::sort): Use int for dim argument. + 2009-10-13 Jaroslav Hajek <highegg@gmail.com> * lo-traits.h (equal_types, is_instance, subst_template_param): New
--- a/liboctave/idx-vector.cc +++ b/liboctave/idx-vector.cc @@ -541,6 +541,26 @@ return res; } +octave_idx_type +idx_vector::increment (void) const +{ + octave_idx_type retval = 0; + switch (rep->idx_class ()) + { + case class_colon: + retval = 1; + case class_range: + retval = dynamic_cast<idx_range_rep *> (rep) -> get_step (); + break; + case class_vector: + { + if (length (0) > 1) + retval = elem (1) - elem (0); + } + } + return retval; +} + void idx_vector::copy_data (octave_idx_type *data) const {
--- a/liboctave/idx-vector.h +++ b/liboctave/idx-vector.h @@ -796,6 +796,10 @@ bool is_cont_range (octave_idx_type n, octave_idx_type& l, octave_idx_type& u) const; + // Returns the increment for ranges and colon, 0 for scalars and empty + // vectors, 1st difference otherwise. + octave_idx_type increment (void) const; + idx_vector complement (octave_idx_type n) const;
--- a/liboctave/oct-sort.cc +++ b/liboctave/oct-sort.cc @@ -1919,7 +1919,6 @@ lookupm (data, nel, values, nvalues, idx, std::ptr_fun (compare)); } -#include <iostream> template <class T> template <class Comp> void octave_sort<T>::lookupb (const T *data, octave_idx_type nel, @@ -1983,6 +1982,53 @@ lookupb (data, nel, values, nvalues, match, std::ptr_fun (compare)); } +template <class T> template <class Comp> +void +octave_sort<T>::nth_element (T *data, octave_idx_type nel, + octave_idx_type lo, octave_idx_type up, + Comp comp) +{ + // Simply wrap the STL algorithms. + // FIXME: this will fail if we attempt to inline <,> for Complex. + if (up == lo+1) + std::nth_element (data, data + lo, data + nel, comp); + else if (lo == 0) + std::partial_sort (data, data + up, data + nel, comp); + else + { + std::nth_element (data, data + lo, data + nel, comp); + if (up == lo + 2) + { + // Finding two subsequent elements. + std::swap (data[lo+1], + *std::min_element (data + lo + 1, data + nel, comp)); + } + else + std::partial_sort (data + lo + 1, data + up, data + nel, comp); + } +} + +template <class T> +void +octave_sort<T>::nth_element (T *data, octave_idx_type nel, + octave_idx_type lo, octave_idx_type up) +{ + if (up < 0) + up = lo + 1; +#ifdef INLINE_ASCENDING_SORT + if (compare == ascending_compare) + nth_element (data, nel, lo, up, std::less<T> ()); + else +#endif +#ifdef INLINE_DESCENDING_SORT + if (compare == descending_compare) + nth_element (data, nel, lo, up, std::greater<T> ()); + else +#endif + if (compare) + nth_element (data, nel, lo, up, std::ptr_fun (compare)); +} + template <class T> bool octave_sort<T>::ascending_compare (typename ref_param<T>::type x,
--- a/liboctave/oct-sort.h +++ b/liboctave/oct-sort.h @@ -159,6 +159,11 @@ const T* values, octave_idx_type nvalues, bool *match); + // Rearranges the array so that the elements with indices + // lo..up-1 are in their correct place. + void nth_element (T *data, octave_idx_type nel, + octave_idx_type lo, octave_idx_type up = -1); + static bool ascending_compare (typename ref_param<T>::type, typename ref_param<T>::type); @@ -322,6 +327,11 @@ void lookupb (const T *data, octave_idx_type nel, const T* values, octave_idx_type nvalues, bool *match, Comp comp); + + template <class Comp> + void nth_element (T *data, octave_idx_type nel, + octave_idx_type lo, octave_idx_type up, + Comp comp); }; template <class T>
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,10 @@ +2009-10-14 Jaroslav Hajek <highegg@gmail.com> + + * ov.cc (octave_value::octave_value (const Array<std::string>&)): New + constructor. + * ov.h: Declare it. + * data.cc (Fnth_element): New DEFUN. + 2009-10-13 Jaroslav Hajek <highegg@gmail.com> * data.cc (Fcumsum, Fcumprod, Fprod, Fsum, Fsumsq): Correct help
--- a/src/data.cc +++ b/src/data.cc @@ -6153,6 +6153,88 @@ return retval; } +DEFUN (nth_element, args, , + "-*- texinfo -*-\n\ +@deftypefn {Built-in Function} {} nth_element (@var{x}, @var{n})\n\ +@deftypefnx {Built-in Function} {} nth_element (@var{x}, @var{n}, @var{dim})\n\ +Select the n-th smallest element of a vector, using the ordering defined by @code{sort}.\n\ +In other words, the result is equivalent to @code{sort(@var{x})(@var{n})}.\n\ +@var{n} can also be a contiguous range, either ascending @code{l:u}\n\ +or descending @code{u:-1:l}, in which case a range of elements is returned.\n\ +If @var{x} is an array, @code{nth_element} operates along the dimension defined by @var{dim},\n\ +or the first non-singleton dimension if @var{dim} is not given.\n\ +\n\ +nth_element encapsulates the C++ STL algorithms nth_element and partial_sort.\n\ +On average, the complexity of the operation is O(M*log(K)), where\n\ +@code{M = size(@var{x}, @var{dim})} and @code{K = length (@var{n})}.\n\ +This function is intended for cases where the ratio K/M is small; otherwise,\n\ +it may be better to use @code{sort}.\n\ +@seealso{sort, min, max}\n\ +@end deftypefn") +{ + octave_value retval; + int nargin = args.length (); + + if (nargin == 2 || nargin == 3) + { + octave_value argx = args(0); + + int dim = -1; + if (nargin == 3) + { + dim = args(2).int_value (true) - 1; + if (dim < 0 || dim >= argx.ndims ()) + error ("nth_element: dim must be a valid dimension"); + } + if (dim < 0) + dim = argx.dims ().first_non_singleton (); + + idx_vector n = args(1).index_vector (); + + if (error_state) + return retval; + + switch (argx.builtin_type ()) + { + case btyp_double: + retval = argx.array_value ().nth_element (n, dim); + break; + case btyp_float: + retval = argx.float_array_value ().nth_element (n, dim); + break; + case btyp_complex: + retval = argx.complex_array_value ().nth_element (n, dim); + break; + case btyp_float_complex: + retval = argx.float_complex_array_value ().nth_element (n, dim); + break; +#define MAKE_INT_BRANCH(X) \ + case btyp_ ## X: \ + retval = argx.X ## _array_value ().nth_element (n, dim); \ + break + + MAKE_INT_BRANCH (int8); + MAKE_INT_BRANCH (int16); + MAKE_INT_BRANCH (int32); + MAKE_INT_BRANCH (int64); + MAKE_INT_BRANCH (uint8); + MAKE_INT_BRANCH (uint16); + MAKE_INT_BRANCH (uint32); + MAKE_INT_BRANCH (uint64); +#undef MAKE_INT_BRANCH + default: + if (argx.is_cellstr ()) + retval = argx.cellstr_value ().nth_element (n, dim); + else + gripe_wrong_type_arg ("nth_element", argx); + } + } + else + print_usage (); + + return retval; +} + template <class NDT> static NDT do_accumarray_sum (const idx_vector& idx, const NDT& vals,
--- a/src/ov.cc +++ b/src/ov.cc @@ -1078,6 +1078,12 @@ maybe_mutate (); } +octave_value::octave_value (const Array<std::string>& cellstr) + : rep (new octave_cell (cellstr)) +{ + maybe_mutate (); +} + octave_value::octave_value (double base, double limit, double inc) : rep (new octave_range (base, limit, inc)) {
--- a/src/ov.h +++ b/src/ov.h @@ -270,6 +270,7 @@ octave_value (const ArrayN<octave_uint64>& inda); octave_value (const Array<octave_idx_type>& inda, bool zero_based = false, bool cache_index = false); + octave_value (const Array<std::string>& cellstr); octave_value (const idx_vector& idx); octave_value (double base, double limit, double inc); octave_value (const Range& r);