Mercurial > hg > octave-nkf
comparison liboctave/fMatrix.cc @ 7804:a0c550b22e61
compound ops for float matrices
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Wed, 21 May 2008 19:25:08 +0200 |
parents | f42c6f8d6d8e |
children | 935be827eaf8 |
comparison
equal
deleted
inserted
replaced
7803:9bcb31cc56be | 7804:a0c550b22e61 |
---|---|
1 // Matrix manipulations. | 1 // Matrix manipulations. |
2 /* | 2 /* |
3 | 3 |
4 Copyright (C) 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, | 4 Copyright (C) 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, |
5 2003, 2004, 2005, 2006, 2007 John W. Eaton | 5 2003, 2004, 2005, 2006, 2007 John W. Eaton |
6 Copyright (C) 2008 Jaroslav Hajek | |
6 | 7 |
7 This file is part of Octave. | 8 This file is part of Octave. |
8 | 9 |
9 Octave is free software; you can redistribute it and/or modify it | 10 Octave is free software; you can redistribute it and/or modify it |
10 under the terms of the GNU General Public License as published by the | 11 under the terms of the GNU General Public License as published by the |
101 F77_CHAR_ARG_LEN_DECL); | 102 F77_CHAR_ARG_LEN_DECL); |
102 | 103 |
103 F77_RET_T | 104 F77_RET_T |
104 F77_FUNC (xsdot, XSDOT) (const octave_idx_type&, const float*, const octave_idx_type&, | 105 F77_FUNC (xsdot, XSDOT) (const octave_idx_type&, const float*, const octave_idx_type&, |
105 const float*, const octave_idx_type&, float&); | 106 const float*, const octave_idx_type&, float&); |
107 | |
108 F77_RET_T | |
109 F77_FUNC (ssyrk, SSYRK) (F77_CONST_CHAR_ARG_DECL, | |
110 F77_CONST_CHAR_ARG_DECL, | |
111 const octave_idx_type&, const octave_idx_type&, | |
112 const float&, const float*, const octave_idx_type&, | |
113 const float&, float*, const octave_idx_type& | |
114 F77_CHAR_ARG_LEN_DECL | |
115 F77_CHAR_ARG_LEN_DECL); | |
106 | 116 |
107 F77_RET_T | 117 F77_RET_T |
108 F77_FUNC (sgetrf, SGETRF) (const octave_idx_type&, const octave_idx_type&, float*, const octave_idx_type&, | 118 F77_FUNC (sgetrf, SGETRF) (const octave_idx_type&, const octave_idx_type&, float*, const octave_idx_type&, |
109 octave_idx_type*, octave_idx_type&); | 119 octave_idx_type*, octave_idx_type&); |
110 | 120 |
3359 %!assert([M*cv,M*cv],M*[cv,cv],1e-14) | 3369 %!assert([M*cv,M*cv],M*[cv,cv],1e-14) |
3360 %!assert([rv*M;rv*M],[rv;rv]*M,1e-14) | 3370 %!assert([rv*M;rv*M],[rv;rv]*M,1e-14) |
3361 %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14) | 3371 %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14) |
3362 */ | 3372 */ |
3363 | 3373 |
3364 | 3374 static const char * |
3365 FloatMatrix | 3375 get_blas_trans_arg (bool trans) |
3366 operator * (const FloatMatrix& m, const FloatMatrix& a) | 3376 { |
3377 static char blas_notrans = 'N', blas_trans = 'T'; | |
3378 return (trans) ? &blas_trans : &blas_notrans; | |
3379 } | |
3380 | |
3381 // the general GEMM operation | |
3382 | |
3383 FloatMatrix | |
3384 xgemm (bool transa, const FloatMatrix& a, bool transb, const FloatMatrix& b) | |
3367 { | 3385 { |
3368 FloatMatrix retval; | 3386 FloatMatrix retval; |
3369 | 3387 |
3370 octave_idx_type nr = m.rows (); | 3388 octave_idx_type a_nr = transa ? a.cols () : a.rows (); |
3371 octave_idx_type nc = m.cols (); | 3389 octave_idx_type a_nc = transa ? a.rows () : a.cols (); |
3372 | 3390 |
3373 octave_idx_type a_nr = a.rows (); | 3391 octave_idx_type b_nr = transb ? b.cols () : b.rows (); |
3374 octave_idx_type a_nc = a.cols (); | 3392 octave_idx_type b_nc = transb ? b.rows () : b.cols (); |
3375 | 3393 |
3376 if (nc != a_nr) | 3394 if (a_nc != b_nr) |
3377 gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc); | 3395 gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc); |
3378 else | 3396 else |
3379 { | 3397 { |
3380 if (nr == 0 || nc == 0 || a_nc == 0) | 3398 if (a_nr == 0 || a_nc == 0 || b_nc == 0) |
3381 retval.resize (nr, a_nc, 0.0); | 3399 retval.resize (a_nr, b_nc, 0.0); |
3400 else if (a.data () == b.data () && a_nr == b_nc && transa != transb) | |
3401 { | |
3402 octave_idx_type lda = a.rows (); | |
3403 | |
3404 retval.resize (a_nr, b_nc); | |
3405 float *c = retval.fortran_vec (); | |
3406 | |
3407 const char *ctransa = get_blas_trans_arg (transa); | |
3408 F77_XFCN (ssyrk, SSYRK, (F77_CONST_CHAR_ARG2 ("U", 1), | |
3409 F77_CONST_CHAR_ARG2 (ctransa, 1), | |
3410 a_nr, a_nc, 1.0, | |
3411 a.data (), lda, 0.0, c, a_nr | |
3412 F77_CHAR_ARG_LEN (1) | |
3413 F77_CHAR_ARG_LEN (1))); | |
3414 for (int j = 0; j < a_nr; j++) | |
3415 for (int i = 0; i < j; i++) | |
3416 retval.xelem (j,i) = retval.xelem (i,j); | |
3417 | |
3418 } | |
3382 else | 3419 else |
3383 { | 3420 { |
3384 octave_idx_type ld = nr; | 3421 octave_idx_type lda = a.rows (), tda = a.cols (); |
3385 octave_idx_type lda = a_nr; | 3422 octave_idx_type ldb = b.rows (), tdb = b.cols (); |
3386 | 3423 |
3387 retval.resize (nr, a_nc); | 3424 retval.resize (a_nr, b_nc); |
3388 float *c = retval.fortran_vec (); | 3425 float *c = retval.fortran_vec (); |
3389 | 3426 |
3390 if (a_nc == 1) | 3427 if (b_nc == 1) |
3391 { | 3428 { |
3392 if (nr == 1) | 3429 if (a_nr == 1) |
3393 F77_FUNC (xsdot, XSDOT) (nc, m.data (), 1, a.data (), 1, *c); | 3430 F77_FUNC (xsdot, XSDOT) (a_nc, a.data (), 1, b.data (), 1, *c); |
3394 else | 3431 else |
3395 { | 3432 { |
3396 F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 ("N", 1), | 3433 const char *ctransa = get_blas_trans_arg (transa); |
3397 nr, nc, 1.0, m.data (), ld, | 3434 F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1), |
3398 a.data (), 1, 0.0, c, 1 | 3435 lda, tda, 1.0, a.data (), lda, |
3436 b.data (), 1, 0.0, c, 1 | |
3399 F77_CHAR_ARG_LEN (1))); | 3437 F77_CHAR_ARG_LEN (1))); |
3400 } | 3438 } |
3401 } | 3439 } |
3440 else if (a_nr == 1) | |
3441 { | |
3442 const char *crevtransb = get_blas_trans_arg (! transb); | |
3443 F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1), | |
3444 ldb, tdb, 1.0, b.data (), ldb, | |
3445 a.data (), 1, 0.0, c, 1 | |
3446 F77_CHAR_ARG_LEN (1))); | |
3447 } | |
3402 else | 3448 else |
3403 { | 3449 { |
3404 F77_XFCN (sgemm, SGEMM, (F77_CONST_CHAR_ARG2 ("N", 1), | 3450 const char *ctransa = get_blas_trans_arg (transa); |
3405 F77_CONST_CHAR_ARG2 ("N", 1), | 3451 const char *ctransb = get_blas_trans_arg (transb); |
3406 nr, a_nc, nc, 1.0, m.data (), | 3452 F77_XFCN (sgemm, SGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1), |
3407 ld, a.data (), lda, 0.0, c, nr | 3453 F77_CONST_CHAR_ARG2 (ctransb, 1), |
3454 a_nr, b_nc, a_nc, 1.0, a.data (), | |
3455 lda, b.data (), ldb, 0.0, c, a_nr | |
3408 F77_CHAR_ARG_LEN (1) | 3456 F77_CHAR_ARG_LEN (1) |
3409 F77_CHAR_ARG_LEN (1))); | 3457 F77_CHAR_ARG_LEN (1))); |
3410 } | 3458 } |
3411 } | 3459 } |
3412 } | 3460 } |
3413 | 3461 |
3414 return retval; | 3462 return retval; |
3463 } | |
3464 | |
3465 FloatMatrix | |
3466 operator * (const FloatMatrix& a, const FloatMatrix& b) | |
3467 { | |
3468 return xgemm (false, a, false, b); | |
3415 } | 3469 } |
3416 | 3470 |
3417 // FIXME -- it would be nice to share code among the min/max | 3471 // FIXME -- it would be nice to share code among the min/max |
3418 // functions below. | 3472 // functions below. |
3419 | 3473 |