changeset 9678:c929f09457b7

rewrite num2cell for speed-up + a few associated fixes
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 01 Oct 2009 14:07:06 +0200
parents 8cf522ce9c4d
children 0896714301e4
files liboctave/Array.cc liboctave/ChangeLog src/ChangeLog src/DLD-FUNCTIONS/cellfun.cc
diffstat 4 files changed, 201 insertions(+), 64 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/Array.cc
+++ b/liboctave/Array.cc
@@ -592,6 +592,8 @@
   // Need this array to check for identical elements in permutation array.
   OCTAVE_LOCAL_BUFFER_INIT (bool, checked, perm_vec_len, false);
 
+  bool identity = true;
+
   // Find dimension vector of permuted array.
   for (int i = 0; i < perm_vec_len; i++)
     {
@@ -614,11 +616,17 @@
 	  return retval;
 	}
       else
-	checked[perm_elt] = true;
+        {
+          checked[perm_elt] = true;
+          identity = identity && perm_elt == i;
+        }
 
       dv_new(i) = dv(perm_elt);
     }
 
+  if (identity)
+    return *this;
+
   if (inv)
     {
       for (int i = 0; i < perm_vec_len; i++)
--- a/liboctave/ChangeLog
+++ b/liboctave/ChangeLog
@@ -1,3 +1,7 @@
+2009-10-01  Jaroslav Hajek  <highegg@gmail.com>
+
+	* Array.cc (Array<T>::permute): Fast case identity permutation.
+
 2009-09-27  Jaroslav Hajek  <highegg@gmail.com>
 
 	* oct-cmplx.h: Fix complex-real orderings.
--- a/src/ChangeLog
+++ b/src/ChangeLog
@@ -1,3 +1,10 @@
+2009-10-01  Jaroslav Hajek  <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/cellfun.cc 
+	(do_num2cell_helper, do_num2cell): New funcs.
+	(Fnum2cell): Rewrite.
+	(do_cellslices_nda): Do not leave trailing dims.
+
 2009-09-30  John W. Eaton  <jwe@octave.org>
 
 	* error.cc (error_1, pr_where_2, handle_message):
--- a/src/DLD-FUNCTIONS/cellfun.cc
+++ b/src/DLD-FUNCTIONS/cellfun.cc
@@ -40,6 +40,7 @@
 #include "variables.h"
 #include "ov-colon.h"
 #include "unwind-prot.h"
+#include "gripes.h"
 
 // Rationale:
 // The octave_base_value::subsasgn method carries too much overhead for
@@ -905,6 +906,91 @@
 
 */
 
+static void
+do_num2cell_helper (const dim_vector& dv,
+                    const Array<int>& dimv,
+                    dim_vector& celldv, dim_vector& arraydv,
+                    Array<int>& perm)
+{
+  int dvl = dimv.length ();
+  int maxd = dv.length ();
+  celldv = dv;
+  for (int i = 0; i < dvl; i++)
+    maxd = std::max (maxd, dimv(i));
+  if (maxd > dv.length ())
+    celldv.resize (maxd, 1);
+  arraydv = celldv;
+
+  OCTAVE_LOCAL_BUFFER_INIT (bool, sing, maxd, false);
+
+  perm.clear (maxd);
+  for (int i = 0; i < dvl; i++)
+    {
+      int k = dimv(i) - 1;
+      if (k < 0)
+        {
+          error ("num2cell: dimension indices must be positive");
+          return;
+        }
+      else if (i > 0 && k < dimv(i-1) - 1)
+        {
+          error ("num2cell: dimension indices must be strictly increasing");
+          return;
+        }
+
+      sing[k] = true;
+      perm(i) = k;
+    }
+
+  for (int k = 0, i = dvl; k < maxd; k++)
+    if (! sing[k])
+      perm(i++) = k;
+
+  for (int i = 0; i < maxd; i++)
+    if (sing[i])
+      celldv(i) = 1;
+    else
+      arraydv(i) = 1;
+}
+
+template<class NDA>
+static Cell
+do_num2cell (const NDA& array, const Array<int>& dimv)
+{
+  if (dimv.is_empty ())
+    {
+      Cell retval (array.dims ());
+      octave_idx_type nel = array.numel ();
+      for (octave_idx_type i = 0; i < nel; i++)
+        retval.xelem (i) = array(i);
+
+      return retval;
+    }
+  else
+    {
+      dim_vector celldv, arraydv;
+      Array<int> perm;
+      do_num2cell_helper (array.dims (), dimv, celldv, arraydv, perm);
+      if (error_state)
+        return Cell ();
+
+      NDA parray = array.permute (perm);
+
+      octave_idx_type nela = arraydv.numel (), nelc = celldv.numel ();
+      parray = parray.reshape (dim_vector (nela, nelc));
+
+      Cell retval (celldv);
+      for (octave_idx_type i = 0; i < nelc; i++)
+        {
+          NDA tmp (parray.index (idx_vector::colon, idx_vector (i)));
+          retval.xelem (i) = tmp.reshape (arraydv);
+        }
+
+      return retval;
+    }
+}
+
+
 DEFUN_DLD (num2cell, args, ,
   "-*- texinfo -*-\n\
 @deftypefn  {Loadable Function} {@var{c} =} num2cell (@var{m})\n\
@@ -922,72 +1008,83 @@
     print_usage ();
   else
     {
-      dim_vector dv = args(0).dims ();
-      Array<int> sings;
-
-      if (nargin == 2)
-	{
-	  ColumnVector dsings = ColumnVector (args(1).vector_value 
-						  (false, true));
-	  sings.resize (dsings.length());
-
-	  if (!error_state)
-	    for (octave_idx_type i = 0; i < dsings.length(); i++)
-	      if (dsings(i) > dv.length() || dsings(i) < 1 ||
-		  D_NINT(dsings(i)) != dsings(i))
-		{
-		  error ("invalid dimension specified");
-		  break;
-		}
-	      else
-		sings(i) = NINT(dsings(i)) - 1;
-	}
-
-      if (! error_state)
-	{
-	  Array<bool> idx_colon (dv.length());
-	  dim_vector new_dv (dv);
-	  octave_value_list lst (new_dv.length(), octave_value());
+      octave_value array = args(0);
+      Array<int> dimv;
+      if (nargin > 1)
+        dimv = args (1).int_vector_value (true);
 
-	  for (int i = 0; i < dv.length(); i++)
-	    {
-	      idx_colon(i) = false;
-	      for (int j = 0; j < sings.length(); j++)
-		{
-		  if (sings(j) == i)
-		    {
-		      new_dv(i) = 1;
-		      idx_colon(i) = true;
-		      lst(i) = octave_value (octave_value::magic_colon_t); 
-		      break;
-		    }
-		}
-	    }
-
-	  Cell ret (new_dv);
-	  octave_idx_type nel = new_dv.numel();
-	  octave_idx_type ntot = 1;
+      if (error_state)
+        ;
+      else if (array.is_bool_type ())
+        retval = do_num2cell (array.bool_array_value (), dimv);
+      else if (array.is_char_matrix ())
+        retval = do_num2cell (array.char_array_value (), dimv);
+      else if (array.is_numeric_type ())
+        {
+          if (array.is_integer_type ())
+            {
+              if (array.is_int8_type ())
+                retval = do_num2cell (array.int8_array_value (), dimv);
+              else if (array.is_int16_type ())
+                retval = do_num2cell (array.int16_array_value (), dimv);
+              else if (array.is_int32_type ())
+                retval = do_num2cell (array.int32_array_value (), dimv);
+              else if (array.is_int64_type ())
+                retval = do_num2cell (array.int64_array_value (), dimv);
+              else if (array.is_uint8_type ())
+                retval = do_num2cell (array.uint8_array_value (), dimv);
+              else if (array.is_uint16_type ())
+                retval = do_num2cell (array.uint16_array_value (), dimv);
+              else if (array.is_uint32_type ())
+                retval = do_num2cell (array.uint32_array_value (), dimv);
+              else if (array.is_uint64_type ())
+                retval = do_num2cell (array.uint64_array_value (), dimv);
+            }
+          else if (array.is_complex_type ())
+            {
+              if (array.is_single_type ())
+                retval = do_num2cell (array.float_complex_array_value (), dimv);
+              else
+                retval = do_num2cell (array.complex_array_value (), dimv);
+            }
+          else
+            {
+              if (array.is_single_type ())
+                retval = do_num2cell (array.float_array_value (), dimv);
+              else
+                retval = do_num2cell (array.array_value (), dimv);
+            }
+        }
+      else if (array.is_cell () || array.is_map ())
+        {
+          dim_vector celldv, arraydv;
+          Array<int> perm;
+          do_num2cell_helper (array.dims (), dimv, celldv, arraydv, perm);
 
-	  for (int j = 0; j < new_dv.length()-1; j++)
-	    ntot *= new_dv(j);
+          if (! error_state)
+            {
+              // FIXME: this operation may be rather inefficient.
+              octave_value parray = array.permute (perm);
+
+              octave_idx_type nela = arraydv.numel (), nelc = celldv.numel ();
+              parray = parray.reshape (dim_vector (nela, nelc));
+
+              Cell retcell (celldv);
+              octave_value_list idx (2);
+              idx(0) = octave_value::magic_colon_t;
 
-	  for (octave_idx_type i = 0; i <  nel; i++)
-	    {
-	      octave_idx_type n = ntot;
-	      octave_idx_type ii = i;
-	      for (int j = new_dv.length() - 1; j >= 0 ; j--)
-		{
-		  if (! idx_colon(j))
-		    lst (j) = ii/n + 1;
-		  ii = ii % n;
-		  if (j != 0)
-		    n /= new_dv(j-1);
-		}
-	      ret(i) = args(0).do_index_op(lst, 0);
-	    }
+              for (octave_idx_type i = 0; i < nelc; i++)
+                {
+                  idx(1) = i + 1;
+                  octave_value tmp = parray.do_index_op (idx);
+                  retcell(i) = tmp.reshape (arraydv);
+                }
 
-	  retval = ret;
-	}
+              retval = retcell;
+            }
+        }
+      else
+        gripe_wrong_type_arg ("num2cell", array);
     }
 
   return retval;
@@ -1224,7 +1321,7 @@
 */
 
 template <class NDA>
-Cell 
+static Cell 
 do_cellslices_nda (const NDA& array, const idx_vector& lb, const idx_vector& ub)
 {
   octave_idx_type n = lb.length (0);
@@ -1242,7 +1339,9 @@
       for (octave_idx_type i = 0; i < n && ! error_state; i++)
         {
           // Do it with a single index to speed things up.
+          dv = array.dims ();
           dv(dv.length () - 1) = ub(i) + 1 - lb(i);
+          dv.chop_trailing_singletons ();
           retval(i) = array.index (idx_vector (nl*lb(i), nl*(ub(i) + 1))).reshape (dv);
         }
     }
@@ -1293,6 +1392,25 @@
                     retcell = do_cellslices_nda (x.bool_array_value (), lb, ub);
                   else if (x.is_char_matrix ())
                     retcell = do_cellslices_nda (x.char_array_value (), lb, ub);
+                  else if (x.is_integer_type ())
+                    {
+                      if (x.is_int8_type ())
+                        retcell = do_cellslices_nda (x.int8_array_value (), lb, ub);
+                      else if (x.is_int16_type ())
+                        retcell = do_cellslices_nda (x.int16_array_value (), lb, ub);
+                      else if (x.is_int32_type ())
+                        retcell = do_cellslices_nda (x.int32_array_value (), lb, ub);
+                      else if (x.is_int64_type ())
+                        retcell = do_cellslices_nda (x.int64_array_value (), lb, ub);
+                      else if (x.is_uint8_type ())
+                        retcell = do_cellslices_nda (x.uint8_array_value (), lb, ub);
+                      else if (x.is_uint16_type ())
+                        retcell = do_cellslices_nda (x.uint16_array_value (), lb, ub);
+                      else if (x.is_uint32_type ())
+                        retcell = do_cellslices_nda (x.uint32_array_value (), lb, ub);
+                      else if (x.is_uint64_type ())
+                        retcell = do_cellslices_nda (x.uint64_array_value (), lb, ub);
+                    }
                   else if (x.is_complex_type ())
                     {
                       if (x.is_single_type ())