Mercurial > hg > octave-lyh
diff scripts/general/accumarray.m @ 8934:c2099a4d12ea
partially optimize accumarray
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Mon, 09 Mar 2009 10:59:19 +0100 |
parents | eb63fbe60fab |
children | 1bf0ce0930be |
line wrap: on
line diff
--- a/scripts/general/accumarray.m +++ b/scripts/general/accumarray.m @@ -1,4 +1,5 @@ ## Copyright (C) 2007, 2008, 2009 David Bateman +## Copyright (C) 2009 VZLU Prague ## ## This file is part of Octave. ## @@ -64,20 +65,6 @@ endif ndims = size (subs, 2); - if (nargin < 3 || isempty (sz)) - sz = max (subs); - if (isscalar(sz)) - sz = [sz, 1]; - endif - elseif (length (sz) != ndims - && (ndims != 1 || length (sz) != 2 || sz(2) != 1)) - error ("accumarray: inconsistent dimensions"); - endif - - if (nargin < 4 || isempty (fun)) - fun = @sum; - endif - if (nargin < 5 || isempty (fillval)) fillval = 0; endif @@ -90,6 +77,71 @@ error ("accumarray: sparse matrices limited to 2 dimensions"); endif + if (nargin < 4 || isempty (fun)) + fun = @sum; + ## This is the fast summation case. Unlike the general case, + ## this case will be handled using an O(N) algorithm. + + if (isspar && fillval == 0) + ## The "sparse" function can handle this case. + + if ((nargin < 3 || isempty (sz))) + A = sparse (subs(:,1), subs(:,2), val); + elseif (length (sz) == 2) + A = sparse (subs(:,1), subs(:,2), val, sz(1), sz(2)); + else + error ("accumarray: dimensions mismatch") + endif + else + ## This case is handled by an internal function. + + if (ndims > 1) + if ((nargin < 3 || isempty (sz))) + sz = max (subs); + elseif (ndims != length (sz)) + error ("accumarray: dimensions mismatch") + elseif (any (max (subs) > sz)) + error ("accumarray: index out of range") + endif + + ## Convert multidimensional subscripts. + subs = sub2ind (sz, mat2cell (subs, rows (subs), ones (1, ndims)){:}); + elseif (nargin < 3) + ## In case of linear indexing, the fast built-in accumulator + ## will determine the extent for us. + sz = []; + endif + + ## Call the built-in accumulator. + if (isempty (sz)) + A = __accumarray_sum__ (subs, val); + else + A = __accumarray_sum__ (subs, val, prod (sz)); + ## set proper shape. + A = reshape (A, sz); + endif + + ## we fill in nonzero fill value. + if (fillval != 0) + mask = true (size (A)); + mask(subs) = false; + A(mask) = fillval; + endif + endif + + return + endif + + if (nargin < 3 || isempty (sz)) + sz = max (subs); + if (isscalar(sz)) + sz = [sz, 1]; + endif + elseif (length (sz) != ndims + && (ndims != 1 || length (sz) != 2 || sz(2) != 1)) + error ("accumarray: inconsistent dimensions"); + endif + [subs, idx] = sortrows (subs); if (isscalar (val))