comparison liboctave/fCMatrix.cc @ 9665:1dba57e9d08d

use blas_trans_type for xgemm
author Jaroslav Hajek <highegg@gmail.com>
date Sat, 26 Sep 2009 10:41:07 +0200
parents 7e5b4de5fbfe
children f80c566bc751
comparison
equal deleted inserted replaced
9664:2c5169034035 9665:1dba57e9d08d
3775 } 3775 }
3776 3776
3777 // the general GEMM operation 3777 // the general GEMM operation
3778 3778
3779 FloatComplexMatrix 3779 FloatComplexMatrix
3780 xgemm (bool transa, bool conja, const FloatComplexMatrix& a, 3780 xgemm (const FloatComplexMatrix& a, const FloatComplexMatrix& b,
3781 bool transb, bool conjb, const FloatComplexMatrix& b) 3781 blas_trans_type transa, blas_trans_type transb)
3782 { 3782 {
3783 FloatComplexMatrix retval; 3783 FloatComplexMatrix retval;
3784 3784
3785 // conjugacy is ignored if no transpose 3785 bool tra = transa != blas_no_trans, trb = transb != blas_no_trans;
3786 conja = conja && transa; 3786 bool cja = transa == blas_conj_trans, cjb = transb == blas_conj_trans;
3787 conjb = conjb && transb; 3787
3788 3788 octave_idx_type a_nr = tra ? a.cols () : a.rows ();
3789 octave_idx_type a_nr = transa ? a.cols () : a.rows (); 3789 octave_idx_type a_nc = tra ? a.rows () : a.cols ();
3790 octave_idx_type a_nc = transa ? a.rows () : a.cols (); 3790
3791 3791 octave_idx_type b_nr = trb ? b.cols () : b.rows ();
3792 octave_idx_type b_nr = transb ? b.cols () : b.rows (); 3792 octave_idx_type b_nc = trb ? b.rows () : b.cols ();
3793 octave_idx_type b_nc = transb ? b.rows () : b.cols ();
3794 3793
3795 if (a_nc != b_nr) 3794 if (a_nc != b_nr)
3796 gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); 3795 gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc);
3797 else 3796 else
3798 { 3797 {
3799 if (a_nr == 0 || a_nc == 0 || b_nc == 0) 3798 if (a_nr == 0 || a_nc == 0 || b_nc == 0)
3800 retval = FloatComplexMatrix (a_nr, b_nc, 0.0); 3799 retval = FloatComplexMatrix (a_nr, b_nc, 0.0);
3801 else if (a.data () == b.data () && a_nr == b_nc && transa != transb) 3800 else if (a.data () == b.data () && a_nr == b_nc && tra != trb)
3802 { 3801 {
3803 octave_idx_type lda = a.rows (); 3802 octave_idx_type lda = a.rows ();
3804 3803
3805 retval = FloatComplexMatrix (a_nr, b_nc); 3804 retval = FloatComplexMatrix (a_nr, b_nc);
3806 FloatComplex *c = retval.fortran_vec (); 3805 FloatComplex *c = retval.fortran_vec ();
3807 3806
3808 const char *ctransa = get_blas_trans_arg (transa, conja); 3807 const char *ctra = get_blas_trans_arg (tra, cja);
3809 if (conja || conjb) 3808 if (cja || cjb)
3810 { 3809 {
3811 F77_XFCN (cherk, CHERK, (F77_CONST_CHAR_ARG2 ("U", 1), 3810 F77_XFCN (cherk, CHERK, (F77_CONST_CHAR_ARG2 ("U", 1),
3812 F77_CONST_CHAR_ARG2 (ctransa, 1), 3811 F77_CONST_CHAR_ARG2 (ctra, 1),
3813 a_nr, a_nc, 1.0, 3812 a_nr, a_nc, 1.0,
3814 a.data (), lda, 0.0, c, a_nr 3813 a.data (), lda, 0.0, c, a_nr
3815 F77_CHAR_ARG_LEN (1) 3814 F77_CHAR_ARG_LEN (1)
3816 F77_CHAR_ARG_LEN (1))); 3815 F77_CHAR_ARG_LEN (1)));
3817 for (int j = 0; j < a_nr; j++) 3816 for (int j = 0; j < a_nr; j++)
3819 retval.xelem (j,i) = std::conj (retval.xelem (i,j)); 3818 retval.xelem (j,i) = std::conj (retval.xelem (i,j));
3820 } 3819 }
3821 else 3820 else
3822 { 3821 {
3823 F77_XFCN (csyrk, CSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), 3822 F77_XFCN (csyrk, CSYRK, (F77_CONST_CHAR_ARG2 ("U", 1),
3824 F77_CONST_CHAR_ARG2 (ctransa, 1), 3823 F77_CONST_CHAR_ARG2 (ctra, 1),
3825 a_nr, a_nc, 1.0, 3824 a_nr, a_nc, 1.0,
3826 a.data (), lda, 0.0, c, a_nr 3825 a.data (), lda, 0.0, c, a_nr
3827 F77_CHAR_ARG_LEN (1) 3826 F77_CHAR_ARG_LEN (1)
3828 F77_CHAR_ARG_LEN (1))); 3827 F77_CHAR_ARG_LEN (1)));
3829 for (int j = 0; j < a_nr; j++) 3828 for (int j = 0; j < a_nr; j++)
3841 retval = FloatComplexMatrix (a_nr, b_nc); 3840 retval = FloatComplexMatrix (a_nr, b_nc);
3842 FloatComplex *c = retval.fortran_vec (); 3841 FloatComplex *c = retval.fortran_vec ();
3843 3842
3844 if (b_nc == 1 && a_nr == 1) 3843 if (b_nc == 1 && a_nr == 1)
3845 { 3844 {
3846 if (conja == conjb) 3845 if (cja == cjb)
3847 { 3846 {
3848 F77_FUNC (xcdotu, XCDOTU) (a_nc, a.data (), 1, b.data (), 1, *c); 3847 F77_FUNC (xcdotu, XCDOTU) (a_nc, a.data (), 1, b.data (), 1, *c);
3849 if (conja) *c = std::conj (*c); 3848 if (cja) *c = std::conj (*c);
3850 } 3849 }
3851 else if (conja) 3850 else if (cja)
3852 F77_FUNC (xcdotc, XCDOTC) (a_nc, a.data (), 1, b.data (), 1, *c); 3851 F77_FUNC (xcdotc, XCDOTC) (a_nc, a.data (), 1, b.data (), 1, *c);
3853 else 3852 else
3854 F77_FUNC (xcdotc, XCDOTC) (a_nc, b.data (), 1, a.data (), 1, *c); 3853 F77_FUNC (xcdotc, XCDOTC) (a_nc, b.data (), 1, a.data (), 1, *c);
3855 } 3854 }
3856 else if (b_nc == 1 && ! conjb) 3855 else if (b_nc == 1 && ! cjb)
3857 { 3856 {
3858 const char *ctransa = get_blas_trans_arg (transa, conja); 3857 const char *ctra = get_blas_trans_arg (tra, cja);
3859 F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), 3858 F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (ctra, 1),
3860 lda, tda, 1.0, a.data (), lda, 3859 lda, tda, 1.0, a.data (), lda,
3861 b.data (), 1, 0.0, c, 1 3860 b.data (), 1, 0.0, c, 1
3862 F77_CHAR_ARG_LEN (1))); 3861 F77_CHAR_ARG_LEN (1)));
3863 } 3862 }
3864 else if (a_nr == 1 && ! conja && ! conjb) 3863 else if (a_nr == 1 && ! cja && ! cjb)
3865 { 3864 {
3866 const char *crevtransb = get_blas_trans_arg (! transb, conjb); 3865 const char *crevtrb = get_blas_trans_arg (! trb, cjb);
3867 F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), 3866 F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (crevtrb, 1),
3868 ldb, tdb, 1.0, b.data (), ldb, 3867 ldb, tdb, 1.0, b.data (), ldb,
3869 a.data (), 1, 0.0, c, 1 3868 a.data (), 1, 0.0, c, 1
3870 F77_CHAR_ARG_LEN (1))); 3869 F77_CHAR_ARG_LEN (1)));
3871 } 3870 }
3872 else 3871 else
3873 { 3872 {
3874 const char *ctransa = get_blas_trans_arg (transa, conja); 3873 const char *ctra = get_blas_trans_arg (tra, cja);
3875 const char *ctransb = get_blas_trans_arg (transb, conjb); 3874 const char *ctrb = get_blas_trans_arg (trb, cjb);
3876 F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), 3875 F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 (ctra, 1),
3877 F77_CONST_CHAR_ARG2 (ctransb, 1), 3876 F77_CONST_CHAR_ARG2 (ctrb, 1),
3878 a_nr, b_nc, a_nc, 1.0, a.data (), 3877 a_nr, b_nc, a_nc, 1.0, a.data (),
3879 lda, b.data (), ldb, 0.0, c, a_nr 3878 lda, b.data (), ldb, 0.0, c, a_nr
3880 F77_CHAR_ARG_LEN (1) 3879 F77_CHAR_ARG_LEN (1)
3881 F77_CHAR_ARG_LEN (1))); 3880 F77_CHAR_ARG_LEN (1)));
3882 } 3881 }
3887 } 3886 }
3888 3887
3889 FloatComplexMatrix 3888 FloatComplexMatrix
3890 operator * (const FloatComplexMatrix& a, const FloatComplexMatrix& b) 3889 operator * (const FloatComplexMatrix& a, const FloatComplexMatrix& b)
3891 { 3890 {
3892 return xgemm (false, false, a, false, false, b); 3891 return xgemm (a, b);
3893 } 3892 }
3894 3893
3895 // FIXME -- it would be nice to share code among the min/max 3894 // FIXME -- it would be nice to share code among the min/max
3896 // functions below. 3895 // functions below.
3897 3896