changeset 7128:73308b8f8777

[project @ 2007-11-08 03:55:04 by jwe]
author jwe
date Thu, 08 Nov 2007 03:55:04 +0000
parents d31f017af3a4
children 363ffc8a5c80
files scripts/ChangeLog scripts/set/ismember.m
diffstat 2 files changed, 245 insertions(+), 92 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/ChangeLog
+++ b/scripts/ChangeLog
@@ -1,3 +1,9 @@
+2007-11-07  Ben Abbott  <bpabbott@mac.com>
+
+	* set/ismember.m: Call cell_ismember to handle cellstr args.
+	New tests.
+	(cell_ismember): New function.
+
 2007-11-07  John W. Eaton  <jwe@octave.org>
 
 	* control/base/__bodquist__.m, control/base/__freqresp__.m,
--- a/scripts/set/ismember.m
+++ b/scripts/set/ismember.m
@@ -17,118 +17,265 @@
 ## <http://www.gnu.org/licenses/>.
 
 ## -*- texinfo -*-
-## @deftypefn {Function File} {} ismember (@var{A}, @var{S})
-## Return a matrix the same shape as @var{A} which has 1 if
-## @code{A(i,j)} is in @var{S} or 0 if it isn't.
+## @deftypefn {Function File} {[@var{tf}, @var{a_idx}] =} ismember (@var{A}, @var{S}) 
+## Return a matrix @var{tf} the same shape as @var{A} which has 1 if 
+## @code{A(i,j)} is in @var{S} or 0 if it isn't. If a second output argument 
+## is requested, the indexes into  @var{S} of the matching elements is 
+## also returned. 
+##
+## @example
+## @group
+## a = [3, 10, 1];
+## s = [0:9];
+## [tf, a_idx] = residue (a, s);
+##      @result{} tf = [1, 0, 1]
+##      @result{} a_idx = [4, 0, 2]
+## @end group
+## @end example
+##
+## The inputs, @var{A} and @var{S}, may also be cell arrays.
+##
+## @example
+## @group
+## a = @{'abc'@};
+## s = @{'abc', 'def'@};
+## [tf, a_idx] = residue (a, s);
+##      @result{} tf = [1, 0]
+##      @result{} a_idx = [1, 0]
+## @end group
+## @end example
+##
+## @deftypefnx {Function File} {[@var{tf}, @var{a_idx}] =} ismember (@var{A}, @var{S}, 'rows')
+## When @var{A} and @var{S} are matrices with the same number of columes,
+## the row vectors may matched.
+##
+## @example
+## @group
+## a = [1:3; 5:7; 4:6];
+## s = [0:2; 1:3; 2:4; 3:5; 4:6];
+## [tf, a_idx] = ismember(a, s, 'rows');
+##      @result{} tf = logical ([1; 0; 1])
+##      @result{} a_idx = [2; 0; 5];
+## @end group
+## @end example
+##
 ## @seealso{unique, union, intersection, setxor, setdiff}
 ## @end deftypefn
 
 ## Author: Paul Kienzle
+## Author: Søren Hauberg
+## Author: Ben Abbott
 ## Adapted-by: jwe
 
-function c = ismember (a, S)
+function [tf, a_idx] = ismember (a, s, rows_opt) 
 
-  if (nargin != 2)
+  if (nargin == 2 || nargin == 3) 
+    if (iscell(a) || iscell(s))
+      if (nargin == 3)
+        error ("ismember: with 'rows' both sets must be matrices"); 
+      else
+        [tf, a_idx] = cell_ismember (a, s);
+      endif
+    else
+      if (nargin == 3) 
+        ## The 'rows' argument is handled in a fairly ugly way. A better
+        ## solution would be to vectorize this loop over 'r' below.
+        if (strcmpi (rows_opt, "rows") && ismatrix (a) && ismatrix (s) && ...
+            columns (a) == columns (s)) 
+          rs = rows (s);
+          ra = rows (a);
+          a_idx = zeros (ra, 1);
+          for r = 1:ra
+           tmp = ones (rs, 1) * a(r,:);
+            f = find (all (tmp' == s'), 1);
+            if ! isempty (f)
+              a_idx(r) = f;
+            endif
+          endfor
+          tf = logical (a_idx);
+        elseif strcmpi (rows_opt, "rows")
+          error ("ismember: with 'rows' both sets must be matrices with an equal number of columns"); 
+        else
+          error ("ismember: invalid input"); 
+        endif
+      else
+        ## Input checking 
+        if ( ! isa (a, class (s)) ) 
+          error ("ismember: both input arguments must be the same type");
+        elseif ( ! ischar (a) && ! isnumeric (a) )
+          error ("ismember: input arguments must be arrays, cell arrays, or strings"); 
+        elseif ( ischar (a) && ischar (s) ) 
+          a = uint8 (a);
+          s = uint8 (s);
+        endif
+        ## Convert matrices to vectors 
+        if (all (size (a) > 1))
+          a = a(:);
+        endif 
+        if (all (size (s) > 1))
+          s = s(:);
+        endif 
+        ## Do the actual work
+        if (isempty (a) || isempty (s))
+          tf = zeros (size (a), "logical");
+          a_idx = []; 
+        elseif (numel (s) == 1) 
+          tf = (a == s);
+          a_idx = double (tf);
+        elseif (numel (a) == 1) 
+          f = find (a == s, 1); 
+          tf = !isempty (f); 
+          a_idx = f; 
+          if (isempty (a_idx))
+            a_idx = 0;
+          endif 
+        else
+          ## Magic:  the following code determines for each a, the index i 
+          ## such that s(i)<= a < s(i+1).  It does this by sorting the a 
+          ## into s and remembering the source index where each element came 
+          ## from.  Since all the a's originally came after all the s's, if 
+          ## the source index is less than the length of s, then the element 
+          ## came from s.  We can then do a cumulative sum on the indices to 
+          ## figure out which element of s each a comes after. 
+          ## E.g., s=[2 4 6], a=[1 2 3 4 5 6 7] 
+          ##    unsorted [s a]  = [ 2 4 6 1 2 3 4 5 6 7 ] 
+          ##    sorted [s a]    = [ 1 2 2 3 4 4 5 6 6 7 ] 
+          ##    source index p  = [ 4 1 5 6 2 7 8 3 9 10 ] 
+          ##    boolean p<=l(s) = [ 0 1 0 0 1 0 0 1 0 0 ] 
+          ##    cumsum(p<=l(s)) = [ 0 1 1 1 2 2 2 3 3 3 ] 
+          ## Note that this leaves a(1) coming after s(0) which doesn't 
+          ## exist.  So arbitrarily, we will dump all elements less than 
+          ## s(1) into the interval after s(1).  We do this by dropping s(1) 
+          ## from the sort!  E.g., s=[2 4 6], a=[1 2 3 4 5 6 7] 
+          ##    unsorted [s(2:3) a] =[4 6 1 2 3 4 5 6 7 ] 
+          ##    sorted [s(2:3) a] = [ 1 2 3 4 4 5 6 6 7 ] 
+          ##    source index p    = [ 3 4 5 1 6 7 2 8 9 ] 
+          ##    boolean p<=l(s)-1 = [ 0 0 0 1 0 0 1 0 0 ] 
+          ##    cumsum(p<=l(s)-1) = [ 0 0 0 1 1 1 2 2 2 ] 
+          ## Now we can use Octave's lvalue indexing to "invert" the sort, 
+          ## and assign all these indices back to the appropriate a and s, 
+          ## giving s_idx = [ -- 1 2], a_idx = [ 0 0 0 1 1 2 2 ].  Add 1 to 
+          ## a_idx, and we know which interval s(i) contains a.  It is 
+          ## easy to now check membership by comparing s(a_idx) == a.  This 
+          ## magic works because s starts out sorted, and because sort 
+          ## preserves the relative order of identical elements. 
+          lt = numel(s); 
+          [s, sidx] = sort (s); 
+          [v, p] = sort ([s(2:lt)(:); a(:)]); 
+          idx(p) = cumsum (p <= lt-1) + 1; 
+          idx = idx(lt:end); 
+          tf = (a == reshape (s (idx), size (a))); 
+          a_idx = zeros (size(tf)); 
+          a_idx(tf) = sidx(idx(tf));
+        endif
+        ## Resize result to the original size of 'a' 
+        size_a = size(a);
+        tf = reshape (tf, size_a); 
+        a_idx = reshape (a_idx, size_a);
+      endif
+    endif
+  else
     print_usage ();
   endif
 
-  if (isempty (a) || isempty (S))
-    c = zeros (size (a), "logical");
-  else
-    if (iscell (a) && ! iscell (S))
-      tmp{1} = S;
-      S = tmp;
-    endif
-    if (! iscell (a) && iscell (S))
-      tmp{1} = a;
-      a = tmp;
-    endif
-    S = unique (S(:));
-    lt = length (S);
-    if (lt == 1)
-      if (iscell (a) || iscell (S))
-        c = cellfun ("length", a) == cellfun ("length", S);
-        idx = find (c);
-	if (isempty (idx))
-	  c = zeros (size (a), "logical");
-	else
-	  c(idx) = all (char (a(idx)) == repmat (char (S), length (idx), 1), 2);
-	endif
+endfunction
+
+function [tf, a_idx] = cell_ismember (a, s)
+  if (nargin == 2)
+    if (ischar (a) && iscellstr (s)) 
+      if (isempty (a)) # Work around bug in 'cellstr' 
+        a = {''};
+      else
+        a = cellstr(a);
+      endif
+    elseif (iscellstr (a) && ischar (s))
+      if (isempty (s)) # Work around bug in 'cellstr' 
+        s = {''};
       else
-        c = (a == S);
+        s = cellstr(s);
       endif
-    elseif (numel (a) == 1)
-      if (iscell (a) || iscell (S))
-        c = cellfun ("length", a) == cellfun ("length", S);
-        idx = find (c);
-	if (isempty (idx))
-	  c = zeros (size (a), "logical");
-	else
-          c(idx) = all (repmat (char (a), length (idx), 1) == char (S(idx)), 2);
-          c = any(c);
-	endif
-      else
-        c = any (a == S);
+    endif 
+    if (iscellstr (a) && iscellstr (s))
+      ## Do the actual work
+      if (isempty (a) || isempty (s))
+        tf = zeros (size (a), "logical");
+        a_idx = []; 
+      elseif (numel (s) == 1) 
+        tf = strcmp (a, s);
+        a_idx = double (tf);
+      elseif (numel (a) == 1) 
+        f = find (strcmp (a, s), 1); 
+        tf = !isempty (f);
+        a_idx = f; 
+        if (isempty (a_idx))
+          a_idx = 0;
+        endif 
+      else 
+        lt = numel(s);
+        [s, sidx] = sort (s);
+        [v, p] = sort ([s(2:lt)(:); a(:)]);
+        idx(p) = cumsum (p <= lt-1) + 1;
+        idx = idx(lt:end);
+        tf = (cellfun ("length", a) 
+            == reshape (cellfun ("length", s(idx)), size (a)));
+        idx2 = find (tf);
+        tf(idx2) = all (char (a(idx2)) == char (s(idx)(idx2)), 2);
+        a_idx = zeros (size (tf));
+        a_idx(tf) = sidx(idx(tf));
       endif
     else
-      ## Magic:  the following code determines for each a, the index i
-      ## such that S(i)<= a < S(i+1).  It does this by sorting the a
-      ## into S and remembering the source index where each element came
-      ## from.  Since all the a's originally came after all the S's, if 
-      ## the source index is less than the length of S, then the element
-      ## came from S.  We can then do a cumulative sum on the indices to
-      ## figure out which element of S each a comes after.
-      ## E.g., S=[2 4 6], a=[1 2 3 4 5 6 7]
-      ##    unsorted [S a]  = [ 2 4 6 1 2 3 4 5 6 7 ]
-      ##    sorted [ S a ]  = [ 1 2 2 3 4 4 5 6 6 7 ] 
-      ##    source index p  = [ 4 1 5 6 2 7 8 3 9 10 ]
-      ##    boolean p<=l(S) = [ 0 1 0 0 1 0 0 1 0 0 ]
-      ##    cumsum(p<=l(S)) = [ 0 1 1 1 2 2 2 3 3 3 ]
-      ## Note that this leaves a(1) coming after S(0) which doesn't
-      ## exist.  So arbitrarily, we will dump all elements less than
-      ## S(1) into the interval after S(1).  We do this by dropping S(1)
-      ## from the sort!  E.g., S=[2 4 6], a=[1 2 3 4 5 6 7]
-      ##    unsorted [S(2:3) a] =[4 6 1 2 3 4 5 6 7 ]
-      ##    sorted [S(2:3) a] = [ 1 2 3 4 4 5 6 6 7 ]
-      ##    source index p    = [ 3 4 5 1 6 7 2 8 9 ]
-      ##    boolean p<=l(S)-1 = [ 0 0 0 1 0 0 1 0 0 ]
-      ##    cumsum(p<=l(S)-1) = [ 0 0 0 1 1 1 2 2 2 ]
-      ## Now we can use Octave's lvalue indexing to "invert" the sort,
-      ## and assign all these indices back to the appropriate A and S,
-      ## giving S_idx = [ -- 1 2], a_idx = [ 0 0 0 1 1 2 2 ].  Add 1 to
-      ## a_idx, and we know which interval S(i) contains a.  It is
-      ## easy to now check membership by comparing S(a_idx) == a.  This
-      ## magic works because S starts out sorted, and because sort
-      ## preserves the relative order of identical elements.
-      [v, p] = sort ([S(2:lt); a(:)]);
-      idx(p) = cumsum (p <= lt-1) + 1;
-      idx = idx(lt:end);
-      if (iscell (a) || iscell (S))
-        c = (cellfun ("length", a)
-	     == reshape (cellfun ("length", S(idx)), size (a)));
-        idx2 = find (c);
-        c(idx2) = all (char (a(idx2)) == char (S(idx)(idx2)), 2);
-      else
-        c = (a == reshape (S (idx), size (a)));
-      endif
+      error ("cell_ismember: arguments must be cell arrays of character strings");
     endif
+  else
+    print_usage ();
   endif
-
+  ## Resize result to the original size of 'a' 
+  size_a = size(a);
+  tf = reshape (tf, size_a); 
+  a_idx = reshape (a_idx, size_a); 
 endfunction
 
 %!assert (ismember ({''}, {'abc', 'def'}), false);
 %!assert (ismember ('abc', {'abc', 'def'}), true);
 %!assert (isempty (ismember ([], [1, 2])), true);
 %!assert (isempty (ismember ({}, {'a', 'b'})), true);
-%!xtest assert (ismember ('', {'abc', 'def'}), false);
-%!xtest fail ('ismember ([], {1, 2})', 'error:.*');
+%!assert (ismember ('', {'abc', 'def'}), false);
+%!fail ('ismember ([], {1, 2})', 'error:.*');
 %!fail ('ismember ({[]}, {1, 2})', 'error:.*');
-%!xtest fail ('ismember ({}, {1, 2})', 'error:.*');
-%!assert (ismember ({'foo', 'bar'}, {'foobar'}), logical ([0, 0]))
-%!assert (ismember ({'foo'}, {'foobar'}), false)
-%!assert (ismember ({'bar'}, {'foobar'}), false)
-%!assert (ismember ({'bar'}, {'foobar', 'bar'}), true)
-%!assert (ismember ({'foo', 'bar'}, {'foobar', 'bar'}), logical ([0, 1]))
-%!assert (ismember ({'xfb', 'f', 'b'}, {'fb', 'b'}), logical ([0, 0, 1]))
-%!assert (ismember ("1", "0123456789."), true)
-%!assert (ismember ("1.1", "0123456789."), logical ([1, 1, 1]))
+%!fail ('ismember ({}, {1, 2})', 'error:.*');
+%!fail ('ismember ({1}, {''1'', ''2''})', 'error:.*');
+%!fail ('ismember (1, ''abc'')', 'error:.*');
+%!fail ('ismember ({''1''}, {''1'', ''2''},''rows'')', 'error:.*');
+%!fail ('ismember ([1 2 3], [5 4 3 1], ''rows'')', 'error:.*');
+%!assert (ismember ({'foo', 'bar'}, {'foobar'}), logical ([0, 0]));
+%!assert (ismember ({'foo'}, {'foobar'}), false);
+%!assert (ismember ({'bar'}, {'foobar'}), false);
+%!assert (ismember ({'bar'}, {'foobar', 'bar'}), true);
+%!assert (ismember ({'foo', 'bar'}, {'foobar', 'bar'}), logical ([0, 1]));
+%!assert (ismember ({'xfb', 'f', 'b'}, {'fb', 'b'}), logical ([0, 0, 1]));
+%!assert (ismember ("1", "0123456789."), true);
+
+%!test
+%! [result, a_idx] = ismember([1 2 3 4 5], [3]);
+%! assert (all (result == logical ([0 0 1 0 0])) && all (a_idx == [0 0 1 0 0]));
+
+%!test
+%! [result, a_idx] = ismember([1 6], [1 2 3 4 5 1 6 1]);
+%! assert (all (result == logical ([1 1])) && all (a_idx == [8 7]));
+
+%!test
+%! [result, a_idx] = ismember ([3,10,1], [0,1,2,3,4,5,6,7,8,9]);
+%! assert (all (result == logical ([1, 0, 1])) && all (a_idx == [4, 0, 2]));
+
+%!test
+%! [result, a_idx] = ismember ("1.1", "0123456789.1");
+%! assert (all (result == logical ([1, 1, 1])) && all (a_idx == [12, 11, 12]));
+
+%!test
+%! [result, a_idx] = ismember([1:3; 5:7; 4:6], [0:2; 1:3; 2:4; 3:5; 4:6], 'rows');
+%! assert (all (result == logical ([1; 0; 1])) && all (a_idx == [2; 0; 5]));
+
+%!test
+%! [result, a_idx] = ismember([1.1,1.2,1.3; 2.1,2.2,2.3; 10,11,12], [1.1,1.2,1.3; 10,11,12; 2.12,2.22,2.32], 'rows');
+%! assert (all (result == logical ([1; 0; 1])) && all (a_idx == [1; 0; 2]));
+