diff liboctave/Array-util.cc @ 9479:d9716e3ee0dd

supply optimized compiled sub2ind & ind2sub
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 03 Aug 2009 15:52:40 +0200
parents 864805896876
children b096d11237be
line wrap: on
line diff
--- a/liboctave/Array-util.cc
+++ b/liboctave/Array-util.cc
@@ -27,6 +27,7 @@
 #include "Array-util.h"
 #include "dim-vector.h"
 #include "lo-error.h"
+#include "oct-locbuf.h"
 
 bool
 index_in_bounds (const Array<octave_idx_type>& ra_idx,
@@ -475,6 +476,130 @@
   return rdv;
 }
 
+// A helper class.
+struct sub2ind_helper
+{
+  octave_idx_type *ind, n;
+  sub2ind_helper (octave_idx_type *_ind, octave_idx_type _n)
+    : ind(_ind), n(_n) { }
+  void operator ()(octave_idx_type k) { (*ind++ *= n) += k; }
+};
+
+idx_vector sub2ind (const dim_vector& dv, const Array<idx_vector>& idxa)
+{
+  idx_vector retval;
+  octave_idx_type len = idxa.length ();
+
+  if (len >= 2)
+    {
+      const dim_vector dvx = dv.redim (len);
+      bool all_ranges = true;
+      octave_idx_type clen = -1;
+
+      for (octave_idx_type i = 0; i < len; i++)
+        {
+          idx_vector idx = idxa(i);
+          octave_idx_type n = dvx(i);
+
+          all_ranges = all_ranges && idx.is_range ();
+          if (clen < 0)
+            clen = idx.length (n);
+          else if (clen != idx.length (n))
+            current_liboctave_error_handler ("sub2ind: lengths of indices must match");
+
+          if (idx.extent (n) > n)
+            current_liboctave_error_handler ("sub2ind: index out of range");
+        }
+
+      if (clen == 1)
+        {
+          // All scalars case - the result is a scalar.
+          octave_idx_type idx = idxa(len-1)(0);
+          for (octave_idx_type i = len - 2; i >= 0; i--)
+            idx = idx * dvx(i) + idxa(i)(0);
+          retval = idx_vector (idx);
+        }
+      else if (all_ranges && clen != 0)
+        {
+          // All ranges case - the result is a range.
+          octave_idx_type start = 0, step = 0;
+          for (octave_idx_type i = len - 1; i >= 0; i--)
+            {
+              octave_idx_type xstart = idxa(i)(0), xstep = idxa(i)(1) - xstart;
+              start = start * dvx(i) + xstart;
+              step = step * dvx(i) + xstep;
+            }
+          retval = idx_vector::make_range (start, step, clen);
+        }
+      else
+        {
+          Array<octave_idx_type> idx (idxa(0).orig_dimensions ());
+          octave_idx_type *idx_vec = idx.fortran_vec ();
+
+          for (octave_idx_type i = len - 1; i >= 0; i--)
+            {
+              if (i < len - 1)
+                idxa(i).loop (clen, sub2ind_helper (idx_vec, dvx(i)));
+              else
+                idxa(i).copy_data (idx_vec);
+            }
+
+          retval = idx_vector (idx);
+        }
+    }
+  else
+    current_liboctave_error_handler ("sub2ind: needs at least 2 indices");
+
+  return retval;
+}
+
+Array<idx_vector> ind2sub (const dim_vector& dv, const idx_vector& idx)
+{
+  octave_idx_type len = idx.length (0), n = dv.length ();
+  Array<idx_vector> retval(n);
+  octave_idx_type numel = dv.numel ();
+
+  if (idx.extent (numel) > numel)
+    current_liboctave_error_handler ("ind2sub: index out of range");
+  else
+    {
+      if (idx.is_scalar ())
+        {
+          octave_idx_type k = idx(0);
+          for (octave_idx_type j = 0; j < n; j++)
+            {
+              retval(j) = k % dv(j);
+              k /= dv(j);
+            }
+        }
+      else
+        {
+          OCTAVE_LOCAL_BUFFER (Array<octave_idx_type>, rdata, n);
+
+          dim_vector odv = idx.orig_dimensions ();
+          for (octave_idx_type j = 0; j < n; j++)
+            rdata[j] = Array<octave_idx_type> (odv);
+
+          for (octave_idx_type i = 0; i < len; i++)
+            {
+              octave_idx_type k = idx(i);
+              for (octave_idx_type j = 0; j < n; j++)
+                {
+                  rdata[j](i) = k % dv(j);
+                  k /= dv(j);
+                }
+            }
+
+          for (octave_idx_type j = 0; j < n; j++)
+            retval(j) = rdata[j];
+        }
+
+
+    }
+
+  return retval;
+}
+
 int
 permute_vector_compare (const void *a, const void *b)
 {