Mercurial > hg > octave-nkf
changeset 17251:684ccccbc15d
assert.m: Add vector checking for relative errors.
* scripts/testfun/assert.m: Print vector of outputs failing relative error
checking. Don't clear() persistent variables. Eliminate unneeded variable
assert_error_occurred. Use field widths in sprintf for centering output text.
Put input validation first in function. Wrap long lines < 80 chars. Don't
use space between variable and parenthesis when indexing.
author | Rik <rik@octave.org> |
---|---|
date | Wed, 14 Aug 2013 11:42:35 -0700 |
parents | afd235a206a2 |
children | dac81d4b8ce1 |
files | scripts/testfun/assert.m |
diffstat | 1 files changed, 176 insertions(+), 171 deletions(-) [+] |
line wrap: on
line diff
--- a/scripts/testfun/assert.m +++ b/scripts/testfun/assert.m @@ -56,13 +56,17 @@ function assert (cond, varargin) - if (exist ("assert_call_depth", "var")) - assert_call_depth++; + if (nargin == 0 || nargin > 3) + print_usage (); + endif + + persistent call_depth = 0; + persistent errmsg; + + if (call_depth > 0) + call_depth++; else - persistent assert_call_depth = 0; - persistent assert_error_occurred; - assert_error_occurred = 0; - persistent errmsg; + call_depth = 0; errmsg = ""; end @@ -75,17 +79,13 @@ if (nargin == 1 || (nargin > 1 && islogical (cond) && ischar (varargin{1}))) if ((! isnumeric (cond) && ! islogical (cond)) || ! all (cond(:))) if (nargin == 1) - ## Say which elements failed? + ## Perhaps, say which elements failed? error ("assert %s failed", in); else error (varargin{:}); endif endif else - if (nargin < 2 || nargin > 3) - print_usage (); - endif - expected = varargin{1}; if (nargin < 3) tol = 0; @@ -93,11 +93,7 @@ tol = varargin{2}; endif - if (exist ("argn") == 0) - argn = " "; - endif - - ## Add to lists as the errors accumulate. If empty at end then no erros. + ## Add to list as the errors accumulate. If empty at end then no errors. err.index = {}; err.observed = {}; err.expected = {}; @@ -105,61 +101,62 @@ if (ischar (expected)) if (! ischar (cond)) - err.index{end + 1} = "[]"; - err.expected{end + 1} = expected; + err.index{end+1} = "[]"; + err.expected{end+1} = expected; if (isnumeric (cond)) - err.observed{end + 1} = num2str (cond); - err.reason{end + 1} = "Expected string, but observed number"; + err.observed{end+1} = num2str (cond); + err.reason{end+1} = "Expected string, but observed number"; elseif (iscell (cond)) - err.observed{end + 1} = "{}"; - err.reason{end + 1} = "Expected string, but observed cell"; + err.observed{end+1} = "{}"; + err.reason{end+1} = "Expected string, but observed cell"; else - err.observed{end + 1} = "[]"; - err.reason{end + 1} = "Expected string, but observed struct"; + err.observed{end+1} = "[]"; + err.reason{end+1} = "Expected string, but observed struct"; end elseif (! strcmp (cond, expected)) - err.index{end + 1} = "[]"; - err.observed{end + 1} = cond; - err.expected{end + 1} = expected; - err.reason{end + 1} = "Strings don't match"; + err.index{end+1} = "[]"; + err.observed{end+1} = cond; + err.expected{end+1} = expected; + err.reason{end+1} = "Strings don't match"; endif + elseif (iscell (expected)) if (! iscell (cond) || any (size (cond) != size (expected))) - err.index{end + 1} = "{}"; - err.observed{end + 1} = "O"; - err.expected{end + 1} = "E"; - err.reason{end + 1} = "Cell sizes don't match"; + err.index{end+1} = "{}"; + err.observed{end+1} = "O"; + err.expected{end+1} = "E"; + err.reason{end+1} = "Cell sizes don't match"; else try + ## Recursively compare cell arrays for i = 1:length (expected(:)) assert (cond{i}, expected{i}, tol); endfor catch - err.index{end + 1} = "{}"; - err.observed{end + 1} = "O"; - err.expected{end + 1} = "E"; - err.reason{end + 1} = "Cell configuration error"; + err.index{end+1} = "{}"; + err.observed{end+1} = "O"; + err.expected{end+1} = "E"; + err.reason{end+1} = "Cell configuration error"; end_try_catch endif elseif (isstruct (expected)) if (! isstruct (cond) || any (size (cond) != size (expected)) || rows (fieldnames (cond)) != rows (fieldnames (expected))) - err.index{end + 1} = "{}"; - err.observed{end + 1} = "O"; - err.expected{end + 1} = "E"; - err.reason{end + 1} = "Structure sizes don't match"; + err.index{end+1} = "{}"; + err.observed{end+1} = "O"; + err.expected{end+1} = "E"; + err.reason{end+1} = "Structure sizes don't match"; else try - #empty = numel (cond) == 0; empty = isempty (cond); normal = (numel (cond) == 1); for [v, k] = cond if (! isfield (expected, k)) - err.index{end + 1} = "."; - err.observed{end + 1} = "O"; - err.expected{end + 1} = "E"; - err.reason{end + 1} = ["'" k "'" " is not an expected field"]; + err.index{end+1} = "."; + err.observed{end+1} = "O"; + err.expected{end+1} = "E"; + err.reason{end+1} = ["'" k "'" " is not an expected field"]; endif if (empty) v = {}; @@ -168,49 +165,50 @@ else v = v(:)'; endif + ## Recursively call assert for struct array values assert (v, {expected.(k)}, tol); endfor catch - err.index{end + 1} = "."; - err.observed{end + 1} = "O"; - err.expected{end + 1} = "E"; - err.reason{end + 1} = "Structure configuration error"; + err.index{end+1} = "."; + err.observed{end+1} = "O"; + err.expected{end+1} = "E"; + err.reason{end+1} = "Structure configuration error"; end_try_catch endif elseif (ndims (cond) != ndims (expected) || any (size (cond) != size (expected))) - err.index{end + 1} = "."; - err.observed{end + 1} = ["O(" (sprintf ("%dx", size (cond)) (1:end-1)) ")"]; - err.expected{end + 1} = ["E(" (sprintf ("%dx", size (expected)) (1:end-1)) ")"]; - err.reason{end + 1} = "Dimensions don't match"; + err.index{end+1} = "."; + err.observed{end+1} = ["O(" sprintf("%dx", size(cond))(1:end-1) ")"]; + err.expected{end+1} = ["E(" sprintf("%dx", size(expected))(1:end-1) ")"]; + err.reason{end+1} = "Dimensions don't match"; - else + else # Numeric comparison if (nargin < 3) ## Without explicit tolerance, be more strict. if (! strcmp (class (cond), class (expected))) - err.index{end + 1} = "()"; - err.observed{end + 1} = "O"; - err.expected{end + 1} = "E"; - err.reason{end + 1} = cstrcat("Class ", class (cond), " != ", class(expected)); + err.index{end+1} = "()"; + err.observed{end+1} = "O"; + err.expected{end+1} = "E"; + err.reason{end+1} = ["Class " class(cond) " != " class(expected)]; elseif (isnumeric (cond)) if (issparse (cond) != issparse (expected)) - err.index{end + 1} = "()"; - err.observed{end + 1} = "O"; - err.expected{end + 1} = "E"; + err.index{end+1} = "()"; + err.observed{end+1} = "O"; + err.expected{end+1} = "E"; if (issparse (cond)) - err.reason{end + 1} = "sparse != non-sparse"; + err.reason{end+1} = "sparse != non-sparse"; else - err.reason{end + 1} = "non-sparse != sparse"; + err.reason{end+1} = "non-sparse != sparse"; endif elseif (iscomplex (cond) != iscomplex (expected)) - err.index{end + 1} = "()"; - err.observed{end + 1} = "O"; - err.expected{end + 1} = "E"; + err.index{end+1} = "()"; + err.observed{end+1} = "O"; + err.expected{end+1} = "E"; if (iscomplex (cond)) - err.reason{end + 1} = "complex != real"; + err.reason{end+1} = "complex != real"; else - err.reason{end + 1} = "real != complex"; + err.reason{end+1} = "real != complex"; endif endif endif @@ -218,103 +216,129 @@ if (isempty (err.index)) - ## Numeric. A = cond; B = expected; ## Check exceptional values. - erridx = find (isna (real (A)) != isna (real (B)) | isna (imag (A)) != isna (imag (B))); + erridx = find ( isna (real (A)) != isna (real (B)) + | isna (imag (A)) != isna (imag (B))); if (! isempty (erridx)) - err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx); - err.observed (end + 1:end + length (erridx)) = ... - strtrim (cellstr (num2str (A (erridx) (:)))); - err.expected (end + 1:end + length (erridx)) = ... - strtrim (cellstr (num2str (B (erridx) (:)))); - err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'NA' mismatch", length (erridx), 1)); + err.index(end+1:end + length (erridx)) = ... + ind2tuple (size (A), erridx); + err.observed(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (A(erridx) (:)))); + err.expected(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (B(erridx) (:)))); + err.reason(end+1:end + length (erridx)) = ... + cellstr (repmat ("'NA' mismatch", length (erridx), 1)); endif - erridx = find (isnan (real (A)) != isnan (real (B)) | isnan (imag (A)) != isnan (imag (B))); + erridx = find ( isnan (real (A)) != isnan (real (B)) + | isnan (imag (A)) != isnan (imag (B))); if (! isempty (erridx)) - err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx); - err.observed (end + 1:end + length (erridx)) = ... - strtrim (cellstr (num2str (A (erridx) (:)))); - err.expected (end + 1:end + length (erridx)) = ... - strtrim (cellstr (num2str (B (erridx) (:)))); - err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'NaN' mismatch", length (erridx), 1)); + err.index(end+1:end + length (erridx)) = ... + ind2tuple (size (A), erridx); + err.observed(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (A(erridx) (:)))); + err.expected(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (B(erridx) (:)))); + err.reason(end+1:end + length (erridx)) = ... + cellstr (repmat ("'NaN' mismatch", length (erridx), 1)); endif - erridx = find (((isinf (real (A)) | isinf (real (B))) & real (A) != real (B)) | ... - ((isinf (imag (A)) | isinf (imag (B))) & imag (A) != imag (B))); + erridx = find (((isinf (real (A)) | isinf (real (B))) ... + & real (A) != real (B)) ... + | ((isinf (imag (A)) | isinf (imag (B))) + & imag (A) != imag (B))); if (! isempty (erridx)) - err.index (end + 1:end + length (erridx)) = construct_indeces (size (A), erridx); - err.observed (end + 1:end + length (erridx)) = ... - strtrim (cellstr (num2str (A (erridx) (:)))); - err.expected (end + 1:end + length (erridx)) = ... - strtrim (cellstr (num2str (B (erridx) (:)))); - err.reason (end + 1:end + length (erridx)) = cellstr (repmat ("'Inf' mismatch", length (erridx), 1)); + err.index(end+1:end + length (erridx)) = ... + ind2tuple (size (A), erridx); + err.observed(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (A(erridx) (:)))); + err.expected(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (B(erridx) (:)))); + err.reason(end+1:end + length (erridx)) = ... + cellstr (repmat ("'Inf' mismatch", length (erridx), 1)); endif - ## Check normal values. Replace all exceptional values by zero. + ## Check normal values. + ## Replace exceptional values already checked above by zero. A_null_real = real (A); B_null_real = real (B); exclude = ! isfinite (A_null_real) & ! isfinite (B_null_real); - A_null_real (exclude) = 0; - B_null_real (exclude) = 0; + A_null_real(exclude) = 0; + B_null_real(exclude) = 0; A_null_imag = imag (A); B_null_imag = imag (B); exclude = ! isfinite (A_null_real) & ! isfinite (B_null_real); - A_null_imag (exclude) = 0; - B_null_imag (exclude) = 0; + A_null_imag(exclude) = 0; + B_null_imag(exclude) = 0; A_null = complex (A_null_real, A_null_imag); B_null = complex (B_null_real, B_null_imag); if (isscalar (tol)) - mtol = ones (size (A)) * tol; + mtol = repmat (tol, size (A)); else mtol = tol; endif - k = (mtol == 0 & isfinite (A_null) & isfinite (B_null)); - erridx = find (A_null != B_null & k); + k = (mtol == 0); + erridx = find ((A_null != B_null) & k); if (! isempty (erridx)) - err.index (end + 1:end + length (erridx)) = ... - construct_indeces (size (A), erridx); - err.observed (end + 1:end + length (erridx)) = ... - strtrim (cellstr (num2str (A (erridx) (:)))); - err.expected (end + 1:end + length (erridx)) = ... - strtrim (cellstr (num2str (B (erridx) (:)))); - err.reason (end + 1:end + length (erridx)) = ... - strsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ... - [(abs (A_null (erridx) - B_null (erridx))) (mtol (erridx))]')), "\n"); + err.index(end+1:end + length (erridx)) = ... + ind2tuple (size (A), erridx); + err.observed(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (A(erridx) (:)))); + err.expected(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (B(erridx) (:)))); + err.reason(end+1:end + length (erridx)) = ... + ostrsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ... + [abs(A_null(erridx) - B_null(erridx)) mtol(erridx)]')), "\n"); endif - k = (mtol > 0 & isfinite (A_null) & isfinite (B_null)); - erridx = find (abs (A_null - B_null) > mtol & k); + k = (mtol > 0); + erridx = find ((abs (A_null - B_null) > mtol) & k); if (! isempty (erridx)) - err.index (end + 1:end + length (erridx)) = ... - construct_indeces (size (A), erridx); - err.observed (end + 1:end + length (erridx)) = ... - strtrim (cellstr (num2str (A (erridx) (:)))); - err.expected (end + 1:end + length (erridx)) = ... - strtrim (cellstr (num2str (B (erridx) (:)))); - err.reason (end + 1:end + length (erridx)) = ... - strsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ... - [(abs (A_null (erridx) - B_null (erridx))) (mtol (erridx))]')), "\n"); + err.index(end+1:end + length (erridx)) = ... + ind2tuple (size (A), erridx); + err.observed(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (A(erridx) (:)))); + err.expected(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (B(erridx) (:)))); + err.reason(end+1:end + length (erridx)) = ... + ostrsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", ... + [abs(A_null(erridx) - B_null(erridx)) mtol(erridx)]')), "\n"); endif k = (mtol < 0); if (any (k)) - AA = A_null (k); - BB = B_null (k); - abserr = max (abs (AA(BB == 0))); - AA = AA(BB != 0); - BB = BB(BB != 0); - relerr = max (abs (AA - BB) ./ abs (BB)); - maxerr = max ([abserr; relerr]); - if (maxerr > abs (tol)) - err.index{end + 1} = "()"; - err.observed{end + 1} = "O"; - err.expected{end + 1} = "E"; - err.reason{end + 1} = sprintf ("Max rel err %g exceeds tol %g", maxerr, abs (tol)); + ## Test for absolute error where relative error can't be calculated. + erridx = find ((B_null == 0) & abs (A_null) > abs (mtol) & k); + if (! isempty (erridx)) + err.index(end+1:end + length (erridx)) = ... + ind2tuple (size (A), erridx); + err.observed(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (A(erridx) (:)))); + err.expected(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (B(erridx) (:)))); + err.reason(end+1:end + length (erridx)) = ... + ostrsplit (deblank (sprintf ("Abs err %g exceeds tol %g\n", + [abs(A_null(erridx) - B_null(erridx)) -mtol(erridx)]')), "\n"); + endif + ## Test for relative error + Bdiv = Inf (size (B_null)); + Bdiv(k & (B_null != 0)) = B_null(k & (B_null != 0)); + relerr = abs ((A_null - B_null) ./ abs (Bdiv)); + erridx = find ((relerr > abs (mtol)) & k); + if (! isempty (erridx)) + err.index(end+1:end + length (erridx)) = ... + ind2tuple (size (A), erridx); + err.observed(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (A(erridx) (:)))); + err.expected(end+1:end + length (erridx)) = ... + strtrim (cellstr (num2str (B(erridx) (:)))); + err.reason(end+1:end + length (erridx)) = ... + ostrsplit (deblank (sprintf ("Rel err %g exceeds tol %g\n", + [relerr(erridx) -mtol(erridx)]')), "\n"); endif endif endif @@ -323,27 +347,21 @@ ## Print any errors if (! isempty (err.index)) - assert_error_occurred = 1; if (! isempty (errmsg)) - errmsg = cstrcat (errmsg, "\n"); + errmsg = [errmsg "\n"]; endif - errmsg = cstrcat (errmsg, pprint (in, err)); + errmsg = [errmsg, pprint(in, err)]; end endif - if (assert_call_depth == 0) - - ## Remove from the variable space to indicate end of recursion - clear -v assert_call_depth; - + if (call_depth == 0) ## Last time through. If there were any errors on any pass, raise a flag. - if (assert_error_occurred) - error ("%s", errmsg); + if (! isempty (errmsg)) + error (errmsg); endif - else - assert_call_depth--; + call_depth--; endif endfunction @@ -471,19 +489,19 @@ %! fail ("assert (y1, y2 + eps*1e-70, eps (y1))"); ## test input validation -%!error assert +%!error assert () %!error assert (1,2,3,4) -## Convert all indeces into tuple format -function cout = construct_indeces (matsize, erridx) +## Convert all error indices into tuple format +function cout = ind2tuple (matsize, erridx) cout = cell (numel (erridx), 1); tmp = cell (1, numel (matsize)); [tmp{:}] = ind2sub (matsize, erridx (:)); subs = [tmp{:}]; if (numel (matsize) == 2) - subs = subs (:, matsize != 1); + subs = subs(:, matsize != 1); endif for i = 1:numel (erridx) loc = sprintf ("%d,", subs(i,:)); @@ -496,31 +514,18 @@ ## Pretty print the various errors in a condensed tabular format. function str = pprint (in, err) - str = sprintf (cstrcat ("ASSERT errors for: assert ", in, "\n")); - str = cstrcat (str, sprintf ("\n Location | Observed | Expected | Reason\n")); - prespace = zeros (3); - postspace = zeros (3); + str = ["ASSERT errors for: assert " in "\n"]; + str = [str, "\n Location | Observed | Expected | Reason\n"]; for i = 1:length (err.index) - len = length (err.index{i}); - prespace (1) = floor ((10 - len) / 2); - postspace (1) = 10 - len - prespace (1); - len = length (err.observed{i}); - prespace (2) = floor ((10 - len) / 2); - postspace (2) = 10 - len - prespace (2); - len = length (err.expected{i}); - prespace (3) = floor ((10 - len) / 2); - postspace (3) = 10 - len - prespace (3); - str = cstrcat (str, sprintf ("%s %s %s %s %s %s %s %s %s %s\n", ... - repmat (' ', 1, prespace (1)), ... - err.index{i}, ... - repmat (' ', 1, postspace (1)), ... - repmat (' ', 1, prespace (2)), ... - err.observed{i}, ... - repmat (' ', 1, postspace (2)), ... - repmat (' ', 1, prespace (3)), ... - err.expected{i}, ... - repmat (' ', 1, postspace (3)), ... - err.reason{i})); + leni = length (err.index{i}); + leno = length (err.observed{i}); + lene = length (err.expected{i}); + str = [str, sprintf("%*s%*s %*s%*s %*s%*s %s\n", + 6+fix(leni/2), err.index{i} , 6-fix(leni/2), "", + 6+fix(leno/2), err.observed{i}, 6-fix(leno/2), "", + 6+fix(lene/2), err.expected{i}, 6-fix(lene/2), "", + err.reason{i})]; endfor endfunction +