Mercurial > hg > octave-nkf
comparison liboctave/dMatrix.cc @ 9665:1dba57e9d08d
use blas_trans_type for xgemm
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Sat, 26 Sep 2009 10:41:07 +0200 |
parents | 0d3b248f4ab6 |
children | f80c566bc751 |
comparison
equal
deleted
inserted
replaced
9664:2c5169034035 | 9665:1dba57e9d08d |
---|---|
3202 } | 3202 } |
3203 | 3203 |
3204 // the general GEMM operation | 3204 // the general GEMM operation |
3205 | 3205 |
3206 Matrix | 3206 Matrix |
3207 xgemm (bool transa, const Matrix& a, bool transb, const Matrix& b) | 3207 xgemm (const Matrix& a, const Matrix& b, |
3208 blas_trans_type transa, blas_trans_type transb) | |
3208 { | 3209 { |
3209 Matrix retval; | 3210 Matrix retval; |
3210 | 3211 |
3211 octave_idx_type a_nr = transa ? a.cols () : a.rows (); | 3212 bool tra = transa != blas_no_trans, trb = transb != blas_no_trans; |
3212 octave_idx_type a_nc = transa ? a.rows () : a.cols (); | 3213 |
3213 | 3214 octave_idx_type a_nr = tra ? a.cols () : a.rows (); |
3214 octave_idx_type b_nr = transb ? b.cols () : b.rows (); | 3215 octave_idx_type a_nc = tra ? a.rows () : a.cols (); |
3215 octave_idx_type b_nc = transb ? b.rows () : b.cols (); | 3216 |
3217 octave_idx_type b_nr = trb ? b.cols () : b.rows (); | |
3218 octave_idx_type b_nc = trb ? b.rows () : b.cols (); | |
3216 | 3219 |
3217 if (a_nc != b_nr) | 3220 if (a_nc != b_nr) |
3218 gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); | 3221 gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); |
3219 else | 3222 else |
3220 { | 3223 { |
3221 if (a_nr == 0 || a_nc == 0 || b_nc == 0) | 3224 if (a_nr == 0 || a_nc == 0 || b_nc == 0) |
3222 retval = Matrix (a_nr, b_nc, 0.0); | 3225 retval = Matrix (a_nr, b_nc, 0.0); |
3223 else if (a.data () == b.data () && a_nr == b_nc && transa != transb) | 3226 else if (a.data () == b.data () && a_nr == b_nc && tra != trb) |
3224 { | 3227 { |
3225 octave_idx_type lda = a.rows (); | 3228 octave_idx_type lda = a.rows (); |
3226 | 3229 |
3227 retval = Matrix (a_nr, b_nc); | 3230 retval = Matrix (a_nr, b_nc); |
3228 double *c = retval.fortran_vec (); | 3231 double *c = retval.fortran_vec (); |
3229 | 3232 |
3230 const char *ctransa = get_blas_trans_arg (transa); | 3233 const char *ctra = get_blas_trans_arg (tra); |
3231 F77_XFCN (dsyrk, DSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), | 3234 F77_XFCN (dsyrk, DSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), |
3232 F77_CONST_CHAR_ARG2 (ctransa, 1), | 3235 F77_CONST_CHAR_ARG2 (ctra, 1), |
3233 a_nr, a_nc, 1.0, | 3236 a_nr, a_nc, 1.0, |
3234 a.data (), lda, 0.0, c, a_nr | 3237 a.data (), lda, 0.0, c, a_nr |
3235 F77_CHAR_ARG_LEN (1) | 3238 F77_CHAR_ARG_LEN (1) |
3236 F77_CHAR_ARG_LEN (1))); | 3239 F77_CHAR_ARG_LEN (1))); |
3237 for (int j = 0; j < a_nr; j++) | 3240 for (int j = 0; j < a_nr; j++) |
3251 { | 3254 { |
3252 if (a_nr == 1) | 3255 if (a_nr == 1) |
3253 F77_FUNC (xddot, XDDOT) (a_nc, a.data (), 1, b.data (), 1, *c); | 3256 F77_FUNC (xddot, XDDOT) (a_nc, a.data (), 1, b.data (), 1, *c); |
3254 else | 3257 else |
3255 { | 3258 { |
3256 const char *ctransa = get_blas_trans_arg (transa); | 3259 const char *ctra = get_blas_trans_arg (tra); |
3257 F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), | 3260 F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (ctra, 1), |
3258 lda, tda, 1.0, a.data (), lda, | 3261 lda, tda, 1.0, a.data (), lda, |
3259 b.data (), 1, 0.0, c, 1 | 3262 b.data (), 1, 0.0, c, 1 |
3260 F77_CHAR_ARG_LEN (1))); | 3263 F77_CHAR_ARG_LEN (1))); |
3261 } | 3264 } |
3262 } | 3265 } |
3263 else if (a_nr == 1) | 3266 else if (a_nr == 1) |
3264 { | 3267 { |
3265 const char *crevtransb = get_blas_trans_arg (! transb); | 3268 const char *crevtrb = get_blas_trans_arg (! trb); |
3266 F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), | 3269 F77_XFCN (dgemv, DGEMV, (F77_CONST_CHAR_ARG2 (crevtrb, 1), |
3267 ldb, tdb, 1.0, b.data (), ldb, | 3270 ldb, tdb, 1.0, b.data (), ldb, |
3268 a.data (), 1, 0.0, c, 1 | 3271 a.data (), 1, 0.0, c, 1 |
3269 F77_CHAR_ARG_LEN (1))); | 3272 F77_CHAR_ARG_LEN (1))); |
3270 } | 3273 } |
3271 else | 3274 else |
3272 { | 3275 { |
3273 const char *ctransa = get_blas_trans_arg (transa); | 3276 const char *ctra = get_blas_trans_arg (tra); |
3274 const char *ctransb = get_blas_trans_arg (transb); | 3277 const char *ctrb = get_blas_trans_arg (trb); |
3275 F77_XFCN (dgemm, DGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), | 3278 F77_XFCN (dgemm, DGEMM, (F77_CONST_CHAR_ARG2 (ctra, 1), |
3276 F77_CONST_CHAR_ARG2 (ctransb, 1), | 3279 F77_CONST_CHAR_ARG2 (ctrb, 1), |
3277 a_nr, b_nc, a_nc, 1.0, a.data (), | 3280 a_nr, b_nc, a_nc, 1.0, a.data (), |
3278 lda, b.data (), ldb, 0.0, c, a_nr | 3281 lda, b.data (), ldb, 0.0, c, a_nr |
3279 F77_CHAR_ARG_LEN (1) | 3282 F77_CHAR_ARG_LEN (1) |
3280 F77_CHAR_ARG_LEN (1))); | 3283 F77_CHAR_ARG_LEN (1))); |
3281 } | 3284 } |
3286 } | 3289 } |
3287 | 3290 |
3288 Matrix | 3291 Matrix |
3289 operator * (const Matrix& a, const Matrix& b) | 3292 operator * (const Matrix& a, const Matrix& b) |
3290 { | 3293 { |
3291 return xgemm (false, a, false, b); | 3294 return xgemm (a, b); |
3292 } | 3295 } |
3293 | 3296 |
3294 // FIXME -- it would be nice to share code among the min/max | 3297 // FIXME -- it would be nice to share code among the min/max |
3295 // functions below. | 3298 // functions below. |
3296 | 3299 |