diff liboctave/base-lu.cc @ 9694:50db3c5175b5

allow unpacked form of LU
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 05 Oct 2009 15:39:44 +0200
parents eb63fbe60fab
children 51c17bd18563
line wrap: on
line diff
--- a/liboctave/base-lu.cc
+++ b/liboctave/base-lu.cc
@@ -27,70 +27,121 @@
 #include "base-lu.h"
 
 template <class lu_type>
+base_lu<lu_type>::base_lu (const lu_type& l, const lu_type& u, 
+                           const PermMatrix& p)
+  : a_fact (u), l_fact (l), ipvt (p.pvec ())
+{
+  if (l.columns () != u.rows ())
+    (*current_liboctave_error_handler) ("lu: dimension mismatch");
+}
+
+template <class lu_type>
+bool
+base_lu <lu_type> :: packed (void) const
+{
+  return l_fact.dims () == dim_vector ();
+}
+
+template <class lu_type>
+void
+base_lu <lu_type> :: unpack (void)
+{
+  if (packed ())
+    {
+      l_fact = L ();
+      a_fact = U (); // FIXME: sub-optimal
+    }
+}
+
+template <class lu_type>
 lu_type
 base_lu <lu_type> :: L (void) const
 {
-  octave_idx_type a_nr = a_fact.rows ();
-  octave_idx_type a_nc = a_fact.cols ();
-  octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc);
-
-  lu_type l (a_nr, mn, lu_elt_type (0.0));
-
-  for (octave_idx_type i = 0; i < a_nr; i++)
+  if (packed ())
     {
-      if (i < a_nc)
-	l.xelem (i, i) = 1.0;
+      octave_idx_type a_nr = a_fact.rows ();
+      octave_idx_type a_nc = a_fact.cols ();
+      octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc);
+
+      lu_type l (a_nr, mn, lu_elt_type (0.0));
 
-      for (octave_idx_type j = 0; j < (i < a_nc ? i : a_nc); j++)
-	l.xelem (i, j) = a_fact.xelem (i, j);
+      for (octave_idx_type i = 0; i < a_nr; i++)
+        {
+          if (i < a_nc)
+            l.xelem (i, i) = 1.0;
+
+          for (octave_idx_type j = 0; j < (i < a_nc ? i : a_nc); j++)
+            l.xelem (i, j) = a_fact.xelem (i, j);
+        }
+
+      return l;
     }
-
-  return l;
+  else
+    return l_fact;
 }
 
 template <class lu_type>
 lu_type
 base_lu <lu_type> :: U (void) const
 {
-  octave_idx_type a_nr = a_fact.rows ();
-  octave_idx_type a_nc = a_fact.cols ();
-  octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc);
+  if (packed ())
+    {
+      octave_idx_type a_nr = a_fact.rows ();
+      octave_idx_type a_nc = a_fact.cols ();
+      octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc);
 
-  lu_type u (mn, a_nc, lu_elt_type (0.0));
+      lu_type u (mn, a_nc, lu_elt_type (0.0));
+
+      for (octave_idx_type i = 0; i < mn; i++)
+        {
+          for (octave_idx_type j = i; j < a_nc; j++)
+            u.xelem (i, j) = a_fact.xelem (i, j);
+        }
 
-  for (octave_idx_type i = 0; i < mn; i++)
-    {
-      for (octave_idx_type j = i; j < a_nc; j++)
-	u.xelem (i, j) = a_fact.xelem (i, j);
+      return u;
     }
+  else
+    return a_fact;
+}
 
-  return u;
+template <class lu_type>
+lu_type
+base_lu <lu_type> :: Y (void) const
+{
+  if (! packed ())
+    (*current_liboctave_error_handler) ("lu: Y() not implemented for unpacked form.");
+  return a_fact;
 }
 
 template <class lu_type>
 Array<octave_idx_type>
 base_lu <lu_type> :: getp (void) const
 {
-  octave_idx_type a_nr = a_fact.rows ();
-
-  Array<octave_idx_type> pvt (a_nr);
-
-  for (octave_idx_type i = 0; i < a_nr; i++)
-    pvt.xelem (i) = i;
-
-  for (octave_idx_type i = 0; i < ipvt.length(); i++)
+  if (packed ())
     {
-      octave_idx_type k = ipvt.xelem (i);
+      octave_idx_type a_nr = a_fact.rows ();
+
+      Array<octave_idx_type> pvt (a_nr);
+
+      for (octave_idx_type i = 0; i < a_nr; i++)
+        pvt.xelem (i) = i;
+
+      for (octave_idx_type i = 0; i < ipvt.length(); i++)
+        {
+          octave_idx_type k = ipvt.xelem (i);
 
-      if (k != i)
-	{
-	  octave_idx_type tmp = pvt.xelem (k);
-	  pvt.xelem (k) = pvt.xelem (i);
-	  pvt.xelem (i) = tmp;
-	}
+          if (k != i)
+            {
+              octave_idx_type tmp = pvt.xelem (k);
+              pvt.xelem (k) = pvt.xelem (i);
+              pvt.xelem (i) = tmp;
+            }
+        }
+
+      return pvt;
     }
-
-  return pvt;
+  else
+    return ipvt;
 }
 
 template <class lu_type>