changeset 10786:146a97c3bc97

some optimizations for mat2cell
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 13 Jul 2010 13:40:51 +0200
parents c2041adcf234
children ac433932ce23
files src/ChangeLog src/DLD-FUNCTIONS/cellfun.cc
diffstat 2 files changed, 282 insertions(+), 144 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog
+++ b/src/ChangeLog
@@ -1,3 +1,9 @@
+2010-07-13  Jaroslav Hajek  <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/cellfun.cc (mat2cell_mismatch, prepare_idx,
+	do_mat2cell_2d, do_mat2cell_nd, do_mat2cell): New helper funcs.
+	(Fmat2cell): Use them here.
+
 2010-07-13  Jaroslav Hajek  <highegg@gmail.com>
 
 	* data.cc (do_sparse_diff): Use typecasts where needed.
--- a/src/DLD-FUNCTIONS/cellfun.cc
+++ b/src/DLD-FUNCTIONS/cellfun.cc
@@ -1059,6 +1059,226 @@
 
 */
 
+static bool 
+mat2cell_mismatch (const dim_vector& dv,
+                   const Array<octave_idx_type> *d, int nd)
+{
+  for (int i = 0; i < nd; i++)
+    {
+      octave_idx_type s = 0;
+      for (octave_idx_type j = 0; j < d[i].length (); j++)
+        s += d[i](j);
+
+      octave_idx_type r = i < dv.length () ? dv(i) : 1;
+
+      if (s != r)
+        {
+          error ("mat2cell: mismatch on %d-th dimension (%d != %d)",
+                 i+1, r, s);
+          return true;
+        }
+    }
+
+  return false;
+}
+
+template<class container>
+static void 
+prepare_idx (container *idx, int idim, int nd,
+             const Array<octave_idx_type>* d)
+{
+  octave_idx_type nidx = idim < nd ? d[idim].numel () : 1;
+  if (nidx == 1)
+    idx[0] = idx_vector::colon;
+  else
+    {
+      octave_idx_type l = 0;
+      for (octave_idx_type i = 0; i < nidx; i++)
+        {
+          octave_idx_type u = l + d[idim](i);
+          idx[i] = idx_vector (l, u);
+          l = u;
+        }
+    }
+}
+
+// 2D specialization, works for Array, Sparse and octave_map.
+// Uses 1D or 2D indexing.
+
+template <class Array2D>
+static Cell
+do_mat2cell_2d (const Array2D& a, const Array<octave_idx_type> *d, int nd)
+{
+  NoAlias<Cell> retval;
+  assert (nd == 1 || nd == 2);
+  assert (a.ndims () == 2);
+
+  if (mat2cell_mismatch (a.dims (), d, nd))
+    return retval;
+
+  octave_idx_type nridx = d[0].length ();
+  octave_idx_type ncidx = nd == 1 ? 1 : d[1].length ();
+  retval.clear (nridx, ncidx);
+
+  int ivec = -1;
+  if (a.rows () > 1 && a.cols () == 1 && ncidx == 1)
+    ivec = 0;
+  else if (a.rows () == 1 && nridx == 1 && nd == 2)
+    ivec = 1;
+
+  if (ivec >= 0)
+    {
+      // Vector split. Use 1D indexing.
+      octave_idx_type l = 0, nidx = (ivec == 0 ? nridx : ncidx);
+      for (octave_idx_type i = 0; i < nidx; i++)
+        {
+          octave_idx_type u = l + d[ivec](i);
+          retval(i) = a.index (idx_vector (l, u));
+          l = u;
+        }
+    }
+  else
+    {
+      // General 2D case. Use 2D indexing.
+      OCTAVE_LOCAL_BUFFER (idx_vector, ridx, nridx);
+      prepare_idx (ridx, 0, nd, d);
+
+      OCTAVE_LOCAL_BUFFER (idx_vector, cidx, ncidx);
+      prepare_idx (cidx, 1, nd, d);
+
+      for (octave_idx_type j = 0; j < ncidx; j++)
+        for (octave_idx_type i = 0; i < nridx; i++)
+          {
+            octave_quit ();
+
+            retval(i,j) = a.index (ridx[i], cidx[j]);
+          }
+    }
+
+  return retval;
+}
+
+// Nd case. Works for Arrays and octave_map.
+// Uses Nd indexing.
+
+template <class ArrayND>
+Cell
+do_mat2cell_nd (const ArrayND& a, const Array<octave_idx_type> *d, int nd)
+{
+  NoAlias<Cell> retval;
+  assert (nd >= 1);
+
+  if (mat2cell_mismatch (a.dims (), d, nd))
+    return retval;
+
+  dim_vector rdv = dim_vector::alloc (nd);
+  OCTAVE_LOCAL_BUFFER (octave_idx_type, nidx, nd);
+  octave_idx_type idxtot = 0;
+  for (int i = 0; i < nd; i++)
+    {
+      rdv(i) = nidx[i] = d[i].length ();
+      idxtot += nidx[i];
+    }
+
+  retval.clear (rdv);
+
+  OCTAVE_LOCAL_BUFFER (idx_vector, xidx, idxtot);
+  OCTAVE_LOCAL_BUFFER (idx_vector *, idx, nd);
+
+  idxtot = 0;
+  for (int i = 0; i < nd; i++)
+    {
+      idx[i] = xidx + idxtot;
+      prepare_idx (idx[i], i, nd, d);
+      idxtot += nidx[i];
+    }
+
+  OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, ridx, nd, 0);
+  NoAlias< Array<idx_vector> > ra_idx (1, std::max (nd, a.ndims ()),
+                                       idx_vector::colon);
+
+  for (octave_idx_type j = 0; j < retval.numel (); j++)
+    {
+      octave_quit ();
+
+      for (int i = 0; i < nd; i++)
+        ra_idx(i) = idx[i][ridx[i]];
+
+      retval(j) = a.index (ra_idx);
+
+      rdv.increment_index (ridx);
+    }
+
+  return retval;
+}
+
+// Dispatcher.
+template <class ArrayND>
+Cell
+do_mat2cell (const ArrayND& a, const Array<octave_idx_type> *d, int nd)
+{
+  if (a.ndims () == 2 && nd <= 2)
+    return do_mat2cell_2d (a, d, nd);
+  else
+    return do_mat2cell_nd (a, d, nd);
+}
+
+// General case. Works for any class supporting do_index_op.
+// Uses Nd indexing.
+
+Cell
+do_mat2cell (octave_value& a, const Array<octave_idx_type> *d, int nd)
+{
+  NoAlias<Cell> retval;
+  assert (nd >= 1);
+
+  if (mat2cell_mismatch (a.dims (), d, nd))
+    return retval;
+
+  dim_vector rdv = dim_vector::alloc (nd);
+  OCTAVE_LOCAL_BUFFER (octave_idx_type, nidx, nd);
+  octave_idx_type idxtot = 0;
+  for (int i = 0; i < nd; i++)
+    {
+      rdv(i) = nidx[i] = d[i].length ();
+      idxtot += nidx[i];
+    }
+
+  retval.clear (rdv);
+
+  OCTAVE_LOCAL_BUFFER (octave_value, xidx, idxtot);
+  OCTAVE_LOCAL_BUFFER (octave_value *, idx, nd);
+
+  idxtot = 0;
+  for (int i = 0; i < nd; i++)
+    {
+      idx[i] = xidx + idxtot;
+      prepare_idx (idx[i], i, nd, d);
+      idxtot += nidx[i];
+    }
+
+  OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, ridx, nd, 0);
+  octave_value_list ra_idx (std::max (nd, a.ndims ()), 
+                            octave_value::magic_colon_t);
+
+  for (octave_idx_type j = 0; j < retval.numel (); j++)
+    {
+      octave_quit ();
+
+      for (int i = 0; i < nd; i++)
+        ra_idx(i) = idx[i][ridx[i]];
+
+      retval(j) = a.do_index_op (ra_idx);
+
+      if (error_state)
+        break;
+
+      rdv.increment_index (ridx);
+    }
+
+  return retval;
+}
+
 DEFUN_DLD (mat2cell, args, ,
   "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {@var{b} =} mat2cell (@var{a}, @var{m}, @var{n})\n\
@@ -1108,158 +1328,70 @@
     print_usage ();
   else
     {
-      dim_vector dv = args(0).dims();
-      dim_vector new_dv;
-      new_dv.resize(dv.length());
-      
-      if (nargin > 2)
-        {
-          octave_idx_type nmax = -1;
-
-          if (nargin - 1 != dv.length())
-            error ("mat2cell: Incorrect number of dimensions");
-          else
-            {
-              for (octave_idx_type j = 0; j < dv.length(); j++)
-                {
-                  ColumnVector d = ColumnVector (args(j+1).vector_value 
-                                                 (false, true));
-
-                  if (d.length() < 1)
-                    {
-                      error ("mat2cell: dimension can not be empty");
-                      break;
-                    }
-                  else
-                    {
-                      if (nmax < d.length())
-                        nmax = d.length();
-
-                      for (octave_idx_type i = 1; i < d.length(); i++)
-                        {
-                          OCTAVE_QUIT;
+      // Prepare indices.
+      OCTAVE_LOCAL_BUFFER (Array<octave_idx_type>, d, nargin-1);
 
-                          if (d(i) >= 0)
-                            d(i) += d(i-1);
-                          else
-                            {
-                              error ("mat2cell: invalid dimensional argument");
-                              break;
-                            }
-                        }
-
-                      if (d(0) < 0)
-                        error ("mat2cell: invalid dimensional argument");
-                      
-                      if (d(d.length() - 1) != dv(j))
-                        error ("mat2cell: inconsistent dimensions");
-
-                      if (error_state)
-                        break;
+      for (int i = 1; i < nargin; i++)
+        {
+          d[i-1] = args(i).octave_idx_type_vector_value (true);
+          if (error_state)
+            return retval;
+        }
 
-                      new_dv(j) = d.length();
-                    }
-                }
-            }
-
-          if (! error_state)
-            {
-              // Construct a matrix with the index values
-              Matrix dimargs(nmax, new_dv.length());
-              for (octave_idx_type j = 0; j < new_dv.length(); j++)
-                {
-                  OCTAVE_QUIT;
-
-                  ColumnVector d = ColumnVector (args(j+1).vector_value 
-                                                 (false, true));
-
-                  dimargs(0,j) = d(0);
-                  for (octave_idx_type i = 1; i < d.length(); i++)
-                    dimargs(i,j) = dimargs(i-1,j) + d(i);
-                }
-
-
-              octave_value_list lst (new_dv.length(), octave_value());
-              Cell ret (new_dv);
-              octave_idx_type nel = new_dv.numel();
-              octave_idx_type ntot = 1;
+      octave_value a = args(0);
+      bool sparse = a.is_sparse_type ();
+      if (sparse && nargin > 3)
+        {
+          error ("mat2cell: sparse arguments only support 2D indexing");
+          return retval;
+        }
 
-              for (int j = 0; j < new_dv.length()-1; j++)
-                ntot *= new_dv(j);
-
-              for (octave_idx_type i = 0; i <  nel; i++)
-                {
-                  octave_idx_type n = ntot;
-                  octave_idx_type ii = i;
-                  for (octave_idx_type j =  new_dv.length() - 1;  j >= 0; j--)
-                    {
-                      OCTAVE_QUIT;
-                  
-                      octave_idx_type idx = ii / n;
-                      lst (j) = Range((idx == 0 ? 1. : dimargs(idx-1,j)+1.),
-                                      dimargs(idx,j));
-                      ii = ii % n;
-                      if (j != 0)
-                        n /= new_dv(j-1);
-                    }
-                  ret(i) = octave_value(args(0)).do_index_op(lst, 0);
-                  if (error_state)
-                    break;
-                }
-          
-              if (!error_state)
-                retval = ret;
-            }
-        }
-      else
+      switch (a.builtin_type ())
         {
-          ColumnVector d = ColumnVector (args(1).vector_value 
-                                         (false, true));
-
-          double sumd = 0.;
-          for (octave_idx_type i = 0; i < d.length(); i++)
-            {
-              OCTAVE_QUIT;
+        case btyp_double:
+          {
+            if (sparse)
+              retval = do_mat2cell_2d (a.sparse_matrix_value (), d, nargin-1);
+            else
+              retval = do_mat2cell (a.array_value (), d, nargin - 1);
+            break;
+          }
+        case btyp_complex:
+          {
+            if (sparse)
+              retval = do_mat2cell_2d (a.sparse_complex_matrix_value (), d, nargin-1);
+            else
+              retval = do_mat2cell (a.complex_array_value (), d, nargin - 1);
+            break;
+          }
+#define BTYP_BRANCH(X,Y) \
+        case btyp_ ## X: \
+            retval = do_mat2cell (a.Y ## _value (), d, nargin - 1); \
+          break
 
-              if (d(i) >= 0)
-                sumd += d(i);
-              else
-                {
-                  error ("mat2cell: invalid dimensional argument");
-                  break;
-                }
-            }
-
-          if (sumd != dv(0))
-            error ("mat2cell: inconsistent dimensions");
-
-          new_dv(0) = d.length();
-          for (octave_idx_type i = 1; i < dv.length(); i++)
-            new_dv(i) = 1;
+        BTYP_BRANCH (float, float_array);
+        BTYP_BRANCH (float_complex, float_complex_array);
+        BTYP_BRANCH (bool, bool_array);
+        BTYP_BRANCH (char, char_array);
 
-          if (! error_state)
-            {
-              octave_value_list lst (new_dv.length(), octave_value());
-              Cell ret (new_dv);
+        BTYP_BRANCH (int8,  int8_array);
+        BTYP_BRANCH (int16, int16_array);
+        BTYP_BRANCH (int32, int32_array);
+        BTYP_BRANCH (int64, int64_array);
+        BTYP_BRANCH (uint8,  uint8_array);
+        BTYP_BRANCH (uint16, uint16_array);
+        BTYP_BRANCH (uint32, uint32_array);
+        BTYP_BRANCH (uint64, uint64_array);
 
-              for (octave_idx_type i = 1; i < new_dv.length(); i++)
-                lst (i) = Range (1., static_cast<double>(dv(i)));
-              
-              double idx = 0.;
-              for (octave_idx_type i = 0; i <  new_dv(0); i++)
-                {
-                  OCTAVE_QUIT;
+        BTYP_BRANCH (cell, cell);
+        BTYP_BRANCH (struct, map);
+#undef BTYP_BRANCH
 
-                  lst(0) = Range(idx + 1., idx + d(i));
-                  ret(i) = octave_value(args(0)).do_index_op(lst, 0);
-                  idx += d(i);
-                  if (error_state)
-                    break;
-                }
-          
-              if (!error_state)
-                retval = ret;
-            }
+        case btyp_func_handle:
+          gripe_wrong_type_arg ("mat2cell", a);
+          break;
+        default:
+          retval = do_mat2cell (a, d, nargin-1);
         }
     }