comparison liboctave/fMatrix.cc @ 7804:a0c550b22e61

compound ops for float matrices
author Jaroslav Hajek <highegg@gmail.com>
date Wed, 21 May 2008 19:25:08 +0200
parents f42c6f8d6d8e
children 935be827eaf8
comparison
equal deleted inserted replaced
7803:9bcb31cc56be 7804:a0c550b22e61
1 // Matrix manipulations. 1 // Matrix manipulations.
2 /* 2 /*
3 3
4 Copyright (C) 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002, 4 Copyright (C) 1994, 1995, 1996, 1997, 1998, 1999, 2000, 2001, 2002,
5 2003, 2004, 2005, 2006, 2007 John W. Eaton 5 2003, 2004, 2005, 2006, 2007 John W. Eaton
6 Copyright (C) 2008 Jaroslav Hajek
6 7
7 This file is part of Octave. 8 This file is part of Octave.
8 9
9 Octave is free software; you can redistribute it and/or modify it 10 Octave is free software; you can redistribute it and/or modify it
10 under the terms of the GNU General Public License as published by the 11 under the terms of the GNU General Public License as published by the
101 F77_CHAR_ARG_LEN_DECL); 102 F77_CHAR_ARG_LEN_DECL);
102 103
103 F77_RET_T 104 F77_RET_T
104 F77_FUNC (xsdot, XSDOT) (const octave_idx_type&, const float*, const octave_idx_type&, 105 F77_FUNC (xsdot, XSDOT) (const octave_idx_type&, const float*, const octave_idx_type&,
105 const float*, const octave_idx_type&, float&); 106 const float*, const octave_idx_type&, float&);
107
108 F77_RET_T
109 F77_FUNC (ssyrk, SSYRK) (F77_CONST_CHAR_ARG_DECL,
110 F77_CONST_CHAR_ARG_DECL,
111 const octave_idx_type&, const octave_idx_type&,
112 const float&, const float*, const octave_idx_type&,
113 const float&, float*, const octave_idx_type&
114 F77_CHAR_ARG_LEN_DECL
115 F77_CHAR_ARG_LEN_DECL);
106 116
107 F77_RET_T 117 F77_RET_T
108 F77_FUNC (sgetrf, SGETRF) (const octave_idx_type&, const octave_idx_type&, float*, const octave_idx_type&, 118 F77_FUNC (sgetrf, SGETRF) (const octave_idx_type&, const octave_idx_type&, float*, const octave_idx_type&,
109 octave_idx_type*, octave_idx_type&); 119 octave_idx_type*, octave_idx_type&);
110 120
3359 %!assert([M*cv,M*cv],M*[cv,cv],1e-14) 3369 %!assert([M*cv,M*cv],M*[cv,cv],1e-14)
3360 %!assert([rv*M;rv*M],[rv;rv]*M,1e-14) 3370 %!assert([rv*M;rv*M],[rv;rv]*M,1e-14)
3361 %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14) 3371 %!assert(2*rv*cv,[rv,rv]*[cv;cv],1e-14)
3362 */ 3372 */
3363 3373
3364 3374 static const char *
3365 FloatMatrix 3375 get_blas_trans_arg (bool trans)
3366 operator * (const FloatMatrix& m, const FloatMatrix& a) 3376 {
3377 static char blas_notrans = 'N', blas_trans = 'T';
3378 return (trans) ? &blas_trans : &blas_notrans;
3379 }
3380
3381 // the general GEMM operation
3382
3383 FloatMatrix
3384 xgemm (bool transa, const FloatMatrix& a, bool transb, const FloatMatrix& b)
3367 { 3385 {
3368 FloatMatrix retval; 3386 FloatMatrix retval;
3369 3387
3370 octave_idx_type nr = m.rows (); 3388 octave_idx_type a_nr = transa ? a.cols () : a.rows ();
3371 octave_idx_type nc = m.cols (); 3389 octave_idx_type a_nc = transa ? a.rows () : a.cols ();
3372 3390
3373 octave_idx_type a_nr = a.rows (); 3391 octave_idx_type b_nr = transb ? b.cols () : b.rows ();
3374 octave_idx_type a_nc = a.cols (); 3392 octave_idx_type b_nc = transb ? b.rows () : b.cols ();
3375 3393
3376 if (nc != a_nr) 3394 if (a_nc != b_nr)
3377 gripe_nonconformant ("operator *", nr, nc, a_nr, a_nc); 3395 gripe_nonconformant ("operator *", a_nr, a_nc, b_nr, b_nc);
3378 else 3396 else
3379 { 3397 {
3380 if (nr == 0 || nc == 0 || a_nc == 0) 3398 if (a_nr == 0 || a_nc == 0 || b_nc == 0)
3381 retval.resize (nr, a_nc, 0.0); 3399 retval.resize (a_nr, b_nc, 0.0);
3400 else if (a.data () == b.data () && a_nr == b_nc && transa != transb)
3401 {
3402 octave_idx_type lda = a.rows ();
3403
3404 retval.resize (a_nr, b_nc);
3405 float *c = retval.fortran_vec ();
3406
3407 const char *ctransa = get_blas_trans_arg (transa);
3408 F77_XFCN (ssyrk, SSYRK, (F77_CONST_CHAR_ARG2 ("U", 1),
3409 F77_CONST_CHAR_ARG2 (ctransa, 1),
3410 a_nr, a_nc, 1.0,
3411 a.data (), lda, 0.0, c, a_nr
3412 F77_CHAR_ARG_LEN (1)
3413 F77_CHAR_ARG_LEN (1)));
3414 for (int j = 0; j < a_nr; j++)
3415 for (int i = 0; i < j; i++)
3416 retval.xelem (j,i) = retval.xelem (i,j);
3417
3418 }
3382 else 3419 else
3383 { 3420 {
3384 octave_idx_type ld = nr; 3421 octave_idx_type lda = a.rows (), tda = a.cols ();
3385 octave_idx_type lda = a_nr; 3422 octave_idx_type ldb = b.rows (), tdb = b.cols ();
3386 3423
3387 retval.resize (nr, a_nc); 3424 retval.resize (a_nr, b_nc);
3388 float *c = retval.fortran_vec (); 3425 float *c = retval.fortran_vec ();
3389 3426
3390 if (a_nc == 1) 3427 if (b_nc == 1)
3391 { 3428 {
3392 if (nr == 1) 3429 if (a_nr == 1)
3393 F77_FUNC (xsdot, XSDOT) (nc, m.data (), 1, a.data (), 1, *c); 3430 F77_FUNC (xsdot, XSDOT) (a_nc, a.data (), 1, b.data (), 1, *c);
3394 else 3431 else
3395 { 3432 {
3396 F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 ("N", 1), 3433 const char *ctransa = get_blas_trans_arg (transa);
3397 nr, nc, 1.0, m.data (), ld, 3434 F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (ctransa, 1),
3398 a.data (), 1, 0.0, c, 1 3435 lda, tda, 1.0, a.data (), lda,
3436 b.data (), 1, 0.0, c, 1
3399 F77_CHAR_ARG_LEN (1))); 3437 F77_CHAR_ARG_LEN (1)));
3400 } 3438 }
3401 } 3439 }
3440 else if (a_nr == 1)
3441 {
3442 const char *crevtransb = get_blas_trans_arg (! transb);
3443 F77_XFCN (sgemv, SGEMV, (F77_CONST_CHAR_ARG2 (crevtransb, 1),
3444 ldb, tdb, 1.0, b.data (), ldb,
3445 a.data (), 1, 0.0, c, 1
3446 F77_CHAR_ARG_LEN (1)));
3447 }
3402 else 3448 else
3403 { 3449 {
3404 F77_XFCN (sgemm, SGEMM, (F77_CONST_CHAR_ARG2 ("N", 1), 3450 const char *ctransa = get_blas_trans_arg (transa);
3405 F77_CONST_CHAR_ARG2 ("N", 1), 3451 const char *ctransb = get_blas_trans_arg (transb);
3406 nr, a_nc, nc, 1.0, m.data (), 3452 F77_XFCN (sgemm, SGEMM, (F77_CONST_CHAR_ARG2 (ctransa, 1),
3407 ld, a.data (), lda, 0.0, c, nr 3453 F77_CONST_CHAR_ARG2 (ctransb, 1),
3454 a_nr, b_nc, a_nc, 1.0, a.data (),
3455 lda, b.data (), ldb, 0.0, c, a_nr
3408 F77_CHAR_ARG_LEN (1) 3456 F77_CHAR_ARG_LEN (1)
3409 F77_CHAR_ARG_LEN (1))); 3457 F77_CHAR_ARG_LEN (1)));
3410 } 3458 }
3411 } 3459 }
3412 } 3460 }
3413 3461
3414 return retval; 3462 return retval;
3463 }
3464
3465 FloatMatrix
3466 operator * (const FloatMatrix& a, const FloatMatrix& b)
3467 {
3468 return xgemm (false, a, false, b);
3415 } 3469 }
3416 3470
3417 // FIXME -- it would be nice to share code among the min/max 3471 // FIXME -- it would be nice to share code among the min/max
3418 // functions below. 3472 // functions below.
3419 3473