Mercurial > hg > octave-lyh
diff liboctave/base-lu.cc @ 9694:50db3c5175b5
allow unpacked form of LU
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Mon, 05 Oct 2009 15:39:44 +0200 |
parents | eb63fbe60fab |
children | 51c17bd18563 |
line wrap: on
line diff
--- a/liboctave/base-lu.cc +++ b/liboctave/base-lu.cc @@ -27,70 +27,121 @@ #include "base-lu.h" template <class lu_type> +base_lu<lu_type>::base_lu (const lu_type& l, const lu_type& u, + const PermMatrix& p) + : a_fact (u), l_fact (l), ipvt (p.pvec ()) +{ + if (l.columns () != u.rows ()) + (*current_liboctave_error_handler) ("lu: dimension mismatch"); +} + +template <class lu_type> +bool +base_lu <lu_type> :: packed (void) const +{ + return l_fact.dims () == dim_vector (); +} + +template <class lu_type> +void +base_lu <lu_type> :: unpack (void) +{ + if (packed ()) + { + l_fact = L (); + a_fact = U (); // FIXME: sub-optimal + } +} + +template <class lu_type> lu_type base_lu <lu_type> :: L (void) const { - octave_idx_type a_nr = a_fact.rows (); - octave_idx_type a_nc = a_fact.cols (); - octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc); - - lu_type l (a_nr, mn, lu_elt_type (0.0)); - - for (octave_idx_type i = 0; i < a_nr; i++) + if (packed ()) { - if (i < a_nc) - l.xelem (i, i) = 1.0; + octave_idx_type a_nr = a_fact.rows (); + octave_idx_type a_nc = a_fact.cols (); + octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc); + + lu_type l (a_nr, mn, lu_elt_type (0.0)); - for (octave_idx_type j = 0; j < (i < a_nc ? i : a_nc); j++) - l.xelem (i, j) = a_fact.xelem (i, j); + for (octave_idx_type i = 0; i < a_nr; i++) + { + if (i < a_nc) + l.xelem (i, i) = 1.0; + + for (octave_idx_type j = 0; j < (i < a_nc ? i : a_nc); j++) + l.xelem (i, j) = a_fact.xelem (i, j); + } + + return l; } - - return l; + else + return l_fact; } template <class lu_type> lu_type base_lu <lu_type> :: U (void) const { - octave_idx_type a_nr = a_fact.rows (); - octave_idx_type a_nc = a_fact.cols (); - octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc); + if (packed ()) + { + octave_idx_type a_nr = a_fact.rows (); + octave_idx_type a_nc = a_fact.cols (); + octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc); - lu_type u (mn, a_nc, lu_elt_type (0.0)); + lu_type u (mn, a_nc, lu_elt_type (0.0)); + + for (octave_idx_type i = 0; i < mn; i++) + { + for (octave_idx_type j = i; j < a_nc; j++) + u.xelem (i, j) = a_fact.xelem (i, j); + } - for (octave_idx_type i = 0; i < mn; i++) - { - for (octave_idx_type j = i; j < a_nc; j++) - u.xelem (i, j) = a_fact.xelem (i, j); + return u; } + else + return a_fact; +} - return u; +template <class lu_type> +lu_type +base_lu <lu_type> :: Y (void) const +{ + if (! packed ()) + (*current_liboctave_error_handler) ("lu: Y() not implemented for unpacked form."); + return a_fact; } template <class lu_type> Array<octave_idx_type> base_lu <lu_type> :: getp (void) const { - octave_idx_type a_nr = a_fact.rows (); - - Array<octave_idx_type> pvt (a_nr); - - for (octave_idx_type i = 0; i < a_nr; i++) - pvt.xelem (i) = i; - - for (octave_idx_type i = 0; i < ipvt.length(); i++) + if (packed ()) { - octave_idx_type k = ipvt.xelem (i); + octave_idx_type a_nr = a_fact.rows (); + + Array<octave_idx_type> pvt (a_nr); + + for (octave_idx_type i = 0; i < a_nr; i++) + pvt.xelem (i) = i; + + for (octave_idx_type i = 0; i < ipvt.length(); i++) + { + octave_idx_type k = ipvt.xelem (i); - if (k != i) - { - octave_idx_type tmp = pvt.xelem (k); - pvt.xelem (k) = pvt.xelem (i); - pvt.xelem (i) = tmp; - } + if (k != i) + { + octave_idx_type tmp = pvt.xelem (k); + pvt.xelem (k) = pvt.xelem (i); + pvt.xelem (i) = tmp; + } + } + + return pvt; } - - return pvt; + else + return ipvt; } template <class lu_type>