changeset 9341:9fd5c56ce57a

extend lookup capabilities
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 12 Jun 2009 16:01:53 +0200
parents 49fe8721bae1
children 2ca8879a140c
files liboctave/Array.cc liboctave/Array.h liboctave/ChangeLog liboctave/oct-sort.cc liboctave/oct-sort.h src/ChangeLog src/DLD-FUNCTIONS/lookup.cc
diffstat 7 files changed, 283 insertions(+), 50 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/Array.cc
+++ b/liboctave/Array.cc
@@ -2470,6 +2470,56 @@
 
 template <class T>
 Array<octave_idx_type> 
+Array<T>::lookupm (const Array<T>& values, sortmode mode) const
+{
+  octave_idx_type n = numel ();
+  octave_sort<T> lsort;
+  Array<octave_idx_type> idx (values.dims ());
+
+  if (mode == UNSORTED)
+    {
+      // auto-detect mode
+      if (n > 1 && lsort.descending_compare (elem (0), elem (n-1)))
+        mode = DESCENDING;
+      else
+        mode = ASCENDING;
+    }
+
+  lsort.set_compare (mode);
+
+  lsort.lookupm (data (), n, values.data (), values.numel (),
+                 idx.fortran_vec ());
+
+  return idx;
+}
+
+template <class T>
+Array<bool> 
+Array<T>::lookupb (const Array<T>& values, sortmode mode) const
+{
+  octave_idx_type n = numel ();
+  octave_sort<T> lsort;
+  Array<bool> match (values.dims ());
+
+  if (mode == UNSORTED)
+    {
+      // auto-detect mode
+      if (n > 1 && lsort.descending_compare (elem (0), elem (n-1)))
+        mode = DESCENDING;
+      else
+        mode = ASCENDING;
+    }
+
+  lsort.set_compare (mode);
+
+  lsort.lookupb (data (), n, values.data (), values.numel (),
+                 match.fortran_vec ());
+
+  return match;
+}
+
+template <class T>
+Array<octave_idx_type> 
 Array<T>::find (octave_idx_type n, bool backward) const
 {
   Array<octave_idx_type> retval;
@@ -2581,6 +2631,12 @@
 template <> Array<octave_idx_type>  \
 Array<T>::lookup (const Array<T>&, sortmode, bool, bool) const \
 { return Array<octave_idx_type> (); } \
+template <> Array<octave_idx_type>  \
+Array<T>::lookupm (const Array<T>&, sortmode) const \
+{ return Array<octave_idx_type> (); } \
+template <> Array<bool>  \
+Array<T>::lookupb (const Array<T>&, sortmode) const \
+{ return Array<bool> (); } \
 template <> Array<octave_idx_type> \
 Array<T>::find (octave_idx_type, bool) const\
 { return Array<octave_idx_type> (); } \
--- a/liboctave/Array.h
+++ b/liboctave/Array.h
@@ -582,6 +582,13 @@
   Array<octave_idx_type> lookup (const Array<T>& values, sortmode mode = UNSORTED, 
                                  bool linf = false, bool rinf = false) const;
 
+  // This looks up only exact matches, giving their indices. Non-exact matches get
+  // the value -1.
+  Array<octave_idx_type> lookupm (const Array<T>& values, sortmode mode = UNSORTED) const;
+
+  // This looks up only exact matches, returning true/false if match.
+  Array<bool> lookupb (const Array<T>& values, sortmode mode = UNSORTED) const;
+
   // Find indices of (at most n) nonzero elements. If n is specified, backward
   // specifies search from backward.
   Array<octave_idx_type> find (octave_idx_type n = -1, bool backward = false) const;
--- a/liboctave/ChangeLog
+++ b/liboctave/ChangeLog
@@ -1,3 +1,11 @@
+2009-06-12  Jaroslav Hajek  <highegg@gmail.com>
+
+	* oct-sort.cc (octave_sort::lookupm, octave_sort::lookupb): New
+	overloaded methods.
+	* oct-sort.h: Declare them.
+	* Array.cc (Array<T>::lookupm, Array<T>::lookupb): New methods.
+	* Array.h: Declare them.
+
 2009-06-09  Jaroslav Hajek  <highegg@gmail.com>
 
 	* cmd-edit.cc (command_editor::force_default_editor): New static
--- a/liboctave/oct-sort.cc
+++ b/liboctave/oct-sort.cc
@@ -1925,6 +1925,74 @@
       lookup (data, nel, values, nvalues, idx, offset, std::ptr_fun (compare));
 }
 
+template <class T> template <class Comp>
+void 
+octave_sort<T>::lookupm (const T *data, octave_idx_type nel,
+                         const T *values, octave_idx_type nvalues,
+                         octave_idx_type *idx, Comp comp)
+{
+  const T *end = data + nel;
+  for (octave_idx_type i = 0; i < nvalues; i++)
+    {
+      const T *ptr = std::lower_bound (data, end, values[i], comp);
+      if (ptr != end && ! comp (values[i], *ptr))
+        idx[i] = ptr - data;
+      else
+        idx[i] = -1;
+    }
+}
+
+template <class T>
+void 
+octave_sort<T>::lookupm (const T *data, octave_idx_type nel,
+                         const T* values, octave_idx_type nvalues,
+                         octave_idx_type *idx)
+{
+#ifdef INLINE_ASCENDING_SORT
+  if (compare == ascending_compare)
+    lookupm (data, nel, values, nvalues, idx, std::less<T> ());
+  else
+#endif
+#ifdef INLINE_DESCENDING_SORT    
+    if (compare == descending_compare)
+      lookupm (data, nel, values, nvalues, idx, std::greater<T> ());
+  else
+#endif
+    if (compare)
+      lookupm (data, nel, values, nvalues, idx, std::ptr_fun (compare));
+}
+
+template <class T> template <class Comp>
+void 
+octave_sort<T>::lookupb (const T *data, octave_idx_type nel,
+                         const T *values, octave_idx_type nvalues,
+                         bool *match, Comp comp)
+{
+  const T *end = data + nel;
+  for (octave_idx_type i = 0; i < nvalues; i++)
+    match[i] = std::binary_search (data, end, values[i], comp);
+}
+
+template <class T>
+void 
+octave_sort<T>::lookupb (const T *data, octave_idx_type nel,
+                         const T* values, octave_idx_type nvalues,
+                         bool *match)
+{
+#ifdef INLINE_ASCENDING_SORT
+  if (compare == ascending_compare)
+    lookupb (data, nel, values, nvalues, match, std::less<T> ());
+  else
+#endif
+#ifdef INLINE_DESCENDING_SORT    
+    if (compare == descending_compare)
+      lookupb (data, nel, values, nvalues, match, std::greater<T> ());
+  else
+#endif
+    if (compare)
+      lookupb (data, nel, values, nvalues, match, std::ptr_fun (compare));
+}
+
 template <class T>
 bool 
 octave_sort<T>::ascending_compare (typename ref_param<T>::type x,
--- a/liboctave/oct-sort.h
+++ b/liboctave/oct-sort.h
@@ -148,6 +148,17 @@
                const T* values, octave_idx_type nvalues,
                octave_idx_type *idx, octave_idx_type offset = 0);
 
+  // Lookup an array of values, only returning indices of
+  // exact matches. Non-matches are returned as -1.
+  void lookupm (const T *data, octave_idx_type nel,
+                const T* values, octave_idx_type nvalues,
+                octave_idx_type *idx);
+
+  // Lookup an array of values, only indicating exact matches.
+  void lookupb (const T *data, octave_idx_type nel,
+                const T* values, octave_idx_type nvalues,
+                bool *match);
+
   static bool ascending_compare (typename ref_param<T>::type,
 				 typename ref_param<T>::type);
 
@@ -302,6 +313,15 @@
                const T* values, octave_idx_type nvalues,
                octave_idx_type *idx, octave_idx_type offset, Comp comp);
 
+  template <class Comp>
+  void lookupm (const T *data, octave_idx_type nel,
+                const T* values, octave_idx_type nvalues,
+                octave_idx_type *idx, Comp comp);
+
+  template <class Comp>
+  void lookupb (const T *data, octave_idx_type nel,
+                const T* values, octave_idx_type nvalues,
+                bool *match, Comp comp);
 };
 
 template <class T>
--- a/src/ChangeLog
+++ b/src/ChangeLog
@@ -1,3 +1,9 @@
+2009-06-12  Jaroslav Hajek  <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/lookup.cc (do_numeric_lookup): New template function.
+	(Flookup): Extend to support b and m options, improve diagnostic.
+	Refactor.
+
 2009-06-12  Kai NODA  <nodakai@gmail.com>
 
 	* ls-mat4.h: Fix include guard
--- a/src/DLD-FUNCTIONS/lookup.cc
+++ b/src/DLD-FUNCTIONS/lookup.cc
@@ -99,8 +99,37 @@
 
 #define INT_ARRAY_LOOKUP(TYPE) \
   (table.is_ ## TYPE ## _type () && y.is_ ## TYPE ## _type ()) \
-    idx = table.TYPE ## _array_value ().lookup (y.TYPE ## _array_value (), \
-                                                UNSORTED, left_inf, right_inf);
+    retval = do_numeric_lookup (table.TYPE ## _array_value (), \
+                                y.TYPE ## _array_value (), \
+                                left_inf, right_inf, \
+                                match_idx, match_bool);
+template <class ArrayT>
+static octave_value
+do_numeric_lookup (const ArrayT& array, const ArrayT& values, 
+                   bool left_inf, bool right_inf,
+                   bool match_idx, bool match_bool)
+{
+  octave_value retval;
+
+  if (match_bool)
+    {
+      boolNDArray match = ArrayN<bool> (array.lookupb (values));
+      retval = match;
+    }
+  else
+    {
+      Array<octave_idx_type> idx;
+
+      if (match_idx)
+        idx = array.lookupm (values);
+      else
+        idx = array.lookup (values, UNSORTED, left_inf, right_inf);
+
+      retval = NDArray (idx, match_idx);
+    }
+
+  return retval;
+}
 
 DEFUN_DLD (lookup, args, ,
   "-*- texinfo -*-\n\
@@ -112,7 +141,7 @@
 @code{table(idx(i)) <= y(i) < table(idx(i+1))} for all @code{y(i)}\n\
 within the table.  If @code{y(i) < table (1)} then\n\
 @code{idx(i)} is 0. If @code{y(i) >= table(end)} then\n\
-@code{idx(i)} is @code{table(n)}.\n\
+@code{idx(i)} is @code{n}.\n\
 \n\
 If the table is strictly decreasing, then the tests are reversed.\n\
 There are no guarantees for tables which are non-monotonic or are not\n\
@@ -129,6 +158,12 @@
 \n\
 If @var{opts} is specified, it shall be a string with letters indicating\n\
 additional options.\n\
+\n\
+If 'm' is specified as option, @code{table(idx(i)) == val(i)} if @code{val(i)}\n\
+occurs in table; otherwise, @code{idx(i)} is zero.\n\
+If 'b' is specified, then @code{idx(i)} is a logical 1 or 0, indicating whether\n\
+@code{val(i)} is contained in table or not.\n\
+\n\
 For numeric lookup, 'l' in @var{opts} indicates that\n\
 the leftmost subinterval shall be extended to infinity (i.e., all indices\n\
 at least 1), and 'r' indicates that the rightmost subinterval shall be\n\
@@ -137,7 +172,7 @@
 For string lookup, 'i' indicates case-insensitive comparison.\n\
 @end deftypefn") 
 {
-  octave_value_list retval;
+  octave_value retval;
 
   int nargin = args.length ();
 
@@ -153,18 +188,36 @@
 
   bool num_case = table.is_numeric_type () && y.is_numeric_type ();
   bool str_case = table.is_cellstr () && (y.is_string () || y.is_cellstr ());
+  bool left_inf = false;
+  bool right_inf = false;
+  bool match_idx = false;
+  bool match_bool = false;
+  bool icase = false;
+
+  if (nargin == 3)
+    {
+      std::string opt = args(2).string_value ();
+      left_inf = contains_char (opt, 'l');
+      right_inf = contains_char (opt, 'r');
+      icase = contains_char (opt, 'i');
+      match_idx = contains_char (opt, 'm');
+      match_bool = contains_char (opt, 'b');
+    }
+
+  if ((match_idx || match_bool) && (left_inf || right_inf))
+    error ("lookup: m, b cannot be specified with l or r");
+  else if (match_idx && match_bool)
+    error ("lookup: only one of m, b can be specified");
+  else if (str_case && (left_inf || right_inf))
+    error ("lookup: l,r not recognized for string lookups");
+  else if (num_case && icase)
+    error ("lookup: i not recognized for numeric lookups");
+
+  if (error_state)
+    return retval;
 
   if (num_case) 
     {
-      bool left_inf = false;
-      bool right_inf = false;
-
-      if (nargin == 3)
-        {
-          std::string opt = args(2).string_value ();
-          left_inf = contains_char (opt, 'l');
-          right_inf = contains_char (opt, 'r');
-        }
 
       // In the case of a complex array, absolute values will be used for compatibility
       // (though it's not too meaningful).
@@ -187,13 +240,15 @@
       else if INT_ARRAY_LOOKUP (uint32)
       else if INT_ARRAY_LOOKUP (uint64)
       else if (table.is_single_type () || y.is_single_type ())
-        idx = table.float_array_value ().lookup (y.float_array_value (), 
-                                                 UNSORTED, left_inf, right_inf);
+        retval = do_numeric_lookup (table.float_array_value (),
+                                    y.float_array_value (),
+                                    left_inf, right_inf,
+                                    match_idx, match_bool);
       else
-        idx = table.array_value ().lookup (y.array_value (), 
-                                           UNSORTED, left_inf, right_inf);
-
-      retval(0) = NDArray (idx);
+        retval = do_numeric_lookup (table.array_value (),
+                                    y.array_value (),
+                                    left_inf, right_inf,
+                                    match_idx, match_bool);
 
     }
   else if (str_case)
@@ -203,13 +258,11 @@
       // Here we'll use octave_sort directly to avoid converting the array
       // for case-insensitive comparison.
 
-      bool icase = false;
 
       // check for case-insensitive option
       if (nargin == 3)
         {
           std::string opt = args(2).string_value ();
-          icase = contains_char (opt, 'i');
         }
 
       sortmode mode = (icase ? get_sort_mode (str_table, stri_comp_gt)
@@ -224,27 +277,38 @@
         str_comp = icase ? stri_comp_lt : octave_sort<std::string>::ascending_compare;
 
       octave_sort<std::string> lsort (str_comp);
+      Array<std::string> str_y (1);
+
       if (y.is_cellstr ())
+        str_y = y.cellstr_value ();
+      else
+        str_y(0) = y.string_value ();
+
+      if (match_bool)
         {
-          Array<std::string> str_y = y.cellstr_value ();
+          boolNDArray match (str_y.dims ());
+
+          lsort.lookupb (str_table.data (), str_table.nelem (), str_y.data (),
+                         str_y.nelem (), match.fortran_vec ());
 
+          retval = match;
+        }
+      else
+        {
           Array<octave_idx_type> idx (str_y.dims ());
 
-          lsort.lookup (str_table.data (), str_table.nelem (), str_y.data (),
-                        str_y.nelem (), idx.fortran_vec ());
+          if (match_idx)
+            {
+              lsort.lookupm (str_table.data (), str_table.nelem (), str_y.data (),
+                             str_y.nelem (), idx.fortran_vec ());
+            }
+          else
+            {
+              lsort.lookup (str_table.data (), str_table.nelem (), str_y.data (),
+                            str_y.nelem (), idx.fortran_vec ());
+            }
 
-          retval(0) = NDArray (idx);
-        }
-      else if (y.is_string ())
-        {
-          std::string str_y = y.string_value ();
-
-          octave_idx_type idx;
-
-          lsort.lookup (str_table.data (), str_table.nelem (), &str_y,
-                        1, &idx);
-
-          retval(0) = idx;
+          retval = NDArray (idx, match_idx);
         }
     }
   else
@@ -255,23 +319,27 @@
 }  
 
 /*
-%!assert (real(lookup(1:3, 0.5)), 0)     # value before table
-%!assert (real(lookup(1:3, 3.5)), 3)     # value after table error
-%!assert (real(lookup(1:3, 1.5)), 1)     # value within table error
-%!assert (real(lookup(1:3, [3,2,1])), [3,2,1])
-%!assert (real(lookup([1:4]', [1.2, 3.5]')), [1, 3]');
-%!assert (real(lookup([1:4], [1.2, 3.5]')), [1, 3]');
-%!assert (real(lookup([1:4]', [1.2, 3.5])), [1, 3]);
-%!assert (real(lookup([1:4], [1.2, 3.5])), [1, 3]);
-%!assert (real(lookup(1:3, [3, 2, 1])), [3, 2, 1]);
-%!assert (real(lookup([3:-1:1], [3.5, 3, 1.2, 2.5, 2.5])), [0, 1, 2, 1, 1])
+%!assert (lookup(1:3, 0.5), 0)     # value before table
+%!assert (lookup(1:3, 3.5), 3)     # value after table error
+%!assert (lookup(1:3, 1.5), 1)     # value within table error
+%!assert (lookup(1:3, [3,2,1]), [3,2,1])
+%!assert (lookup([1:4]', [1.2, 3.5]'), [1, 3]');
+%!assert (lookup([1:4], [1.2, 3.5]'), [1, 3]');
+%!assert (lookup([1:4]', [1.2, 3.5]), [1, 3]);
+%!assert (lookup([1:4], [1.2, 3.5]), [1, 3]);
+%!assert (lookup(1:3, [3, 2, 1]), [3, 2, 1]);
+%!assert (lookup([3:-1:1], [3.5, 3, 1.2, 2.5, 2.5]), [0, 1, 2, 1, 1])
 %!assert (isempty(lookup([1:3], [])))
 %!assert (isempty(lookup([1:3]', [])))
-%!assert (real(lookup(1:3, [1, 2; 3, 0.5])), [1, 2; 3, 0]);
+%!assert (lookup(1:3, [1, 2; 3, 0.5]), [1, 2; 3, 0]);
+%!assert (lookup(1:4, [1, 1.2; 3, 2.5], "m"), [1, 0; 3, 0]);
+%!assert (lookup(4:-1:1, [1, 1.2; 3, 2.5], "m"), [4, 0; 2, 0]);
+%!assert (lookup(1:4, [1, 1.2; 3, 2.5], "b"), logical ([1, 0; 3, 0]));
+%!assert (lookup(4:-1:1, [1, 1.2; 3, 2.5], "b"), logical ([4, 0; 2, 0]));
 %!
-%!assert (real(lookup({"apple","lemon","orange"}, {"banana","kiwi"; "ananas","mango"})), [1,1;0,2])
-%!assert (real(lookup({"apple","lemon","orange"}, "potato")), 3)
-%!assert (real(lookup({"orange","lemon","apple"}, "potato")), 0)
+%!assert (lookup({"apple","lemon","orange"}, {"banana","kiwi"; "ananas","mango"}), [1,1;0,2])
+%!assert (lookup({"apple","lemon","orange"}, "potato"), 3)
+%!assert (lookup({"orange","lemon","apple"}, "potato"), 0)
 */
 
 /*