Mercurial > hg > octave-lyh
changeset 9694:50db3c5175b5
allow unpacked form of LU
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Mon, 05 Oct 2009 15:39:44 +0200 |
parents | 1c19877799d3 |
children | 9fba7e1da785 |
files | liboctave/ChangeLog liboctave/CmplxLU.cc liboctave/base-lu.cc liboctave/base-lu.h liboctave/dbleLU.cc liboctave/dim-vector.h liboctave/fCmplxLU.cc liboctave/floatLU.cc |
diffstat | 8 files changed, 127 insertions(+), 49 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,3 +1,12 @@ +2009-10-05 Jaroslav Hajek <highegg@gmail.com> + + * dim-vector.h (operator ==): Include fast case. + * base-lu.cc (base_lu::packed, base_lu::unpack): New methods. + (base_lu::L, base_lu::U, base_lu::Y, base_lu::getp): Distinguish + packed vs. unpacked case. + * base-lu.h: Update decls. + (base_lu::l_fact): New member field. + 2009-10-02 Jaroslav Hajek <highegg@gmail.com> * lo-traits.h (strip_template_param): New trait class.
--- a/liboctave/CmplxLU.cc +++ b/liboctave/CmplxLU.cc @@ -61,7 +61,8 @@ F77_XFCN (zgetrf, ZGETRF, (a_nr, a_nc, tmp_data, a_nr, pipvt, info)); - ipvt -= static_cast<octave_idx_type> (1); + for (octave_idx_type i = 0; i < mn; i++) + pipvt[i] -= 1; } /*
--- a/liboctave/base-lu.cc +++ b/liboctave/base-lu.cc @@ -27,70 +27,121 @@ #include "base-lu.h" template <class lu_type> +base_lu<lu_type>::base_lu (const lu_type& l, const lu_type& u, + const PermMatrix& p) + : a_fact (u), l_fact (l), ipvt (p.pvec ()) +{ + if (l.columns () != u.rows ()) + (*current_liboctave_error_handler) ("lu: dimension mismatch"); +} + +template <class lu_type> +bool +base_lu <lu_type> :: packed (void) const +{ + return l_fact.dims () == dim_vector (); +} + +template <class lu_type> +void +base_lu <lu_type> :: unpack (void) +{ + if (packed ()) + { + l_fact = L (); + a_fact = U (); // FIXME: sub-optimal + } +} + +template <class lu_type> lu_type base_lu <lu_type> :: L (void) const { - octave_idx_type a_nr = a_fact.rows (); - octave_idx_type a_nc = a_fact.cols (); - octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc); - - lu_type l (a_nr, mn, lu_elt_type (0.0)); - - for (octave_idx_type i = 0; i < a_nr; i++) + if (packed ()) { - if (i < a_nc) - l.xelem (i, i) = 1.0; + octave_idx_type a_nr = a_fact.rows (); + octave_idx_type a_nc = a_fact.cols (); + octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc); + + lu_type l (a_nr, mn, lu_elt_type (0.0)); - for (octave_idx_type j = 0; j < (i < a_nc ? i : a_nc); j++) - l.xelem (i, j) = a_fact.xelem (i, j); + for (octave_idx_type i = 0; i < a_nr; i++) + { + if (i < a_nc) + l.xelem (i, i) = 1.0; + + for (octave_idx_type j = 0; j < (i < a_nc ? i : a_nc); j++) + l.xelem (i, j) = a_fact.xelem (i, j); + } + + return l; } - - return l; + else + return l_fact; } template <class lu_type> lu_type base_lu <lu_type> :: U (void) const { - octave_idx_type a_nr = a_fact.rows (); - octave_idx_type a_nc = a_fact.cols (); - octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc); + if (packed ()) + { + octave_idx_type a_nr = a_fact.rows (); + octave_idx_type a_nc = a_fact.cols (); + octave_idx_type mn = (a_nr < a_nc ? a_nr : a_nc); - lu_type u (mn, a_nc, lu_elt_type (0.0)); + lu_type u (mn, a_nc, lu_elt_type (0.0)); + + for (octave_idx_type i = 0; i < mn; i++) + { + for (octave_idx_type j = i; j < a_nc; j++) + u.xelem (i, j) = a_fact.xelem (i, j); + } - for (octave_idx_type i = 0; i < mn; i++) - { - for (octave_idx_type j = i; j < a_nc; j++) - u.xelem (i, j) = a_fact.xelem (i, j); + return u; } + else + return a_fact; +} - return u; +template <class lu_type> +lu_type +base_lu <lu_type> :: Y (void) const +{ + if (! packed ()) + (*current_liboctave_error_handler) ("lu: Y() not implemented for unpacked form."); + return a_fact; } template <class lu_type> Array<octave_idx_type> base_lu <lu_type> :: getp (void) const { - octave_idx_type a_nr = a_fact.rows (); - - Array<octave_idx_type> pvt (a_nr); - - for (octave_idx_type i = 0; i < a_nr; i++) - pvt.xelem (i) = i; - - for (octave_idx_type i = 0; i < ipvt.length(); i++) + if (packed ()) { - octave_idx_type k = ipvt.xelem (i); + octave_idx_type a_nr = a_fact.rows (); + + Array<octave_idx_type> pvt (a_nr); + + for (octave_idx_type i = 0; i < a_nr; i++) + pvt.xelem (i) = i; + + for (octave_idx_type i = 0; i < ipvt.length(); i++) + { + octave_idx_type k = ipvt.xelem (i); - if (k != i) - { - octave_idx_type tmp = pvt.xelem (k); - pvt.xelem (k) = pvt.xelem (i); - pvt.xelem (i) = tmp; - } + if (k != i) + { + octave_idx_type tmp = pvt.xelem (k); + pvt.xelem (k) = pvt.xelem (i); + pvt.xelem (i) = tmp; + } + } + + return pvt; } - - return pvt; + else + return ipvt; } template <class lu_type>
--- a/liboctave/base-lu.h +++ b/liboctave/base-lu.h @@ -37,13 +37,18 @@ base_lu (void) { } - base_lu (const base_lu& a) : a_fact (a.a_fact), ipvt (a.ipvt) { } + base_lu (const base_lu& a) : + a_fact (a.a_fact), l_fact (a.l_fact), ipvt (a.ipvt) { } + + base_lu (const lu_type& l, const lu_type& u, + const PermMatrix& p); base_lu& operator = (const base_lu& a) { if (this != &a) { a_fact = a.a_fact; + l_fact = a.l_fact; ipvt = a.ipvt; } return *this; @@ -51,11 +56,15 @@ ~base_lu (void) { } + bool packed (void) const; + + void unpack (void); + lu_type L (void) const; lu_type U (void) const; - lu_type Y (void) const { return a_fact; } + lu_type Y (void) const; PermMatrix P (void) const; @@ -64,8 +73,8 @@ protected: Array<octave_idx_type> getp (void) const; - lu_type a_fact; - MArray<octave_idx_type> ipvt; + lu_type a_fact, l_fact; + Array<octave_idx_type> ipvt; }; #endif
--- a/liboctave/dbleLU.cc +++ b/liboctave/dbleLU.cc @@ -61,7 +61,8 @@ F77_XFCN (dgetrf, DGETRF, (a_nr, a_nc, tmp_data, a_nr, pipvt, info)); - ipvt -= static_cast<octave_idx_type> (1); + for (octave_idx_type i = 0; i < mn; i++) + pipvt[i] -= 1; } /*
--- a/liboctave/dim-vector.h +++ b/liboctave/dim-vector.h @@ -516,11 +516,16 @@ return def; } + friend bool operator == (const dim_vector& a, const dim_vector& b); }; -static inline bool +inline bool operator == (const dim_vector& a, const dim_vector& b) { + // Fast case. + if (a.rep == b.rep) + return true; + bool retval = true; int a_len = a.length (); @@ -543,7 +548,7 @@ return retval; } -static inline bool +inline bool operator != (const dim_vector& a, const dim_vector& b) { return ! operator == (a, b);