diff src/DLD-FUNCTIONS/lu.cc @ 9708:6f3ffe11d926

implement luupdate
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 08 Oct 2009 16:05:53 +0200
parents 923c7cb7f13f
children f8e2e9fdaa8f
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/lu.cc
+++ b/src/DLD-FUNCTIONS/lu.cc
@@ -584,6 +584,149 @@
 
  */
 
+static
+bool check_lu_dims (const octave_value& l, const octave_value& u,
+                    const octave_value& p)
+{
+  octave_idx_type m = l.rows (), k = u.rows (), n = u.columns ();
+  return ((l.ndims () == 2 && u.ndims () == 2 && k == l.columns ())
+            && k == std::min (m, n) &&
+            (p.is_undefined () || p.rows () == m));
+}
+
+DEFUN_DLD (luupdate, args, nargout,
+  "-*- texinfo -*-\n\
+@deftypefn {Loadable Function} {[@var{l}, @var{u}] =} luupdate (@var{l}, @var{u}, @var{x}, @var{y})\n\
+@deftypefn {Loadable Function} {[@var{l}, @var{u}, @var{p}] =}\
+luupdate (@var{l}, @var{u}, @var{p}, @var{x}, @var{y})\n\
+")
+{
+  octave_idx_type nargin = args.length ();
+  octave_value_list retval;
+
+  bool pivoted = nargin == 5;
+
+  if (nargin != 4 && nargin != 5)
+    {
+      print_usage ();
+      return retval;
+    }
+
+  octave_value argl = args(0);
+  octave_value argu = args(1);
+  octave_value argp = pivoted ? args(2) : octave_value ();
+  octave_value argx = args(2 + pivoted);
+  octave_value argy = args(3 + pivoted);
+
+  if (argl.is_numeric_type () && argu.is_numeric_type () 
+      && argx.is_numeric_type () && argy.is_numeric_type ()
+      && (! pivoted || argp.is_perm_matrix ()))
+    {
+      if (check_lu_dims (argl, argu, argp))
+        {
+          PermMatrix P = (pivoted 
+                          ? argp.perm_matrix_value () 
+                          : PermMatrix::eye (argl.rows ()));
+
+          if (argl.is_real_type () 
+	      && argu.is_real_type () 
+	      && argx.is_real_type () 
+	      && argy.is_real_type ())
+            {
+	      // all real case
+	      if (argl.is_single_type () 
+		  || argu.is_single_type () 
+		  || argx.is_single_type () 
+		  || argy.is_single_type ())
+		{
+		  FloatMatrix L = argl.float_matrix_value ();
+		  FloatMatrix U = argu.float_matrix_value ();
+		  FloatMatrix x = argx.float_matrix_value ();
+		  FloatMatrix y = argy.float_matrix_value ();
+
+		  FloatLU fact (L, U, P);
+                  if (pivoted)
+                    fact.update_piv (x, y);
+                  else
+                    fact.update (x, y);
+
+                  if (pivoted)
+                    retval(2) = fact.P ();
+		  retval(1) = fact.U ();
+		  retval(0) = fact.L ();
+		}
+	      else
+		{
+		  Matrix L = argl.matrix_value ();
+		  Matrix U = argu.matrix_value ();
+		  Matrix x = argx.matrix_value ();
+		  Matrix y = argy.matrix_value ();
+
+		  LU fact (L, U, P);
+                  if (pivoted)
+                    fact.update_piv (x, y);
+                  else
+                    fact.update (x, y);
+
+                  if (pivoted)
+                    retval(2) = fact.P ();
+		  retval(1) = fact.U ();
+		  retval(0) = fact.L ();
+		}
+            }
+          else
+            {
+              // complex case
+	      if (argl.is_single_type () 
+		  || argu.is_single_type () 
+		  || argx.is_single_type () 
+		  || argy.is_single_type ())
+		{
+		  FloatComplexMatrix L = argl.float_complex_matrix_value ();
+		  FloatComplexMatrix U = argu.float_complex_matrix_value ();
+		  FloatComplexMatrix x = argx.float_complex_matrix_value ();
+		  FloatComplexMatrix y = argy.float_complex_matrix_value ();
+
+		  FloatComplexLU fact (L, U, P);
+                  if (pivoted)
+                    fact.update_piv (x, y);
+                  else
+                    fact.update (x, y);
+              
+                  if (pivoted)
+                    retval(2) = fact.P ();
+		  retval(1) = fact.U ();
+		  retval(0) = fact.L ();
+		}
+	      else
+		{
+		  ComplexMatrix L = argl.complex_matrix_value ();
+		  ComplexMatrix U = argu.complex_matrix_value ();
+		  ComplexMatrix x = argx.complex_matrix_value ();
+		  ComplexMatrix y = argy.complex_matrix_value ();
+
+		  ComplexLU fact (L, U, P);
+                  if (pivoted)
+                    fact.update_piv (x, y);
+                  else
+                    fact.update (x, y);
+              
+                  if (pivoted)
+                    retval(2) = fact.P ();
+		  retval(1) = fact.U ();
+		  retval(0) = fact.L ();
+		}
+            }
+        }
+      else
+	error ("luupdate: dimensions mismatch");
+    }
+  else
+    error ("luupdate: expecting numeric arguments");
+
+  return retval;
+}
+
 /*
 ;;; Local Variables: ***
 ;;; mode: C++ ***