comparison liboctave/dMatrix.cc @ 9665:1dba57e9d08d

use blas_trans_type for xgemm
author Jaroslav Hajek <highegg@gmail.com>
date Sat, 26 Sep 2009 10:41:07 +0200
parents 0d3b248f4ab6
children f80c566bc751
comparison
equal deleted inserted replaced
9664:2c5169034035 9665:1dba57e9d08d
3202 } 3202 }
3203 3203
3204 // the general GEMM operation 3204 // the general GEMM operation
3205 3205
3206 Matrix 3206 Matrix
3207 xgemm (bool transa, const Matrix& a, bool transb, const Matrix& b) 3207 xgemm (const Matrix& a, const Matrix& b,
3208 blas_trans_type transa, blas_trans_type transb)
3208 { 3209 {
3209 Matrix retval; 3210 Matrix retval;
3210 3211
3211 octave_idx_type a_nr = transa ? a.cols () : a.rows (); 3212 bool tra = transa != blas_no_trans, trb = transb != blas_no_trans;
3212 octave_idx_type a_nc = transa ? a.rows () : a.cols (); 3213
3213 3214 octave_idx_type a_nr = tra ? a.cols () : a.rows ();
3214 octave_idx_type b_nr = transb ? b.cols () : b.rows (); 3215 octave_idx_type a_nc = tra ? a.rows () : a.cols ();
3215 octave_idx_type b_nc = transb ? b.rows () : b.cols (); 3216
3217 octave_idx_type b_nr = trb ? b.cols () : b.rows ();
3218 octave_idx_type b_nc = trb ? b.rows () : b.cols ();
3216 3219
3217 if (a_nc != b_nr) 3220 if (a_nc != b_nr)
3218 gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); 3221 gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc);
3219 else 3222 else
3220 { 3223 {
3221 if (a_nr == 0 || a_nc == 0 || b_nc == 0) 3224 if (a_nr == 0 || a_nc == 0 || b_nc == 0)
3222 retval = Matrix (a_nr, b_nc, 0.0); 3225 retval = Matrix (a_nr, b_nc, 0.0);
3223 else if (a.data () == b.data () && a_nr == b_nc && transa != transb) 3226 else if (a.data () == b.data () && a_nr == b_nc && tra != trb)
3224 { 3227 {
3225 octave_idx_type lda = a.rows (); 3228 octave_idx_type lda = a.rows ();
3226 3229
3227 retval = Matrix (a_nr, b_nc); 3230 retval = Matrix (a_nr, b_nc);
3228 double *c = retval.fortran_vec (); 3231 double *c = retval.fortran_vec ();
3229 3232
3230 const char *ctransa = get_blas_trans_arg (transa); 3233 const char *ctra = get_blas_trans_arg (tra);
3231 F77_XFCN (dsyrk, DSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), 3234 F77_XFCN (dsyrk, DSYRK, (F77_CONST_CHAR_ARG2 ("U", 1),
3232 F77_CONST_CHAR_ARG2 (ctransa, 1), 3235 F77_CONST_CHAR_ARG2 (ctra, 1),
3233 a_nr, a_nc, 1.0, 3236 a_nr, a_nc, 1.0,
3234 a.data (), lda, 0.0, c, a_nr 3237 a.data (), lda, 0.0, c, a_nr
3235 F77_CHAR_ARG_LEN (1) 3238 F77_CHAR_ARG_LEN (1)
3236 F77_CHAR_ARG_LEN (1))); 3239 F77_CHAR_ARG_LEN (1)));
3237 for (int j = 0; j < a_nr; j++) 3240 for (int j = 0; j < a_nr; j++)
3251 { 3254 {
3252 if (a_nr == 1) 3255 if (a_nr == 1)
3253 F77_FUNC (xddot, XDDOT) (a_nc, a.data (), 1, b.data (), 1, *c); 3256 F77_FUNC (xddot, XDDOT) (a_nc, a.data (), 1, b.data (), 1, *c);
3254 else 3257 else
3255 { 3258 {
3256 const char *ctransa = get_blas_trans_arg (transa); 3259 const char *ctra = get_blas_trans_arg (tra);
3257 F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), 3260 F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (ctra, 1),
3258 lda, tda, 1.0, a.data (), lda, 3261 lda, tda, 1.0, a.data (), lda,
3259 b.data (), 1, 0.0, c, 1 3262 b.data (), 1, 0.0, c, 1
3260 F77_CHAR_ARG_LEN (1))); 3263 F77_CHAR_ARG_LEN (1)));
3261 } 3264 }
3262 } 3265 }
3263 else if (a_nr == 1) 3266 else if (a_nr == 1)
3264 { 3267 {
3265 const char *crevtransb = get_blas_trans_arg (! transb); 3268 const char *crevtrb = get_blas_trans_arg (! trb);
3266 F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), 3269 F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (crevtrb, 1),
3267 ldb, tdb, 1.0, b.data (), ldb, 3270 ldb, tdb, 1.0, b.data (), ldb,
3268 a.data (), 1, 0.0, c, 1 3271 a.data (), 1, 0.0, c, 1
3269 F77_CHAR_ARG_LEN (1))); 3272 F77_CHAR_ARG_LEN (1)));
3270 } 3273 }
3271 else 3274 else
3272 { 3275 {
3273 const char *ctransa = get_blas_trans_arg (transa); 3276 const char *ctra = get_blas_trans_arg (tra);
3274 const char *ctransb = get_blas_trans_arg (transb); 3277 const char *ctrb = get_blas_trans_arg (trb);
3275 F77_XFCN (dgemm, DGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), 3278 F77_XFCN (dgemm, DGEMM, (F77_CONST_CHAR_ARG2 (ctra, 1),
3276 F77_CONST_CHAR_ARG2 (ctransb, 1), 3279 F77_CONST_CHAR_ARG2 (ctrb, 1),
3277 a_nr, b_nc, a_nc, 1.0, a.data (), 3280 a_nr, b_nc, a_nc, 1.0, a.data (),
3278 lda, b.data (), ldb, 0.0, c, a_nr 3281 lda, b.data (), ldb, 0.0, c, a_nr
3279 F77_CHAR_ARG_LEN (1) 3282 F77_CHAR_ARG_LEN (1)
3280 F77_CHAR_ARG_LEN (1))); 3283 F77_CHAR_ARG_LEN (1)));
3281 } 3284 }
3286 } 3289 }
3287 3290
3288 Matrix 3291 Matrix
3289 operator * (const Matrix& a, const Matrix& b) 3292 operator * (const Matrix& a, const Matrix& b)
3290 { 3293 {
3291 return xgemm (false, a, false, b); 3294 return xgemm (a, b);
3292 } 3295 }
3293 3296
3294 // FIXME -- it would be nice to share code among the min/max 3297 // FIXME -- it would be nice to share code among the min/max
3295 // functions below. 3298 // functions below.
3296 3299