Mercurial > hg > octave-nkf
diff liboctave/numeric/floatCHOL.cc @ 20672:5ce959c55cc0
Propagate 'lower' in chol(a, 'lower') to underlying library function.
* chol.cc (chol): Send 'L' parameter correctly when chol is called with 'lower'.
* floatCHOL.cc (init): Propagate 'lower' to underlying library function.
* floatCHOL.h: Modify the prototype of methods.
* fMatrix.cc (inverse): Invoke chol with additional parameter.
* dbleCHOL.cc (init): Propagate 'lower' to underlying library function.
* dbleCHOL.h: Modify the prototype of methods.
* dMatrix.cc (inverse): Invoke chol with additional parameter.
* CmplxCHOL.cc (init): Propagate 'lower' to underlying library function.
* CmplxCHOL.h: Modify the prototype of methods.
* CMatrix.cc (inverse): Invoke chol with additional parameter.
author | PrasannaKumar Muralidharan <prasannatsmkumar@gmail.com> |
---|---|
date | Sun, 24 Aug 2014 19:35:06 +0530 |
parents | a9574e3c6e9e |
children | dcfbf4c1c3c8 |
line wrap: on
line diff
--- a/liboctave/numeric/floatCHOL.cc +++ b/liboctave/numeric/floatCHOL.cc @@ -87,7 +87,7 @@ } octave_idx_type -FloatCHOL::init (const FloatMatrix& a, bool calc_cond) +FloatCHOL::init (const FloatMatrix& a, bool upper, bool calc_cond) { octave_idx_type a_nr = a.rows (); octave_idx_type a_nc = a.cols (); @@ -101,14 +101,30 @@ octave_idx_type n = a_nc; octave_idx_type info; + is_upper = upper; + chol_mat.clear (n, n); - for (octave_idx_type j = 0; j < n; j++) + if (is_upper) { - for (octave_idx_type i = 0; i <= j; i++) - chol_mat.xelem (i, j) = a(i, j); - for (octave_idx_type i = j+1; i < n; i++) - chol_mat.xelem (i, j) = 0.0f; + for (octave_idx_type j = 0; j < n; j++) + { + for (octave_idx_type i = 0; i <= j; i++) + chol_mat.xelem (i, j) = a(i, j); + for (octave_idx_type i = j+1; i < n; i++) + chol_mat.xelem (i, j) = 0.0f; + } } + else + { + for (octave_idx_type j = 0; j < n; j++) + { + for (octave_idx_type i = 0; i <= j; i++) + chol_mat.xelem (i, j) = 0.0f; + for (octave_idx_type i = j+1; i < n; i++) + chol_mat.xelem (i, j) = a(i, j); + } + } + float *h = chol_mat.fortran_vec (); // Calculate the norm of the matrix, for later use. @@ -116,9 +132,18 @@ if (calc_cond) anorm = xnorm (a, 1); - F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), - n, h, n, info - F77_CHAR_ARG_LEN (1))); + if (is_upper) + { + F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), + n, h, n, info + F77_CHAR_ARG_LEN (1))); + } + else + { + F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), + n, h, n, info + F77_CHAR_ARG_LEN (1))); + } xrcond = 0.0; if (info > 0) @@ -132,9 +157,19 @@ float *pz = z.fortran_vec (); Array<octave_idx_type> iz (dim_vector (n, 1)); octave_idx_type *piz = iz.fortran_vec (); - F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, - n, anorm, xrcond, pz, piz, spocon_info - F77_CHAR_ARG_LEN (1))); + if (is_upper) + { + F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, + n, anorm, xrcond, pz, piz, spocon_info + F77_CHAR_ARG_LEN (1))); + } + else + { + F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("L", 1), n, h, + n, anorm, xrcond, pz, piz, spocon_info + F77_CHAR_ARG_LEN (1))); + } + if (spocon_info != 0) info = -1; @@ -144,7 +179,7 @@ } static FloatMatrix -chol2inv_internal (const FloatMatrix& r) +chol2inv_internal (const FloatMatrix& r, bool is_upper = true) { FloatMatrix retval; @@ -161,17 +196,37 @@ if (info == 0) { - F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n, - v, n, info - F77_CHAR_ARG_LEN (1))); + if (is_upper) + { + F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n, + v, n, info + F77_CHAR_ARG_LEN (1))); + } + else + { + F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n, + v, n, info + F77_CHAR_ARG_LEN (1))); + } // If someone thinks of a more graceful way of doing this (or // faster for that matter :-)), please let me know! if (n > 1) - for (octave_idx_type j = 0; j < r_nc; j++) - for (octave_idx_type i = j+1; i < r_nr; i++) - tmp.xelem (i, j) = tmp.xelem (j, i); + { + if (is_upper) + { + for (octave_idx_type j = 0; j < r_nc; j++) + for (octave_idx_type i = j+1; i < r_nr; i++) + tmp.xelem (i, j) = tmp.xelem (j, i); + } + else + { + for (octave_idx_type j = 0; j < r_nc; j++) + for (octave_idx_type i = j+1; i < r_nr; i++) + tmp.xelem (j, i) = tmp.xelem (i, j); + } + } retval = tmp; } @@ -186,7 +241,7 @@ FloatMatrix FloatCHOL::inverse (void) const { - return chol2inv_internal (chol_mat); + return chol2inv_internal (chol_mat, is_upper); } void