changeset 18978:dccbc8bff5cb stable

Fix binmap for sparse-scalar or scalar-sparse operations (bug #40813). * oct-binmap.h (binmap (Sparse, Scalar), binmap (Scalar, Sparse)): Check that function is sparsity preserving before using sparse algorithm. Initialize retval row, column indices from original sparse array. Call maybe_compress (true) to remove zero elements that function may have produced. * data.cc (Fatan2): Only preserve sparsity if *first* argument is sparse. Add %!tests to verify sparse operations. * data.cc (do_hypot): Call binmap with sparse inputs whenever at least one of the inputs is sparse. * data.cc (Fhypot): Add %!tests to verify sparse operations. * data.cc (Frem): Call binmap with sparse inputs whenever at least one of the inputs is sparse. Add %!tests to verify sparse operations.
author Rik <rik@octave.org>
date Sun, 01 Jun 2014 21:41:58 -0700
parents 658d23da2c46
children 59975c3cea6b 322eb69e30ad
files libinterp/corefcn/data.cc liboctave/util/oct-binmap.h
diffstat 2 files changed, 150 insertions(+), 47 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/corefcn/data.cc
+++ b/libinterp/corefcn/data.cc
@@ -244,12 +244,9 @@
         }
       else
         {
-          bool a0_scalar = args(0).is_scalar_type ();
-          bool a1_scalar = args(1).is_scalar_type ();
-          if (a0_scalar && a1_scalar)
+          if (args(0).is_scalar_type () && args(1).is_scalar_type ())
             retval = atan2 (args(0).scalar_value (), args(1).scalar_value ());
-          else if ((a0_scalar || args(0).is_sparse_type ())
-                   && (a1_scalar || args(1).is_sparse_type ()))
+          else if (args(0).is_sparse_type ())
             {
               SparseMatrix m0 = args(0).sparse_matrix_value ();
               SparseMatrix m1 = args(1).sparse_matrix_value ();
@@ -292,6 +289,40 @@
 %! x = single ([1, 3, 1, 1, 1, 1, 3, 1]);
 %! assert (atan2 (y, x), v, sqrt (eps ("single")));
 
+## Test sparse implementations
+%!shared xs
+%! xs = sparse (0:3);
+%!test
+%! y = atan2 (1, xs);
+%! assert (issparse (y), false);
+%! assert (nnz (y), 4);
+%! assert (y, atan2 (1, 0:3));
+%!test
+%! y = atan2 (0, xs);
+%! assert (issparse (y), false);
+%! assert (nnz (y), 0);
+%! assert (y, zeros (1,4));
+%!test
+%! y = atan2 (xs, 1);
+%! assert (issparse (y));
+%! assert (nnz (y), 3);
+%! assert (y, sparse (atan2 (0:3, 1)));
+%!test
+%! y = atan2 (xs, 0);
+%! assert (issparse (y));
+%! assert (nnz (y), 3);
+%! assert (y, sparse (atan2 (0:3, 0)));
+%!test
+%! y = atan2 (xs, sparse (ones (1, 4)));
+%! assert (issparse (y));
+%! assert (nnz (y), 3);
+%! assert (y, sparse (atan2 (0:3, ones (1,4))));
+%!test
+%! y = atan2 (xs, sparse (zeros (1,4)));
+%! assert (issparse (y));
+%! assert (nnz (y), 3);
+%! assert (y, sparse (atan2 (0:3, zeros (1,4))));
+
 %!error atan2 ()
 %!error atan2 (1, 2, 3)
 */
@@ -327,12 +358,9 @@
         }
       else
         {
-          bool a0_scalar = arg0.is_scalar_type ();
-          bool a1_scalar = arg1.is_scalar_type ();
-          if (a0_scalar && a1_scalar)
+          if (arg0.is_scalar_type () && arg1.is_scalar_type ())
             retval = hypot (arg0.scalar_value (), arg1.scalar_value ());
-          else if ((a0_scalar || arg0.is_sparse_type ())
-                   && (a1_scalar || arg1.is_sparse_type ()))
+          else if (arg0.is_sparse_type () || arg1.is_sparse_type ())
             {
               SparseMatrix m0 = arg0.sparse_matrix_value ();
               SparseMatrix m1 = arg1.sparse_matrix_value ();
@@ -397,6 +425,35 @@
 %!assert (size (hypot (1, 2)), [1, 1])
 %!assert (hypot (1:10, 1:10), sqrt (2) * [1:10], 16*eps)
 %!assert (hypot (single (1:10), single (1:10)), single (sqrt (2) * [1:10]))
+
+## Test sparse implementations
+%!shared xs
+%! xs = sparse (0:3);
+%!test
+%! y = hypot (1, xs);
+%! assert (nnz (y), 4);
+%! assert (y, sparse (hypot (1, 0:3)));
+%!test
+%! y = hypot (0, xs);
+%! assert (nnz (y), 3);
+%! assert (y, xs);
+%!test
+%! y = hypot (xs, 1);
+%! assert (nnz (y), 4);
+%! assert (y, sparse (hypot (0:3, 1)));
+%!test
+%! y = hypot (xs, 0);
+%! assert (nnz (y), 3);
+%! assert (y, xs);
+%!test
+%! y = hypot (sparse ([0 0]), sparse ([0 1]));
+%! assert (nnz (y), 1);
+%! assert (y, sparse ([0 1]));
+%!test
+%! y = hypot (sparse ([0 1]), sparse ([0 0]));
+%! assert (nnz (y), 1);
+%! assert (y, sparse ([0 1]));
+
 */
 
 template<typename T, typename ET>
@@ -593,12 +650,9 @@
         }
       else
         {
-          bool a0_scalar = args(0).is_scalar_type ();
-          bool a1_scalar = args(1).is_scalar_type ();
-          if (a0_scalar && a1_scalar)
+          if (args(0).is_scalar_type () && args(1).is_scalar_type ())
             retval = xrem (args(0).scalar_value (), args(1).scalar_value ());
-          else if ((a0_scalar || args(0).is_sparse_type ())
-                   && (a1_scalar || args(1).is_sparse_type ()))
+          else if (args(0).is_sparse_type () || args(1).is_sparse_type ())
             {
               SparseMatrix m0 = args(0).sparse_matrix_value ();
               SparseMatrix m1 = args(1).sparse_matrix_value ();
@@ -619,26 +673,52 @@
 }
 
 /*
+%!assert (size (fmod (zeros (0, 2), zeros (0, 2))), [0, 2])
+%!assert (size (fmod (rand (2, 3, 4), zeros (2, 3, 4))), [2, 3, 4])
+%!assert (size (fmod (rand (2, 3, 4), 1)), [2, 3, 4])
+%!assert (size (fmod (1, rand (2, 3, 4))), [2, 3, 4])
+%!assert (size (fmod (1, 2)), [1, 1])
+
 %!assert (rem ([1, 2, 3; -1, -2, -3], 2), [1, 0, 1; -1, 0, -1])
 %!assert (rem ([1, 2, 3; -1, -2, -3], 2 * ones (2, 3)),[1, 0, 1; -1, 0, -1])
 %!assert (rem (uint8 ([1, 2, 3; -1, -2, -3]), uint8 (2)), uint8 ([1, 0, 1; -1, 0, -1]))
 %!assert (uint8 (rem ([1, 2, 3; -1, -2, -3], 2 * ones (2, 3))),uint8 ([1, 0, 1; -1, 0, -1]))
 
+## Test sparse implementations
+%!shared xs
+%! xs = sparse (0:3);
+%!test
+%! y = rem (11, xs);
+%! assert (nnz (y), 3);
+%! assert (y, sparse (rem (11, 0:3)));
+%!test
+%! y = rem (0, xs);
+%! assert (nnz (y), 0);
+%! assert (y, sparse (zeros (1,4)));
+%!test
+%! y = rem (xs, 2);
+%! assert (nnz (y), 2);
+%! assert (y, sparse (rem (0:3, 2)));
+%!test
+%! y = rem (xs, 1);
+%! assert (nnz (y), 0);
+%! assert (y, sparse (rem (0:3, 1)));
+%!test
+%! y = rem (sparse ([11 11 11 11]), xs);
+%! assert (nnz (y), 3);
+%! assert (y, sparse (rem (11, 0:3)));
+%!test
+%! y = rem (sparse ([0 0 0 0]), xs);
+%! assert (nnz (y), 0);
+%! assert (y, sparse (zeros (1,4)));
+
 %!error rem (uint (8), int8 (5))
 %!error rem (uint8 ([1, 2]), uint8 ([3, 4, 5]))
 %!error rem ()
 %!error rem (1, 2, 3)
 %!error rem ([1, 2], [3, 4, 5])
 %!error rem (i, 1)
-*/
-
-/*
-
-%!assert (size (fmod (zeros (0, 2), zeros (0, 2))), [0, 2])
-%!assert (size (fmod (rand (2, 3, 4), zeros (2, 3, 4))), [2, 3, 4])
-%!assert (size (fmod (rand (2, 3, 4), 1)), [2, 3, 4])
-%!assert (size (fmod (1, rand (2, 3, 4))), [2, 3, 4])
-%!assert (size (fmod (1, 2)), [1, 1])
+
 */
 
 DEFALIAS (fmod, rem)
@@ -727,12 +807,9 @@
         }
       else
         {
-          bool a0_scalar = args(0).is_scalar_type ();
-          bool a1_scalar = args(1).is_scalar_type ();
-          if (a0_scalar && a1_scalar)
+          if (args(0).is_scalar_type () && args(1).is_scalar_type ())
             retval = xmod (args(0).scalar_value (), args(1).scalar_value ());
-          else if ((a0_scalar || args(0).is_sparse_type ())
-                   && (a1_scalar || args(1).is_sparse_type ()))
+          else if (args(0).is_sparse_type () || args(1).is_sparse_type ())
             {
               SparseMatrix m0 = args(0).sparse_matrix_value ();
               SparseMatrix m1 = args(1).sparse_matrix_value ();
--- a/liboctave/util/oct-binmap.h
+++ b/liboctave/util/oct-binmap.h
@@ -218,17 +218,30 @@
 Sparse<U>
 binmap (const T& x, const Sparse<R>& ys, F fcn)
 {
-  octave_idx_type nz = ys.nnz ();
-  Sparse<U> retval (ys.rows (), ys.cols (), nz);
-  for (octave_idx_type i = 0; i < nz; i++)
+  R yzero = R ();
+  U fz = fcn (x, yzero);
+
+  if (fz == U ())  // Sparsity preserving fcn
     {
+      octave_idx_type nz = ys.nnz ();
+      Sparse<U> retval (ys.rows (), ys.cols (), nz);
+      copy_or_memcpy (nz, ys.ridx (), retval.ridx ());
+      copy_or_memcpy (ys.cols () + 1, ys.cidx (), retval.cidx ());
+
+      for (octave_idx_type i = 0; i < nz; i++)
+        {
+          octave_quit ();
+          // FIXME: Could keep track of whether fcn call results in a 0.
+          //        If no zeroes are created could skip maybe_compress()
+          retval.xdata (i) = fcn (x, ys.data (i));
+        }
+
       octave_quit ();
-      retval.xdata (i) = fcn (x, ys.data (i));
+      retval.maybe_compress (true);
+      return retval;
     }
-
-  octave_quit ();
-  retval.maybe_compress ();
-  return retval;
+  else
+    return Sparse<U> (binmap<U, T, R, F> (x, ys.array_value (), fcn));
 }
 
 // Sparse-scalar
@@ -236,17 +249,30 @@
 Sparse<U>
 binmap (const Sparse<T>& xs, const R& y, F fcn)
 {
-  octave_idx_type nz = xs.nnz ();
-  Sparse<U> retval (xs.rows (), xs.cols (), nz);
-  for (octave_idx_type i = 0; i < nz; i++)
+  T xzero = T ();
+  U fz = fcn (xzero, y);
+
+  if (fz == U ())  // Sparsity preserving fcn
     {
+      octave_idx_type nz = xs.nnz ();
+      Sparse<U> retval (xs.rows (), xs.cols (), nz);
+      copy_or_memcpy (nz, xs.ridx (), retval.ridx ());
+      copy_or_memcpy (xs.cols () + 1, xs.cidx (), retval.cidx ());
+
+      for (octave_idx_type i = 0; i < nz; i++)
+        {
+          octave_quit ();
+          // FIXME: Could keep track of whether fcn call results in a 0.
+          //        If no zeroes are created could skip maybe_compress()
+          retval.xdata (i) = fcn (xs.data (i), y);
+        }
+
       octave_quit ();
-      retval.xdata (i) = fcn (xs.data (i), y);
+      retval.maybe_compress (true);
+      return retval;
     }
-
-  octave_quit ();
-  retval.maybe_compress ();
-  return retval;
+  else
+    return Sparse<U> (binmap<U, T, R, F> (xs.array_value (), y, fcn));
 }
 
 // Sparse-Sparse (treats singletons as scalars)
@@ -295,8 +321,8 @@
                   jx++;
                   jx_lt_max = jx < jx_max;
                 }
-              else if (! jx_lt_max ||
-                  (jy_lt_max && (ys.ridx (jy) < xs.ridx (jx))))
+              else if (! jx_lt_max
+                       || (jy_lt_max && (ys.ridx (jy) < xs.ridx (jx))))
                 {
                   retval.xridx (nz) = ys.ridx (jy);
                   retval.xdata (nz) = fcn (xzero, ys.data (jy));