Mercurial > hg > octave-max
changeset 8597:c86718093c1b
improve & fix QR classes
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Tue, 27 Jan 2009 12:40:06 +0100 |
parents | 8833c0b18eb2 |
children | 11cf7bc4a871 |
files | liboctave/ChangeLog liboctave/CmplxQR.cc liboctave/CmplxQR.h liboctave/CmplxQRP.cc liboctave/dbleQR.cc liboctave/dbleQR.h liboctave/dbleQRP.cc liboctave/fCmplxQR.cc liboctave/fCmplxQR.h liboctave/fCmplxQRP.cc liboctave/floatQR.cc liboctave/floatQR.h liboctave/floatQRP.cc src/ChangeLog src/DLD-FUNCTIONS/qr.cc |
diffstat | 15 files changed, 394 insertions(+), 376 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,3 +1,33 @@ +2009-01-27 Jaroslav Hajek <highegg@gmail.com> + + * dbleQR.cc (QR::init): Use form. Use local buffers. + Query for optimal block size. + (QR::form): New function. + * dbleQR.h: Declare it. + * dbleQRP.cc (QRP::init):Use form. Use local buffers. + Query for optimal block size. + + * floatQR.cc (FloatQR::init): Use form. Use local buffers. + Query for optimal block size. + (FloatQR::form): New function. + * floatQR.h: Declare it. + * floatQRP.cc (FloatQRP::init):Use form. Use local buffers. + Query for optimal block size. + + * CmplxQR.cc (ComplexQR::init): Use form. Use local buffers. + Query for optimal block size. + (ComplexQR::form): New function. + * CmplxQR.h: Declare it. + * CmplxQRP.cc (ComplexQRP::init):Use form. Use local buffers. + Query for optimal block size. + + * fCmplxQR.cc (FloatComplexQR::init): Use form. Use local buffers. + Query for optimal block size. + (FloatComplexQR::form): New function. + * fCmplxQR.h: Declare it. + * fCmplxQRP.cc (FloatComplexQRP::init):Use form. Use local buffers. + Query for optimal block size. + 2009-01-23 Jaroslav Hajek <highegg@gmail.com> * Array.cc (Array<T>::assign (const idx_vector&, const Array<T>&)):
--- a/liboctave/CmplxQR.cc +++ b/liboctave/CmplxQR.cc @@ -3,6 +3,7 @@ Copyright (C) 1994, 1995, 1996, 1997, 2002, 2003, 2004, 2005, 2007 John W. Eaton Copyright (C) 2008, 2009 Jaroslav Hajek +Copyright (C) 2009 VZLU Prague This file is part of Octave. @@ -93,36 +94,35 @@ octave_idx_type m = a.rows (); octave_idx_type n = a.cols (); - if (m == 0 || n == 0) - { - (*current_liboctave_error_handler) - ("ComplexQR must have non-empty matrix"); - return; - } - octave_idx_type min_mn = m < n ? m : n; - - Array<Complex> tau (min_mn); - Complex *ptau = tau.fortran_vec (); - - octave_idx_type lwork = 32*n; - Array<Complex> work (lwork); - Complex *pwork = work.fortran_vec (); + OCTAVE_LOCAL_BUFFER (Complex, tau, min_mn); octave_idx_type info = 0; - ComplexMatrix A_fact; - if (m > n && qr_type != QR::economy) + ComplexMatrix afact = a; + if (m > n && qr_type == QR::std) + afact.resize (m, m); + + if (m > 0) { - A_fact.resize (m, m); - A_fact.insert (a, 0, 0); + // workspace query. + Complex clwork; + F77_XFCN (zgeqrf, ZGEQRF, (m, n, afact.fortran_vec (), m, tau, &clwork, -1, info)); + + // allocate buffer and do the job. + octave_idx_type lwork = clwork.real (); lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (Complex, work, lwork); + F77_XFCN (zgeqrf, ZGEQRF, (m, n, afact.fortran_vec (), m, tau, work, lwork, info)); } - else - A_fact = a; + + form (n, afact, tau, qr_type); +} - Complex *tmp_data = A_fact.fortran_vec (); - - F77_XFCN (zgeqrf, ZGEQRF, (m, n, tmp_data, m, ptau, pwork, lwork, info)); +void ComplexQR::form (octave_idx_type n, ComplexMatrix& afact, + Complex *tau, QR::type qr_type) +{ + octave_idx_type m = afact.rows (), min_mn = std::min (m, n); + octave_idx_type info; if (qr_type == QR::raw) { @@ -130,39 +130,58 @@ { octave_idx_type limit = j < min_mn - 1 ? j : min_mn - 1; for (octave_idx_type i = limit + 1; i < m; i++) - A_fact.elem (i, j) *= tau.elem (j); + afact.elem (i, j) *= tau[j]; } - r = A_fact; - - if (m > n) - r.resize (m, n); + r = afact; } else { - octave_idx_type n2 = (qr_type == QR::economy) ? min_mn : m; - - if (qr_type == QR::economy && m > n) - r.resize (n, n, 0.0); + // Attempt to minimize copying. + if (m >= n) + { + // afact will become q. + q = afact; + octave_idx_type k = qr_type == QR::economy ? n : m; + r = ComplexMatrix (k, n); + for (octave_idx_type j = 0; j < n; j++) + { + octave_idx_type i = 0; + for (; i <= j; i++) + r.xelem (i, j) = afact.xelem (i, j); + for (;i < k; i++) + r.xelem (i, j) = 0; + } + afact = ComplexMatrix (); // optimize memory + } else - r.resize (m, n, 0.0); + { + // afact will become r. + q = ComplexMatrix (m, m); + for (octave_idx_type j = 0; j < m; j++) + for (octave_idx_type i = j + 1; i < m; i++) + { + q.xelem (i, j) = afact.xelem (i, j); + afact.xelem (i, j) = 0; + } + r = afact; + } - for (octave_idx_type j = 0; j < n; j++) - { - octave_idx_type limit = j < min_mn-1 ? j : min_mn-1; - for (octave_idx_type i = 0; i <= limit; i++) - r.elem (i, j) = A_fact.elem (i, j); - } - lwork = 32 * n2; - work.resize (lwork); - Complex *pwork2 = work.fortran_vec (); + if (m > 0) + { + octave_idx_type k = q.columns (); + // workspace query. + Complex clwork; + F77_XFCN (zungqr, ZUNGQR, (m, k, min_mn, q.fortran_vec (), m, tau, + &clwork, -1, info)); - F77_XFCN (zungqr, ZUNGQR, (m, n2, min_mn, tmp_data, m, ptau, - pwork2, lwork, info)); - - q = A_fact; - q.resize (m, n2); + // allocate buffer and do the job. + octave_idx_type lwork = clwork.real (); lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (Complex, work, lwork); + F77_XFCN (zungqr, ZUNGQR, (m, k, min_mn, q.fortran_vec (), m, tau, + work, lwork, info)); + } } }
--- a/liboctave/CmplxQR.h +++ b/liboctave/CmplxQR.h @@ -89,6 +89,9 @@ protected: + void form (octave_idx_type n, ComplexMatrix& afact, + Complex *tau, QR::type qr_type); + ComplexMatrix q; ComplexMatrix r; };
--- a/liboctave/CmplxQRP.cc +++ b/liboctave/CmplxQRP.cc @@ -30,6 +30,7 @@ #include "CmplxQRP.h" #include "f77-fcn.h" #include "lo-error.h" +#include "oct-locbuf.h" extern "C" { @@ -37,11 +38,6 @@ F77_FUNC (zgeqp3, ZGEQP3) (const octave_idx_type&, const octave_idx_type&, Complex*, const octave_idx_type&, octave_idx_type*, Complex*, Complex*, const octave_idx_type&, double*, octave_idx_type&); - - F77_RET_T - F77_FUNC (zungqr, ZUNGQR) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, - Complex*, const octave_idx_type&, Complex*, - Complex*, const octave_idx_type&, octave_idx_type&); } // It would be best to share some of this code with ComplexQR class... @@ -60,44 +56,34 @@ octave_idx_type m = a.rows (); octave_idx_type n = a.cols (); - if (m == 0 || n == 0) - { - (*current_liboctave_error_handler) - ("ComplexQR must have non-empty matrix"); - return; - } - octave_idx_type min_mn = m < n ? m : n; - Array<Complex> tau (min_mn); - Complex *ptau = tau.fortran_vec (); + OCTAVE_LOCAL_BUFFER (Complex, tau, min_mn); octave_idx_type info = 0; - ComplexMatrix A_fact = a; - if (m > n && qr_type != QR::economy) - A_fact.resize (m, m, 0.0); - - Complex *tmp_data = A_fact.fortran_vec (); - - Array<double> rwork (2*n); - double *prwork = rwork.fortran_vec (); + ComplexMatrix afact = a; + if (m > n && qr_type == QR::std) + afact.resize (m, m); MArray<octave_idx_type> jpvt (n, 0); - octave_idx_type *pjpvt = jpvt.fortran_vec (); - Complex rlwork = 0; - // Workspace query... - F77_XFCN (zgeqp3, ZGEQP3, (m, n, tmp_data, m, pjpvt, ptau, &rlwork, - -1, prwork, info)); + if (m > 0) + { + OCTAVE_LOCAL_BUFFER (double, rwork, 2*n); - octave_idx_type lwork = rlwork.real (); - Array<Complex> work (lwork); - Complex *pwork = work.fortran_vec (); + // workspace query. + Complex clwork; + F77_XFCN (zgeqp3, ZGEQP3, (m, n, afact.fortran_vec (), m, jpvt.fortran_vec (), + tau, &clwork, -1, rwork, info)); - // Code to enforce a certain permutation could go here... - - F77_XFCN (zgeqp3, ZGEQP3, (m, n, tmp_data, m, pjpvt, ptau, pwork, - lwork, prwork, info)); + // allocate buffer and do the job. + octave_idx_type lwork = clwork.real (); lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (Complex, work, lwork); + F77_XFCN (zgeqp3, ZGEQP3, (m, n, afact.fortran_vec (), m, jpvt.fortran_vec (), + tau, work, lwork, rwork, info)); + } + else + for (octave_idx_type i = 0; i < n; i++) jpvt(i) = i+1; // Form Permutation matrix (if economy is requested, return the // indices only!) @@ -105,25 +91,8 @@ jpvt -= 1; p = PermMatrix (jpvt, true); - octave_idx_type n2 = (qr_type == QR::economy) ? min_mn : m; - if (qr_type == QR::economy && m > n) - r.resize (n, n, 0.0); - else - r.resize (m, n, 0.0); - - for (octave_idx_type j = 0; j < n; j++) - { - octave_idx_type limit = j < min_mn-1 ? j : min_mn-1; - for (octave_idx_type i = 0; i <= limit; i++) - r.elem (i, j) = A_fact.elem (i, j); - } - - F77_XFCN (zungqr, ZUNGQR, (m, n2, min_mn, tmp_data, m, ptau, - pwork, lwork, info)); - - q = A_fact; - q.resize (m, n2); + form (n, afact, tau, qr_type); } ColumnVector
--- a/liboctave/dbleQR.cc +++ b/liboctave/dbleQR.cc @@ -3,6 +3,7 @@ Copyright (C) 1994, 1995, 1996, 1997, 2002, 2003, 2004, 2005, 2007 John W. Eaton Copyright (C) 2008, 2009 Jaroslav Hajek +Copyright (C) 2009 VZLU Prague This file is part of Octave. @@ -91,29 +92,35 @@ octave_idx_type m = a.rows (); octave_idx_type n = a.cols (); - if (m == 0 || n == 0) - { - (*current_liboctave_error_handler) ("QR must have non-empty matrix"); - return; - } - octave_idx_type min_mn = m < n ? m : n; - Array<double> tau (min_mn); - double *ptau = tau.fortran_vec (); - - octave_idx_type lwork = 32*n; - Array<double> work (lwork); - double *pwork = work.fortran_vec (); + OCTAVE_LOCAL_BUFFER (double, tau, min_mn); octave_idx_type info = 0; - Matrix A_fact = a; - if (m > n && qr_type != QR::economy) - A_fact.resize (m, m, 0.0); + Matrix afact = a; + if (m > n && qr_type == QR::std) + afact.resize (m, m); + + if (m > 0) + { + // workspace query. + double rlwork; + F77_XFCN (dgeqrf, DGEQRF, (m, n, afact.fortran_vec (), m, tau, &rlwork, -1, info)); - double *tmp_data = A_fact.fortran_vec (); + // allocate buffer and do the job. + octave_idx_type lwork = rlwork; lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (double, work, lwork); + F77_XFCN (dgeqrf, DGEQRF, (m, n, afact.fortran_vec (), m, tau, work, lwork, info)); + } - F77_XFCN (dgeqrf, DGEQRF, (m, n, tmp_data, m, ptau, pwork, lwork, info)); + form (n, afact, tau, qr_type); +} + +void QR::form (octave_idx_type n, Matrix& afact, + double *tau, QR::type qr_type) +{ + octave_idx_type m = afact.rows (), min_mn = std::min (m, n); + octave_idx_type info; if (qr_type == QR::raw) { @@ -121,39 +128,58 @@ { octave_idx_type limit = j < min_mn - 1 ? j : min_mn - 1; for (octave_idx_type i = limit + 1; i < m; i++) - A_fact.elem (i, j) *= tau.elem (j); + afact.elem (i, j) *= tau[j]; } - r = A_fact; - - if (m > n) - r.resize (m, n); + r = afact; } else { - octave_idx_type n2 = (qr_type == QR::economy) ? min_mn : m; - - if (qr_type == QR::economy && m > n) - r.resize (n, n, 0.0); + // Attempt to minimize copying. + if (m >= n) + { + // afact will become q. + q = afact; + octave_idx_type k = qr_type == QR::economy ? n : m; + r = Matrix (k, n); + for (octave_idx_type j = 0; j < n; j++) + { + octave_idx_type i = 0; + for (; i <= j; i++) + r.xelem (i, j) = afact.xelem (i, j); + for (;i < k; i++) + r.xelem (i, j) = 0; + } + afact = Matrix (); // optimize memory + } else - r.resize (m, n, 0.0); + { + // afact will become r. + q = Matrix (m, m); + for (octave_idx_type j = 0; j < m; j++) + for (octave_idx_type i = j + 1; i < m; i++) + { + q.xelem (i, j) = afact.xelem (i, j); + afact.xelem (i, j) = 0; + } + r = afact; + } - for (octave_idx_type j = 0; j < n; j++) - { - octave_idx_type limit = j < min_mn-1 ? j : min_mn-1; - for (octave_idx_type i = 0; i <= limit; i++) - r.elem (i, j) = tmp_data[m*j+i]; - } - lwork = 32 * n2; - work.resize (lwork); - double *pwork2 = work.fortran_vec (); + if (m > 0) + { + octave_idx_type k = q.columns (); + // workspace query. + double rlwork; + F77_XFCN (dorgqr, DORGQR, (m, k, min_mn, q.fortran_vec (), m, tau, + &rlwork, -1, info)); - F77_XFCN (dorgqr, DORGQR, (m, n2, min_mn, tmp_data, m, ptau, - pwork2, lwork, info)); - - q = A_fact; - q.resize (m, n2); + // allocate buffer and do the job. + octave_idx_type lwork = rlwork; lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (double, work, lwork); + F77_XFCN (dorgqr, DORGQR, (m, k, min_mn, q.fortran_vec (), m, tau, + work, lwork, info)); + } } }
--- a/liboctave/dbleQR.h +++ b/liboctave/dbleQR.h @@ -94,6 +94,9 @@ protected: + void form (octave_idx_type n, Matrix& afact, + double *tau, QR::type qr_type); + Matrix q; Matrix r; };
--- a/liboctave/dbleQRP.cc +++ b/liboctave/dbleQRP.cc @@ -30,6 +30,7 @@ #include "dbleQRP.h" #include "f77-fcn.h" #include "lo-error.h" +#include "oct-locbuf.h" extern "C" { @@ -37,11 +38,6 @@ F77_FUNC (dgeqp3, DGEQP3) (const octave_idx_type&, const octave_idx_type&, double*, const octave_idx_type&, octave_idx_type*, double*, double*, const octave_idx_type&, octave_idx_type&); - - F77_RET_T - F77_FUNC (dorgqr, DORGQR) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, - double*, const octave_idx_type&, double*, double*, - const octave_idx_type&, octave_idx_type&); } // It would be best to share some of this code with QR class... @@ -60,38 +56,32 @@ octave_idx_type m = a.rows (); octave_idx_type n = a.cols (); - if (m == 0 || n == 0) - { - (*current_liboctave_error_handler) ("QR must have non-empty matrix"); - return; - } - octave_idx_type min_mn = m < n ? m : n; - Array<double> tau (min_mn); - double *ptau = tau.fortran_vec (); + OCTAVE_LOCAL_BUFFER (double, tau, min_mn); octave_idx_type info = 0; - Matrix A_fact = a; - if (m > n && qr_type != QR::economy) - A_fact.resize (m, m, 0.0); - - double *tmp_data = A_fact.fortran_vec (); + Matrix afact = a; + if (m > n && qr_type == QR::std) + afact.resize (m, m); MArray<octave_idx_type> jpvt (n, 0); - octave_idx_type *pjpvt = jpvt.fortran_vec (); - double rlwork = 0; - // Workspace query... - F77_XFCN (dgeqp3, DGEQP3, (m, n, tmp_data, m, pjpvt, ptau, &rlwork, -1, info)); + if (m > 0) + { + // workspace query. + double rlwork; + F77_XFCN (dgeqp3, DGEQP3, (m, n, afact.fortran_vec (), m, jpvt.fortran_vec (), + tau, &rlwork, -1, info)); - octave_idx_type lwork = rlwork; - Array<double> work (lwork); - double *pwork = work.fortran_vec (); - - // Code to enforce a certain permutation could go here... - - F77_XFCN (dgeqp3, DGEQP3, (m, n, tmp_data, m, pjpvt, ptau, pwork, lwork, info)); + // allocate buffer and do the job. + octave_idx_type lwork = rlwork; lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (double, work, lwork); + F77_XFCN (dgeqp3, DGEQP3, (m, n, afact.fortran_vec (), m, jpvt.fortran_vec (), + tau, work, lwork, info)); + } + else + for (octave_idx_type i = 0; i < n; i++) jpvt(i) = i+1; // Form Permutation matrix (if economy is requested, return the // indices only!) @@ -99,25 +89,8 @@ jpvt -= 1; p = PermMatrix (jpvt, true); - octave_idx_type n2 = (qr_type == QR::economy) ? min_mn : m; - if (qr_type == QR::economy && m > n) - r.resize (n, n, 0.0); - else - r.resize (m, n, 0.0); - - for (octave_idx_type j = 0; j < n; j++) - { - octave_idx_type limit = j < min_mn-1 ? j : min_mn-1; - for (octave_idx_type i = 0; i <= limit; i++) - r.elem (i, j) = A_fact.elem (i, j); - } - - F77_XFCN (dorgqr, DORGQR, (m, n2, min_mn, tmp_data, m, ptau, - pwork, lwork, info)); - - q = A_fact; - q.resize (m, n2); + form (n, afact, tau, qr_type); } ColumnVector
--- a/liboctave/fCmplxQR.cc +++ b/liboctave/fCmplxQR.cc @@ -3,6 +3,7 @@ Copyright (C) 1994, 1995, 1996, 1997, 2002, 2003, 2004, 2005, 2007 John W. Eaton Copyright (C) 2008, 2009 Jaroslav Hajek +Copyright (C) 2009 VZLU Prague This file is part of Octave. @@ -93,36 +94,35 @@ octave_idx_type m = a.rows (); octave_idx_type n = a.cols (); - if (m == 0 || n == 0) - { - (*current_liboctave_error_handler) - ("FloatComplexQR must have non-empty matrix"); - return; - } - octave_idx_type min_mn = m < n ? m : n; - - Array<FloatComplex> tau (min_mn); - FloatComplex *ptau = tau.fortran_vec (); - - octave_idx_type lwork = 32*n; - Array<FloatComplex> work (lwork); - FloatComplex *pwork = work.fortran_vec (); + OCTAVE_LOCAL_BUFFER (FloatComplex, tau, min_mn); octave_idx_type info = 0; - FloatComplexMatrix A_fact; - if (m > n && qr_type != QR::economy) + FloatComplexMatrix afact = a; + if (m > n && qr_type == QR::std) + afact.resize (m, m); + + if (m > 0) { - A_fact.resize (m, m); - A_fact.insert (a, 0, 0); + // workspace query. + FloatComplex clwork; + F77_XFCN (cgeqrf, CGEQRF, (m, n, afact.fortran_vec (), m, tau, &clwork, -1, info)); + + // allocate buffer and do the job. + octave_idx_type lwork = clwork.real (); lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (FloatComplex, work, lwork); + F77_XFCN (cgeqrf, CGEQRF, (m, n, afact.fortran_vec (), m, tau, work, lwork, info)); } - else - A_fact = a; + + form (n, afact, tau, qr_type); +} - FloatComplex *tmp_data = A_fact.fortran_vec (); - - F77_XFCN (cgeqrf, CGEQRF, (m, n, tmp_data, m, ptau, pwork, lwork, info)); +void FloatComplexQR::form (octave_idx_type n, FloatComplexMatrix& afact, + FloatComplex *tau, QR::type qr_type) +{ + octave_idx_type m = afact.rows (), min_mn = std::min (m, n); + octave_idx_type info; if (qr_type == QR::raw) { @@ -130,39 +130,58 @@ { octave_idx_type limit = j < min_mn - 1 ? j : min_mn - 1; for (octave_idx_type i = limit + 1; i < m; i++) - A_fact.elem (i, j) *= tau.elem (j); + afact.elem (i, j) *= tau[j]; } - r = A_fact; - - if (m > n) - r.resize (m, n); + r = afact; } else { - octave_idx_type n2 = (qr_type == QR::economy) ? min_mn : m; - - if (qr_type == QR::economy && m > n) - r.resize (n, n, 0.0); + // Attempt to minimize copying. + if (m >= n) + { + // afact will become q. + q = afact; + octave_idx_type k = qr_type == QR::economy ? n : m; + r = FloatComplexMatrix (k, n); + for (octave_idx_type j = 0; j < n; j++) + { + octave_idx_type i = 0; + for (; i <= j; i++) + r.xelem (i, j) = afact.xelem (i, j); + for (;i < k; i++) + r.xelem (i, j) = 0; + } + afact = FloatComplexMatrix (); // optimize memory + } else - r.resize (m, n, 0.0); + { + // afact will become r. + q = FloatComplexMatrix (m, m); + for (octave_idx_type j = 0; j < m; j++) + for (octave_idx_type i = j + 1; i < m; i++) + { + q.xelem (i, j) = afact.xelem (i, j); + afact.xelem (i, j) = 0; + } + r = afact; + } - for (octave_idx_type j = 0; j < n; j++) - { - octave_idx_type limit = j < min_mn-1 ? j : min_mn-1; - for (octave_idx_type i = 0; i <= limit; i++) - r.elem (i, j) = A_fact.elem (i, j); - } - lwork = 32 * n2; - work.resize (lwork); - FloatComplex *pwork2 = work.fortran_vec (); + if (m > 0) + { + octave_idx_type k = q.columns (); + // workspace query. + FloatComplex clwork; + F77_XFCN (cungqr, CUNGQR, (m, k, min_mn, q.fortran_vec (), m, tau, + &clwork, -1, info)); - F77_XFCN (cungqr, CUNGQR, (m, n2, min_mn, tmp_data, m, ptau, - pwork2, lwork, info)); - - q = A_fact; - q.resize (m, n2); + // allocate buffer and do the job. + octave_idx_type lwork = clwork.real (); lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (FloatComplex, work, lwork); + F77_XFCN (cungqr, CUNGQR, (m, k, min_mn, q.fortran_vec (), m, tau, + work, lwork, info)); + } } }
--- a/liboctave/fCmplxQR.h +++ b/liboctave/fCmplxQR.h @@ -90,6 +90,9 @@ protected: + void form (octave_idx_type n, FloatComplexMatrix& afact, + FloatComplex *tau, QR::type qr_type); + FloatComplexMatrix q; FloatComplexMatrix r; };
--- a/liboctave/fCmplxQRP.cc +++ b/liboctave/fCmplxQRP.cc @@ -30,6 +30,7 @@ #include "fCmplxQRP.h" #include "f77-fcn.h" #include "lo-error.h" +#include "oct-locbuf.h" extern "C" { @@ -37,11 +38,6 @@ F77_FUNC (cgeqp3, CGEQP3) (const octave_idx_type&, const octave_idx_type&, FloatComplex*, const octave_idx_type&, octave_idx_type*, FloatComplex*, FloatComplex*, const octave_idx_type&, float*, octave_idx_type&); - - F77_RET_T - F77_FUNC (cungqr, CUNGQR) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, - FloatComplex*, const octave_idx_type&, FloatComplex*, - FloatComplex*, const octave_idx_type&, octave_idx_type&); } // It would be best to share some of this code with FloatComplexQR class... @@ -60,44 +56,34 @@ octave_idx_type m = a.rows (); octave_idx_type n = a.cols (); - if (m == 0 || n == 0) - { - (*current_liboctave_error_handler) - ("FloatComplexQR must have non-empty matrix"); - return; - } - octave_idx_type min_mn = m < n ? m : n; - Array<FloatComplex> tau (min_mn); - FloatComplex *ptau = tau.fortran_vec (); + OCTAVE_LOCAL_BUFFER (FloatComplex, tau, min_mn); octave_idx_type info = 0; - FloatComplexMatrix A_fact = a; - if (m > n && qr_type != QR::economy) - A_fact.resize (m, m, 0.0); - - FloatComplex *tmp_data = A_fact.fortran_vec (); - - Array<float> rwork (2*n); - float *prwork = rwork.fortran_vec (); + FloatComplexMatrix afact = a; + if (m > n && qr_type == QR::std) + afact.resize (m, m); MArray<octave_idx_type> jpvt (n, 0); - octave_idx_type *pjpvt = jpvt.fortran_vec (); - FloatComplex rlwork = 0; - // Workspace query... - F77_XFCN (cgeqp3, CGEQP3, (m, n, tmp_data, m, pjpvt, ptau, &rlwork, - -1, prwork, info)); + if (m > 0) + { + OCTAVE_LOCAL_BUFFER (float, rwork, 2*n); - octave_idx_type lwork = rlwork.real (); - Array<FloatComplex> work (lwork); - FloatComplex *pwork = work.fortran_vec (); + // workspace query. + FloatComplex clwork; + F77_XFCN (cgeqp3, CGEQP3, (m, n, afact.fortran_vec (), m, jpvt.fortran_vec (), + tau, &clwork, -1, rwork, info)); - // Code to enforce a certain permutation could go here... - - F77_XFCN (cgeqp3, CGEQP3, (m, n, tmp_data, m, pjpvt, ptau, pwork, - lwork, prwork, info)); + // allocate buffer and do the job. + octave_idx_type lwork = clwork.real (); lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (FloatComplex, work, lwork); + F77_XFCN (cgeqp3, CGEQP3, (m, n, afact.fortran_vec (), m, jpvt.fortran_vec (), + tau, work, lwork, rwork, info)); + } + else + for (octave_idx_type i = 0; i < n; i++) jpvt(i) = i+1; // Form Permutation matrix (if economy is requested, return the // indices only!) @@ -105,25 +91,8 @@ jpvt -= 1; p = PermMatrix (jpvt, true); - octave_idx_type n2 = (qr_type == QR::economy) ? min_mn : m; - if (qr_type == QR::economy && m > n) - r.resize (n, n, 0.0); - else - r.resize (m, n, 0.0); - - for (octave_idx_type j = 0; j < n; j++) - { - octave_idx_type limit = j < min_mn-1 ? j : min_mn-1; - for (octave_idx_type i = 0; i <= limit; i++) - r.elem (i, j) = A_fact.elem (i, j); - } - - F77_XFCN (cungqr, CUNGQR, (m, n2, min_mn, tmp_data, m, ptau, - pwork, lwork, info)); - - q = A_fact; - q.resize (m, n2); + form (n, afact, tau, qr_type); } FloatColumnVector
--- a/liboctave/floatQR.cc +++ b/liboctave/floatQR.cc @@ -3,6 +3,7 @@ Copyright (C) 1994, 1995, 1996, 1997, 2002, 2003, 2004, 2005, 2007 John W. Eaton Copyright (C) 2008, 2009 Jaroslav Hajek +Copyright (C) 2009 VZLU Prague This file is part of Octave. @@ -91,29 +92,35 @@ octave_idx_type m = a.rows (); octave_idx_type n = a.cols (); - if (m == 0 || n == 0) - { - (*current_liboctave_error_handler) ("QR must have non-empty matrix"); - return; - } - octave_idx_type min_mn = m < n ? m : n; - Array<float> tau (min_mn); - float *ptau = tau.fortran_vec (); - - octave_idx_type lwork = 32*n; - Array<float> work (lwork); - float *pwork = work.fortran_vec (); + OCTAVE_LOCAL_BUFFER (float, tau, min_mn); octave_idx_type info = 0; - FloatMatrix A_fact = a; - if (m > n && qr_type != QR::economy) - A_fact.resize (m, m, 0.0); + FloatMatrix afact = a; + if (m > n && qr_type == QR::std) + afact.resize (m, m); + + if (m > 0) + { + // workspace query. + float rlwork; + F77_XFCN (sgeqrf, SGEQRF, (m, n, afact.fortran_vec (), m, tau, &rlwork, -1, info)); - float *tmp_data = A_fact.fortran_vec (); + // allocate buffer and do the job. + octave_idx_type lwork = rlwork; lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (float, work, lwork); + F77_XFCN (sgeqrf, SGEQRF, (m, n, afact.fortran_vec (), m, tau, work, lwork, info)); + } - F77_XFCN (sgeqrf, SGEQRF, (m, n, tmp_data, m, ptau, pwork, lwork, info)); + form (n, afact, tau, qr_type); +} + +void FloatQR::form (octave_idx_type n, FloatMatrix& afact, + float *tau, QR::type qr_type) +{ + octave_idx_type m = afact.rows (), min_mn = std::min (m, n); + octave_idx_type info; if (qr_type == QR::raw) { @@ -121,39 +128,58 @@ { octave_idx_type limit = j < min_mn - 1 ? j : min_mn - 1; for (octave_idx_type i = limit + 1; i < m; i++) - A_fact.elem (i, j) *= tau.elem (j); + afact.elem (i, j) *= tau[j]; } - r = A_fact; - - if (m > n) - r.resize (m, n); + r = afact; } else { - octave_idx_type n2 = (qr_type == QR::economy) ? min_mn : m; - - if (qr_type == QR::economy && m > n) - r.resize (n, n, 0.0); + // Attempt to minimize copying. + if (m >= n) + { + // afact will become q. + q = afact; + octave_idx_type k = qr_type == QR::economy ? n : m; + r = FloatMatrix (k, n); + for (octave_idx_type j = 0; j < n; j++) + { + octave_idx_type i = 0; + for (; i <= j; i++) + r.xelem (i, j) = afact.xelem (i, j); + for (;i < k; i++) + r.xelem (i, j) = 0; + } + afact = FloatMatrix (); // optimize memory + } else - r.resize (m, n, 0.0); + { + // afact will become r. + q = FloatMatrix (m, m); + for (octave_idx_type j = 0; j < m; j++) + for (octave_idx_type i = j + 1; i < m; i++) + { + q.xelem (i, j) = afact.xelem (i, j); + afact.xelem (i, j) = 0; + } + r = afact; + } - for (octave_idx_type j = 0; j < n; j++) - { - octave_idx_type limit = j < min_mn-1 ? j : min_mn-1; - for (octave_idx_type i = 0; i <= limit; i++) - r.elem (i, j) = tmp_data[m*j+i]; - } - lwork = 32 * n2; - work.resize (lwork); - float *pwork2 = work.fortran_vec (); + if (m > 0) + { + octave_idx_type k = q.columns (); + // workspace query. + float rlwork; + F77_XFCN (sorgqr, SORGQR, (m, k, min_mn, q.fortran_vec (), m, tau, + &rlwork, -1, info)); - F77_XFCN (sorgqr, SORGQR, (m, n2, min_mn, tmp_data, m, ptau, - pwork2, lwork, info)); - - q = A_fact; - q.resize (m, n2); + // allocate buffer and do the job. + octave_idx_type lwork = rlwork; lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (float, work, lwork); + F77_XFCN (sorgqr, SORGQR, (m, k, min_mn, q.fortran_vec (), m, tau, + work, lwork, info)); + } } }
--- a/liboctave/floatQR.h +++ b/liboctave/floatQR.h @@ -88,6 +88,9 @@ protected: + void form (octave_idx_type n, FloatMatrix& afact, + float *tau, QR::type qr_type); + FloatMatrix q; FloatMatrix r; };
--- a/liboctave/floatQRP.cc +++ b/liboctave/floatQRP.cc @@ -30,6 +30,7 @@ #include "floatQRP.h" #include "f77-fcn.h" #include "lo-error.h" +#include "oct-locbuf.h" extern "C" { @@ -37,11 +38,6 @@ F77_FUNC (sgeqp3, SGEQP3) (const octave_idx_type&, const octave_idx_type&, float*, const octave_idx_type&, octave_idx_type*, float*, float*, const octave_idx_type&, octave_idx_type&); - - F77_RET_T - F77_FUNC (sorgqr, SORGQR) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, - float*, const octave_idx_type&, float*, float*, - const octave_idx_type&, octave_idx_type&); } // It would be best to share some of this code with QR class... @@ -60,38 +56,32 @@ octave_idx_type m = a.rows (); octave_idx_type n = a.cols (); - if (m == 0 || n == 0) - { - (*current_liboctave_error_handler) ("QR must have non-empty matrix"); - return; - } - octave_idx_type min_mn = m < n ? m : n; - Array<float> tau (min_mn); - float *ptau = tau.fortran_vec (); + OCTAVE_LOCAL_BUFFER (float, tau, min_mn); octave_idx_type info = 0; - FloatMatrix A_fact = a; - if (m > n && qr_type != QR::economy) - A_fact.resize (m, m, 0.0); - - float *tmp_data = A_fact.fortran_vec (); + FloatMatrix afact = a; + if (m > n && qr_type == QR::std) + afact.resize (m, m); MArray<octave_idx_type> jpvt (n, 0); - octave_idx_type *pjpvt = jpvt.fortran_vec (); - float rlwork = 0; - // Workspace query... - F77_XFCN (sgeqp3, SGEQP3, (m, n, tmp_data, m, pjpvt, ptau, &rlwork, -1, info)); + if (m > 0) + { + // workspace query. + float rlwork; + F77_XFCN (sgeqp3, SGEQP3, (m, n, afact.fortran_vec (), m, jpvt.fortran_vec (), + tau, &rlwork, -1, info)); - octave_idx_type lwork = rlwork; - Array<float> work (lwork); - float *pwork = work.fortran_vec (); - - // Code to enforce a certain permutation could go here... - - F77_XFCN (sgeqp3, SGEQP3, (m, n, tmp_data, m, pjpvt, ptau, pwork, lwork, info)); + // allocate buffer and do the job. + octave_idx_type lwork = rlwork; lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (float, work, lwork); + F77_XFCN (sgeqp3, SGEQP3, (m, n, afact.fortran_vec (), m, jpvt.fortran_vec (), + tau, work, lwork, info)); + } + else + for (octave_idx_type i = 0; i < n; i++) jpvt(i) = i+1; // Form Permutation matrix (if economy is requested, return the // indices only!) @@ -99,25 +89,8 @@ jpvt -= 1; p = PermMatrix (jpvt, true); - octave_idx_type n2 = (qr_type == QR::economy) ? min_mn : m; - if (qr_type == QR::economy && m > n) - r.resize (n, n, 0.0); - else - r.resize (m, n, 0.0); - - for (octave_idx_type j = 0; j < n; j++) - { - octave_idx_type limit = j < min_mn-1 ? j : min_mn-1; - for (octave_idx_type i = 0; i <= limit; i++) - r.elem (i, j) = A_fact.elem (i, j); - } - - F77_XFCN (sorgqr, SORGQR, (m, n2, min_mn, tmp_data, m, ptau, - pwork, lwork, info)); - - q = A_fact; - q.resize (m, n2); + form (n, afact, tau, qr_type); } FloatColumnVector