diff liboctave/numeric/floatCHOL.cc @ 20672:5ce959c55cc0

Propagate 'lower' in chol(a, 'lower') to underlying library function. * chol.cc (chol): Send 'L' parameter correctly when chol is called with 'lower'. * floatCHOL.cc (init): Propagate 'lower' to underlying library function. * floatCHOL.h: Modify the prototype of methods. * fMatrix.cc (inverse): Invoke chol with additional parameter. * dbleCHOL.cc (init): Propagate 'lower' to underlying library function. * dbleCHOL.h: Modify the prototype of methods. * dMatrix.cc (inverse): Invoke chol with additional parameter. * CmplxCHOL.cc (init): Propagate 'lower' to underlying library function. * CmplxCHOL.h: Modify the prototype of methods. * CMatrix.cc (inverse): Invoke chol with additional parameter.
author PrasannaKumar Muralidharan <prasannatsmkumar@gmail.com>
date Sun, 24 Aug 2014 19:35:06 +0530
parents a9574e3c6e9e
children dcfbf4c1c3c8
line wrap: on
line diff
--- a/liboctave/numeric/floatCHOL.cc
+++ b/liboctave/numeric/floatCHOL.cc
@@ -87,7 +87,7 @@
 }
 
 octave_idx_type
-FloatCHOL::init (const FloatMatrix& a, bool calc_cond)
+FloatCHOL::init (const FloatMatrix& a, bool upper, bool calc_cond)
 {
   octave_idx_type a_nr = a.rows ();
   octave_idx_type a_nc = a.cols ();
@@ -101,14 +101,30 @@
   octave_idx_type n = a_nc;
   octave_idx_type info;
 
+  is_upper = upper;
+
   chol_mat.clear (n, n);
-  for (octave_idx_type j = 0; j < n; j++)
+  if (is_upper)
     {
-      for (octave_idx_type i = 0; i <= j; i++)
-        chol_mat.xelem (i, j) = a(i, j);
-      for (octave_idx_type i = j+1; i < n; i++)
-        chol_mat.xelem (i, j) = 0.0f;
+      for (octave_idx_type j = 0; j < n; j++)
+        {
+          for (octave_idx_type i = 0; i <= j; i++)
+            chol_mat.xelem (i, j) = a(i, j);
+          for (octave_idx_type i = j+1; i < n; i++)
+            chol_mat.xelem (i, j) = 0.0f;
+        }
     }
+  else
+    {
+      for (octave_idx_type j = 0; j < n; j++)
+        {
+          for (octave_idx_type i = 0; i <= j; i++)
+            chol_mat.xelem (i, j) = 0.0f;
+          for (octave_idx_type i = j+1; i < n; i++)
+            chol_mat.xelem (i, j) = a(i, j);
+        }
+    }
+
   float *h = chol_mat.fortran_vec ();
 
   // Calculate the norm of the matrix, for later use.
@@ -116,9 +132,18 @@
   if (calc_cond)
     anorm = xnorm (a, 1);
 
-  F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1),
-                             n, h, n, info
-                             F77_CHAR_ARG_LEN (1)));
+  if (is_upper)
+    {
+      F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1),
+                                 n, h, n, info
+                                 F77_CHAR_ARG_LEN (1)));   
+    }
+  else
+    {
+      F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1),
+                                 n, h, n, info
+                                 F77_CHAR_ARG_LEN (1)));   
+    }
 
   xrcond = 0.0;
   if (info > 0)
@@ -132,9 +157,19 @@
       float *pz = z.fortran_vec ();
       Array<octave_idx_type> iz (dim_vector (n, 1));
       octave_idx_type *piz = iz.fortran_vec ();
-      F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
-                                 n, anorm, xrcond, pz, piz, spocon_info
-                                 F77_CHAR_ARG_LEN (1)));
+      if (is_upper)
+        {
+          F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
+                                     n, anorm, xrcond, pz, piz, spocon_info
+                                     F77_CHAR_ARG_LEN (1)));       
+        }
+      else
+        {
+          F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("L", 1), n, h,
+                                     n, anorm, xrcond, pz, piz, spocon_info
+                                     F77_CHAR_ARG_LEN (1)));       
+        }
+
 
       if (spocon_info != 0)
         info = -1;
@@ -144,7 +179,7 @@
 }
 
 static FloatMatrix
-chol2inv_internal (const FloatMatrix& r)
+chol2inv_internal (const FloatMatrix& r, bool is_upper = true)
 {
   FloatMatrix retval;
 
@@ -161,17 +196,37 @@
 
       if (info == 0)
         {
-          F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
-                                     v, n, info
-                                     F77_CHAR_ARG_LEN (1)));
+          if (is_upper)
+            {
+              F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
+                                         v, n, info
+                                         F77_CHAR_ARG_LEN (1)));
+            }
+          else
+            {
+              F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
+                                         v, n, info
+                                         F77_CHAR_ARG_LEN (1)));
+            }
 
           // If someone thinks of a more graceful way of doing this (or
           // faster for that matter :-)), please let me know!
 
           if (n > 1)
-            for (octave_idx_type j = 0; j < r_nc; j++)
-              for (octave_idx_type i = j+1; i < r_nr; i++)
-                tmp.xelem (i, j) = tmp.xelem (j, i);
+            {
+              if (is_upper)
+                {
+                  for (octave_idx_type j = 0; j < r_nc; j++)
+                    for (octave_idx_type i = j+1; i < r_nr; i++)
+                      tmp.xelem (i, j) = tmp.xelem (j, i); 
+                }
+              else
+                {
+                  for (octave_idx_type j = 0; j < r_nc; j++)
+                    for (octave_idx_type i = j+1; i < r_nr; i++)
+                      tmp.xelem (j, i) = tmp.xelem (i, j);
+                }
+            }
 
           retval = tmp;
         }
@@ -186,7 +241,7 @@
 FloatMatrix
 FloatCHOL::inverse (void) const
 {
-  return chol2inv_internal (chol_mat);
+  return chol2inv_internal (chol_mat, is_upper);
 }
 
 void