diff src/data.cc @ 7515:f3c00dc0912b

Eliminate the rest of the dispatched sparse functions
author David Bateman <dbateman@free.fr>
date Fri, 22 Feb 2008 15:50:51 +0100
parents 798b0a00e80c
children 7ebdc99a0bab
line wrap: on
line diff
--- a/src/data.cc
+++ b/src/data.cc
@@ -455,20 +455,11 @@
 
 	      if (! error_state)
 		{
-		  if (arg_x.is_sparse_type ())
-		    {
-		      SparseMatrix x = arg_x.sparse_matrix_value ();
-
-		      if (! error_state)
-			retval = map_d_s (atan2, y, x);
-		    }
-		  else
-		    {
-		      NDArray x = arg_x.array_value ();
-
-		      if (! error_state)
-			retval = map_d_m (atan2, y, x);
-		    }
+		  // Even if x is sparse return a full matrix here
+		  NDArray x = arg_x.array_value ();
+
+		  if (! error_state)
+		    retval = map_d_m (atan2, y, x);
 		}
 	    }
 	  else if (x_is_scalar)
@@ -480,7 +471,7 @@
 		  if (! error_state)
 		    {
 		      double x = arg_x.double_value ();
-
+		      
 		      if (! error_state)
 			retval = map_s_d (atan2, y, x);
 		    }
@@ -500,7 +491,8 @@
 	    }
 	  else if (y_dims == x_dims)
 	    {
-	      if (arg_y.is_sparse_type () || arg_x.is_sparse_type ())
+	      // Even if y is sparse return a full matrix here
+	      if (arg_x.is_sparse_type ())
 		{
 		  SparseMatrix y = arg_y.sparse_matrix_value ();
 
@@ -712,21 +704,67 @@
 	{ \
 	  if (dim >= -1) \
 	    { \
-              if (isnative) \
-                { \
-                  if NATIVE_REDUCTION_1 (FCN, uint8, dim) \
-                  else if NATIVE_REDUCTION_1 (FCN, uint16, dim) \
-                  else if NATIVE_REDUCTION_1 (FCN, uint32, dim) \
-                  else if NATIVE_REDUCTION_1 (FCN, uint64, dim) \
-                  else if NATIVE_REDUCTION_1 (FCN, int8, dim) \
-                  else if NATIVE_REDUCTION_1 (FCN, int16, dim) \
-                  else if NATIVE_REDUCTION_1 (FCN, int32, dim) \
-                  else if NATIVE_REDUCTION_1 (FCN, int64, dim) \
-                  else if NATIVE_REDUCTION_1 (FCN, bool, dim) \
-                  else if (arg.is_char_matrix ()) \
-                    { \
-                       error (#FCN, ": invalid char type"); \
+	      if (arg.is_sparse_type ()) \
+		{ \
+		  if (arg.is_real_type ()) \
+		    { \
+		      SparseMatrix tmp = arg.sparse_matrix_value (); \
+		      \
+		      if (! error_state) \
+			retval = tmp.FCN (dim); \
+		    } \
+		  else \
+		    { \
+		      SparseComplexMatrix tmp = arg.sparse_complex_matrix_value (); \
+                      \
+		      if (! error_state) \
+			retval = tmp.FCN (dim); \
+		    } \
+		} \
+	      else \
+		{ \
+		  if (isnative)	\
+		    { \
+		      if NATIVE_REDUCTION_1 (FCN, uint8, dim) \
+		      else if NATIVE_REDUCTION_1 (FCN, uint16, dim) \
+                      else if NATIVE_REDUCTION_1 (FCN, uint32, dim) \
+                      else if NATIVE_REDUCTION_1 (FCN, uint64, dim) \
+                      else if NATIVE_REDUCTION_1 (FCN, int8, dim) \
+                      else if NATIVE_REDUCTION_1 (FCN, int16, dim) \
+                      else if NATIVE_REDUCTION_1 (FCN, int32, dim) \
+                      else if NATIVE_REDUCTION_1 (FCN, int64, dim) \
+                      else if NATIVE_REDUCTION_1 (FCN, bool, dim) \
+                      else if (arg.is_char_matrix ()) \
+                        { \
+			  error (#FCN, ": invalid char type"); \
+			} \
+	              else if (arg.is_complex_type ()) \
+		        { \
+		          ComplexNDArray tmp = arg.complex_array_value (); \
+                          \
+		          if (! error_state) \
+		            retval = tmp.FCN (dim); \
+		        } \
+	              else if (arg.is_real_type ()) \
+		        { \
+		          NDArray tmp = arg.array_value (); \
+                          \
+		          if (! error_state) \
+		            retval = tmp.FCN (dim); \
+		        } \
+                      else \
+		        { \
+		          gripe_wrong_type_arg (#FCN, arg); \
+		          return retval; \
+		        } \
                     } \
+	          else if (arg.is_real_type ()) \
+		    { \
+		      NDArray tmp = arg.array_value (); \
+                      \
+		      if (! error_state) \
+		        retval = tmp.FCN (dim); \
+		    } \
 	          else if (arg.is_complex_type ()) \
 		    { \
 		      ComplexNDArray tmp = arg.complex_array_value (); \
@@ -734,37 +772,11 @@
 		      if (! error_state) \
 		        retval = tmp.FCN (dim); \
 		    } \
-	          else if (arg.is_real_type ()) \
-		    { \
-		      NDArray tmp = arg.array_value (); \
-                      \
-		      if (! error_state) \
-		        retval = tmp.FCN (dim); \
-		    } \
-                  else \
+	          else \
 		    { \
 		      gripe_wrong_type_arg (#FCN, arg); \
 		      return retval; \
 		    } \
-                } \
-	      else if (arg.is_real_type ()) \
-		{ \
-		  NDArray tmp = arg.array_value (); \
-                  \
-		  if (! error_state) \
-		    retval = tmp.FCN (dim); \
-		} \
-	      else if (arg.is_complex_type ()) \
-		{ \
-		  ComplexNDArray tmp = arg.complex_array_value (); \
-                  \
-		  if (! error_state) \
-		    retval = tmp.FCN (dim); \
-		} \
-	      else \
-		{ \
-		  gripe_wrong_type_arg (#FCN, arg); \
-		  return retval; \
 		} \
 	    } \
 	  else \
@@ -795,17 +807,37 @@
 	    { \
 	      if (arg.is_real_type ()) \
 		{ \
-		  NDArray tmp = arg.array_value (); \
+		  if (arg.is_sparse_type ()) \
+		    { \
+		      SparseMatrix tmp = arg.sparse_matrix_value (); \
  \
-		  if (! error_state) \
-		    retval = tmp.FCN (dim); \
+		      if (! error_state) \
+			retval = tmp.FCN (dim); \
+		    } \
+		  else \
+		    { \
+		      NDArray tmp = arg.array_value (); \
+ \
+		      if (! error_state) \
+			retval = tmp.FCN (dim); \
+		    } \
 		} \
 	      else if (arg.is_complex_type ()) \
 		{ \
-		  ComplexNDArray tmp = arg.complex_array_value (); \
+		  if (arg.is_sparse_type ()) \
+		    { \
+		      SparseComplexMatrix tmp = arg.sparse_complex_matrix_value (); \
  \
-		  if (! error_state) \
-		    retval = tmp.FCN (dim); \
+		      if (! error_state) \
+			retval = tmp.FCN (dim); \
+		    } \
+		  else \
+		    { \
+		      ComplexNDArray tmp = arg.complex_array_value (); \
+ \
+		      if (! error_state) \
+			retval = tmp.FCN (dim); \
+		    } \
 		} \
 	      else \
 		{ \
@@ -946,6 +978,94 @@
 make_diag (const uint64NDArray& v, octave_idx_type k);
 #endif
 
+template <class T>
+static octave_value
+make_spdiag (const T& v, octave_idx_type k)
+{
+  octave_value retval;
+  dim_vector dv = v.dims ();
+  octave_idx_type nr = dv (0);
+  octave_idx_type nc = dv (1);
+
+  if (nr == 0 || nc == 0)
+    retval = T ();
+  else if (nr != 1 && nc != 1)
+    retval = v.diag (k);
+  else
+    {
+      octave_idx_type roff = 0;
+      octave_idx_type coff = 0;
+      if (k > 0) 
+	{
+	  roff = 0;
+	  coff = k;
+	} 
+      else if (k < 0) 
+	{
+	  roff = -k;
+	  coff = 0;
+	}
+
+      if (nr == 1) 
+	{
+	  octave_idx_type n = nc + std::abs (k);
+	  octave_idx_type nz = v.nzmax ();
+	  T r (n, n, nz);
+	  for (octave_idx_type i = 0; i < coff+1; i++)
+	    r.xcidx (i) = 0;
+	  for (octave_idx_type j = 0; j < nc; j++)
+	    {
+	      for (octave_idx_type i = v.cidx(j); i < v.cidx(j+1); i++)
+		{
+		  r.xdata (i) = v.data (i);
+		  r.xridx (i) = j + roff;
+		}
+	      r.xcidx (j+coff+1) = v.cidx(j+1);
+	    }
+	  for (octave_idx_type i = nc+coff+1; i < n+1; i++)
+	    r.xcidx (i) = nz;
+	  retval = r;
+	} 
+      else 
+	{
+	  octave_idx_type n = nr + std::abs (k);
+	  octave_idx_type nz = v.nzmax ();
+	  octave_idx_type ii = 0;
+	  octave_idx_type ir = v.ridx(0);
+	  T r (n, n, nz);
+	  for (octave_idx_type i = 0; i < coff+1; i++)
+	    r.xcidx (i) = 0;
+	  for (octave_idx_type i = 0; i < nr; i++)
+	    {
+	      if (ir == i)
+		{
+		  r.xdata (ii) = v.data (ii);
+		  r.xridx (ii++) = ir + roff;
+		  if (ii != nz)
+		    ir = v.ridx (ii);
+		}
+	      r.xcidx (i+coff+1) = ii;
+	    }
+	  for (octave_idx_type i = nr+coff+1; i < n+1; i++)
+	    r.xcidx (i) = nz;
+	  retval = r;
+	}
+    }
+
+  return retval;
+}
+
+#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
+static octave_value
+make_spdiag (const SparseMatrix& v, octave_idx_type k);
+
+static octave_value
+make_spdiag (const SparseComplexMatrix& v, octave_idx_type k);
+
+static octave_value
+make_spdiag (const SparseBoolMatrix& v, octave_idx_type k);
+#endif
+
 static octave_value
 make_diag (const octave_value& a, octave_idx_type k)
 {
@@ -954,17 +1074,35 @@
 
   if (result_type == "double")
     {
-      if (a.is_real_type ())
+      if (a.is_sparse_type ())
 	{
-	  Matrix m = a.matrix_value ();
-	  if (!error_state)
-	    retval = make_diag (m, k);
+	  if (a.is_real_type ())
+	    {
+	      SparseMatrix m = a.sparse_matrix_value ();
+	      if (!error_state)
+		retval = make_spdiag (m, k);
+	    }
+	  else
+	    {
+	      SparseComplexMatrix m = a.sparse_complex_matrix_value ();
+	      if (!error_state)
+		retval = make_spdiag (m, k);
+	    }
 	}
       else
 	{
-	  ComplexMatrix m = a.complex_matrix_value ();
-	  if (!error_state)
-	    retval = make_diag (m, k);
+	  if (a.is_real_type ())
+	    {
+	      Matrix m = a.matrix_value ();
+	      if (!error_state)
+		retval = make_diag (m, k);
+	    }
+	  else
+	    {
+	      ComplexMatrix m = a.complex_matrix_value ();
+	      if (!error_state)
+		retval = make_diag (m, k);
+	    }
 	}
     }
 #if 0
@@ -983,9 +1121,18 @@
     }
   else if (result_type == "logical")
     {
-      boolMatrix m = a.bool_matrix_value ();
-      if (!error_state)
-	retval = make_diag (m, k);
+      if (a.is_sparse_type ())
+	{
+	  SparseBoolMatrix m = a.sparse_bool_matrix_value ();
+	  if (!error_state)
+	    retval = make_spdiag (m, k);
+	}
+      else
+	{
+	  boolMatrix m = a.bool_matrix_value ();
+	  if (!error_state)
+	    retval = make_diag (m, k);
+	}
     }
   else if (result_type == "int8")
     retval = make_diag (a.int8_array_value (), k);