diff liboctave/floatQR.cc @ 8547:d66c9b6e506a

imported patch qrupdate.diff
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 20 Jan 2009 21:16:42 +0100
parents 4976f66d469b
children a6edd5c23cb5
line wrap: on
line diff
--- a/liboctave/floatQR.cc
+++ b/liboctave/floatQR.cc
@@ -2,6 +2,7 @@
 
 Copyright (C) 1994, 1995, 1996, 1997, 2002, 2003, 2004, 2005, 2007
               John W. Eaton
+Copyright (C) 2008, 2009 Jaroslav Hajek
 
 This file is part of Octave.
 
@@ -30,6 +31,7 @@
 #include "lo-error.h"
 #include "Range.h"
 #include "idx-vector.h"
+#include "oct-locbuf.h"
 
 extern "C"
 {
@@ -41,33 +43,40 @@
   F77_FUNC (sorgqr, SORGQR) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, float*,
 			     const octave_idx_type&, float*, float*, const octave_idx_type&, octave_idx_type&);
 
-  // these come from qrupdate
+#ifdef HAVE_QRUPDATE
 
   F77_RET_T
   F77_FUNC (sqr1up, SQR1UP) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, 
-                             float*, float*, const float*, const float*);
+                             float*, const octave_idx_type&, float*, const octave_idx_type&,
+                             float*, float*, float*);
 
   F77_RET_T
   F77_FUNC (sqrinc, SQRINC) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, 
-                             float*, const float*, float*, const octave_idx_type&, const float*);
+                             float*, const octave_idx_type&, float*, const octave_idx_type&,
+                             const octave_idx_type&, const float*, float*);
 
   F77_RET_T
   F77_FUNC (sqrdec, SQRDEC) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, 
-                             float*, const float*, float*, const octave_idx_type&);
+                             float*, const octave_idx_type&, float*, const octave_idx_type&,
+                             const octave_idx_type&, float*);
 
   F77_RET_T
   F77_FUNC (sqrinr, SQRINR) (const octave_idx_type&, const octave_idx_type&, 
-                             const float*, float*, const float*, float*, 
-                             const octave_idx_type&, const float*);
+                             float*, const octave_idx_type&, float*, const octave_idx_type&,
+                             const octave_idx_type&, const float*, float*);
 
   F77_RET_T
   F77_FUNC (sqrder, SQRDER) (const octave_idx_type&, const octave_idx_type&, 
-                             const float*, float*, const float*, float *, 
-                             const octave_idx_type&);
+                             float*, const octave_idx_type&, float*, const octave_idx_type&,
+                             const octave_idx_type&, float*);
 
   F77_RET_T
   F77_FUNC (sqrshc, SQRSHC) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&,
-                             float*, float*, const octave_idx_type&, const octave_idx_type&);
+                             float*, const octave_idx_type&, float*, const octave_idx_type&,
+                             const octave_idx_type&, const octave_idx_type&,
+                             float*);
+
+#endif
 }
 
 FloatQR::FloatQR (const FloatMatrix& a, QR::type qr_type)
@@ -160,6 +169,26 @@
   this->r = r_arg;
 }
 
+#ifdef HAVE_QRUPDATE
+
+void
+FloatQR::update (const FloatColumnVector& u, const FloatColumnVector& v)
+{
+  octave_idx_type m = q.rows ();
+  octave_idx_type n = r.columns ();
+  octave_idx_type k = q.columns ();
+
+  if (u.length () == m && v.length () == n)
+    {
+      FloatColumnVector utmp = u, vtmp = v;
+      OCTAVE_LOCAL_BUFFER (float, w, 2*k);
+      F77_XFCN (sqr1up, SQR1UP, (m, n, k, q.fortran_vec (), m, r.fortran_vec (), k,
+                                 utmp.fortran_vec (), vtmp.fortran_vec (), w));
+    }
+  else
+    (*current_liboctave_error_handler) ("QR update dimensions mismatch");
+}
+
 void
 FloatQR::update (const FloatMatrix& u, const FloatMatrix& v)
 {
@@ -167,32 +196,93 @@
   octave_idx_type n = r.columns ();
   octave_idx_type k = q.columns ();
 
-  if (u.length () == m && v.length () == n)
-    F77_XFCN (sqr1up, SQR1UP, (m, n, k, q.fortran_vec (), r.fortran_vec (), 
-			       u.data (), v.data ()));
+  if (u.rows () == m && v.rows () == n && u.cols () == v.cols ())
+    {
+      OCTAVE_LOCAL_BUFFER (float, w, 2*k);
+      for (octave_idx_type i = 0; i < u.cols (); i++)
+        {
+          FloatColumnVector utmp = u.column (i), vtmp = v.column (i);
+          F77_XFCN (sqr1up, SQR1UP, (m, n, k, q.fortran_vec (), m, r.fortran_vec (), k,
+                                     utmp.fortran_vec (), vtmp.fortran_vec (), w));
+        }
+    }
   else
-    (*current_liboctave_error_handler) ("QR update dimensions mismatch");
+    (*current_liboctave_error_handler) ("qrupdate: dimensions mismatch");
 }
 
 void
-FloatQR::insert_col (const FloatMatrix& u, octave_idx_type j)
+FloatQR::insert_col (const FloatColumnVector& u, octave_idx_type j)
 {
   octave_idx_type m = q.rows ();
   octave_idx_type n = r.columns ();
   octave_idx_type k = q.columns ();
 
   if (u.length () != m)
-    (*current_liboctave_error_handler) ("QR insert dimensions mismatch");
+    (*current_liboctave_error_handler) ("qrinsert: dimensions mismatch");
   else if (j < 0 || j > n) 
-    (*current_liboctave_error_handler) ("QR insert index out of range");
+    (*current_liboctave_error_handler) ("qrinsert: index out of range");
   else
     {
-      FloatMatrix r1 (m, n+1);
+      if (k < m)
+        {
+          q.resize (m, k+1);
+          r.resize (k+1, n+1);
+        }
+      else
+        {
+          r.resize (k, n+1);
+        }
+
+      FloatColumnVector utmp = u;
+      OCTAVE_LOCAL_BUFFER (float, w, k);
+      F77_XFCN (sqrinc, SQRINC, (m, n, k, q.fortran_vec (), q.rows (),
+                                 r.fortran_vec (), r.rows (), j + 1, 
+                                 utmp.data (), w));
+    }
+}
+
+void
+FloatQR::insert_col (const FloatMatrix& u, const Array<octave_idx_type>& j)
+{
+  octave_idx_type m = q.rows ();
+  octave_idx_type n = r.columns ();
+  octave_idx_type k = q.columns ();
 
-      F77_XFCN (sqrinc, SQRINC, (m, n, k, q.fortran_vec (), r.data (),
-				 r1.fortran_vec (), j+1, u.data ()));
+  Array<octave_idx_type> jsi;
+  Array<octave_idx_type> js = j.sort (jsi, ASCENDING);
+  octave_idx_type nj = js.length ();
+  bool dups = false;
+  for (octave_idx_type i = 0; i < nj - 1; i++)
+    dups = dups && js(i) == js(i+1);
 
-      r = r1;
+  if (dups)
+    (*current_liboctave_error_handler) ("qrinsert: duplicate index detected");
+  else if (u.length () != m || u.columns () != nj)
+    (*current_liboctave_error_handler) ("qrinsert: dimensions mismatch");
+  else if (nj > 0 && (js(0) < 0 || js(nj-1) > n))
+    (*current_liboctave_error_handler) ("qrinsert: index out of range");
+  else if (nj > 0)
+    {
+      octave_idx_type kmax = std::min (k + nj, m);
+      if (k < m)
+        {
+          q.resize (m, kmax);
+          r.resize (kmax, n + nj);
+        }
+      else
+        {
+          r.resize (k, n + nj);
+        }
+
+      OCTAVE_LOCAL_BUFFER (float, w, kmax);
+      for (octave_idx_type i = 0; i < js.length (); i++)
+        {
+          FloatColumnVector utmp = u.column (jsi(i));
+          F77_XFCN (sqrinc, SQRINC, (m, n + i, std::min (kmax, k + i), 
+                                     q.fortran_vec (), q.rows (),
+                                     r.fortran_vec (), r.rows (), js(i) + 1, 
+                                     utmp.data (), w));
+        }
     }
 }
 
@@ -203,41 +293,87 @@
   octave_idx_type k = r.rows ();
   octave_idx_type n = r.columns ();
 
-  if (k < m && k < n) 
-    (*current_liboctave_error_handler) ("QR delete dimensions mismatch");
-  else if (j < 0 || j > n-1) 
-    (*current_liboctave_error_handler) ("QR delete index out of range");
+  if (j < 0 || j > n-1) 
+    (*current_liboctave_error_handler) ("qrdelete: index out of range");
   else
     {
-      FloatMatrix r1 (k, n-1);
+      OCTAVE_LOCAL_BUFFER (float, w, k);
+      F77_XFCN (sqrdec, SQRDEC, (m, n, k, q.fortran_vec (), q.rows (),
+				 r.fortran_vec (), r.rows (), j + 1, w));
 
-      F77_XFCN (sqrdec, SQRDEC, (m, n, k, q.fortran_vec (), r.data (),
-				 r1.fortran_vec (), j+1));
-
-      r = r1;
+      if (k < m)
+        {
+          q.resize (m, k-1);
+          r.resize (k-1, n-1);
+        }
+      else
+        {
+          r.resize (k, n-1);
+        }
     }
 }
 
 void
-FloatQR::insert_row (const FloatMatrix& u, octave_idx_type j)
+FloatQR::delete_col (const Array<octave_idx_type>& j)
+{
+  octave_idx_type m = q.rows ();
+  octave_idx_type n = r.columns ();
+  octave_idx_type k = q.columns ();
+
+  Array<octave_idx_type> jsi;
+  Array<octave_idx_type> js = j.sort (jsi, DESCENDING);
+  octave_idx_type nj = js.length ();
+  bool dups = false;
+  for (octave_idx_type i = 0; i < nj - 1; i++)
+    dups = dups && js(i) == js(i+1);
+
+  if (dups)
+    (*current_liboctave_error_handler) ("qrinsert: duplicate index detected");
+  else if (nj > 0 && (js(0) > n-1 || js(nj-1) < 0))
+    (*current_liboctave_error_handler) ("qrinsert: index out of range");
+  else if (nj > 0)
+    {
+      OCTAVE_LOCAL_BUFFER (float, w, k);
+      for (octave_idx_type i = 0; i < js.length (); i++)
+        {
+          F77_XFCN (sqrdec, SQRDEC, (m, n - i, k == m ? k : k - i, 
+                                     q.fortran_vec (), q.rows (),
+                                     r.fortran_vec (), r.rows (), js(i) + 1, w));
+        }
+      if (k < m)
+        {
+          q.resize (m, k - nj);
+          r.resize (k - nj, n - nj);
+        }
+      else
+        {
+          r.resize (k, n - nj);
+        }
+
+    }
+}
+
+void
+FloatQR::insert_row (const FloatRowVector& u, octave_idx_type j)
 {
   octave_idx_type m = r.rows ();
   octave_idx_type n = r.columns ();
+  octave_idx_type k = std::min (m, n);
 
   if (! q.is_square () || u.length () != n)
-    (*current_liboctave_error_handler) ("QR insert dimensions mismatch");
+    (*current_liboctave_error_handler) ("qrinsert: dimensions mismatch");
   else if (j < 0 || j > m) 
-    (*current_liboctave_error_handler) ("QR insert index out of range");
+    (*current_liboctave_error_handler) ("qrinsert: index out of range");
   else
     {
-      FloatMatrix q1 (m+1, m+1);
-      FloatMatrix r1 (m+1, n);
+      q.resize (m + 1, m + 1);
+      r.resize (m + 1, n);
+      FloatRowVector utmp = u;
+      OCTAVE_LOCAL_BUFFER (float, w, k);
+      F77_XFCN (sqrinr, SQRINR, (m, n, q.fortran_vec (), q.rows (),
+				 r.fortran_vec (), r.rows (), 
+                                 j + 1, utmp.fortran_vec (), w));
 
-      F77_XFCN (sqrinr, SQRINR, (m, n, q.data (), q1.fortran_vec (), 
-				 r.data (), r1.fortran_vec (), j+1, u.data ()));
-
-      q = q1;
-      r = r1;
     }
 }
 
@@ -248,19 +384,18 @@
   octave_idx_type n = r.columns ();
 
   if (! q.is_square ())
-    (*current_liboctave_error_handler) ("QR delete dimensions mismatch");
+    (*current_liboctave_error_handler) ("qrdelete: dimensions mismatch");
   else if (j < 0 || j > m-1) 
-    (*current_liboctave_error_handler) ("QR delete index out of range");
+    (*current_liboctave_error_handler) ("qrdelete: index out of range");
   else
     {
-      FloatMatrix q1 (m-1, m-1);
-      FloatMatrix r1 (m-1, n);
+      OCTAVE_LOCAL_BUFFER (float, w, 2*m);
+      F77_XFCN (sqrder, SQRDER, (m, n, q.fortran_vec (), q.rows (),
+				 r.fortran_vec (), r.rows (), j + 1,
+                                 w));
 
-      F77_XFCN (sqrder, SQRDER, (m, n, q.data (), q1.fortran_vec (), 
-				 r.data (), r1.fortran_vec (), j+1 ));
-
-      q = q1;
-      r = r1;
+      q.resize (m - 1, m - 1);
+      r.resize (m - 1, n);
     }
 }
 
@@ -272,22 +407,18 @@
   octave_idx_type n = r.columns ();
 
   if (i < 0 || i > n-1 || j < 0 || j > n-1) 
-    (*current_liboctave_error_handler) ("QR shift index out of range");
+    (*current_liboctave_error_handler) ("qrshift: index out of range");
   else
-    F77_XFCN (sqrshc, SQRSHC, (m, n, k, q.fortran_vec (), r.fortran_vec (), i+1, j+1));
+    {
+      OCTAVE_LOCAL_BUFFER (float, w, 2*k);
+      F77_XFCN (sqrshc, SQRSHC, (m, n, k, 
+                                 q.fortran_vec (), q.rows (),
+                                 r.fortran_vec (), r.rows (),
+                                 i + 1, j + 1, w));
+    }
 }
 
-void
-FloatQR::economize (void)
-{
-  octave_idx_type r_nc = r.columns ();
-
-  if (r.rows () > r_nc)
-    {
-      q.resize (q.rows (), r_nc);
-      r.resize (r_nc, r_nc);
-    }
-}
+#endif
 
 
 /*