diff liboctave/floatQR.cc @ 8562:a6edd5c23cb5

use replacement methods if qrupdate is not available
author Jaroslav Hajek <highegg@gmail.com>
date Thu, 22 Jan 2009 11:10:47 +0100
parents d66c9b6e506a
children c86718093c1b
line wrap: on
line diff
--- a/liboctave/floatQR.cc
+++ b/liboctave/floatQR.cc
@@ -159,14 +159,28 @@
 
 FloatQR::FloatQR (const FloatMatrix& q_arg, const FloatMatrix& r_arg)
 {
-  if (q_arg.columns () != r_arg.rows ()) 
+  octave_idx_type qr = q_arg.rows (), qc = q_arg.columns ();
+  octave_idx_type rr = r_arg.rows (), rc = r_arg.columns ();
+  if (qc == rr && (qr == qc || (qr > qc && rr == rc)))
     {
-      (*current_liboctave_error_handler) ("QR dimensions mismatch");
-      return;
+      q = q_arg;
+      r = r_arg;
     }
+  else
+    (*current_liboctave_error_handler) ("QR dimensions mismatch");
+}
 
-  this->q = q_arg;
-  this->r = r_arg;
+QR::type
+FloatQR::get_type (void) const
+{
+  QR::type retval;
+  if (!q.is_empty () && q.is_square ())
+    retval = QR::std;
+  else if (q.rows () > q.columns () && r.is_square ())
+    retval = QR::economy;
+  else
+    retval = QR::raw;
+  return retval;
 }
 
 #ifdef HAVE_QRUPDATE
@@ -186,7 +200,7 @@
                                  utmp.fortran_vec (), vtmp.fortran_vec (), w));
     }
   else
-    (*current_liboctave_error_handler) ("QR update dimensions mismatch");
+    (*current_liboctave_error_handler) ("qrupdate: dimensions mismatch");
 }
 
 void
@@ -418,6 +432,249 @@
     }
 }
 
+#else
+
+// Replacement update methods.
+
+void
+FloatQR::update (const FloatColumnVector& u, const FloatColumnVector& v)
+{
+  warn_qrupdate_once ();
+
+  octave_idx_type m = q.rows ();
+  octave_idx_type n = r.columns ();
+
+  if (u.length () == m && v.length () == n)
+    {
+      init(q*r + FloatMatrix (u) * FloatMatrix (v).transpose (), get_type ());
+    }
+  else
+    (*current_liboctave_error_handler) ("qrupdate: dimensions mismatch");
+}
+
+void
+FloatQR::update (const FloatMatrix& u, const FloatMatrix& v)
+{
+  warn_qrupdate_once ();
+
+  octave_idx_type m = q.rows ();
+  octave_idx_type n = r.columns ();
+
+  if (u.rows () == m && v.rows () == n && u.cols () == v.cols ())
+    {
+      init(q*r + u * v.transpose (), get_type ());
+    }
+  else
+    (*current_liboctave_error_handler) ("qrupdate: dimensions mismatch");
+}
+
+static
+FloatMatrix insert_col (const FloatMatrix& a, octave_idx_type i,
+                        const FloatColumnVector& x)
+{
+  FloatMatrix retval (a.rows (), a.columns () + 1);
+  retval.assign (idx_vector::colon, idx_vector (0, i),
+                 a.index (idx_vector::colon, idx_vector (0, i)));
+  retval.assign (idx_vector::colon, idx_vector (i), x);
+  retval.assign (idx_vector::colon, idx_vector (i+1, retval.columns ()),
+                 a.index (idx_vector::colon, idx_vector (i, a.columns ())));
+  return retval;
+}
+
+static
+FloatMatrix insert_row (const FloatMatrix& a, octave_idx_type i,
+                        const FloatRowVector& x)
+{
+  FloatMatrix retval (a.rows () + 1, a.columns ());
+  retval.assign (idx_vector (0, i), idx_vector::colon,
+                 a.index (idx_vector (0, i), idx_vector::colon));
+  retval.assign (idx_vector (i), idx_vector::colon, x);
+  retval.assign (idx_vector (i+1, retval.rows ()), idx_vector::colon,
+                 a.index (idx_vector (i, a.rows ()), idx_vector::colon));
+  return retval;
+}
+
+static
+FloatMatrix delete_col (const FloatMatrix& a, octave_idx_type i)
+{
+  FloatMatrix retval = a;
+  retval.delete_elements (1, idx_vector (i));
+  return retval;
+}
+
+static
+FloatMatrix delete_row (const FloatMatrix& a, octave_idx_type i)
+{
+  FloatMatrix retval = a;
+  retval.delete_elements (0, idx_vector (i));
+  return retval;
+}
+
+static
+FloatMatrix shift_cols (const FloatMatrix& a, 
+                        octave_idx_type i, octave_idx_type j)
+{
+  octave_idx_type n = a.columns ();
+  Array<octave_idx_type> p (n);
+  for (octave_idx_type k = 0; k < n; k++) p(k) = k;
+  if (i < j)
+    {
+      for (octave_idx_type k = i; k < j; k++) p(k) = k+1;
+      p(j) = i;
+    }
+  else if (j < i)
+    {
+      p(j) = i;
+      for (octave_idx_type k = j+1; k < i+1; k++) p(k) = k-1;
+    }
+
+  return a.index (idx_vector::colon, idx_vector (p));
+}
+
+void
+FloatQR::insert_col (const FloatColumnVector& u, octave_idx_type j)
+{
+  warn_qrupdate_once ();
+
+  octave_idx_type m = q.rows ();
+  octave_idx_type n = r.columns ();
+
+  if (u.length () != m)
+    (*current_liboctave_error_handler) ("qrinsert: dimensions mismatch");
+  else if (j < 0 || j > n) 
+    (*current_liboctave_error_handler) ("qrinsert: index out of range");
+  else
+    {
+      init (::insert_col (q*r, j, u), get_type ());
+    }
+}
+
+void
+FloatQR::insert_col (const FloatMatrix& u, const Array<octave_idx_type>& j)
+{
+  warn_qrupdate_once ();
+
+  octave_idx_type m = q.rows ();
+  octave_idx_type n = r.columns ();
+
+  Array<octave_idx_type> jsi;
+  Array<octave_idx_type> js = j.sort (jsi, ASCENDING);
+  octave_idx_type nj = js.length ();
+  bool dups = false;
+  for (octave_idx_type i = 0; i < nj - 1; i++)
+    dups = dups && js(i) == js(i+1);
+
+  if (dups)
+    (*current_liboctave_error_handler) ("qrinsert: duplicate index detected");
+  else if (u.length () != m || u.columns () != nj)
+    (*current_liboctave_error_handler) ("qrinsert: dimensions mismatch");
+  else if (nj > 0 && (js(0) < 0 || js(nj-1) > n))
+    (*current_liboctave_error_handler) ("qrinsert: index out of range");
+  else if (nj > 0)
+    {
+      FloatMatrix a = q*r;
+      for (octave_idx_type i = 0; i < js.length (); i++)
+        a = ::insert_col (a, js(i), u.column (i));
+      init (a, get_type ());
+    }
+}
+
+void
+FloatQR::delete_col (octave_idx_type j)
+{
+  warn_qrupdate_once ();
+
+  octave_idx_type m = q.rows ();
+  octave_idx_type n = r.columns ();
+
+  if (j < 0 || j > n-1) 
+    (*current_liboctave_error_handler) ("qrdelete: index out of range");
+  else
+    {
+      init (::delete_col (q*r, j), get_type ());
+    }
+}
+
+void
+FloatQR::delete_col (const Array<octave_idx_type>& j)
+{
+  warn_qrupdate_once ();
+
+  octave_idx_type m = q.rows ();
+  octave_idx_type n = r.columns ();
+
+  Array<octave_idx_type> jsi;
+  Array<octave_idx_type> js = j.sort (jsi, DESCENDING);
+  octave_idx_type nj = js.length ();
+  bool dups = false;
+  for (octave_idx_type i = 0; i < nj - 1; i++)
+    dups = dups && js(i) == js(i+1);
+
+  if (dups)
+    (*current_liboctave_error_handler) ("qrinsert: duplicate index detected");
+  else if (nj > 0 && (js(0) > n-1 || js(nj-1) < 0))
+    (*current_liboctave_error_handler) ("qrinsert: index out of range");
+  else if (nj > 0)
+    {
+      FloatMatrix a = q*r;
+      for (octave_idx_type i = 0; i < js.length (); i++)
+        a = ::delete_col (a, js(i));
+      init (a, get_type ());
+    }
+}
+
+void
+FloatQR::insert_row (const FloatRowVector& u, octave_idx_type j)
+{
+  warn_qrupdate_once ();
+
+  octave_idx_type m = r.rows ();
+  octave_idx_type n = r.columns ();
+
+  if (! q.is_square () || u.length () != n)
+    (*current_liboctave_error_handler) ("qrinsert: dimensions mismatch");
+  else if (j < 0 || j > m) 
+    (*current_liboctave_error_handler) ("qrinsert: index out of range");
+  else
+    {
+      init (::insert_row (q*r, j, u), get_type ());
+    }
+}
+
+void
+FloatQR::delete_row (octave_idx_type j)
+{
+  warn_qrupdate_once ();
+
+  octave_idx_type m = r.rows ();
+  octave_idx_type n = r.columns ();
+
+  if (! q.is_square ())
+    (*current_liboctave_error_handler) ("qrdelete: dimensions mismatch");
+  else if (j < 0 || j > m-1) 
+    (*current_liboctave_error_handler) ("qrdelete: index out of range");
+  else
+    {
+      init (::delete_row (q*r, j), get_type ());
+    }
+}
+
+void
+FloatQR::shift_cols (octave_idx_type i, octave_idx_type j)
+{
+  warn_qrupdate_once ();
+
+  octave_idx_type m = q.rows ();
+  octave_idx_type n = r.columns ();
+
+  if (i < 0 || i > n-1 || j < 0 || j > n-1) 
+    (*current_liboctave_error_handler) ("qrshift: index out of range");
+  else
+    {
+      init (::shift_cols (q*r, i, j), get_type ());
+    }
+}
+
 #endif