Mercurial > hg > octave-lyh
changeset 7804:a0c550b22e61
compound ops for float matrices
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Wed, 21 May 2008 19:25:08 +0200 |
parents | 9bcb31cc56be |
children | 62affb34e648 |
files | liboctave/ChangeLog liboctave/fCMatrix.cc liboctave/fCMatrix.h liboctave/fMatrix.cc liboctave/fMatrix.h src/ChangeLog src/OPERATORS/op-fcm-fcm.cc src/OPERATORS/op-fm-fm.cc |
diffstat | 8 files changed, 284 insertions(+), 57 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,5 +1,17 @@ 2008-05-21 Jaroslav Hajek <highegg@gmail.com> + * fCMatrix.h (xgemm): Provide decl. + (xcdotc, csyrk, cherk): New F77 decls. + * fMatrix.cc (xgemm): New function. + (operator * (const FloatMatrix&, const FloatMatrix&)): Simplify. + (get_blas_trans_arg): New function. + * fCMatrix.h (xgemm): Provide decl. + (ssyrk): New F77 decl. + * fCMatrix.cc (xgemm): New function. + (operator * (const FloatComplexMatrix&, const + FloatComplexMatrix&)): Simplify. + (get_blas_trans_arg): New function. + * dMatrix.cc, CMatrix.cc, Sparse-op-defs.h: Add missing copyright. * Sparse-op-defs.h (SPARSE_FULL_MUL): Simplify scalar*matrix case.
--- a/liboctave/fCMatrix.cc +++ b/liboctave/fCMatrix.cc @@ -107,6 +107,28 @@ const FloatComplex*, const octave_idx_type&, FloatComplex&); F77_RET_T + F77_FUNC (xcdotc, XCDOTC) (const octave_idx_type&, const FloatComplex*, const octave_idx_type&, + const FloatComplex*, const octave_idx_type&, FloatComplex&); + + F77_RET_T + F77_FUNC (csyrk, CSYRK) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const octave_idx_type&, const octave_idx_type&, + const FloatComplex&, const FloatComplex*, const octave_idx_type&, + const FloatComplex&, FloatComplex*, const octave_idx_type& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T + F77_FUNC (cherk, CHERK) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const octave_idx_type&, const octave_idx_type&, + const FloatComplex&, const FloatComplex*, const octave_idx_type&, + const FloatComplex&, FloatComplex*, const octave_idx_type& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T F77_FUNC (cgetrf, CGETRF) (const octave_idx_type&, const octave_idx_type&, FloatComplex*, const octave_idx_type&, octave_idx_type*, octave_idx_type&); @@ -3984,49 +4006,116 @@ %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14) */ +static const char * +get_blas_trans_arg (bool trans, bool conj) +{ + static char blas_notrans = 'N', blas_trans = 'T', blas_conj_trans = 'C'; + return trans ? (conj ? &blas_conj_trans : &blas_trans) : &blas_notrans; +} + +// the general GEMM operation + FloatComplexMatrix -operator * (const FloatComplexMatrix& m, const FloatComplexMatrix& a) +xgemm (bool transa, bool conja, const FloatComplexMatrix& a, + bool transb, bool conjb, const FloatComplexMatrix& b) { FloatComplexMatrix retval; - octave_idx_type nr = m.rows (); - octave_idx_type nc = m.cols (); - - octave_idx_type a_nr = a.rows (); - octave_idx_type a_nc = a.cols (); - - if (nc != a_nr) - gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc); + // conjugacy is ignored if no transpose + conja = conja && transa; + conjb = conjb && transb; + + octave_idx_type a_nr = transa ? a.cols () : a.rows (); + octave_idx_type a_nc = transa ? a.rows () : a.cols (); + + octave_idx_type b_nr = transb ? b.cols () : b.rows (); + octave_idx_type b_nc = transb ? b.rows () : b.cols (); + + if (a_nc != b_nr) + gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); else { - if (nr == 0 || nc == 0 || a_nc == 0) - retval.resize (nr, a_nc, 0.0); + if (a_nr == 0 || a_nc == 0 || b_nc == 0) + retval.resize (a_nr, b_nc, 0.0); + else if (a.data () == b.data () && a_nr == b_nc && transa != transb) + { + octave_idx_type lda = a.rows (); + + retval.resize (a_nr, b_nc); + FloatComplex *c = retval.fortran_vec (); + + const char *ctransa = get_blas_trans_arg (transa, conja); + if (conja || conjb) + { + F77_XFCN (cherk, CHERK, (F77_CONST_CHAR_ARG2 ("U", 1), + F77_CONST_CHAR_ARG2 (ctransa, 1), + a_nr, a_nc, 1.0, + a.data (), lda, 0.0, c, a_nr + F77_CHAR_ARG_LEN (1) + F77_CHAR_ARG_LEN (1))); + for (int j = 0; j < a_nr; j++) + for (int i = 0; i < j; i++) + retval.xelem (j,i) = std::conj (retval.xelem (i,j)); + } + else + { + F77_XFCN (csyrk, CSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), + F77_CONST_CHAR_ARG2 (ctransa, 1), + a_nr, a_nc, 1.0, + a.data (), lda, 0.0, c, a_nr + F77_CHAR_ARG_LEN (1) + F77_CHAR_ARG_LEN (1))); + for (int j = 0; j < a_nr; j++) + for (int i = 0; i < j; i++) + retval.xelem (j,i) = retval.xelem (i,j); + + } + + } else { - octave_idx_type ld = nr; - octave_idx_type lda = a.rows (); - - retval.resize (nr, a_nc); + octave_idx_type lda = a.rows (), tda = a.cols (); + octave_idx_type ldb = b.rows (), tdb = b.cols (); + + retval.resize (a_nr, b_nc); FloatComplex *c = retval.fortran_vec (); - if (a_nc == 1) + if (b_nc == 1 && a_nr == 1) { - if (nr == 1) - F77_FUNC (xcdotu, XCDOTU) (nc, m.data (), 1, a.data (), 1, *c); - else - { - F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 ("N", 1), - nr, nc, 1.0, m.data (), ld, - a.data (), 1, 0.0, c, 1 - F77_CHAR_ARG_LEN (1))); - } - } + if (conja == conjb) + { + F77_FUNC (xcdotu, XCDOTU) (a_nc, a.data (), 1, b.data (), 1, *c); + if (conja) *c = std::conj (*c); + } + else if (conjb) + F77_FUNC (xcdotc, XCDOTC) (a_nc, a.data (), 1, b.data (), 1, *c); + else + F77_FUNC (xcdotc, XCDOTC) (a_nc, b.data (), 1, a.data (), 1, *c); + } + else if (b_nc == 1 && ! conjb) + { + const char *ctransa = get_blas_trans_arg (transa, conja); + F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + lda, tda, 1.0, a.data (), lda, + b.data (), 1, 0.0, c, 1 + F77_CHAR_ARG_LEN (1))); + } + else if (a_nr == 1 && ! conja) + { + const char *crevtransb = get_blas_trans_arg (! transb, conjb); + F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + ldb, tdb, 1.0, b.data (), ldb, + a.data (), 1, 0.0, c, 1 + F77_CHAR_ARG_LEN (1))); + } else { - F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 ("N", 1), - F77_CONST_CHAR_ARG2 ("N", 1), - nr, a_nc, nc, 1.0, m.data (), - ld, a.data (), lda, 0.0, c, nr + const char *ctransa = get_blas_trans_arg (transa, conja); + const char *ctransb = get_blas_trans_arg (transb, conjb); + F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctransb, 1), + a_nr, b_nc, a_nc, 1.0, a.data (), + lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) F77_CHAR_ARG_LEN (1))); } @@ -4036,6 +4125,12 @@ return retval; } +FloatComplexMatrix +operator * (const FloatComplexMatrix& a, const FloatComplexMatrix& b) +{ + return xgemm (false, false, a, false, false, b); +} + // FIXME -- it would be nice to share code among the min/max // functions below.
--- a/liboctave/fCMatrix.h +++ b/liboctave/fCMatrix.h @@ -388,6 +388,10 @@ extern OCTAVE_API FloatComplexMatrix Sylvester (const FloatComplexMatrix&, const FloatComplexMatrix&, const FloatComplexMatrix&); +extern OCTAVE_API FloatComplexMatrix +xgemm (bool transa, bool conja, const FloatComplexMatrix& a, + bool transb, bool conjb, const FloatComplexMatrix& b); + extern OCTAVE_API FloatComplexMatrix operator * (const FloatMatrix&, const FloatComplexMatrix&); extern OCTAVE_API FloatComplexMatrix operator * (const FloatComplexMatrix&, const FloatMatrix&); extern OCTAVE_API FloatComplexMatrix operator * (const FloatComplexMatrix&, const FloatComplexMatrix&);
--- a/liboctave/fMatrix.cc +++ b/liboctave/fMatrix.cc @@ -3,6 +3,7 @@ Copyright (C) 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007 John W. Eaton +Copyright (C) 2008 Jaroslav Hajek This file is part of Octave. @@ -105,6 +106,15 @@ const float*, const octave_idx_type&, float&); F77_RET_T + F77_FUNC (ssyrk, SSYRK) (F77_CONST_CHAR_ARG_DECL, + F77_CONST_CHAR_ARG_DECL, + const octave_idx_type&, const octave_idx_type&, + const float&, const float*, const octave_idx_type&, + const float&, float*, const octave_idx_type& + F77_CHAR_ARG_LEN_DECL + F77_CHAR_ARG_LEN_DECL); + + F77_RET_T F77_FUNC (sgetrf, SGETRF) (const octave_idx_type&, const octave_idx_type&, float*, const octave_idx_type&, octave_idx_type*, octave_idx_type&); @@ -3361,50 +3371,88 @@ %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14) */ - -FloatMatrix -operator * (const FloatMatrix& m, const FloatMatrix& a) +static const char * +get_blas_trans_arg (bool trans) +{ + static char blas_notrans = 'N', blas_trans = 'T'; + return (trans) ? &blas_trans : &blas_notrans; +} + +// the general GEMM operation + +FloatMatrix +xgemm (bool transa, const FloatMatrix& a, bool transb, const FloatMatrix& b) { FloatMatrix retval; - octave_idx_type nr = m.rows (); - octave_idx_type nc = m.cols (); - - octave_idx_type a_nr = a.rows (); - octave_idx_type a_nc = a.cols (); - - if (nc != a_nr) - gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc); + octave_idx_type a_nr = transa ? a.cols () : a.rows (); + octave_idx_type a_nc = transa ? a.rows () : a.cols (); + + octave_idx_type b_nr = transb ? b.cols () : b.rows (); + octave_idx_type b_nc = transb ? b.rows () : b.cols (); + + if (a_nc != b_nr) + gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); else { - if (nr == 0 || nc == 0 || a_nc == 0) - retval.resize (nr, a_nc, 0.0); + if (a_nr == 0 || a_nc == 0 || b_nc == 0) + retval.resize (a_nr, b_nc, 0.0); + else if (a.data () == b.data () && a_nr == b_nc && transa != transb) + { + octave_idx_type lda = a.rows (); + + retval.resize (a_nr, b_nc); + float *c = retval.fortran_vec (); + + const char *ctransa = get_blas_trans_arg (transa); + F77_XFCN (ssyrk, SSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), + F77_CONST_CHAR_ARG2 (ctransa, 1), + a_nr, a_nc, 1.0, + a.data (), lda, 0.0, c, a_nr + F77_CHAR_ARG_LEN (1) + F77_CHAR_ARG_LEN (1))); + for (int j = 0; j < a_nr; j++) + for (int i = 0; i < j; i++) + retval.xelem (j,i) = retval.xelem (i,j); + + } else { - octave_idx_type ld = nr; - octave_idx_type lda = a_nr; - - retval.resize (nr, a_nc); + octave_idx_type lda = a.rows (), tda = a.cols (); + octave_idx_type ldb = b.rows (), tdb = b.cols (); + + retval.resize (a_nr, b_nc); float *c = retval.fortran_vec (); - if (a_nc == 1) + if (b_nc == 1) { - if (nr == 1) - F77_FUNC (xsdot, XSDOT) (nc, m.data (), 1, a.data (), 1, *c); + if (a_nr == 1) + F77_FUNC (xsdot, XSDOT) (a_nc, a.data (), 1, b.data (), 1, *c); else { - F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 ("N", 1), - nr, nc, 1.0, m.data (), ld, - a.data (), 1, 0.0, c, 1 + const char *ctransa = get_blas_trans_arg (transa); + F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + lda, tda, 1.0, a.data (), lda, + b.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } } + else if (a_nr == 1) + { + const char *crevtransb = get_blas_trans_arg (! transb); + F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + ldb, tdb, 1.0, b.data (), ldb, + a.data (), 1, 0.0, c, 1 + F77_CHAR_ARG_LEN (1))); + } else { - F77_XFCN (sgemm, SGEMM, (F77_CONST_CHAR_ARG2 ("N", 1), - F77_CONST_CHAR_ARG2 ("N", 1), - nr, a_nc, nc, 1.0, m.data (), - ld, a.data (), lda, 0.0, c, nr + const char *ctransa = get_blas_trans_arg (transa); + const char *ctransb = get_blas_trans_arg (transb); + F77_XFCN (sgemm, SGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctransb, 1), + a_nr, b_nc, a_nc, 1.0, a.data (), + lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) F77_CHAR_ARG_LEN (1))); } @@ -3414,6 +3462,12 @@ return retval; } +FloatMatrix +operator * (const FloatMatrix& a, const FloatMatrix& b) +{ + return xgemm (false, a, false, b); +} + // FIXME -- it would be nice to share code among the min/max // functions below.
--- a/liboctave/fMatrix.h +++ b/liboctave/fMatrix.h @@ -339,6 +339,8 @@ extern OCTAVE_API FloatMatrix Sylvester (const FloatMatrix&, const FloatMatrix&, const FloatMatrix&); +extern OCTAVE_API FloatMatrix xgemm (bool transa, const FloatMatrix& a, bool transb, const FloatMatrix& b); + extern OCTAVE_API FloatMatrix operator * (const FloatMatrix& a, const FloatMatrix& b); extern OCTAVE_API FloatMatrix min (float d, const FloatMatrix& m);
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,5 +1,11 @@ 2008-05-21 Jaroslav Hajek <highegg@gmail.com> + * OPERATORS/op-fcm-fcm.cc (trans_mul, mul_trans, herm_mul, mul_herm): + New functions. + (install_fcm_fcm_ops): Install them. + * OPERATORS/op-fm-fm.cc (trans_mul, mul_trans): New functions. + (install_fm_fm_ops): Install them. + * OPERATORS/op-sm-m.cc (trans_mul): New function. (install_sm_m_ops): Register it. * OPERATORS/op-m-sm.cc (mul_trans): New function.
--- a/src/OPERATORS/op-fcm-fcm.cc +++ b/src/OPERATORS/op-fcm-fcm.cc @@ -111,6 +111,34 @@ return ret; } +DEFBINOP (trans_mul, float_complex_matrix, float_complex_matrix) +{ + CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&); + return octave_value(xgemm (true, false, v1.float_complex_matrix_value (), + false, false, v2.float_complex_matrix_value ())); +} + +DEFBINOP (mul_trans, float_complex_matrix, float_complex_matrix) +{ + CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&); + return octave_value(xgemm (false, false, v1.float_complex_matrix_value (), + true, false, v2.float_complex_matrix_value ())); +} + +DEFBINOP (herm_mul, float_complex_matrix, float_complex_matrix) +{ + CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&); + return octave_value(xgemm (true, true, v1.float_complex_matrix_value (), + false, false, v2.float_complex_matrix_value ())); +} + +DEFBINOP (mul_herm, float_complex_matrix, float_complex_matrix) +{ + CAST_BINOP_ARGS (const octave_float_complex_matrix&, const octave_float_complex_matrix&); + return octave_value(xgemm (false, false, v1.float_complex_matrix_value (), + true, true, v2.float_complex_matrix_value ())); +} + DEFNDBINOP_FN (lt, float_complex_matrix, float_complex_matrix, float_complex_array, float_complex_array, mx_el_lt) DEFNDBINOP_FN (le, float_complex_matrix, float_complex_matrix, @@ -183,6 +211,14 @@ octave_float_complex_matrix, pow); INSTALL_BINOP (op_ldiv, octave_float_complex_matrix, octave_float_complex_matrix, ldiv); + INSTALL_BINOP (op_trans_mul, octave_float_complex_matrix, + octave_float_complex_matrix, trans_mul); + INSTALL_BINOP (op_mul_trans, octave_float_complex_matrix, + octave_float_complex_matrix, mul_trans); + INSTALL_BINOP (op_herm_mul, octave_float_complex_matrix, + octave_float_complex_matrix, herm_mul); + INSTALL_BINOP (op_mul_herm, octave_float_complex_matrix, + octave_float_complex_matrix, mul_herm); INSTALL_BINOP (op_lt, octave_float_complex_matrix, octave_float_complex_matrix, lt); INSTALL_BINOP (op_le, octave_float_complex_matrix,
--- a/src/OPERATORS/op-fm-fm.cc +++ b/src/OPERATORS/op-fm-fm.cc @@ -94,6 +94,20 @@ return ret; } +DEFBINOP (trans_mul, float_matrix, float_matrix) +{ + CAST_BINOP_ARGS (const octave_float_matrix&, const octave_float_matrix&); + return octave_value(xgemm (true, v1.float_matrix_value (), + false, v2.float_matrix_value ())); +} + +DEFBINOP (mul_trans, float_matrix, float_matrix) +{ + CAST_BINOP_ARGS (const octave_float_matrix&, const octave_float_matrix&); + return octave_value(xgemm (false, v1.float_matrix_value (), + true, v2.float_matrix_value ())); +} + DEFNDBINOP_FN (lt, float_matrix, float_matrix, float_array, float_array, mx_el_lt) DEFNDBINOP_FN (le, float_matrix, float_matrix, float_array, @@ -171,6 +185,10 @@ INSTALL_BINOP (op_el_ldiv, octave_float_matrix, octave_float_matrix, el_ldiv); INSTALL_BINOP (op_el_and, octave_float_matrix, octave_float_matrix, el_and); INSTALL_BINOP (op_el_or, octave_float_matrix, octave_float_matrix, el_or); + INSTALL_BINOP (op_trans_mul, octave_float_matrix, octave_float_matrix, trans_mul); + INSTALL_BINOP (op_mul_trans, octave_float_matrix, octave_float_matrix, mul_trans); + INSTALL_BINOP (op_herm_mul, octave_float_matrix, octave_float_matrix, trans_mul); + INSTALL_BINOP (op_mul_herm, octave_float_matrix, octave_float_matrix, mul_trans); INSTALL_CATOP (octave_float_matrix, octave_float_matrix, fm_fm);