Mercurial > hg > octave-nkf
diff scripts/set/setdiff.m @ 7920:e56bb65186f6
improve set functions for Matlab compatibility
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Wed, 25 Jun 2008 22:11:07 +0200 |
parents | a1dbe9d80eee |
children | 1bf0ce0930be |
line wrap: on
line diff
--- a/scripts/set/setdiff.m +++ b/scripts/set/setdiff.m @@ -1,4 +1,5 @@ ## Copyright (C) 2000, 2005, 2006, 2007 Paul Kienzle +## Copyright (C) 2008 Jaroslav Hajek ## ## This file is part of Octave. ## @@ -19,19 +20,22 @@ ## -*- texinfo -*- ## @deftypefn {Function File} {} setdiff (@var{a}, @var{b}) ## @deftypefnx {Function File} {} setdiff (@var{a}, @var{b}, "rows") +## @deftypefnx {Function File} {[@var{c}, @var{i}] = } setdiff (@var{a}, @var{b}) ## Return the elements in @var{a} that are not in @var{b}, sorted in ## ascending order. If @var{a} and @var{b} are both column vectors ## return a column vector, otherwise return a row vector. ## ## Given the optional third argument @samp{"rows"}, return the rows in ## @var{a} that are not in @var{b}, sorted in ascending order by rows. +## +## If requested, return @var{i} such that @code{c = a(i)}. ## @seealso{unique, union, intersect, setxor, ismember} ## @end deftypefn ## Author: Paul Kienzle ## Adapted-by: jwe -function c = setdiff (a, b, byrows_arg) +function [c, i] = setdiff (a, b, byrows_arg) if (nargin < 2 || nargin > 3) print_usage (); @@ -50,7 +54,11 @@ endif if (byrows) - c = unique (a, "rows"); + if (nargout > 1) + [c, i] = unique (a, "rows"); + else + c = unique (a, "rows"); + endif if (! isempty (c) && ! isempty (b)) ## Form a and b into combined set. b = unique (b, "rows"); @@ -58,9 +66,16 @@ ## Eliminate those elements of a that are the same as in b. dups = find (all (dummy(1:end-1,:) == dummy(2:end,:), 2)); c(idx(dups),:) = []; + if (nargout > 1) + i(idx(dups),:) = []; + endif endif else - c = unique (a); + if (nargout > 1) + [c, i] = unique (a); + else + c = unique (a); + endif if (! isempty (c) && ! isempty (b)) ## Form a and b into combined set. b = unique (b); @@ -72,6 +87,9 @@ dups = find (dummy(1:end-1) == dummy(2:end)); endif c(idx(dups)) = []; + if (nargout > 1) + i(idx(dups)) = []; + endif ## Reshape if necessary. if (size (c, 1) != 1 && size (b, 1) == 1) c = c.'; @@ -88,3 +106,9 @@ %!assert(setdiff([1; 2; 3; 4], [1; 2; 4], "rows"), 3) %!assert(setdiff([1, 2; 3, 4], [1, 2; 3, 6], "rows"), [3, 4]) %!assert(setdiff({"one","two";"three","four"},{"one","two";"three","six"}), {"four"}) + +%!test +%! a = [3, 1, 4, 1, 5]; b = [1, 2, 3, 4]; +%! [y, i] = setdiff (a, b.'); +%! assert(y, [5]); +%! assert(y, a(i));