# HG changeset patch # User Carlo de Falco # Date 1315073190 -7200 # Node ID addfc0ae69c0561849e0d4b9bff9aea1d6e71e24 # Parent 85dd509673e76de53820a04805d642ae1a94d72d Make bicgstab interface more compatible * bicgstab.m: Add the possibility to pass a function handle for the coefficient matrix. Also add more tests. diff --git a/scripts/sparse/bicgstab.m b/scripts/sparse/bicgstab.m --- a/scripts/sparse/bicgstab.m +++ b/scripts/sparse/bicgstab.m @@ -1,4 +1,5 @@ ## Copyright (C) 2008-2011 Radek Salac +## Copyright (C) 2011 Carlo de Falco ## ## This file is part of Octave. ## @@ -17,155 +18,179 @@ ## . ## -*- texinfo -*- -## @deftypefn {Function File} {} bicgstab (@var{A}, @var{b}) -## @deftypefnx {Function File} {} bicgstab (@var{A}, @var{b}, @var{tol}, @var{maxit}, @var{M1}, @var{M2}, @var{x0}) -## This procedure attempts to solve a system of linear equations A*x = b for x. -## The @var{A} must be square, symmetric and positive definite real matrix N*N. -## The @var{b} must be a one column vector with a length of N. -## The @var{tol} specifies the tolerance of the method, the default value is -## 1e-6. -## The @var{maxit} specifies the maximum number of iterations, the default value -## is min(20,N). -## The @var{M1} specifies a preconditioner, can also be a function handler which -## returns M\X. -## The @var{M2} combined with @var{M1} defines preconditioner as -## preconditioner=M1*M2. -## The @var{x0} is the initial guess, the default value is zeros(N,1). +## +## @deftypefn {Function File} {@var{x} =} bicgstab (@var{A}, @var{b}, @var{rtol}, @var{maxit}, @var{M1}, @var{M2}, @var{x0}) +## @deftypefnx {Function File} {@var{x} =} bicgstab (@var{A}, @var{b}, @var{rtol}, @var{maxit}, @var{P}) +## @deftypefnx {Function File} {[@var{x}, @var{flag}, @var{relres}, @var{iter}, @var{resvec}] =} bicgstab (@var{A}, @var{b}, ...) +## +## Solve @code{A x = b} using the stabilizied Bi-conjugate gradient iterative method. +## +## @itemize @minus +## @item @var{rtol} is the relative tolerance, if not given or set to [] the default value 1e-6 is used. +## @item @var{maxit} the maximum number of outer iterations, if not given or set to [] the default value @code{min (20, numel (b))} is used. +## @item @var{x0} the initial guess, if not given or set to [] the default value @code{zeros (size (b))} is used. +## @end itemize +## +## @var{A} can be passed as a matrix or as a function handle or +## inline function @code{f} such that @code{f(x) = A*x}. ## -## The value @var{x} is a computed result of this procedure. -## The value @var{flag} can be 0 when we reach tolerance in @var{maxit} -## iterations, 1 when -## we don't reach tolerance in @var{maxit} iterations and 3 when the procedure -## stagnates. -## The value @var{relres} is a relative residual - norm(b-A*x)/norm(b). -## The value @var{iter} is an iteration number in which x was computed. -## The value @var{resvec} is a vector of @var{relres} for each iteration. +## The preconditioner @var{P} is given as @code{P = M1 * M2}. +## Both @var{M1} and @var{M2} can be passed as a matrix or as a function handle or +## inline function @code{g} such that @code{g(x) = M1 \ x} or @code{g(x) = M2 \ x}. +## +## If called with more than one output parameter +## +## @itemize @minus +## @item @var{flag} indicates the exit status: +## @itemize @minus +## @item 0: iteration converged to the within the chosen tolerance +## @item 1: the maximum number of iterations was reached before convergence +## @item 3: the algorithm reached stagnation +## @end itemize +## (the value 2 is unused but skipped for compatibility). +## @item @var{relres} is the final value of the relative residual. +## @item @var{iter} is the number of iterations performed. +## @item @var{resvec} is a vector containing the relative residual at each iteration. +## @end itemize +## +## @seealso{pcg,cgs,bigc,gmres} ## ## @end deftypefn -function [x, flag, relres, iter, resvec] = bicgstab (A, b, tol, maxit, M1, M2, x0) +function [x, flag, relres, iter, resvec] = bicgstab (A, b, tol, maxit, + M1, M2, x0) - if (nargin < 2 || nargin > 7 || nargout > 5) - print_usage (); - elseif (!(isnumeric (A) && issquare (A))) - error ("bicgstab: A must be a square numeric matrix"); - elseif (!isvector (b)) - error ("bicgstab: B must be a vector"); - elseif (!any (b)) - error ("bicgstab: B must not be a vector of all zeros"); - elseif (rows (A) != rows (b)) - error ("bicgstab: A and B must have the same number of rows"); - elseif (nargin > 2 && !isscalar (tol)) - error ("bicgstab: TOL must be a scalar"); - elseif (nargin > 3 && !isscalar (maxit)) - error ("bicgstab: MAXIT must be a scalar"); - elseif (nargin > 4 && ismatrix (M1) && (rows (M1) != rows (A) || columns (M1) != columns (A))) - error ("bicgstab: M1 must have the same number of rows and columns as A"); - elseif (nargin > 5 && (!ismatrix (M2) || rows (M2) != rows (A) || columns (M2) != columns (A))) - error ("bicgstab: M2 must have the same number of rows and columns as A"); - elseif (nargin > 6 && !isvector (x0)) - error ("bicgstab: X0 must be a vector"); - elseif (nargin > 6 && rows (x0) != rows (b)) - error ("bicgstab: X0 must have the same number of rows as B"); - endif + if ((nargin >= 2) && (nargin <= 7) && isvector (full (b))) + + if (ischar (A)) + A = str2func (A); + elseif (ismatrix (A)) + Ax = @(x) A * x; + elseif (isa (A, "function_handle")) + Ax = @(x) feval (A, x); + else + error (["bicgstab: first argument is expected " ... + "to be a function or a square matrix"]); + endif + + if ((nargin < 3) || (isempty (tol))) + tol = 1e-6; + endif - ## Default tolerance. - if (nargin < 3) - tol = 1e-6; - endif - - ## Default maximum number of iteration. - if (nargin < 4) - maxit = min (rows (b), 20); - endif + if ((nargin < 4) || (isempty (maxit))) + maxit = min (rows (b), 20); + endif - ## Left preconditioner. - if (nargin == 5) - if (isnumeric (M1)) - precon = @(x) M1 \ x; + if ((nargin < 5) || isempty (M1)) + M1m1x = @(x) x; + elseif (ischar (M1)) + M1m1x = str2func (M1); + elseif (ismatrix (M1)) + M1m1x = @(x) M1 \ x; + elseif (isa (M1, "function_handle")) + M1m1x = @(x) feval (M1, x); + else + error (["bicgstab: preconditioner is " ... + "expected to be a function or matrix"]); endif - elseif (nargin > 5) - if (issparse (M1) && issparse (M2)) - precon = @(x) M2 \ (M1 \ x); + + if ((nargin < 6) || isempty (M2)) + M2m1x = @(x) x; + elseif (ischar (M2)) + M2m1x = str2func (M2); + elseif (ismatrix (M2)) + M2m1x = @(x) M2 \ x; + elseif (isa (M2, "function_handle")) + M2m1x = @(x) feval (M2, x); else - M = M1*M2; - precon = @(x) M \ x; + error (["bicgstab: preconditioner is "... + "expected to be a function or matrix"]); endif - else - precon = @(x) x; - endif - ## specifies initial estimate x0 - if (nargin < 7) - x = zeros (rows (b), 1); - else - x = x0; - endif + precon = @(x) M2m1x (M1m1x (x)); - norm_b = norm (b); - - res = b - A*x; - rr = res; + if ((nargin < 7) || (isempty (x0))) + x0 = zeros (size (b)); + endif - ## Vector of the residual norms for each iteration. - resvec = [norm(res)/norm_b]; - ## Default behaviour we don't reach tolerance tol within maxit iterations. - flag = 1; - - for iter = 1:maxit - rho_1 = res' * rr; - - if (iter == 1) - p = res; + ## specifies initial estimate x0 + if (nargin < 7) + x = zeros (rows (b), 1); else - beta = (rho_1 / rho_2) * (alpha / omega); - p = res + beta * (p - omega * v); + x = x0; endif - phat = precon (p); + norm_b = norm (b); + + res = b - Ax (x); + rr = res; - v = A * phat; - alpha = rho_1 / (rr' * v); - s = res - alpha * v; + ## Vector of the residual norms for each iteration. + resvec = norm(res) / norm_b; - shat = precon (s); + ## Default behaviour we don't reach tolerance tol within maxit iterations. + flag = 1; + + for iter = 1:maxit + rho_1 = res' * rr; - t = A * shat; - omega = (t' * s) / (t' * t); - x = x + alpha * phat + omega * shat; - res = s - omega * t; - rho_2 = rho_1; + if (iter == 1) + p = res; + else + beta = (rho_1 / rho_2) * (alpha / omega); + p = res + beta * (p - omega * v); + endif + + phat = precon (p); - relres = norm (res) / norm_b; - resvec = [resvec; relres]; + v = Ax (phat); + alpha = rho_1 / (rr' * v); + s = res - alpha * v; + + shat = precon (s); + + t = Ax (shat); + omega = (t' * s) / (t' * t); + x = x + alpha * phat + omega * shat; + res = s - omega * t; + rho_2 = rho_1; - if (relres <= tol) - ## We reach tolerance tol within maxit iterations. - flag = 0; - break; - elseif (resvec (end) == resvec (end - 1)) - ## The method stagnates. - flag = 3; - break; - endif - endfor + relres = norm (res) / norm_b; + resvec = [resvec; relres]; + + if (relres <= tol) + ## We reach tolerance tol within maxit iterations. + flag = 0; + break; + elseif (resvec(end) == resvec(end - 1)) + ## The method stagnates. + flag = 3; + break; + endif + endfor - if (nargout < 2) - if (flag == 0) - printf (["bicgstab converged at iteration %i ", - "to a solution with relative residual %e\n"],iter,relres); - elseif (flag == 3) - printf (["bicgstab stopped at iteration %i ", - "without converging to the desired tolerance %e\n", - "because the method stagnated.\n", - "The iterate returned (number %i) has relative residual %e\n"],iter,tol,iter,relres); - else - printf (["bicgstab stopped at iteration %i ", - "without converging to the desired toleranc %e\n", - "because the maximum number of iterations was reached.\n", - "The iterate returned (number %i) has relative residual %e\n"],iter,tol,iter,relres); + if (nargout < 2) + if (flag == 0) + printf ("bicgstab converged at iteration %i ", iter); + printf ("to a solution with relative residual %e\n", relres); + elseif (flag == 3) + printf ("bicgstab stopped at iteration %i ", iter); + printf ("without converging to the desired tolerance %e\n", tol); + printf ("because the method stagnated.\n"); + printf ("The iterate returned (number %i) ", iter); + printf ("has relative residual %e\n", relres); + else + printf ("bicgstab stopped at iteration %i ", iter); + printf ("without converging to the desired toleranc %e\n", tol); + printf ("because the maximum number of iterations was reached.\n"); + printf ("The iterate returned (number %i) ", iter); + printf ("has relative residual %e\n", relres); + endif endif + + else + print_usage (); endif endfunction @@ -176,3 +201,36 @@ %! b = [7;-1;4] %! [x, flag, relres, iter, resvec] = bicgstab(A, b) +%!shared A, b, n, M1, M2 +%! +%!test +%! n = 100; +%! A = spdiags ([-2*ones(n,1) 4*ones(n,1) -ones(n,1)], -1:1, n, n); +%! b = sum (A, 2); +%! tol = 1e-8; +%! maxit = 15; +%! M1 = spdiags ([ones(n,1)/(-2) ones(n,1)],-1:0, n, n); +%! M2 = spdiags ([4*ones(n,1) -ones(n,1)], 0:1, n, n); +%! [x, flag, relres, iter, resvec] = bicgstab (A, b, tol, maxit, M1, M2); +%! assert (x, ones (size (b)), 1e-7); +%! +%!test +%! tol = 1e-8; +%! maxit = 15; +%! +%! function y = afun (x, a) +%! y = a * x; +%! endfunction +%! +%! [x, flag, relres, iter, resvec] = bicgstab (@(x) afun (x, A), b, +%! tol, maxit, M1, M2); +%! assert (x, ones (size (b)), 1e-7); + +%!test +%! n = 100; +%! tol = 1e-8; +%! a = sprand (n, n, .1); +%! A = a'*a + 100 * eye (n); +%! b = sum (A, 2); +%! [x, flag, relres, iter, resvec] = bicgstab (A, b, tol, [], diag (diag (A))); +%! assert (x, ones (size (b)), 1e-7);