Mercurial > hg > octave-nkf
diff liboctave/dMatrix.cc @ 9665:1dba57e9d08d
use blas_trans_type for xgemm
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Sat, 26 Sep 2009 10:41:07 +0200 |
parents | 0d3b248f4ab6 |
children | f80c566bc751 |
line wrap: on
line diff
--- a/liboctave/dMatrix.cc +++ b/liboctave/dMatrix.cc @@ -3204,15 +3204,18 @@ // the general GEMM operation Matrix -xgemm (bool transa, const Matrix& a, bool transb, const Matrix& b) +xgemm (const Matrix& a, const Matrix& b, + blas_trans_type transa, blas_trans_type transb) { Matrix retval; - octave_idx_type a_nr = transa ? a.cols () : a.rows (); - octave_idx_type a_nc = transa ? a.rows () : a.cols (); - - octave_idx_type b_nr = transb ? b.cols () : b.rows (); - octave_idx_type b_nc = transb ? b.rows () : b.cols (); + bool tra = transa != blas_no_trans, trb = transb != blas_no_trans; + + octave_idx_type a_nr = tra ? a.cols () : a.rows (); + octave_idx_type a_nc = tra ? a.rows () : a.cols (); + + octave_idx_type b_nr = trb ? b.cols () : b.rows (); + octave_idx_type b_nc = trb ? b.rows () : b.cols (); if (a_nc != b_nr) gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); @@ -3220,16 +3223,16 @@ { if (a_nr == 0 || a_nc == 0 || b_nc == 0) retval = Matrix (a_nr, b_nc, 0.0); - else if (a.data () == b.data () && a_nr == b_nc && transa != transb) + else if (a.data () == b.data () && a_nr == b_nc && tra != trb) { octave_idx_type lda = a.rows (); retval = Matrix (a_nr, b_nc); double *c = retval.fortran_vec (); - const char *ctransa = get_blas_trans_arg (transa); + const char *ctra = get_blas_trans_arg (tra); F77_XFCN (dsyrk, DSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), - F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctra, 1), a_nr, a_nc, 1.0, a.data (), lda, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3253,8 +3256,8 @@ F77_FUNC (xddot, XDDOT) (a_nc, a.data (), 1, b.data (), 1, *c); else { - const char *ctransa = get_blas_trans_arg (transa); - F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + const char *ctra = get_blas_trans_arg (tra); + F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (ctra, 1), lda, tda, 1.0, a.data (), lda, b.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); @@ -3262,18 +3265,18 @@ } else if (a_nr == 1) { - const char *crevtransb = get_blas_trans_arg (! transb); - F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + const char *crevtrb = get_blas_trans_arg (! trb); + F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (crevtrb, 1), ldb, tdb, 1.0, b.data (), ldb, a.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } else { - const char *ctransa = get_blas_trans_arg (transa); - const char *ctransb = get_blas_trans_arg (transb); - F77_XFCN (dgemm, DGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), - F77_CONST_CHAR_ARG2 (ctransb, 1), + const char *ctra = get_blas_trans_arg (tra); + const char *ctrb = get_blas_trans_arg (trb); + F77_XFCN (dgemm, DGEMM, (F77_CONST_CHAR_ARG2 (ctra, 1), + F77_CONST_CHAR_ARG2 (ctrb, 1), a_nr, b_nc, a_nc, 1.0, a.data (), lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) @@ -3288,7 +3291,7 @@ Matrix operator * (const Matrix& a, const Matrix& b) { - return xgemm (false, a, false, b); + return xgemm (a, b); } // FIXME -- it would be nice to share code among the min/max