changeset 10900:b64803a8be4e

optimize element-wise sparse-dense multiplication and division
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 18 Aug 2010 14:02:16 +0200
parents 686e3bc432a2
children 860427ac9b77
files liboctave/ChangeLog liboctave/Sparse-op-defs.h liboctave/mx-inlines.cc
diffstat 3 files changed, 89 insertions(+), 59 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/ChangeLog
+++ b/liboctave/ChangeLog
@@ -1,3 +1,14 @@
+2010-08-18  Jaroslav Hajek  <highegg@gmail.com>
+
+	* mx-inlines.cc (mx_inline_all_finite): New check.
+	* Sparse-op-defs.h (SPARSE_SMS_BIN_OP_1, SPARSE_SMS_BIN_OP_2,
+	SPARSE_SSM_BIN_OP_1, SPARSE_SSM_BIN_OP_2): Use unchecked access where
+	appropriate.
+	(SPARSE_SMM_BIN_OP_1, SPARSE_MSM_BIN_OP_1): Simplify.
+	(SPARSE_SMM_BIN_OP_2, SPARSE_MSM_BIN_OP_2): Use optimized code path
+	if all values are finite.
+	(SPARSE_MSM_BIN_OPS): Use SPARSE_MSM_BIN_OP_1 for division.
+
 2010-07-31  Rik <octave@nomad.inbox5.com>
 
 	* DASPK-opts.in, DASRT-opts.in, DASSL-opts.in, LSODE-opts.in, 
--- a/liboctave/Sparse-op-defs.h
+++ b/liboctave/Sparse-op-defs.h
@@ -28,6 +28,7 @@
 #include "Array-util.h"
 #include "mx-ops.h"
 #include "oct-locbuf.h"
+#include "mx-inlines.cc"
 
 #define SPARSE_BIN_OP_DECL(R, OP, X, Y, API) \
   extern API R OP (const X&, const Y&)
@@ -57,7 +58,7 @@
  \
     for (octave_idx_type j = 0; j < nc; j++) \
       for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++) \
-        r.elem (m.ridx (i), j) = m.data (i) OP s; \
+        r.xelem (m.ridx (i), j) = m.data (i) OP s; \
     return r; \
   }
 
@@ -73,11 +74,11 @@
  \
     for (octave_idx_type i = 0; i < nz; i++) \
       { \
-        r.data(i) = m.data(i) OP s; \
-        r.ridx(i) = m.ridx(i); \
+        r.xdata(i) = m.data(i) OP s; \
+        r.xridx(i) = m.ridx(i); \
       } \
     for (octave_idx_type i = 0; i < nc + 1; i++) \
-      r.cidx(i) = m.cidx(i); \
+      r.xcidx(i) = m.cidx(i); \
     \
     r.maybe_compress (true); \
     return r; \
@@ -225,7 +226,7 @@
  \
     for (octave_idx_type j = 0; j < nc; j++) \
       for (octave_idx_type i = m.cidx (j); i < m.cidx (j+1); i++) \
-        r.elem (m.ridx (i), j) = s OP m.data (i); \
+        r.xelem (m.ridx (i), j) = s OP m.data (i); \
  \
     return r; \
   }
@@ -242,11 +243,11 @@
  \
     for (octave_idx_type i = 0; i < nz; i++) \
       { \
-        r.data(i) = s OP m.data(i); \
-        r.ridx(i) = m.ridx(i); \
+        r.xdata(i) = s OP m.data(i); \
+        r.xridx(i) = m.ridx(i); \
       } \
     for (octave_idx_type i = 0; i < nc + 1; i++) \
-      r.cidx(i) = m.cidx(i); \
+      r.xcidx(i) = m.cidx(i); \
  \
     r.maybe_compress(true); \
     return r; \
@@ -1095,16 +1096,12 @@
       gripe_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \
     else \
       { \
-        r = R (m1_nr, m1_nc); \
-        \
-        for (octave_idx_type j = 0; j < m1_nc; j++) \
-          for (octave_idx_type i = 0; i < m1_nr; i++) \
-            r.elem (i, j) = m1.elem (i, j) OP m2.elem (i, j); \
+        r = R (F (m1, m2.matrix_value ())); \
       } \
     return r; \
   }
 
-#define SPARSE_MSM_BIN_OP_2(R, F, OP, M1, M2, ZERO) \
+#define SPARSE_MSM_BIN_OP_2(R, F, OP, M1, M2) \
   R \
   F (const M1& m1, const M2& m2) \
   { \
@@ -1122,29 +1119,32 @@
       gripe_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \
     else \
       { \
-        /* Count num of non-zero elements */ \
-        octave_idx_type nel = 0; \
-        for (octave_idx_type j = 0; j < m1_nc; j++) \
-          for (octave_idx_type i = 0; i < m1_nr; i++) \
-            if ((m1.elem(i, j) OP m2.elem(i, j)) != ZERO) \
-              nel++; \
-        \
-        r = R (m1_nr, m1_nc, nel); \
-        \
-        octave_idx_type ii = 0; \
-        r.cidx (0) = 0; \
-        for (octave_idx_type j = 0 ; j < m1_nc ; j++) \
+        if (do_mx_check (m1, mx_inline_all_finite)) \
           { \
-            for (octave_idx_type i = 0 ; i < m1_nr ; i++)       \
+            /* Sparsity pattern is preserved. */ \
+            octave_idx_type m2_nz = m2.nnz (); \
+            r = R (m2_nr, m2_nc, m2_nz); \
+            for (octave_idx_type j = 0, k = 0; j < m2_nc; j++) \
               { \
-                if ((m1.elem(i, j) OP m2.elem(i, j)) != ZERO) \
+                octave_quit (); \
+                for (octave_idx_type i = m2.cidx(j); i < m2.cidx(j+1); i++) \
                   { \
-                    r.data (ii) = m1.elem(i, j) OP m2.elem(i,j); \
-                    r.ridx (ii++) = i; \
+                    octave_idx_type mri = m2.ridx(i); \
+                    R::element_type x = m1(mri, j) OP m2.data(i); \
+                    if (x != 0.0) \
+                      { \
+                        r.xdata(k) = x; \
+                        r.xridx(k) = m2.ridx(i); \
+                        k++; \
+                      } \
                   } \
+                r.xcidx(j+1) = k; \
               } \
-            r.cidx(j+1) = ii; \
+            r.maybe_compress (false); \
+            return r; \
           } \
+        else \
+          r = R (F (m1, m2.matrix_value ())); \
       } \
  \
     return r; \
@@ -1154,8 +1154,8 @@
 #define SPARSE_MSM_BIN_OPS(R1, R2, M1, M2) \
   SPARSE_MSM_BIN_OP_1 (R1, operator +,  +, M1, M2) \
   SPARSE_MSM_BIN_OP_1 (R1, operator -,  -, M1, M2) \
-  SPARSE_MSM_BIN_OP_2 (R2, product,     *, M1, M2, 0.0) \
-  SPARSE_MSM_BIN_OP_2 (R2, quotient,    /, M1, M2, 0.0)
+  SPARSE_MSM_BIN_OP_2 (R2, product,     *, M1, M2) \
+  SPARSE_MSM_BIN_OP_1 (R2, quotient,    /, M1, M2)
 
 #define SPARSE_MSM_CMP_OP_DECLS(M1, M2, API) \
   SPARSE_CMP_OP_DECL (mx_el_lt, M1, M2, API); \
@@ -1329,16 +1329,20 @@
       gripe_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \
     else \
       { \
-        r = R (m1_nr, m1_nc); \
-        \
-        for (octave_idx_type j = 0; j < m1_nc; j++) \
-          for (octave_idx_type i = 0; i < m1_nr; i++) \
-            r.elem (i, j) = m1.elem (i, j) OP m2.elem (i, j); \
+        r = R (m1.matrix_value () OP m2); \
       } \
     return r; \
   }
 
-#define SPARSE_SMM_BIN_OP_2(R, F, OP, M1, M2, ZERO) \
+// sm .* m preserves sparsity if m contains no Infs nor Nans.
+#define SPARSE_SMM_BIN_OP_2_CHECK_product \
+  do_mx_check (m2, mx_inline_all_finite)
+
+// sm ./ m preserves sparsity if m contains no NaNs or zeros.
+#define SPARSE_SMM_BIN_OP_2_CHECK_quotient \
+  ! do_mx_check (m2, mx_inline_any_nan) && m2.nnz () == m2.numel ()
+
+#define SPARSE_SMM_BIN_OP_2(R, F, OP, M1, M2) \
   R \
   F (const M1& m1, const M2& m2) \
   { \
@@ -1356,40 +1360,42 @@
       gripe_nonconformant (#F, m1_nr, m1_nc, m2_nr, m2_nc); \
     else \
       { \
-        /* Count num of non-zero elements */ \
-        octave_idx_type nel = 0; \
-        for (octave_idx_type j = 0; j < m1_nc; j++) \
-          for (octave_idx_type i = 0; i < m1_nr; i++) \
-            if ((m1.elem(i, j) OP m2.elem(i, j)) != ZERO) \
-              nel++; \
-        \
-        r = R (m1_nr, m1_nc, nel); \
-        \
-        octave_idx_type ii = 0; \
-        r.cidx (0) = 0; \
-        for (octave_idx_type j = 0 ; j < m1_nc ; j++) \
+        if (SPARSE_SMM_BIN_OP_2_CHECK_ ## F) \
           { \
-            for (octave_idx_type i = 0 ; i < m1_nr ; i++)       \
+            /* Sparsity pattern is preserved. */ \
+            octave_idx_type m1_nz = m1.nnz (); \
+            r = R (m1_nr, m1_nc, m1_nz); \
+            for (octave_idx_type j = 0, k = 0; j < m1_nc; j++) \
               { \
-                if ((m1.elem(i, j) OP m2.elem(i, j)) != ZERO) \
+                octave_quit (); \
+                for (octave_idx_type i = m1.cidx(j); i < m1.cidx(j+1); i++) \
                   { \
-                    r.data (ii) = m1.elem(i, j) OP m2.elem(i,j); \
-                    r.ridx (ii++) = i; \
+                    octave_idx_type mri = m1.ridx(i); \
+                    R::element_type x = m1.data(i) OP m2(mri, j); \
+                    if (x != 0.0) \
+                      { \
+                        r.xdata(k) = x; \
+                        r.xridx(k) = m1.ridx(i); \
+                        k++; \
+                      } \
                   } \
+                r.xcidx(j+1) = k; \
               } \
-            r.cidx(j+1) = ii; \
+            r.maybe_compress (false); \
+            return r; \
           } \
+        else \
+          r = R (F (m1.matrix_value (), m2)); \
       } \
  \
     return r; \
   }
 
-// FIXME Pass a specific ZERO value
 #define SPARSE_SMM_BIN_OPS(R1, R2, M1, M2) \
   SPARSE_SMM_BIN_OP_1 (R1, operator +,  +, M1, M2) \
   SPARSE_SMM_BIN_OP_1 (R1, operator -,  -, M1, M2) \
-  SPARSE_SMM_BIN_OP_2 (R2, product,     *, M1, M2, 0.0) \
-  SPARSE_SMM_BIN_OP_2 (R2, quotient,    /, M1, M2, 0.0)
+  SPARSE_SMM_BIN_OP_2 (R2, product,     *, M1, M2) \
+  SPARSE_SMM_BIN_OP_2 (R2, quotient,    /, M1, M2)
 
 #define SPARSE_SMM_CMP_OP_DECLS(M1, M2, API) \
   SPARSE_CMP_OP_DECL (mx_el_lt, M1, M2, API); \
--- a/liboctave/mx-inlines.cc
+++ b/liboctave/mx-inlines.cc
@@ -185,6 +185,19 @@
   return false;
 }
 
+template <class T>
+inline bool
+mx_inline_all_finite (size_t n, const T* x)  throw ()
+{
+  for (size_t i = 0; i < n; i++)
+    {
+      if (! xfinite (x[i]))
+        return false;
+    }
+
+  return true;
+}
+
 template <class T> 
 inline bool 
 mx_inline_any_negative (size_t n, const T* x) throw ()