# HG changeset patch # User Carlo de Falco # Date 1315073930 -7200 # Node ID 4b110dd204b9c3db9b16c198093bfa796c47810e # Parent b3a8b75dfec393f3c2864aa1bbdc7f41ef5c8ca3 codesprint: Allow passing a function handle for the coefficient matrix in cgs * cgs.m: Allow passing a function handle for the coefficient matrix. Also add more tests. diff --git a/scripts/sparse/cgs.m b/scripts/sparse/cgs.m --- a/scripts/sparse/cgs.m +++ b/scripts/sparse/cgs.m @@ -1,4 +1,5 @@ ## Copyright (C) 2008-2011 Radek Salac +## Copyright (C) 2011 Carlo de Falco ## ## This file is part of Octave. ## @@ -17,134 +18,163 @@ ## . ## -*- texinfo -*- -## @deftypefn {Function File} {} cgs (@var{A}, @var{b}) -## @deftypefnx {Function File} {} cgs (@var{A}, @var{b}, @var{tol}, @var{maxit}, @var{M1}, @var{M2}, @var{x0}) -## This procedure attempts to solve a system of linear equations A*x = b for x. -## The @var{A} must be square, symmetric and positive definite real matrix N*N. -## The @var{b} must be a one column vector with a length of N. -## The @var{tol} specifies the tolerance of the method, default value is 1e-6. -## The @var{maxit} specifies the maximum number of iteration, default value is -## MIN(20,N). -## The @var{M1} specifies a preconditioner, can also be a function handler which -## returns M\X. -## The @var{M2} combined with @var{M1} defines preconditioner as -## preconditioner=M1*M2. -## The @var{x0} is initial guess, default value is zeros(N,1). +## +## @deftypefn {Function File} {@var{x} =} cgs (@var{A}, @var{b}, @var{rtol}, @var{maxit}, @var{M1}, @var{M2}, @var{x0}) +## @deftypefnx {Function File} {@var{x} =} cgs (@var{A}, @var{b}, @var{rtol}, @var{maxit}, @var{P}) +## @deftypefnx {Function File} {[@var{x}, @var{flag}, @var{relres}, @var{iter}, @var{resvec}] =} cgs (@var{A}, @var{b}, ...) +## +## Solve @code{A x = b}, where @var{A} is a square matrix, using the Conjugate Gradients Squared method. +## +## @itemize @minus +## @item @var{rtol} is the relative tolerance, if not given or set to [] the default value 1e-6 is used. +## @item @var{maxit} the maximum number of outer iterations, if not given or set to [] the default value @code{min (20, numel (b))} is used. +## @item @var{x0} the initial guess, if not given or set to [] the default value @code{zeros (size (b))} is used. +## @end itemize +## +## @var{A} can be passed as a matrix or as a function handle or +## inline function @code{f} such that @code{f(x) = A*x}. +## +## The preconditioner @var{P} is given as @code{P = M1 * M2}. +## Both @var{M1} and @var{M2} can be passed as a matrix or as a function handle or +## inline function @code{g} such that @code{g(x) = M1 \ x} or @code{g(x) = M2 \ x}. +## +## If called with more than one output parameter +## +## @itemize @minus +## @item @var{flag} indicates the exit status: +## @itemize @minus +## @item 0: iteration converged to the within the chosen tolerance +## @item 1: the maximum number of iterations was reached before convergence +## @item 3: the algorithm reached stagnation +## @end itemize +## (the value 2 is unused but skipped for compatibility). +## @item @var{relres} is the final value of the relative residual. +## @item @var{iter} is the number of iterations performed. +## @item @var{resvec} is a vector containing the relative residual at each iteration. +## @end itemize +## +## @seealso{pcg,bicgstab,bicg,gmres} ## ## @end deftypefn function [x, flag, relres, iter, resvec] = cgs (A, b, tol, maxit, M1, M2, x0) - if (nargin < 2 || nargin > 7 || nargout > 5) - print_usage (); - elseif (!(isnumeric (A) && issquare (A))) - error ("cgs: A must be a square numeric matrix"); - elseif (!isvector (b)) - error ("cgs: B must be a vector"); - elseif (rows (A) != rows (b)) - error ("cgs: A and B must have the same number of rows"); - elseif (nargin > 2 && !isscalar (tol)) - error ("cgs: TOL must be a scalar"); - elseif (nargin > 3 && !isscalar (maxit)) - error ("cgs: MAXIT must be a scalar"); - elseif (nargin > 4 && ismatrix (M1) && (rows (M1) != rows (A) || columns (M1) != columns (A))) - error ("cgs: M1 must have the same number of rows and columns as A"); - elseif (nargin > 5 && (!ismatrix (M2) || rows (M2) != rows (A) || columns (M2) != columns (A))) - error ("cgs: M2 must have the same number of rows and columns as A"); - elseif (nargin > 6 && !isvector (x0)) - error ("cgs: X0 must be a vector"); - elseif (nargin > 6 && rows (x0) != rows (b)) - error ("cgs: X0 must have the same number of rows as B"); - endif - - ## Default tolerance. - if (nargin < 3) - tol = 1e-6; - endif + if ((nargin >= 2) && (nargin <= 7) && isvector (full (b))) + + if (ischar (A)) + A = str2func (A); + elseif (ismatrix (A)) + Ax = @(x) A * x; + elseif (isa (A, "function_handle")) + Ax = @(x) feval (A, x); + else + error (["cgs: first argument is expected to "... + "be a function or a square matrix"]); + endif + + if ((nargin < 3) || (isempty (tol))) + tol = 1e-6; + endif + + if ((nargin < 4) || (isempty (maxit))) + maxit = min (rows (b), 20); + endif + + if ((nargin < 5) || isempty (M1)) + M1m1x = @(x) x; + elseif (ischar (M1)) + M1m1x = str2func (M1); + elseif (ismatrix (M1)) + M1m1x = @(x) M1 \ x; + elseif (isa (M1, "function_handle")) + M1m1x = @(x) feval (M1, x); + else + error (["cgs: preconditioner is expected "... + "to be a function or matrix"]); + endif + + if ((nargin < 6) || isempty (M2)) + M2m1x = @(x) x; + elseif (ischar (M2)) + M2m1x = str2func (M2); + elseif (ismatrix (M2)) + M2m1x = @(x) M2 \ x; + elseif (isa (M2, "function_handle")) + M2m1x = @(x) feval (M2, x); + else + error ("cgs: preconditioner is expected to be a function or matrix"); + endif - ## Default maximum number of iteration. - if (nargin < 4) - maxit = min (rows (b),20); - endif + precon = @(x) M2m1x (M1m1x (x)); - ## Left preconditioner. - if (nargin == 5) - if (isnumeric (M1)) - precon = @(x) M1 \ x; + if ((nargin < 7) || (isempty (x0))) + x0 = zeros (size (b)); endif - elseif (nargin > 5) - if (issparse (M1) && issparse (M2)) - precon = @(x) M2 \ (M1 \ x); - else - M = M1*M2; - precon = @(x) M \ x; - endif - else - precon = @(x) x; - endif - - ## Specifies initial estimate x0. - if (nargin < 7) - x = zeros (rows (b), 1); - else - x = x0; - endif - res = b - A * x; - norm_b = norm (b); - ## Vector of the residual norms for each iteration. - resvec = [ norm(res)/norm_b ]; - ro = 0; - ## Default behavior we don't reach tolerance tol within maxit iterations. - flag = 1; - for iter = 1 : maxit + + x = x0; + + res = b - Ax (x); + norm_b = norm (b); + ## Vector of the residual norms for each iteration. + resvec = [ norm(res)/norm_b ]; + ro = 0; + ## Default behavior we don't reach tolerance tol within maxit iterations. + flag = 1; + for iter = 1 : maxit - z = precon (res); + z = precon (res); - ## Cache. - ro_old = ro; - ro = res' * z; - if (iter == 1) - p = z; - else - beta = ro / ro_old; - p = z + beta * p; - endif - ## Cache. - q = A * p; - alpha = ro / (p' * q); - x = x + alpha * p; + ## Cache. + ro_old = ro; + ro = res' * z; + if (iter == 1) + p = z; + else + beta = ro / ro_old; + p = z + beta * p; + endif + ## Cache. + q = Ax (p); + alpha = ro / (p' * q); + x = x + alpha * p; + + res = res - alpha * q; + relres = norm(res) / norm_b; + resvec = [resvec;relres]; - res = res - alpha * q; - relres = norm(res) / norm_b; - resvec = [resvec;relres]; - - if (relres <= tol) - ## We reach tolerance tol within maxit iterations. - flag = 0; - break; - elseif (resvec (end) == resvec (end - 1)) - ## The method stagnates. - flag = 3; - break; - endif - endfor; + if (relres <= tol) + ## We reach tolerance tol within maxit iterations. + flag = 0; + break; + elseif (resvec (end) == resvec (end - 1)) + ## The method stagnates. + flag = 3; + break; + endif + endfor; - if (nargout < 1) - if ( flag == 0 ) - printf (["cgs converged at iteration %i ", - "to a solution with relative residual %e\n"],iter,relres); - elseif (flag == 3) - printf (["cgs stopped at iteration %i ", - "without converging to the desired tolerance %e\n", - "because the method stagnated.\n", - "The iterate returned (number %i) has relative residual %e\n"],iter,tol,iter,relres); - else - printf (["cgs stopped at iteration %i ", - "without converging to the desired tolerance %e\n", - "because the maximum number of iterations was reached.\n", - "The iterate returned (number %i) has relative residual %e\n"],iter,tol,iter,relres); + if (nargout < 1) + if ( flag == 0 ) + printf (["cgs converged at iteration %i ", + "to a solution with relative residual %e\n"],iter,relres); + elseif (flag == 3) + printf (["cgs stopped at iteration %i ", + "without converging to the desired tolerance %e\n", + "because the method stagnated.\n", + "The iterate returned (number %i) has relative residual %e\n"], + iter,tol,iter,relres); + else + printf (["cgs stopped at iteration %i ", + "without converging to the desired tolerance %e\n", + "because the maximum number of iterations was reached.\n", + "The iterate returned (number %i) has relative residual %e\n"], + iter,tol,iter,relres); + endif endif + + else + print_usage (); endif endfunction @@ -156,3 +186,35 @@ %! A=[5 -1 3;-1 2 -2;3 -2 3] %! b=[7;-1;4] %! [a,b,c,d,e]=cgs(A,b) + +%!shared A, b, n, M +%! +%!test +%! n = 100; +%! A = spdiags ([-ones(n,1) 4*ones(n,1) -ones(n,1)], -1:1, n, n); +%! b = sum (A, 2); +%! tol = 1e-8; +%! maxit = 1000; +%! M = 4*eye (n); +%! [x, flag, relres, iter, resvec] = cgs (A, b, tol, maxit, M); +%! assert (x, ones (size (b)), 1e-7); +%! +%!test +%! tol = 1e-8; +%! maxit = 15; +%! +%! function y = afun (x, a) +%! y = a * x; +%! endfunction +%! +%! [x, flag, relres, iter, resvec] = cgs (@(x) afun (x, A), b, tol, maxit, M); +%! assert (x, ones (size (b)), 1e-7); + +%!test +%! n = 100; +%! tol = 1e-8; +%! a = sprand (n, n, .1); +%! A = a'*a + 100 * eye (n); +%! b = sum (A, 2); +%! [x, flag, relres, iter, resvec] = cgs (A, b, tol, [], diag (diag (A))); +%! assert (x, ones (size (b)), 1e-7);