diff src/data.cc @ 9725:aea3a3a950e1

implement nth_element
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 14 Oct 2009 13:23:31 +0200
parents f426899f4b9c
children b4fdfee405b5
line wrap: on
line diff
--- a/src/data.cc
+++ b/src/data.cc
@@ -6153,6 +6153,88 @@
   return retval;
 }
 
+DEFUN (nth_element, args, ,
+  "-*- texinfo -*-\n\
+@deftypefn {Built-in Function} {} nth_element (@var{x}, @var{n})\n\
+@deftypefnx {Built-in Function} {} nth_element (@var{x}, @var{n}, @var{dim})\n\
+Select the n-th smallest element of a vector, using the ordering defined by @code{sort}.\n\
+In other words, the result is equivalent to @code{sort(@var{x})(@var{n})}.\n\
+@var{n} can also be a contiguous range, either ascending @code{l:u}\n\
+or descending @code{u:-1:l}, in which case a range of elements is returned.\n\
+If @var{x} is an array, @code{nth_element} operates along the dimension defined by @var{dim},\n\
+or the first non-singleton dimension if @var{dim} is not given.\n\
+\n\
+nth_element encapsulates the C++ STL algorithms nth_element and partial_sort.\n\
+On average, the complexity of the operation is O(M*log(K)), where\n\
+@code{M = size(@var{x}, @var{dim})} and @code{K = length (@var{n})}.\n\
+This function is intended for cases where the ratio K/M is small; otherwise,\n\
+it may be better to use @code{sort}.\n\
+@seealso{sort, min, max}\n\
+@end deftypefn")
+{
+  octave_value retval;
+  int nargin = args.length ();
+
+  if (nargin == 2 || nargin == 3)
+    {
+      octave_value argx = args(0);
+
+      int dim = -1;
+      if (nargin == 3)
+        {
+          dim = args(2).int_value (true) - 1;
+          if (dim < 0 || dim >= argx.ndims ())
+            error ("nth_element: dim must be a valid dimension");
+        }
+      if (dim < 0)
+        dim = argx.dims ().first_non_singleton ();
+
+      idx_vector n = args(1).index_vector ();
+
+      if (error_state)
+        return retval;
+
+      switch (argx.builtin_type ())
+        {
+        case btyp_double:
+          retval = argx.array_value ().nth_element (n, dim);
+          break;
+        case btyp_float:
+          retval = argx.float_array_value ().nth_element (n, dim);
+          break;
+        case btyp_complex:
+          retval = argx.complex_array_value ().nth_element (n, dim);
+          break;
+        case btyp_float_complex:
+          retval = argx.float_complex_array_value ().nth_element (n, dim);
+          break;
+#define MAKE_INT_BRANCH(X) \
+        case btyp_ ## X: \
+          retval = argx.X ## _array_value ().nth_element (n, dim); \
+          break
+
+        MAKE_INT_BRANCH (int8);
+        MAKE_INT_BRANCH (int16);
+        MAKE_INT_BRANCH (int32);
+        MAKE_INT_BRANCH (int64);
+        MAKE_INT_BRANCH (uint8);
+        MAKE_INT_BRANCH (uint16);
+        MAKE_INT_BRANCH (uint32);
+        MAKE_INT_BRANCH (uint64);
+#undef MAKE_INT_BRANCH
+        default:
+          if (argx.is_cellstr ())
+            retval = argx.cellstr_value ().nth_element (n, dim);
+          else
+            gripe_wrong_type_arg ("nth_element", argx);
+        }
+    }
+  else
+    print_usage ();
+
+  return retval;
+}
+
 template <class NDT>
 static NDT 
 do_accumarray_sum (const idx_vector& idx, const NDT& vals,