Mercurial > hg > octave-lyh
diff src/DLD-FUNCTIONS/lu.cc @ 9708:6f3ffe11d926
implement luupdate
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Thu, 08 Oct 2009 16:05:53 +0200 |
parents | 923c7cb7f13f |
children | f8e2e9fdaa8f |
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/lu.cc +++ b/src/DLD-FUNCTIONS/lu.cc @@ -584,6 +584,149 @@ */ +static +bool check_lu_dims (const octave_value& l, const octave_value& u, + const octave_value& p) +{ + octave_idx_type m = l.rows (), k = u.rows (), n = u.columns (); + return ((l.ndims () == 2 && u.ndims () == 2 && k == l.columns ()) + && k == std::min (m, n) && + (p.is_undefined () || p.rows () == m)); +} + +DEFUN_DLD (luupdate, args, nargout, + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {[@var{l}, @var{u}] =} luupdate (@var{l}, @var{u}, @var{x}, @var{y})\n\ +@deftypefn {Loadable Function} {[@var{l}, @var{u}, @var{p}] =}\ +luupdate (@var{l}, @var{u}, @var{p}, @var{x}, @var{y})\n\ +") +{ + octave_idx_type nargin = args.length (); + octave_value_list retval; + + bool pivoted = nargin == 5; + + if (nargin != 4 && nargin != 5) + { + print_usage (); + return retval; + } + + octave_value argl = args(0); + octave_value argu = args(1); + octave_value argp = pivoted ? args(2) : octave_value (); + octave_value argx = args(2 + pivoted); + octave_value argy = args(3 + pivoted); + + if (argl.is_numeric_type () && argu.is_numeric_type () + && argx.is_numeric_type () && argy.is_numeric_type () + && (! pivoted || argp.is_perm_matrix ())) + { + if (check_lu_dims (argl, argu, argp)) + { + PermMatrix P = (pivoted + ? argp.perm_matrix_value () + : PermMatrix::eye (argl.rows ())); + + if (argl.is_real_type () + && argu.is_real_type () + && argx.is_real_type () + && argy.is_real_type ()) + { + // all real case + if (argl.is_single_type () + || argu.is_single_type () + || argx.is_single_type () + || argy.is_single_type ()) + { + FloatMatrix L = argl.float_matrix_value (); + FloatMatrix U = argu.float_matrix_value (); + FloatMatrix x = argx.float_matrix_value (); + FloatMatrix y = argy.float_matrix_value (); + + FloatLU fact (L, U, P); + if (pivoted) + fact.update_piv (x, y); + else + fact.update (x, y); + + if (pivoted) + retval(2) = fact.P (); + retval(1) = fact.U (); + retval(0) = fact.L (); + } + else + { + Matrix L = argl.matrix_value (); + Matrix U = argu.matrix_value (); + Matrix x = argx.matrix_value (); + Matrix y = argy.matrix_value (); + + LU fact (L, U, P); + if (pivoted) + fact.update_piv (x, y); + else + fact.update (x, y); + + if (pivoted) + retval(2) = fact.P (); + retval(1) = fact.U (); + retval(0) = fact.L (); + } + } + else + { + // complex case + if (argl.is_single_type () + || argu.is_single_type () + || argx.is_single_type () + || argy.is_single_type ()) + { + FloatComplexMatrix L = argl.float_complex_matrix_value (); + FloatComplexMatrix U = argu.float_complex_matrix_value (); + FloatComplexMatrix x = argx.float_complex_matrix_value (); + FloatComplexMatrix y = argy.float_complex_matrix_value (); + + FloatComplexLU fact (L, U, P); + if (pivoted) + fact.update_piv (x, y); + else + fact.update (x, y); + + if (pivoted) + retval(2) = fact.P (); + retval(1) = fact.U (); + retval(0) = fact.L (); + } + else + { + ComplexMatrix L = argl.complex_matrix_value (); + ComplexMatrix U = argu.complex_matrix_value (); + ComplexMatrix x = argx.complex_matrix_value (); + ComplexMatrix y = argy.complex_matrix_value (); + + ComplexLU fact (L, U, P); + if (pivoted) + fact.update_piv (x, y); + else + fact.update (x, y); + + if (pivoted) + retval(2) = fact.P (); + retval(1) = fact.U (); + retval(0) = fact.L (); + } + } + } + else + error ("luupdate: dimensions mismatch"); + } + else + error ("luupdate: expecting numeric arguments"); + + return retval; +} + /* ;;; Local Variables: *** ;;; mode: C++ ***