diff liboctave/oct-sort.cc @ 8814:de16ebeef93d

improve lookup, provide Array<T>::lookup
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 19 Feb 2009 15:19:59 +0100
parents d5af326a3ede
children a4a8f871be81
line wrap: on
line diff
--- a/liboctave/oct-sort.cc
+++ b/liboctave/oct-sort.cc
@@ -96,6 +96,7 @@
 
 #include <cassert>
 #include <algorithm>
+#include <functional>
 #include <cstring>
 #include <stack>
 
@@ -105,15 +106,22 @@
 #include "oct-locbuf.h"
 
 template <class T>
-octave_sort<T>::octave_sort (void) : compare (ascending_compare)
+octave_sort<T>::octave_sort (void) : 
+  compare (ascending_compare), ms (0)
 { 
-  merge_init ();
 }
 
 template <class T>
-octave_sort<T>::octave_sort (compare_fcn_type comp) : compare (comp) 
+octave_sort<T>::octave_sort (compare_fcn_type comp) 
+  : compare (comp), ms (0)
 { 
-  merge_init (); 
+}
+
+template <class T>
+octave_sort<T>::~octave_sort () 
+{ 
+  merge_freemem ();
+  delete ms;
 }
 
 template <class T>
@@ -476,11 +484,12 @@
 void
 octave_sort<T>::merge_init (void)
 {
-  ms.a = 0;
-  ms.ia = 0;
-  ms.alloced = 0;
-  ms.n = 0;
-  ms.min_gallop = MIN_GALLOP;
+  if (! ms) ms = new MergeState;
+  ms->a = 0;
+  ms->ia = 0;
+  ms->alloced = 0;
+  ms->n = 0;
+  ms->min_gallop = MIN_GALLOP;
 }
 
 /* Free all the temp memory owned by the MergeState.  This must be called
@@ -491,10 +500,13 @@
 void
 octave_sort<T>::merge_freemem (void)
 {
-  delete [] ms.a;
-  delete [] ms.ia;
-  ms.alloced = 0;
-  ms.a = 0;
+  if (ms)
+    {
+      delete [] ms->a;
+      delete [] ms->ia;
+      ms->alloced = 0;
+      ms->a = 0;
+    }
 }
 
 static inline octave_idx_type
@@ -541,7 +553,7 @@
 int
 octave_sort<T>::merge_getmem (octave_idx_type need)
 {
-  if (need <= ms.alloced)
+  if (need <= ms->alloced)
     return 0;
 
   need = roundupsize (need); 
@@ -549,10 +561,10 @@
    * we don't care what's in the block.
    */
   merge_freemem ();
-  ms.a = new T[need];
-  if (ms.a)
+  ms->a = new T[need];
+  if (ms->a)
     {
-      ms.alloced = need;
+      ms->alloced = need;
       return 0;
     }
   merge_freemem ();	/* reset to sane state */
@@ -564,7 +576,7 @@
 int
 octave_sort<T>::merge_getmemi (octave_idx_type need)
 {
-  if (need <= ms.alloced && ms.a && ms.ia)
+  if (need <= ms->alloced && ms->a && ms->ia)
     return 0;
 
   need = roundupsize (need); 
@@ -572,11 +584,11 @@
    * we don't care what's in the block.
    */
   merge_freemem ();
-  ms.a = new T[need];
-  ms.ia = new octave_idx_type[need];
-  if (ms.a && ms.ia)
+  ms->a = new T[need];
+  ms->ia = new octave_idx_type[need];
+  if (ms->a && ms->ia)
     {
-      ms.alloced = need;
+      ms->alloced = need;
       return 0;
     }
   merge_freemem ();	/* reset to sane state */
@@ -600,13 +612,13 @@
   octave_idx_type k;
   T *dest;
   int result = -1;	/* guilty until proved innocent */
-  octave_idx_type min_gallop = ms.min_gallop;
+  octave_idx_type min_gallop = ms->min_gallop;
 
   if (merge_getmem (na) < 0)
     return -1;
-  std::copy (pa, pa + na, ms.a);
+  std::copy (pa, pa + na, ms->a);
   dest = pa;
-  pa = ms.a;
+  pa = ms->a;
 
   *dest++ = *pb++;
   --nb;
@@ -662,7 +674,7 @@
       do
 	{
 	  min_gallop -= min_gallop > 1;
-	  ms.min_gallop = min_gallop;
+	  ms->min_gallop = min_gallop;
 	  k = gallop_right (*pb, pa, na, 0, comp);
 	  acount = k;
 	  if (k)
@@ -706,7 +718,7 @@
       while (acount >= MIN_GALLOP || bcount >= MIN_GALLOP);
 
       ++min_gallop;	/* penalize it for leaving galloping mode */
-      ms.min_gallop = min_gallop;
+      ms->min_gallop = min_gallop;
     }
 
  Succeed:
@@ -736,14 +748,14 @@
   T *dest;
   octave_idx_type *idest;
   int result = -1;	/* guilty until proved innocent */
-  octave_idx_type min_gallop = ms.min_gallop;
+  octave_idx_type min_gallop = ms->min_gallop;
 
   if (merge_getmemi (na) < 0)
     return -1;
-  std::copy (pa, pa + na, ms.a);
-  std::copy (ipa, ipa + na, ms.ia);
+  std::copy (pa, pa + na, ms->a);
+  std::copy (ipa, ipa + na, ms->ia);
   dest = pa; idest = ipa;
-  pa = ms.a; ipa = ms.ia;
+  pa = ms->a; ipa = ms->ia;
 
   *dest++ = *pb++; *idest++ = *ipb++;
   --nb;
@@ -796,7 +808,7 @@
       do
 	{
 	  min_gallop -= min_gallop > 1;
-	  ms.min_gallop = min_gallop;
+	  ms->min_gallop = min_gallop;
 	  k = gallop_right (*pb, pa, na, 0, comp);
 	  acount = k;
 	  if (k)
@@ -842,7 +854,7 @@
       while (acount >= MIN_GALLOP || bcount >= MIN_GALLOP);
 
       ++min_gallop;	/* penalize it for leaving galloping mode */
-      ms.min_gallop = min_gallop;
+      ms->min_gallop = min_gallop;
     }
 
  Succeed:
@@ -883,15 +895,15 @@
   T *dest;
   int result = -1;	/* guilty until proved innocent */
   T *basea, *baseb;
-  octave_idx_type min_gallop = ms.min_gallop;
+  octave_idx_type min_gallop = ms->min_gallop;
 
   if (merge_getmem (nb) < 0)
     return -1;
   dest = pb + nb - 1;
-  std::copy (pb, pb + nb, ms.a);
+  std::copy (pb, pb + nb, ms->a);
   basea = pa;
-  baseb = ms.a;
-  pb = ms.a + nb - 1;
+  baseb = ms->a;
+  pb = ms->a + nb - 1;
   pa += na - 1;
 
   *dest-- = *pa--;
@@ -944,7 +956,7 @@
       do 
 	{
 	  min_gallop -= min_gallop > 1;
-	  ms.min_gallop = min_gallop;
+	  ms->min_gallop = min_gallop;
 	  k = gallop_right (*pb, basea, na, na-1, comp);
 	  if (k < 0)
 	    goto Fail;
@@ -989,7 +1001,7 @@
 	    goto Succeed;
 	} while (acount >= MIN_GALLOP || bcount >= MIN_GALLOP);
       ++min_gallop;	/* penalize it for leaving galloping mode */
-      ms.min_gallop = min_gallop;
+      ms->min_gallop = min_gallop;
     }
 
 Succeed:
@@ -1022,17 +1034,17 @@
   int result = -1;	/* guilty until proved innocent */
   T *basea, *baseb;
   octave_idx_type *ibasea, *ibaseb;
-  octave_idx_type min_gallop = ms.min_gallop;
+  octave_idx_type min_gallop = ms->min_gallop;
 
   if (merge_getmemi (nb) < 0)
     return -1;
   dest = pb + nb - 1;
   idest = ipb + nb - 1;
-  std::copy (pb, pb + nb, ms.a);
-  std::copy (ipb, ipb + nb, ms.ia);
+  std::copy (pb, pb + nb, ms->a);
+  std::copy (ipb, ipb + nb, ms->ia);
   basea = pa; ibasea = ipa;
-  baseb = ms.a; ibaseb = ms.ia;
-  pb = ms.a + nb - 1; ipb = ms.ia + nb - 1;
+  baseb = ms->a; ibaseb = ms->ia;
+  pb = ms->a + nb - 1; ipb = ms->ia + nb - 1;
   pa += na - 1; ipa += na - 1;
 
   *dest-- = *pa--; *idest-- = *ipa--;
@@ -1085,7 +1097,7 @@
       do 
 	{
 	  min_gallop -= min_gallop > 1;
-	  ms.min_gallop = min_gallop;
+	  ms->min_gallop = min_gallop;
 	  k = gallop_right (*pb, basea, na, na-1, comp);
 	  if (k < 0)
 	    goto Fail;
@@ -1132,7 +1144,7 @@
 	    goto Succeed;
 	} while (acount >= MIN_GALLOP || bcount >= MIN_GALLOP);
       ++min_gallop;	/* penalize it for leaving galloping mode */
-      ms.min_gallop = min_gallop;
+      ms->min_gallop = min_gallop;
     }
 
 Succeed:
@@ -1169,19 +1181,19 @@
   octave_idx_type na, nb;
   octave_idx_type k;
 
-  pa = data + ms.pending[i].base;
-  na = ms.pending[i].len;
-  pb = data + ms.pending[i+1].base;
-  nb = ms.pending[i+1].len;
+  pa = data + ms->pending[i].base;
+  na = ms->pending[i].len;
+  pb = data + ms->pending[i+1].base;
+  nb = ms->pending[i+1].len;
 
   /* Record the length of the combined runs; if i is the 3rd-last
    * run now, also slide over the last run (which isn't involved
    * in this merge).  The current run i+1 goes away in any case.
    */
-  ms.pending[i].len = na + nb;
-  if (i == ms.n - 3)
-    ms.pending[i+1] = ms.pending[i+2];
-  --ms.n;
+  ms->pending[i].len = na + nb;
+  if (i == ms->n - 3)
+    ms->pending[i+1] = ms->pending[i+2];
+  ms->n--;
 
   /* Where does b start in a?  Elements in a before that can be
    * ignored (already in place).
@@ -1221,21 +1233,21 @@
   octave_idx_type na, nb;
   octave_idx_type k;
 
-  pa = data + ms.pending[i].base;
-  ipa = idx + ms.pending[i].base;
-  na = ms.pending[i].len;
-  pb = data + ms.pending[i+1].base;
-  ipb = idx + ms.pending[i+1].base;
-  nb = ms.pending[i+1].len;
+  pa = data + ms->pending[i].base;
+  ipa = idx + ms->pending[i].base;
+  na = ms->pending[i].len;
+  pb = data + ms->pending[i+1].base;
+  ipb = idx + ms->pending[i+1].base;
+  nb = ms->pending[i+1].len;
 
   /* Record the length of the combined runs; if i is the 3rd-last
    * run now, also slide over the last run (which isn't involved
    * in this merge).  The current run i+1 goes away in any case.
    */
-  ms.pending[i].len = na + nb;
-  if (i == ms.n - 3)
-    ms.pending[i+1] = ms.pending[i+2];
-  --ms.n;
+  ms->pending[i].len = na + nb;
+  if (i == ms->n - 3)
+    ms->pending[i+1] = ms->pending[i+2];
+  ms->n--;
 
   /* Where does b start in a?  Elements in a before that can be
    * ignored (already in place).
@@ -1279,11 +1291,11 @@
 int
 octave_sort<T>::merge_collapse (T *data, Comp comp)
 {
-  struct s_slice *p = ms.pending;
+  struct s_slice *p = ms->pending;
 
-  while (ms.n > 1) 
+  while (ms->n > 1) 
     {
-      octave_idx_type n = ms.n - 2;
+      octave_idx_type n = ms->n - 2;
       if (n > 0 && p[n-1].len <= p[n].len + p[n+1].len) 
 	{
 	  if (p[n-1].len < p[n+1].len)
@@ -1308,11 +1320,11 @@
 int
 octave_sort<T>::merge_collapse (T *data, octave_idx_type *idx, Comp comp)
 {
-  struct s_slice *p = ms.pending;
+  struct s_slice *p = ms->pending;
 
-  while (ms.n > 1) 
+  while (ms->n > 1) 
     {
-      octave_idx_type n = ms.n - 2;
+      octave_idx_type n = ms->n - 2;
       if (n > 0 && p[n-1].len <= p[n].len + p[n+1].len) 
 	{
 	  if (p[n-1].len < p[n+1].len)
@@ -1342,11 +1354,11 @@
 int
 octave_sort<T>::merge_force_collapse (T *data, Comp comp)
 {
-  struct s_slice *p = ms.pending;
+  struct s_slice *p = ms->pending;
 
-  while (ms.n > 1) 
+  while (ms->n > 1) 
     {
-      octave_idx_type n = ms.n - 2;
+      octave_idx_type n = ms->n - 2;
       if (n > 0 && p[n-1].len < p[n+1].len)
 	--n;
       if (merge_at (n, data, comp) < 0)
@@ -1361,11 +1373,11 @@
 int
 octave_sort<T>::merge_force_collapse (T *data, octave_idx_type *idx, Comp comp)
 {
-  struct s_slice *p = ms.pending;
+  struct s_slice *p = ms->pending;
 
-  while (ms.n > 1) 
+  while (ms->n > 1) 
     {
-      octave_idx_type n = ms.n - 2;
+      octave_idx_type n = ms->n - 2;
       if (n > 0 && p[n-1].len < p[n+1].len)
 	--n;
       if (merge_at (n, data, idx, comp) < 0)
@@ -1406,8 +1418,13 @@
 octave_sort<T>::sort (T *data, octave_idx_type nel, Comp comp)
 {
   /* Re-initialize the Mergestate as this might be the second time called */
-  ms.n = 0;
-  ms.min_gallop = MIN_GALLOP;
+  if (ms)
+    {
+      ms->n = 0;
+      ms->min_gallop = MIN_GALLOP;
+    }
+  else
+    merge_init ();
 
   if (nel > 1)
     {
@@ -1437,10 +1454,10 @@
 	      n = force;
 	    }
 	  /* Push run onto pending-runs stack, and maybe merge. */
-	  assert (ms.n < MAX_MERGE_PENDING);
-	  ms.pending[ms.n].base = lo;
-	  ms.pending[ms.n].len = n;
-	  ++ms.n;
+	  assert (ms->n < MAX_MERGE_PENDING);
+	  ms->pending[ms->n].base = lo;
+	  ms->pending[ms->n].len = n;
+	  ms->n++;
 	  if (merge_collapse (data, comp) < 0)
 	    goto fail;
 	  /* Advance to find next run. */
@@ -1462,10 +1479,6 @@
 octave_sort<T>::sort (T *data, octave_idx_type *idx, octave_idx_type nel, 
                       Comp comp)
 {
-  /* Re-initialize the Mergestate as this might be the second time called */
-  ms.n = 0;
-  ms.min_gallop = MIN_GALLOP;
-
   if (nel > 1)
     {
       octave_idx_type nremaining = nel; 
@@ -1497,10 +1510,10 @@
 	      n = force;
 	    }
 	  /* Push run onto pending-runs stack, and maybe merge. */
-	  assert (ms.n < MAX_MERGE_PENDING);
-	  ms.pending[ms.n].base = lo;
-	  ms.pending[ms.n].len = n;
-	  ++ms.n;
+	  assert (ms->n < MAX_MERGE_PENDING);
+	  ms->pending[ms->n].base = lo;
+	  ms->pending[ms->n].len = n;
+	  ms->n++;
 	  if (merge_collapse (data, idx, comp) < 0)
 	    goto fail;
 	  /* Advance to find next run. */
@@ -1520,7 +1533,17 @@
 void
 octave_sort<T>::sort (T *data, octave_idx_type nel)
 {
+  /* Re-initialize the Mergestate as this might be the second time called */
+  if (ms)
+    {
+      ms->n = 0;
+      ms->min_gallop = MIN_GALLOP;
+    }
+  else
+    merge_init ();
+
   merge_getmem (1024);
+
 #ifdef INLINE_ASCENDING_SORT
   if (compare == ascending_compare)
     sort (data, nel, std::less<T> ());
@@ -1539,7 +1562,17 @@
 void
 octave_sort<T>::sort (T *data, octave_idx_type *idx, octave_idx_type nel)
 {
+  /* Re-initialize the Mergestate as this might be the second time called */
+  if (ms)
+    {
+      ms->n = 0;
+      ms->min_gallop = MIN_GALLOP;
+    }
+  else
+    merge_init ();
+
   merge_getmemi (1024);
+
 #ifdef INLINE_ASCENDING_SORT
   if (compare == ascending_compare)
     sort (data, idx, nel, std::less<T> ());
@@ -1761,6 +1794,188 @@
 }
 
 
+template <class T> template <class Comp>
+octave_idx_type 
+octave_sort<T>::lookup (const T *data, octave_idx_type nel,
+                        const T& value, Comp comp)
+{
+  return std::upper_bound (data, data + nel, value, comp) - data;
+}
+
+template <class T>
+octave_idx_type 
+octave_sort<T>::lookup (const T *data, octave_idx_type nel,
+                        const T& value)
+{
+  octave_idx_type retval = 0;
+
+#ifdef INLINE_ASCENDING_SORT
+  if (compare == ascending_compare)
+    retval = lookup (data, nel, value, std::less<T> ());
+  else
+#endif
+#ifdef INLINE_DESCENDING_SORT    
+    if (compare == descending_compare)
+      retval = lookup (data, nel, value, std::greater<T> ());
+  else
+#endif
+    if (compare)
+      retval = lookup (data, nel, value, std::ptr_fun (compare));
+
+  return retval;
+}
+
+// a unary functor that checks whether a value is outside [a,b) range
+template<class T, class Comp>
+class out_of_range_pred : public std::unary_function<T, bool>
+{
+public:
+  out_of_range_pred (const T& aa, const T& bb, Comp c) 
+    : a (aa), b (bb), comp (c) { }
+  bool operator () (const T& x) { return comp (x, a) || ! comp (x, b); }
+
+private:
+  T a, b;
+  Comp comp;
+};
+
+// a unary functor that checks whether a value is < a
+template<class T, class Comp>
+class less_than_pred : public std::unary_function<T, bool>
+{
+  typedef typename ref_param<T>::type param_type;
+public:
+  less_than_pred (param_type aa, Comp c) 
+    : a (aa), comp (c) { }
+  bool operator () (const T& x) { return comp (x, a); }
+
+private:
+  T a;
+  Comp comp;
+};
+
+// a unary functor that checks whether a value is >= a
+template<class T, class Comp>
+class greater_or_equal_pred : public std::unary_function<T, bool>
+{
+public:
+  greater_or_equal_pred (const T& aa, Comp c) 
+    : a (aa), comp (c) { }
+  bool operator () (const T& x) { return ! comp (x, a); }
+
+private:
+  T a;
+  Comp comp;
+};
+
+// conveniently constructs the above functors.
+// NOTE: with SGI extensions, this one can be written as
+// compose2 (logical_and<bool>(), bind2nd (less<T>(), a),
+//           not1 (bind2nd (less<T>(), b)))
+template<class T, class Comp>
+inline out_of_range_pred<T, Comp> 
+out_of_range (const T& a, 
+              const T& b, Comp comp)
+{
+  return out_of_range_pred<T, Comp> (a, b, comp);
+}
+
+// Note: these could be written as
+//    std::not1 (std::bind2nd (comp, *cur))
+// and
+//    std::bind2nd (comp, *(cur-1)));
+// but that doesn't work for functions with reference parameters in g++ 4.3.
+template<class T, class Comp>
+inline less_than_pred<T, Comp> 
+less_than (const T& a, Comp comp)
+{
+  return less_than_pred<T, Comp> (a, comp);
+}
+template<class T, class Comp>
+inline greater_or_equal_pred<T, Comp> 
+greater_or_equal (const T& a, Comp comp)
+{
+  return greater_or_equal_pred<T, Comp> (a, comp);
+}
+
+
+template <class T> template <class Comp>
+void 
+octave_sort<T>::lookup (const T *data, octave_idx_type nel,
+                        const T *values, octave_idx_type nvalues,
+                        octave_idx_type *idx, octave_idx_type offset, Comp comp)
+{
+  if (nel == 0)
+    // the trivial case of empty table
+    std::fill_n (idx, nvalues, offset);
+  else
+    {
+      const T *vcur = values;
+      const T *vend = values + nvalues;
+
+      const T *cur = data;
+      const T *end = data + nel;
+
+      while (vcur != vend)
+        {
+          // determine the enclosing interval for next value, trying
+          // ++cur as a special case;
+          if (cur == end || comp (*vcur, *cur))
+            cur = std::upper_bound (data, cur, *vcur, comp);
+          else
+            {
+              ++cur;
+              if (cur != end && ! comp (*vcur, *cur))
+                cur = std::upper_bound (cur + 1, end, *vcur, comp);
+            }
+
+          octave_idx_type vidx = cur - data + offset;
+          // store index of the current interval.
+          *(idx++) = vidx;
+          ++vcur;
+
+          // find first value not in current subrange
+          const T *vnew;
+          if (cur != end)
+            if (cur != data)
+              // inner interval
+              vnew = std::find_if (vcur, vend,
+                                   out_of_range (*(cur-1), *cur, comp));
+            else
+              // special case: lowermost range (-Inf, min) 
+              vnew = std::find_if (vcur, vend, greater_or_equal (*cur, comp));
+          else
+            // special case: uppermost range [max, Inf)
+            vnew = std::find_if (vcur, vend, less_than (*(cur-1), comp));
+
+          // store index of the current interval.
+          std::fill_n (idx, vnew - vcur, vidx);
+          idx += (vnew - vcur);
+          vcur = vnew;
+        }
+    }
+}
+
+template <class T>
+void 
+octave_sort<T>::lookup (const T *data, octave_idx_type nel,
+                        const T* values, octave_idx_type nvalues,
+                        octave_idx_type *idx, octave_idx_type offset)
+{
+#ifdef INLINE_ASCENDING_SORT
+  if (compare == ascending_compare)
+    lookup (data, nel, values, nvalues, idx, offset, std::less<T> ());
+  else
+#endif
+#ifdef INLINE_DESCENDING_SORT    
+    if (compare == descending_compare)
+      lookup (data, nel, values, nvalues, idx, offset, std::greater<T> ());
+  else
+#endif
+    if (compare)
+      lookup (data, nel, values, nvalues, idx, offset, std::ptr_fun (compare));
+}
+
 template <class T>
 bool 
 octave_sort<T>::ascending_compare (typename ref_param<T>::type x,