Mercurial > hg > octave-lyh
diff liboctave/fCMatrix.cc @ 9665:1dba57e9d08d
use blas_trans_type for xgemm
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Sat, 26 Sep 2009 10:41:07 +0200 |
parents | 7e5b4de5fbfe |
children | f80c566bc751 |
line wrap: on
line diff
--- a/liboctave/fCMatrix.cc +++ b/liboctave/fCMatrix.cc @@ -3777,20 +3777,19 @@ // the general GEMM operation FloatComplexMatrix -xgemm (bool transa, bool conja, const FloatComplexMatrix& a, - bool transb, bool conjb, const FloatComplexMatrix& b) +xgemm (const FloatComplexMatrix& a, const FloatComplexMatrix& b, + blas_trans_type transa, blas_trans_type transb) { FloatComplexMatrix retval; - // conjugacy is ignored if no transpose - conja = conja && transa; - conjb = conjb && transb; - - octave_idx_type a_nr = transa ? a.cols () : a.rows (); - octave_idx_type a_nc = transa ? a.rows () : a.cols (); - - octave_idx_type b_nr = transb ? b.cols () : b.rows (); - octave_idx_type b_nc = transb ? b.rows () : b.cols (); + bool tra = transa != blas_no_trans, trb = transb != blas_no_trans; + bool cja = transa == blas_conj_trans, cjb = transb == blas_conj_trans; + + octave_idx_type a_nr = tra ? a.cols () : a.rows (); + octave_idx_type a_nc = tra ? a.rows () : a.cols (); + + octave_idx_type b_nr = trb ? b.cols () : b.rows (); + octave_idx_type b_nc = trb ? b.rows () : b.cols (); if (a_nc != b_nr) gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); @@ -3798,18 +3797,18 @@ { if (a_nr == 0 || a_nc == 0 || b_nc == 0) retval = FloatComplexMatrix (a_nr, b_nc, 0.0); - else if (a.data () == b.data () && a_nr == b_nc && transa != transb) + else if (a.data () == b.data () && a_nr == b_nc && tra != trb) { octave_idx_type lda = a.rows (); retval = FloatComplexMatrix (a_nr, b_nc); FloatComplex *c = retval.fortran_vec (); - const char *ctransa = get_blas_trans_arg (transa, conja); - if (conja || conjb) + const char *ctra = get_blas_trans_arg (tra, cja); + if (cja || cjb) { F77_XFCN (cherk, CHERK, (F77_CONST_CHAR_ARG2 ("U", 1), - F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctra, 1), a_nr, a_nc, 1.0, a.data (), lda, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3821,7 +3820,7 @@ else { F77_XFCN (csyrk, CSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), - F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctra, 1), a_nr, a_nc, 1.0, a.data (), lda, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3843,38 +3842,38 @@ if (b_nc == 1 && a_nr == 1) { - if (conja == conjb) + if (cja == cjb) { F77_FUNC (xcdotu, XCDOTU) (a_nc, a.data (), 1, b.data (), 1, *c); - if (conja) *c = std::conj (*c); + if (cja) *c = std::conj (*c); } - else if (conja) + else if (cja) F77_FUNC (xcdotc, XCDOTC) (a_nc, a.data (), 1, b.data (), 1, *c); else F77_FUNC (xcdotc, XCDOTC) (a_nc, b.data (), 1, a.data (), 1, *c); } - else if (b_nc == 1 && ! conjb) + else if (b_nc == 1 && ! cjb) { - const char *ctransa = get_blas_trans_arg (transa, conja); - F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + const char *ctra = get_blas_trans_arg (tra, cja); + F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (ctra, 1), lda, tda, 1.0, a.data (), lda, b.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } - else if (a_nr == 1 && ! conja && ! conjb) + else if (a_nr == 1 && ! cja && ! cjb) { - const char *crevtransb = get_blas_trans_arg (! transb, conjb); - F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + const char *crevtrb = get_blas_trans_arg (! trb, cjb); + F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (crevtrb, 1), ldb, tdb, 1.0, b.data (), ldb, a.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } else { - const char *ctransa = get_blas_trans_arg (transa, conja); - const char *ctransb = get_blas_trans_arg (transb, conjb); - F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), - F77_CONST_CHAR_ARG2 (ctransb, 1), + const char *ctra = get_blas_trans_arg (tra, cja); + const char *ctrb = get_blas_trans_arg (trb, cjb); + F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 (ctra, 1), + F77_CONST_CHAR_ARG2 (ctrb, 1), a_nr, b_nc, a_nc, 1.0, a.data (), lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3889,7 +3888,7 @@ FloatComplexMatrix operator * (const FloatComplexMatrix& a, const FloatComplexMatrix& b) { - return xgemm (false, false, a, false, false, b); + return xgemm (a, b); } // FIXME -- it would be nice to share code among the min/max