diff liboctave/floatQR.cc @ 8597:c86718093c1b

improve & fix QR classes
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 27 Jan 2009 12:40:06 +0100
parents a6edd5c23cb5
children e9cb742df9eb
line wrap: on
line diff
--- a/liboctave/floatQR.cc
+++ b/liboctave/floatQR.cc
@@ -3,6 +3,7 @@
 Copyright (C) 1994, 1995, 1996, 1997, 2002, 2003, 2004, 2005, 2007
               John W. Eaton
 Copyright (C) 2008, 2009 Jaroslav Hajek
+Copyright (C) 2009 VZLU Prague
 
 This file is part of Octave.
 
@@ -91,29 +92,35 @@
   octave_idx_type m = a.rows ();
   octave_idx_type n = a.cols ();
 
-  if (m == 0 || n == 0)
-    {
-      (*current_liboctave_error_handler) ("QR must have non-empty matrix");
-      return;
-    }
-
   octave_idx_type min_mn = m < n ? m : n;
-  Array<float> tau (min_mn);
-  float *ptau = tau.fortran_vec ();
-
-  octave_idx_type lwork = 32*n;
-  Array<float> work (lwork);
-  float *pwork = work.fortran_vec ();
+  OCTAVE_LOCAL_BUFFER (float, tau, min_mn);
 
   octave_idx_type info = 0;
 
-  FloatMatrix A_fact = a;
-  if (m > n && qr_type != QR::economy)
-      A_fact.resize (m, m, 0.0);
+  FloatMatrix afact = a;
+  if (m > n && qr_type == QR::std)
+    afact.resize (m, m);
+
+  if (m > 0)
+    {
+      // workspace query.
+      float rlwork;
+      F77_XFCN (sgeqrf, SGEQRF, (m, n, afact.fortran_vec (), m, tau, &rlwork, -1, info));
 
-  float *tmp_data = A_fact.fortran_vec ();
+      // allocate buffer and do the job.
+      octave_idx_type lwork = rlwork; lwork = std::max (lwork, 1);
+      OCTAVE_LOCAL_BUFFER (float, work, lwork);
+      F77_XFCN (sgeqrf, SGEQRF, (m, n, afact.fortran_vec (), m, tau, work, lwork, info));
+    }
 
-  F77_XFCN (sgeqrf, SGEQRF, (m, n, tmp_data, m, ptau, pwork, lwork, info));
+  form (n, afact, tau, qr_type);
+}
+
+void FloatQR::form (octave_idx_type n, FloatMatrix& afact, 
+                    float *tau, QR::type qr_type)
+{
+  octave_idx_type m = afact.rows (), min_mn = std::min (m, n);
+  octave_idx_type info;
 
   if (qr_type == QR::raw)
     {
@@ -121,39 +128,58 @@
 	{
 	  octave_idx_type limit = j < min_mn - 1 ? j : min_mn - 1;
 	  for (octave_idx_type i = limit + 1; i < m; i++)
-	    A_fact.elem (i, j) *= tau.elem (j);
+	    afact.elem (i, j) *= tau[j];
 	}
 
-      r = A_fact;
-
-      if (m > n)
-	r.resize (m, n);
+      r = afact;
     }
   else
     {
-      octave_idx_type n2 = (qr_type == QR::economy) ? min_mn : m;
-
-      if (qr_type == QR::economy && m > n)
-	r.resize (n, n, 0.0);
+      // Attempt to minimize copying.
+      if (m >= n)
+        {
+          // afact will become q.
+          q = afact;
+          octave_idx_type k = qr_type == QR::economy ? n : m;
+          r = FloatMatrix (k, n);
+          for (octave_idx_type j = 0; j < n; j++)
+            {
+              octave_idx_type i = 0;
+              for (; i <= j; i++)
+                r.xelem (i, j) = afact.xelem (i, j);
+              for (;i < k; i++)
+                r.xelem (i, j) = 0;
+            }
+          afact = FloatMatrix (); // optimize memory
+        }
       else
-	r.resize (m, n, 0.0);
+        {
+          // afact will become r.
+          q = FloatMatrix (m, m);
+          for (octave_idx_type j = 0; j < m; j++)
+            for (octave_idx_type i = j + 1; i < m; i++)
+              {
+                q.xelem (i, j) = afact.xelem (i, j);
+                afact.xelem (i, j) = 0;
+              }
+          r = afact;
+        }
 
-      for (octave_idx_type j = 0; j < n; j++)
-	{
-	  octave_idx_type limit = j < min_mn-1 ? j : min_mn-1;
-	  for (octave_idx_type i = 0; i <= limit; i++)
-	    r.elem (i, j) = tmp_data[m*j+i];
-	}
 
-      lwork = 32 * n2;
-      work.resize (lwork);
-      float *pwork2 = work.fortran_vec ();
+      if (m > 0)
+        {
+          octave_idx_type k = q.columns ();
+          // workspace query.
+          float rlwork;
+          F77_XFCN (sorgqr, SORGQR, (m, k, min_mn, q.fortran_vec (), m, tau,
+                                     &rlwork, -1, info));
 
-      F77_XFCN (sorgqr, SORGQR, (m, n2, min_mn, tmp_data, m, ptau,
-				 pwork2, lwork, info));
-
-      q = A_fact;
-      q.resize (m, n2);
+          // allocate buffer and do the job.
+          octave_idx_type lwork = rlwork; lwork = std::max (lwork, 1);
+          OCTAVE_LOCAL_BUFFER (float, work, lwork);
+          F77_XFCN (sorgqr, SORGQR, (m, k, min_mn, q.fortran_vec (), m, tau,
+                                     work, lwork, info));
+        }
     }
 }