Mercurial > hg > octave-lyh
diff liboctave/floatQRP.cc @ 8597:c86718093c1b
improve & fix QR classes
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Tue, 27 Jan 2009 12:40:06 +0100 |
parents | e3c9102431a9 |
children | 20dfb885f877 |
line wrap: on
line diff
--- a/liboctave/floatQRP.cc +++ b/liboctave/floatQRP.cc @@ -30,6 +30,7 @@ #include "floatQRP.h" #include "f77-fcn.h" #include "lo-error.h" +#include "oct-locbuf.h" extern "C" { @@ -37,11 +38,6 @@ F77_FUNC (sgeqp3, SGEQP3) (const octave_idx_type&, const octave_idx_type&, float*, const octave_idx_type&, octave_idx_type*, float*, float*, const octave_idx_type&, octave_idx_type&); - - F77_RET_T - F77_FUNC (sorgqr, SORGQR) (const octave_idx_type&, const octave_idx_type&, const octave_idx_type&, - float*, const octave_idx_type&, float*, float*, - const octave_idx_type&, octave_idx_type&); } // It would be best to share some of this code with QR class... @@ -60,38 +56,32 @@ octave_idx_type m = a.rows (); octave_idx_type n = a.cols (); - if (m == 0 || n == 0) - { - (*current_liboctave_error_handler) ("QR must have non-empty matrix"); - return; - } - octave_idx_type min_mn = m < n ? m : n; - Array<float> tau (min_mn); - float *ptau = tau.fortran_vec (); + OCTAVE_LOCAL_BUFFER (float, tau, min_mn); octave_idx_type info = 0; - FloatMatrix A_fact = a; - if (m > n && qr_type != QR::economy) - A_fact.resize (m, m, 0.0); - - float *tmp_data = A_fact.fortran_vec (); + FloatMatrix afact = a; + if (m > n && qr_type == QR::std) + afact.resize (m, m); MArray<octave_idx_type> jpvt (n, 0); - octave_idx_type *pjpvt = jpvt.fortran_vec (); - float rlwork = 0; - // Workspace query... - F77_XFCN (sgeqp3, SGEQP3, (m, n, tmp_data, m, pjpvt, ptau, &rlwork, -1, info)); + if (m > 0) + { + // workspace query. + float rlwork; + F77_XFCN (sgeqp3, SGEQP3, (m, n, afact.fortran_vec (), m, jpvt.fortran_vec (), + tau, &rlwork, -1, info)); - octave_idx_type lwork = rlwork; - Array<float> work (lwork); - float *pwork = work.fortran_vec (); - - // Code to enforce a certain permutation could go here... - - F77_XFCN (sgeqp3, SGEQP3, (m, n, tmp_data, m, pjpvt, ptau, pwork, lwork, info)); + // allocate buffer and do the job. + octave_idx_type lwork = rlwork; lwork = std::max (lwork, 1); + OCTAVE_LOCAL_BUFFER (float, work, lwork); + F77_XFCN (sgeqp3, SGEQP3, (m, n, afact.fortran_vec (), m, jpvt.fortran_vec (), + tau, work, lwork, info)); + } + else + for (octave_idx_type i = 0; i < n; i++) jpvt(i) = i+1; // Form Permutation matrix (if economy is requested, return the // indices only!) @@ -99,25 +89,8 @@ jpvt -= 1; p = PermMatrix (jpvt, true); - octave_idx_type n2 = (qr_type == QR::economy) ? min_mn : m; - if (qr_type == QR::economy && m > n) - r.resize (n, n, 0.0); - else - r.resize (m, n, 0.0); - - for (octave_idx_type j = 0; j < n; j++) - { - octave_idx_type limit = j < min_mn-1 ? j : min_mn-1; - for (octave_idx_type i = 0; i <= limit; i++) - r.elem (i, j) = A_fact.elem (i, j); - } - - F77_XFCN (sorgqr, SORGQR, (m, n2, min_mn, tmp_data, m, ptau, - pwork, lwork, info)); - - q = A_fact; - q.resize (m, n2); + form (n, afact, tau, qr_type); } FloatColumnVector