# HG changeset patch # User Jaroslav Hajek # Date 1255010753 -7200 # Node ID 6f3ffe11d926840a83ffa13d039a19224bb31927 # Parent 6f5c4c82c5fca83880900734bd2e9d0d3178a187 implement luupdate diff --git a/ChangeLog b/ChangeLog --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,7 @@ +2009-10-08 Jaroslav Hajek + + * configure.ac: Define HAVE_QRUPDATE_LUU if qrupdate supports LU updates. + 2009-10-08 Jaroslav Hajek * configure.ac: Use the LANG argument to OCTAVE_CHECK_LIBRARY to avoid diff --git a/configure.ac b/configure.ac --- a/configure.ac +++ b/configure.ac @@ -905,6 +905,20 @@ [], [sqr1up], [Fortran 77], [don't use qrupdate, disable QR & Cholesky updating functions]) + +if test "$octave_qrupdate_ok" = yes; then + LIBS="$LIBS $QRUPDATE_LIBS" + AC_LANG_PUSH([Fortran 77]) + AC_MSG_CHECKING([for slup1up in $QRUPDATE_LIBS]) + octave_qrupdate_luu=no + AC_LINK_IFELSE([AC_LANG_CALL([], [slup1up])], + [octave_qrupdate_luu=yes]) + AC_MSG_RESULT($octave_qrupdate_luu) + if test "$octave_qrupdate_luu" = yes; then + AC_DEFINE(HAVE_QRUPDATE_LUU, [1], [Define if qrupdate supports LU updates]) + fi + AC_LANG_POP([Fortran 77]) +fi LIBS="$save_LIBS" # Check for AMD library diff --git a/liboctave/ChangeLog b/liboctave/ChangeLog --- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,3 +1,19 @@ +2009-10-08 Jaroslav Hajek + + * PermMatrix.cc (PermMatrix::eye): New method. + * PermMatrix.h: Declare it. + * dbleLU.cc (LU::update, LU::update_piv): New overloaded methods. + * dbleLU.h: Declare them. + * floatLU.cc (FloatLU::update, FloatLU::update_piv): New overloaded + methods. + * floatLU.h: Declare them. + * CmplxLU.cc (ComplexLU::update, ComplexLU::update_piv): New + overloaded methods. + * CmplxLU.h: Declare them. + * fCmplxLU.cc (FloatComplexLU::update, FloatComplexLU::update_piv): + New overloaded methods. + * fCmplxLU.h: Declare them. + 2009-10-07 John W. Eaton * mx-inlines.cc (mx_inline_diff): Avoid uninitialized variable warning. diff --git a/liboctave/CmplxLU.cc b/liboctave/CmplxLU.cc --- a/liboctave/CmplxLU.cc +++ b/liboctave/CmplxLU.cc @@ -2,6 +2,7 @@ Copyright (C) 1994, 1995, 1996, 1997, 1999, 2002, 2003, 2004, 2005, 2007, 2008 John W. Eaton +Copyright (C) 2009 VZLU Prague, a.s. This file is part of Octave. @@ -28,6 +29,8 @@ #include "CmplxLU.h" #include "f77-fcn.h" #include "lo-error.h" +#include "oct-locbuf.h" +#include "CColVector.h" // Instantiate the base LU class for the types we need. @@ -43,6 +46,21 @@ F77_RET_T F77_FUNC (zgetrf, ZGETRF) (const octave_idx_type&, const octave_idx_type&, Complex*, const octave_idx_type&, octave_idx_type*, octave_idx_type&); + +#ifdef HAVE_QRUPDATE_LUU + F77_RET_T + F77_FUNC (zlu1up, ZLU1UP) (const octave_idx_type&, const octave_idx_type&, + Complex *, const octave_idx_type&, + Complex *, const octave_idx_type&, + Complex *, Complex *); + + F77_RET_T + F77_FUNC (zlup1up, ZLUP1UP) (const octave_idx_type&, const octave_idx_type&, + Complex *, const octave_idx_type&, + Complex *, const octave_idx_type&, + octave_idx_type *, const Complex *, + const Complex *, Complex *); +#endif } ComplexLU::ComplexLU (const ComplexMatrix& a) @@ -65,6 +83,132 @@ pipvt[i] -= 1; } +#ifdef HAVE_QRUPDATE_LUU + +void ComplexLU::update (const ComplexColumnVector& u, const ComplexColumnVector& v) +{ + if (packed ()) + unpack (); + + ComplexMatrix& l = l_fact; + ComplexMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.length () == m && v.length () == n) + { + ComplexColumnVector utmp = u, vtmp = v; + F77_XFCN (zlu1up, ZLU1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + utmp.fortran_vec (), vtmp.fortran_vec ())); + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void ComplexLU::update (const ComplexMatrix& u, const ComplexMatrix& v) +{ + if (packed ()) + unpack (); + + ComplexMatrix& l = l_fact; + ComplexMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.rows () == m && v.rows () == n && u.cols () == v.cols ()) + { + for (volatile octave_idx_type i = 0; i < u.cols (); i++) + { + ComplexColumnVector utmp = u.column (i), vtmp = v.column (i); + F77_XFCN (zlu1up, ZLU1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + utmp.fortran_vec (), vtmp.fortran_vec ())); + } + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void ComplexLU::update_piv (const ComplexColumnVector& u, const ComplexColumnVector& v) +{ + if (packed ()) + unpack (); + + ComplexMatrix& l = l_fact; + ComplexMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.length () == m && v.length () == n) + { + ComplexColumnVector utmp = u, vtmp = v; + OCTAVE_LOCAL_BUFFER (Complex, w, m); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) += 1; // increment + F77_XFCN (zlup1up, ZLUP1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + ipvt.fortran_vec (), utmp.data (), vtmp.data (), w)); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) -= 1; // decrement + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void ComplexLU::update_piv (const ComplexMatrix& u, const ComplexMatrix& v) +{ + if (packed ()) + unpack (); + + ComplexMatrix& l = l_fact; + ComplexMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.rows () == m && v.rows () == n && u.cols () == v.cols ()) + { + OCTAVE_LOCAL_BUFFER (Complex, w, m); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) += 1; // increment + for (volatile octave_idx_type i = 0; i < u.cols (); i++) + { + ComplexColumnVector utmp = u.column (i), vtmp = v.column (i); + F77_XFCN (zlup1up, ZLUP1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + ipvt.fortran_vec (), utmp.data (), vtmp.data (), w)); + } + for (octave_idx_type i = 0; i < m; i++) ipvt(i) -= 1; // decrement + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +#else + +void ComplexLU::update (const ComplexColumnVector&, const ComplexColumnVector&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void ComplexLU::update (const ComplexMatrix&, const ComplexMatrix&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void ComplexLU::update_piv (const ComplexColumnVector&, const ComplexColumnVector&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void ComplexLU::update_piv (const ComplexMatrix&, const ComplexMatrix&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +#endif + /* ;;; Local Variables: *** ;;; mode: C++ *** diff --git a/liboctave/CmplxLU.h b/liboctave/CmplxLU.h --- a/liboctave/CmplxLU.h +++ b/liboctave/CmplxLU.h @@ -42,6 +42,10 @@ ComplexLU (const ComplexLU& a) : base_lu (a) { } + ComplexLU (const ComplexMatrix& l, const ComplexMatrix& u, + const PermMatrix& p) + : base_lu (l, u, p) { } + ComplexLU& operator = (const ComplexLU& a) { if (this != &a) @@ -51,6 +55,14 @@ } ~ComplexLU (void) { } + + void update (const ComplexColumnVector& u, const ComplexColumnVector& v); + + void update (const ComplexMatrix& u, const ComplexMatrix& v); + + void update_piv (const ComplexColumnVector& u, const ComplexColumnVector& v); + + void update_piv (const ComplexMatrix& u, const ComplexMatrix& v); }; #endif diff --git a/liboctave/PermMatrix.cc b/liboctave/PermMatrix.cc --- a/liboctave/PermMatrix.cc +++ b/liboctave/PermMatrix.cc @@ -179,6 +179,16 @@ return PermMatrix (res_pvec, res_colp, false); } +PermMatrix +PermMatrix::eye (octave_idx_type n) +{ + Array p(n); + for (octave_idx_type i = 0; i < n; i++) + p(i) = i; + + return PermMatrix (p, false, false); +} + PermMatrix operator *(const PermMatrix& a, const PermMatrix& b) { diff --git a/liboctave/PermMatrix.h b/liboctave/PermMatrix.h --- a/liboctave/PermMatrix.h +++ b/liboctave/PermMatrix.h @@ -118,6 +118,8 @@ void print_info (std::ostream& os, const std::string& prefix) const { Array::print_info (os, prefix); } + static PermMatrix eye (octave_idx_type n); + private: bool _colp; }; diff --git a/liboctave/dbleLU.cc b/liboctave/dbleLU.cc --- a/liboctave/dbleLU.cc +++ b/liboctave/dbleLU.cc @@ -2,6 +2,7 @@ Copyright (C) 1994, 1995, 1996, 1997, 1999, 2002, 2003, 2004, 2005, 2007, 2008 John W. Eaton +Copyright (C) 2009 VZLU Prague, a.s. This file is part of Octave. @@ -28,6 +29,8 @@ #include "dbleLU.h" #include "f77-fcn.h" #include "lo-error.h" +#include "oct-locbuf.h" +#include "dColVector.h" // Instantiate the base LU class for the types we need. @@ -43,6 +46,21 @@ F77_RET_T F77_FUNC (dgetrf, DGETRF) (const octave_idx_type&, const octave_idx_type&, double*, const octave_idx_type&, octave_idx_type*, octave_idx_type&); + +#ifdef HAVE_QRUPDATE_LUU + F77_RET_T + F77_FUNC (dlu1up, DLU1UP) (const octave_idx_type&, const octave_idx_type&, + double *, const octave_idx_type&, + double *, const octave_idx_type&, + double *, double *); + + F77_RET_T + F77_FUNC (dlup1up, DLUP1UP) (const octave_idx_type&, const octave_idx_type&, + double *, const octave_idx_type&, + double *, const octave_idx_type&, + octave_idx_type *, const double *, + const double *, double *); +#endif } LU::LU (const Matrix& a) @@ -65,6 +83,132 @@ pipvt[i] -= 1; } +#ifdef HAVE_QRUPDATE_LUU + +void LU::update (const ColumnVector& u, const ColumnVector& v) +{ + if (packed ()) + unpack (); + + Matrix& l = l_fact; + Matrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.length () == m && v.length () == n) + { + ColumnVector utmp = u, vtmp = v; + F77_XFCN (dlu1up, DLU1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + utmp.fortran_vec (), vtmp.fortran_vec ())); + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void LU::update (const Matrix& u, const Matrix& v) +{ + if (packed ()) + unpack (); + + Matrix& l = l_fact; + Matrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.rows () == m && v.rows () == n && u.cols () == v.cols ()) + { + for (volatile octave_idx_type i = 0; i < u.cols (); i++) + { + ColumnVector utmp = u.column (i), vtmp = v.column (i); + F77_XFCN (dlu1up, DLU1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + utmp.fortran_vec (), vtmp.fortran_vec ())); + } + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void LU::update_piv (const ColumnVector& u, const ColumnVector& v) +{ + if (packed ()) + unpack (); + + Matrix& l = l_fact; + Matrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.length () == m && v.length () == n) + { + ColumnVector utmp = u, vtmp = v; + OCTAVE_LOCAL_BUFFER (double, w, m); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) += 1; // increment + F77_XFCN (dlup1up, DLUP1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + ipvt.fortran_vec (), utmp.data (), vtmp.data (), w)); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) -= 1; // decrement + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void LU::update_piv (const Matrix& u, const Matrix& v) +{ + if (packed ()) + unpack (); + + Matrix& l = l_fact; + Matrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.rows () == m && v.rows () == n && u.cols () == v.cols ()) + { + OCTAVE_LOCAL_BUFFER (double, w, m); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) += 1; // increment + for (volatile octave_idx_type i = 0; i < u.cols (); i++) + { + ColumnVector utmp = u.column (i), vtmp = v.column (i); + F77_XFCN (dlup1up, DLUP1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + ipvt.fortran_vec (), utmp.data (), vtmp.data (), w)); + } + for (octave_idx_type i = 0; i < m; i++) ipvt(i) -= 1; // decrement + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +#else + +void LU::update (const ColumnVector&, const ColumnVector&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void LU::update (const Matrix&, const Matrix&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void LU::update_piv (const ColumnVector&, const ColumnVector&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void LU::update_piv (const Matrix&, const Matrix&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +#endif + /* ;;; Local Variables: *** ;;; mode: C++ *** diff --git a/liboctave/dbleLU.h b/liboctave/dbleLU.h --- a/liboctave/dbleLU.h +++ b/liboctave/dbleLU.h @@ -39,6 +39,9 @@ LU (const LU& a) : base_lu (a) { } + LU (const Matrix& l, const Matrix& u, const PermMatrix& p) + : base_lu (l, u, p) { } + LU& operator = (const LU& a) { if (this != &a) @@ -48,6 +51,14 @@ } ~LU (void) { } + + void update (const ColumnVector& u, const ColumnVector& v); + + void update (const Matrix& u, const Matrix& v); + + void update_piv (const ColumnVector& u, const ColumnVector& v); + + void update_piv (const Matrix& u, const Matrix& v); }; #endif diff --git a/liboctave/fCmplxLU.cc b/liboctave/fCmplxLU.cc --- a/liboctave/fCmplxLU.cc +++ b/liboctave/fCmplxLU.cc @@ -2,6 +2,7 @@ Copyright (C) 1994, 1995, 1996, 1997, 1999, 2002, 2003, 2004, 2005, 2007, 2008 John W. Eaton +Copyright (C) 2009 VZLU Prague, a.s. This file is part of Octave. @@ -28,6 +29,8 @@ #include "fCmplxLU.h" #include "f77-fcn.h" #include "lo-error.h" +#include "oct-locbuf.h" +#include "fCColVector.h" // Instantiate the base LU class for the types we need. @@ -43,6 +46,21 @@ F77_RET_T F77_FUNC (cgetrf, CGETRF) (const octave_idx_type&, const octave_idx_type&, FloatComplex*, const octave_idx_type&, octave_idx_type*, octave_idx_type&); + +#ifdef HAVE_QRUPDATE_LUU + F77_RET_T + F77_FUNC (clu1up, CLU1UP) (const octave_idx_type&, const octave_idx_type&, + FloatComplex *, const octave_idx_type&, + FloatComplex *, const octave_idx_type&, + FloatComplex *, FloatComplex *); + + F77_RET_T + F77_FUNC (clup1up, CLUP1UP) (const octave_idx_type&, const octave_idx_type&, + FloatComplex *, const octave_idx_type&, + FloatComplex *, const octave_idx_type&, + octave_idx_type *, const FloatComplex *, + const FloatComplex *, FloatComplex *); +#endif } FloatComplexLU::FloatComplexLU (const FloatComplexMatrix& a) @@ -65,6 +83,132 @@ pipvt[i] -= 1; } +#ifdef HAVE_QRUPDATE_LUU + +void FloatComplexLU::update (const FloatComplexColumnVector& u, const FloatComplexColumnVector& v) +{ + if (packed ()) + unpack (); + + FloatComplexMatrix& l = l_fact; + FloatComplexMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.length () == m && v.length () == n) + { + FloatComplexColumnVector utmp = u, vtmp = v; + F77_XFCN (clu1up, CLU1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + utmp.fortran_vec (), vtmp.fortran_vec ())); + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void FloatComplexLU::update (const FloatComplexMatrix& u, const FloatComplexMatrix& v) +{ + if (packed ()) + unpack (); + + FloatComplexMatrix& l = l_fact; + FloatComplexMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.rows () == m && v.rows () == n && u.cols () == v.cols ()) + { + for (volatile octave_idx_type i = 0; i < u.cols (); i++) + { + FloatComplexColumnVector utmp = u.column (i), vtmp = v.column (i); + F77_XFCN (clu1up, CLU1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + utmp.fortran_vec (), vtmp.fortran_vec ())); + } + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void FloatComplexLU::update_piv (const FloatComplexColumnVector& u, const FloatComplexColumnVector& v) +{ + if (packed ()) + unpack (); + + FloatComplexMatrix& l = l_fact; + FloatComplexMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.length () == m && v.length () == n) + { + FloatComplexColumnVector utmp = u, vtmp = v; + OCTAVE_LOCAL_BUFFER (FloatComplex, w, m); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) += 1; // increment + F77_XFCN (clup1up, CLUP1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + ipvt.fortran_vec (), utmp.data (), vtmp.data (), w)); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) -= 1; // decrement + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void FloatComplexLU::update_piv (const FloatComplexMatrix& u, const FloatComplexMatrix& v) +{ + if (packed ()) + unpack (); + + FloatComplexMatrix& l = l_fact; + FloatComplexMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.rows () == m && v.rows () == n && u.cols () == v.cols ()) + { + OCTAVE_LOCAL_BUFFER (FloatComplex, w, m); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) += 1; // increment + for (volatile octave_idx_type i = 0; i < u.cols (); i++) + { + FloatComplexColumnVector utmp = u.column (i), vtmp = v.column (i); + F77_XFCN (clup1up, CLUP1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + ipvt.fortran_vec (), utmp.data (), vtmp.data (), w)); + } + for (octave_idx_type i = 0; i < m; i++) ipvt(i) -= 1; // decrement + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +#else + +void FloatComplexLU::update (const FloatComplexColumnVector&, const FloatComplexColumnVector&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void FloatComplexLU::update (const FloatComplexMatrix&, const FloatComplexMatrix&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void FloatComplexLU::update_piv (const FloatComplexColumnVector&, const FloatComplexColumnVector&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void FloatComplexLU::update_piv (const FloatComplexMatrix&, const FloatComplexMatrix&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +#endif + /* ;;; Local Variables: *** ;;; mode: C++ *** diff --git a/liboctave/fCmplxLU.h b/liboctave/fCmplxLU.h --- a/liboctave/fCmplxLU.h +++ b/liboctave/fCmplxLU.h @@ -42,6 +42,10 @@ FloatComplexLU (const FloatComplexLU& a) : base_lu (a) { } + FloatComplexLU (const FloatComplexMatrix& l, const FloatComplexMatrix& u, + const PermMatrix& p) + : base_lu (l, u, p) { } + FloatComplexLU& operator = (const FloatComplexLU& a) { if (this != &a) @@ -51,6 +55,14 @@ } ~FloatComplexLU (void) { } + + void update (const FloatComplexColumnVector& u, const FloatComplexColumnVector& v); + + void update (const FloatComplexMatrix& u, const FloatComplexMatrix& v); + + void update_piv (const FloatComplexColumnVector& u, const FloatComplexColumnVector& v); + + void update_piv (const FloatComplexMatrix& u, const FloatComplexMatrix& v); }; #endif diff --git a/liboctave/floatLU.cc b/liboctave/floatLU.cc --- a/liboctave/floatLU.cc +++ b/liboctave/floatLU.cc @@ -2,6 +2,7 @@ Copyright (C) 1994, 1995, 1996, 1997, 1999, 2002, 2003, 2004, 2005, 2007, 2008 John W. Eaton +Copyright (C) 2009 VZLU Prague, a.s. This file is part of Octave. @@ -28,6 +29,8 @@ #include "floatLU.h" #include "f77-fcn.h" #include "lo-error.h" +#include "oct-locbuf.h" +#include "fColVector.h" // Instantiate the base LU class for the types we need. @@ -43,6 +46,21 @@ F77_RET_T F77_FUNC (sgetrf, SGETRF) (const octave_idx_type&, const octave_idx_type&, float*, const octave_idx_type&, octave_idx_type*, octave_idx_type&); + +#ifdef HAVE_QRUPDATE_LUU + F77_RET_T + F77_FUNC (slu1up, SLU1UP) (const octave_idx_type&, const octave_idx_type&, + float *, const octave_idx_type&, + float *, const octave_idx_type&, + float *, float *); + + F77_RET_T + F77_FUNC (slup1up, SLUP1UP) (const octave_idx_type&, const octave_idx_type&, + float *, const octave_idx_type&, + float *, const octave_idx_type&, + octave_idx_type *, const float *, + const float *, float *); +#endif } FloatLU::FloatLU (const FloatMatrix& a) @@ -65,6 +83,132 @@ pipvt[i] -= 1; } +#ifdef HAVE_QRUPDATE_LUU + +void FloatLU::update (const FloatColumnVector& u, const FloatColumnVector& v) +{ + if (packed ()) + unpack (); + + FloatMatrix& l = l_fact; + FloatMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.length () == m && v.length () == n) + { + FloatColumnVector utmp = u, vtmp = v; + F77_XFCN (slu1up, SLU1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + utmp.fortran_vec (), vtmp.fortran_vec ())); + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void FloatLU::update (const FloatMatrix& u, const FloatMatrix& v) +{ + if (packed ()) + unpack (); + + FloatMatrix& l = l_fact; + FloatMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.rows () == m && v.rows () == n && u.cols () == v.cols ()) + { + for (volatile octave_idx_type i = 0; i < u.cols (); i++) + { + FloatColumnVector utmp = u.column (i), vtmp = v.column (i); + F77_XFCN (slu1up, SLU1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + utmp.fortran_vec (), vtmp.fortran_vec ())); + } + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void FloatLU::update_piv (const FloatColumnVector& u, const FloatColumnVector& v) +{ + if (packed ()) + unpack (); + + FloatMatrix& l = l_fact; + FloatMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.length () == m && v.length () == n) + { + FloatColumnVector utmp = u, vtmp = v; + OCTAVE_LOCAL_BUFFER (float, w, m); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) += 1; // increment + F77_XFCN (slup1up, SLUP1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + ipvt.fortran_vec (), utmp.data (), vtmp.data (), w)); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) -= 1; // decrement + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +void FloatLU::update_piv (const FloatMatrix& u, const FloatMatrix& v) +{ + if (packed ()) + unpack (); + + FloatMatrix& l = l_fact; + FloatMatrix& r = a_fact; + + octave_idx_type m = l.rows (); + octave_idx_type n = r.columns (); + octave_idx_type k = l.columns (); + + if (u.rows () == m && v.rows () == n && u.cols () == v.cols ()) + { + OCTAVE_LOCAL_BUFFER (float, w, m); + for (octave_idx_type i = 0; i < m; i++) ipvt(i) += 1; // increment + for (volatile octave_idx_type i = 0; i < u.cols (); i++) + { + FloatColumnVector utmp = u.column (i), vtmp = v.column (i); + F77_XFCN (slup1up, SLUP1UP, (m, n, l.fortran_vec (), m, r.fortran_vec (), k, + ipvt.fortran_vec (), utmp.data (), vtmp.data (), w)); + } + for (octave_idx_type i = 0; i < m; i++) ipvt(i) -= 1; // decrement + } + else + (*current_liboctave_error_handler) ("luupdate: dimensions mismatch"); +} + +#else + +void FloatLU::update (const FloatColumnVector&, const FloatColumnVector&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void FloatLU::update (const FloatMatrix&, const FloatMatrix&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void FloatLU::update_piv (const FloatColumnVector&, const FloatColumnVector&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +void FloatLU::update_piv (const FloatMatrix&, const FloatMatrix&) +{ + (*current_liboctave_error_handler) ("luupdate: not available in this version"); +} + +#endif + /* ;;; Local Variables: *** ;;; mode: C++ *** diff --git a/liboctave/floatLU.h b/liboctave/floatLU.h --- a/liboctave/floatLU.h +++ b/liboctave/floatLU.h @@ -40,6 +40,10 @@ FloatLU (const FloatLU& a) : base_lu (a) { } + FloatLU (const FloatMatrix& l, const FloatMatrix& u, + const PermMatrix& p) + : base_lu (l, u, p) { } + FloatLU& operator = (const FloatLU& a) { if (this != &a) @@ -49,6 +53,14 @@ } ~FloatLU (void) { } + + void update (const FloatColumnVector& u, const FloatColumnVector& v); + + void update (const FloatMatrix& u, const FloatMatrix& v); + + void update_piv (const FloatColumnVector& u, const FloatColumnVector& v); + + void update_piv (const FloatMatrix& u, const FloatMatrix& v); }; #endif diff --git a/src/ChangeLog b/src/ChangeLog --- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,8 @@ +2009-10-08 Jaroslav Hajek + + * DLD-FUNCTIONS/lu.cc (Fluupdate): New DEFUN_DLD. + (check_lu_dims): New helper func. + 2009-10-08 Jaroslav Hajek * data.cc (Flength): Simplify. diff --git a/src/DLD-FUNCTIONS/lu.cc b/src/DLD-FUNCTIONS/lu.cc --- a/src/DLD-FUNCTIONS/lu.cc +++ b/src/DLD-FUNCTIONS/lu.cc @@ -584,6 +584,149 @@ */ +static +bool check_lu_dims (const octave_value& l, const octave_value& u, + const octave_value& p) +{ + octave_idx_type m = l.rows (), k = u.rows (), n = u.columns (); + return ((l.ndims () == 2 && u.ndims () == 2 && k == l.columns ()) + && k == std::min (m, n) && + (p.is_undefined () || p.rows () == m)); +} + +DEFUN_DLD (luupdate, args, nargout, + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {[@var{l}, @var{u}] =} luupdate (@var{l}, @var{u}, @var{x}, @var{y})\n\ +@deftypefn {Loadable Function} {[@var{l}, @var{u}, @var{p}] =}\ +luupdate (@var{l}, @var{u}, @var{p}, @var{x}, @var{y})\n\ +") +{ + octave_idx_type nargin = args.length (); + octave_value_list retval; + + bool pivoted = nargin == 5; + + if (nargin != 4 && nargin != 5) + { + print_usage (); + return retval; + } + + octave_value argl = args(0); + octave_value argu = args(1); + octave_value argp = pivoted ? args(2) : octave_value (); + octave_value argx = args(2 + pivoted); + octave_value argy = args(3 + pivoted); + + if (argl.is_numeric_type () && argu.is_numeric_type () + && argx.is_numeric_type () && argy.is_numeric_type () + && (! pivoted || argp.is_perm_matrix ())) + { + if (check_lu_dims (argl, argu, argp)) + { + PermMatrix P = (pivoted + ? argp.perm_matrix_value () + : PermMatrix::eye (argl.rows ())); + + if (argl.is_real_type () + && argu.is_real_type () + && argx.is_real_type () + && argy.is_real_type ()) + { + // all real case + if (argl.is_single_type () + || argu.is_single_type () + || argx.is_single_type () + || argy.is_single_type ()) + { + FloatMatrix L = argl.float_matrix_value (); + FloatMatrix U = argu.float_matrix_value (); + FloatMatrix x = argx.float_matrix_value (); + FloatMatrix y = argy.float_matrix_value (); + + FloatLU fact (L, U, P); + if (pivoted) + fact.update_piv (x, y); + else + fact.update (x, y); + + if (pivoted) + retval(2) = fact.P (); + retval(1) = fact.U (); + retval(0) = fact.L (); + } + else + { + Matrix L = argl.matrix_value (); + Matrix U = argu.matrix_value (); + Matrix x = argx.matrix_value (); + Matrix y = argy.matrix_value (); + + LU fact (L, U, P); + if (pivoted) + fact.update_piv (x, y); + else + fact.update (x, y); + + if (pivoted) + retval(2) = fact.P (); + retval(1) = fact.U (); + retval(0) = fact.L (); + } + } + else + { + // complex case + if (argl.is_single_type () + || argu.is_single_type () + || argx.is_single_type () + || argy.is_single_type ()) + { + FloatComplexMatrix L = argl.float_complex_matrix_value (); + FloatComplexMatrix U = argu.float_complex_matrix_value (); + FloatComplexMatrix x = argx.float_complex_matrix_value (); + FloatComplexMatrix y = argy.float_complex_matrix_value (); + + FloatComplexLU fact (L, U, P); + if (pivoted) + fact.update_piv (x, y); + else + fact.update (x, y); + + if (pivoted) + retval(2) = fact.P (); + retval(1) = fact.U (); + retval(0) = fact.L (); + } + else + { + ComplexMatrix L = argl.complex_matrix_value (); + ComplexMatrix U = argu.complex_matrix_value (); + ComplexMatrix x = argx.complex_matrix_value (); + ComplexMatrix y = argy.complex_matrix_value (); + + ComplexLU fact (L, U, P); + if (pivoted) + fact.update_piv (x, y); + else + fact.update (x, y); + + if (pivoted) + retval(2) = fact.P (); + retval(1) = fact.U (); + retval(0) = fact.L (); + } + } + } + else + error ("luupdate: dimensions mismatch"); + } + else + error ("luupdate: expecting numeric arguments"); + + return retval; +} + /* ;;; Local Variables: *** ;;; mode: C++ ***