diff liboctave/dbleCHOL.cc @ 6486:e978a9233cf6

[project @ 2007-04-04 15:16:46 by jwe]
author jwe
date Wed, 04 Apr 2007 15:17:51 +0000
parents 15843d76156d
children 93c65f2a5668
line wrap: on
line diff
--- a/liboctave/dbleCHOL.cc
+++ b/liboctave/dbleCHOL.cc
@@ -25,6 +25,7 @@
 #include <config.h>
 #endif
 
+#include "dRowVector.h"
 #include "dbleCHOL.h"
 #include "f77-fcn.h"
 #include "lo-error.h"
@@ -40,10 +41,16 @@
   F77_FUNC (dpotri, DPOTRI) (F77_CONST_CHAR_ARG_DECL, const octave_idx_type&,
 			     double*, const octave_idx_type&, octave_idx_type&
 			     F77_CHAR_ARG_LEN_DECL);
+
+  F77_RET_T
+  F77_FUNC (dpocon, DPOCON) (F77_CONST_CHAR_ARG_DECL, const octave_idx_type&,
+			     double*, const octave_idx_type&, const double&,
+			     double&, double*, octave_idx_type*, 
+			     octave_idx_type& F77_CHAR_ARG_LEN_DECL);
 }
 
 octave_idx_type
-CHOL::init (const Matrix& a)
+CHOL::init (const Matrix& a, bool calc_cond)
 {
   octave_idx_type a_nr = a.rows ();
   octave_idx_type a_nc = a.cols ();
@@ -60,6 +67,11 @@
   chol_mat = a;
   double *h = chol_mat.fortran_vec ();
 
+  // Calculate the norm of the matrix, for later use.
+  double anorm = 0;
+  if (calc_cond) 
+    anorm = chol_mat.abs().sum().row(static_cast<octave_idx_type>(0)).max();
+
   F77_XFCN (dpotrf, DPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1),
 			     n, h, n, info
 			     F77_CHAR_ARG_LEN (1)));
@@ -68,13 +80,39 @@
     (*current_liboctave_error_handler) ("unrecoverable error in dpotrf");
   else
     {
-      // If someone thinks of a more graceful way of doing this (or
-      // faster for that matter :-)), please let me know!
+      xrcond = 0.0;
+      if (info != 0)
+	info = -1;
+      else if (calc_cond) 
+	{
+	  octave_idx_type dpocon_info = 0;
+
+	  // Now calculate the condition number for non-singular matrix.
+	  Array<double> z (3*n);
+	  double *pz = z.fortran_vec ();
+	  Array<octave_idx_type> iz (n);
+	  octave_idx_type *piz = iz.fortran_vec ();
+	  F77_XFCN (dpocon, DPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
+				     n, anorm, xrcond, pz, piz, dpocon_info
+				     F77_CHAR_ARG_LEN (1)));
 
-      if (n > 1)
-	for (octave_idx_type j = 0; j < a_nc; j++)
-	  for (octave_idx_type i = j+1; i < a_nr; i++)
-	    chol_mat.xelem (i, j) = 0.0;
+	  if (f77_exception_encountered)
+	    (*current_liboctave_error_handler) 
+	      ("unrecoverable error in dpocon");
+
+	  if (dpocon_info != 0) 
+	    info = -1;
+	}
+      else
+	{
+	  // If someone thinks of a more graceful way of doing this (or
+	  // faster for that matter :-)), please let me know!
+
+	  if (n > 1)
+	    for (octave_idx_type j = 0; j < a_nc; j++)
+	      for (octave_idx_type i = j+1; i < a_nr; i++)
+		chol_mat.xelem (i, j) = 0.0;
+	}
     }
 
   return info;
@@ -91,27 +129,32 @@
   if (r_nr == r_nc)
     {
       octave_idx_type n = r_nc;
-      octave_idx_type info;
+      octave_idx_type info = 0;
 
       Matrix tmp = r;
+      double *v = tmp.fortran_vec();
 
-      F77_XFCN (dpotri, DPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
-				 tmp.fortran_vec (), n, info
-				 F77_CHAR_ARG_LEN (1)));
-
-      if (f77_exception_encountered)
-	(*current_liboctave_error_handler) ("unrecoverable error in dpotri");
-      else
+      if (info == 0)
 	{
-	  // If someone thinks of a more graceful way of doing this (or
-	  // faster for that matter :-)), please let me know!
+	  F77_XFCN (dpotri, DPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
+				     v, n, info
+				     F77_CHAR_ARG_LEN (1)));
 
-	  if (n > 1)
-	    for (octave_idx_type j = 0; j < r_nc; j++)
-	      for (octave_idx_type i = j+1; i < r_nr; i++)
-		tmp.xelem (i, j) = tmp.xelem (j, i);
+	  if (f77_exception_encountered)
+	    (*current_liboctave_error_handler) 
+	      ("unrecoverable error in dpotri");
+	  else
+	    {
+	      // If someone thinks of a more graceful way of doing this (or
+	      // faster for that matter :-)), please let me know!
 
-	  retval = tmp;
+	      if (n > 1)
+		for (octave_idx_type j = 0; j < r_nc; j++)
+		  for (octave_idx_type i = j+1; i < r_nr; i++)
+		    tmp.xelem (i, j) = tmp.xelem (j, i);
+
+	      retval = tmp;
+	    }
 	}
     }
   else