Mercurial > hg > octave-lyh
changeset 9665:1dba57e9d08d
use blas_trans_type for xgemm
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Sat, 26 Sep 2009 10:41:07 +0200 |
parents | 2c5169034035 |
children | a531dec450c4 |
files | liboctave/CMatrix.cc liboctave/CMatrix.h liboctave/ChangeLog liboctave/dMatrix.cc liboctave/dMatrix.h liboctave/fCMatrix.cc liboctave/fCMatrix.h liboctave/fMatrix.cc liboctave/fMatrix.h src/ChangeLog src/OPERATORS/op-cm-cm.cc src/OPERATORS/op-cm-m.cc src/OPERATORS/op-fcm-fcm.cc src/OPERATORS/op-fcm-fm.cc src/OPERATORS/op-fm-fcm.cc src/OPERATORS/op-fm-fm.cc src/OPERATORS/op-m-cm.cc src/OPERATORS/op-m-m.cc |
diffstat | 18 files changed, 180 insertions(+), 132 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/CMatrix.cc +++ b/liboctave/CMatrix.cc @@ -3784,20 +3784,19 @@ // the general GEMM operation ComplexMatrix -xgemm (bool transa, bool conja, const ComplexMatrix& a, - bool transb, bool conjb, const ComplexMatrix& b) +xgemm (const ComplexMatrix& a, const ComplexMatrix& b, + blas_trans_type transa, blas_trans_type transb) { ComplexMatrix retval; - // conjugacy is ignored if no transpose - conja = conja && transa; - conjb = conjb && transb; - - octave_idx_type a_nr = transa ? a.cols () : a.rows (); - octave_idx_type a_nc = transa ? a.rows () : a.cols (); - - octave_idx_type b_nr = transb ? b.cols () : b.rows (); - octave_idx_type b_nc = transb ? b.rows () : b.cols (); + bool tra = transa != blas_no_trans, trb = transb != blas_no_trans; + bool cja = transa == blas_conj_trans, cjb = transb == blas_conj_trans; + + octave_idx_type a_nr = tra ? a.cols () : a.rows (); + octave_idx_type a_nc = tra ? a.rows () : a.cols (); + + octave_idx_type b_nr = trb ? b.cols () : b.rows (); + octave_idx_type b_nc = trb ? b.rows () : b.cols (); if (a_nc != b_nr) gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); @@ -3805,18 +3804,18 @@ { if (a_nr == 0 || a_nc == 0 || b_nc == 0) retval = ComplexMatrix (a_nr, b_nc, 0.0); - else if (a.data () == b.data () && a_nr == b_nc && transa != transb) + else if (a.data () == b.data () && a_nr == b_nc && tra != trb) { octave_idx_type lda = a.rows (); retval = ComplexMatrix (a_nr, b_nc); Complex *c = retval.fortran_vec (); - const char *ctransa = get_blas_trans_arg (transa, conja); - if (conja || conjb) + const char *ctra = get_blas_trans_arg (tra, cja); + if (cja || cjb) { F77_XFCN (zherk, ZHERK, (F77_CONST_CHAR_ARG2 ("U", 1), - F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctra, 1), a_nr, a_nc, 1.0, a.data (), lda, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3828,7 +3827,7 @@ else { F77_XFCN (zsyrk, ZSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), - F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctra, 1), a_nr, a_nc, 1.0, a.data (), lda, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3850,38 +3849,38 @@ if (b_nc == 1 && a_nr == 1) { - if (conja == conjb) + if (cja == cjb) { F77_FUNC (xzdotu, XZDOTU) (a_nc, a.data (), 1, b.data (), 1, *c); - if (conja) *c = std::conj (*c); + if (cja) *c = std::conj (*c); } - else if (conja) + else if (cja) F77_FUNC (xzdotc, XZDOTC) (a_nc, a.data (), 1, b.data (), 1, *c); else F77_FUNC (xzdotc, XZDOTC) (a_nc, b.data (), 1, a.data (), 1, *c); } - else if (b_nc == 1 && ! conjb) + else if (b_nc == 1 && ! cjb) { - const char *ctransa = get_blas_trans_arg (transa, conja); - F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + const char *ctra = get_blas_trans_arg (tra, cja); + F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (ctra, 1), lda, tda, 1.0, a.data (), lda, b.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } - else if (a_nr == 1 && ! conja && ! conjb) + else if (a_nr == 1 && ! cja && ! cjb) { - const char *crevtransb = get_blas_trans_arg (! transb, conjb); - F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + const char *crevtrb = get_blas_trans_arg (! trb, cjb); + F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (crevtrb, 1), ldb, tdb, 1.0, b.data (), ldb, a.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } else { - const char *ctransa = get_blas_trans_arg (transa, conja); - const char *ctransb = get_blas_trans_arg (transb, conjb); - F77_XFCN (zgemm, ZGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), - F77_CONST_CHAR_ARG2 (ctransb, 1), + const char *ctra = get_blas_trans_arg (tra, cja); + const char *ctrb = get_blas_trans_arg (trb, cjb); + F77_XFCN (zgemm, ZGEMM, (F77_CONST_CHAR_ARG2 (ctra, 1), + F77_CONST_CHAR_ARG2 (ctrb, 1), a_nr, b_nc, a_nc, 1.0, a.data (), lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3896,7 +3895,7 @@ ComplexMatrix operator * (const ComplexMatrix& a, const ComplexMatrix& b) { - return xgemm (false, false, a, false, false, b); + return xgemm (a, b); } // FIXME -- it would be nice to share code among the min/max
--- a/liboctave/CMatrix.h +++ b/liboctave/CMatrix.h @@ -406,8 +406,9 @@ Sylvester (const ComplexMatrix&, const ComplexMatrix&, const ComplexMatrix&); extern OCTAVE_API ComplexMatrix -xgemm (bool transa, bool conja, const ComplexMatrix& a, - bool transb, bool conjb, const ComplexMatrix& b); +xgemm (const ComplexMatrix& a, const ComplexMatrix& b, + blas_trans_type transa = blas_no_trans, + blas_trans_type transb = blas_no_trans); extern OCTAVE_API ComplexMatrix operator * (const Matrix&, const ComplexMatrix&); extern OCTAVE_API ComplexMatrix operator * (const ComplexMatrix&, const Matrix&);
--- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,3 +1,18 @@ +2009-09-26 Jaroslav Hajek <highegg@gmail.com> + + * dMatrix.cc (xgemm): Use blas_trans_type to indicate transposes. + (operator *(const Matrix&, const Matrix&)): Update. + * fMatrix.cc (xgemm): Use blas_trans_type to indicate transposes. + (operator *(const FloatMatrix&, const FloatMatrix&)): Update. + * CMatrix.cc (xgemm): Use blas_trans_type to indicate transposes. + (operator *(const ComplexMatrix&, const ComplexMatrix&)): Update. + * fCMatrix.cc (xgemm): Use blas_trans_type to indicate transposes. + (operator *(const FloatComplexMatrix&, const FloatComplexMatrix&)): Update. + * dMatrix.h: Update decl. + * fMatrix.h: Update decl. + * CMatrix.h: Update decl. + * fCMatrix.h: Update decl. + 2009-09-23 Jaroslav Hajek <highegg@gmail.com> * CMatrix.cc (ComplexMatrix::ComplexMatrix (const Matrix&, const
--- a/liboctave/dMatrix.cc +++ b/liboctave/dMatrix.cc @@ -3204,15 +3204,18 @@ // the general GEMM operation Matrix -xgemm (bool transa, const Matrix& a, bool transb, const Matrix& b) +xgemm (const Matrix& a, const Matrix& b, + blas_trans_type transa, blas_trans_type transb) { Matrix retval; - octave_idx_type a_nr = transa ? a.cols () : a.rows (); - octave_idx_type a_nc = transa ? a.rows () : a.cols (); - - octave_idx_type b_nr = transb ? b.cols () : b.rows (); - octave_idx_type b_nc = transb ? b.rows () : b.cols (); + bool tra = transa != blas_no_trans, trb = transb != blas_no_trans; + + octave_idx_type a_nr = tra ? a.cols () : a.rows (); + octave_idx_type a_nc = tra ? a.rows () : a.cols (); + + octave_idx_type b_nr = trb ? b.cols () : b.rows (); + octave_idx_type b_nc = trb ? b.rows () : b.cols (); if (a_nc != b_nr) gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); @@ -3220,16 +3223,16 @@ { if (a_nr == 0 || a_nc == 0 || b_nc == 0) retval = Matrix (a_nr, b_nc, 0.0); - else if (a.data () == b.data () && a_nr == b_nc && transa != transb) + else if (a.data () == b.data () && a_nr == b_nc && tra != trb) { octave_idx_type lda = a.rows (); retval = Matrix (a_nr, b_nc); double *c = retval.fortran_vec (); - const char *ctransa = get_blas_trans_arg (transa); + const char *ctra = get_blas_trans_arg (tra); F77_XFCN (dsyrk, DSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), - F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctra, 1), a_nr, a_nc, 1.0, a.data (), lda, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3253,8 +3256,8 @@ F77_FUNC (xddot, XDDOT) (a_nc, a.data (), 1, b.data (), 1, *c); else { - const char *ctransa = get_blas_trans_arg (transa); - F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + const char *ctra = get_blas_trans_arg (tra); + F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (ctra, 1), lda, tda, 1.0, a.data (), lda, b.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); @@ -3262,18 +3265,18 @@ } else if (a_nr == 1) { - const char *crevtransb = get_blas_trans_arg (! transb); - F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + const char *crevtrb = get_blas_trans_arg (! trb); + F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (crevtrb, 1), ldb, tdb, 1.0, b.data (), ldb, a.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } else { - const char *ctransa = get_blas_trans_arg (transa); - const char *ctransb = get_blas_trans_arg (transb); - F77_XFCN (dgemm, DGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), - F77_CONST_CHAR_ARG2 (ctransb, 1), + const char *ctra = get_blas_trans_arg (tra); + const char *ctrb = get_blas_trans_arg (trb); + F77_XFCN (dgemm, DGEMM, (F77_CONST_CHAR_ARG2 (ctra, 1), + F77_CONST_CHAR_ARG2 (ctrb, 1), a_nr, b_nc, a_nc, 1.0, a.data (), lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3288,7 +3291,7 @@ Matrix operator * (const Matrix& a, const Matrix& b) { - return xgemm (false, a, false, b); + return xgemm (a, b); } // FIXME -- it would be nice to share code among the min/max
--- a/liboctave/dMatrix.h +++ b/liboctave/dMatrix.h @@ -354,7 +354,9 @@ extern OCTAVE_API Matrix Sylvester (const Matrix&, const Matrix&, const Matrix&); -extern OCTAVE_API Matrix xgemm (bool transa, const Matrix& a, bool transb, const Matrix& b); +extern OCTAVE_API Matrix xgemm (const Matrix& a, const Matrix& b, + blas_trans_type transa = blas_no_trans, + blas_trans_type transb = blas_no_trans); extern OCTAVE_API Matrix operator * (const Matrix& a, const Matrix& b);
--- a/liboctave/fCMatrix.cc +++ b/liboctave/fCMatrix.cc @@ -3777,20 +3777,19 @@ // the general GEMM operation FloatComplexMatrix -xgemm (bool transa, bool conja, const FloatComplexMatrix& a, - bool transb, bool conjb, const FloatComplexMatrix& b) +xgemm (const FloatComplexMatrix& a, const FloatComplexMatrix& b, + blas_trans_type transa, blas_trans_type transb) { FloatComplexMatrix retval; - // conjugacy is ignored if no transpose - conja = conja && transa; - conjb = conjb && transb; - - octave_idx_type a_nr = transa ? a.cols () : a.rows (); - octave_idx_type a_nc = transa ? a.rows () : a.cols (); - - octave_idx_type b_nr = transb ? b.cols () : b.rows (); - octave_idx_type b_nc = transb ? b.rows () : b.cols (); + bool tra = transa != blas_no_trans, trb = transb != blas_no_trans; + bool cja = transa == blas_conj_trans, cjb = transb == blas_conj_trans; + + octave_idx_type a_nr = tra ? a.cols () : a.rows (); + octave_idx_type a_nc = tra ? a.rows () : a.cols (); + + octave_idx_type b_nr = trb ? b.cols () : b.rows (); + octave_idx_type b_nc = trb ? b.rows () : b.cols (); if (a_nc != b_nr) gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); @@ -3798,18 +3797,18 @@ { if (a_nr == 0 || a_nc == 0 || b_nc == 0) retval = FloatComplexMatrix (a_nr, b_nc, 0.0); - else if (a.data () == b.data () && a_nr == b_nc && transa != transb) + else if (a.data () == b.data () && a_nr == b_nc && tra != trb) { octave_idx_type lda = a.rows (); retval = FloatComplexMatrix (a_nr, b_nc); FloatComplex *c = retval.fortran_vec (); - const char *ctransa = get_blas_trans_arg (transa, conja); - if (conja || conjb) + const char *ctra = get_blas_trans_arg (tra, cja); + if (cja || cjb) { F77_XFCN (cherk, CHERK, (F77_CONST_CHAR_ARG2 ("U", 1), - F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctra, 1), a_nr, a_nc, 1.0, a.data (), lda, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3821,7 +3820,7 @@ else { F77_XFCN (csyrk, CSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), - F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctra, 1), a_nr, a_nc, 1.0, a.data (), lda, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3843,38 +3842,38 @@ if (b_nc == 1 && a_nr == 1) { - if (conja == conjb) + if (cja == cjb) { F77_FUNC (xcdotu, XCDOTU) (a_nc, a.data (), 1, b.data (), 1, *c); - if (conja) *c = std::conj (*c); + if (cja) *c = std::conj (*c); } - else if (conja) + else if (cja) F77_FUNC (xcdotc, XCDOTC) (a_nc, a.data (), 1, b.data (), 1, *c); else F77_FUNC (xcdotc, XCDOTC) (a_nc, b.data (), 1, a.data (), 1, *c); } - else if (b_nc == 1 && ! conjb) + else if (b_nc == 1 && ! cjb) { - const char *ctransa = get_blas_trans_arg (transa, conja); - F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + const char *ctra = get_blas_trans_arg (tra, cja); + F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (ctra, 1), lda, tda, 1.0, a.data (), lda, b.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } - else if (a_nr == 1 && ! conja && ! conjb) + else if (a_nr == 1 && ! cja && ! cjb) { - const char *crevtransb = get_blas_trans_arg (! transb, conjb); - F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + const char *crevtrb = get_blas_trans_arg (! trb, cjb); + F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (crevtrb, 1), ldb, tdb, 1.0, b.data (), ldb, a.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } else { - const char *ctransa = get_blas_trans_arg (transa, conja); - const char *ctransb = get_blas_trans_arg (transb, conjb); - F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), - F77_CONST_CHAR_ARG2 (ctransb, 1), + const char *ctra = get_blas_trans_arg (tra, cja); + const char *ctrb = get_blas_trans_arg (trb, cjb); + F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 (ctra, 1), + F77_CONST_CHAR_ARG2 (ctrb, 1), a_nr, b_nc, a_nc, 1.0, a.data (), lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3889,7 +3888,7 @@ FloatComplexMatrix operator * (const FloatComplexMatrix& a, const FloatComplexMatrix& b) { - return xgemm (false, false, a, false, false, b); + return xgemm (a, b); } // FIXME -- it would be nice to share code among the min/max
--- a/liboctave/fCMatrix.h +++ b/liboctave/fCMatrix.h @@ -406,8 +406,9 @@ Sylvester (const FloatComplexMatrix&, const FloatComplexMatrix&, const FloatComplexMatrix&); extern OCTAVE_API FloatComplexMatrix -xgemm (bool transa, bool conja, const FloatComplexMatrix& a, - bool transb, bool conjb, const FloatComplexMatrix& b); +xgemm (const FloatComplexMatrix& a, const FloatComplexMatrix& b, + blas_trans_type transa = blas_no_trans, + blas_trans_type transb = blas_no_trans); extern OCTAVE_API FloatComplexMatrix operator * (const FloatMatrix&, const FloatComplexMatrix&); extern OCTAVE_API FloatComplexMatrix operator * (const FloatComplexMatrix&, const FloatMatrix&);
--- a/liboctave/fMatrix.cc +++ b/liboctave/fMatrix.cc @@ -3203,15 +3203,18 @@ // the general GEMM operation FloatMatrix -xgemm (bool transa, const FloatMatrix& a, bool transb, const FloatMatrix& b) +xgemm (const FloatMatrix& a, const FloatMatrix& b, + blas_trans_type transa, blas_trans_type transb) { FloatMatrix retval; - octave_idx_type a_nr = transa ? a.cols () : a.rows (); - octave_idx_type a_nc = transa ? a.rows () : a.cols (); - - octave_idx_type b_nr = transb ? b.cols () : b.rows (); - octave_idx_type b_nc = transb ? b.rows () : b.cols (); + bool tra = transa != blas_no_trans, trb = transb != blas_no_trans; + + octave_idx_type a_nr = tra ? a.cols () : a.rows (); + octave_idx_type a_nc = tra ? a.rows () : a.cols (); + + octave_idx_type b_nr = trb ? b.cols () : b.rows (); + octave_idx_type b_nc = trb ? b.rows () : b.cols (); if (a_nc != b_nr) gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); @@ -3219,16 +3222,16 @@ { if (a_nr == 0 || a_nc == 0 || b_nc == 0) retval = FloatMatrix (a_nr, b_nc, 0.0); - else if (a.data () == b.data () && a_nr == b_nc && transa != transb) + else if (a.data () == b.data () && a_nr == b_nc && tra != trb) { octave_idx_type lda = a.rows (); retval = FloatMatrix (a_nr, b_nc); float *c = retval.fortran_vec (); - const char *ctransa = get_blas_trans_arg (transa); + const char *ctra = get_blas_trans_arg (tra); F77_XFCN (ssyrk, SSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), - F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctra, 1), a_nr, a_nc, 1.0, a.data (), lda, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3252,8 +3255,8 @@ F77_FUNC (xsdot, XSDOT) (a_nc, a.data (), 1, b.data (), 1, *c); else { - const char *ctransa = get_blas_trans_arg (transa); - F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + const char *ctra = get_blas_trans_arg (tra); + F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (ctra, 1), lda, tda, 1.0, a.data (), lda, b.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); @@ -3261,18 +3264,18 @@ } else if (a_nr == 1) { - const char *crevtransb = get_blas_trans_arg (! transb); - F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + const char *crevtrb = get_blas_trans_arg (! trb); + F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (crevtrb, 1), ldb, tdb, 1.0, b.data (), ldb, a.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } else { - const char *ctransa = get_blas_trans_arg (transa); - const char *ctransb = get_blas_trans_arg (transb); - F77_XFCN (sgemm, SGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), - F77_CONST_CHAR_ARG2 (ctransb, 1), + const char *ctra = get_blas_trans_arg (tra); + const char *ctrb = get_blas_trans_arg (trb); + F77_XFCN (sgemm, SGEMM, (F77_CONST_CHAR_ARG2 (ctra, 1), + F77_CONST_CHAR_ARG2 (ctrb, 1), a_nr, b_nc, a_nc, 1.0, a.data (), lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3287,7 +3290,7 @@ FloatMatrix operator * (const FloatMatrix& a, const FloatMatrix& b) { - return xgemm (false, a, false, b); + return xgemm (a, b); } // FIXME -- it would be nice to share code among the min/max
--- a/liboctave/fMatrix.h +++ b/liboctave/fMatrix.h @@ -354,7 +354,9 @@ extern OCTAVE_API FloatMatrix Sylvester (const FloatMatrix&, const FloatMatrix&, const FloatMatrix&); -extern OCTAVE_API FloatMatrix xgemm (bool transa, const FloatMatrix& a, bool transb, const FloatMatrix& b); +extern OCTAVE_API FloatMatrix xgemm (const FloatMatrix& a, const FloatMatrix& b, + blas_trans_type transa = blas_no_trans, + blas_trans_type transb = blas_no_trans); extern OCTAVE_API FloatMatrix operator * (const FloatMatrix& a, const FloatMatrix& b);
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,14 @@ +2009-09-26 Jaroslav Hajek <highegg@gmail.com> + + * OPERATORS/op-m-m.cc (trans_mul, mul_trans): Update. + * OPERATORS/op-fm-fm.cc (trans_mul, mul_trans): Update. + * OPERATORS/op-cm-cm.cc (trans_mul, mul_trans, herm_mul, mul_herm): Update. + * OPERATORS/op-fcm-fcm.cc (trans_mul, mul_trans, herm_mul, mul_herm): Update. + * OPERATORS/op-m-cm.cc (trans_mul): Update. + * OPERATORS/op-cm-m.cc (mul_trans): Update. + * OPERATORS/op-fm-fcm.cc (trans_mul): Update. + * OPERATORS/op-fcm-fm.cc (mul_trans): Update. + 2009-09-23 Jaroslav Hajek <highegg@gmail.com> * OPERATORS/op-m-cm.cc: Declare and install trans_mul operator.
--- a/src/OPERATORS/op-cm-cm.cc +++ b/src/OPERATORS/op-cm-cm.cc @@ -112,29 +112,33 @@ DEFBINOP (trans_mul, complex_matrix, complex_matrix) { CAST_BINOP_ARGS (const octave_complex_matrix&, const octave_complex_matrix&); - return octave_value(xgemm (true, false, v1.complex_matrix_value (), - false, false, v2.complex_matrix_value ())); + return octave_value(xgemm (v1.complex_matrix_value (), + v2.complex_matrix_value (), + blas_trans, blas_no_trans)); } DEFBINOP (mul_trans, complex_matrix, complex_matrix) { CAST_BINOP_ARGS (const octave_complex_matrix&, const octave_complex_matrix&); - return octave_value(xgemm (false, false, v1.complex_matrix_value (), - true, false, v2.complex_matrix_value ())); + return octave_value(xgemm (v1.complex_matrix_value (), + v2.complex_matrix_value (), + blas_no_trans, blas_trans)); } DEFBINOP (herm_mul, complex_matrix, complex_matrix) { CAST_BINOP_ARGS (const octave_complex_matrix&, const octave_complex_matrix&); - return octave_value(xgemm (true, true, v1.complex_matrix_value (), - false, false, v2.complex_matrix_value ())); + return octave_value(xgemm (v1.complex_matrix_value (), + v2.complex_matrix_value (), + blas_conj_trans, blas_no_trans)); } DEFBINOP (mul_herm, complex_matrix, complex_matrix) { CAST_BINOP_ARGS (const octave_complex_matrix&, const octave_complex_matrix&); - return octave_value(xgemm (false, false, v1.complex_matrix_value (), - true, true, v2.complex_matrix_value ())); + return octave_value(xgemm (v1.complex_matrix_value (), + v2.complex_matrix_value (), + blas_no_trans, blas_conj_trans)); } DEFBINOP (trans_ldiv, complex_matrix, complex_matrix)
--- a/src/OPERATORS/op-cm-m.cc +++ b/src/OPERATORS/op-cm-m.cc @@ -54,8 +54,8 @@ ComplexMatrix m1 = v1.complex_matrix_value (); Matrix m2 = v2.matrix_value (); - return ComplexMatrix (xgemm (false, real (m1), true, m2), - xgemm (false, imag (m1), true, m2)); + return ComplexMatrix (xgemm (real (m1), m2, blas_no_trans, blas_trans), + xgemm (imag (m1), m2, blas_no_trans, blas_trans)); } DEFBINOP (div, complex_matrix, matrix)
--- a/src/OPERATORS/op-fcm-fcm.cc +++ b/src/OPERATORS/op-fcm-fcm.cc @@ -116,29 +116,33 @@ DEFBINOP (trans_mul, float_complex_matrix, float_complex_matrix) { CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&); - return octave_value(xgemm (true, false, v1.float_complex_matrix_value (), - false, false, v2.float_complex_matrix_value ())); + return octave_value(xgemm (v1.float_complex_matrix_value (), + v2.float_complex_matrix_value (), + blas_trans, blas_no_trans)); } DEFBINOP (mul_trans, float_complex_matrix, float_complex_matrix) { CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&); - return octave_value(xgemm (false, false, v1.float_complex_matrix_value (), - true, false, v2.float_complex_matrix_value ())); + return octave_value(xgemm (v1.float_complex_matrix_value (), + v2.float_complex_matrix_value (), + blas_no_trans, blas_trans)); } DEFBINOP (herm_mul, float_complex_matrix, float_complex_matrix) { CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&); - return octave_value(xgemm (true, true, v1.float_complex_matrix_value (), - false, false, v2.float_complex_matrix_value ())); + return octave_value(xgemm (v1.float_complex_matrix_value (), + v2.float_complex_matrix_value (), + blas_conj_trans, blas_no_trans)); } DEFBINOP (mul_herm, float_complex_matrix, float_complex_matrix) { CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&); - return octave_value(xgemm (false, false, v1.float_complex_matrix_value (), - true, true, v2.float_complex_matrix_value ())); + return octave_value(xgemm (v1.float_complex_matrix_value (), + v2.float_complex_matrix_value (), + blas_no_trans, blas_conj_trans)); } DEFBINOP (trans_ldiv, float_complex_matrix, float_complex_matrix)
--- a/src/OPERATORS/op-fcm-fm.cc +++ b/src/OPERATORS/op-fcm-fm.cc @@ -56,8 +56,8 @@ FloatComplexMatrix m1 = v1.float_complex_matrix_value (); FloatMatrix m2 = v2.float_matrix_value (); - return FloatComplexMatrix (xgemm (false, real (m1), true, m2), - xgemm (false, imag (m1), true, m2)); + return FloatComplexMatrix (xgemm (real (m1), m2, blas_no_trans, blas_trans), + xgemm (imag (m1), m2, blas_no_trans, blas_trans)); } DEFBINOP (div, float_complex_matrix, float_matrix)
--- a/src/OPERATORS/op-fm-fcm.cc +++ b/src/OPERATORS/op-fm-fcm.cc @@ -58,8 +58,8 @@ FloatMatrix m1 = v1.float_matrix_value (); FloatComplexMatrix m2 = v2.float_complex_matrix_value (); - return FloatComplexMatrix (xgemm (true, m1, false, real (m2)), - xgemm (true, m1, false, imag (m2))); + return FloatComplexMatrix (xgemm (m1, real (m2), blas_trans, blas_no_trans), + xgemm (m1, imag (m2), blas_trans, blas_no_trans)); } DEFBINOP (div, float_matrix, float_complex_matrix)
--- a/src/OPERATORS/op-fm-fm.cc +++ b/src/OPERATORS/op-fm-fm.cc @@ -99,15 +99,17 @@ DEFBINOP (trans_mul, float_matrix, float_matrix) { CAST_BINOP_ARGS (const octave_float_matrix&, const octave_float_matrix&); - return octave_value(xgemm (true, v1.float_matrix_value (), - false, v2.float_matrix_value ())); + return octave_value(xgemm (v1.float_matrix_value (), + v2.float_matrix_value (), + blas_trans, blas_no_trans)); } DEFBINOP (mul_trans, float_matrix, float_matrix) { CAST_BINOP_ARGS (const octave_float_matrix&, const octave_float_matrix&); - return octave_value(xgemm (false, v1.float_matrix_value (), - true, v2.float_matrix_value ())); + return octave_value(xgemm (v1.float_matrix_value (), + v2.float_matrix_value (), + blas_no_trans, blas_trans)); } DEFBINOP (trans_ldiv, float_matrix, float_matrix)
--- a/src/OPERATORS/op-m-cm.cc +++ b/src/OPERATORS/op-m-cm.cc @@ -56,8 +56,8 @@ Matrix m1 = v1.matrix_value (); ComplexMatrix m2 = v2.complex_matrix_value (); - return ComplexMatrix (xgemm (true, m1, false, real (m2)), - xgemm (true, m1, false, imag (m2))); + return ComplexMatrix (xgemm (m1, real (m2), blas_trans, blas_no_trans), + xgemm (m1, imag (m2), blas_trans, blas_no_trans)); } DEFBINOP (div, matrix, complex_matrix)
--- a/src/OPERATORS/op-m-m.cc +++ b/src/OPERATORS/op-m-m.cc @@ -97,13 +97,15 @@ DEFBINOP (trans_mul, matrix, matrix) { CAST_BINOP_ARGS (const octave_matrix&, const octave_matrix&); - return octave_value(xgemm (true, v1.matrix_value (), false, v2.matrix_value ())); + return octave_value(xgemm (v1.matrix_value (), v2.matrix_value (), + blas_trans, blas_no_trans)); } DEFBINOP (mul_trans, matrix, matrix) { CAST_BINOP_ARGS (const octave_matrix&, const octave_matrix&); - return octave_value(xgemm (false, v1.matrix_value (), true, v2.matrix_value ())); + return octave_value(xgemm (v1.matrix_value (), v2.matrix_value (), + blas_no_trans, blas_trans)); } DEFBINOP (trans_ldiv, matrix, matrix)