Mercurial > hg > octave-lyh
view scripts/testfun/assert.m @ 17258:afd235a206a2
Allow vector/matrix tolerance and improve error messages for assert.m script
* assert.m
(assert): Document non-scalar tolerance option. Remove FIXME about format of
output. Remove 'coda' and 'iserror' spanning whole routine. Use structure
'err.index/expected/observed/reason' to keep track of multiple results and
recursions. Add persistent variables 'errmsg', 'assert_call_depth' and
'assert_error_occurred' to allow recursions and print only when all complete.
Place output formating in pprint() function. Construct vector tolerance from
scalar tolerance. Add test illustrating recursions and multiple tables. Add
test illustrating variable tolerance. Add test illustrating multidimensional
matrices. Remove looping for constructing error information. Add thorough
tests for exceptional values by checking both real and imaginary. Place zeros
where exceptional values exist in real and imaginary parts of the two matrices.
Add tests illustrating exceptional values in real and/or imaginary part and
numerical mismatch in the other part.
(construct_indeces): Format linear indexing as tuple indexing, vectors (#),
scalars ().
(pprint): Sub function to format and print input command, index of failure,
expected and observed values at failure, and the reason for failure.
author | Daniel J Sebald <daniel.sebald@ieee.org> |
---|---|
date | Mon, 12 Aug 2013 15:44:40 -0500 |
parents | d6499c14021c |
children | 684ccccbc15d |
line wrap: on
line source
## Copyright (C) 2000-2012 Paul Kienzle ## ## This file is part of Octave. ## ## Octave is free software; you can redistribute it and/or modify it ## under the terms of the GNU General Public License as published by ## the Free Software Foundation; either version 3 of the License, or (at ## your option) any later version. ## ## Octave is distributed in the hope that it will be useful, but ## WITHOUT ANY WARRANTY; without even the implied warranty of ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU ## General Public License for more details. ## ## You should have received a copy of the GNU General Public License ## along with Octave; see the file COPYING. If not, see ## <http://www.gnu.org/licenses/>. ## -*- texinfo -*- ## @deftypefn {Function File} {} assert (@var{cond}) ## @deftypefnx {Function File} {} assert (@var{cond}, @var{errmsg}, @dots{}) ## @deftypefnx {Function File} {} assert (@var{cond}, @var{msg_id}, @var{errmsg}, @dots{}) ## @deftypefnx {Function File} {} assert (@var{observed}, @var{expected}) ## @deftypefnx {Function File} {} assert (@var{observed}, @var{expected}, @var{tol}) ## ## Produce an error if the specified condition is not met. @code{assert} can ## be called in three different ways. ## ## @table @code ## @item assert (@var{cond}) ## @itemx assert (@var{cond}, @var{errmsg}, @dots{}) ## @itemx assert (@var{cond}, @var{msg_id}, @var{errmsg}, @dots{}) ## Called with a single argument @var{cond}, @code{assert} produces an ## error if @var{cond} is zero. When called with more than one argument the ## additional arguments are passed to the @code{error} function. ## ## @item assert (@var{observed}, @var{expected}) ## Produce an error if observed is not the same as expected. Note that ## @var{observed} and @var{expected} can be scalars, vectors, matrices, ## strings, cell arrays, or structures. ## ## @item assert (@var{observed}, @var{expected}, @var{tol}) ## Produce an error if observed is not the same as expected but equality ## comparison for numeric data uses a tolerance @var{tol}. ## If @var{tol} is positive then it is an absolute tolerance which will produce ## an error if @code{abs (@var{observed} - @var{expected}) > abs (@var{tol})}. ## If @var{tol} is negative then it is a relative tolerance which will produce ## an error if @code{abs (@var{observed} - @var{expected}) > ## abs (@var{tol} * @var{expected})}. If @var{expected} is zero @var{tol} will ## always be interpreted as an absolute tolerance. If @var{tol} is not scalar ## its dimensions must agree with those of @var{observed} and @var{expected} ## and tests are performed on an element-wise basis. ## @end table ## @seealso{test, fail, error} ## @end deftypefn function assert (cond, varargin) if (exist ("assert_call_depth", "var")) assert_call_depth++; else persistent assert_call_depth = 0; persistent assert_error_occurred; assert_error_occurred = 0; persistent errmsg; errmsg = ""; end in = deblank (argn(1,:)); for i = 2:rows (argn) in = [in "," deblank(argn(i,:))]; endfor in = ["(" in ")"]; if (nargin == 1 || (nargin > 1 && islogical (cond) && ischar (varargin{1}))) if ((! isnumeric (cond) && ! islogical (cond)) || ! all (cond(:))) if (nargin == 1) ## Say which elements failed? error ("assert %s failed", in); else error (varargin{:}); endif endif else if (nargin < 2 || nargin > 3) print_usage (); endif expected = varargin{1}; if (nargin < 3) tol = 0; else tol = varargin{2}; endif if (exist ("argn") == 0) argn = " "; endif ## Add to lists as the errors accumulate. If empty at end then no erros. err.index = {}; err.observed = {}; err.expected = {}; err.reason = {}; if (ischar (expected)) if (! ischar (cond)) err.index{end + 1} = "[]"; err.expected{end + 1} = expected; if (isnumeric (cond)) err.observed{end + 1} = num2str (cond); err.reason{end + 1} = "Expected string, but observed number"; elseif (iscell (cond)) err.observed{end + 1} = "{}"; err.reason{end + 1} = "Expected string, but observed cell"; else err.observed{end + 1} = "[]"; err.reason{end + 1} = "Expected string, but observed struct"; end elseif (! strcmp (cond, expected)) err.index{end + 1} = "[]"; err.observed{end + 1} = cond; err.expected{end + 1} = expected; err.reason{end + 1} = "Strings don't match"; endif elseif (iscell (expected)) if (! iscell (cond) || any (size (cond) != size (expected))) err.index{end + 1} = "{}"; err.observed{end + 1} = "O"; err.expected{end + 1} = "E"; err.reason{end + 1} = "Cell sizes don't match"; else try for i = 1:length (expected(:)) assert (cond{i}, expected{i}, tol); endfor catch err.index{end + 1} = "{}"; err.observed{end + 1} = "O"; err.expected{end + 1} = "E"; err.reason{end + 1} = "Cell configuration error"; end_try_catch endif elseif (isstruct (expected)) if (! isstruct (cond) || any (size (cond) != size (expected)) || rows (fieldnames (cond)) != rows (fieldnames (expected))) err.index{end + 1} = "{}"; err.observed{end + 1} = "O"; err.expected{end + 1} = "E"; err.reason{end + 1} = "Structure sizes don't match"; else try #empty = numel (cond) == 0; empty = isempty (cond); normal = (numel (cond) == 1); for [v, k] = cond if (! isfield (expected, k)) err.index{end + 1} = "."; err.observed{end + 1} = "O"; err.expected{end + 1} = "E"; err.reason{end + 1} = ["'" k "'" " is not an expected field"]; endif if (empty) v = {}; elseif (normal) v = {v}; else v = v(:)'; endif assert (v, {expected.(k)}, tol); endfor catch err.index{end + 1} = "."; err.observed{end + 1} = "O"; err.expected{end + 1} = "E"; err.reason{end + 1} = "Structure configuration error"; end_try_catch endif elseif (ndims (cond) != ndims (expected) || any (size (cond) != size (expected))) err.index{end + 1} = "."; err.observed{end + 1} = ["O(" (sprintf ("%dx", size (cond)) (1:end-1)) ")"]; err.expected{end + 1} = ["E(" (sprintf ("%dx", size (expected)) (1:end-1)) ")"]; err.reason{end + 1} = "Dimensions don't match"; else if (nargin < 3) ## Without explicit tolerance, be more strict. if (! strcmp (class (cond), class (expected))) err.index{end + 1} = "()"; err.observed{end + 1} = "O"; err.expected{end + 1} = "E"; err.reason{end + 1} = cstrcat("Class ", class (cond), " != ", class(expected)); elseif (isnumeric (cond)) if (issparse (cond) != issparse (expected)) err.index{end + 1} = "()"; err.observed{end + 1} = "O"; err.expected{end + 1} = "E"; if (issparse (cond)) err.reason{end + 1} = "sparse != non-sparse"; else err.reason{end + 1} = "non-sparse != sparse"; endif elseif (iscomplex (cond) != iscomplex (expected)) err.index{end + 1} = "()"; err.observed{end + 1} = "O"; err.expected{end + 1} = "E"; if (iscomplex (cond)) err.reason{end + 1} = "complex != real"; else err.reason{end + 1} = "real != complex"; endif endif endif endif if (isempty (err.index)) ## Numeric. A = cond; B = expected; ## Check exceptional values. erridx = find (isna (real (A)) != isna (real (B)) | isna (imag (A)) != isna (imag (B))); if (! isempty (erridx)) err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx); err.observed (end + 1:end + length (erridx)) = ... strtrim (cellstr (num2str (A (erridx) (:)))); err.expected (end + 1:end + length (erridx)) = ... strtrim (cellstr (num2str (B (erridx) (:)))); err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'NA' mismatch", length (erridx), 1)); endif erridx = find (isnan (real (A)) != isnan (real (B)) | isnan (imag (A)) != isnan (imag (B))); if (! isempty (erridx)) err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx); err.observed (end + 1:end + length (erridx)) = ... strtrim (cellstr (num2str (A (erridx) (:)))); err.expected (end + 1:end + length (erridx)) = ... strtrim (cellstr (num2str (B (erridx) (:)))); err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'NaN' mismatch", length (erridx), 1)); endif erridx = find (((isinf (real (A)) | isinf (real (B))) & real (A) != real (B)) | ... ((isinf (imag (A)) | isinf (imag (B))) & imag (A) != imag (B))); if (! isempty (erridx)) err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx); err.observed (end + 1:end + length (erridx)) = ... strtrim (cellstr (num2str (A (erridx) (:)))); err.expected (end + 1:end + length (erridx)) = ... strtrim (cellstr (num2str (B (erridx) (:)))); err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'Inf' mismatch", length (erridx), 1)); endif ## Check normal values. Replace all exceptional values by zero. A_null_real = real (A); B_null_real = real (B); exclude = ! isfinite (A_null_real) & ! isfinite (B_null_real); A_null_real (exclude) = 0; B_null_real (exclude) = 0; A_null_imag = imag (A); B_null_imag = imag (B); exclude = ! isfinite (A_null_real) & ! isfinite (B_null_real); A_null_imag (exclude) = 0; B_null_imag (exclude) = 0; A_null = complex (A_null_real, A_null_imag); B_null = complex (B_null_real, B_null_imag); if (isscalar (tol)) mtol = ones (size (A)) * tol; else mtol = tol; endif k = (mtol == 0 & isfinite (A_null) & isfinite (B_null)); erridx = find (A_null != B_null & k); if (! isempty (erridx)) err.index (end + 1:end + length (erridx)) = ... construct_indeces (size (A), erridx); err.observed (end + 1:end + length (erridx)) = ... strtrim (cellstr (num2str (A (erridx) (:)))); err.expected (end + 1:end + length (erridx)) = ... strtrim (cellstr (num2str (B (erridx) (:)))); err.reason (end + 1:end + length (erridx)) = ... strsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ... [(abs (A_null (erridx) - B_null (erridx))) (mtol (erridx))]')), "\n"); endif k = (mtol > 0 & isfinite (A_null) & isfinite (B_null)); erridx = find (abs (A_null - B_null) > mtol & k); if (! isempty (erridx)) err.index (end + 1:end + length (erridx)) = ... construct_indeces (size (A), erridx); err.observed (end + 1:end + length (erridx)) = ... strtrim (cellstr (num2str (A (erridx) (:)))); err.expected (end + 1:end + length (erridx)) = ... strtrim (cellstr (num2str (B (erridx) (:)))); err.reason (end + 1:end + length (erridx)) = ... strsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ... [(abs (A_null (erridx) - B_null (erridx))) (mtol (erridx))]')), "\n"); endif k = (mtol < 0); if (any (k)) AA = A_null (k); BB = B_null (k); abserr = max (abs (AA(BB == 0))); AA = AA(BB != 0); BB = BB(BB != 0); relerr = max (abs (AA - BB) ./ abs (BB)); maxerr = max ([abserr; relerr]); if (maxerr > abs (tol)) err.index{end + 1} = "()"; err.observed{end + 1} = "O"; err.expected{end + 1} = "E"; err.reason{end + 1} = sprintf ("Max rel err %g exceeds tol %g", maxerr, abs (tol)); endif endif endif endif ## Print any errors if (! isempty (err.index)) assert_error_occurred = 1; if (! isempty (errmsg)) errmsg = cstrcat (errmsg, "\n"); endif errmsg = cstrcat (errmsg, pprint (in, err)); end endif if (assert_call_depth == 0) ## Remove from the variable space to indicate end of recursion clear -v assert_call_depth; ## Last time through. If there were any errors on any pass, raise a flag. if (assert_error_occurred) error ("%s", errmsg); endif else assert_call_depth--; endif endfunction ## empty input %!assert ([]) %!assert (zeros (3,0), zeros (3,0)) %!error assert (zeros (3,0), zeros (0,2)) %!error assert (zeros (3,0), []) %!error <Dimensions don't match> assert (zeros (2,0,2), zeros (2,0)) ## conditions %!assert (isempty ([])) %!assert (1) %!error assert (0) %!assert (ones (3,1)) %!assert (ones (1,3)) %!assert (ones (3,4)) %!error assert ([1,0,1]) %!error assert ([1;1;0]) %!error assert ([1,0;1,1]) ## scalars %!error assert (3, [3,3; 3,3]) %!error assert ([3,3; 3,3], 3) %!assert (3, 3) %!assert (3+eps, 3, eps) %!assert (3, 3+eps, eps) %!error assert (3+2*eps, 3, eps) %!error assert (3, 3+2*eps, eps) ## vectors %!assert ([1,2,3],[1,2,3]); %!assert ([1;2;3],[1;2;3]); %!error assert ([2,2,3,3],[1,2,3,4]); %!error assert ([6;6;7;7],[5;6;7;8]); %!error assert ([1,2,3],[1;2;3]); %!error assert ([1,2],[1,2,3]); %!error assert ([1;2;3],[1;2]); %!assert ([1,2;3,4],[1,2;3,4]); %!error assert ([1,4;3,4],[1,2;3,4]) %!error assert ([1,3;2,4;3,5],[1,2;3,4]) ## matrices %!test %! A = [1 2 3]'*[1,2]; %! assert (A,A); %! fail ("assert (A.*(A!=2),A)"); %! X = zeros (2,2,3); %! Y = X; %! Y (1,2,3) = 1; %! fail ("assert (X,Y)"); ## must give a small tolerance for floating point errors on relative %!assert (100+100*eps, 100, -2*eps) %!assert (100, 100+100*eps, -2*eps) %!error assert (100+300*eps, 100, -2*eps) %!error assert (100, 100+300*eps, -2*eps) %!error assert (3, [3,3]) %!error assert (3, 4) ## test relative vs. absolute tolerances %!test assert (0.1+eps, 0.1, 2*eps); # accept absolute %!error assert (0.1+eps, 0.1, -2*eps); # fail relative %!test assert (100+100*eps, 100, -2*eps); # accept relative %!error assert (100+100*eps, 100, 2*eps); # fail absolute ## exceptional values %!assert ([NaN, NA, Inf, -Inf, 1+eps, eps], [NaN, NA, Inf, -Inf, 1, 0], eps) %!error assert (NaN, 1) %!error assert ([NaN 1], [1 NaN]) %!error assert (NA, 1) %!error assert ([NA 1]', [1 NA]') %!error assert ([(complex (NA, 1)) (complex (2, NA))], [(complex (NA, 2)) 2]) %!error assert (-Inf, Inf) %!error assert ([-Inf Inf], [Inf -Inf]) %!error assert (complex (Inf, 0.2), complex (-Inf, 0.2 + 2*eps), eps) ## strings %!assert ("dog", "dog") %!error assert ("dog", "cat") %!error assert ("dog", 3) %!error assert (3, "dog") %!error assert (cellstr ("dog"), "dog") %!error assert (cell2struct ({"dog"; 3}, {"pet", "age"}, 1), "dog"); ## structures %!shared x,y %! x.a = 1; x.b=[2, 2]; %! y.a = 1; y.b=[2, 2]; %!assert (x, y) %!test y.b=3; %!error assert (x, y) %!error assert (3, x) %!error assert (x, 3) %!test %! # Empty structures %! x = resize (x, 0, 1); %! y = resize (y, 0, 1); %! assert (x, y); ## cell arrays %!test %! x = {[3], [1,2,3]; 100+100*eps, "dog"}; %! y = x; %! assert (x, y); %! y = x; y(1,1) = [2]; %! fail ("assert (x, y)"); %! y = x; y(1,2) = [0, 2, 3]; %! fail ("assert (x, y)"); %! y = x; y(2,1) = 101; %! fail ("assert (x, y)"); %! y = x; y(2,2) = "cat"; %! fail ("assert (x, y)"); %! y = x; y(1,1) = [2]; y(1,2) = [0, 2, 3]; y(2,1) = 101; y(2,2) = "cat"; %! fail ("assert (x, y)"); ## variable tolerance %!test %! x = [-40:0]; %! y1 = (10.^x).*(10.^x); %! y2 = 10.^(2*x); %! assert (y1, y2, eps (y1)); %! fail ("assert (y1, y2 + eps*1e-70, eps (y1))"); ## test input validation %!error assert %!error assert (1,2,3,4) ## Convert all indeces into tuple format function cout = construct_indeces (matsize, erridx) cout = cell (numel (erridx), 1); tmp = cell (1, numel (matsize)); [tmp{:}] = ind2sub (matsize, erridx (:)); subs = [tmp{:}]; if (numel (matsize) == 2) subs = subs (:, matsize != 1); endif for i = 1:numel (erridx) loc = sprintf ("%d,", subs(i,:)); cout{i} = ["(" loc(1:end-1) ")"]; endfor endfunction ## Pretty print the various errors in a condensed tabular format. function str = pprint (in, err) str = sprintf (cstrcat ("ASSERT errors for: assert ", in, "\n")); str = cstrcat (str, sprintf ("\n Location | Observed | Expected | Reason\n")); prespace = zeros (3); postspace = zeros (3); for i = 1:length (err.index) len = length (err.index{i}); prespace (1) = floor ((10 - len) / 2); postspace (1) = 10 - len - prespace (1); len = length (err.observed{i}); prespace (2) = floor ((10 - len) / 2); postspace (2) = 10 - len - prespace (2); len = length (err.expected{i}); prespace (3) = floor ((10 - len) / 2); postspace (3) = 10 - len - prespace (3); str = cstrcat (str, sprintf ("%s %s %s %s %s %s %s %s %s %s\n", ... repmat (' ', 1, prespace (1)), ... err.index{i}, ... repmat (' ', 1, postspace (1)), ... repmat (' ', 1, prespace (2)), ... err.observed{i}, ... repmat (' ', 1, postspace (2)), ... repmat (' ', 1, prespace (3)), ... err.expected{i}, ... repmat (' ', 1, postspace (3)), ... err.reason{i})); endfor endfunction