Mercurial > hg > octave-nkf
changeset 18535:c5a101de2d88
Allow pinv to work on Diagonal Matrices with a tolerance (bug #41546).
* pinv.cc (Fpinv): Validate tolerance argument and pass it through to
pseudo_inverse().
CDiagMatrix.h, dDiagMatrix.h, fCDiagMatrix.h, fDiagMatrix.h: Redefine
prototype for pseudo_inverse to accept a single argument for tolerance.
* CDiagMatrix.cc (pseudo_inverse), dDiagMatrix.cc(pseudo_inverse),
fCDiagMatrix.cc(pseudo_inverse), fDiagMatrix.cc(pseudo_inverse):
Use std::abs(elem) to get magnitude of element and only invert if
value is greater than tolerance.
author | Rik <rik@octave.org> |
---|---|
date | Sat, 15 Feb 2014 14:42:07 -0800 |
parents | a3611f3e80eb |
children | 0bfa7798c496 |
files | libinterp/corefcn/pinv.cc liboctave/array/CDiagMatrix.cc liboctave/array/CDiagMatrix.h liboctave/array/dDiagMatrix.cc liboctave/array/dDiagMatrix.h liboctave/array/fCDiagMatrix.cc liboctave/array/fCDiagMatrix.h liboctave/array/fDiagMatrix.cc liboctave/array/fDiagMatrix.h |
diffstat | 9 files changed, 53 insertions(+), 30 deletions(-) [+] |
line wrap: on
line diff
--- a/libinterp/corefcn/pinv.cc +++ b/libinterp/corefcn/pinv.cc @@ -76,22 +76,45 @@ if (arg.is_diag_matrix ()) { - if (nargin == 2) - warning ("pinv: tol is ignored for diagonal matrices"); - - if (arg.is_complex_type ()) + if (isfloat) { - if (isfloat) - retval = arg.float_complex_diag_matrix_value ().pseudo_inverse (); + float tol = 0.0; + if (nargin == 2) + tol = args(1).float_value (); + + if (error_state) + return retval; + + if (tol < 0.0) + { + error ("pinv: TOL must be greater than zero"); + return retval; + } + + if (arg.is_real_type ()) + retval = arg.float_diag_matrix_value ().pseudo_inverse (tol); else - retval = arg.complex_diag_matrix_value ().pseudo_inverse (); + retval = arg.float_complex_diag_matrix_value ().pseudo_inverse (tol); } else { - if (isfloat) - retval = arg.float_diag_matrix_value ().pseudo_inverse (); + double tol = 0.0; + if (nargin == 2) + tol = args(1).double_value (); + + if (error_state) + return retval; + + if (tol < 0.0) + { + error ("pinv: TOL must be greater than zero"); + return retval; + } + + if (arg.is_real_type ()) + retval = arg.diag_matrix_value ().pseudo_inverse (tol); else - retval = arg.diag_matrix_value ().pseudo_inverse (); + retval = arg.complex_diag_matrix_value ().pseudo_inverse (tol); } } else if (arg.is_perm_matrix ())
--- a/liboctave/array/CDiagMatrix.cc +++ b/liboctave/array/CDiagMatrix.cc @@ -383,7 +383,7 @@ } ComplexDiagMatrix -ComplexDiagMatrix::pseudo_inverse (void) const +ComplexDiagMatrix::pseudo_inverse (double tol) const { octave_idx_type r = rows (); octave_idx_type c = cols (); @@ -393,10 +393,10 @@ for (octave_idx_type i = 0; i < len; i++) { - if (elem (i, i) != 0.0) + if (std::abs (elem (i, i)) < tol) + retval.elem (i, i) = 0.0; + else retval.elem (i, i) = 1.0 / elem (i, i); - else - retval.elem (i, i) = 0.0; } return retval;
--- a/liboctave/array/CDiagMatrix.h +++ b/liboctave/array/CDiagMatrix.h @@ -116,7 +116,7 @@ ComplexDiagMatrix inverse (octave_idx_type& info) const; ComplexDiagMatrix inverse (void) const; - ComplexDiagMatrix pseudo_inverse (void) const; + ComplexDiagMatrix pseudo_inverse (double tol = 0.0) const; bool all_elements_are_real (void) const;
--- a/liboctave/array/dDiagMatrix.cc +++ b/liboctave/array/dDiagMatrix.cc @@ -292,7 +292,7 @@ } DiagMatrix -DiagMatrix::pseudo_inverse (void) const +DiagMatrix::pseudo_inverse (double tol) const { octave_idx_type r = rows (); octave_idx_type c = cols (); @@ -302,10 +302,10 @@ for (octave_idx_type i = 0; i < len; i++) { - if (elem (i, i) != 0.0) + if (std::abs (elem (i, i)) < tol) + retval.elem (i, i) = 0.0; + else retval.elem (i, i) = 1.0 / elem (i, i); - else - retval.elem (i, i) = 0.0; } return retval;
--- a/liboctave/array/dDiagMatrix.h +++ b/liboctave/array/dDiagMatrix.h @@ -98,7 +98,7 @@ DiagMatrix inverse (void) const; DiagMatrix inverse (octave_idx_type& info) const; - DiagMatrix pseudo_inverse (void) const; + DiagMatrix pseudo_inverse (double tol = 0.0) const; // other operations
--- a/liboctave/array/fCDiagMatrix.cc +++ b/liboctave/array/fCDiagMatrix.cc @@ -387,7 +387,7 @@ } FloatComplexDiagMatrix -FloatComplexDiagMatrix::pseudo_inverse (void) const +FloatComplexDiagMatrix::pseudo_inverse (float tol) const { octave_idx_type r = rows (); octave_idx_type c = cols (); @@ -397,10 +397,10 @@ for (octave_idx_type i = 0; i < len; i++) { - if (elem (i, i) != 0.0f) + if (std::abs (elem (i, i)) < tol) + retval.elem (i, i) = 0.0f; + else retval.elem (i, i) = 1.0f / elem (i, i); - else - retval.elem (i, i) = 0.0f; } return retval;
--- a/liboctave/array/fCDiagMatrix.h +++ b/liboctave/array/fCDiagMatrix.h @@ -122,7 +122,7 @@ FloatComplexDiagMatrix inverse (octave_idx_type& info) const; FloatComplexDiagMatrix inverse (void) const; - FloatComplexDiagMatrix pseudo_inverse (void) const; + FloatComplexDiagMatrix pseudo_inverse (float tol = 0.0f) const; bool all_elements_are_real (void) const;
--- a/liboctave/array/fDiagMatrix.cc +++ b/liboctave/array/fDiagMatrix.cc @@ -292,7 +292,7 @@ } FloatDiagMatrix -FloatDiagMatrix::pseudo_inverse (void) const +FloatDiagMatrix::pseudo_inverse (float tol) const { octave_idx_type r = rows (); octave_idx_type c = cols (); @@ -302,10 +302,10 @@ for (octave_idx_type i = 0; i < len; i++) { - if (elem (i, i) != 0.0f) + if (std::abs (elem (i, i)) < tol) + retval.elem (i, i) = 0.0f; + else retval.elem (i, i) = 1.0f / elem (i, i); - else - retval.elem (i, i) = 0.0f; } return retval;
--- a/liboctave/array/fDiagMatrix.h +++ b/liboctave/array/fDiagMatrix.h @@ -99,7 +99,7 @@ FloatDiagMatrix inverse (void) const; FloatDiagMatrix inverse (octave_idx_type& info) const; - FloatDiagMatrix pseudo_inverse (void) const; + FloatDiagMatrix pseudo_inverse (float tol = 0.0f) const; // other operations