changeset 10764:e141bcb1befd

implement map concat optimizations for [] operator
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 02 Jul 2010 10:10:51 +0200
parents b397b8edd8c5
children 3952b4c4e44a
files src/ChangeLog src/oct-map.cc src/oct-map.h src/pt-mat.cc
diffstat 4 files changed, 91 insertions(+), 2 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog
+++ b/src/ChangeLog
@@ -1,3 +1,17 @@
+2010-07-02  Jaroslav Hajek  <highegg@gmail.com>
+
+	* pt-mat.cc (tm_row_const::tm_row_const_rep::all_1x1,
+	tm_cont::all_1x1): New member fields.
+	(tm_row_const::tm_row_const_rep::init, tm_const::init):
+	Handle them here.
+	(tm_row_const::all_1x1_p, tm_const::all_1x1_p): New methods.
+	(single_type_concat<MAP> (octave_map&, ...)): New template
+	overload.
+	(do_single_type_concat<octave_map>): New template specialization.
+	(tree_matrix::rvalue1): Specialize for cell and struct classes.
+	* oct-map.cc (octave_map::do_cat (..., const octave_map *, ...)): 
+	Assign result dimensions.
+
 2010-07-02  Jaroslav Hajek  <highegg@gmail.com>
 
 	* oct-map.cc (octave_map::cat (..., const octave_scalar_map *)):
--- a/src/oct-map.cc
+++ b/src/oct-map.cc
@@ -609,6 +609,8 @@
         field_list[i] = map_list[i].xvals[j];
 
       retval.xvals.push_back (Array<octave_value>::cat (dim, n, field_list));
+      if (j == 0)
+        retval.dimensions = retval.xvals[j].dims ();
     }
 }
 
--- a/src/oct-map.h
+++ b/src/oct-map.h
@@ -364,6 +364,7 @@
   // The Array-like methods.
   octave_idx_type numel (void) const { return dimensions.numel (); }
   octave_idx_type length (void) const { return numel (); }
+  bool is_empty (void) const { return dimensions.any_zero (); }
 
   octave_idx_type rows (void) const { return dimensions(0); }
   octave_idx_type cols (void) const { return dimensions(1); }
--- a/src/pt-mat.cc
+++ b/src/pt-mat.cc
@@ -70,14 +70,14 @@
         all_sq_str (false), all_dq_str (false),
         some_str (false), all_real (false), all_cmplx (false),
         all_mt (true), any_sparse (false), any_class (false),
-        class_nm (), ok (false)
+        all_1x1 (false), class_nm (), ok (false)
     { }
 
     tm_row_const_rep (const tree_argument_list& row)
       : count (1), dv (0, 0), all_str (false), all_sq_str (false),
         some_str (false), all_real (false), all_cmplx (false),
         all_mt (true), any_sparse (false), any_class (false),
-        class_nm (), ok (false)
+        all_1x1 (! row.empty ()), class_nm (), ok (false)
     { init (row); }
 
     ~tm_row_const_rep (void) { }
@@ -95,6 +95,7 @@
     bool all_mt;
     bool any_sparse;
     bool any_class;
+    bool all_1x1;
 
     std::string class_nm;
 
@@ -171,6 +172,7 @@
   bool all_empty_p (void) const { return rep->all_mt; }
   bool any_sparse_p (void) const { return rep->any_sparse; }
   bool any_class_p (void) const { return rep->any_class; }
+  bool all_1x1_p (void) const { return rep->all_1x1; }
 
   std::string class_name (void) const { return rep->class_nm; }
 
@@ -326,6 +328,8 @@
   if (!any_class && val.is_object ())
     any_class = true;
 
+  all_1x1 = all_1x1 && val.numel () == 1;
+
   return true;
 }
 
@@ -420,6 +424,7 @@
   bool all_empty_p (void) const { return all_mt; }
   bool any_sparse_p (void) const { return any_sparse; }
   bool any_class_p (void) const { return any_class; }
+  bool all_1x1_p (void) const { return all_1x1; }
 
   std::string class_name (void) const { return class_nm; }
 
@@ -438,6 +443,7 @@
   bool all_mt;
   bool any_sparse;
   bool any_class;
+  bool all_1x1;
 
   std::string class_nm;
 
@@ -462,6 +468,7 @@
   all_cmplx = true;
   any_sparse = false;
   any_class = false;
+  all_1x1 = ! empty ();
 
   bool first_elem = true;
 
@@ -507,6 +514,8 @@
           if (!any_class && tmp.any_class_p ())
             any_class = true;
 
+          all_1x1 = all_1x1 && tmp.all_1x1_p ();
+
           append (tmp);
         }
       else
@@ -681,6 +690,7 @@
     {
       // If possible, forward the operation to liboctave.
       // Single row.
+      // FIXME: optimize all scalars case.
       tm_row_const& row = tmp.front ();
       octave_idx_type ncols = row.length (), i = 0;
       OCTAVE_LOCAL_BUFFER (Array<T>, array_list, ncols);
@@ -752,6 +762,49 @@
   result = Sparse<T>::cat (0, nrows, sparse_row_list);
 }
 
+template<class MAP>
+static void 
+single_type_concat (octave_map& result,
+                    const dim_vector& dv,
+                    tm_const& tmp)
+{
+  if (dv.any_zero ())
+    {
+      result = octave_map (dv);
+      return;
+    }
+
+  octave_idx_type nrows = tmp.length (), j = 0;
+  OCTAVE_LOCAL_BUFFER (octave_map, map_row_list, nrows);
+  for (tm_const::iterator p = tmp.begin (); p != tmp.end (); p++)
+    {
+      tm_row_const row = *p;
+      octave_idx_type ncols = row.length (), i = 0;
+      OCTAVE_LOCAL_BUFFER (MAP, map_list, ncols);
+
+      for (tm_row_const::iterator q = row.begin ();
+           q != row.end () && ! error_state;
+           q++)
+        {
+          octave_quit ();
+
+          // Use 0x0 in place of all empty arrays to allow looser rules.
+          // If MAP is octave_scalar_map, the condition is vacuously true.
+          if (! q->is_empty ())
+            map_list[i] = octave_value_extract<MAP> (*q);
+          i++;
+        }
+
+      octave_map mtmp = octave_map::cat (1, ncols, map_list);
+      // Use 0x0 in place of all empty arrays to allow looser rules.
+      if (! mtmp.is_empty ())
+        map_row_list[j] = mtmp;
+      j++;
+    }
+
+  result = octave_map::cat (0, nrows, map_row_list);
+}
+
 template<class TYPE>
 static octave_value 
 do_single_type_concat (const dim_vector& dv,
@@ -764,6 +817,21 @@
   return result;
 }
 
+template<>
+octave_value 
+do_single_type_concat<octave_map> (const dim_vector& dv,
+                                   tm_const& tmp)
+{
+  octave_map result;
+
+  if (tmp.all_1x1_p ())
+    single_type_concat<octave_scalar_map> (result, dv, tmp);
+  else
+    single_type_concat<octave_map> (result, dv, tmp);
+
+  return result;
+}
+
 template<class TYPE, class OV_TYPE>
 static octave_value 
 do_single_type_concat_no_mutate (const dim_vector& dv,
@@ -933,6 +1001,10 @@
         retval = do_single_type_concat<uint32NDArray> (dv, tmp);
       else if (result_type == "uint64")
         retval = do_single_type_concat<uint64NDArray> (dv, tmp);
+      else if (result_type == "cell")
+        retval = do_single_type_concat<Cell> (dv, tmp);
+      else if (result_type == "struct")
+        retval = do_single_type_concat<octave_map> (dv, tmp);
       else
         {
           // The line below might seem crazy, since we take a copy of