Mercurial > hg > octave-nkf
diff src/data.cc @ 8934:c2099a4d12ea
partially optimize accumarray
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Mon, 09 Mar 2009 10:59:19 +0100 |
parents | eb63fbe60fab |
children | 1e4b3149365a |
line wrap: on
line diff
--- a/src/data.cc +++ b/src/data.cc @@ -5724,6 +5724,78 @@ return retval; } +template <class NDT> +static NDT +do_accumarray_sum (const idx_vector& idx, const NDT& vals, + octave_idx_type n = -1) +{ + typedef typename NDT::element_type T; + if (n < 0) + n = idx.extent (0); + else if (idx.extent (n) > n) + error ("accumarray: index out of range"); + + // FIXME: the class tree in liboctave is overly complicated, hence the + // following type gymnastics. + MArray<T> array; + + if (vals.numel () == 1) + { + array = MArray<T> (n, T ()); + array.idx_add (idx, vals (0)); + } + else if (vals.length () == idx.length (n)) + { + array = MArray<T> (n, T ()); + array.idx_add (idx, MArray<T> (vals)); + } + else + error ("accumarray: dimensions mismatch"); + + return NDT (MArrayN<T> (ArrayN<T> (array))); +} + +DEFUN (__accumarray_sum__, args, , + "-*- texinfo -*-\n\ +@deftypefn {Built-in Function} {} __accumarray_sum__ (@var{idx}, @var{vals}, @var{n})\n\ +Undocumented internal function.\n\ +@end deftypefn") +{ + octave_value retval; + int nargin = args.length (); + if (nargin >= 2 && nargin <= 3 && args(0).is_numeric_type ()) + { + idx_vector idx = args(0).index_vector (); + octave_idx_type n = -1; + if (nargin == 3) + n = args(2).idx_type_value (true); + + if (! error_state) + { + octave_value vals = args(1); + if (vals.is_single_type ()) + { + if (vals.is_complex_type ()) + retval = do_accumarray_sum (idx, vals.float_complex_array_value (), n); + else + retval = do_accumarray_sum (idx, vals.float_array_value (), n); + } + else if (vals.is_numeric_type ()) + { + if (vals.is_complex_type ()) + retval = do_accumarray_sum (idx, vals.complex_array_value (), n); + else + retval = do_accumarray_sum (idx, vals.array_value (), n); + } + else + gripe_wrong_type_arg ("accumarray", vals); + } + } + else + print_usage (); + + return retval; +} /*