diff src/data.cc @ 9721:192d94cff6c1

improve sum & implement the 'extra' option, refactor some code
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 13 Oct 2009 12:22:50 +0200
parents 71160b139b07
children f426899f4b9c
line wrap: on
line diff
--- a/src/data.cc
+++ b/src/data.cc
@@ -1600,7 +1600,118 @@
 @seealso{sum, cumprod}\n\
 @end deftypefn")
 {
-  NATIVE_REDUCTION (cumsum, cumsum);
+  octave_value retval;
+
+  int nargin = args.length ();
+
+  bool isnative = false;
+  bool isdouble = false;
+
+  if (nargin > 1 && args(nargin - 1).is_string ())
+    {
+      std::string str = args(nargin - 1).string_value ();
+
+      if (! error_state)
+	{
+	  if (str == "native")
+	    isnative = true;
+	  else if (str == "double")
+            isdouble = true;
+          else
+	    error ("sum: unrecognized string argument");
+          nargin --;
+	}
+    }
+
+  if (error_state)
+    return retval;
+
+  if (nargin == 1 || nargin == 2)
+    {
+      octave_value arg = args(0);
+
+      int dim = -1;
+      if (nargin == 2)
+        {
+          dim = args(1).int_value () - 1;
+          if (dim < 0)
+	    error ("cumsum: invalid dimension argument = %d", dim + 1);
+        }
+
+      if (! error_state)
+	{
+          switch (arg.builtin_type ())
+            {
+            case btyp_double:
+              if (arg.is_sparse_type ())
+                retval = arg.sparse_matrix_value ().cumsum (dim);
+              else
+                retval = arg.array_value ().cumsum (dim);
+              break;
+            case btyp_complex:
+              if (arg.is_sparse_type ())
+                retval = arg.sparse_complex_matrix_value ().cumsum (dim);
+              else
+                retval = arg.complex_array_value ().cumsum (dim);
+              break;
+            case btyp_float:
+              if (isdouble)
+                retval = arg.array_value ().cumsum (dim);
+              else
+                retval = arg.float_array_value ().cumsum (dim);
+              break;
+            case btyp_float_complex:
+              if (isdouble)
+                retval = arg.complex_array_value ().cumsum (dim);
+              else
+                retval = arg.float_complex_array_value ().cumsum (dim);
+              break;
+
+#define MAKE_INT_BRANCH(X) \
+            case btyp_ ## X: \
+              if (isnative) \
+                retval = arg.X ## _array_value ().cumsum (dim); \
+              else \
+                retval = arg.array_value ().cumsum (dim); \
+              break
+            MAKE_INT_BRANCH (int8);
+            MAKE_INT_BRANCH (int16);
+            MAKE_INT_BRANCH (int32);
+            MAKE_INT_BRANCH (int64);
+            MAKE_INT_BRANCH (uint8);
+            MAKE_INT_BRANCH (uint16);
+            MAKE_INT_BRANCH (uint32);
+            MAKE_INT_BRANCH (uint64);
+#undef MAKE_INT_BRANCH
+
+            case btyp_bool:
+              if (arg.is_sparse_type ())
+                {
+                  SparseMatrix cs = arg.sparse_matrix_value ().cumsum (dim);
+                  if (isnative)
+                    retval = cs != 0.0;
+                  else
+                    retval = cs;
+                }
+              else
+                {
+                  NDArray cs = arg.bool_array_value ().cumsum (dim);
+                  if (isnative)
+                    retval = cs != 0.0;
+                  else
+                    retval = cs;
+                }
+              break;
+
+            default:
+              gripe_wrong_type_arg ("cumsum", arg);
+            }
+	}
+    }
+  else
+    print_usage ();
+
+  return retval;
 }
 
 /*
@@ -2553,6 +2664,8 @@
 @deftypefn  {Built-in Function} {} sum (@var{x})\n\
 @deftypefnx {Built-in Function} {} sum (@var{x}, @var{dim})\n\
 @deftypefnx {Built-in Function} {} sum (@dots{}, 'native')\n\
+@deftypefnx {Built-in Function} {} sum (@dots{}, 'double')\n\
+@deftypefnx {Built-in Function} {} sum (@dots{}, 'extra')\n\
 Sum of elements along dimension @var{dim}.  If @var{dim} is\n\
 omitted, it defaults to 1 (column-wise sum).\n\
 \n\
@@ -2571,10 +2684,136 @@
   @result{} true\n\
 @end group\n\
 @end example\n\
+On the contrary, if 'double' is given, the sum is performed in double precision\n\
+even for single precision inputs.\n\
+For double precision inputs, 'extra' indicates that a more accurate algorithm\n\
+than straightforward summation is to be used. For single precision inputs, 'extra' is\n\
+the same as 'double'. Otherwise, 'extra' has no effect.\n\
 @seealso{cumsum, sumsq, prod}\n\
 @end deftypefn")
 {
-  NATIVE_REDUCTION (sum, any);
+  octave_value retval;
+
+  int nargin = args.length ();
+
+  bool isnative = false;
+  bool isdouble = false;
+  bool isextra = false;
+
+  if (nargin > 1 && args(nargin - 1).is_string ())
+    {
+      std::string str = args(nargin - 1).string_value ();
+
+      if (! error_state)
+	{
+	  if (str == "native")
+	    isnative = true;
+	  else if (str == "double")
+            isdouble = true;
+          else if (str == "extra")
+            isextra = true;
+          else
+	    error ("sum: unrecognized string argument");
+          nargin --;
+	}
+    }
+
+  if (error_state)
+    return retval;
+
+  if (nargin == 1 || nargin == 2)
+    {
+      octave_value arg = args(0);
+
+      int dim = -1;
+      if (nargin == 2)
+        {
+          dim = args(1).int_value () - 1;
+          if (dim < 0)
+	    error ("sum: invalid dimension argument = %d", dim + 1);
+        }
+
+      if (! error_state)
+	{
+          switch (arg.builtin_type ())
+            {
+            case btyp_double:
+              if (arg.is_sparse_type ())
+                {
+                  if (isextra)
+                    warning ("sum: 'extra' not yet implemented for sparse matrices");
+                  retval = arg.sparse_matrix_value ().sum (dim);
+                }
+              else if (isextra)
+                retval = arg.array_value ().xsum (dim);
+              else
+                retval = arg.array_value ().sum (dim);
+              break;
+            case btyp_complex:
+              if (arg.is_sparse_type ())
+                {
+                  if (isextra)
+                    warning ("sum: 'extra' not yet implemented for sparse matrices");
+                  retval = arg.sparse_complex_matrix_value ().sum (dim);
+                }
+              else if (isextra)
+                retval = arg.complex_array_value ().xsum (dim);
+              else
+                retval = arg.complex_array_value ().sum (dim);
+              break;
+            case btyp_float:
+              if (isdouble || isextra)
+                retval = arg.float_array_value ().dsum (dim);
+              else
+                retval = arg.float_array_value ().sum (dim);
+              break;
+            case btyp_float_complex:
+              if (isdouble || isextra)
+                retval = arg.float_complex_array_value ().dsum (dim);
+              else
+                retval = arg.float_complex_array_value ().sum (dim);
+              break;
+
+#define MAKE_INT_BRANCH(X) \
+            case btyp_ ## X: \
+              if (isnative) \
+                retval = arg.X ## _array_value ().sum (dim); \
+              else \
+                retval = arg.X ## _array_value ().dsum (dim); \
+              break
+            MAKE_INT_BRANCH (int8);
+            MAKE_INT_BRANCH (int16);
+            MAKE_INT_BRANCH (int32);
+            MAKE_INT_BRANCH (int64);
+            MAKE_INT_BRANCH (uint8);
+            MAKE_INT_BRANCH (uint16);
+            MAKE_INT_BRANCH (uint32);
+            MAKE_INT_BRANCH (uint64);
+#undef MAKE_INT_BRANCH
+
+            case btyp_bool:
+              if (arg.is_sparse_type ())
+                {
+                  if (isnative)
+                    retval = arg.sparse_bool_matrix_value ().any (dim);
+                  else
+                    retval = arg.sparse_matrix_value ().sum (dim);
+                }
+              else if (isnative)
+                retval = arg.bool_array_value ().any (dim);
+              else
+                retval = arg.bool_array_value ().sum (dim);
+              break;
+
+            default:
+              gripe_wrong_type_arg ("sum", arg);
+            }
+	}
+    }
+  else
+    print_usage ();
+
+  return retval;
 }
 
 /*