# HG changeset patch # User Jaroslav Hajek # Date 1228741928 -3600 # Node ID 9b20a484705611d4d8f64638de3fcbeb1b45b5cc # Parent ad896677a2e2e72038a598ce5ccb422766820e64 implement scalar powers of diag matrices diff --git a/src/ChangeLog b/src/ChangeLog --- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,17 @@ +2008-12-08 Jaroslav Hajek + + * xpow.cc ( xpow (const DiagMatrix& a, double b), + xpow (const DiagMatrix& a, const Complex& b), + xpow (const ComplexDiagMatrix& a, double b), + xpow (const ComplexDiagMatrix& a, const Complex& b), + xpow (const FloatDiagMatrix& a, float b), + xpow (const FloatDiagMatrix& a, const FloatComplex& b), + xpow (const FloatComplexDiagMatrix& a, float b), + xpow (const FloatComplexDiagMatrix& a, const FloatComplex& b)): + New methods. + * xpow.h: Declare them. + * OPERATORS/op-dms-template.cc: Support diagonal matrix ^ scalar. + 2008-12-08 Jaroslav Hajek * ov-re-diag.cc (octave_diag_matrix::save_binary, diff --git a/src/OPERATORS/op-dms-template.cc b/src/OPERATORS/op-dms-template.cc --- a/src/OPERATORS/op-dms-template.cc +++ b/src/OPERATORS/op-dms-template.cc @@ -25,6 +25,7 @@ #endif #include "ops.h" +#include "xpow.h" #include SINCLUDE #include MINCLUDE @@ -58,6 +59,13 @@ return v2.MATRIX_VALUE () / v1.SCALAR_VALUE (); } +DEFBINOP (dmspow, MATRIX, SCALAR) +{ + CAST_BINOP_ARGS (const OCTAVE_MATRIX&, const OCTAVE_SCALAR&); + + return xpow (v1.MATRIX_VALUE (), v2.SCALAR_VALUE ()); +} + #define SHORT_NAME CONCAT3(MSHORT, _, SSHORT) #define INST_NAME CONCAT3(install_, SHORT_NAME, _ops) @@ -72,4 +80,5 @@ INSTALL_BINOP (op_sub, OCTAVE_SCALAR, OCTAVE_MATRIX, sdmsub); INSTALL_BINOP (op_mul, OCTAVE_SCALAR, OCTAVE_MATRIX, sdmmul); INSTALL_BINOP (op_ldiv, OCTAVE_SCALAR, OCTAVE_MATRIX, sdmldiv); + INSTALL_BINOP (op_pow, OCTAVE_MATRIX, OCTAVE_SCALAR, dmspow); } diff --git a/src/xpow.cc b/src/xpow.cc --- a/src/xpow.cc +++ b/src/xpow.cc @@ -31,10 +31,12 @@ #include "Array-util.h" #include "CColVector.h" #include "CDiagMatrix.h" +#include "fCDiagMatrix.h" #include "CMatrix.h" #include "EIG.h" #include "fEIG.h" #include "dDiagMatrix.h" +#include "fDiagMatrix.h" #include "dMatrix.h" #include "mx-cm-cdm.h" #include "oct-cmplx.h" @@ -262,6 +264,38 @@ return retval; } +// -*- 5d -*- +octave_value +xpow (const DiagMatrix& a, double b) +{ + octave_value retval; + + octave_idx_type nr = a.rows (); + octave_idx_type nc = a.cols (); + + if (nr == 0 || nc == 0 || nr != nc) + error ("for A^b, A must be square"); + else + { + if (static_cast (b) == b) + { + DiagMatrix r (nr, nc); + for (octave_idx_type i = 0; i < nc; i++) + r(i, i) = std::pow (a(i, i), b); + retval = r; + } + else + { + ComplexDiagMatrix r (nr, nc); + for (octave_idx_type i = 0; i < nc; i++) + r(i, i) = std::pow (static_cast (a(i, i)), b); + retval = r; + } + } + + return retval; +} + // -*- 6 -*- octave_value xpow (const Matrix& a, const Complex& b) @@ -517,6 +551,42 @@ return retval; } +// -*- 12d -*- +octave_value +xpow (const ComplexDiagMatrix& a, const Complex& b) +{ + octave_value retval; + + octave_idx_type nr = a.rows (); + octave_idx_type nc = a.cols (); + + if (nr == 0 || nc == 0 || nr != nc) + error ("for A^b, A must be square"); + else + { + ComplexDiagMatrix r (nr, nc); + for (octave_idx_type i = 0; i < nc; i++) + r(i, i) = std::pow (a(i, i), b); + retval = r; + } + + return retval; +} + +// mixed +octave_value +xpow (const ComplexDiagMatrix& a, double b) +{ + return xpow (a, static_cast (b)); +} + +octave_value +xpow (const DiagMatrix& a, const Complex& b) +{ + return xpow (ComplexDiagMatrix (a), b); +} + + // Safer pow functions that work elementwise for matrices. // // op2 \ op1: s m cs cm @@ -1474,6 +1544,38 @@ return retval; } +// -*- 5d -*- +octave_value +xpow (const FloatDiagMatrix& a, float b) +{ + octave_value retval; + + octave_idx_type nr = a.rows (); + octave_idx_type nc = a.cols (); + + if (nr == 0 || nc == 0 || nr != nc) + error ("for A^b, A must be square"); + else + { + if (static_cast (b) == b) + { + FloatDiagMatrix r (nr, nc); + for (octave_idx_type i = 0; i < nc; i++) + r(i, i) = std::pow (a(i, i), b); + retval = r; + } + else + { + FloatComplexDiagMatrix r (nr, nc); + for (octave_idx_type i = 0; i < nc; i++) + r(i, i) = std::pow (static_cast (a(i, i)), b); + retval = r; + } + } + + return retval; +} + // -*- 6 -*- octave_value xpow (const FloatMatrix& a, const FloatComplex& b) @@ -1729,6 +1831,41 @@ return retval; } +// -*- 12d -*- +octave_value +xpow (const FloatComplexDiagMatrix& a, const FloatComplex& b) +{ + octave_value retval; + + octave_idx_type nr = a.rows (); + octave_idx_type nc = a.cols (); + + if (nr == 0 || nc == 0 || nr != nc) + error ("for A^b, A must be square"); + else + { + FloatComplexDiagMatrix r (nr, nc); + for (octave_idx_type i = 0; i < nc; i++) + r(i, i) = std::pow (a(i, i), b); + retval = r; + } + + return retval; +} + +// mixed +octave_value +xpow (const FloatComplexDiagMatrix& a, float b) +{ + return xpow (a, static_cast (b)); +} + +octave_value +xpow (const FloatDiagMatrix& a, const FloatComplex& b) +{ + return xpow (FloatComplexDiagMatrix (a), b); +} + // Safer pow functions that work elementwise for matrices. // // op2 \ op1: s m cs cm diff --git a/src/xpow.h b/src/xpow.h --- a/src/xpow.h +++ b/src/xpow.h @@ -30,6 +30,14 @@ class ComplexMatrix; class FloatMatrix; class FloatComplexMatrix; +class DiagMatrix; +class ComplexDiagMatrix; +class FloatDiagMatrix; +class FloatComplexDiagMatrix; +class NDArray; +class FloatNDArray; +class ComplexNDArray; +class FloatComplexNDArray; class octave_value; extern octave_value xpow (double a, double b); @@ -40,6 +48,9 @@ extern octave_value xpow (const Matrix& a, double b); extern octave_value xpow (const Matrix& a, const Complex& b); +extern octave_value xpow (const DiagMatrix& a, double b); +extern octave_value xpow (const DiagMatrix& a, const Complex& b); + extern octave_value xpow (const Complex& a, double b); extern octave_value xpow (const Complex& a, const Matrix& b); extern octave_value xpow (const Complex& a, const Complex& b); @@ -48,6 +59,9 @@ extern octave_value xpow (const ComplexMatrix& a, double b); extern octave_value xpow (const ComplexMatrix& a, const Complex& b); +extern octave_value xpow (const ComplexDiagMatrix& a, double b); +extern octave_value xpow (const ComplexDiagMatrix& a, const Complex& b); + extern octave_value elem_xpow (double a, const Matrix& b); extern octave_value elem_xpow (double a, const ComplexMatrix& b); @@ -89,6 +103,9 @@ extern octave_value xpow (const FloatMatrix& a, float b); extern octave_value xpow (const FloatMatrix& a, const FloatComplex& b); +extern octave_value xpow (const FloatDiagMatrix& a, float b); +extern octave_value xpow (const FloatDiagMatrix& a, const FloatComplex& b); + extern octave_value xpow (const FloatComplex& a, float b); extern octave_value xpow (const FloatComplex& a, const FloatMatrix& b); extern octave_value xpow (const FloatComplex& a, const FloatComplex& b); @@ -97,6 +114,9 @@ extern octave_value xpow (const FloatComplexMatrix& a, float b); extern octave_value xpow (const FloatComplexMatrix& a, const FloatComplex& b); +extern octave_value xpow (const FloatComplexDiagMatrix& a, float b); +extern octave_value xpow (const FloatComplexDiagMatrix& a, const FloatComplex& b); + extern octave_value elem_xpow (float a, const FloatMatrix& b); extern octave_value elem_xpow (float a, const FloatComplexMatrix& b);