Mercurial > hg > octave-lyh
changeset 9721:192d94cff6c1
improve sum & implement the 'extra' option, refactor some code
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Tue, 13 Oct 2009 12:22:50 +0200 |
parents | 2997727398d1 |
children | 97d683d8b9ff |
files | liboctave/CNDArray.cc liboctave/CNDArray.h liboctave/ChangeLog liboctave/dNDArray.cc liboctave/dNDArray.h liboctave/fCNDArray.cc liboctave/fCNDArray.h liboctave/fNDArray.cc liboctave/fNDArray.h liboctave/intNDArray.cc liboctave/intNDArray.h liboctave/lo-traits.h liboctave/mx-inlines.cc src/ChangeLog src/data.cc |
diffstat | 15 files changed, 414 insertions(+), 2 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/CNDArray.cc +++ b/liboctave/CNDArray.cc @@ -662,6 +662,12 @@ } ComplexNDArray +ComplexNDArray::xsum (int dim) const +{ + return do_mx_red_op<ComplexNDArray, Complex> (*this, dim, mx_inline_xsum); +} + +ComplexNDArray ComplexNDArray::sumsq (int dim) const { return do_mx_red_op<NDArray, Complex> (*this, dim, mx_inline_sumsq);
--- a/liboctave/CNDArray.h +++ b/liboctave/CNDArray.h @@ -81,6 +81,7 @@ ComplexNDArray cumsum (int dim = -1) const; ComplexNDArray prod (int dim = -1) const; ComplexNDArray sum (int dim = -1) const; + ComplexNDArray xsum (int dim = -1) const; ComplexNDArray sumsq (int dim = -1) const; ComplexNDArray concat (const ComplexNDArray& rb, const Array<octave_idx_type>& ra_idx); ComplexNDArray concat (const NDArray& rb, const Array<octave_idx_type>& ra_idx);
--- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,3 +1,20 @@ +2009-10-13 Jaroslav Hajek <highegg@gmail.com> + + * lo-traits.h (equal_types, is_instance, subst_template_param): New + traits classes. + * mx-inlines.cc (op_dble_sum, twosum_accum): New helper funcs. + (mx_inline_dsum, mx_inline_xsum): New reduction loops. + * fNDArray.cc (FloatNDArray::dsum): New method. + * fNDArray.h: Declare it. + * fCNDArray.cc (FloatComplexNDArray::dsum): New method. + * fCNDArray.h: Declare it. + * dNDArray.cc (NDArray::xsum): New method. + * dNDArray.h: Declare it. + * CNDArray.cc (ComplexNDArray::xsum): New method. + * CNDArray.h: Declare it. + * intNDArray.cc (intNDArray::dsum): New method. + * intNDArray.h: Declare it. + 2009-10-12 Jaroslav Hajek <highegg@gmail.com> * base-qr.cc (base_qr::regular): New method.
--- a/liboctave/dNDArray.cc +++ b/liboctave/dNDArray.cc @@ -726,6 +726,12 @@ } NDArray +NDArray::xsum (int dim) const +{ + return do_mx_red_op<NDArray, double> (*this, dim, mx_inline_xsum); +} + +NDArray NDArray::sumsq (int dim) const { return do_mx_red_op<NDArray, double> (*this, dim, mx_inline_sumsq);
--- a/liboctave/dNDArray.h +++ b/liboctave/dNDArray.h @@ -92,6 +92,7 @@ NDArray cumsum (int dim = -1) const; NDArray prod (int dim = -1) const; NDArray sum (int dim = -1) const; + NDArray xsum (int dim = -1) const; NDArray sumsq (int dim = -1) const; NDArray concat (const NDArray& rb, const Array<octave_idx_type>& ra_idx); ComplexNDArray concat (const ComplexNDArray& rb, const Array<octave_idx_type>& ra_idx);
--- a/liboctave/fCNDArray.cc +++ b/liboctave/fCNDArray.cc @@ -656,6 +656,12 @@ return do_mx_red_op<FloatComplexNDArray, FloatComplex> (*this, dim, mx_inline_sum); } +ComplexNDArray +FloatComplexNDArray::dsum (int dim) const +{ + return do_mx_red_op<ComplexNDArray, FloatComplex> (*this, dim, mx_inline_dsum); +} + FloatComplexNDArray FloatComplexNDArray::sumsq (int dim) const {
--- a/liboctave/fCNDArray.h +++ b/liboctave/fCNDArray.h @@ -81,6 +81,7 @@ FloatComplexNDArray cumsum (int dim = -1) const; FloatComplexNDArray prod (int dim = -1) const; FloatComplexNDArray sum (int dim = -1) const; + ComplexNDArray dsum (int dim = -1) const; FloatComplexNDArray sumsq (int dim = -1) const; FloatComplexNDArray concat (const FloatComplexNDArray& rb, const Array<octave_idx_type>& ra_idx); FloatComplexNDArray concat (const FloatNDArray& rb, const Array<octave_idx_type>& ra_idx);
--- a/liboctave/fNDArray.cc +++ b/liboctave/fNDArray.cc @@ -683,6 +683,12 @@ return do_mx_red_op<FloatNDArray, float> (*this, dim, mx_inline_sum); } +NDArray +FloatNDArray::dsum (int dim) const +{ + return do_mx_red_op<NDArray, float> (*this, dim, mx_inline_dsum); +} + FloatNDArray FloatNDArray::sumsq (int dim) const {
--- a/liboctave/fNDArray.h +++ b/liboctave/fNDArray.h @@ -89,6 +89,7 @@ FloatNDArray cumsum (int dim = -1) const; FloatNDArray prod (int dim = -1) const; FloatNDArray sum (int dim = -1) const; + NDArray dsum (int dim = -1) const; FloatNDArray sumsq (int dim = -1) const; FloatNDArray concat (const FloatNDArray& rb, const Array<octave_idx_type>& ra_idx); FloatComplexNDArray concat (const FloatComplexNDArray& rb, const Array<octave_idx_type>& ra_idx);
--- a/liboctave/intNDArray.cc +++ b/liboctave/intNDArray.cc @@ -209,6 +209,13 @@ } template <class T> +NDArray +intNDArray<T>::dsum (int dim) const +{ + return do_mx_red_op<NDArray , T> (*this, dim, mx_inline_dsum); +} + +template <class T> intNDArray<T> intNDArray<T>::cumsum (int dim) const {
--- a/liboctave/intNDArray.h +++ b/liboctave/intNDArray.h @@ -25,6 +25,7 @@ #include "MArrayN.h" #include "boolNDArray.h" +class NDArray; template <class T> class @@ -90,6 +91,7 @@ intNDArray cummin (ArrayN<octave_idx_type>& index, int dim = 0) const; intNDArray sum (int dim) const; + NDArray dsum (int dim) const; intNDArray cumsum (int dim) const; intNDArray diff (octave_idx_type order = 1, int dim = 0) const;
--- a/liboctave/lo-traits.h +++ b/liboctave/lo-traits.h @@ -48,6 +48,41 @@ typedef T2 result; }; +// Determine whether two types are equal. +template <class T1, class T2> +class equal_types +{ +public: + + static const bool value = false; +}; + +template <class T> +class equal_types <T, T> +{ +public: + + static const bool value = false; +}; + +// Determine whether a type is an instance of a template. + +template <template <class> class Template, class T> +class is_instance +{ +public: + + static const bool value = false; +}; + +template <template <class> class Template, class T> +class is_instance <Template, Template<T> > +{ +public: + + static const bool value = true; +}; + // Determine whether a template paramter is a class type. template<typename T1> @@ -98,6 +133,23 @@ typedef T type; }; +// Will turn TemplatedClass<T> to TemplatedClass<S>, T to S otherwise. +// Useful for generic promotions. + +template<template<typename> class TemplatedClass, typename T, typename S> +class subst_template_param +{ +public: + typedef S type; +}; + +template<template<typename> class TemplatedClass, typename T, typename S> +class subst_template_param<TemplatedClass, TemplatedClass<T>, S> +{ +public: + typedef TemplatedClass<S> type; +}; + #endif /*
--- a/liboctave/mx-inlines.cc +++ b/liboctave/mx-inlines.cc @@ -415,6 +415,14 @@ #define OP_RED_SUMSQ(ac, el) ac += el*el #define OP_RED_SUMSQC(ac, el) ac += cabsq (el) +inline void op_dble_sum(double& ac, float el) +{ ac += el; } +inline void op_dble_sum(Complex& ac, const FloatComplex& el) +{ ac += el; } // FIXME: guaranteed? +template <class T> +inline void op_dble_sum(double& ac, const octave_int<T>& el) +{ ac += el.double_value (); } + // The following two implement a simple short-circuiting. #define OP_RED_ANYC(ac, el) if (xis_true (el)) { ac = true; break; } else continue #define OP_RED_ALLC(ac, el) if (xis_false (el)) { ac = false; break; } else continue @@ -430,7 +438,10 @@ return ac; \ } +#define PROMOTE_DOUBLE(T) typename subst_template_param<std::complex, T, double>::type + OP_RED_FCN (mx_inline_sum, T, T, OP_RED_SUM, 0) +OP_RED_FCN (mx_inline_dsum, T, PROMOTE_DOUBLE(T), op_dble_sum, 0.0) OP_RED_FCN (mx_inline_count, bool, T, OP_RED_SUM, 0) OP_RED_FCN (mx_inline_prod, T, T, OP_RED_PROD, 1) OP_RED_FCN (mx_inline_sumsq, T, T, OP_RED_SUMSQ, 0) @@ -455,6 +466,7 @@ } OP_RED_FCN2 (mx_inline_sum, T, T, OP_RED_SUM, 0) +OP_RED_FCN2 (mx_inline_dsum, T, PROMOTE_DOUBLE(T), op_dble_sum, 0.0) OP_RED_FCN2 (mx_inline_count, bool, T, OP_RED_SUM, 0) OP_RED_FCN2 (mx_inline_prod, T, T, OP_RED_PROD, 1) OP_RED_FCN2 (mx_inline_sumsq, T, T, OP_RED_SUMSQ, 0) @@ -518,6 +530,7 @@ } OP_RED_FCNN (mx_inline_sum, T, T) +OP_RED_FCNN (mx_inline_dsum, T, PROMOTE_DOUBLE(T)) OP_RED_FCNN (mx_inline_count, bool, T) OP_RED_FCNN (mx_inline_prod, T, T) OP_RED_FCNN (mx_inline_sumsq, T, T) @@ -1238,6 +1251,54 @@ return ret; } +// Fast extra-precise summation. According to +// T. Ogita, S. M. Rump, S. Oishi: +// Accurate Sum And Dot Product, +// SIAM J. Sci. Computing, Vol. 26, 2005 + +template <class T> +inline void twosum_accum (T& s, T& e, + const T& x) +{ + FLOAT_TRUNCATE T s1 = s + x, t = s1 - s, e1 = (s - (s1 - t)) + (x - t); + s = s1; + e += e1; +} + +template <class T> +inline T +mx_inline_xsum (const T *v, octave_idx_type n) +{ + T s = 0, e = 0; + for (octave_idx_type i = 0; i < n; i++) + twosum_accum (s, e, v[i]); + + return s + e; +} + +template <class T> +inline void +mx_inline_xsum (const T *v, T *r, + octave_idx_type m, octave_idx_type n) +{ + OCTAVE_LOCAL_BUFFER (T, e, m); + for (octave_idx_type i = 0; i < m; i++) + e[i] = r[i] = T (); + + for (octave_idx_type j = 0; j < n; j++) + { + for (octave_idx_type i = 0; i < m; i++) + twosum_accum (r[i], e[i], v[i]); + + v += m; + } + + for (octave_idx_type i = 0; i < m; i++) + r[i] += e[i]; +} + +OP_RED_FCNN (mx_inline_xsum, T, T) + #endif /*
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,9 @@ +2009-10-13 Jaroslav Hajek <highegg@gmail.com> + + * data.cc (Fsum): Rewrite. + (Fcumsum): Rewrite. + (NATIVE_REDUCTION, NATIVE_REDUCTION_1): Remove. + 2009-10-12 Jaroslav Hajek <highegg@gmail.com> * pt-binop.cc, pt-unop.cc: Revert the effect of 1be3c73ed7b5.
--- a/src/data.cc +++ b/src/data.cc @@ -1600,7 +1600,118 @@ @seealso{sum, cumprod}\n\ @end deftypefn") { - NATIVE_REDUCTION (cumsum, cumsum); + octave_value retval; + + int nargin = args.length (); + + bool isnative = false; + bool isdouble = false; + + if (nargin > 1 && args(nargin - 1).is_string ()) + { + std::string str = args(nargin - 1).string_value (); + + if (! error_state) + { + if (str == "native") + isnative = true; + else if (str == "double") + isdouble = true; + else + error ("sum: unrecognized string argument"); + nargin --; + } + } + + if (error_state) + return retval; + + if (nargin == 1 || nargin == 2) + { + octave_value arg = args(0); + + int dim = -1; + if (nargin == 2) + { + dim = args(1).int_value () - 1; + if (dim < 0) + error ("cumsum: invalid dimension argument = %d", dim + 1); + } + + if (! error_state) + { + switch (arg.builtin_type ()) + { + case btyp_double: + if (arg.is_sparse_type ()) + retval = arg.sparse_matrix_value ().cumsum (dim); + else + retval = arg.array_value ().cumsum (dim); + break; + case btyp_complex: + if (arg.is_sparse_type ()) + retval = arg.sparse_complex_matrix_value ().cumsum (dim); + else + retval = arg.complex_array_value ().cumsum (dim); + break; + case btyp_float: + if (isdouble) + retval = arg.array_value ().cumsum (dim); + else + retval = arg.float_array_value ().cumsum (dim); + break; + case btyp_float_complex: + if (isdouble) + retval = arg.complex_array_value ().cumsum (dim); + else + retval = arg.float_complex_array_value ().cumsum (dim); + break; + +#define MAKE_INT_BRANCH(X) \ + case btyp_ ## X: \ + if (isnative) \ + retval = arg.X ## _array_value ().cumsum (dim); \ + else \ + retval = arg.array_value ().cumsum (dim); \ + break + MAKE_INT_BRANCH (int8); + MAKE_INT_BRANCH (int16); + MAKE_INT_BRANCH (int32); + MAKE_INT_BRANCH (int64); + MAKE_INT_BRANCH (uint8); + MAKE_INT_BRANCH (uint16); + MAKE_INT_BRANCH (uint32); + MAKE_INT_BRANCH (uint64); +#undef MAKE_INT_BRANCH + + case btyp_bool: + if (arg.is_sparse_type ()) + { + SparseMatrix cs = arg.sparse_matrix_value ().cumsum (dim); + if (isnative) + retval = cs != 0.0; + else + retval = cs; + } + else + { + NDArray cs = arg.bool_array_value ().cumsum (dim); + if (isnative) + retval = cs != 0.0; + else + retval = cs; + } + break; + + default: + gripe_wrong_type_arg ("cumsum", arg); + } + } + } + else + print_usage (); + + return retval; } /* @@ -2553,6 +2664,8 @@ @deftypefn {Built-in Function} {} sum (@var{x})\n\ @deftypefnx {Built-in Function} {} sum (@var{x}, @var{dim})\n\ @deftypefnx {Built-in Function} {} sum (@dots{}, 'native')\n\ +@deftypefnx {Built-in Function} {} sum (@dots{}, 'double')\n\ +@deftypefnx {Built-in Function} {} sum (@dots{}, 'extra')\n\ Sum of elements along dimension @var{dim}. If @var{dim} is\n\ omitted, it defaults to 1 (column-wise sum).\n\ \n\ @@ -2571,10 +2684,136 @@ @result{} true\n\ @end group\n\ @end example\n\ +On the contrary, if 'double' is given, the sum is performed in double precision\n\ +even for single precision inputs.\n\ +For double precision inputs, 'extra' indicates that a more accurate algorithm\n\ +than straightforward summation is to be used. For single precision inputs, 'extra' is\n\ +the same as 'double'. Otherwise, 'extra' has no effect.\n\ @seealso{cumsum, sumsq, prod}\n\ @end deftypefn") { - NATIVE_REDUCTION (sum, any); + octave_value retval; + + int nargin = args.length (); + + bool isnative = false; + bool isdouble = false; + bool isextra = false; + + if (nargin > 1 && args(nargin - 1).is_string ()) + { + std::string str = args(nargin - 1).string_value (); + + if (! error_state) + { + if (str == "native") + isnative = true; + else if (str == "double") + isdouble = true; + else if (str == "extra") + isextra = true; + else + error ("sum: unrecognized string argument"); + nargin --; + } + } + + if (error_state) + return retval; + + if (nargin == 1 || nargin == 2) + { + octave_value arg = args(0); + + int dim = -1; + if (nargin == 2) + { + dim = args(1).int_value () - 1; + if (dim < 0) + error ("sum: invalid dimension argument = %d", dim + 1); + } + + if (! error_state) + { + switch (arg.builtin_type ()) + { + case btyp_double: + if (arg.is_sparse_type ()) + { + if (isextra) + warning ("sum: 'extra' not yet implemented for sparse matrices"); + retval = arg.sparse_matrix_value ().sum (dim); + } + else if (isextra) + retval = arg.array_value ().xsum (dim); + else + retval = arg.array_value ().sum (dim); + break; + case btyp_complex: + if (arg.is_sparse_type ()) + { + if (isextra) + warning ("sum: 'extra' not yet implemented for sparse matrices"); + retval = arg.sparse_complex_matrix_value ().sum (dim); + } + else if (isextra) + retval = arg.complex_array_value ().xsum (dim); + else + retval = arg.complex_array_value ().sum (dim); + break; + case btyp_float: + if (isdouble || isextra) + retval = arg.float_array_value ().dsum (dim); + else + retval = arg.float_array_value ().sum (dim); + break; + case btyp_float_complex: + if (isdouble || isextra) + retval = arg.float_complex_array_value ().dsum (dim); + else + retval = arg.float_complex_array_value ().sum (dim); + break; + +#define MAKE_INT_BRANCH(X) \ + case btyp_ ## X: \ + if (isnative) \ + retval = arg.X ## _array_value ().sum (dim); \ + else \ + retval = arg.X ## _array_value ().dsum (dim); \ + break + MAKE_INT_BRANCH (int8); + MAKE_INT_BRANCH (int16); + MAKE_INT_BRANCH (int32); + MAKE_INT_BRANCH (int64); + MAKE_INT_BRANCH (uint8); + MAKE_INT_BRANCH (uint16); + MAKE_INT_BRANCH (uint32); + MAKE_INT_BRANCH (uint64); +#undef MAKE_INT_BRANCH + + case btyp_bool: + if (arg.is_sparse_type ()) + { + if (isnative) + retval = arg.sparse_bool_matrix_value ().any (dim); + else + retval = arg.sparse_matrix_value ().sum (dim); + } + else if (isnative) + retval = arg.bool_array_value ().any (dim); + else + retval = arg.bool_array_value ().sum (dim); + break; + + default: + gripe_wrong_type_arg ("sum", arg); + } + } + } + else + print_usage (); + + return retval; } /*