Mercurial > hg > octave-nkf
diff liboctave/floatQR.cc @ 8597:c86718093c1b
improve & fix QR classes
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Tue, 27 Jan 2009 12:40:06 +0100 |
parents | a6edd5c23cb5 |
children | e9cb742df9eb |
line wrap: on
line diff
--- a/liboctave/floatQR.cc +++ b/liboctave/floatQR.cc @@ -3,6 +3,7 @@ Copyright (C) 1994, 1995, 1996, 1997, 2002, 2003, 2004, 2005, 2007 John W. Eaton Copyright (C) 2008, 2009 Jaroslav Hajek +Copyright (C) 2009 VZLU Prague This file is part of Octave. @@ -91,29 +92,35 @@ octave_idx_type m = a.rows (); octave_idx_type n = a.cols (); - if (m == 0 || n == 0) - { - (*current_liboctave_error_handler) ("QR must have non-empty matrix"); - return; - } - octave_idx_type min_mn = m < n ? m : n; - Array<float> tau (min_mn); - float *ptau = tau.fortran_vec (); - - octave_idx_type lwork = 32*n; - Array<float> work (lwork); - float *pwork = work.fortran_vec (); + OCTAVE_LOCAL_BUFFER (float, tau, min_mn); octave_idx_type info = 0; - FloatMatrix A_fact = a; - if (m > n && qr_type != QR::economy) - A_fact.resize (m, m, 0.0); + FloatMatrix afact = a; + if (m > n && qr_type == QR::std) + afact.resize (m, m); + + if (m > 0) + { + // workspace query. + float rlwork; + F77_XFCN (sgeqrf, SGEQRF, (m, n, afact.fortran_vec (), m, tau, &rlwork, -1, info)); - float *tmp_data = A_fact.fortran_vec (); + // allocate buffer and do the job. + octave_idx_type lwork = rlwork; lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (float, work, lwork); + F77_XFCN (sgeqrf, SGEQRF, (m, n, afact.fortran_vec (), m, tau, work, lwork, info)); + } - F77_XFCN (sgeqrf, SGEQRF, (m, n, tmp_data, m, ptau, pwork, lwork, info)); + form (n, afact, tau, qr_type); +} + +void FloatQR::form (octave_idx_type n, FloatMatrix& afact, + float *tau, QR::type qr_type) +{ + octave_idx_type m = afact.rows (), min_mn = std::min (m, n); + octave_idx_type info; if (qr_type == QR::raw) { @@ -121,39 +128,58 @@ { octave_idx_type limit = j < min_mn - 1 ? j : min_mn - 1; for (octave_idx_type i = limit + 1; i < m; i++) - A_fact.elem (i, j) *= tau.elem (j); + afact.elem (i, j) *= tau[j]; } - r = A_fact; - - if (m > n) - r.resize (m, n); + r = afact; } else { - octave_idx_type n2 = (qr_type == QR::economy) ? min_mn : m; - - if (qr_type == QR::economy && m > n) - r.resize (n, n, 0.0); + // Attempt to minimize copying. + if (m >= n) + { + // afact will become q. + q = afact; + octave_idx_type k = qr_type == QR::economy ? n : m; + r = FloatMatrix (k, n); + for (octave_idx_type j = 0; j < n; j++) + { + octave_idx_type i = 0; + for (; i <= j; i++) + r.xelem (i, j) = afact.xelem (i, j); + for (;i < k; i++) + r.xelem (i, j) = 0; + } + afact = FloatMatrix (); // optimize memory + } else - r.resize (m, n, 0.0); + { + // afact will become r. + q = FloatMatrix (m, m); + for (octave_idx_type j = 0; j < m; j++) + for (octave_idx_type i = j + 1; i < m; i++) + { + q.xelem (i, j) = afact.xelem (i, j); + afact.xelem (i, j) = 0; + } + r = afact; + } - for (octave_idx_type j = 0; j < n; j++) - { - octave_idx_type limit = j < min_mn-1 ? j : min_mn-1; - for (octave_idx_type i = 0; i <= limit; i++) - r.elem (i, j) = tmp_data[m*j+i]; - } - lwork = 32 * n2; - work.resize (lwork); - float *pwork2 = work.fortran_vec (); + if (m > 0) + { + octave_idx_type k = q.columns (); + // workspace query. + float rlwork; + F77_XFCN (sorgqr, SORGQR, (m, k, min_mn, q.fortran_vec (), m, tau, + &rlwork, -1, info)); - F77_XFCN (sorgqr, SORGQR, (m, n2, min_mn, tmp_data, m, ptau, - pwork2, lwork, info)); - - q = A_fact; - q.resize (m, n2); + // allocate buffer and do the job. + octave_idx_type lwork = rlwork; lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (float, work, lwork); + F77_XFCN (sorgqr, SORGQR, (m, k, min_mn, q.fortran_vec (), m, tau, + work, lwork, info)); + } } }