diff liboctave/CMatrix.cc @ 7800:5861b95e9879

support for compound operators, implement trans_mul, mul_trans, herm_mul and mul_herm
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 07 May 2008 16:33:15 +0200
parents 82be108cc558
children 776791438957
line wrap: on
line diff
--- a/liboctave/CMatrix.cc
+++ b/liboctave/CMatrix.cc
@@ -108,6 +108,10 @@
 			     const Complex*, const octave_idx_type&, Complex&);
 
   F77_RET_T
+  F77_FUNC (xzdotc, XZDOTC) (const octave_idx_type&, const Complex*, const octave_idx_type&,
+			     const Complex*, const octave_idx_type&, Complex&);
+
+  F77_RET_T
   F77_FUNC (zgetrf, ZGETRF) (const octave_idx_type&, const octave_idx_type&, Complex*, const octave_idx_type&,
 			     octave_idx_type*, octave_idx_type&);
 
@@ -3950,49 +3954,81 @@
 %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14)
 */
 
+static const char *
+get_blas_trans_arg (bool trans, bool conj)
+{
+  static char blas_notrans = 'N', blas_trans = 'T', blas_conj_trans = 'C';
+  return trans ? (conj ? &blas_conj_trans : &blas_trans) : &blas_notrans;
+}
+
+// the general GEMM operation
+
 ComplexMatrix
-operator * (const ComplexMatrix& m, const ComplexMatrix& a)
+xgemm (bool transa, bool conja, const ComplexMatrix& a, 
+       bool transb, bool conjb, const ComplexMatrix& b)
 {
   ComplexMatrix retval;
 
-  octave_idx_type nr = m.rows ();
-  octave_idx_type nc = m.cols ();
-
-  octave_idx_type a_nr = a.rows ();
-  octave_idx_type a_nc = a.cols ();
-
-  if (nc != a_nr)
-    gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc);
+  // conjugacy is ignored if no transpose
+  conja = conja && transa;
+  conjb = conjb && transb;
+
+  octave_idx_type a_nr = transa ? a.cols () : a.rows ();
+  octave_idx_type a_nc = transa ? a.rows () : a.cols ();
+
+  octave_idx_type b_nr = transb ? b.cols () : b.rows ();
+  octave_idx_type b_nc = transb ? b.rows () : b.cols ();
+
+  if (a_nc != b_nr)
+    gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc);
   else
     {
-      if (nr == 0 || nc == 0 || a_nc == 0)
-	retval.resize (nr, a_nc, 0.0);
+      if (a_nr == 0 || a_nc == 0 || b_nc == 0)
+	retval.resize (a_nr, b_nc, 0.0);
       else
 	{
-	  octave_idx_type ld  = nr;
-	  octave_idx_type lda = a.rows ();
-
-	  retval.resize (nr, a_nc);
+	  octave_idx_type lda = a.rows (), tda = a.cols ();
+	  octave_idx_type ldb = b.rows (), tdb = b.cols ();
+
+	  retval.resize (a_nr, b_nc);
 	  Complex *c = retval.fortran_vec ();
 
-	  if (a_nc == 1)
+	  if (b_nc == 1 && a_nr == 1)
 	    {
-	      if (nr == 1)
-		F77_FUNC (xzdotu, XZDOTU) (nc, m.data (), 1, a.data (), 1, *c);
-	      else
-		{
-		  F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 ("N", 1),
-					   nr, nc, 1.0,  m.data (), ld,
-					   a.data (), 1, 0.0, c, 1
-					   F77_CHAR_ARG_LEN (1)));
-		}
-	    }
+              if (conja == conjb)
+                {
+                  F77_FUNC (xzdotu, XZDOTU) (a_nc, a.data (), 1, b.data (), 1, *c);
+                  if (conja) *c = std::conj (*c);
+                }
+              else if (conjb)
+                  F77_FUNC (xzdotc, XZDOTC) (a_nc, a.data (), 1, b.data (), 1, *c);
+              else
+                  F77_FUNC (xzdotc, XZDOTC) (a_nc, b.data (), 1, a.data (), 1, *c);
+            }
+          else if (b_nc == 1 && ! conjb)
+            {
+              const char *ctransa = get_blas_trans_arg (transa, conja);
+              F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1),
+                                       lda, tda, 1.0,  a.data (), lda,
+                                       b.data (), 1, 0.0, c, 1
+                                       F77_CHAR_ARG_LEN (1)));
+            }
+          else if (a_nr == 1 && ! conja)
+            {
+              const char *crevtransb = get_blas_trans_arg (! transb, conjb);
+              F77_XFCN (zgemv, ZGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1),
+                                       ldb, tdb, 1.0,  b.data (), ldb,
+                                       a.data (), 1, 0.0, c, 1
+                                       F77_CHAR_ARG_LEN (1)));
+            }
 	  else
 	    {
-	      F77_XFCN (zgemm, ZGEMM, (F77_CONST_CHAR_ARG2 ("N", 1),
-				       F77_CONST_CHAR_ARG2 ("N", 1),
-				       nr, a_nc, nc, 1.0, m.data (),
-				       ld, a.data (), lda, 0.0, c, nr
+              const char *ctransa = get_blas_trans_arg (transa, conja);
+              const char *ctransb = get_blas_trans_arg (transb, conjb);
+	      F77_XFCN (zgemm, ZGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1),
+				       F77_CONST_CHAR_ARG2 (ctransb, 1),
+				       a_nr, b_nc, a_nc, 1.0, a.data (),
+				       lda, b.data (), ldb, 0.0, c, a_nr
 				       F77_CHAR_ARG_LEN (1)
 				       F77_CHAR_ARG_LEN (1)));
 	    }
@@ -4002,6 +4038,12 @@
   return retval;
 }
 
+ComplexMatrix
+operator * (const ComplexMatrix& a, const ComplexMatrix& b)
+{
+  return xgemm (false, false, a, false, false, b);
+}
+
 // FIXME -- it would be nice to share code among the min/max
 // functions below.