diff liboctave/base-lu.cc @ 8367:445d27d79f4e

support permutation matrix objects
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 04 Dec 2008 08:31:56 +0100
parents f3c00dc0912b
children eb63fbe60fab
line wrap: on
line diff
--- a/liboctave/base-lu.cc
+++ b/liboctave/base-lu.cc
@@ -26,9 +26,9 @@
 
 #include "base-lu.h"
 
-template <class lu_type, class lu_elt_type, class p_type, class p_elt_type>
+template <class lu_type>
 lu_type
-base_lu <lu_type, lu_elt_type, p_type, p_elt_type> :: L (void) const
+base_lu <lu_type> :: L (void) const
 {
   octave_idx_type a_nr = a_fact.rows ();
   octave_idx_type a_nc = a_fact.cols ();
@@ -48,9 +48,9 @@
   return l;
 }
 
-template <class lu_type, class lu_elt_type, class p_type, class p_elt_type>
+template <class lu_type>
 lu_type
-base_lu <lu_type, lu_elt_type, p_type, p_elt_type> :: U (void) const
+base_lu <lu_type> :: U (void) const
 {
   octave_idx_type a_nr = a_fact.rows ();
   octave_idx_type a_nc = a_fact.cols ();
@@ -67,9 +67,9 @@
   return u;
 }
 
-template <class lu_type, class lu_elt_type, class p_type, class p_elt_type>
-p_type
-base_lu <lu_type, lu_elt_type, p_type, p_elt_type> :: P (void) const
+template <class lu_type>
+Array<octave_idx_type>
+base_lu <lu_type> :: getp (void) const
 {
   octave_idx_type a_nr = a_fact.rows ();
 
@@ -90,38 +90,25 @@
 	}
     }
 
-  p_type p (a_nr, a_nr, p_elt_type (0.0));
-
-  for (octave_idx_type i = 0; i < a_nr; i++)
-    p.xelem (i, pvt.xelem (i)) = 1.0;
-
-  return p;
+  return pvt;
 }
 
-template <class lu_type, class lu_elt_type, class p_type, class p_elt_type>
+template <class lu_type>
+PermMatrix
+base_lu <lu_type> :: P (void) const
+{
+  return PermMatrix (getp (), false);
+}
+
+template <class lu_type>
 ColumnVector
-base_lu <lu_type, lu_elt_type, p_type, p_elt_type> :: P_vec (void) const
+base_lu <lu_type> :: P_vec (void) const
 {
   octave_idx_type a_nr = a_fact.rows ();
 
-  Array<octave_idx_type> pvt (a_nr);
-
-  for (octave_idx_type i = 0; i < a_nr; i++)
-    pvt.xelem (i) = i;
-
-  for (octave_idx_type i = 0; i < ipvt.length(); i++)
-    {
-      octave_idx_type k = ipvt.xelem (i);
+  ColumnVector p (a_nr);
 
-      if (k != i)
-	{
-	  octave_idx_type tmp = pvt.xelem (k);
-	  pvt.xelem (k) = pvt.xelem (i);
-	  pvt.xelem (i) = tmp;
-	}
-    }
-
-  ColumnVector p (a_nr);
+  Array<octave_idx_type> pvt = getp ();
 
   for (octave_idx_type i = 0; i < a_nr; i++)
     p.xelem (i) = static_cast<double> (pvt.xelem (i) + 1);