changeset 13061:addfc0ae69c0

Make bicgstab interface more compatible * bicgstab.m: Add the possibility to pass a function handle for the coefficient matrix. Also add more tests.
author Carlo de Falco <kingcrimson@tiscali.it>
date Sat, 03 Sep 2011 20:06:30 +0200
parents 85dd509673e7
children b3a8b75dfec3
files scripts/sparse/bicgstab.m
diffstat 1 files changed, 183 insertions(+), 125 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/sparse/bicgstab.m
+++ b/scripts/sparse/bicgstab.m
@@ -1,4 +1,5 @@
 ## Copyright (C) 2008-2011 Radek Salac
+## Copyright (C) 2011 Carlo de Falco
 ##
 ## This file is part of Octave.
 ##
@@ -17,155 +18,179 @@
 ## <http://www.gnu.org/licenses/>.
 
 ## -*- texinfo -*-
-## @deftypefn  {Function File} {} bicgstab (@var{A}, @var{b})
-## @deftypefnx {Function File} {} bicgstab (@var{A}, @var{b}, @var{tol}, @var{maxit}, @var{M1}, @var{M2}, @var{x0})
-## This procedure attempts to solve a system of linear equations A*x = b for x.
-## The @var{A} must be square, symmetric and positive definite real matrix N*N.
-## The @var{b} must be a one column vector with a length of N.
-## The @var{tol} specifies the tolerance of the method, the default value is
-## 1e-6.
-## The @var{maxit} specifies the maximum number of iterations, the default value
-## is min(20,N).
-## The @var{M1} specifies a preconditioner, can also be a function handler which
-## returns M\X.
-## The @var{M2} combined with @var{M1} defines preconditioner as
-## preconditioner=M1*M2.
-## The @var{x0} is the initial guess, the default value is zeros(N,1).
+##  
+## @deftypefn {Function File} {@var{x} =} bicgstab (@var{A}, @var{b}, @var{rtol}, @var{maxit}, @var{M1}, @var{M2}, @var{x0})
+## @deftypefnx {Function File} {@var{x} =} bicgstab (@var{A}, @var{b}, @var{rtol}, @var{maxit}, @var{P})
+## @deftypefnx {Function File} {[@var{x}, @var{flag}, @var{relres}, @var{iter}, @var{resvec}] =} bicgstab (@var{A}, @var{b}, ...)
+##
+##   Solve @code{A x = b} using the stabilizied Bi-conjugate gradient iterative method.
+##
+##   @itemize @minus
+##   @item @var{rtol} is the relative tolerance, if not given or set to [] the default value 1e-6 is used.
+##   @item @var{maxit} the maximum number of outer iterations, if not given or set to [] the default value @code{min (20, numel (b))} is used.
+##   @item @var{x0} the initial guess, if not given or set to [] the default value @code{zeros (size (b))} is used. 
+##   @end itemize
+##
+##   @var{A} can be passed as a matrix or as a function handle or 
+##   inline function @code{f} such that @code{f(x) = A*x}.
 ##
-## The value @var{x} is a computed result of this procedure.
-## The value @var{flag} can be 0 when we reach tolerance in @var{maxit}
-## iterations, 1 when
-## we don't reach tolerance in @var{maxit} iterations and 3 when the procedure
-## stagnates.
-## The value @var{relres} is a relative residual - norm(b-A*x)/norm(b).
-## The value @var{iter} is an iteration number in which x was computed.
-## The value @var{resvec} is a vector of @var{relres} for each iteration.
+##   The preconditioner @var{P} is given as @code{P = M1 * M2}. 
+##   Both @var{M1} and @var{M2} can be passed as a matrix or as a function handle or 
+##   inline function @code{g} such that @code{g(x) = M1 \ x} or @code{g(x) = M2 \ x}.
+##
+##   If called with more than one output parameter
+##
+##   @itemize @minus
+##   @item @var{flag} indicates the exit status:
+##   @itemize @minus
+##     @item 0: iteration converged to the within the chosen tolerance
+##     @item 1: the maximum number of iterations was reached before convergence
+##     @item 3: the algorithm reached stagnation
+##   @end itemize
+##   (the value 2 is unused but skipped for compatibility).
+##   @item @var{relres} is the final value of the relative residual.
+##   @item @var{iter} is the number of iterations performed. 
+##   @item @var{resvec} is a vector containing the relative residual at each iteration.
+##   @end itemize
+##
+##   @seealso{pcg,cgs,bigc,gmres}
 ##
 ## @end deftypefn
 
-function [x, flag, relres, iter, resvec] = bicgstab (A, b, tol, maxit, M1, M2, x0)
+function [x, flag, relres, iter, resvec] = bicgstab (A, b, tol, maxit, 
+                                                     M1, M2, x0)
 
-  if (nargin < 2 || nargin > 7 || nargout > 5)
-    print_usage ();
-  elseif (!(isnumeric (A) && issquare (A)))
-    error ("bicgstab: A must be a square numeric matrix");
-  elseif (!isvector (b))
-    error ("bicgstab: B must be a vector");
-  elseif (!any (b))
-    error ("bicgstab: B must not be a vector of all zeros");
-  elseif (rows (A) != rows (b))
-    error ("bicgstab: A and B must have the same number of rows");
-  elseif (nargin > 2 && !isscalar (tol))
-    error ("bicgstab: TOL must be a scalar");
-  elseif (nargin > 3 && !isscalar (maxit))
-    error ("bicgstab: MAXIT must be a scalar");
-  elseif (nargin > 4 && ismatrix (M1) && (rows (M1) != rows (A) || columns (M1) != columns (A)))
-    error ("bicgstab: M1 must have the same number of rows and columns as A");
-  elseif (nargin > 5 && (!ismatrix (M2) || rows (M2) != rows (A) || columns (M2) != columns (A)))
-    error ("bicgstab: M2 must have the same number of rows and columns as A");
-  elseif (nargin > 6 && !isvector (x0))
-    error ("bicgstab: X0 must be a vector");
-  elseif (nargin > 6 && rows (x0) != rows (b))
-    error ("bicgstab: X0 must have the same number of rows as B");
-  endif
+  if ((nargin >= 2) && (nargin <= 7) && isvector (full (b)))
+    
+    if (ischar (A))
+      A = str2func (A);
+    elseif (ismatrix (A))
+      Ax  = @(x) A  * x;
+    elseif (isa (A, "function_handle"))
+      Ax  = @(x) feval (A, x);
+    else
+      error (["bicgstab: first argument is expected " ...
+              "to be a function or a square matrix"]);
+    endif
+    
+    if ((nargin < 3) || (isempty (tol)))
+      tol = 1e-6;
+    endif
 
-  ## Default tolerance.
-  if (nargin < 3)
-    tol = 1e-6;
-  endif
-
-  ## Default maximum number of iteration.
-  if (nargin < 4)
-    maxit = min (rows (b), 20);
-  endif
+    if ((nargin < 4) || (isempty (maxit)))
+      maxit = min (rows (b), 20);
+    endif
 
-  ## Left preconditioner.
-  if (nargin == 5)
-    if (isnumeric (M1))
-      precon = @(x) M1 \ x;
+    if ((nargin < 5) || isempty (M1))
+      M1m1x = @(x) x;
+    elseif (ischar (M1))
+      M1m1x = str2func (M1);
+    elseif (ismatrix (M1))
+      M1m1x  = @(x) M1  \ x;
+    elseif (isa (M1, "function_handle"))
+      M1m1x  = @(x) feval (M1, x);
+    else
+      error (["bicgstab: preconditioner is " ...
+              "expected to be a function or matrix"]);
     endif
-  elseif (nargin > 5)
-    if (issparse (M1) && issparse (M2))
-      precon = @(x) M2 \ (M1 \ x);
+    
+    if ((nargin < 6) || isempty (M2))
+      M2m1x = @(x) x;
+    elseif (ischar (M2))
+      M2m1x = str2func (M2);
+    elseif (ismatrix (M2))
+      M2m1x  = @(x) M2  \ x;
+    elseif (isa (M2, "function_handle"))
+      M2m1x  = @(x) feval (M2, x);
     else
-      M = M1*M2;
-      precon = @(x) M \ x;
+      error (["bicgstab: preconditioner is "...
+              "expected to be a function or matrix"]);
     endif
-  else
-    precon = @(x) x;
-  endif
 
-  ## specifies initial estimate x0
-  if (nargin < 7)
-    x = zeros (rows (b), 1);
-  else
-    x = x0;
-  endif
+    precon  = @(x) M2m1x (M1m1x (x));
 
-  norm_b = norm (b);
-
-  res = b - A*x;
-  rr = res;
+    if ((nargin < 7) || (isempty (x0)))
+      x0 = zeros (size (b));
+    endif
 
-  ## Vector of the residual norms for each iteration.
-  resvec = [norm(res)/norm_b];
 
-  ## Default behaviour we don't reach tolerance tol within maxit iterations.
-  flag = 1;
-
-  for iter = 1:maxit
-    rho_1 = res' * rr;
-
-    if (iter == 1)
-      p = res;
+    ## specifies initial estimate x0
+    if (nargin < 7)
+      x = zeros (rows (b), 1);
     else
-      beta = (rho_1 / rho_2) * (alpha / omega);
-      p = res + beta * (p - omega * v);
+      x = x0;
     endif
 
-    phat = precon (p);
+    norm_b = norm (b);
+
+    res = b - Ax (x);
+    rr = res;
 
-    v = A * phat;
-    alpha = rho_1 / (rr' * v);
-    s = res - alpha * v;
+    ## Vector of the residual norms for each iteration.
+    resvec = norm(res) / norm_b;
 
-    shat = precon (s);
+    ## Default behaviour we don't reach tolerance tol within maxit iterations.
+    flag = 1;
+
+    for iter = 1:maxit
+      rho_1 = res' * rr;
 
-    t = A * shat;
-    omega = (t' * s) / (t' * t);
-    x = x + alpha * phat + omega * shat;
-    res = s - omega * t;
-    rho_2 = rho_1;
+      if (iter == 1)
+        p = res;
+      else
+        beta = (rho_1 / rho_2) * (alpha / omega);
+        p = res + beta * (p - omega * v);
+      endif
+
+      phat = precon (p);
 
-    relres = norm (res) / norm_b;
-    resvec = [resvec; relres];
+      v = Ax (phat);
+      alpha = rho_1 / (rr' * v);
+      s = res - alpha * v;
+
+      shat = precon (s);
+
+      t = Ax (shat);
+      omega = (t' * s) / (t' * t);
+      x = x + alpha * phat + omega * shat;
+      res = s - omega * t;
+      rho_2 = rho_1;
 
-    if (relres <= tol)
-      ## We reach tolerance tol within maxit iterations.
-      flag = 0;
-      break;
-    elseif (resvec (end) == resvec (end - 1))
-      ## The method stagnates.
-      flag = 3;
-      break;
-    endif
-  endfor
+      relres = norm (res) / norm_b;
+      resvec = [resvec; relres];
+
+      if (relres <= tol)
+        ## We reach tolerance tol within maxit iterations.
+        flag = 0;
+        break;
+      elseif (resvec(end) == resvec(end - 1))
+        ## The method stagnates.
+        flag = 3;
+        break;
+      endif
+    endfor
 
-  if (nargout < 2)
-    if (flag == 0)
-      printf (["bicgstab converged at iteration %i ",
-      "to a solution with relative residual %e\n"],iter,relres);
-    elseif (flag == 3)
-      printf (["bicgstab stopped at iteration %i ",
-      "without converging to the desired tolerance %e\n",
-      "because the method stagnated.\n",
-      "The iterate returned (number %i) has relative residual %e\n"],iter,tol,iter,relres);
-    else
-      printf (["bicgstab stopped at iteration %i ",
-      "without converging to the desired toleranc %e\n",
-      "because the maximum number of iterations was reached.\n",
-      "The iterate returned (number %i) has relative residual %e\n"],iter,tol,iter,relres);
+    if (nargout < 2)
+      if (flag == 0)
+        printf ("bicgstab converged at iteration %i ", iter);
+        printf ("to a solution with relative residual %e\n", relres);
+      elseif (flag == 3)
+        printf ("bicgstab stopped at iteration %i ", iter);
+        printf ("without converging to the desired tolerance %e\n", tol);
+        printf ("because the method stagnated.\n");
+        printf ("The iterate returned (number %i) ", iter);
+        printf ("has relative residual %e\n", relres);
+      else
+        printf ("bicgstab stopped at iteration %i ", iter);
+        printf ("without converging to the desired toleranc %e\n", tol);
+        printf ("because the maximum number of iterations was reached.\n");
+        printf ("The iterate returned (number %i) ", iter);
+        printf ("has relative residual %e\n", relres);
+      endif
     endif
+
+  else
+    print_usage ();
   endif
 
 endfunction
@@ -176,3 +201,36 @@
 %! b = [7;-1;4]
 %! [x, flag, relres, iter, resvec] = bicgstab(A, b)
 
+%!shared A, b, n, M1, M2
+%!
+%!test
+%! n = 100; 
+%! A = spdiags ([-2*ones(n,1) 4*ones(n,1) -ones(n,1)], -1:1, n, n);
+%! b = sum (A, 2); 
+%! tol = 1e-8; 
+%! maxit = 15;
+%! M1 = spdiags ([ones(n,1)/(-2) ones(n,1)],-1:0, n, n); 
+%! M2 = spdiags ([4*ones(n,1) -ones(n,1)], 0:1, n, n); 
+%! [x, flag, relres, iter, resvec] = bicgstab (A, b, tol, maxit, M1, M2);
+%! assert (x, ones (size (b)), 1e-7);
+%!
+%!test
+%! tol = 1e-8; 
+%! maxit = 15;
+%!
+%! function y = afun (x, a)
+%!     y = a * x;
+%! endfunction
+%!
+%! [x, flag, relres, iter, resvec] = bicgstab (@(x) afun (x, A), b, 
+%!                                             tol, maxit, M1, M2);
+%! assert (x, ones (size (b)), 1e-7);
+
+%!test
+%! n = 100; 
+%! tol = 1e-8; 
+%! a = sprand (n, n, .1);
+%! A = a'*a + 100 * eye (n);
+%! b = sum (A, 2); 
+%! [x, flag, relres, iter, resvec] = bicgstab (A, b, tol, [], diag (diag (A)));
+%! assert (x, ones (size (b)), 1e-7);