diff scripts/general/accumarray.m @ 10268:9a16a61ed43d

new optimizations for accumarray
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 05 Feb 2010 12:09:21 +0100
parents 5919f2bd9a99
children 217d36560dfa
line wrap: on
line diff
--- a/scripts/general/accumarray.m
+++ b/scripts/general/accumarray.m
@@ -54,65 +54,107 @@
 ## @end example
 ## @end deftypefn
 
-function A = accumarray (subs, val, sz, func, fillval, isspar)  
+function A = accumarray (subs, val, sz = [], func = [], fillval = [], isspar = [])  
 
   if (nargin < 2 || nargin > 6)
     print_usage ();
   endif
 
   if (iscell (subs))
-    subs = cell2mat (cellfun (@(x) x(:), subs, "UniformOutput", false));
+    subs = cellfun (@(x) x(:), subs, "UniformOutput", false);
+    ndims = numel (subs);
+    if (ndims == 1)
+      subs = subs{1};
+    endif
+  else
+    ndims = columns (subs);
   endif
-  ndims = size (subs, 2);
 
-  if (nargin < 5 || isempty (fillval))
+  if (isempty (fillval))
     fillval = 0;
   endif
 
-  if (nargin < 6 || isempty (isspar))
+  if (isempty (isspar))
     isspar = false;
   endif
 
-  if (isspar && ndims > 2)
-    error ("accumarray: sparse matrices limited to 2 dimensions");
-  endif
+  if (isspar)
+
+    ## Sparse case. Avoid linearizing the subscripts, because it could overflow.
+
+    if (fillval != 0)
+      error ("accumarray: fillval must be zero in the sparse case");
+    endif
+
+    ## Ensure subscripts are a two-column matrix.
+    if (iscell (subs))
+      subs = [subs{:}];
+    endif
 
-  if (nargin < 4 || isempty (func))
-    func = @sum;
-    ## This is the fast summation case. Unlike the general case,
-    ## this case will be handled using an O(N) algorithm.
+    ## Validate dimensions.
+    if (ndims == 1)
+      subs(:,2) = 1;
+    elseif (ndims != 2)
+      error ("accumarray: in the sparse case, needs 1 or 2 subscripts");
+    endif
+
+    if (isnumeric (val) || islogical (val))
+      vals = double (val);
+    else
+      error ("accumarray: in the sparse case, values must be numeric or logical");
+    endif
+
+    if (! (isempty (func) || func == @sum))
 
-    if (isspar && fillval == 0)
-      ## The "sparse" function can handle this case.
+      ## Reduce values. This is not needed if we're about to sum them, because
+      ## "sparse" can do that.
+      
+      ## Sort indices.
+      [subs, idx] = sortrows (subs);
+      n = rows (subs);
+      ## Identify runs.
+      jdx = find (any (diff (subs, 1, 1), 2));
+      jdx = [jdx; n];
+
+      val = cellfun (func, mat2cell (val(:)(idx), diff ([0; jdx])));
+      subs = subs(jdx, :);
+    endif
 
-      if ((nargin < 3 || isempty (sz)))
-        A = sparse (subs(:,1), subs(:,2), val);
-      elseif (length (sz) == 2)
-        A = sparse (subs(:,1), subs(:,2), val, sz(1), sz(2));
-      else
+    ## Form the sparse matrix.
+    if (isempty (sz))
+      A = sparse (subs(:,1), subs(:,2), val);
+    elseif (length (sz) == 2)
+      A = sparse (subs(:,1), subs(:,2), val, sz(1), sz(2));
+    else
+      error ("accumarray: dimensions mismatch")
+    endif
+
+  else
+
+    ## Linearize subscripts.
+    if (ndims > 1)
+      if (isempty (sz))
+        if (iscell (subs))
+          sz = cellfun (@max, subs);
+        else
+          sz = max (subs, [], 1);
+        endif
+      elseif (ndims != length (sz))
         error ("accumarray: dimensions mismatch")
       endif
-    else
-      ## This case is handled by an internal function.
 
-      if (ndims > 1)
-        if ((nargin < 3 || isempty (sz)))
-          sz = max (subs);
-        elseif (ndims != length (sz))
-          error ("accumarray: dimensions mismatch")
-        elseif (any (max (subs) > sz))
-          error ("accumarray: index out of range")
-        endif
+      ## Convert multidimensional subscripts.
+      if (ismatrix (subs))
+        subs = num2cell (subs, 1);
+      endif
+      subs = sub2ind (sz, subs{:});
+    endif
 
-        ## Convert multidimensional subscripts.
-        subs = sub2ind (sz, num2cell (subs, 1){:});
-      elseif (nargin < 3)
-        ## In case of linear indexing, the fast built-in accumulator
-        ## will determine the extent for us.
-        sz = [];
-      endif
+
+    ## Some built-in reductions handled efficiently.
 
-      ## Call the built-in accumulator.
+    if (isempty (func) || func == @sum)
+      ## Fast summation.
       if (isempty (sz))
         A = __accumarray_sum__ (subs, val);
       else
@@ -127,47 +169,87 @@
         mask(subs) = false;
         A(mask) = fillval;
       endif
-    endif
+    elseif (func == @max)
+      ## Fast maximization.
 
-    return
-  endif
+      if (isinteger (val))
+        zero = intmin (class (val));
+      elseif (fillval == 0 && all (val(:) >= 0))
+        ## This is a common case - fillval is zero, all numbers nonegative.
+        zero = 0;
+      else
+        zero = NaN; # Neutral value.
+      endif
+
+      if (isempty (sz))
+        A = __accumarray_max__ (subs, val, zero);
+      else
+        A = __accumarray_max__ (subs, val, zero, prod (sz));
+        A = reshape (A, sz);
+      endif
 
-  if (nargin < 3 || isempty (sz))
-    sz = max (subs);
-    if (isscalar(sz))
-      sz = [sz, 1];
-    endif
-  elseif (length (sz) != ndims
-	  && (ndims != 1 || length (sz) != 2 || sz(2) != 1))
-    error ("accumarray: inconsistent dimensions");
-  endif
-  
-  [subs, idx] = sortrows (subs);
+      if (fillval != zero && isnan (fillval) != isnan (zero))
+        mask = true (size (A));
+        mask(subs) = false;
+        A(mask) = fillval;
+      endif
+    elseif (func == @min)
+      ## Fast minimization.
+
+      if (isinteger (val))
+        zero = intmax (class (val));
+      else
+        zero = NaN; # Neutral value.
+      endif
+
+      if (isempty (sz))
+        A = __accumarray_min__ (subs, val, zero);
+      else
+        A = __accumarray_min__ (subs, val, zero, prod (sz));
+        A = reshape (A, sz);
+      endif
 
-  if (isscalar (val))
-    val = repmat (size (idx));
-  else
-    val = val(idx);
-  endif
-  cidx = find ([true; (sum (abs (diff (subs)), 2) != 0)]);
-  idx = cell (1, ndims);
-  for i = 1:ndims
-    idx{i} = subs (cidx, i);
-  endfor
-  x = cellfun (func, mat2cell (val(:), diff ([cidx; length(val) + 1])));
-  if (isspar && fillval == 0)
-    A = sparse (idx{1}, idx{2}, x, sz(1), sz(2));
-  else
-    if (iscell (x))
-      ## Why did matlab choose to reverse the order of the elements
-      x = cellfun (@(x) flipud (x(:)), x, "UniformOutput", false);
-      A = cell (sz);
-    elseif (fillval == 0)
-      A = zeros (sz, class (x));
-    else 
-      A = fillval .* ones (sz);
+      if (fillval != zero && isnan (fillval) != isnan (zero))
+        mask = true (size (A));
+        mask(subs) = false;
+        A(mask) = fillval;
+      endif
+    else
+
+      ## The general case. Reduce values. 
+      n = rows (subs);
+      if (numel (val) == 1)
+        val = val(ones (1, n), 1);
+      else
+        val = val(:);
+      endif
+      
+      ## Sort indices.
+      [subs, idx] = sort (subs);
+      ## Identify runs.
+      jdx = find (diff (subs, 1, 1));
+      jdx = [jdx; n];
+      val = mat2cell (val(idx), diff ([0; jdx]));
+      ## Optimize the case when function is @(x) {x}, i.e. we just want to
+      ## collect the values to cells.
+      persistent simple_cell_str = func2str (@(x) {x});
+      if (! strcmp (func2str (func), simple_cell_str))
+        val = cellfun (func, val);
+      endif
+      subs = subs(jdx);
+
+      ## Construct matrix of fillvals.
+      if (iscell (val))
+        A = cell (sz);
+      elseif (fillval == 0)
+        A = zeros (sz, class (val));
+      else
+        A = repmat (fillval, sz);
+      endif
+
+      ## Set the reduced values.
+      A(subs) = val;
     endif
-    A(sub2ind (sz, idx{:})) = x;
   endif
 endfunction
 
@@ -183,4 +265,24 @@
 %!assert (accumarray ([1 1; 2 1; 2 3; 2 1; 2 3],101:105,[2,4],@(x)length(x)>1),[false,false,false,false;true,false,true,false])
 %!test
 %! A = accumarray ([1 1; 2 1; 2 3; 2 1; 2 3],101:105,[2,4],@(x){x});
-%! assert (A{2},[104;102])
+%! assert (A{2},[102;104])
+%!test
+%! subs = ceil (rand (2000, 3)*10);
+%! val = rand (2000, 1);
+%! assert (accumarray (subs, val, [], @max), accumarray (subs, val, [], @(x) max (x)));
+%!test
+%! subs = ceil (rand (2000, 1)*100);
+%! val = rand (2000, 1);
+%! assert (accumarray (subs, val, [100, 1], @min, NaN), accumarray (subs, val, [100, 1], @(x) min (x), NaN));
+%!test
+%! subs = ceil (rand (2000, 2)*30);
+%! subsc = num2cell (subs, 1);
+%! val = rand (2000, 1);
+%! assert (accumarray (subsc, val, [], [], 0, true), accumarray (subs, val, [], [], 0, true));
+%!test
+%! subs = ceil (rand (2000, 3)*10);
+%! subsc = num2cell (subs, 1);
+%! val = rand (2000, 1);
+%! assert (accumarray (subsc, val, [], @max), accumarray (subs, val, [], @max));
+
+