Mercurial > hg > octave-nkf
diff liboctave/Sparse-op-defs.h @ 7802:1a446f28ce68
implement optimized sparse-dense transposed multiplication
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Sun, 18 May 2008 20:23:31 +0200 |
parents | 288614c6634d |
children | 9bcb31cc56be |
line wrap: on
line diff
--- a/liboctave/Sparse-op-defs.h +++ b/liboctave/Sparse-op-defs.h @@ -1904,15 +1904,7 @@ \ if (nr == 1 && nc == 1) \ { \ - RET_TYPE retval (a_nr, a_nc, ZERO); \ - for (octave_idx_type i = 0; i < a_nc ; i++) \ - { \ - for (octave_idx_type j = 0; j < a_nr; j++) \ - { \ - OCTAVE_QUIT; \ - retval.elem (j,i) += a.elem(j,i) * m.elem(0,0); \ - } \ - } \ + RET_TYPE retval = m.elem (0,0) * a; \ return retval; \ } \ else if (nc != a_nr) \ @@ -1925,15 +1917,51 @@ RET_TYPE retval (nr, a_nc, ZERO); \ \ for (octave_idx_type i = 0; i < a_nc ; i++) \ - { \ - for (octave_idx_type j = 0; j < a_nr; j++) \ - { \ + { \ + for (octave_idx_type j = 0; j < a_nr; j++) \ + { \ OCTAVE_QUIT; \ - \ + \ EL_TYPE tmpval = a.elem(j,i); \ - for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \ - retval.elem (m.ridx(k),i) += tmpval * m.data(k); \ - } \ + for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \ + retval.elem (m.ridx(k),i) += tmpval * m.data(k); \ + } \ + } \ + return retval; \ + } + +#define SPARSE_FULL_TRANS_MUL( RET_TYPE, EL_TYPE, ZERO, CONJ_OP ) \ + octave_idx_type nr = m.rows (); \ + octave_idx_type nc = m.cols (); \ + \ + octave_idx_type a_nr = a.rows (); \ + octave_idx_type a_nc = a.cols (); \ + \ + if (nr == 1 && nc == 1) \ + { \ + RET_TYPE retval = CONJ_OP (m.elem(0,0)) * a; \ + return retval; \ + } \ + else if (nr != a_nr) \ + { \ + gripe_nonconformant ("operator *", nc, nr, a_nr, a_nc); \ + return RET_TYPE (); \ + } \ + else \ + { \ + RET_TYPE retval (nc, a_nc); \ + \ + for (octave_idx_type i = 0; i < a_nc ; i++) \ + { \ + for (octave_idx_type j = 0; j < nc; j++) \ + { \ + OCTAVE_QUIT; \ + \ + EL_TYPE acc = ZERO; \ + for (octave_idx_type k = m.cidx(j) ; k < m.cidx(j+1); k++) \ + acc += a.elem (m.ridx(k),i) * CONJ_OP (m.data(k)); \ + retval.xelem (j,i) = acc; \ + } \ } \ return retval; \ } @@ -1947,15 +1975,7 @@ \ if (a_nr == 1 && a_nc == 1) \ { \ - RET_TYPE retval (nr, nc, ZERO); \ - for (octave_idx_type i = 0; i < nc ; i++) \ - { \ - for (octave_idx_type j = 0; j < nr; j++) \ - { \ - OCTAVE_QUIT; \ - retval.elem (j,i) += a.elem(0,0) * m.elem(j,i); \ - } \ - } \ + RET_TYPE retval = m * a.elem (0,0); \ return retval; \ } \ else if (nc != a_nr) \ @@ -1968,16 +1988,51 @@ RET_TYPE retval (nr, a_nc, ZERO); \ \ for (octave_idx_type i = 0; i < a_nc ; i++) \ - { \ - for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \ - { \ - octave_idx_type col = a.ridx(j); \ - EL_TYPE tmpval = a.data(j); \ - OCTAVE_QUIT; \ - \ - for (octave_idx_type k = 0 ; k < nr; k++) \ - retval.elem (k,i) += tmpval * m.elem(k,col); \ - } \ + { \ + OCTAVE_QUIT; \ + for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \ + { \ + octave_idx_type col = a.ridx(j); \ + EL_TYPE tmpval = a.data(j); \ + \ + for (octave_idx_type k = 0 ; k < nr; k++) \ + retval.xelem (k,i) += tmpval * m.elem(k,col); \ + } \ + } \ + return retval; \ + } + +#define FULL_SPARSE_MUL_TRANS( RET_TYPE, EL_TYPE, ZERO, CONJ_OP ) \ + octave_idx_type nr = m.rows (); \ + octave_idx_type nc = m.cols (); \ + \ + octave_idx_type a_nr = a.rows (); \ + octave_idx_type a_nc = a.cols (); \ + \ + if (a_nr == 1 && a_nc == 1) \ + { \ + RET_TYPE retval = m * CONJ_OP (a.elem(0,0)); \ + return retval; \ + } \ + else if (nc != a_nc) \ + { \ + gripe_nonconformant ("operator *", nr, nc, a_nc, a_nr); \ + return RET_TYPE (); \ + } \ + else \ + { \ + RET_TYPE retval (nr, a_nr, ZERO); \ + \ + for (octave_idx_type i = 0; i < a_nc ; i++) \ + { \ + OCTAVE_QUIT; \ + for (octave_idx_type j = a.cidx(i); j < a.cidx(i+1); j++) \ + { \ + octave_idx_type col = a.ridx(j); \ + EL_TYPE tmpval = CONJ_OP (a.data(j)); \ + for (octave_idx_type k = 0 ; k < nr; k++) \ + retval.xelem (k,col) += tmpval * m.elem(k,i); \ + } \ } \ return retval; \ }