Mercurial > hg > octave-lyh
diff liboctave/floatQR.cc @ 8547:d66c9b6e506a
imported patch qrupdate.diff
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Tue, 20 Jan 2009 21:16:42 +0100 |
parents | 4976f66d469b |
children | a6edd5c23cb5 |
line wrap: on
line diff
--- a/liboctave/floatQR.cc +++ b/liboctave/floatQR.cc @@ -2,6 +2,7 @@ Copyright (C) 1994, 1995, 1996, 1997, 2002, 2003, 2004, 2005, 2007 John W. Eaton +Copyright (C) 2008, 2009 Jaroslav Hajek This file is part of Octave. @@ -30,6 +31,7 @@ #include "lo-error.h" #include "Range.h" #include "idx-vector.h" +#include "oct-locbuf.h" extern "C" { @@ -41,33 +43,40 @@ F77_FUNC (sorgqr, SORGQR) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, float*, const octave_idx_type&, float*, float*, const octave_idx_type&, octave_idx_type&); - // these come from qrupdate +#ifdef HAVE_QRUPDATE F77_RET_T F77_FUNC (sqr1up, SQR1UP) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, - float*, float*, const float*, const float*); + float*, const octave_idx_type&, float*, const octave_idx_type&, + float*, float*, float*); F77_RET_T F77_FUNC (sqrinc, SQRINC) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, - float*, const float*, float*, const octave_idx_type&, const float*); + float*, const octave_idx_type&, float*, const octave_idx_type&, + const octave_idx_type&, const float*, float*); F77_RET_T F77_FUNC (sqrdec, SQRDEC) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, - float*, const float*, float*, const octave_idx_type&); + float*, const octave_idx_type&, float*, const octave_idx_type&, + const octave_idx_type&, float*); F77_RET_T F77_FUNC (sqrinr, SQRINR) (const octave_idx_type&, const octave_idx_type&, - const float*, float*, const float*, float*, - const octave_idx_type&, const float*); + float*, const octave_idx_type&, float*, const octave_idx_type&, + const octave_idx_type&, const float*, float*); F77_RET_T F77_FUNC (sqrder, SQRDER) (const octave_idx_type&, const octave_idx_type&, - const float*, float*, const float*, float *, - const octave_idx_type&); + float*, const octave_idx_type&, float*, const octave_idx_type&, + const octave_idx_type&, float*); F77_RET_T F77_FUNC (sqrshc, SQRSHC) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, - float*, float*, const octave_idx_type&, const octave_idx_type&); + float*, const octave_idx_type&, float*, const octave_idx_type&, + const octave_idx_type&, const octave_idx_type&, + float*); + +#endif } FloatQR::FloatQR (const FloatMatrix& a, QR::type qr_type) @@ -160,6 +169,26 @@ this->r = r_arg; } +#ifdef HAVE_QRUPDATE + +void +FloatQR::update (const FloatColumnVector& u, const FloatColumnVector& v) +{ + octave_idx_type m = q.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = q.columns (); + + if (u.length () == m && v.length () == n) + { + FloatColumnVector utmp = u, vtmp = v; + OCTAVE_LOCAL_BUFFER (float, w, 2*k); + F77_XFCN (sqr1up, SQR1UP, (m, n, k, q.fortran_vec (), m, r.fortran_vec (), k, + utmp.fortran_vec (), vtmp.fortran_vec (), w)); + } + else + (*current_liboctave_error_handler) ("QR update dimensions mismatch"); +} + void FloatQR::update (const FloatMatrix& u, const FloatMatrix& v) { @@ -167,32 +196,93 @@ octave_idx_type n = r.columns (); octave_idx_type k = q.columns (); - if (u.length () == m && v.length () == n) - F77_XFCN (sqr1up, SQR1UP, (m, n, k, q.fortran_vec (), r.fortran_vec (), - u.data (), v.data ())); + if (u.rows () == m && v.rows () == n && u.cols () == v.cols ()) + { + OCTAVE_LOCAL_BUFFER (float, w, 2*k); + for (octave_idx_type i = 0; i < u.cols (); i++) + { + FloatColumnVector utmp = u.column (i), vtmp = v.column (i); + F77_XFCN (sqr1up, SQR1UP, (m, n, k, q.fortran_vec (), m, r.fortran_vec (), k, + utmp.fortran_vec (), vtmp.fortran_vec (), w)); + } + } else - (*current_liboctave_error_handler) ("QR update dimensions mismatch"); + (*current_liboctave_error_handler) ("qrupdate: dimensions mismatch"); } void -FloatQR::insert_col (const FloatMatrix& u, octave_idx_type j) +FloatQR::insert_col (const FloatColumnVector& u, octave_idx_type j) { octave_idx_type m = q.rows (); octave_idx_type n = r.columns (); octave_idx_type k = q.columns (); if (u.length () != m) - (*current_liboctave_error_handler) ("QR insert dimensions mismatch"); + (*current_liboctave_error_handler) ("qrinsert: dimensions mismatch"); else if (j < 0 || j > n) - (*current_liboctave_error_handler) ("QR insert index out of range"); + (*current_liboctave_error_handler) ("qrinsert: index out of range"); else { - FloatMatrix r1 (m, n+1); + if (k < m) + { + q.resize (m, k+1); + r.resize (k+1, n+1); + } + else + { + r.resize (k, n+1); + } + + FloatColumnVector utmp = u; + OCTAVE_LOCAL_BUFFER (float, w, k); + F77_XFCN (sqrinc, SQRINC, (m, n, k, q.fortran_vec (), q.rows (), + r.fortran_vec (), r.rows (), j + 1, + utmp.data (), w)); + } +} + +void +FloatQR::insert_col (const FloatMatrix& u, const Array<octave_idx_type>& j) +{ + octave_idx_type m = q.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = q.columns (); - F77_XFCN (sqrinc, SQRINC, (m, n, k, q.fortran_vec (), r.data (), - r1.fortran_vec (), j+1, u.data ())); + Array<octave_idx_type> jsi; + Array<octave_idx_type> js = j.sort (jsi, ASCENDING); + octave_idx_type nj = js.length (); + bool dups = false; + for (octave_idx_type i = 0; i < nj - 1; i++) + dups = dups && js(i) == js(i+1); - r = r1; + if (dups) + (*current_liboctave_error_handler) ("qrinsert: duplicate index detected"); + else if (u.length () != m || u.columns () != nj) + (*current_liboctave_error_handler) ("qrinsert: dimensions mismatch"); + else if (nj > 0 && (js(0) < 0 || js(nj-1) > n)) + (*current_liboctave_error_handler) ("qrinsert: index out of range"); + else if (nj > 0) + { + octave_idx_type kmax = std::min (k + nj, m); + if (k < m) + { + q.resize (m, kmax); + r.resize (kmax, n + nj); + } + else + { + r.resize (k, n + nj); + } + + OCTAVE_LOCAL_BUFFER (float, w, kmax); + for (octave_idx_type i = 0; i < js.length (); i++) + { + FloatColumnVector utmp = u.column (jsi(i)); + F77_XFCN (sqrinc, SQRINC, (m, n + i, std::min (kmax, k + i), + q.fortran_vec (), q.rows (), + r.fortran_vec (), r.rows (), js(i) + 1, + utmp.data (), w)); + } } } @@ -203,41 +293,87 @@ octave_idx_type k = r.rows (); octave_idx_type n = r.columns (); - if (k < m && k < n) - (*current_liboctave_error_handler) ("QR delete dimensions mismatch"); - else if (j < 0 || j > n-1) - (*current_liboctave_error_handler) ("QR delete index out of range"); + if (j < 0 || j > n-1) + (*current_liboctave_error_handler) ("qrdelete: index out of range"); else { - FloatMatrix r1 (k, n-1); + OCTAVE_LOCAL_BUFFER (float, w, k); + F77_XFCN (sqrdec, SQRDEC, (m, n, k, q.fortran_vec (), q.rows (), + r.fortran_vec (), r.rows (), j + 1, w)); - F77_XFCN (sqrdec, SQRDEC, (m, n, k, q.fortran_vec (), r.data (), - r1.fortran_vec (), j+1)); - - r = r1; + if (k < m) + { + q.resize (m, k-1); + r.resize (k-1, n-1); + } + else + { + r.resize (k, n-1); + } } } void -FloatQR::insert_row (const FloatMatrix& u, octave_idx_type j) +FloatQR::delete_col (const Array<octave_idx_type>& j) +{ + octave_idx_type m = q.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = q.columns (); + + Array<octave_idx_type> jsi; + Array<octave_idx_type> js = j.sort (jsi, DESCENDING); + octave_idx_type nj = js.length (); + bool dups = false; + for (octave_idx_type i = 0; i < nj - 1; i++) + dups = dups && js(i) == js(i+1); + + if (dups) + (*current_liboctave_error_handler) ("qrinsert: duplicate index detected"); + else if (nj > 0 && (js(0) > n-1 || js(nj-1) < 0)) + (*current_liboctave_error_handler) ("qrinsert: index out of range"); + else if (nj > 0) + { + OCTAVE_LOCAL_BUFFER (float, w, k); + for (octave_idx_type i = 0; i < js.length (); i++) + { + F77_XFCN (sqrdec, SQRDEC, (m, n - i, k == m ? k : k - i, + q.fortran_vec (), q.rows (), + r.fortran_vec (), r.rows (), js(i) + 1, w)); + } + if (k < m) + { + q.resize (m, k - nj); + r.resize (k - nj, n - nj); + } + else + { + r.resize (k, n - nj); + } + + } +} + +void +FloatQR::insert_row (const FloatRowVector& u, octave_idx_type j) { octave_idx_type m = r.rows (); octave_idx_type n = r.columns (); + octave_idx_type k = std::min (m, n); if (! q.is_square () || u.length () != n) - (*current_liboctave_error_handler) ("QR insert dimensions mismatch"); + (*current_liboctave_error_handler) ("qrinsert: dimensions mismatch"); else if (j < 0 || j > m) - (*current_liboctave_error_handler) ("QR insert index out of range"); + (*current_liboctave_error_handler) ("qrinsert: index out of range"); else { - FloatMatrix q1 (m+1, m+1); - FloatMatrix r1 (m+1, n); + q.resize (m + 1, m + 1); + r.resize (m + 1, n); + FloatRowVector utmp = u; + OCTAVE_LOCAL_BUFFER (float, w, k); + F77_XFCN (sqrinr, SQRINR, (m, n, q.fortran_vec (), q.rows (), + r.fortran_vec (), r.rows (), + j + 1, utmp.fortran_vec (), w)); - F77_XFCN (sqrinr, SQRINR, (m, n, q.data (), q1.fortran_vec (), - r.data (), r1.fortran_vec (), j+1, u.data ())); - - q = q1; - r = r1; } } @@ -248,19 +384,18 @@ octave_idx_type n = r.columns (); if (! q.is_square ()) - (*current_liboctave_error_handler) ("QR delete dimensions mismatch"); + (*current_liboctave_error_handler) ("qrdelete: dimensions mismatch"); else if (j < 0 || j > m-1) - (*current_liboctave_error_handler) ("QR delete index out of range"); + (*current_liboctave_error_handler) ("qrdelete: index out of range"); else { - FloatMatrix q1 (m-1, m-1); - FloatMatrix r1 (m-1, n); + OCTAVE_LOCAL_BUFFER (float, w, 2*m); + F77_XFCN (sqrder, SQRDER, (m, n, q.fortran_vec (), q.rows (), + r.fortran_vec (), r.rows (), j + 1, + w)); - F77_XFCN (sqrder, SQRDER, (m, n, q.data (), q1.fortran_vec (), - r.data (), r1.fortran_vec (), j+1 )); - - q = q1; - r = r1; + q.resize (m - 1, m - 1); + r.resize (m - 1, n); } } @@ -272,22 +407,18 @@ octave_idx_type n = r.columns (); if (i < 0 || i > n-1 || j < 0 || j > n-1) - (*current_liboctave_error_handler) ("QR shift index out of range"); + (*current_liboctave_error_handler) ("qrshift: index out of range"); else - F77_XFCN (sqrshc, SQRSHC, (m, n, k, q.fortran_vec (), r.fortran_vec (), i+1, j+1)); + { + OCTAVE_LOCAL_BUFFER (float, w, 2*k); + F77_XFCN (sqrshc, SQRSHC, (m, n, k, + q.fortran_vec (), q.rows (), + r.fortran_vec (), r.rows (), + i + 1, j + 1, w)); + } } -void -FloatQR::economize (void) -{ - octave_idx_type r_nc = r.columns (); - - if (r.rows () > r_nc) - { - q.resize (q.rows (), r_nc); - r.resize (r_nc, r_nc); - } -} +#endif /*