diff src/data.cc @ 10716:f7f26094021b

improve cat code design in data.cc, make horzcat/vertcat more Matlab compatible
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 21 Jun 2010 15:48:56 +0200
parents fbd7843974fa
children f3892d8eea9f
line wrap: on
line diff
--- a/src/data.cc
+++ b/src/data.cc
@@ -1376,17 +1376,17 @@
                     int dim)
 {
   int n_args = args.length ();
-  OCTAVE_LOCAL_BUFFER (Array<T>, array_list, n_args - 1);
-
-  for (int j = 1; j < n_args && ! error_state; j++)
+  OCTAVE_LOCAL_BUFFER (Array<T>, array_list, n_args);
+
+  for (int j = 0; j < n_args && ! error_state; j++)
     {
       octave_quit ();
 
-      array_list[j-1] = octave_value_extract<TYPE> (args(j));
+      array_list[j] = octave_value_extract<TYPE> (args(j));
     }
 
   if (! error_state)
-    result = Array<T>::cat (dim, n_args-1, array_list);
+    result = Array<T>::cat (dim, n_args, array_list);
 }
 
 template <class TYPE, class T>
@@ -1396,17 +1396,17 @@
                     int dim)
 {
   int n_args = args.length ();
-  OCTAVE_LOCAL_BUFFER (Sparse<T>, sparse_list, n_args-1);
-
-  for (int j = 1; j < n_args && ! error_state; j++)
+  OCTAVE_LOCAL_BUFFER (Sparse<T>, sparse_list, n_args);
+
+  for (int j = 0; j < n_args && ! error_state; j++)
     {
       octave_quit ();
 
-      sparse_list[j-1] = octave_value_extract<TYPE> (args(j));
+      sparse_list[j] = octave_value_extract<TYPE> (args(j));
     }
 
   if (! error_state)
-    result = Sparse<T>::cat (dim, n_args-1, sparse_list);
+    result = Sparse<T>::cat (dim, n_args, sparse_list);
 }
 
 // Dispatcher.
@@ -1422,172 +1422,168 @@
 }
 
 static octave_value
-do_cat (const octave_value_list& args, std::string fname)
+do_cat (const octave_value_list& args, int dim, std::string fname)
 {
   octave_value retval;
 
   int n_args = args.length (); 
 
-  if (n_args == 1)
+  if (n_args == 0)
     retval = Matrix ();
-  else if (n_args == 2)
-    retval = args(1);
-  else if (n_args > 2)
+  else if (n_args == 1)
+    retval = args(0);
+  else if (n_args > 1)
     {
-      octave_idx_type dim = args(0).int_value () - 1;
-
-      if (error_state)
-        {
-          error ("cat: expecting first argument to be a integer");
-          return retval;
-        }
-  
-      if (dim >= 0)
+
+      std::string result_type = args(0).class_name ();
+
+      bool all_sq_strings_p = args(0).is_sq_string ();
+      bool all_dq_strings_p = args(0).is_dq_string ();
+      bool all_real_p = args(0).is_real_type ();
+      bool any_sparse_p = args(0).is_sparse_type();
+
+      for (int i = 1; i < args.length (); i++)
         {
-          
-          std::string result_type = args(1).class_name ();
-          
-          bool all_sq_strings_p = args(1).is_sq_string ();
-          bool all_dq_strings_p = args(1).is_dq_string ();
-          bool all_real_p = args(1).is_real_type ();
-          bool any_sparse_p = args(1).is_sparse_type();
-
-          for (int i = 2; i < args.length (); i++)
-            {
-              result_type = 
-                get_concat_class (result_type, args(i).class_name ());
-
-              if (all_sq_strings_p && ! args(i).is_sq_string ())
-                all_sq_strings_p = false;
-              if (all_dq_strings_p && ! args(i).is_dq_string ())
-                all_dq_strings_p = false;
-              if (all_real_p && ! args(i).is_real_type ())
-                all_real_p = false;
-              if (!any_sparse_p && args(i).is_sparse_type ())
-                any_sparse_p = true;
+          result_type = 
+            get_concat_class (result_type, args(i).class_name ());
+
+          if (all_sq_strings_p && ! args(i).is_sq_string ())
+            all_sq_strings_p = false;
+          if (all_dq_strings_p && ! args(i).is_dq_string ())
+            all_dq_strings_p = false;
+          if (all_real_p && ! args(i).is_real_type ())
+            all_real_p = false;
+          if (!any_sparse_p && args(i).is_sparse_type ())
+            any_sparse_p = true;
+        }
+
+      if (result_type == "double")
+        {
+          if (any_sparse_p)
+            {           
+              if (all_real_p)
+                retval = do_single_type_concat<SparseMatrix> (args, dim);
+              else
+                retval = do_single_type_concat<SparseComplexMatrix> (args, dim);
             }
-
-          if (result_type == "double")
-            {
-              if (any_sparse_p)
-                {           
-                  if (all_real_p)
-                    retval = do_single_type_concat<SparseMatrix> (args, dim);
-                  else
-                    retval = do_single_type_concat<SparseComplexMatrix> (args, dim);
-                }
-              else
-                {
-                  if (all_real_p)
-                    retval = do_single_type_concat<NDArray> (args, dim);
-                  else
-                    retval = do_single_type_concat<ComplexNDArray> (args, dim);
-                }
-            }
-          else if (result_type == "single")
+          else
             {
               if (all_real_p)
-                retval = do_single_type_concat<FloatNDArray> (args, dim);
+                retval = do_single_type_concat<NDArray> (args, dim);
               else
-                retval = do_single_type_concat<FloatComplexNDArray> (args, dim);
-            }
-          else if (result_type == "char")
-            {
-              char type = all_dq_strings_p ? '"' : '\'';
-
-              maybe_warn_string_concat (all_dq_strings_p, all_sq_strings_p);
-
-              charNDArray result =  do_single_type_concat<charNDArray> (args, dim);
-
-              retval = octave_value (result, type);
-            }
-          else if (result_type == "logical")
-            {
-              if (any_sparse_p)
-                retval = do_single_type_concat<SparseBoolMatrix> (args, dim);
-              else
-                retval = do_single_type_concat<boolNDArray> (args, dim);
+                retval = do_single_type_concat<ComplexNDArray> (args, dim);
             }
-          else if (result_type == "int8")
-            retval = do_single_type_concat<int8NDArray> (args, dim);
-          else if (result_type == "int16")
-            retval = do_single_type_concat<int16NDArray> (args, dim);
-          else if (result_type == "int32")
-            retval = do_single_type_concat<int32NDArray> (args, dim);
-          else if (result_type == "int64")
-            retval = do_single_type_concat<int64NDArray> (args, dim);
-          else if (result_type == "uint8")
-            retval = do_single_type_concat<uint8NDArray> (args, dim);
-          else if (result_type == "uint16")
-            retval = do_single_type_concat<uint16NDArray> (args, dim);
-          else if (result_type == "uint32")
-            retval = do_single_type_concat<uint32NDArray> (args, dim);
-          else if (result_type == "uint64")
-            retval = do_single_type_concat<uint64NDArray> (args, dim);
+        }
+      else if (result_type == "single")
+        {
+          if (all_real_p)
+            retval = do_single_type_concat<FloatNDArray> (args, dim);
+          else
+            retval = do_single_type_concat<FloatComplexNDArray> (args, dim);
+        }
+      else if (result_type == "char")
+        {
+          char type = all_dq_strings_p ? '"' : '\'';
+
+          maybe_warn_string_concat (all_dq_strings_p, all_sq_strings_p);
+
+          charNDArray result =  do_single_type_concat<charNDArray> (args, dim);
+
+          retval = octave_value (result, type);
+        }
+      else if (result_type == "logical")
+        {
+          if (any_sparse_p)
+            retval = do_single_type_concat<SparseBoolMatrix> (args, dim);
           else
+            retval = do_single_type_concat<boolNDArray> (args, dim);
+        }
+      else if (result_type == "int8")
+        retval = do_single_type_concat<int8NDArray> (args, dim);
+      else if (result_type == "int16")
+        retval = do_single_type_concat<int16NDArray> (args, dim);
+      else if (result_type == "int32")
+        retval = do_single_type_concat<int32NDArray> (args, dim);
+      else if (result_type == "int64")
+        retval = do_single_type_concat<int64NDArray> (args, dim);
+      else if (result_type == "uint8")
+        retval = do_single_type_concat<uint8NDArray> (args, dim);
+      else if (result_type == "uint16")
+        retval = do_single_type_concat<uint16NDArray> (args, dim);
+      else if (result_type == "uint32")
+        retval = do_single_type_concat<uint32NDArray> (args, dim);
+      else if (result_type == "uint64")
+        retval = do_single_type_concat<uint64NDArray> (args, dim);
+      else
+        {
+          dim_vector  dv = args(0).dims ();
+
+          // Default concatenation.
+          bool (dim_vector::*concat_rule) (const dim_vector&, int) = &dim_vector::concat;
+
+          if (dim == -1 || dim == -2)
             {
-              dim_vector  dv = args(1).dims ();
-
-              for (int i = 2; i < args.length (); i++)
+              concat_rule = &dim_vector::hvcat;
+              dim = -dim - 1;
+            }
+
+          for (int i = 1; i < args.length (); i++)
+            {
+              if (! (dv.*concat_rule) (args(i).dims (), dim))
                 {
-                  if (! dv.concat (args(i).dims (), dim))
-                    {
-                      // Dimensions do not match. 
-                      error ("cat: dimension mismatch");
-                      return retval;
-                    }
+                  // Dimensions do not match. 
+                  error ("cat: dimension mismatch");
+                  return retval;
                 }
-              
-              // The lines below might seem crazy, since we take a copy
-              // of the first argument, resize it to be empty and then resize
-              // it to be full. This is done since it means that there is no
-              // recopying of data, as would happen if we used a single resize.
-              // It should be noted that resize operation is also significantly 
-              // slower than the do_cat_op function, so it makes sense to have
-              // an empty matrix and copy all data.
-              //
-              // We might also start with a empty octave_value using
-              //   tmp = octave_value_typeinfo::lookup_type 
-              //                                (args(1).type_name());
-              // and then directly resize. However, for some types there might
-              // be some additional setup needed, and so this should be avoided.
-
-              octave_value tmp = args (1);
-              tmp = tmp.resize (dim_vector (0,0)).resize (dv);
+            }
+
+          // The lines below might seem crazy, since we take a copy
+          // of the first argument, resize it to be empty and then resize
+          // it to be full. This is done since it means that there is no
+          // recopying of data, as would happen if we used a single resize.
+          // It should be noted that resize operation is also significantly 
+          // slower than the do_cat_op function, so it makes sense to have
+          // an empty matrix and copy all data.
+          //
+          // We might also start with a empty octave_value using
+          //   tmp = octave_value_typeinfo::lookup_type 
+          //                                (args(1).type_name());
+          // and then directly resize. However, for some types there might
+          // be some additional setup needed, and so this should be avoided.
+
+          octave_value tmp = args (0);
+          tmp = tmp.resize (dim_vector (0,0)).resize (dv);
+
+          if (error_state)
+            return retval;
+
+          int dv_len = dv.length ();
+          Array<octave_idx_type> ra_idx (dv_len, 1, 0);
+
+          for (int j = 0; j < n_args; j++)
+            {
+              // Can't fast return here to skip empty matrices as something
+              // like cat(1,[],single([])) must return an empty matrix of
+              // the right type.
+              tmp = do_cat_op (tmp, args (j), ra_idx);
 
               if (error_state)
                 return retval;
 
-              int dv_len = dv.length ();
-              Array<octave_idx_type> ra_idx (dv_len, 1, 0);
-
-              for (int j = 1; j < n_args; j++)
+              dim_vector dv_tmp = args (j).dims ();
+
+              if (dim >= dv_len)
                 {
-                  // Can't fast return here to skip empty matrices as something
-                  // like cat(1,[],single([])) must return an empty matrix of
-                  // the right type.
-                  tmp = do_cat_op (tmp, args (j), ra_idx);
-
-                  if (error_state)
-                    return retval;
-
-                  dim_vector dv_tmp = args (j).dims ();
-
-                  if (dim >= dv_len)
-                    {
-                      if (j > 1)
-                        error ("%s: indexing error", fname.c_str ());
-                      break;
-                    }
-                  else
-                    ra_idx (dim) += (dim < dv_tmp.length () ? 
-                                     dv_tmp (dim) : 1);
+                  if (j > 1)
+                    error ("%s: indexing error", fname.c_str ());
+                  break;
                 }
-              retval = tmp;
+              else
+                ra_idx (dim) += (dim < dv_tmp.length () ? 
+                                 dv_tmp (dim) : 1);
             }
+          retval = tmp;
         }
-      else
-        error ("%s: invalid dimension argument", fname.c_str ());
     }
   else
     print_usage ();
@@ -1603,15 +1599,7 @@
 @seealso{cat, vertcat}\n\
 @end deftypefn")
 {
-  octave_value_list args_tmp = args;
-  
-  int dim = 2;
-  
-  octave_value d (dim);
-  
-  args_tmp.prepend (d);
-  
-  return do_cat (args_tmp, "horzcat");
+  return do_cat (args, -2, "horzcat");
 }
 
 DEFUN (vertcat, args, ,
@@ -1622,15 +1610,7 @@
 @seealso{cat, horzcat}\n\
 @end deftypefn")
 {
-  octave_value_list args_tmp = args;
-  
-  int dim = 1;
-  
-  octave_value d (dim);
-  
-  args_tmp.prepend (d);
-  
-  return do_cat (args_tmp, "vertcat");
+  return do_cat (args, -1, "vertcat");
 }
 
 DEFUN (cat, args, ,
@@ -1681,7 +1661,26 @@
 @seealso{horzcat, vertcat}\n\
 @end deftypefn")
 {
-  return do_cat (args, "cat");
+  octave_value retval;
+
+  if (args.length () > 0)
+    {
+      int dim = args(0).int_value () - 1;
+
+      if (! error_state)
+        {
+          if (dim >= 0)
+            retval = do_cat (args.slice (1, args.length () - 1), dim, "cat");
+          else
+            error ("cat: invalid dimension specified");
+        }
+      else
+        error ("cat: expecting first argument to be a integer");
+    }
+  else
+    print_usage ();
+
+  return retval;
 }
 
 /*