# HG changeset patch # User Jaroslav Hajek # Date 1210170795 -7200 # Node ID 5861b95e987906acdc486576113fa80ff06402df # Parent 199181592240840343d8350b6cdaa49d999a0616 support for compound operators, implement trans_mul, mul_trans, herm_mul and mul_herm diff --git a/liboctave/CMatrix.cc b/liboctave/CMatrix.cc --- a/liboctave/CMatrix.cc +++ b/liboctave/CMatrix.cc @@ -108,6 +108,10 @@ const Complex*, const octave_idx_type&, Complex&); F77_RET_T + F77_FUNC (xzdotc, XZDOTC) (const octave_idx_type&, const Complex*, const octave_idx_type&, + const Complex*, const octave_idx_type&, Complex&); + + F77_RET_T F77_FUNC (zgetrf, ZGETRF) (const octave_idx_type&, const octave_idx_type&, Complex*, const octave_idx_type&, octave_idx_type*, octave_idx_type&); @@ -3950,49 +3954,81 @@ %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14) */ +static const char * +get_blas_trans_arg (bool trans, bool conj) +{ + static char blas_notrans = 'N', blas_trans = 'T', blas_conj_trans = 'C'; + return trans ? (conj ? &blas_conj_trans : &blas_trans) : &blas_notrans; +} + +// the general GEMM operation + ComplexMatrix -operator * (const ComplexMatrix& m, const ComplexMatrix& a) +xgemm (bool transa, bool conja, const ComplexMatrix& a, + bool transb, bool conjb, const ComplexMatrix& b) { ComplexMatrix retval; - octave_idx_type nr = m.rows (); - octave_idx_type nc = m.cols (); - - octave_idx_type a_nr = a.rows (); - octave_idx_type a_nc = a.cols (); - - if (nc != a_nr) - gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc); + // conjugacy is ignored if no transpose + conja = conja && transa; + conjb = conjb && transb; + + octave_idx_type a_nr = transa ? a.cols () : a.rows (); + octave_idx_type a_nc = transa ? a.rows () : a.cols (); + + octave_idx_type b_nr = transb ? b.cols () : b.rows (); + octave_idx_type b_nc = transb ? b.rows () : b.cols (); + + if (a_nc != b_nr) + gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); else { - if (nr == 0 || nc == 0 || a_nc == 0) - retval.resize (nr, a_nc, 0.0); + if (a_nr == 0 || a_nc == 0 || b_nc == 0) + retval.resize (a_nr, b_nc, 0.0); else { - octave_idx_type ld = nr; - octave_idx_type lda = a.rows (); - - retval.resize (nr, a_nc); + octave_idx_type lda = a.rows (), tda = a.cols (); + octave_idx_type ldb = b.rows (), tdb = b.cols (); + + retval.resize (a_nr, b_nc); Complex *c = retval.fortran_vec (); - if (a_nc == 1) + if (b_nc == 1 && a_nr == 1) { - if (nr == 1) - F77_FUNC (xzdotu, XZDOTU) (nc, m.data (), 1, a.data (), 1, *c); - else - { - F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 ("N", 1), - nr, nc, 1.0, m.data (), ld, - a.data (), 1, 0.0, c, 1 - F77_CHAR_ARG_LEN (1))); - } - } + if (conja == conjb) + { + F77_FUNC (xzdotu, XZDOTU) (a_nc, a.data (), 1, b.data (), 1, *c); + if (conja) *c = std::conj (*c); + } + else if (conjb) + F77_FUNC (xzdotc, XZDOTC) (a_nc, a.data (), 1, b.data (), 1, *c); + else + F77_FUNC (xzdotc, XZDOTC) (a_nc, b.data (), 1, a.data (), 1, *c); + } + else if (b_nc == 1 && ! conjb) + { + const char *ctransa = get_blas_trans_arg (transa, conja); + F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + lda, tda, 1.0, a.data (), lda, + b.data (), 1, 0.0, c, 1 + F77_CHAR_ARG_LEN (1))); + } + else if (a_nr == 1 && ! conja) + { + const char *crevtransb = get_blas_trans_arg (! transb, conjb); + F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + ldb, tdb, 1.0, b.data (), ldb, + a.data (), 1, 0.0, c, 1 + F77_CHAR_ARG_LEN (1))); + } else { - F77_XFCN (zgemm, ZGEMM, (F77_CONST_CHAR_ARG2 ("N", 1), - F77_CONST_CHAR_ARG2 ("N", 1), - nr, a_nc, nc, 1.0, m.data (), - ld, a.data (), lda, 0.0, c, nr + const char *ctransa = get_blas_trans_arg (transa, conja); + const char *ctransb = get_blas_trans_arg (transb, conjb); + F77_XFCN (zgemm, ZGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctransb, 1), + a_nr, b_nc, a_nc, 1.0, a.data (), + lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) F77_CHAR_ARG_LEN (1))); } @@ -4002,6 +4038,12 @@ return retval; } +ComplexMatrix +operator * (const ComplexMatrix& a, const ComplexMatrix& b) +{ + return xgemm (false, false, a, false, false, b); +} + // FIXME -- it would be nice to share code among the min/max // functions below. diff --git a/liboctave/CMatrix.h b/liboctave/CMatrix.h --- a/liboctave/CMatrix.h +++ b/liboctave/CMatrix.h @@ -388,6 +388,10 @@ extern OCTAVE_API ComplexMatrix Sylvester (const ComplexMatrix&, const ComplexMatrix&, const ComplexMatrix&); +extern OCTAVE_API ComplexMatrix +xgemm (bool transa, bool conja, const ComplexMatrix& a, + bool transb, bool conjb, const ComplexMatrix& b); + extern OCTAVE_API ComplexMatrix operator * (const Matrix&, const ComplexMatrix&); extern OCTAVE_API ComplexMatrix operator * (const ComplexMatrix&, const Matrix&); extern OCTAVE_API ComplexMatrix operator * (const ComplexMatrix&, const ComplexMatrix&); diff --git a/liboctave/ChangeLog b/liboctave/ChangeLog --- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,5 +1,14 @@ 2008-05-21 Jaroslav Hajek + * dMatrix.h (xgemm): Provide decl. + * dMatrix.cc (xgemm): New function. + (operator * (const Matrix&, const Matrix&)): Simplify. + (get_blas_trans_arg): New function. + * CMatrix.h (xgemm): Provide decl. + * CMatrix.cc (xgemm): New function. + (operator * (const ComplexMatrix&, const ComplexMatrix&)): Simplify. + (get_blas_trans_arg): New function. + * MatrixType.cc (matrix_real_probe, matrix_complex_probe): New template functions. (MatrixType::MatrixType (const Matrix&), diff --git a/liboctave/dMatrix.cc b/liboctave/dMatrix.cc --- a/liboctave/dMatrix.cc +++ b/liboctave/dMatrix.cc @@ -3362,50 +3362,69 @@ %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14) */ - -Matrix -operator * (const Matrix& m, const Matrix& a) +static const char * +get_blas_trans_arg (bool trans) +{ + static char blas_notrans = 'N', blas_trans = 'T'; + return (trans) ? &blas_trans : &blas_notrans; +} + +// the general GEMM operation + +Matrix +xgemm (bool transa, const Matrix& a, bool transb, const Matrix& b) { Matrix retval; - octave_idx_type nr = m.rows (); - octave_idx_type nc = m.cols (); - - octave_idx_type a_nr = a.rows (); - octave_idx_type a_nc = a.cols (); - - if (nc != a_nr) - gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc); + octave_idx_type a_nr = transa ? a.cols () : a.rows (); + octave_idx_type a_nc = transa ? a.rows () : a.cols (); + + octave_idx_type b_nr = transb ? b.cols () : b.rows (); + octave_idx_type b_nc = transb ? b.rows () : b.cols (); + + if (a_nc != b_nr) + gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); else { - if (nr == 0 || nc == 0 || a_nc == 0) - retval.resize (nr, a_nc, 0.0); + if (a_nr == 0 || a_nc == 0 || b_nc == 0) + retval.resize (a_nr, b_nc, 0.0); else { - octave_idx_type ld = nr; - octave_idx_type lda = a_nr; - - retval.resize (nr, a_nc); + octave_idx_type lda = a.rows (), tda = a.cols (); + octave_idx_type ldb = b.rows (), tdb = b.cols (); + + retval.resize (a_nr, b_nc); double *c = retval.fortran_vec (); - if (a_nc == 1) + if (b_nc == 1) { - if (nr == 1) - F77_FUNC (xddot, XDDOT) (nc, m.data (), 1, a.data (), 1, *c); + if (a_nr == 1) + F77_FUNC (xddot, XDDOT) (a_nc, a.data (), 1, b.data (), 1, *c); else { - F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 ("N", 1), - nr, nc, 1.0, m.data (), ld, - a.data (), 1, 0.0, c, 1 + const char *ctransa = get_blas_trans_arg (transa); + F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), + lda, tda, 1.0, a.data (), lda, + b.data (), 1, 0.0, c, 1 F77_CHAR_ARG_LEN (1))); } } + else if (a_nr == 1) + { + const char *crevtransb = get_blas_trans_arg (! transb); + F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), + ldb, tdb, 1.0, b.data (), ldb, + a.data (), 1, 0.0, c, 1 + F77_CHAR_ARG_LEN (1))); + } else { - F77_XFCN (dgemm, DGEMM, (F77_CONST_CHAR_ARG2 ("N", 1), - F77_CONST_CHAR_ARG2 ("N", 1), - nr, a_nc, nc, 1.0, m.data (), - ld, a.data (), lda, 0.0, c, nr + const char *ctransa = get_blas_trans_arg (transa); + const char *ctransb = get_blas_trans_arg (transb); + F77_XFCN (dgemm, DGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), + F77_CONST_CHAR_ARG2 (ctransb, 1), + a_nr, b_nc, a_nc, 1.0, a.data (), + lda, b.data (), ldb, 0.0, c, a_nr F77_CHAR_ARG_LEN (1) F77_CHAR_ARG_LEN (1))); } @@ -3415,6 +3434,12 @@ return retval; } +Matrix +operator * (const Matrix& a, const Matrix& b) +{ + return xgemm (false, a, false, b); +} + // FIXME -- it would be nice to share code among the min/max // functions below. diff --git a/liboctave/dMatrix.h b/liboctave/dMatrix.h --- a/liboctave/dMatrix.h +++ b/liboctave/dMatrix.h @@ -339,6 +339,8 @@ extern OCTAVE_API Matrix Sylvester (const Matrix&, const Matrix&, const Matrix&); +extern OCTAVE_API Matrix xgemm (bool transa, const Matrix& a, bool transb, const Matrix& b); + extern OCTAVE_API Matrix operator * (const Matrix& a, const Matrix& b); extern OCTAVE_API Matrix min (double d, const Matrix& m); diff --git a/src/ChangeLog b/src/ChangeLog --- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,5 +1,67 @@ +2008-05-21 John W. Eaton + 2008-05-21 Jaroslav Hajek + * ov.h (octave_value::compound_binary_op): New enum. + (do_binary_op (octave_value::compound_binary_op, ...), + octave_value::binary_op_fcn_name (compound_binary_op), + octave_value::do_binary_op (compound_binary_op, ...)): + New declarations. + (OV_COMP_BINOP_FN): New macro (+ several expansions). + * ov.cc (octave_value::binary_op_fcn_name (compound_binary_op), + decompose_binary_op, do_binary_op (compound_binary_op, ...)): + New functions. + * ov-typeinfo.h (octave_value_typeinfo::register_binary_class_op + (octave_value::compound_binary_op, ...), + octave_value_typeinfo::register_binary_op + (octave_value::compound_binary_op, ...), + octave_value_typeinfo::do_register_binary_class_op + (octave_value::compound_binary_op, ...), + octave_value_typeinfo::do_register_binary_op + (octave_value::compound_binary_op, ...), + octave_value_typeinfo::do_lookup_binary_class_op + (octave_value::compound_binary_op), + octave_value_typeinfo::do_lookup_binary_op + (octave_value::compound_binary_op, ...)): + New declarations. + (octave_value_typeinfo::lookup_binary_class_op + (octave_value::compound_binary_op), + (octave_value_typeinfo::lookup_binary_op + (octave_value::compound_binary_op, ...)): + New functions. + (octave_value_typeinfo::compound_binary_class_ops, + octave_value_typeinfo::compound_binary_ops): + New fields. + * ov-typeinfo.cc (octave_value_typeinfo::register_binary_class_op + (octave_value::compound_binary_op, ...), + octave_value_typeinfo::register_binary_op + (octave_value::compound_binary_op, ...), + octave_value_typeinfo::do_register_binary_class_op + (octave_value::compound_binary_op, ...), + octave_value_typeinfo::do_register_binary_op + (octave_value::compound_binary_op, ...), + octave_value_typeinfo::do_lookup_binary_class_op + (octave_value::compound_binary_op), + octave_value_typeinfo::do_lookup_binary_op + (octave_value::compound_binary_op, ...)): + New functions. + (octave_value::do_register_type): Resize also compound_binary_ops + field. + * pt-exp.h (tree_expression::is_unary_expression): New virtual + function. + * pt-unop.h (tree_unary_expression::is_unary_expression): New virtual + override. + * pt-cbinop.h, pt-cbinop.cc: New files (implement + tree_compound_binary_expression class). + * pt-all.h: Include pt-cbinop.h. + * Makefile.in (PT_INCLUDES, PT_SRC): Include them in the lists. + * parse.y (make_binary_op): Call maybe_compound_binary_expression. + * OPERATORS/op-m-m.cc (trans_mul, mul_trans): New operator handlers. + (install_m_m_ops): Register them. + * OPERATORS/op-cm-cm.cc (trans_mul, mul_trans, herm_mul, mul_herm): + New operator handlers. + (install_cm_cm_ops): Register them. + * DLD-FUNCTIONS/matrix_type.cc: Fix tests relying on the older more optimistic hermitian check. diff --git a/src/Makefile.in b/src/Makefile.in --- a/src/Makefile.in +++ b/src/Makefile.in @@ -110,9 +110,9 @@ ov-base-sparse.h ov-bool-sparse.h ov-cx-sparse.h ov-re-sparse.h PT_INCLUDES := pt.h pt-all.h pt-arg-list.h pt-assign.h pt-binop.h \ - pt-bp.h pt-cell.h pt-check.h pt-cmd.h pt-colon.h pt-const.h \ - pt-decl.h pt-except.h pt-exp.h pt-fcn-handle.h pt-id.h pt-idx.h \ - pt-jump.h pt-loop.h pt-mat.h pt-misc.h \ + pt-bp.h pt-cbinop.h pt-cell.h pt-check.h pt-cmd.h pt-colon.h \ + pt-const.h pt-decl.h pt-except.h pt-exp.h pt-fcn-handle.h \ + pt-id.h pt-idx.h pt-jump.h pt-loop.h pt-mat.h pt-misc.h \ pt-pr-code.h pt-select.h pt-stmt.h pt-unop.h pt-walk.h \ INCLUDES := Cell.h base-list.h builtins.h c-file-ptr-stream.h \ @@ -189,9 +189,9 @@ $(OV_SPARSE_SRC) PT_SRC := pt.cc pt-arg-list.cc pt-assign.cc pt-bp.cc pt-binop.cc \ - pt-cell.cc pt-check.cc pt-cmd.cc pt-colon.cc pt-const.cc \ - pt-decl.cc pt-except.cc pt-exp.cc pt-fcn-handle.cc pt-id.cc \ - pt-idx.cc pt-jump.cc pt-loop.cc pt-mat.cc pt-misc.cc \ + pt-cbinop.cc pt-cell.cc pt-check.cc pt-cmd.cc pt-colon.cc \ + pt-const.cc pt-decl.cc pt-except.cc pt-exp.cc pt-fcn-handle.cc \ + pt-id.cc pt-idx.cc pt-jump.cc pt-loop.cc pt-mat.cc pt-misc.cc \ pt-pr-code.cc pt-select.cc pt-stmt.cc pt-unop.cc DIST_SRC := Cell.cc bitfcns.cc c-file-ptr-stream.cc comment-list.cc \ diff --git a/src/OPERATORS/op-cm-cm.cc b/src/OPERATORS/op-cm-cm.cc --- a/src/OPERATORS/op-cm-cm.cc +++ b/src/OPERATORS/op-cm-cm.cc @@ -107,6 +107,34 @@ return ret; } +DEFBINOP (trans_mul, complex_matrix, complex_matrix) +{ + CAST_BINOP_ARGS (const octave_complex_matrix&, const octave_complex_matrix&); + return octave_value(xgemm (true, false, v1.complex_matrix_value (), + false, false, v2.complex_matrix_value ())); +} + +DEFBINOP (mul_trans, complex_matrix, complex_matrix) +{ + CAST_BINOP_ARGS (const octave_complex_matrix&, const octave_complex_matrix&); + return octave_value(xgemm (false, false, v1.complex_matrix_value (), + true, false, v2.complex_matrix_value ())); +} + +DEFBINOP (herm_mul, complex_matrix, complex_matrix) +{ + CAST_BINOP_ARGS (const octave_complex_matrix&, const octave_complex_matrix&); + return octave_value(xgemm (true, true, v1.complex_matrix_value (), + false, false, v2.complex_matrix_value ())); +} + +DEFBINOP (mul_herm, complex_matrix, complex_matrix) +{ + CAST_BINOP_ARGS (const octave_complex_matrix&, const octave_complex_matrix&); + return octave_value(xgemm (false, false, v1.complex_matrix_value (), + true, true, v2.complex_matrix_value ())); +} + DEFNDBINOP_FN (lt, complex_matrix, complex_matrix, complex_array, complex_array, mx_el_lt) DEFNDBINOP_FN (le, complex_matrix, complex_matrix, complex_array, complex_array, mx_el_le) DEFNDBINOP_FN (eq, complex_matrix, complex_matrix, complex_array, complex_array, mx_el_eq) @@ -157,6 +185,10 @@ INSTALL_BINOP (op_div, octave_complex_matrix, octave_complex_matrix, div); INSTALL_BINOP (op_pow, octave_complex_matrix, octave_complex_matrix, pow); INSTALL_BINOP (op_ldiv, octave_complex_matrix, octave_complex_matrix, ldiv); + INSTALL_BINOP (op_trans_mul, octave_complex_matrix, octave_complex_matrix, trans_mul); + INSTALL_BINOP (op_mul_trans, octave_complex_matrix, octave_complex_matrix, mul_trans); + INSTALL_BINOP (op_herm_mul, octave_complex_matrix, octave_complex_matrix, herm_mul); + INSTALL_BINOP (op_mul_herm, octave_complex_matrix, octave_complex_matrix, mul_herm); INSTALL_BINOP (op_lt, octave_complex_matrix, octave_complex_matrix, lt); INSTALL_BINOP (op_le, octave_complex_matrix, octave_complex_matrix, le); INSTALL_BINOP (op_eq, octave_complex_matrix, octave_complex_matrix, eq); diff --git a/src/OPERATORS/op-m-m.cc b/src/OPERATORS/op-m-m.cc --- a/src/OPERATORS/op-m-m.cc +++ b/src/OPERATORS/op-m-m.cc @@ -92,6 +92,18 @@ return ret; } +DEFBINOP (trans_mul, matrix, matrix) +{ + CAST_BINOP_ARGS (const octave_matrix&, const octave_matrix&); + return octave_value(xgemm (true, v1.matrix_value (), false, v2.matrix_value ())); +} + +DEFBINOP (mul_trans, matrix, matrix) +{ + CAST_BINOP_ARGS (const octave_matrix&, const octave_matrix&); + return octave_value(xgemm (false, v1.matrix_value (), true, v2.matrix_value ())); +} + DEFNDBINOP_FN (lt, matrix, matrix, array, array, mx_el_lt) DEFNDBINOP_FN (le, matrix, matrix, array, array, mx_el_le) DEFNDBINOP_FN (eq, matrix, matrix, array, array, mx_el_eq) @@ -155,6 +167,10 @@ INSTALL_BINOP (op_el_ldiv, octave_matrix, octave_matrix, el_ldiv); INSTALL_BINOP (op_el_and, octave_matrix, octave_matrix, el_and); INSTALL_BINOP (op_el_or, octave_matrix, octave_matrix, el_or); + INSTALL_BINOP (op_trans_mul, octave_matrix, octave_matrix, trans_mul); + INSTALL_BINOP (op_mul_trans, octave_matrix, octave_matrix, mul_trans); + INSTALL_BINOP (op_herm_mul, octave_matrix, octave_matrix, trans_mul); + INSTALL_BINOP (op_mul_herm, octave_matrix, octave_matrix, mul_trans); INSTALL_CATOP (octave_matrix, octave_matrix, m_m); diff --git a/src/ov-typeinfo.cc b/src/ov-typeinfo.cc --- a/src/ov-typeinfo.cc +++ b/src/ov-typeinfo.cc @@ -142,6 +142,23 @@ } bool +octave_value_typeinfo::register_binary_class_op (octave_value::compound_binary_op op, + octave_value_typeinfo::binary_class_op_fcn f) +{ + return (instance_ok ()) + ? instance->do_register_binary_class_op (op, f) : false; +} + +bool +octave_value_typeinfo::register_binary_op (octave_value::compound_binary_op op, + int t1, int t2, + octave_value_typeinfo::binary_op_fcn f) +{ + return (instance_ok ()) + ? instance->do_register_binary_op (op, t1, t2, f) : false; +} + +bool octave_value_typeinfo::register_cat_op (int t1, int t2, octave_value_typeinfo::cat_op_fcn f) { return (instance_ok ()) @@ -223,6 +240,9 @@ binary_ops.resize (static_cast (octave_value::num_binary_ops), len, len, static_cast (0)); + compound_binary_ops.resize (static_cast (octave_value::num_compound_binary_ops), + len, len, static_cast (0)); + cat_ops.resize (len, len, static_cast (0)); assign_ops.resize (static_cast (octave_value::num_assign_ops), @@ -338,6 +358,43 @@ } bool +octave_value_typeinfo::do_register_binary_class_op (octave_value::compound_binary_op op, + octave_value_typeinfo::binary_class_op_fcn f) +{ + if (lookup_binary_class_op (op)) + { + std::string op_name = octave_value::binary_op_fcn_name (op); + + warning ("duplicate compound binary operator `%s' for class dispatch", + op_name.c_str ()); + } + + compound_binary_class_ops.checkelem (static_cast (op)) = f; + + return false; +} + +bool +octave_value_typeinfo::do_register_binary_op (octave_value::compound_binary_op op, + int t1, int t2, + octave_value_typeinfo::binary_op_fcn f) +{ + if (lookup_binary_op (op, t1, t2)) + { + std::string op_name = octave_value::binary_op_fcn_name (op); + std::string t1_name = types(t1); + std::string t2_name = types(t2); + + warning ("duplicate compound binary operator `%s' for types `%s' and `%s'", + op_name.c_str (), t1_name.c_str (), t1_name.c_str ()); + } + + compound_binary_ops.checkelem (static_cast (op), t1, t2) = f; + + return false; +} + +bool octave_value_typeinfo::do_register_cat_op (int t1, int t2, octave_value_typeinfo::cat_op_fcn f) { if (lookup_cat_op (t1, t2)) @@ -496,6 +553,19 @@ return binary_ops.checkelem (static_cast (op), t1, t2); } +octave_value_typeinfo::binary_class_op_fcn +octave_value_typeinfo::do_lookup_binary_class_op (octave_value::compound_binary_op op) +{ + return compound_binary_class_ops.checkelem (static_cast (op)); +} + +octave_value_typeinfo::binary_op_fcn +octave_value_typeinfo::do_lookup_binary_op (octave_value::compound_binary_op op, + int t1, int t2) +{ + return compound_binary_ops.checkelem (static_cast (op), t1, t2); +} + octave_value_typeinfo::cat_op_fcn octave_value_typeinfo::do_lookup_cat_op (int t1, int t2) { diff --git a/src/ov-typeinfo.h b/src/ov-typeinfo.h --- a/src/ov-typeinfo.h +++ b/src/ov-typeinfo.h @@ -80,6 +80,12 @@ static bool register_binary_op (octave_value::binary_op, int, int, binary_op_fcn); + static bool register_binary_class_op (octave_value::compound_binary_op, + binary_class_op_fcn); + + static bool register_binary_op (octave_value::compound_binary_op, int, int, + binary_op_fcn); + static bool register_cat_op (int, int, cat_op_fcn); static bool register_assign_op (octave_value::assign_op, int, int, @@ -132,6 +138,18 @@ return instance->do_lookup_binary_op (op, t1, t2); } + static binary_class_op_fcn + lookup_binary_class_op (octave_value::compound_binary_op op) + { + return instance->do_lookup_binary_class_op (op); + } + + static binary_op_fcn + lookup_binary_op (octave_value::compound_binary_op op, int t1, int t2) + { + return instance->do_lookup_binary_op (op, t1, t2); + } + static cat_op_fcn lookup_cat_op (int t1, int t2) { @@ -212,6 +230,10 @@ Array3 binary_ops; + Array compound_binary_class_ops; + + Array3 compound_binary_ops; + Array2 cat_ops; Array3 assign_ops; @@ -240,6 +262,12 @@ bool do_register_binary_op (octave_value::binary_op, int, int, binary_op_fcn); + bool do_register_binary_class_op (octave_value::compound_binary_op, + binary_class_op_fcn); + + bool do_register_binary_op (octave_value::compound_binary_op, int, int, + binary_op_fcn); + bool do_register_cat_op (int, int, cat_op_fcn); bool do_register_assign_op (octave_value::assign_op, int, int, @@ -267,6 +295,10 @@ binary_op_fcn do_lookup_binary_op (octave_value::binary_op, int, int); + binary_class_op_fcn do_lookup_binary_class_op (octave_value::compound_binary_op); + + binary_op_fcn do_lookup_binary_op (octave_value::compound_binary_op, int, int); + cat_op_fcn do_lookup_cat_op (int, int); assign_op_fcn do_lookup_assign_op (octave_value::assign_op, int, int); diff --git a/src/ov.cc b/src/ov.cc --- a/src/ov.cc +++ b/src/ov.cc @@ -350,6 +350,36 @@ } std::string +octave_value::binary_op_fcn_name (compound_binary_op op) +{ + std::string retval; + + switch (op) + { + case op_trans_mul: + retval = "transtimes"; + break; + + case op_mul_trans: + retval = "timestrans"; + break; + + case op_herm_mul: + retval = "hermtimes"; + break; + + case op_mul_herm: + retval = "timesherm"; + break; + + default: + break; + } + + return retval; +} + +std::string octave_value::assign_op_as_string (assign_op op) { std::string retval; @@ -2075,6 +2105,96 @@ return retval; } +static octave_value +decompose_binary_op (octave_value::compound_binary_op op, + const octave_value& v1, const octave_value& v2) +{ + octave_value retval; + + switch (op) + { + case octave_value::op_trans_mul: + retval = do_binary_op (octave_value::op_mul, + do_unary_op (octave_value::op_transpose, v1), + v2); + break; + case octave_value::op_mul_trans: + retval = do_binary_op (octave_value::op_mul, + v1, + do_unary_op (octave_value::op_transpose, v2)); + break; + case octave_value::op_herm_mul: + retval = do_binary_op (octave_value::op_mul, + do_unary_op (octave_value::op_hermitian, v1), + v2); + break; + case octave_value::op_mul_herm: + retval = do_binary_op (octave_value::op_mul, + v1, + do_unary_op (octave_value::op_hermitian, v2)); + break; + default: + error ("invalid compound operator"); + break; + } + + return retval; +} + +octave_value +do_binary_op (octave_value::compound_binary_op op, + const octave_value& v1, const octave_value& v2) +{ + octave_value retval; + + int t1 = v1.type_id (); + int t2 = v2.type_id (); + + if (t1 == octave_class::static_type_id () + || t2 == octave_class::static_type_id ()) + { + octave_value_typeinfo::binary_class_op_fcn f + = octave_value_typeinfo::lookup_binary_class_op (op); + + if (f) + { + try + { + retval = f (v1, v2); + } + catch (octave_execution_exception) + { + octave_exception_state = octave_no_exception; + error ("caught execution error in library function"); + } + } + else + retval = decompose_binary_op (op, v1, v2); + } + else + { + octave_value_typeinfo::binary_op_fcn f + = octave_value_typeinfo::lookup_binary_op (op, t1, t2); + + if (f) + { + try + { + retval = f (*v1.rep, *v2.rep); + } + catch (octave_execution_exception) + { + octave_exception_state = octave_no_exception; + error ("caught execution error in library function"); + } + } + else + retval = decompose_binary_op (op, v1, v2); + } + + return retval; +} + static void gripe_cat_op (const std::string& tn1, const std::string& tn2) { diff --git a/src/ov.h b/src/ov.h --- a/src/ov.h +++ b/src/ov.h @@ -110,6 +110,17 @@ unknown_binary_op }; + enum compound_binary_op + { + // ** compound operations ** + op_trans_mul, + op_mul_trans, + op_herm_mul, + op_mul_herm, + num_compound_binary_ops, + unknown_compound_binary_op + }; + enum assign_op { op_asn_eq, @@ -137,6 +148,8 @@ static std::string binary_op_as_string (binary_op); static std::string binary_op_fcn_name (binary_op); + static std::string binary_op_fcn_name (compound_binary_op); + static std::string assign_op_as_string (assign_op); static octave_value empty_conv (const std::string& type, @@ -894,6 +907,10 @@ const octave_value& a, const octave_value& b); + friend OCTINTERP_API octave_value do_binary_op (compound_binary_op op, + const octave_value& a, + const octave_value& b); + friend OCTINTERP_API octave_value do_cat_op (const octave_value& a, const octave_value& b, const Array& ra_idx); @@ -1043,6 +1060,10 @@ do_binary_op (octave_value::binary_op op, const octave_value& a, const octave_value& b); +extern OCTINTERP_API octave_value +do_binary_op (octave_value::compound_binary_op op, + const octave_value& a, const octave_value& b); + #define OV_UNOP_FN(name) \ inline octave_value \ name (const octave_value& a) \ @@ -1117,6 +1138,18 @@ OV_BINOP_FN (op_struct_ref) +#define OV_COMP_BINOP_FN(name) \ + inline octave_value \ + name (const octave_value& a1, const octave_value& a2) \ + { \ + return do_binary_op (octave_value::name, a1, a2); \ + } + +OV_COMP_BINOP_FN (op_trans_mul) +OV_COMP_BINOP_FN (op_mul_trans) +OV_COMP_BINOP_FN (op_herm_mul) +OV_COMP_BINOP_FN (op_mul_herm) + extern OCTINTERP_API void install_types (void); // FIXME -- these trait classes probably belong somehwere else... diff --git a/src/parse.y b/src/parse.y --- a/src/parse.y +++ b/src/parse.y @@ -1910,7 +1910,7 @@ int c = tok_val->column (); tree_binary_expression *e - = new tree_binary_expression (op1, op2, l, c, t); + = maybe_compound_binary_expression (op1, op2, l, c, t); return fold (e); } diff --git a/src/pt-all.h b/src/pt-all.h --- a/src/pt-all.h +++ b/src/pt-all.h @@ -29,6 +29,7 @@ #include "pt-assign.h" #include "pt-bp.h" #include "pt-binop.h" +#include "pt-cbinop.h" #include "pt-check.h" #include "pt-cmd.h" #include "pt-colon.h" diff --git a/src/pt-cbinop.cc b/src/pt-cbinop.cc new file mode 100644 --- /dev/null +++ b/src/pt-cbinop.cc @@ -0,0 +1,158 @@ +/* + +Copyright (C) 2008 Jaroslav Hajek + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#ifdef HAVE_CONFIG_H +#include +#endif + +#include "error.h" +#include "oct-obj.h" +#include "ov.h" +#include "pt-cbinop.h" +#include "pt-bp.h" +#include "pt-unop.h" +#include "pt-walk.h" + +// If a tree expression is a transpose or hermitian transpose, return +// the argument and corresponding operator. + +static octave_value::unary_op +strip_trans_herm (tree_expression *&exp) +{ + if (exp->is_unary_expression ()) + { + tree_unary_expression *uexp = + dynamic_cast (exp); + + octave_value::unary_op op = uexp->op_type (); + + if (op == octave_value::op_transpose + || op == octave_value::op_hermitian) + exp = uexp->operand (); + else + op = octave_value::unknown_unary_op; + + return op; + } + else + return octave_value::unknown_unary_op; +} + +// Possibly convert multiplication to trans_mul, mul_trans, herm_mul, +// or mul_herm. + +static octave_value::compound_binary_op +simplify_mul_op (tree_expression *&a, tree_expression *&b) +{ + octave_value::compound_binary_op retop; + octave_value::unary_op opa = strip_trans_herm (a); + + if (opa == octave_value::op_hermitian) + retop = octave_value::op_herm_mul; + else if (opa == octave_value::op_transpose) + retop = octave_value::op_trans_mul; + else + { + octave_value::unary_op opb = strip_trans_herm (b); + + if (opb == octave_value::op_hermitian) + retop = octave_value::op_mul_herm; + else if (opb == octave_value::op_transpose) + retop = octave_value::op_mul_trans; + else + retop = octave_value::unknown_compound_binary_op; + } + + return retop; +} + +tree_binary_expression * +maybe_compound_binary_expression (tree_expression *a, tree_expression *b, + int l, int c, octave_value::binary_op t) +{ + tree_expression *ca = a, *cb = b; + octave_value::compound_binary_op ct; + + switch (t) + { + case octave_value::op_mul: + ct = simplify_mul_op (ca, cb); + break; + + default: + ct = octave_value::unknown_compound_binary_op; + break; + } + + tree_binary_expression *ret = (ct == octave_value::unknown_compound_binary_op) + ? new tree_binary_expression (a, b, l, c, t) + : new tree_compound_binary_expression (a, b, l, c, t, ca, cb, ct); + + return ret; +} + + +octave_value +tree_compound_binary_expression::rvalue (void) +{ + octave_value retval; + + MAYBE_DO_BREAKPOINT; + + if (error_state) + return retval; + + if (op_lhs) + { + octave_value a = op_lhs->rvalue (); + + if (error_state) + eval_error (); + else if (a.is_defined () && op_rhs) + { + octave_value b = op_rhs->rvalue (); + + if (error_state) + eval_error (); + else if (b.is_defined ()) + { + retval = ::do_binary_op (etype, a, b); + + if (error_state) + { + retval = octave_value (); + eval_error (); + } + } + else + eval_error (); + } + else + eval_error (); + } + else + eval_error (); + + return retval; +} + + diff --git a/src/pt-cbinop.h b/src/pt-cbinop.h new file mode 100644 --- /dev/null +++ b/src/pt-cbinop.h @@ -0,0 +1,78 @@ +/* + +Copyright (C) 2008 Jaroslav Hajek + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#if !defined (octave_tree_cbinop_h) +#define octave_tree_cbinop_h 1 + +#include + +class tree_walker; + +class octave_value; +class octave_value_list; +class octave_lvalue; + +#include "ov.h" +#include "pt-binop.h" +#include "symtab.h" + +// Binary expressions that can be reduced to compound operations + +class +tree_compound_binary_expression : public tree_binary_expression +{ +public: + + tree_compound_binary_expression (tree_expression *a, tree_expression *b, + int l, int c, + octave_value::binary_op t, + tree_expression *ca, tree_expression *cb, + octave_value::compound_binary_op ct) + : tree_binary_expression (a, b, l, c, t), op_lhs (ca), op_rhs (cb), + etype (ct) { } + + octave_value rvalue (void); + + octave_value::compound_binary_op cop_type (void) const { return etype; } + +private: + + tree_expression *op_lhs; + tree_expression *op_rhs; + octave_value::compound_binary_op etype; +}; + +// a "virtual constructor" + +tree_binary_expression * +maybe_compound_binary_expression (tree_expression *a, tree_expression *b, + int l = -1, int c = -1, + octave_value::binary_op t + = octave_value::unknown_binary_op); + +#endif + +/* +;;; Local Variables: *** +;;; mode: C++ *** +;;; End: *** +*/ diff --git a/src/pt-exp.h b/src/pt-exp.h --- a/src/pt-exp.h +++ b/src/pt-exp.h @@ -62,6 +62,8 @@ virtual bool is_prefix_expression (void) const { return false; } + virtual bool is_unary_expression (void) const { return false; } + virtual bool is_binary_expression (void) const { return false; } virtual bool is_boolean_expression (void) const { return false; } diff --git a/src/pt-unop.h b/src/pt-unop.h --- a/src/pt-unop.h +++ b/src/pt-unop.h @@ -54,6 +54,8 @@ ~tree_unary_expression (void) { delete op; } + bool is_unary_expression (void) const { return true; } + bool has_magic_end (void) const { return (op && op->has_magic_end ()); } tree_expression *operand (void) { return op; }