Mercurial > hg > octave-lyh
diff liboctave/CMatrix.cc @ 7800:5861b95e9879
support for compound operators, implement trans_mul, mul_trans, herm_mul and mul_herm
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Wed, 07 May 2008 16:33:15 +0200 |
parents | 82be108cc558 |
children | 776791438957 |
line wrap: on
line diff
--- a/liboctave/CMatrix.cc +++ b/liboctave/CMatrix.cc @@ -108,6 +108,10 @@ const Complex*, const octave_idx_type&, Complex&); F77_RET_T + F77_FUNC (xzdotc, XZDOTC) (const octave_idx_type&, const Complex*, const octave_idx_type&, + const Complex*, const octave_idx_type&, Complex&); + + F77_RET_T F77_FUNC (zgetrf, ZGETRF) (const octave_idx_type&, const octave_idx_type&, Complex*, const octave_idx_type&, octave_idx_type*, octave_idx_type&); @@ -3950,49 +3954,81 @@ %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14) */ +static const char * +get_blas_trans_arg (bool trans, bool conj) +{ + static char blas_notrans = 'N', blas_trans = 'T', blas_conj_trans = 'C'; + return trans ? (conj ? &blas_conj_trans : &blas_trans) : &blas_notrans; +} + +// the general GEMM operation + ComplexMatrix -operator * (const ComplexMatrix& m, const ComplexMatrix& a) +xgemm (bool transa, bool conja, const ComplexMatrix& a, + bool transb, bool conjb, const ComplexMatrix& b) { ComplexMatrix retval; - octave_idx_type nr = m.rows (); - octave_idx_type nc = m.cols (); - - octave_idx_type a_nr = a.rows (); - octave_idx_type a_nc = a.cols (); - - if (nc != a_nr) - gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc); + // conjugacy is ignored if no transpose + conja = conja && transa; + conjb = conjb && transb; + + octave_idx_type a_nr = transa ? a.cols () : a.rows (); + octave_idx_type a_nc = transa ? a.rows () : a.cols (); + + octave_idx_type b_nr = transb ? b.cols () : b.rows (); + octave_idx_type b_nc = transb ? b.rows () : b.cols (); + + if (a_nc != b_nr) + gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); else { - if (nr == 0 || nc == 0 || a_nc == 0) - retval.resize (nr, a_nc, 0.0); + if (a_nr == 0 || a_nc == 0 || b_nc == 0) + retval.resize (a_nr, b_nc, 0.0); else { - octave_idx_type ld = nr; - octave_idx_type lda = a.rows (); - - retval.resize (nr, a_nc); + octave_idx_type lda = a.rows (), tda = a.cols (); + octave_idx_type ldb = b.rows (), tdb = b.cols (); + + retval.resize (a_nr, b_nc); Complex *c = retval.fortran_vec (); - if (a_nc == 1) + if (b_nc == 1 && a_nr == 1) { - if (nr == 1) - F77_FUNC (xzdotu, XZDOTU) (nc, m.data (), 1, a.data (), 1, *c); - else - { - F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 ("N", 1), - nr, nc, 1.0, m.data (), ld, - a.data (), 1, 0.0, c, 1 - F77_CHAR_ARG_LEN (1))); - } - } + if (conja == conjb) + { + F77_FUNC (xzdotu, XZDOTU) (a_nc, a.data (), 1, b.data (), 1, *c); + if (conja) *c = std::conj (*c); + } + else if (conjb) + F77_FUNC (xzdotc, XZDOTC) (a_nc, a.data (), 1, b.data (), 1, *c); + else + F77_FUNC (xzdotc, XZDOTC) (a_nc, b.data (), 1, a.data (), 1, *c); + } + else if (b_nc == 1 && ! conjb) + { + const char *ctransa = get_blas_trans_arg (transa, conja); + F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + lda, tda, 1.0, a.data (), lda, + b.data (), 1, 0.0, c, 1 + F77_CHAR_ARG_LEN (1))); + } + else if (a_nr == 1 && ! conja) + { + const char *crevtransb = get_blas_trans_arg (! transb, conjb); + F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + ldb, tdb, 1.0, b.data (), ldb, + a.data (), 1, 0.0, c, 1 + F77_CHAR_ARG_LEN (1))); + } else { - F77_XFCN (zgemm, ZGEMM, (F77_CONST_CHAR_ARG2 ("N", 1), - F77_CONST_CHAR_ARG2 ("N", 1), - nr, a_nc, nc, 1.0, m.data (), - ld, a.data (), lda, 0.0, c, nr + const char *ctransa = get_blas_trans_arg (transa, conja); + const char *ctransb = get_blas_trans_arg (transb, conjb); + F77_XFCN (zgemm, ZGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctransb, 1), + a_nr, b_nc, a_nc, 1.0, a.data (), + lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) F77_CHAR_ARG_LEN (1))); } @@ -4002,6 +4038,12 @@ return retval; } +ComplexMatrix +operator * (const ComplexMatrix& a, const ComplexMatrix& b) +{ + return xgemm (false, false, a, false, false, b); +} + // FIXME -- it would be nice to share code among the min/max // functions below.