# HG changeset patch # User John W. Eaton # Date 1334262459 14400 # Node ID e8e86ae3abbc5f875a112a1967f9affca727f7b8 # Parent 15e4ec503cfdc20fb29dff460525cdfe72c32003 make diag (x, m, n) return a proper diagonal matrix object (bug #36099) * Array.h, Array.cc (Array::diag (octave_idx_type, octave_idx_type) const): New function. * CMatrix.h, CMatrix.cc (ComplexMatrix::diag (octave_idx_type, octave_idx_type) const): New function. * dMatrix.h, dMatrix.cc (Matrix::diag (octave_idx_type, octave_idx_type) const): New function. * fCMatrix.h, fCMatrix.cc (FloatComplexMatrix::diag (octave_idx_type, octave_idx_type) const): New function. * fMatrix.h, fMatrix.cc (FloatMatrix::diag (octave_idx_type, octave_idx_type) const): New function. * CNDArray.cc, CNDArray.h (ComplexNDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * boolNDArray.cc, boolNDArray.h (boolNDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * chNDArray.cc, chNDArray.h (charNDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * dNDArray.cc, dNDArray.h (NDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * fCNDArray.cc, fCNDArray.h (FloatComplexNDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * fNDArray.cc, fNDArray.h (FloatNDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * intNDArray.cc, intNDArray.h (intNDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * Cell.cc, Cell.h (Cell::diag (octave_idx_type, octave_idx_type) const): New function. * ov.h (octave_value::diag (octave_idx_type, octave_idx_type)): New function. * ov-base.h, ov-base.cc (octave_base_value::diag (octave_idx_type, octave_idx_type) const): New virtual function and default implementation. * ov-base-mat.h (octave_base_matrix::diag (octave_idx_type, octave_idx_type) const): New function. * ov-base-scalar.cc, ov-base-scalar.h (octave_base_scalar::diag (octave_idx_type, octave_idx_type)): New function. * ov-complex.cc, ov-complex.h (octave_complex::diag (octave_idx_type, octave_idx_type) const): New function. * ov-cx-mat.cc, ov-complex.h (octave_complex_matrix::diag (octave_idx_type, octave_idx_type) const): New function. * ov-float.cc, ov-float.h (octave_float_scalar::diag (octave_idx_type, octave_idx_type) const): New function. * ov-flt-complex.cc, ov-flt-complex.h (octave_float_complex::diag (octave_idx_type, octave_idx_type) const): New function. * ov-flt-cx-mat.cc, ov-flt-cx-mat.h (octave_float_complex_matrix::diag (octave_idx_type, octave_idx_type) const): New function. * ov-flt-re-mat.cc, ov-flt-re-mat.h (octave_float_matrix::diag (octave_idx_type, octave_idx_type) const): New function. * ov-range.cc, ov-range.h (octave_range::diag (octave_idx_type, octave_idx_type) const): New function. * ov-re-mat.cc, ov-re-mat.h (octave_matrix::diag (octave_idx_type, octave_idx_type) const): New function. * ov-scalar.cc, ov-scalar.h (octave_scalar::diag (octave_idx_type, octave_idx_type) const): New function. * data.cc (Fdiag): Use two-arg octave_value::diag method for dispatching. New tests. diff --git a/liboctave/Array.cc b/liboctave/Array.cc --- a/liboctave/Array.cc +++ b/liboctave/Array.cc @@ -2524,6 +2524,26 @@ template Array +Array::diag (octave_idx_type m, octave_idx_type n) const +{ + Array retval; + + if (ndims () == 2 && (rows () == 1 || cols () == 1)) + { + retval.resize (m, n, resize_fill_value ()); + + for (octave_idx_type i = 0; i < numel (); i++) + retval.xelem (i, i) = xelem (i); + } + else + (*current_liboctave_error_handler) + ("cat: invalid dimension"); + + return retval; +} + +template +Array Array::cat (int dim, octave_idx_type n, const Array *array_list) { // Default concatenation. diff --git a/liboctave/Array.h b/liboctave/Array.h --- a/liboctave/Array.h +++ b/liboctave/Array.h @@ -562,6 +562,8 @@ Array diag (octave_idx_type k = 0) const; + Array diag (octave_idx_type m, octave_idx_type n) const; + // Concatenation along a specified (0-based) dimension, equivalent to cat(). // dim = -1 corresponds to dim = 0 and dim = -2 corresponds to dim = 1, // but apply the looser matching rules of vertcat/horzcat. diff --git a/liboctave/CMatrix.cc b/liboctave/CMatrix.cc --- a/liboctave/CMatrix.cc +++ b/liboctave/CMatrix.cc @@ -3239,6 +3239,23 @@ return MArray::diag (k); } +ComplexDiagMatrix +ComplexMatrix::diag (octave_idx_type m, octave_idx_type n) const +{ + ComplexDiagMatrix retval; + + octave_idx_type nr = rows (); + octave_idx_type nc = cols (); + + if (nr == 1 || nc == 1) + retval = ComplexDiagMatrix (*this, m, n); + else + (*current_liboctave_error_handler) + ("diag: expecting vector argument"); + + return retval; +} + bool ComplexMatrix::row_is_real_only (octave_idx_type i) const { diff --git a/liboctave/CMatrix.h b/liboctave/CMatrix.h --- a/liboctave/CMatrix.h +++ b/liboctave/CMatrix.h @@ -357,6 +357,8 @@ ComplexMatrix diag (octave_idx_type k = 0) const; + ComplexDiagMatrix diag (octave_idx_type m, octave_idx_type n) const; + bool row_is_real_only (octave_idx_type) const; bool column_is_real_only (octave_idx_type) const; diff --git a/liboctave/CNDArray.cc b/liboctave/CNDArray.cc --- a/liboctave/CNDArray.cc +++ b/liboctave/CNDArray.cc @@ -862,6 +862,12 @@ return MArray::diag (k); } +ComplexNDArray +ComplexNDArray::diag (octave_idx_type m, octave_idx_type n) const +{ + return MArray::diag (m, n); +} + // This contains no information on the array structure !!! std::ostream& operator << (std::ostream& os, const ComplexNDArray& a) diff --git a/liboctave/CNDArray.h b/liboctave/CNDArray.h --- a/liboctave/CNDArray.h +++ b/liboctave/CNDArray.h @@ -142,6 +142,8 @@ ComplexNDArray diag (octave_idx_type k = 0) const; + ComplexNDArray diag (octave_idx_type m, octave_idx_type n) const; + ComplexNDArray& changesign (void) { MArray::changesign (); diff --git a/liboctave/boolNDArray.cc b/liboctave/boolNDArray.cc --- a/liboctave/boolNDArray.cc +++ b/liboctave/boolNDArray.cc @@ -134,6 +134,12 @@ return Array::diag (k); } +boolNDArray +boolNDArray::diag (octave_idx_type m, octave_idx_type n) const +{ + return Array::diag (m, n); +} + NDND_BOOL_OPS (boolNDArray, boolNDArray) NDND_CMP_OPS (boolNDArray, boolNDArray) diff --git a/liboctave/boolNDArray.h b/liboctave/boolNDArray.h --- a/liboctave/boolNDArray.h +++ b/liboctave/boolNDArray.h @@ -103,6 +103,7 @@ boolNDArray diag (octave_idx_type k = 0) const; + boolNDArray diag (octave_idx_type m, octave_idx_type n) const; }; NDND_BOOL_OP_DECLS (boolNDArray, boolNDArray, OCTAVE_API) diff --git a/liboctave/chNDArray.cc b/liboctave/chNDArray.cc --- a/liboctave/chNDArray.cc +++ b/liboctave/chNDArray.cc @@ -133,6 +133,12 @@ return Array::diag (k); } +charNDArray +charNDArray::diag (octave_idx_type m, octave_idx_type n) const +{ + return Array::diag (m, n); +} + NDS_CMP_OPS (charNDArray, char) NDS_BOOL_OPS (charNDArray, char) diff --git a/liboctave/chNDArray.h b/liboctave/chNDArray.h --- a/liboctave/chNDArray.h +++ b/liboctave/chNDArray.h @@ -96,6 +96,7 @@ charNDArray diag (octave_idx_type k = 0) const; + charNDArray diag (octave_idx_type m, octave_idx_type n) const; }; NDS_CMP_OP_DECLS (charNDArray, char, OCTAVE_API) diff --git a/liboctave/dMatrix.cc b/liboctave/dMatrix.cc --- a/liboctave/dMatrix.cc +++ b/liboctave/dMatrix.cc @@ -2783,6 +2783,23 @@ return MArray::diag (k); } +DiagMatrix +Matrix::diag (octave_idx_type m, octave_idx_type n) const +{ + DiagMatrix retval; + + octave_idx_type nr = rows (); + octave_idx_type nc = cols (); + + if (nr == 1 || nc == 1) + retval = DiagMatrix (*this, m, n); + else + (*current_liboctave_error_handler) + ("diag: expecting vector argument"); + + return retval; +} + ColumnVector Matrix::row_min (void) const { diff --git a/liboctave/dMatrix.h b/liboctave/dMatrix.h --- a/liboctave/dMatrix.h +++ b/liboctave/dMatrix.h @@ -316,6 +316,8 @@ Matrix diag (octave_idx_type k = 0) const; + DiagMatrix diag (octave_idx_type m, octave_idx_type n) const; + ColumnVector row_min (void) const; ColumnVector row_max (void) const; diff --git a/liboctave/dNDArray.cc b/liboctave/dNDArray.cc --- a/liboctave/dNDArray.cc +++ b/liboctave/dNDArray.cc @@ -877,6 +877,12 @@ return MArray::diag (k); } +NDArray +NDArray::diag (octave_idx_type m, octave_idx_type n) const +{ + return MArray::diag (m, n); +} + // This contains no information on the array structure !!! std::ostream& operator << (std::ostream& os, const NDArray& a) diff --git a/liboctave/dNDArray.h b/liboctave/dNDArray.h --- a/liboctave/dNDArray.h +++ b/liboctave/dNDArray.h @@ -154,6 +154,8 @@ NDArray diag (octave_idx_type k = 0) const; + NDArray diag (octave_idx_type m, octave_idx_type n) const; + NDArray& changesign (void) { MArray::changesign (); diff --git a/liboctave/fCMatrix.cc b/liboctave/fCMatrix.cc --- a/liboctave/fCMatrix.cc +++ b/liboctave/fCMatrix.cc @@ -3235,6 +3235,23 @@ return MArray::diag (k); } +FloatComplexDiagMatrix +FloatComplexMatrix::diag (octave_idx_type m, octave_idx_type n) const +{ + FloatComplexDiagMatrix retval; + + octave_idx_type nr = rows (); + octave_idx_type nc = cols (); + + if (nr == 1 || nc == 1) + retval = FloatComplexDiagMatrix (*this, m, n); + else + (*current_liboctave_error_handler) + ("diag: expecting vector argument"); + + return retval; +} + bool FloatComplexMatrix::row_is_real_only (octave_idx_type i) const { diff --git a/liboctave/fCMatrix.h b/liboctave/fCMatrix.h --- a/liboctave/fCMatrix.h +++ b/liboctave/fCMatrix.h @@ -362,6 +362,8 @@ FloatComplexMatrix diag (octave_idx_type k = 0) const; + FloatComplexDiagMatrix diag (octave_idx_type m, octave_idx_type n) const; + bool row_is_real_only (octave_idx_type) const; bool column_is_real_only (octave_idx_type) const; diff --git a/liboctave/fCNDArray.cc b/liboctave/fCNDArray.cc --- a/liboctave/fCNDArray.cc +++ b/liboctave/fCNDArray.cc @@ -859,6 +859,12 @@ return MArray::diag (k); } +FloatComplexNDArray +FloatComplexNDArray::diag (octave_idx_type m, octave_idx_type n) const +{ + return MArray::diag (m, n); +} + // This contains no information on the array structure !!! std::ostream& operator << (std::ostream& os, const FloatComplexNDArray& a) diff --git a/liboctave/fCNDArray.h b/liboctave/fCNDArray.h --- a/liboctave/fCNDArray.h +++ b/liboctave/fCNDArray.h @@ -142,6 +142,8 @@ FloatComplexNDArray diag (octave_idx_type k = 0) const; + FloatComplexNDArray diag (octave_idx_type m, octave_idx_type n) const; + FloatComplexNDArray& changesign (void) { MArray::changesign (); diff --git a/liboctave/fMatrix.cc b/liboctave/fMatrix.cc --- a/liboctave/fMatrix.cc +++ b/liboctave/fMatrix.cc @@ -2783,6 +2783,23 @@ return MArray::diag (k); } +FloatDiagMatrix +FloatMatrix::diag (octave_idx_type m, octave_idx_type n) const +{ + FloatDiagMatrix retval; + + octave_idx_type nr = rows (); + octave_idx_type nc = cols (); + + if (nr == 1 || nc == 1) + retval = FloatDiagMatrix (*this, m, n); + else + (*current_liboctave_error_handler) + ("diag: expecting vector argument"); + + return retval; +} + FloatColumnVector FloatMatrix::row_min (void) const { diff --git a/liboctave/fMatrix.h b/liboctave/fMatrix.h --- a/liboctave/fMatrix.h +++ b/liboctave/fMatrix.h @@ -316,6 +316,8 @@ FloatMatrix diag (octave_idx_type k = 0) const; + FloatDiagMatrix diag (octave_idx_type m, octave_idx_type n) const; + FloatColumnVector row_min (void) const; FloatColumnVector row_max (void) const; diff --git a/liboctave/fNDArray.cc b/liboctave/fNDArray.cc --- a/liboctave/fNDArray.cc +++ b/liboctave/fNDArray.cc @@ -837,6 +837,12 @@ return MArray::diag (k); } +FloatNDArray +FloatNDArray::diag (octave_idx_type m, octave_idx_type n) const +{ + return MArray::diag (m, n); +} + // This contains no information on the array structure !!! std::ostream& operator << (std::ostream& os, const FloatNDArray& a) diff --git a/liboctave/fNDArray.h b/liboctave/fNDArray.h --- a/liboctave/fNDArray.h +++ b/liboctave/fNDArray.h @@ -151,6 +151,8 @@ FloatNDArray diag (octave_idx_type k = 0) const; + FloatNDArray diag (octave_idx_type m, octave_idx_type n) const; + FloatNDArray& changesign (void) { MArray::changesign (); diff --git a/liboctave/intNDArray.cc b/liboctave/intNDArray.cc --- a/liboctave/intNDArray.cc +++ b/liboctave/intNDArray.cc @@ -69,6 +69,13 @@ return MArray::diag (k); } +template +intNDArray +intNDArray::diag (octave_idx_type m, octave_idx_type n) const +{ + return MArray::diag (m, n); +} + // FIXME -- this is not quite the right thing. template diff --git a/liboctave/intNDArray.h b/liboctave/intNDArray.h --- a/liboctave/intNDArray.h +++ b/liboctave/intNDArray.h @@ -66,6 +66,8 @@ intNDArray diag (octave_idx_type k = 0) const; + intNDArray diag (octave_idx_type m, octave_idx_type n) const; + intNDArray& changesign (void) { MArray::changesign (); diff --git a/src/Cell.cc b/src/Cell.cc --- a/src/Cell.cc +++ b/src/Cell.cc @@ -315,3 +315,9 @@ { return Array::diag (k); } + +Cell +Cell::diag (octave_idx_type m, octave_idx_type n) const +{ + return Array::diag (m, n); +} diff --git a/src/Cell.h b/src/Cell.h --- a/src/Cell.h +++ b/src/Cell.h @@ -114,6 +114,8 @@ Cell diag (octave_idx_type k = 0) const; + Cell diag (octave_idx_type m, octave_idx_type n) const; + Cell xisalnum (void) const { return map (&octave_value::xisalnum); } Cell xisalpha (void) const { return map (&octave_value::xisalpha); } Cell xisascii (void) const { return map (&octave_value::xisascii); } diff --git a/src/data.cc b/src/data.cc --- a/src/data.cc +++ b/src/data.cc @@ -1281,11 +1281,14 @@ else if (nargin == 3) { octave_value arg0 = args(0); - if (arg0.ndims () == 2 && (args(0).rows () == 1 || args(0).columns () == 1)) + + if (arg0.ndims () == 2 && (arg0.rows () == 1 || arg0.columns () == 1)) { - octave_idx_type m = args(1).int_value (), n = args(2).int_value (); + octave_idx_type m = args(1).int_value (); + octave_idx_type n = args(2).int_value (); + if (! error_state) - retval = arg0.diag ().resize (dim_vector (m, n), true); + retval = arg0.diag (m, n); else error ("diag: invalid dimensions"); } @@ -1341,7 +1344,15 @@ %!error diag (ones (2), 3, 3) %!error diag (1:3, -4, 3) - */ +%!assert (diag (1, 3, 3), diag ([1, 0, 0])) +%!assert (diag (i, 3, 3), diag ([i, 0, 0])) +%!assert (diag (single (1), 3, 3), diag ([single(1), 0, 0])) +%!assert (diag (single (i), 3, 3), diag ([single(i), 0, 0])) +%!assert (diag ([1, 2], 3, 3), diag ([1, 2, 0])) +%!assert (diag ([1, 2]*i, 3, 3), diag ([1, 2, 0]*i)) +%!assert (diag (single ([1, 2]), 3, 3), diag (single ([1, 2, 0]))) +%!assert (diag (single ([1, 2]*i), 3, 3), diag (single ([1, 2, 0]*i))) +*/ DEFUN (prod, args, , "-*- texinfo -*-\n\ diff --git a/src/ov-base-mat.h b/src/ov-base-mat.h --- a/src/ov-base-mat.h +++ b/src/ov-base-mat.h @@ -123,6 +123,9 @@ octave_value diag (octave_idx_type k = 0) const { return octave_value (matrix.diag (k)); } + octave_value diag (octave_idx_type m, octave_idx_type n) const + { return octave_value (matrix.diag (m, n)); } + octave_value sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const { return octave_value (matrix.sort (dim, mode)); } octave_value sort (Array &sidx, octave_idx_type dim = 0, diff --git a/src/ov-base-scalar.cc b/src/ov-base-scalar.cc --- a/src/ov-base-scalar.cc +++ b/src/ov-base-scalar.cc @@ -121,6 +121,13 @@ } template +octave_value +octave_base_scalar::diag (octave_idx_type m, octave_idx_type n) const +{ + return Array (dim_vector (1, 1), scalar).diag (m, n); +} + +template bool octave_base_scalar::is_true (void) const { diff --git a/src/ov-base-scalar.h b/src/ov-base-scalar.h --- a/src/ov-base-scalar.h +++ b/src/ov-base-scalar.h @@ -98,6 +98,8 @@ octave_value diag (octave_idx_type k = 0) const; + octave_value diag (octave_idx_type m, octave_idx_type n) const; + octave_value sort (octave_idx_type, sortmode) const { return octave_value (scalar); } octave_value sort (Array &sidx, octave_idx_type, diff --git a/src/ov-base.cc b/src/ov-base.cc --- a/src/ov-base.cc +++ b/src/ov-base.cc @@ -1121,6 +1121,14 @@ } octave_value +octave_base_value::diag (octave_idx_type, octave_idx_type) const +{ + gripe_wrong_type_arg ("octave_base_value::diag ()", type_name ()); + + return octave_value(); +} + +octave_value octave_base_value::sort (octave_idx_type, sortmode) const { gripe_wrong_type_arg ("octave_base_value::sort ()", type_name ()); diff --git a/src/ov-base.h b/src/ov-base.h --- a/src/ov-base.h +++ b/src/ov-base.h @@ -646,6 +646,8 @@ virtual octave_value diag (octave_idx_type k = 0) const; + virtual octave_value diag (octave_idx_type m, octave_idx_type n) const; + virtual octave_value sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const; virtual octave_value sort (Array &sidx, diff --git a/src/ov-complex.cc b/src/ov-complex.cc --- a/src/ov-complex.cc +++ b/src/ov-complex.cc @@ -243,6 +243,12 @@ } } +octave_value +octave_complex::diag (octave_idx_type m, octave_idx_type n) const +{ + return ComplexDiagMatrix (Array (dim_vector (1, 1), scalar), m, n); +} + bool octave_complex::save_ascii (std::ostream& os) { diff --git a/src/ov-complex.h b/src/ov-complex.h --- a/src/ov-complex.h +++ b/src/ov-complex.h @@ -163,6 +163,8 @@ return boolNDArray (dim_vector (1, 1), scalar != 0.0); } + octave_value diag (octave_idx_type m, octave_idx_type n) const; + void increment (void) { scalar += 1.0; } void decrement (void) { scalar -= 1.0; } diff --git a/src/ov-cx-mat.cc b/src/ov-cx-mat.cc --- a/src/ov-cx-mat.cc +++ b/src/ov-cx-mat.cc @@ -292,6 +292,24 @@ return retval; } +octave_value +octave_complex_matrix::diag (octave_idx_type m, octave_idx_type n) const +{ + octave_value retval; + + if (matrix.ndims () == 2 + && (matrix.rows () == 1 || matrix.columns () == 1)) + { + ComplexMatrix mat = matrix.matrix_value (); + + retval = mat.diag (m, n); + } + else + error ("diag: expecting vector argument"); + + return retval; +} + bool octave_complex_matrix::save_ascii (std::ostream& os) { diff --git a/src/ov-cx-mat.h b/src/ov-cx-mat.h --- a/src/ov-cx-mat.h +++ b/src/ov-cx-mat.h @@ -135,6 +135,8 @@ octave_value diag (octave_idx_type k = 0) const; + octave_value diag (octave_idx_type m, octave_idx_type n) const; + void increment (void) { matrix += Complex (1.0); } void decrement (void) { matrix -= Complex (1.0); } diff --git a/src/ov-float.cc b/src/ov-float.cc --- a/src/ov-float.cc +++ b/src/ov-float.cc @@ -98,6 +98,12 @@ } octave_value +octave_float_scalar::diag (octave_idx_type m, octave_idx_type n) const +{ + return FloatDiagMatrix (Array (dim_vector (1, 1), scalar), m, n); +} + +octave_value octave_float_scalar::convert_to_str_internal (bool, bool, char type) const { octave_value retval; diff --git a/src/ov-float.h b/src/ov-float.h --- a/src/ov-float.h +++ b/src/ov-float.h @@ -211,6 +211,8 @@ return boolNDArray (dim_vector (1, 1), scalar); } + octave_value diag (octave_idx_type m, octave_idx_type n) const; + octave_value convert_to_str_internal (bool pad, bool force, char type) const; void increment (void) { ++scalar; } diff --git a/src/ov-flt-complex.cc b/src/ov-flt-complex.cc --- a/src/ov-flt-complex.cc +++ b/src/ov-flt-complex.cc @@ -228,6 +228,12 @@ } } +octave_value +octave_float_complex::diag (octave_idx_type m, octave_idx_type n) const +{ + return FloatComplexDiagMatrix (Array (dim_vector (1, 1), scalar), m, n); +} + bool octave_float_complex::save_ascii (std::ostream& os) { diff --git a/src/ov-flt-complex.h b/src/ov-flt-complex.h --- a/src/ov-flt-complex.h +++ b/src/ov-flt-complex.h @@ -152,6 +152,8 @@ return boolNDArray (dim_vector (1, 1), scalar != 1.0f); } + octave_value diag (octave_idx_type m, octave_idx_type n) const; + void increment (void) { scalar += 1.0; } void decrement (void) { scalar -= 1.0; } diff --git a/src/ov-flt-cx-mat.cc b/src/ov-flt-cx-mat.cc --- a/src/ov-flt-cx-mat.cc +++ b/src/ov-flt-cx-mat.cc @@ -281,6 +281,24 @@ return retval; } +octave_value +octave_float_complex_matrix::diag (octave_idx_type m, octave_idx_type n) const +{ + octave_value retval; + + if (matrix.ndims () == 2 + && (matrix.rows () == 1 || matrix.columns () == 1)) + { + FloatComplexMatrix mat = matrix.matrix_value (); + + retval = mat.diag (m, n); + } + else + error ("diag: expecting vector argument"); + + return retval; +} + bool octave_float_complex_matrix::save_ascii (std::ostream& os) { diff --git a/src/ov-flt-cx-mat.h b/src/ov-flt-cx-mat.h --- a/src/ov-flt-cx-mat.h +++ b/src/ov-flt-cx-mat.h @@ -133,6 +133,8 @@ octave_value diag (octave_idx_type k = 0) const; + octave_value diag (octave_idx_type m, octave_idx_type n) const; + void increment (void) { matrix += FloatComplex (1.0); } void decrement (void) { matrix -= FloatComplex (1.0); } diff --git a/src/ov-flt-re-mat.cc b/src/ov-flt-re-mat.cc --- a/src/ov-flt-re-mat.cc +++ b/src/ov-flt-re-mat.cc @@ -264,6 +264,24 @@ } octave_value +octave_float_matrix::diag (octave_idx_type m, octave_idx_type n) const +{ + octave_value retval; + + if (matrix.ndims () == 2 + && (matrix.rows () == 1 || matrix.columns () == 1)) + { + FloatMatrix mat = matrix.matrix_value (); + + retval = mat.diag (m, n); + } + else + error ("diag: expecting vector argument"); + + return retval; +} + +octave_value octave_float_matrix::convert_to_str_internal (bool, bool, char type) const { octave_value retval; diff --git a/src/ov-flt-re-mat.h b/src/ov-flt-re-mat.h --- a/src/ov-flt-re-mat.h +++ b/src/ov-flt-re-mat.h @@ -164,6 +164,8 @@ octave_value diag (octave_idx_type k = 0) const; + octave_value diag (octave_idx_type m, octave_idx_type n) const; + // Use matrix_ref here to clear index cache. void increment (void) { matrix_ref () += 1.0; } diff --git a/src/ov-range.cc b/src/ov-range.cc --- a/src/ov-range.cc +++ b/src/ov-range.cc @@ -248,6 +248,13 @@ : octave_value (range.diag (k))); } +octave_value +octave_range::diag (octave_idx_type m, octave_idx_type n) const +{ + Matrix mat = range.matrix_value (); + + return mat.diag (m, n); +} bool octave_range::is_true (void) const diff --git a/src/ov-range.h b/src/ov-range.h --- a/src/ov-range.h +++ b/src/ov-range.h @@ -139,6 +139,8 @@ octave_value diag (octave_idx_type k = 0) const; + octave_value diag (octave_idx_type m, octave_idx_type n) const; + octave_value sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const { return range.sort (dim, mode); } diff --git a/src/ov-re-mat.cc b/src/ov-re-mat.cc --- a/src/ov-re-mat.cc +++ b/src/ov-re-mat.cc @@ -272,6 +272,24 @@ return retval; } +octave_value +octave_matrix::diag (octave_idx_type m, octave_idx_type n) const +{ + octave_value retval; + + if (matrix.ndims () == 2 + && (matrix.rows () == 1 || matrix.columns () == 1)) + { + Matrix mat = matrix.matrix_value (); + + retval = mat.diag (m, n); + } + else + error ("diag: expecting vector argument"); + + return retval; +} + // We override these two functions to allow reshaping both // the matrix and the index cache. octave_value diff --git a/src/ov-re-mat.h b/src/ov-re-mat.h --- a/src/ov-re-mat.h +++ b/src/ov-re-mat.h @@ -178,6 +178,8 @@ octave_value diag (octave_idx_type k = 0) const; + octave_value diag (octave_idx_type m, octave_idx_type n) const; + octave_value reshape (const dim_vector& new_dims) const; octave_value squeeze (void) const; diff --git a/src/ov-scalar.cc b/src/ov-scalar.cc --- a/src/ov-scalar.cc +++ b/src/ov-scalar.cc @@ -113,6 +113,12 @@ } octave_value +octave_scalar::diag (octave_idx_type m, octave_idx_type n) const +{ + return DiagMatrix (Array (dim_vector (1, 1), scalar), m, n); +} + +octave_value octave_scalar::convert_to_str_internal (bool, bool, char type) const { octave_value retval; diff --git a/src/ov-scalar.h b/src/ov-scalar.h --- a/src/ov-scalar.h +++ b/src/ov-scalar.h @@ -212,6 +212,8 @@ return boolNDArray (dim_vector (1, 1), scalar); } + octave_value diag (octave_idx_type m, octave_idx_type n) const; + octave_value convert_to_str_internal (bool pad, bool force, char type) const; void increment (void) { ++scalar; } diff --git a/src/ov.h b/src/ov.h --- a/src/ov.h +++ b/src/ov.h @@ -1075,6 +1075,9 @@ octave_value diag (octave_idx_type k = 0) const { return rep->diag (k); } + octave_value diag (octave_idx_type m, octave_idx_type n) const + { return rep->diag (m, n); } + octave_value sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const { return rep->sort (dim, mode); } octave_value sort (Array &sidx, octave_idx_type dim = 0,