Mercurial > hg > octave-lyh
comparison liboctave/fCMatrix.cc @ 9665:1dba57e9d08d
use blas_trans_type for xgemm
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Sat, 26 Sep 2009 10:41:07 +0200 |
parents | 7e5b4de5fbfe |
children | f80c566bc751 |
comparison
equal
deleted
inserted
replaced
9664:2c5169034035 | 9665:1dba57e9d08d |
---|---|
3775 } | 3775 } |
3776 | 3776 |
3777 // the general GEMM operation | 3777 // the general GEMM operation |
3778 | 3778 |
3779 FloatComplexMatrix | 3779 FloatComplexMatrix |
3780 xgemm (bool transa, bool conja, const FloatComplexMatrix& a, | 3780 xgemm (const FloatComplexMatrix& a, const FloatComplexMatrix& b, |
3781 bool transb, bool conjb, const FloatComplexMatrix& b) | 3781 blas_trans_type transa, blas_trans_type transb) |
3782 { | 3782 { |
3783 FloatComplexMatrix retval; | 3783 FloatComplexMatrix retval; |
3784 | 3784 |
3785 // conjugacy is ignored if no transpose | 3785 bool tra = transa != blas_no_trans, trb = transb != blas_no_trans; |
3786 conja = conja && transa; | 3786 bool cja = transa == blas_conj_trans, cjb = transb == blas_conj_trans; |
3787 conjb = conjb && transb; | 3787 |
3788 | 3788 octave_idx_type a_nr = tra ? a.cols () : a.rows (); |
3789 octave_idx_type a_nr = transa ? a.cols () : a.rows (); | 3789 octave_idx_type a_nc = tra ? a.rows () : a.cols (); |
3790 octave_idx_type a_nc = transa ? a.rows () : a.cols (); | 3790 |
3791 | 3791 octave_idx_type b_nr = trb ? b.cols () : b.rows (); |
3792 octave_idx_type b_nr = transb ? b.cols () : b.rows (); | 3792 octave_idx_type b_nc = trb ? b.rows () : b.cols (); |
3793 octave_idx_type b_nc = transb ? b.rows () : b.cols (); | |
3794 | 3793 |
3795 if (a_nc != b_nr) | 3794 if (a_nc != b_nr) |
3796 gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); | 3795 gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); |
3797 else | 3796 else |
3798 { | 3797 { |
3799 if (a_nr == 0 || a_nc == 0 || b_nc == 0) | 3798 if (a_nr == 0 || a_nc == 0 || b_nc == 0) |
3800 retval = FloatComplexMatrix (a_nr, b_nc, 0.0); | 3799 retval = FloatComplexMatrix (a_nr, b_nc, 0.0); |
3801 else if (a.data () == b.data () && a_nr == b_nc && transa != transb) | 3800 else if (a.data () == b.data () && a_nr == b_nc && tra != trb) |
3802 { | 3801 { |
3803 octave_idx_type lda = a.rows (); | 3802 octave_idx_type lda = a.rows (); |
3804 | 3803 |
3805 retval = FloatComplexMatrix (a_nr, b_nc); | 3804 retval = FloatComplexMatrix (a_nr, b_nc); |
3806 FloatComplex *c = retval.fortran_vec (); | 3805 FloatComplex *c = retval.fortran_vec (); |
3807 | 3806 |
3808 const char *ctransa = get_blas_trans_arg (transa, conja); | 3807 const char *ctra = get_blas_trans_arg (tra, cja); |
3809 if (conja || conjb) | 3808 if (cja || cjb) |
3810 { | 3809 { |
3811 F77_XFCN (cherk, CHERK, (F77_CONST_CHAR_ARG2 ("U", 1), | 3810 F77_XFCN (cherk, CHERK, (F77_CONST_CHAR_ARG2 ("U", 1), |
3812 F77_CONST_CHAR_ARG2 (ctransa, 1), | 3811 F77_CONST_CHAR_ARG2 (ctra, 1), |
3813 a_nr, a_nc, 1.0, | 3812 a_nr, a_nc, 1.0, |
3814 a.data (), lda, 0.0, c, a_nr | 3813 a.data (), lda, 0.0, c, a_nr |
3815 F77_CHAR_ARG_LEN (1) | 3814 F77_CHAR_ARG_LEN (1) |
3816 F77_CHAR_ARG_LEN (1))); | 3815 F77_CHAR_ARG_LEN (1))); |
3817 for (int j = 0; j < a_nr; j++) | 3816 for (int j = 0; j < a_nr; j++) |
3819 retval.xelem (j,i) = std::conj (retval.xelem (i,j)); | 3818 retval.xelem (j,i) = std::conj (retval.xelem (i,j)); |
3820 } | 3819 } |
3821 else | 3820 else |
3822 { | 3821 { |
3823 F77_XFCN (csyrk, CSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), | 3822 F77_XFCN (csyrk, CSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), |
3824 F77_CONST_CHAR_ARG2 (ctransa, 1), | 3823 F77_CONST_CHAR_ARG2 (ctra, 1), |
3825 a_nr, a_nc, 1.0, | 3824 a_nr, a_nc, 1.0, |
3826 a.data (), lda, 0.0, c, a_nr | 3825 a.data (), lda, 0.0, c, a_nr |
3827 F77_CHAR_ARG_LEN (1) | 3826 F77_CHAR_ARG_LEN (1) |
3828 F77_CHAR_ARG_LEN (1))); | 3827 F77_CHAR_ARG_LEN (1))); |
3829 for (int j = 0; j < a_nr; j++) | 3828 for (int j = 0; j < a_nr; j++) |
3841 retval = FloatComplexMatrix (a_nr, b_nc); | 3840 retval = FloatComplexMatrix (a_nr, b_nc); |
3842 FloatComplex *c = retval.fortran_vec (); | 3841 FloatComplex *c = retval.fortran_vec (); |
3843 | 3842 |
3844 if (b_nc == 1 && a_nr == 1) | 3843 if (b_nc == 1 && a_nr == 1) |
3845 { | 3844 { |
3846 if (conja == conjb) | 3845 if (cja == cjb) |
3847 { | 3846 { |
3848 F77_FUNC (xcdotu, XCDOTU) (a_nc, a.data (), 1, b.data (), 1, *c); | 3847 F77_FUNC (xcdotu, XCDOTU) (a_nc, a.data (), 1, b.data (), 1, *c); |
3849 if (conja) *c = std::conj (*c); | 3848 if (cja) *c = std::conj (*c); |
3850 } | 3849 } |
3851 else if (conja) | 3850 else if (cja) |
3852 F77_FUNC (xcdotc, XCDOTC) (a_nc, a.data (), 1, b.data (), 1, *c); | 3851 F77_FUNC (xcdotc, XCDOTC) (a_nc, a.data (), 1, b.data (), 1, *c); |
3853 else | 3852 else |
3854 F77_FUNC (xcdotc, XCDOTC) (a_nc, b.data (), 1, a.data (), 1, *c); | 3853 F77_FUNC (xcdotc, XCDOTC) (a_nc, b.data (), 1, a.data (), 1, *c); |
3855 } | 3854 } |
3856 else if (b_nc == 1 && ! conjb) | 3855 else if (b_nc == 1 && ! cjb) |
3857 { | 3856 { |
3858 const char *ctransa = get_blas_trans_arg (transa, conja); | 3857 const char *ctra = get_blas_trans_arg (tra, cja); |
3859 F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), | 3858 F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (ctra, 1), |
3860 lda, tda, 1.0, a.data (), lda, | 3859 lda, tda, 1.0, a.data (), lda, |
3861 b.data (), 1, 0.0, c, 1 | 3860 b.data (), 1, 0.0, c, 1 |
3862 F77_CHAR_ARG_LEN (1))); | 3861 F77_CHAR_ARG_LEN (1))); |
3863 } | 3862 } |
3864 else if (a_nr == 1 && ! conja && ! conjb) | 3863 else if (a_nr == 1 && ! cja && ! cjb) |
3865 { | 3864 { |
3866 const char *crevtransb = get_blas_trans_arg (! transb, conjb); | 3865 const char *crevtrb = get_blas_trans_arg (! trb, cjb); |
3867 F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), | 3866 F77_XFCN (cgemv, CGEMV, (F77_CONST_CHAR_ARG2 (crevtrb, 1), |
3868 ldb, tdb, 1.0, b.data (), ldb, | 3867 ldb, tdb, 1.0, b.data (), ldb, |
3869 a.data (), 1, 0.0, c, 1 | 3868 a.data (), 1, 0.0, c, 1 |
3870 F77_CHAR_ARG_LEN (1))); | 3869 F77_CHAR_ARG_LEN (1))); |
3871 } | 3870 } |
3872 else | 3871 else |
3873 { | 3872 { |
3874 const char *ctransa = get_blas_trans_arg (transa, conja); | 3873 const char *ctra = get_blas_trans_arg (tra, cja); |
3875 const char *ctransb = get_blas_trans_arg (transb, conjb); | 3874 const char *ctrb = get_blas_trans_arg (trb, cjb); |
3876 F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), | 3875 F77_XFCN (cgemm, CGEMM, (F77_CONST_CHAR_ARG2 (ctra, 1), |
3877 F77_CONST_CHAR_ARG2 (ctransb, 1), | 3876 F77_CONST_CHAR_ARG2 (ctrb, 1), |
3878 a_nr, b_nc, a_nc, 1.0, a.data (), | 3877 a_nr, b_nc, a_nc, 1.0, a.data (), |
3879 lda, b.data (), ldb, 0.0, c, a_nr | 3878 lda, b.data (), ldb, 0.0, c, a_nr |
3880 F77_CHAR_ARG_LEN (1) | 3879 F77_CHAR_ARG_LEN (1) |
3881 F77_CHAR_ARG_LEN (1))); | 3880 F77_CHAR_ARG_LEN (1))); |
3882 } | 3881 } |
3887 } | 3886 } |
3888 | 3887 |
3889 FloatComplexMatrix | 3888 FloatComplexMatrix |
3890 operator * (const FloatComplexMatrix& a, const FloatComplexMatrix& b) | 3889 operator * (const FloatComplexMatrix& a, const FloatComplexMatrix& b) |
3891 { | 3890 { |
3892 return xgemm (false, false, a, false, false, b); | 3891 return xgemm (a, b); |
3893 } | 3892 } |
3894 | 3893 |
3895 // FIXME -- it would be nice to share code among the min/max | 3894 // FIXME -- it would be nice to share code among the min/max |
3896 // functions below. | 3895 // functions below. |
3897 | 3896 |