diff liboctave/Sparse-op-defs.h @ 7802:1a446f28ce68

implement optimized sparse-dense transposed multiplication
author Jaroslav Hajek <highegg@gmail.com>
date Sun, 18 May 2008 20:23:31 +0200
parents 288614c6634d
children 9bcb31cc56be
line wrap: on
line diff
--- a/liboctave/Sparse-op-defs.h
+++ b/liboctave/Sparse-op-defs.h
@@ -1904,15 +1904,7 @@
   \
   if (nr == 1 && nc == 1) \
     { \
-      RET_TYPE retval (a_nr, a_nc, ZERO); \
-      for (octave_idx_type i = 0; i < a_nc ; i++) \
-	{ \
-	  for (octave_idx_type j = 0; j < a_nr; j++) \
-	    { \
-              OCTAVE_QUIT; \
-	      retval.elem (j,i) += a.elem(j,i) * m.elem(0,0); \
-	    } \
-        } \
+      RET_TYPE retval = m.elem (0,0) * a; \
       return retval; \
     } \
   else if (nc != a_nr) \
@@ -1925,15 +1917,51 @@
       RET_TYPE retval (nr, a_nc, ZERO); \
       \
       for (octave_idx_type i = 0; i < a_nc ; i++) \
-	{ \
-	  for (octave_idx_type j = 0; j < a_nr; j++) \
-	    { \
+        { \
+          for (octave_idx_type j = 0; j < a_nr; j++) \
+            { \
               OCTAVE_QUIT; \
-	      \
+              \
               EL_TYPE tmpval = a.elem(j,i); \
-	      for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \
-	        retval.elem (m.ridx(k),i) += tmpval * m.data(k); \
-	    } \
+              for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \
+                retval.elem (m.ridx(k),i) += tmpval * m.data(k); \
+            } \
+        } \
+      return retval; \
+    }
+
+#define SPARSE_FULL_TRANS_MUL( RET_TYPE, EL_TYPE, ZERO, CONJ_OP ) \
+  octave_idx_type nr = m.rows (); \
+  octave_idx_type nc = m.cols (); \
+  \
+  octave_idx_type a_nr = a.rows (); \
+  octave_idx_type a_nc = a.cols (); \
+  \
+  if (nr == 1 && nc == 1) \
+    { \
+      RET_TYPE retval = CONJ_OP (m.elem(0,0)) * a; \
+      return retval; \
+    } \
+  else if (nr != a_nr) \
+    { \
+      gripe_nonconformant ("operator *", nc, nr, a_nr, a_nc); \
+      return RET_TYPE (); \
+    } \
+  else \
+    { \
+      RET_TYPE retval (nc, a_nc); \
+      \
+      for (octave_idx_type i = 0; i < a_nc ; i++) \
+        { \
+          for (octave_idx_type j = 0; j < nc; j++) \
+            { \
+              OCTAVE_QUIT; \
+              \
+              EL_TYPE acc = ZERO; \
+              for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \
+                acc += a.elem (m.ridx(k),i) * CONJ_OP (m.data(k)); \
+              retval.xelem (j,i) = acc; \
+            } \
         } \
       return retval; \
     }
@@ -1947,15 +1975,7 @@
   \
   if (a_nr == 1 && a_nc == 1) \
     { \
-      RET_TYPE retval (nr, nc, ZERO); \
-      for (octave_idx_type i = 0; i < nc ; i++) \
-	{ \
-	  for (octave_idx_type j = 0; j < nr; j++) \
-	    { \
-              OCTAVE_QUIT; \
-	      retval.elem (j,i) += a.elem(0,0) * m.elem(j,i); \
-	    } \
-        } \
+      RET_TYPE retval = m * a.elem (0,0); \
       return retval; \
     } \
   else if (nc != a_nr) \
@@ -1968,16 +1988,51 @@
       RET_TYPE retval (nr, a_nc, ZERO); \
       \
       for (octave_idx_type i = 0; i < a_nc ; i++) \
-	{ \
-	   for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \
-	     { \
-	        octave_idx_type col = a.ridx(j); \
-	        EL_TYPE tmpval = a.data(j); \
-                OCTAVE_QUIT; \
-	        \
-	        for (octave_idx_type k = 0 ; k < nr; k++) \
-	          retval.elem (k,i) += tmpval * m.elem(k,col); \
-	    } \
+        { \
+          OCTAVE_QUIT; \
+          for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \
+            { \
+              octave_idx_type col = a.ridx(j); \
+              EL_TYPE tmpval = a.data(j); \
+              \
+              for (octave_idx_type k = 0 ; k < nr; k++) \
+                retval.xelem (k,i) += tmpval * m.elem(k,col); \
+            } \
+        } \
+      return retval; \
+    }
+
+#define FULL_SPARSE_MUL_TRANS( RET_TYPE, EL_TYPE, ZERO, CONJ_OP ) \
+  octave_idx_type nr = m.rows (); \
+  octave_idx_type nc = m.cols (); \
+  \
+  octave_idx_type a_nr = a.rows (); \
+  octave_idx_type a_nc = a.cols (); \
+  \
+  if (a_nr == 1 && a_nc == 1) \
+    { \
+      RET_TYPE retval = m * CONJ_OP (a.elem(0,0)); \
+      return retval; \
+    } \
+  else if (nc != a_nc) \
+    { \
+      gripe_nonconformant ("operator *", nr, nc, a_nc, a_nr); \
+      return RET_TYPE (); \
+    } \
+  else \
+    { \
+      RET_TYPE retval (nr, a_nr, ZERO); \
+      \
+      for (octave_idx_type i = 0; i < a_nc ; i++) \
+        { \
+          OCTAVE_QUIT; \
+          for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \
+            { \
+              octave_idx_type col = a.ridx(j); \
+              EL_TYPE tmpval = CONJ_OP (a.data(j)); \
+              for (octave_idx_type k = 0 ; k < nr; k++) \
+                retval.xelem (k,col) += tmpval * m.elem(k,i); \
+            } \
         } \
       return retval; \
     }