Mercurial > hg > octave-lyh
changeset 15428:fd5c0159b588 stable
Fix diag handling of diagvectors (bug #37411)
* DiagArray2.h (extract_diag): New function
* DiagArray2.cc (extract_diag): Ditto
* ov.h (octave_value): New constructors for DiagArray2<T> types.
* ov.cc (octave_value): Ditto
* ov-base-diag.h (octave_base_diag<DMT,MT>::diag): Remove definition.
* ov-base-diag.cc (octave_base_diag<DMT,MT>::diag) Rewrite to check
for special diagvector case.
* data.cc: Add test for this bug
author | Jordi Gutiérrez Hermoso <jordigh@octave.org> |
---|---|
date | Fri, 21 Sep 2012 16:42:33 -0400 |
parents | 197774b411ec |
children | 4db96357fec9 c9954a15bc03 |
files | liboctave/DiagArray2.cc liboctave/DiagArray2.h src/data.cc src/ov-base-diag.cc src/ov-base-diag.h src/ov.cc src/ov.h |
diffstat | 7 files changed, 73 insertions(+), 3 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/DiagArray2.cc +++ b/liboctave/DiagArray2.cc @@ -48,6 +48,13 @@ template <class T> Array<T> +DiagArray2<T>::extract_diag (octave_idx_type k) const +{ + return diag (k); +} + +template <class T> +Array<T> DiagArray2<T>::diag (octave_idx_type k) const { Array<T> d;
--- a/liboctave/DiagArray2.h +++ b/liboctave/DiagArray2.h @@ -64,7 +64,7 @@ template <class U> DiagArray2 (const DiagArray2<U>& a) - : Array<T> (a.diag ()), d1 (a.dim1 ()), d2 (a.dim2 ()) { } + : Array<T> (a.extract_diag ()), d1 (a.dim1 ()), d2 (a.dim2 ()) { } ~DiagArray2 (void) { } @@ -98,6 +98,11 @@ dim_vector dims (void) const { return dim_vector (d1, d2); } Array<T> diag (octave_idx_type k = 0) const; + Array<T> extract_diag (octave_idx_type k = 0) const; + DiagArray2<T> build_diag_matrix () const + { + return DiagArray2<T> (array_value ()); + } // Warning: the non-const two-index versions will silently ignore assignments // to off-diagonal elements.
--- a/src/data.cc +++ b/src/data.cc @@ -1354,6 +1354,11 @@ %!assert(diag (int8([0, 1, 0, 0; 0, 0, 2, 0; 0, 0, 0, 3; 0, 0, 0, 0]), 1), int8([1; 2; 3])); %!assert(diag (int8([0, 0, 0, 0; 1, 0, 0, 0; 0, 2, 0, 0; 0, 0, 3, 0]), -1), int8([1; 2; 3])); +## bug #37411 +%!assert (diag (diag ([5, 2, 3])(:,1)), diag([5 0 0 ])) +%!assert (diag (diag ([5, 2, 3])(:,1), 2), [0 0 5 0 0; zeros(4, 5)]) +%!assert (diag (diag ([5, 2, 3])(:,1), -2), [[0 0 5 0 0]', zeros(5, 4)]) + ## Test non-square size %!assert(diag ([1,2,3], 6, 3), [1 0 0; 0 2 0; 0 0 3; 0 0 0; 0 0 0; 0 0 0]) %!assert (diag (1, 2, 3), [1,0,0; 0,0,0]);
--- a/src/ov-base-diag.cc +++ b/src/ov-base-diag.cc @@ -67,6 +67,32 @@ return retval.next_subsref (type, idx); } + +template <class DMT, class MT> +octave_value +octave_base_diag<DMT,MT>::diag (octave_idx_type k) const +{ + octave_value retval; + if (matrix.rows () == 1 || matrix.cols () == 1) + { + // Rather odd special case. This is a row or column vector + // represented as a diagonal matrix with a single nonzero entry, but + // Fdiag semantics are to product a diagonal matrix for vector + // inputs. + if (k == 0) + // Returns Diag2Array<T> with nnz <= 1. + retval = matrix.build_diag_matrix (); + else + // Returns Array<T> matrix + retval = matrix.array_value ().diag (k); + } + else + // Returns Array<T> vector + retval = matrix.extract_diag (k); + return retval; +} + + template <class DMT, class MT> octave_value octave_base_diag<DMT, MT>::do_index_op (const octave_value_list& idx,
--- a/src/ov-base-diag.h +++ b/src/ov-base-diag.h @@ -97,8 +97,7 @@ MatrixType matrix_type (const MatrixType&) const { return matrix_type (); } - octave_value diag (octave_idx_type k = 0) const - { return octave_value (matrix.diag (k)); } + octave_value diag (octave_idx_type k = 0) const; octave_value sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const { return to_dense ().sort (dim, mode); }
--- a/src/ov.cc +++ b/src/ov.cc @@ -630,6 +630,30 @@ maybe_mutate (); } +octave_value::octave_value (const DiagArray2<double>& d) + : rep (new octave_diag_matrix (d)) +{ + maybe_mutate (); +} + +octave_value::octave_value (const DiagArray2<float>& d) + : rep (new octave_float_diag_matrix (d)) +{ + maybe_mutate (); +} + +octave_value::octave_value (const DiagArray2<Complex>& d) + : rep (new octave_complex_diag_matrix (d)) +{ + maybe_mutate (); +} + +octave_value::octave_value (const DiagArray2<FloatComplex>& d) + : rep (new octave_float_complex_diag_matrix (d)) +{ + maybe_mutate (); +} + octave_value::octave_value (const DiagMatrix& d) : rep (new octave_diag_matrix (d)) {
--- a/src/ov.h +++ b/src/ov.h @@ -201,6 +201,10 @@ octave_value (const Array<double>& m); octave_value (const Array<float>& m); octave_value (const DiagMatrix& d); + octave_value (const DiagArray2<double>& d); + octave_value (const DiagArray2<float>& d); + octave_value (const DiagArray2<Complex>& d); + octave_value (const DiagArray2<FloatComplex>& d); octave_value (const FloatDiagMatrix& d); octave_value (const RowVector& v); octave_value (const FloatRowVector& v);