Mercurial > hg > octave-nkf
diff liboctave/MSparse.cc @ 13264:11c8b60f1b68
Eliminate duplicate code for op+= and op-= for MSparse
author | Jordi Gutiérrez Hermoso <jordigh@octave.org> |
---|---|
date | Mon, 03 Oct 2011 00:15:00 -0500 |
parents | 12df7854fa7c |
children | 89789bc755a1 |
line wrap: on
line diff
--- a/liboctave/MSparse.cc +++ b/liboctave/MSparse.cc @@ -25,6 +25,8 @@ #include <config.h> #endif +#include <functional> + #include "quit.h" #include "lo-error.h" #include "MArray.h" @@ -37,9 +39,9 @@ // Element by element MSparse by MSparse ops. -template <class T> +template <class T, class OP> MSparse<T>& -operator += (MSparse<T>& a, const MSparse<T>& b) +plus_or_minus (MSparse<T>& a, const MSparse<T>& b, OP op, const char* op_name) { MSparse<T> r; @@ -50,7 +52,7 @@ octave_idx_type b_nc = b.cols (); if (a_nr != b_nr || a_nc != b_nc) - gripe_nonconformant ("operator +=" , a_nr, a_nc, b_nr, b_nc); + gripe_nonconformant (op_name , a_nr, a_nc, b_nr, b_nc); else { r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ())); @@ -73,7 +75,7 @@ (ja_lt_max && (a.ridx(ja) < b.ridx(jb)))) { r.ridx(jx) = a.ridx(ja); - r.data(jx) = a.data(ja) + 0.; + r.data(jx) = op (a.data(ja), 0.); jx++; ja++; ja_lt_max= ja < ja_max; @@ -82,16 +84,16 @@ (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) ) { r.ridx(jx) = b.ridx(jb); - r.data(jx) = 0. + b.data(jb); + r.data(jx) = op (0., b.data(jb)); jx++; jb++; jb_lt_max= jb < jb_max; } else { - if ((a.data(ja) + b.data(jb)) != 0.) + if (op (a.data(ja), b.data(jb)) != 0.) { - r.data(jx) = a.data(ja) + b.data(jb); + r.data(jx) = op (a.data(ja), b.data(jb)); r.ridx(jx) = a.ridx(ja); jx++; } @@ -110,78 +112,20 @@ return a; } -template <class T> +template <typename T> +MSparse<T>& +operator += (MSparse<T>& a, const MSparse<T>& b) +{ + return plus_or_minus (a, b, std::plus<T> (), "operator +="); +} + +template <typename T> MSparse<T>& operator -= (MSparse<T>& a, const MSparse<T>& b) { - MSparse<T> r; - - octave_idx_type a_nr = a.rows (); - octave_idx_type a_nc = a.cols (); - - octave_idx_type b_nr = b.rows (); - octave_idx_type b_nc = b.cols (); - - if (a_nr != b_nr || a_nc != b_nc) - gripe_nonconformant ("operator -=" , a_nr, a_nc, b_nr, b_nc); - else - { - r = MSparse<T> (a_nr, a_nc, (a.nnz () + b.nnz ())); - - octave_idx_type jx = 0; - for (octave_idx_type i = 0 ; i < a_nc ; i++) - { - octave_idx_type ja = a.cidx(i); - octave_idx_type ja_max = a.cidx(i+1); - bool ja_lt_max= ja < ja_max; - - octave_idx_type jb = b.cidx(i); - octave_idx_type jb_max = b.cidx(i+1); - bool jb_lt_max = jb < jb_max; + return plus_or_minus (a, b, std::minus<T> (), "operator -="); +} - while (ja_lt_max || jb_lt_max ) - { - octave_quit (); - if ((! jb_lt_max) || - (ja_lt_max && (a.ridx(ja) < b.ridx(jb)))) - { - r.ridx(jx) = a.ridx(ja); - r.data(jx) = a.data(ja) - 0.; - jx++; - ja++; - ja_lt_max= ja < ja_max; - } - else if (( !ja_lt_max ) || - (jb_lt_max && (b.ridx(jb) < a.ridx(ja)) ) ) - { - r.ridx(jx) = b.ridx(jb); - r.data(jx) = 0. - b.data(jb); - jx++; - jb++; - jb_lt_max= jb < jb_max; - } - else - { - if ((a.data(ja) - b.data(jb)) != 0.) - { - r.data(jx) = a.data(ja) - b.data(jb); - r.ridx(jx) = a.ridx(ja); - jx++; - } - ja++; - ja_lt_max= ja < ja_max; - jb++; - jb_lt_max= jb < jb_max; - } - } - r.cidx(i+1) = jx; - } - - a = r.maybe_compress (); - } - - return a; -} // Element by element MSparse by scalar ops.