# HG changeset patch # User Jaroslav Hajek # Date 1228642174 -3600 # Node ID dbe67764e628c6a211bdef4410a69898b8b69723 # Parent ad8ed668e0a44e97983e2e05b80d3439be50ccb5 fix & improve speed of diagonal matrix multiplication diff --git a/liboctave/ChangeLog b/liboctave/ChangeLog --- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,3 +1,12 @@ +2008-12-07 Jaroslav Hajek + + * mx-inlines.cc (mx_inline_fill_vs): New template function. + * mx-op-defs.h (everywhere): Replace int by octave_idx_type. + (MDM_MULTIPLY_OP): Use mx_inline_mul_vs and mx_inline_fill_vs. + (DMM_MULTIPLY_OP): Dtto. + * fDiagMatrix.cc (operator *): Remove redundant ifs. + * fCDiagMatrix.cc (operator *): Remove redundant ifs. + 2008-12-06 Jaroslav Hajek * oct-locbuf.h (OCTAVE_LOCAL_BUFFER_INIT): New macro. diff --git a/liboctave/fCDiagMatrix.cc b/liboctave/fCDiagMatrix.cc --- a/liboctave/fCDiagMatrix.cc +++ b/liboctave/fCDiagMatrix.cc @@ -455,14 +455,7 @@ FloatComplex a_element = a.elem (i, i); float b_element = b.elem (i, i); - if (a_element == static_cast (0.0) || b_element == static_cast (0.0)) - c.elem (i, i) = 0; - else if (a_element == static_cast (1.0)) - c.elem (i, i) = b_element; - else if (b_element == static_cast (1.0)) - c.elem (i, i) = a_element; - else - c.elem (i, i) = a_element * b_element; + c.elem (i, i) = a_element * b_element; } return c; @@ -495,14 +488,7 @@ float a_element = a.elem (i, i); FloatComplex b_element = b.elem (i, i); - if (a_element == static_cast (0.0) || b_element == static_cast (0.0)) - c.elem (i, i) = 0; - else if (a_element == static_cast (1.0)) - c.elem (i, i) = b_element; - else if (b_element == static_cast (1.0)) - c.elem (i, i) = a_element; - else - c.elem (i, i) = a_element * b_element; + c.elem (i, i) = a_element * b_element; } return c; @@ -535,14 +521,7 @@ FloatComplex a_element = a.elem (i, i); FloatComplex b_element = b.elem (i, i); - if (a_element == static_cast (0.0) || b_element == static_cast (0.0)) - c.elem (i, i) = 0; - else if (a_element == static_cast (1.0)) - c.elem (i, i) = b_element; - else if (b_element == static_cast (1.0)) - c.elem (i, i) = a_element; - else - c.elem (i, i) = a_element * b_element; + c.elem (i, i) = a_element * b_element; } return c; diff --git a/liboctave/fDiagMatrix.cc b/liboctave/fDiagMatrix.cc --- a/liboctave/fDiagMatrix.cc +++ b/liboctave/fDiagMatrix.cc @@ -334,14 +334,7 @@ float a_element = a.elem (i, i); float b_element = b.elem (i, i); - if (a_element == 0.0 || b_element == 0.0) - c.elem (i, i) = 0.0; - else if (a_element == 1.0) - c.elem (i, i) = b_element; - else if (b_element == 1.0) - c.elem (i, i) = a_element; - else - c.elem (i, i) = a_element * b_element; + c.elem (i, i) = a_element * b_element; } return c; diff --git a/liboctave/mx-inlines.cc b/liboctave/mx-inlines.cc --- a/liboctave/mx-inlines.cc +++ b/liboctave/mx-inlines.cc @@ -30,6 +30,14 @@ #include "oct-cmplx.h" +template +inline void +mx_inline_fill_vs (R *r, size_t n, S s) +{ + for (size_t i = 0; i < n; i++) + r[i] = s; +} + #define VS_OP_FCN(F, OP) \ template \ inline void \ diff --git a/liboctave/mx-op-defs.h b/liboctave/mx-op-defs.h --- a/liboctave/mx-op-defs.h +++ b/liboctave/mx-op-defs.h @@ -2,6 +2,7 @@ Copyright (C) 1996, 1997, 1998, 2000, 2001, 2003, 2004, 2005, 2006, 2007 John W. Eaton +Copyright (C) 2008 Jaroslav Hajek This file is part of Octave. @@ -24,6 +25,7 @@ #if !defined (octave_mx_op_defs_h) #define octave_mx_op_defs_h 1 +#include "oct-types.h" #include "mx-inlines.cc" #define BIN_OP_DECL(R, OP, X, Y, API) \ @@ -56,11 +58,11 @@ R \ F (const V& v, const S& s) \ { \ - int len = v.length (); \ + octave_idx_type len = v.length (); \ \ R r (len); \ \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ r.elem(i) = v.elem(i) OP s; \ \ return r; \ @@ -87,11 +89,11 @@ R \ F (const S& s, const V& v) \ { \ - int len = v.length (); \ + octave_idx_type len = v.length (); \ \ R r (len); \ \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ r.elem(i) = s OP v.elem(i); \ \ return r; \ @@ -120,8 +122,8 @@ { \ R r; \ \ - int v1_len = v1.length (); \ - int v2_len = v2.length (); \ + octave_idx_type v1_len = v1.length (); \ + octave_idx_type v2_len = v2.length (); \ \ if (v1_len != v2_len) \ gripe_nonconformant (#OP, v1_len, v2_len); \ @@ -129,7 +131,7 @@ { \ r.resize (v1_len); \ \ - for (int i = 0; i < v1_len; i++) \ + for (octave_idx_type i = 0; i < v1_len; i++) \ r.elem(i) = v1.elem(i) OP v2.elem(i); \ } \ \ @@ -157,8 +159,8 @@ R \ OP (const M& m, const S& s) \ { \ - int nr = m.rows (); \ - int nc = m.cols (); \ + octave_idx_type nr = m.rows (); \ + octave_idx_type nc = m.cols (); \ \ R r (nr, nc); \ \ @@ -188,15 +190,15 @@ { \ boolMatrix r; \ \ - int nr = m.rows (); \ - int nc = m.cols (); \ + octave_idx_type nr = m.rows (); \ + octave_idx_type nc = m.cols (); \ \ r.resize (nr, nc); \ \ if (nr > 0 && nc > 0) \ { \ - for (int j = 0; j < nc; j++) \ - for (int i = 0; i < nr; i++) \ + for (octave_idx_type j = 0; j < nc; j++) \ + for (octave_idx_type i = 0; i < nr; i++) \ r.elem(i, j) = MC (m.elem(i, j)) OP SC (s); \ } \ \ @@ -221,8 +223,8 @@ { \ boolMatrix r; \ \ - int nr = m.rows (); \ - int nc = m.cols (); \ + octave_idx_type nr = m.rows (); \ + octave_idx_type nc = m.cols (); \ \ if (nr != 0 && nc != 0) \ { \ @@ -233,8 +235,8 @@ else \ { \ \ - for (int j = 0; j < nc; j++) \ - for (int i = 0; i < nr; i++) \ + for (octave_idx_type j = 0; j < nc; j++) \ + for (octave_idx_type i = 0; i < nr; i++) \ if (xisnan (m.elem(i, j))) \ { \ gripe_nan_to_logical_conversion (); \ @@ -272,8 +274,8 @@ R \ OP (const S& s, const M& m) \ { \ - int nr = m.rows (); \ - int nc = m.cols (); \ + octave_idx_type nr = m.rows (); \ + octave_idx_type nc = m.cols (); \ \ R r (nr, nc); \ \ @@ -303,15 +305,15 @@ { \ boolMatrix r; \ \ - int nr = m.rows (); \ - int nc = m.cols (); \ + octave_idx_type nr = m.rows (); \ + octave_idx_type nc = m.cols (); \ \ r.resize (nr, nc); \ \ if (nr > 0 && nc > 0) \ { \ - for (int j = 0; j < nc; j++) \ - for (int i = 0; i < nr; i++) \ + for (octave_idx_type j = 0; j < nc; j++) \ + for (octave_idx_type i = 0; i < nr; i++) \ r.elem(i, j) = SC (s) OP MC (m.elem(i, j)); \ } \ \ @@ -336,8 +338,8 @@ { \ boolMatrix r; \ \ - int nr = m.rows (); \ - int nc = m.cols (); \ + octave_idx_type nr = m.rows (); \ + octave_idx_type nc = m.cols (); \ \ if (nr != 0 && nc != 0) \ { \ @@ -347,8 +349,8 @@ gripe_nan_to_logical_conversion (); \ else \ { \ - for (int j = 0; j < nc; j++) \ - for (int i = 0; i < nr; i++) \ + for (octave_idx_type j = 0; j < nc; j++) \ + for (octave_idx_type i = 0; i < nr; i++) \ if (xisnan (m.elem(i, j))) \ { \ gripe_nan_to_logical_conversion (); \ @@ -388,11 +390,11 @@ { \ R r; \ \ - int m1_nr = m1.rows (); \ - int m1_nc = m1.cols (); \ + octave_idx_type m1_nr = m1.rows (); \ + octave_idx_type m1_nc = m1.cols (); \ \ - int m2_nr = m2.rows (); \ - int m2_nc = m2.cols (); \ + octave_idx_type m2_nr = m2.rows (); \ + octave_idx_type m2_nc = m2.cols (); \ \ if (m1_nr != m2_nr || m1_nc != m2_nc) \ gripe_nonconformant (#OP, m1_nr, m1_nc, m2_nr, m2_nc); \ @@ -427,18 +429,18 @@ { \ boolMatrix r; \ \ - int m1_nr = m1.rows (); \ - int m1_nc = m1.cols (); \ + octave_idx_type m1_nr = m1.rows (); \ + octave_idx_type m1_nc = m1.cols (); \ \ - int m2_nr = m2.rows (); \ - int m2_nc = m2.cols (); \ + octave_idx_type m2_nr = m2.rows (); \ + octave_idx_type m2_nc = m2.cols (); \ \ if (m1_nr == m2_nr && m1_nc == m2_nc) \ { \ r.resize (m1_nr, m1_nc); \ \ - for (int j = 0; j < m1_nc; j++) \ - for (int i = 0; i < m1_nr; i++) \ + for (octave_idx_type j = 0; j < m1_nc; j++) \ + for (octave_idx_type i = 0; i < m1_nr; i++) \ r.elem(i, j) = C1 (m1.elem(i, j)) OP C2 (m2.elem(i, j)); \ } \ else \ @@ -465,11 +467,11 @@ { \ boolMatrix r; \ \ - int m1_nr = m1.rows (); \ - int m1_nc = m1.cols (); \ + octave_idx_type m1_nr = m1.rows (); \ + octave_idx_type m1_nc = m1.cols (); \ \ - int m2_nr = m2.rows (); \ - int m2_nc = m2.cols (); \ + octave_idx_type m2_nr = m2.rows (); \ + octave_idx_type m2_nc = m2.cols (); \ \ if (m1_nr == m2_nr && m1_nc == m2_nc) \ { \ @@ -477,8 +479,8 @@ { \ r.resize (m1_nr, m1_nc); \ \ - for (int j = 0; j < m1_nc; j++) \ - for (int i = 0; i < m1_nr; i++) \ + for (octave_idx_type j = 0; j < m1_nc; j++) \ + for (octave_idx_type i = 0; i < m1_nr; i++) \ if (xisnan (m1.elem(i, j)) || xisnan (m2.elem(i, j))) \ { \ gripe_nan_to_logical_conversion (); \ @@ -524,7 +526,7 @@ { \ R r (m.dims ()); \ \ - int len = m.length (); \ + octave_idx_type len = m.length (); \ \ if (len > 0) \ F ## _vs (r.fortran_vec (), m.data (), len, s); \ @@ -552,11 +554,11 @@ { \ boolNDArray r; \ \ - int len = m.length (); \ + octave_idx_type len = m.length (); \ \ r.resize (m.dims ()); \ \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ r.elem(i) = NDC (m.elem(i)) OP SC (s); \ \ return r; \ @@ -576,11 +578,11 @@ { \ boolNDArray r; \ \ - int len = m.length (); \ + octave_idx_type len = m.length (); \ \ r.resize (m.dims ()); \ \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ r.elem(i) = operator OP (NDC (m.elem(i)), SC (s)); \ \ return r; \ @@ -600,11 +602,11 @@ { \ boolNDArray r; \ \ - int len = m.length (); \ + octave_idx_type len = m.length (); \ \ r.resize (m.dims ()); \ \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ r.elem(i) = operator OP (NDC (m.elem(i)), SC (s)); \ \ return r; \ @@ -628,7 +630,7 @@ { \ boolNDArray r; \ \ - int len = m.length (); \ + octave_idx_type len = m.length (); \ \ if (len > 0) \ { \ @@ -638,7 +640,7 @@ gripe_nan_to_logical_conversion (); \ else \ { \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ if (xisnan (m.elem(i))) \ { \ gripe_nan_to_logical_conversion (); \ @@ -678,7 +680,7 @@ { \ R r (m.dims ()); \ \ - int len = m.length (); \ + octave_idx_type len = m.length (); \ \ if (len > 0) \ F ## _sv (r.fortran_vec (), s, m.data (), len); \ @@ -706,11 +708,11 @@ { \ boolNDArray r; \ \ - int len = m.length (); \ + octave_idx_type len = m.length (); \ \ r.resize (m.dims ()); \ \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ r.elem(i) = SC (s) OP NDC (m.elem(i)); \ \ return r; \ @@ -730,11 +732,11 @@ { \ boolNDArray r; \ \ - int len = m.length (); \ + octave_idx_type len = m.length (); \ \ r.resize (m.dims ()); \ \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ r.elem(i) = operator OP (SC (s), NDC (m.elem(i))); \ \ return r; \ @@ -754,11 +756,11 @@ { \ boolNDArray r; \ \ - int len = m.length (); \ + octave_idx_type len = m.length (); \ \ r.resize (m.dims ()); \ \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ r.elem(i) = operator OP (SC (s), NDC (m.elem(i))); \ \ return r; \ @@ -782,7 +784,7 @@ { \ boolNDArray r; \ \ - int len = m.length (); \ + octave_idx_type len = m.length (); \ \ if (len > 0) \ { \ @@ -792,7 +794,7 @@ gripe_nan_to_logical_conversion (); \ else \ { \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ if (xisnan (m.elem(i))) \ { \ gripe_nan_to_logical_conversion (); \ @@ -841,7 +843,7 @@ { \ r.resize (m1_dims); \ \ - int len = m1.length (); \ + octave_idx_type len = m1.length (); \ \ if (len > 0) \ F ## _vv (r.fortran_vec (), m1.data (), m2.data (), len); \ @@ -877,7 +879,7 @@ { \ r.resize (m1_dims); \ \ - for (int i = 0; i < m1.length (); i++) \ + for (octave_idx_type i = 0; i < m1.length (); i++) \ r.elem(i) = C1 (m1.elem(i)) OP C2 (m2.elem(i)); \ } \ else \ @@ -913,7 +915,7 @@ { \ r.resize (m1_dims); \ \ - for (int i = 0; i < m1.length (); i++) \ + for (octave_idx_type i = 0; i < m1.length (); i++) \ if (xisnan (m1.elem(i)) || xisnan (m2.elem(i))) \ { \ gripe_nan_to_logical_conversion (); \ @@ -951,12 +953,12 @@ R \ OP (const S& s, const DM& dm) \ { \ - int nr = dm.rows (); \ - int nc = dm.cols (); \ + octave_idx_type nr = dm.rows (); \ + octave_idx_type nc = dm.cols (); \ \ R r (nr, nc, s); \ \ - for (int i = 0; i < dm.length (); i++) \ + for (octave_idx_type i = 0; i < dm.length (); i++) \ r.elem(i, i) OPEQ dm.elem(i, i); \ \ return r; \ @@ -979,12 +981,12 @@ R \ OP (const DM& dm, const S& s) \ { \ - int nr = dm.rows (); \ - int nc = dm.cols (); \ + octave_idx_type nr = dm.rows (); \ + octave_idx_type nc = dm.cols (); \ \ R r (nr, nc, SGN s); \ \ - for (int i = 0; i < dm.length (); i++) \ + for (octave_idx_type i = 0; i < dm.length (); i++) \ r.elem(i, i) += dm.elem(i, i); \ \ return r; \ @@ -1010,11 +1012,11 @@ { \ R r; \ \ - int m_nr = m.rows (); \ - int m_nc = m.cols (); \ + octave_idx_type m_nr = m.rows (); \ + octave_idx_type m_nc = m.cols (); \ \ - int dm_nr = dm.rows (); \ - int dm_nc = dm.cols (); \ + octave_idx_type dm_nr = dm.rows (); \ + octave_idx_type dm_nc = dm.cols (); \ \ if (m_nr != dm_nr || m_nc != dm_nc) \ gripe_nonconformant (#OP, m_nr, m_nc, dm_nr, dm_nc); \ @@ -1026,9 +1028,9 @@ { \ r = R (m); \ \ - int len = dm.length (); \ + octave_idx_type len = dm.length (); \ \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ r.elem(i, i) OPEQ dm.elem(i, i); \ } \ } \ @@ -1042,29 +1044,28 @@ { \ R r; \ \ - int m_nr = m.rows (); \ - int m_nc = m.cols (); \ + octave_idx_type m_nr = m.rows (); \ + octave_idx_type m_nc = m.cols (); \ \ - int dm_nr = dm.rows (); \ - int dm_nc = dm.cols (); \ + octave_idx_type dm_nr = dm.rows (); \ + octave_idx_type dm_nc = dm.cols (); \ \ if (m_nc != dm_nr) \ gripe_nonconformant ("operator *", m_nr, m_nc, dm_nr, dm_nc); \ else \ { \ r = R (m_nr, dm_nc); \ - \ - if (m_nr > 0 && m_nc > 0 && dm_nc > 0) \ - { \ - int len = dm.length (); \ + R::element_type *rd = r.fortran_vec (); \ + const M::element_type *md = m.data (); \ + const DM::element_type *dd = dm.data (); \ \ - for (int j = 0; j < len; j++) \ - { \ - const DM::element_type djj = dm.elem (j, j); \ - for (int i = 0; i < m_nr; i++) \ - r.xelem (i, j) = djj * m.elem (i, j); \ - } \ - } \ + octave_idx_type len = dm.length (); \ + for (octave_idx_type i = 0; i < len; i++) \ + { \ + mx_inline_multiply_vs (rd, md, m_nr, dd[i]); \ + rd += m_nr; md += m_nr; \ + } \ + mx_inline_fill_vs (rd, m_nr * (dm_nc - len), R_ZERO); \ } \ \ return r; \ @@ -1091,11 +1092,11 @@ { \ R r; \ \ - int dm_nr = dm.rows (); \ - int dm_nc = dm.cols (); \ + octave_idx_type dm_nr = dm.rows (); \ + octave_idx_type dm_nc = dm.cols (); \ \ - int m_nr = m.rows (); \ - int m_nc = m.cols (); \ + octave_idx_type m_nr = m.rows (); \ + octave_idx_type m_nc = m.cols (); \ \ if (dm_nr != m_nr || dm_nc != m_nc) \ gripe_nonconformant (#OP, dm_nr, dm_nc, m_nr, m_nc); \ @@ -1105,9 +1106,9 @@ { \ r = R (PREOP m); \ \ - int len = dm.length (); \ + octave_idx_type len = dm.length (); \ \ - for (int i = 0; i < len; i++) \ + for (octave_idx_type i = 0; i < len; i++) \ r.elem(i, i) OPEQ dm.elem(i, i); \ } \ else \ @@ -1123,28 +1124,29 @@ { \ R r; \ \ - int dm_nr = dm.rows (); \ - int dm_nc = dm.cols (); \ + octave_idx_type dm_nr = dm.rows (); \ + octave_idx_type dm_nc = dm.cols (); \ \ - int m_nr = m.rows (); \ - int m_nc = m.cols (); \ + octave_idx_type m_nr = m.rows (); \ + octave_idx_type m_nc = m.cols (); \ \ if (dm_nc != m_nr) \ gripe_nonconformant ("operator *", dm_nr, dm_nc, m_nr, m_nc); \ else \ { \ r = R (dm_nr, m_nc); \ - \ - if (dm_nr > 0 && dm_nc > 0 && m_nc > 0) \ - { \ - int len = dm.length (); \ + R::element_type *rd = r.fortran_vec (); \ + const M::element_type *md = m.data (); \ + const DM::element_type *dd = dm.data (); \ \ - for (int i = 0; i < len; i++) \ - { \ - for (int j = 0; j < m_nc; j++) \ - r.xelem (i, j) = dm.elem (i, i) * m.elem (i, j); \ - } \ - } \ + octave_idx_type len = dm.length (); \ + for (octave_idx_type i = 0; i < m_nc; i++) \ + { \ + mx_inline_multiply_vv (rd, md, dd, len); \ + rd += len; md += m_nr; \ + mx_inline_fill_vs (rd, dm_nr - len, R_ZERO); \ + rd += dm_nr - len; \ + } \ } \ \ return r; \ @@ -1171,11 +1173,11 @@ { \ R r; \ \ - int dm1_nr = dm1.rows (); \ - int dm1_nc = dm1.cols (); \ + octave_idx_type dm1_nr = dm1.rows (); \ + octave_idx_type dm1_nc = dm1.cols (); \ \ - int dm2_nr = dm2.rows (); \ - int dm2_nc = dm2.cols (); \ + octave_idx_type dm2_nr = dm2.rows (); \ + octave_idx_type dm2_nc = dm2.cols (); \ \ if (dm1_nr != dm2_nr || dm1_nc != dm2_nc) \ gripe_nonconformant (#OP, dm1_nr, dm1_nc, dm2_nr, dm2_nc); \