changeset 14557:e8e86ae3abbc

make diag (x, m, n) return a proper diagonal matrix object (bug #36099) * Array.h, Array.cc (Array<T>::diag (octave_idx_type, octave_idx_type) const): New function. * CMatrix.h, CMatrix.cc (ComplexMatrix::diag (octave_idx_type, octave_idx_type) const): New function. * dMatrix.h, dMatrix.cc (Matrix::diag (octave_idx_type, octave_idx_type) const): New function. * fCMatrix.h, fCMatrix.cc (FloatComplexMatrix::diag (octave_idx_type, octave_idx_type) const): New function. * fMatrix.h, fMatrix.cc (FloatMatrix::diag (octave_idx_type, octave_idx_type) const): New function. * CNDArray.cc, CNDArray.h (ComplexNDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * boolNDArray.cc, boolNDArray.h (boolNDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * chNDArray.cc, chNDArray.h (charNDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * dNDArray.cc, dNDArray.h (NDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * fCNDArray.cc, fCNDArray.h (FloatComplexNDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * fNDArray.cc, fNDArray.h (FloatNDArray::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * intNDArray.cc, intNDArray.h (intNDArray<T>::diag (octave_idx_type, octave_idx_type) const): New forwarding function. * Cell.cc, Cell.h (Cell::diag (octave_idx_type, octave_idx_type) const): New function. * ov.h (octave_value::diag (octave_idx_type, octave_idx_type)): New function. * ov-base.h, ov-base.cc (octave_base_value::diag (octave_idx_type, octave_idx_type) const): New virtual function and default implementation. * ov-base-mat.h (octave_base_matrix<T>::diag (octave_idx_type, octave_idx_type) const): New function. * ov-base-scalar.cc, ov-base-scalar.h (octave_base_scalar<T>::diag (octave_idx_type, octave_idx_type)): New function. * ov-complex.cc, ov-complex.h (octave_complex::diag (octave_idx_type, octave_idx_type) const): New function. * ov-cx-mat.cc, ov-complex.h (octave_complex_matrix::diag (octave_idx_type, octave_idx_type) const): New function. * ov-float.cc, ov-float.h (octave_float_scalar::diag (octave_idx_type, octave_idx_type) const): New function. * ov-flt-complex.cc, ov-flt-complex.h (octave_float_complex::diag (octave_idx_type, octave_idx_type) const): New function. * ov-flt-cx-mat.cc, ov-flt-cx-mat.h (octave_float_complex_matrix::diag (octave_idx_type, octave_idx_type) const): New function. * ov-flt-re-mat.cc, ov-flt-re-mat.h (octave_float_matrix::diag (octave_idx_type, octave_idx_type) const): New function. * ov-range.cc, ov-range.h (octave_range::diag (octave_idx_type, octave_idx_type) const): New function. * ov-re-mat.cc, ov-re-mat.h (octave_matrix::diag (octave_idx_type, octave_idx_type) const): New function. * ov-scalar.cc, ov-scalar.h (octave_scalar::diag (octave_idx_type, octave_idx_type) const): New function. * data.cc (Fdiag): Use two-arg octave_value::diag method for dispatching. New tests.
author John W. Eaton <jwe@octave.org>
date Thu, 12 Apr 2012 16:27:39 -0400
parents 15e4ec503cfd
children 0c9c85e702ca
files liboctave/Array.cc liboctave/Array.h liboctave/CMatrix.cc liboctave/CMatrix.h liboctave/CNDArray.cc liboctave/CNDArray.h liboctave/boolNDArray.cc liboctave/boolNDArray.h liboctave/chNDArray.cc liboctave/chNDArray.h liboctave/dMatrix.cc liboctave/dMatrix.h liboctave/dNDArray.cc liboctave/dNDArray.h liboctave/fCMatrix.cc liboctave/fCMatrix.h liboctave/fCNDArray.cc liboctave/fCNDArray.h liboctave/fMatrix.cc liboctave/fMatrix.h liboctave/fNDArray.cc liboctave/fNDArray.h liboctave/intNDArray.cc liboctave/intNDArray.h src/Cell.cc src/Cell.h src/data.cc src/ov-base-mat.h src/ov-base-scalar.cc src/ov-base-scalar.h src/ov-base.cc src/ov-base.h src/ov-complex.cc src/ov-complex.h src/ov-cx-mat.cc src/ov-cx-mat.h src/ov-float.cc src/ov-float.h src/ov-flt-complex.cc src/ov-flt-complex.h src/ov-flt-cx-mat.cc src/ov-flt-cx-mat.h src/ov-flt-re-mat.cc src/ov-flt-re-mat.h src/ov-range.cc src/ov-range.h src/ov-re-mat.cc src/ov-re-mat.h src/ov-scalar.cc src/ov-scalar.h src/ov.h
diffstat 51 files changed, 322 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/Array.cc
+++ b/liboctave/Array.cc
@@ -2524,6 +2524,26 @@
 
 template <class T>
 Array<T>
+Array<T>::diag (octave_idx_type m, octave_idx_type n) const
+{
+  Array<T> retval;
+
+  if (ndims () == 2 && (rows () == 1 || cols () == 1))
+    {
+      retval.resize (m, n, resize_fill_value ());
+
+      for (octave_idx_type i = 0; i < numel (); i++)
+        retval.xelem (i, i) = xelem (i);
+    }
+  else
+    (*current_liboctave_error_handler)
+      ("cat: invalid dimension");
+
+  return retval;
+}
+
+template <class T>
+Array<T>
 Array<T>::cat (int dim, octave_idx_type n, const Array<T> *array_list)
 {
   // Default concatenation.
--- a/liboctave/Array.h
+++ b/liboctave/Array.h
@@ -562,6 +562,8 @@
 
   Array<T> diag (octave_idx_type k = 0) const;
 
+  Array<T> diag (octave_idx_type m, octave_idx_type n) const;
+
   // Concatenation along a specified (0-based) dimension, equivalent to cat().
   // dim = -1 corresponds to dim = 0 and dim = -2 corresponds to dim = 1,
   // but apply the looser matching rules of vertcat/horzcat.
--- a/liboctave/CMatrix.cc
+++ b/liboctave/CMatrix.cc
@@ -3239,6 +3239,23 @@
   return MArray<Complex>::diag (k);
 }
 
+ComplexDiagMatrix
+ComplexMatrix::diag (octave_idx_type m, octave_idx_type n) const
+{
+  ComplexDiagMatrix retval;
+
+  octave_idx_type nr = rows ();
+  octave_idx_type nc = cols ();
+
+  if (nr == 1 || nc == 1)
+    retval = ComplexDiagMatrix (*this, m, n);
+  else
+    (*current_liboctave_error_handler)
+      ("diag: expecting vector argument");
+
+  return retval;
+}
+
 bool
 ComplexMatrix::row_is_real_only (octave_idx_type i) const
 {
--- a/liboctave/CMatrix.h
+++ b/liboctave/CMatrix.h
@@ -357,6 +357,8 @@
 
   ComplexMatrix diag (octave_idx_type k = 0) const;
 
+  ComplexDiagMatrix diag (octave_idx_type m, octave_idx_type n) const;
+
   bool row_is_real_only (octave_idx_type) const;
   bool column_is_real_only (octave_idx_type) const;
 
--- a/liboctave/CNDArray.cc
+++ b/liboctave/CNDArray.cc
@@ -862,6 +862,12 @@
   return MArray<Complex>::diag (k);
 }
 
+ComplexNDArray
+ComplexNDArray::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return MArray<Complex>::diag (m, n);
+}
+
 // This contains no information on the array structure !!!
 std::ostream&
 operator << (std::ostream& os, const ComplexNDArray& a)
--- a/liboctave/CNDArray.h
+++ b/liboctave/CNDArray.h
@@ -142,6 +142,8 @@
 
   ComplexNDArray diag (octave_idx_type k = 0) const;
 
+  ComplexNDArray diag (octave_idx_type m, octave_idx_type n) const;
+
   ComplexNDArray& changesign (void)
     {
       MArray<Complex>::changesign ();
--- a/liboctave/boolNDArray.cc
+++ b/liboctave/boolNDArray.cc
@@ -134,6 +134,12 @@
   return Array<bool>::diag (k);
 }
 
+boolNDArray
+boolNDArray::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return Array<bool>::diag (m, n);
+}
+
 NDND_BOOL_OPS (boolNDArray, boolNDArray)
 NDND_CMP_OPS (boolNDArray, boolNDArray)
 
--- a/liboctave/boolNDArray.h
+++ b/liboctave/boolNDArray.h
@@ -103,6 +103,7 @@
 
   boolNDArray diag (octave_idx_type k = 0) const;
 
+  boolNDArray diag (octave_idx_type m, octave_idx_type n) const;
 };
 
 NDND_BOOL_OP_DECLS (boolNDArray, boolNDArray, OCTAVE_API)
--- a/liboctave/chNDArray.cc
+++ b/liboctave/chNDArray.cc
@@ -133,6 +133,12 @@
   return Array<char>::diag (k);
 }
 
+charNDArray
+charNDArray::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return Array<char>::diag (m, n);
+}
+
 NDS_CMP_OPS (charNDArray, char)
 NDS_BOOL_OPS (charNDArray, char)
 
--- a/liboctave/chNDArray.h
+++ b/liboctave/chNDArray.h
@@ -96,6 +96,7 @@
 
   charNDArray diag (octave_idx_type k = 0) const;
 
+  charNDArray diag (octave_idx_type m, octave_idx_type n) const;
 };
 
 NDS_CMP_OP_DECLS (charNDArray, char, OCTAVE_API)
--- a/liboctave/dMatrix.cc
+++ b/liboctave/dMatrix.cc
@@ -2783,6 +2783,23 @@
   return MArray<double>::diag (k);
 }
 
+DiagMatrix
+Matrix::diag (octave_idx_type m, octave_idx_type n) const
+{
+  DiagMatrix retval;
+
+  octave_idx_type nr = rows ();
+  octave_idx_type nc = cols ();
+
+  if (nr == 1 || nc == 1)
+    retval = DiagMatrix (*this, m, n);
+  else
+    (*current_liboctave_error_handler)
+      ("diag: expecting vector argument");
+
+  return retval;
+}
+
 ColumnVector
 Matrix::row_min (void) const
 {
--- a/liboctave/dMatrix.h
+++ b/liboctave/dMatrix.h
@@ -316,6 +316,8 @@
 
   Matrix diag (octave_idx_type k = 0) const;
 
+  DiagMatrix diag (octave_idx_type m, octave_idx_type n) const;
+
   ColumnVector row_min (void) const;
   ColumnVector row_max (void) const;
 
--- a/liboctave/dNDArray.cc
+++ b/liboctave/dNDArray.cc
@@ -877,6 +877,12 @@
   return MArray<double>::diag (k);
 }
 
+NDArray
+NDArray::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return MArray<double>::diag (m, n);
+}
+
 // This contains no information on the array structure !!!
 std::ostream&
 operator << (std::ostream& os, const NDArray& a)
--- a/liboctave/dNDArray.h
+++ b/liboctave/dNDArray.h
@@ -154,6 +154,8 @@
 
   NDArray diag (octave_idx_type k = 0) const;
 
+  NDArray diag (octave_idx_type m, octave_idx_type n) const;
+
   NDArray& changesign (void)
     {
       MArray<double>::changesign ();
--- a/liboctave/fCMatrix.cc
+++ b/liboctave/fCMatrix.cc
@@ -3235,6 +3235,23 @@
   return MArray<FloatComplex>::diag (k);
 }
 
+FloatComplexDiagMatrix
+FloatComplexMatrix::diag (octave_idx_type m, octave_idx_type n) const
+{
+  FloatComplexDiagMatrix retval;
+
+  octave_idx_type nr = rows ();
+  octave_idx_type nc = cols ();
+
+  if (nr == 1 || nc == 1)
+    retval = FloatComplexDiagMatrix (*this, m, n);
+  else
+    (*current_liboctave_error_handler)
+      ("diag: expecting vector argument");
+
+  return retval;
+}
+
 bool
 FloatComplexMatrix::row_is_real_only (octave_idx_type i) const
 {
--- a/liboctave/fCMatrix.h
+++ b/liboctave/fCMatrix.h
@@ -362,6 +362,8 @@
 
   FloatComplexMatrix diag (octave_idx_type k = 0) const;
 
+  FloatComplexDiagMatrix diag (octave_idx_type m, octave_idx_type n) const;
+
   bool row_is_real_only (octave_idx_type) const;
   bool column_is_real_only (octave_idx_type) const;
 
--- a/liboctave/fCNDArray.cc
+++ b/liboctave/fCNDArray.cc
@@ -859,6 +859,12 @@
   return MArray<FloatComplex>::diag (k);
 }
 
+FloatComplexNDArray
+FloatComplexNDArray::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return MArray<FloatComplex>::diag (m, n);
+}
+
 // This contains no information on the array structure !!!
 std::ostream&
 operator << (std::ostream& os, const FloatComplexNDArray& a)
--- a/liboctave/fCNDArray.h
+++ b/liboctave/fCNDArray.h
@@ -142,6 +142,8 @@
 
   FloatComplexNDArray diag (octave_idx_type k = 0) const;
 
+  FloatComplexNDArray diag (octave_idx_type m, octave_idx_type n) const;
+
   FloatComplexNDArray& changesign (void)
     {
       MArray<FloatComplex>::changesign ();
--- a/liboctave/fMatrix.cc
+++ b/liboctave/fMatrix.cc
@@ -2783,6 +2783,23 @@
   return MArray<float>::diag (k);
 }
 
+FloatDiagMatrix
+FloatMatrix::diag (octave_idx_type m, octave_idx_type n) const
+{
+  FloatDiagMatrix retval;
+
+  octave_idx_type nr = rows ();
+  octave_idx_type nc = cols ();
+
+  if (nr == 1 || nc == 1)
+    retval = FloatDiagMatrix (*this, m, n);
+  else
+    (*current_liboctave_error_handler)
+      ("diag: expecting vector argument");
+
+  return retval;
+}
+
 FloatColumnVector
 FloatMatrix::row_min (void) const
 {
--- a/liboctave/fMatrix.h
+++ b/liboctave/fMatrix.h
@@ -316,6 +316,8 @@
 
   FloatMatrix diag (octave_idx_type k = 0) const;
 
+  FloatDiagMatrix diag (octave_idx_type m, octave_idx_type n) const;
+
   FloatColumnVector row_min (void) const;
   FloatColumnVector row_max (void) const;
 
--- a/liboctave/fNDArray.cc
+++ b/liboctave/fNDArray.cc
@@ -837,6 +837,12 @@
   return MArray<float>::diag (k);
 }
 
+FloatNDArray
+FloatNDArray::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return MArray<float>::diag (m, n);
+}
+
 // This contains no information on the array structure !!!
 std::ostream&
 operator << (std::ostream& os, const FloatNDArray& a)
--- a/liboctave/fNDArray.h
+++ b/liboctave/fNDArray.h
@@ -151,6 +151,8 @@
 
   FloatNDArray diag (octave_idx_type k = 0) const;
 
+  FloatNDArray diag (octave_idx_type m, octave_idx_type n) const;
+
   FloatNDArray& changesign (void)
     {
       MArray<float>::changesign ();
--- a/liboctave/intNDArray.cc
+++ b/liboctave/intNDArray.cc
@@ -69,6 +69,13 @@
   return MArray<T>::diag (k);
 }
 
+template <class T>
+intNDArray<T>
+intNDArray<T>::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return MArray<T>::diag (m, n);
+}
+
 // FIXME -- this is not quite the right thing.
 
 template <class T>
--- a/liboctave/intNDArray.h
+++ b/liboctave/intNDArray.h
@@ -66,6 +66,8 @@
 
   intNDArray diag (octave_idx_type k = 0) const;
 
+  intNDArray diag (octave_idx_type m, octave_idx_type n) const;
+
   intNDArray& changesign (void)
     {
       MArray<T>::changesign ();
--- a/src/Cell.cc
+++ b/src/Cell.cc
@@ -315,3 +315,9 @@
 {
   return Array<octave_value>::diag (k);
 }
+
+Cell
+Cell::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return Array<octave_value>::diag (m, n);
+}
--- a/src/Cell.h
+++ b/src/Cell.h
@@ -114,6 +114,8 @@
 
   Cell diag (octave_idx_type k = 0) const;
 
+  Cell diag (octave_idx_type m, octave_idx_type n) const;
+
   Cell xisalnum (void) const { return map (&octave_value::xisalnum); }
   Cell xisalpha (void) const { return map (&octave_value::xisalpha); }
   Cell xisascii (void) const { return map (&octave_value::xisascii); }
--- a/src/data.cc
+++ b/src/data.cc
@@ -1281,11 +1281,14 @@
   else if (nargin == 3)
     {
       octave_value arg0 = args(0);
-      if (arg0.ndims () == 2 && (args(0).rows () == 1 || args(0).columns () == 1))
+
+      if (arg0.ndims () == 2 && (arg0.rows () == 1 || arg0.columns () == 1))
         {
-          octave_idx_type m = args(1).int_value (), n = args(2).int_value ();
+          octave_idx_type m = args(1).int_value ();
+          octave_idx_type n = args(2).int_value ();
+
           if (! error_state)
-            retval = arg0.diag ().resize (dim_vector (m, n), true);
+            retval = arg0.diag (m, n);
           else
             error ("diag: invalid dimensions");
         }
@@ -1341,7 +1344,15 @@
 %!error diag (ones (2), 3, 3)
 %!error diag (1:3, -4, 3)
 
- */
+%!assert (diag (1, 3, 3), diag ([1, 0, 0]))
+%!assert (diag (i, 3, 3), diag ([i, 0, 0]))
+%!assert (diag (single (1), 3, 3), diag ([single(1), 0, 0]))
+%!assert (diag (single (i), 3, 3), diag ([single(i), 0, 0]))
+%!assert (diag ([1, 2], 3, 3), diag ([1, 2, 0]))
+%!assert (diag ([1, 2]*i, 3, 3), diag ([1, 2, 0]*i))
+%!assert (diag (single ([1, 2]), 3, 3), diag (single ([1, 2, 0])))
+%!assert (diag (single ([1, 2]*i), 3, 3), diag (single ([1, 2, 0]*i)))
+*/
 
 DEFUN (prod, args, ,
   "-*- texinfo -*-\n\
--- a/src/ov-base-mat.h
+++ b/src/ov-base-mat.h
@@ -123,6 +123,9 @@
   octave_value diag (octave_idx_type k = 0) const
     { return octave_value (matrix.diag (k)); }
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const
+    { return octave_value (matrix.diag (m, n)); }
+
   octave_value sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const
     { return octave_value (matrix.sort (dim, mode)); }
   octave_value sort (Array<octave_idx_type> &sidx, octave_idx_type dim = 0,
--- a/src/ov-base-scalar.cc
+++ b/src/ov-base-scalar.cc
@@ -121,6 +121,13 @@
 }
 
 template <class ST>
+octave_value
+octave_base_scalar<ST>::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return Array<ST> (dim_vector (1, 1), scalar).diag (m, n);
+}
+
+template <class ST>
 bool
 octave_base_scalar<ST>::is_true (void) const
 {
--- a/src/ov-base-scalar.h
+++ b/src/ov-base-scalar.h
@@ -98,6 +98,8 @@
 
   octave_value diag (octave_idx_type k = 0) const;
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const;
+
   octave_value sort (octave_idx_type, sortmode) const
     { return octave_value (scalar); }
   octave_value sort (Array<octave_idx_type> &sidx, octave_idx_type,
--- a/src/ov-base.cc
+++ b/src/ov-base.cc
@@ -1121,6 +1121,14 @@
 }
 
 octave_value
+octave_base_value::diag (octave_idx_type, octave_idx_type) const
+{
+  gripe_wrong_type_arg ("octave_base_value::diag ()", type_name ());
+
+  return octave_value();
+}
+
+octave_value
 octave_base_value::sort (octave_idx_type, sortmode) const
 {
   gripe_wrong_type_arg ("octave_base_value::sort ()", type_name ());
--- a/src/ov-base.h
+++ b/src/ov-base.h
@@ -646,6 +646,8 @@
 
   virtual octave_value diag (octave_idx_type k = 0) const;
 
+  virtual octave_value diag (octave_idx_type m, octave_idx_type n) const;
+
   virtual octave_value sort (octave_idx_type dim = 0,
                              sortmode mode = ASCENDING) const;
   virtual octave_value sort (Array<octave_idx_type> &sidx,
--- a/src/ov-complex.cc
+++ b/src/ov-complex.cc
@@ -243,6 +243,12 @@
     }
 }
 
+octave_value
+octave_complex::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return ComplexDiagMatrix (Array<Complex> (dim_vector (1, 1), scalar), m, n);
+}
+
 bool
 octave_complex::save_ascii (std::ostream& os)
 {
--- a/src/ov-complex.h
+++ b/src/ov-complex.h
@@ -163,6 +163,8 @@
     return boolNDArray (dim_vector (1, 1), scalar != 0.0);
   }
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const;
+
   void increment (void) { scalar += 1.0; }
 
   void decrement (void) { scalar -= 1.0; }
--- a/src/ov-cx-mat.cc
+++ b/src/ov-cx-mat.cc
@@ -292,6 +292,24 @@
   return retval;
 }
 
+octave_value
+octave_complex_matrix::diag (octave_idx_type m, octave_idx_type n) const
+{
+  octave_value retval;
+
+  if (matrix.ndims () == 2
+      && (matrix.rows () == 1 || matrix.columns () == 1))
+    {
+      ComplexMatrix mat = matrix.matrix_value ();
+
+      retval = mat.diag (m, n);
+    }
+  else
+    error ("diag: expecting vector argument");
+
+  return retval;
+}
+
 bool
 octave_complex_matrix::save_ascii (std::ostream& os)
 {
--- a/src/ov-cx-mat.h
+++ b/src/ov-cx-mat.h
@@ -135,6 +135,8 @@
 
   octave_value diag (octave_idx_type k = 0) const;
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const;
+
   void increment (void) { matrix += Complex (1.0); }
 
   void decrement (void) { matrix -= Complex (1.0); }
--- a/src/ov-float.cc
+++ b/src/ov-float.cc
@@ -98,6 +98,12 @@
 }
 
 octave_value
+octave_float_scalar::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return FloatDiagMatrix (Array<float> (dim_vector (1, 1), scalar), m, n);
+}
+
+octave_value
 octave_float_scalar::convert_to_str_internal (bool, bool, char type) const
 {
   octave_value retval;
--- a/src/ov-float.h
+++ b/src/ov-float.h
@@ -211,6 +211,8 @@
     return boolNDArray (dim_vector (1, 1), scalar);
   }
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const;
+
   octave_value convert_to_str_internal (bool pad, bool force, char type) const;
 
   void increment (void) { ++scalar; }
--- a/src/ov-flt-complex.cc
+++ b/src/ov-flt-complex.cc
@@ -228,6 +228,12 @@
     }
 }
 
+octave_value
+octave_float_complex::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return FloatComplexDiagMatrix (Array<FloatComplex> (dim_vector (1, 1), scalar), m, n);
+}
+
 bool
 octave_float_complex::save_ascii (std::ostream& os)
 {
--- a/src/ov-flt-complex.h
+++ b/src/ov-flt-complex.h
@@ -152,6 +152,8 @@
     return boolNDArray (dim_vector (1, 1), scalar != 1.0f);
   }
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const;
+
   void increment (void) { scalar += 1.0; }
 
   void decrement (void) { scalar -= 1.0; }
--- a/src/ov-flt-cx-mat.cc
+++ b/src/ov-flt-cx-mat.cc
@@ -281,6 +281,24 @@
   return retval;
 }
 
+octave_value
+octave_float_complex_matrix::diag (octave_idx_type m, octave_idx_type n) const
+{
+  octave_value retval;
+
+  if (matrix.ndims () == 2
+      && (matrix.rows () == 1 || matrix.columns () == 1))
+    {
+      FloatComplexMatrix mat = matrix.matrix_value ();
+
+      retval = mat.diag (m, n);
+    }
+  else
+    error ("diag: expecting vector argument");
+
+  return retval;
+}
+
 bool
 octave_float_complex_matrix::save_ascii (std::ostream& os)
 {
--- a/src/ov-flt-cx-mat.h
+++ b/src/ov-flt-cx-mat.h
@@ -133,6 +133,8 @@
 
   octave_value diag (octave_idx_type k = 0) const;
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const;
+
   void increment (void) { matrix += FloatComplex (1.0); }
 
   void decrement (void) { matrix -= FloatComplex (1.0); }
--- a/src/ov-flt-re-mat.cc
+++ b/src/ov-flt-re-mat.cc
@@ -264,6 +264,24 @@
 }
 
 octave_value
+octave_float_matrix::diag (octave_idx_type m, octave_idx_type n) const
+{
+  octave_value retval;
+
+  if (matrix.ndims () == 2
+      && (matrix.rows () == 1 || matrix.columns () == 1))
+    {
+      FloatMatrix mat = matrix.matrix_value ();
+
+      retval = mat.diag (m, n);
+    }
+  else
+    error ("diag: expecting vector argument");
+
+  return retval;
+}
+
+octave_value
 octave_float_matrix::convert_to_str_internal (bool, bool, char type) const
 {
   octave_value retval;
--- a/src/ov-flt-re-mat.h
+++ b/src/ov-flt-re-mat.h
@@ -164,6 +164,8 @@
 
   octave_value diag (octave_idx_type k = 0) const;
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const;
+
   // Use matrix_ref here to clear index cache.
   void increment (void) { matrix_ref () += 1.0; }
 
--- a/src/ov-range.cc
+++ b/src/ov-range.cc
@@ -248,6 +248,13 @@
           : octave_value (range.diag (k)));
 }
 
+octave_value
+octave_range::diag (octave_idx_type m, octave_idx_type n) const
+{
+  Matrix mat = range.matrix_value ();
+
+  return mat.diag (m, n);
+}
 
 bool
 octave_range::is_true (void) const
--- a/src/ov-range.h
+++ b/src/ov-range.h
@@ -139,6 +139,8 @@
 
   octave_value diag (octave_idx_type k = 0) const;
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const;
+
   octave_value sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const
     { return range.sort (dim, mode); }
 
--- a/src/ov-re-mat.cc
+++ b/src/ov-re-mat.cc
@@ -272,6 +272,24 @@
   return retval;
 }
 
+octave_value
+octave_matrix::diag (octave_idx_type m, octave_idx_type n) const
+{
+  octave_value retval;
+
+  if (matrix.ndims () == 2
+      && (matrix.rows () == 1 || matrix.columns () == 1))
+    {
+      Matrix mat = matrix.matrix_value ();
+
+      retval = mat.diag (m, n);
+    }
+  else
+    error ("diag: expecting vector argument");
+
+  return retval;
+}
+
 // We override these two functions to allow reshaping both
 // the matrix and the index cache.
 octave_value
--- a/src/ov-re-mat.h
+++ b/src/ov-re-mat.h
@@ -178,6 +178,8 @@
 
   octave_value diag (octave_idx_type k = 0) const;
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const;
+
   octave_value reshape (const dim_vector& new_dims) const;
 
   octave_value squeeze (void) const;
--- a/src/ov-scalar.cc
+++ b/src/ov-scalar.cc
@@ -113,6 +113,12 @@
 }
 
 octave_value
+octave_scalar::diag (octave_idx_type m, octave_idx_type n) const
+{
+  return DiagMatrix (Array<double> (dim_vector (1, 1), scalar), m, n);
+}
+
+octave_value
 octave_scalar::convert_to_str_internal (bool, bool, char type) const
 {
   octave_value retval;
--- a/src/ov-scalar.h
+++ b/src/ov-scalar.h
@@ -212,6 +212,8 @@
     return boolNDArray (dim_vector (1, 1), scalar);
   }
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const;
+
   octave_value convert_to_str_internal (bool pad, bool force, char type) const;
 
   void increment (void) { ++scalar; }
--- a/src/ov.h
+++ b/src/ov.h
@@ -1075,6 +1075,9 @@
   octave_value diag (octave_idx_type k = 0) const
     { return rep->diag (k); }
 
+  octave_value diag (octave_idx_type m, octave_idx_type n) const
+    { return rep->diag (m, n); }
+
   octave_value sort (octave_idx_type dim = 0, sortmode mode = ASCENDING) const
     { return rep->sort (dim, mode); }
   octave_value sort (Array<octave_idx_type> &sidx, octave_idx_type dim = 0,