Mercurial > hg > octave-max
changeset 13005:4061106b1c4b
Enable automatic bsxfun for power operators
* bsxfun.h: Put #include guards
* int8NDArray.cc: Define bsxfun power operator for integral types.
* int16NDArray.cc: Ditto.
* int32NDArray.cc: Ditto.
* int64NDArray.cc: fDitto.
* uint8ADArray.cc: Ditto.
* uint16NDArray.cc: Ditto.
* uint32NDArray.cc: Ditto.
* uint64NDArray.cc: Ditto.
* mx-inlines.cc: Let the compiler decide to use Octave's own integral pow.
* op-int.h: Call bsxfun for integral operators.
* xpow.cc: Call bsxfun for float operators.
author | Jordi Gutiérrez Hermoso <jordigh@gmail.com> |
---|---|
date | Wed, 24 Aug 2011 23:12:28 -0500 |
parents | d9d65c3017c3 |
children | 61be447052c3 |
files | liboctave/bsxfun.h liboctave/dNDArray.cc liboctave/dNDArray.h liboctave/int16NDArray.cc liboctave/int32NDArray.cc liboctave/int64NDArray.cc liboctave/int8NDArray.cc liboctave/mx-inlines.cc liboctave/uint16NDArray.cc liboctave/uint32NDArray.cc liboctave/uint64NDArray.cc liboctave/uint8NDArray.cc src/OPERATORS/op-int.h src/xpow.cc |
diffstat | 14 files changed, 118 insertions(+), 19 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/bsxfun.h +++ b/liboctave/bsxfun.h @@ -19,6 +19,8 @@ <http://www.gnu.org/licenses/>. */ +#if !defined (bsxfun_h) +#define bsxfun_h 1 #include <algorithm> @@ -38,3 +40,5 @@ } return true; } + +#endif
--- a/liboctave/dNDArray.cc +++ b/liboctave/dNDArray.cc @@ -925,3 +925,5 @@ BSXFUN_OP_DEF_MXLOOP (pow, NDArray, mx_inline_pow) BSXFUN_OP2_DEF_MXLOOP (pow, ComplexNDArray, ComplexNDArray, NDArray, mx_inline_pow) +BSXFUN_OP2_DEF_MXLOOP (pow, ComplexNDArray, NDArray, + ComplexNDArray, mx_inline_pow)
--- a/liboctave/dNDArray.h +++ b/liboctave/dNDArray.h @@ -185,5 +185,7 @@ BSXFUN_OP_DECL (pow, NDArray, OCTAVE_API) BSXFUN_OP2_DECL (pow, ComplexNDArray, ComplexNDArray, NDArray, OCTAVE_API) +BSXFUN_OP2_DECL (pow, ComplexNDArray, NDArray, + ComplexNDArray, OCTAVE_API) #endif
--- a/liboctave/int16NDArray.cc +++ b/liboctave/int16NDArray.cc @@ -54,3 +54,5 @@ BSXFUN_STDOP_DEFS_MXLOOP (int16NDArray) BSXFUN_STDREL_DEFS_MXLOOP (int16NDArray) + +BSXFUN_OP_DEF_MXLOOP (pow, int16NDArray, mx_inline_pow)
--- a/liboctave/int32NDArray.cc +++ b/liboctave/int32NDArray.cc @@ -54,3 +54,5 @@ BSXFUN_STDOP_DEFS_MXLOOP (int32NDArray) BSXFUN_STDREL_DEFS_MXLOOP (int32NDArray) + +BSXFUN_OP_DEF_MXLOOP (pow, int32NDArray, mx_inline_pow)
--- a/liboctave/int64NDArray.cc +++ b/liboctave/int64NDArray.cc @@ -54,3 +54,5 @@ BSXFUN_STDOP_DEFS_MXLOOP (int64NDArray) BSXFUN_STDREL_DEFS_MXLOOP (int64NDArray) + +BSXFUN_OP_DEF_MXLOOP (pow, int64NDArray, mx_inline_pow)
--- a/liboctave/int8NDArray.cc +++ b/liboctave/int8NDArray.cc @@ -54,3 +54,5 @@ BSXFUN_STDOP_DEFS_MXLOOP (int8NDArray) BSXFUN_STDREL_DEFS_MXLOOP (int8NDArray) + +BSXFUN_OP_DEF_MXLOOP (pow, int8NDArray, mx_inline_pow)
--- a/liboctave/mx-inlines.cc +++ b/liboctave/mx-inlines.cc @@ -288,7 +288,10 @@ inline void F (size_t n, R *r, X x, const Y *y) throw () \ { for (size_t i = 0; i < n; i++) r[i] = FUN (x, y[i]); } -DEFMXMAPPER2X (mx_inline_pow, std::pow) +// Let the compiler decide which pow to use, whichever best matches the +// arguments provided. +using std::pow; +DEFMXMAPPER2X (mx_inline_pow, pow) // Arbitrary function appliers. The function is a template parameter to enable // inlining.
--- a/liboctave/uint16NDArray.cc +++ b/liboctave/uint16NDArray.cc @@ -54,3 +54,5 @@ BSXFUN_STDOP_DEFS_MXLOOP (uint16NDArray) BSXFUN_STDREL_DEFS_MXLOOP (uint16NDArray) + +BSXFUN_OP_DEF_MXLOOP (pow, uint16NDArray, mx_inline_pow)
--- a/liboctave/uint32NDArray.cc +++ b/liboctave/uint32NDArray.cc @@ -54,3 +54,5 @@ BSXFUN_STDOP_DEFS_MXLOOP (uint32NDArray) BSXFUN_STDREL_DEFS_MXLOOP (uint32NDArray) + +BSXFUN_OP_DEF_MXLOOP (pow, uint32NDArray, mx_inline_pow)
--- a/liboctave/uint64NDArray.cc +++ b/liboctave/uint64NDArray.cc @@ -54,3 +54,5 @@ BSXFUN_STDOP_DEFS_MXLOOP (uint64NDArray) BSXFUN_STDREL_DEFS_MXLOOP (uint64NDArray) + +BSXFUN_OP_DEF_MXLOOP (pow, uint64NDArray, mx_inline_pow)
--- a/liboctave/uint8NDArray.cc +++ b/liboctave/uint8NDArray.cc @@ -54,3 +54,5 @@ BSXFUN_STDOP_DEFS_MXLOOP (uint8NDArray) BSXFUN_STDREL_DEFS_MXLOOP (uint8NDArray) + +BSXFUN_OP_DEF_MXLOOP (pow, uint8NDArray, mx_inline_pow)
--- a/src/OPERATORS/op-int.h +++ b/src/OPERATORS/op-int.h @@ -21,6 +21,7 @@ */ #include "quit.h" +#include "bsxfun.h" #define DEFINTBINOP_OP(name, t1, t2, op, t3) \ BINOPDECL (name, a1, a2) \ @@ -703,8 +704,15 @@ dim_vector b_dims = b.dims (); \ if (a_dims != b_dims) \ { \ - gripe_nonconformant ("operator .^", a_dims, b_dims); \ - return octave_value (); \ + if (is_valid_bsxfun (a_dims, b_dims)) \ + { \ + return bsxfun_pow (a, b); \ + } \ + else \ + { \ + gripe_nonconformant ("operator .^", a_dims, b_dims); \ + return octave_value (); \ + } \ } \ T1 ## NDArray result (a_dims); \ for (int i = 0; i < a.length (); i++) \ @@ -722,8 +730,15 @@ dim_vector b_dims = b.dims (); \ if (a_dims != b_dims) \ { \ - gripe_nonconformant ("operator .^", a_dims, b_dims); \ - return octave_value (); \ + if (is_valid_bsxfun (a_dims, b_dims)) \ + { \ + return bsxfun_pow (a, static_cast<T1 ## NDArray> (b)); \ + } \ + else \ + { \ + gripe_nonconformant ("operator .^", a_dims, b_dims); \ + return octave_value (); \ + } \ } \ T1 ## NDArray result (a_dims); \ for (int i = 0; i < a.length (); i++) \ @@ -741,8 +756,15 @@ dim_vector b_dims = b.dims (); \ if (a_dims != b_dims) \ { \ - gripe_nonconformant ("operator .^", a_dims, b_dims); \ - return octave_value (); \ + if (is_valid_bsxfun (a_dims, b_dims)) \ + { \ + return bsxfun_pow (static_cast<T2 ## NDArray> (a), b); \ + } \ + else \ + { \ + gripe_nonconformant ("operator .^", a_dims, b_dims); \ + return octave_value (); \ + } \ } \ T2 ## NDArray result (a_dims); \ for (int i = 0; i < a.length (); i++) \ @@ -760,8 +782,15 @@ dim_vector b_dims = b.dims (); \ if (a_dims != b_dims) \ { \ - gripe_nonconformant ("operator .^", a_dims, b_dims); \ - return octave_value (); \ + if (is_valid_bsxfun (a_dims, b_dims)) \ + { \ + return bsxfun_pow (a, static_cast<T1 ## NDArray> (b)); \ + } \ + else \ + { \ + gripe_nonconformant ("operator .^", a_dims, b_dims); \ + return octave_value (); \ + } \ } \ T1 ## NDArray result (a_dims); \ for (int i = 0; i < a.length (); i++) \ @@ -779,8 +808,15 @@ dim_vector b_dims = b.dims (); \ if (a_dims != b_dims) \ { \ - gripe_nonconformant ("operator .^", a_dims, b_dims); \ - return octave_value (); \ + if (is_valid_bsxfun (a_dims, b_dims)) \ + { \ + return bsxfun_pow (static_cast<T1 ## NDArray> (a), b); \ + } \ + else \ + { \ + gripe_nonconformant ("operator .^", a_dims, b_dims); \ + return octave_value (); \ + } \ } \ T2 ## NDArray result (a_dims); \ for (int i = 0; i < a.length (); i++) \
--- a/src/xpow.cc +++ b/src/xpow.cc @@ -49,6 +49,8 @@ #include "utils.h" #include "xpow.h" +#include "bsxfun.h" + #ifdef _OPENMP #include <omp.h> #endif @@ -1243,8 +1245,21 @@ if (a_dims != b_dims) { - gripe_nonconformant ("operator .^", a_dims, b_dims); - return octave_value (); + if (is_valid_bsxfun (a_dims, b_dims)) + { + //Potentially complex results + NDArray xa = octave_value_extract<NDArray> (a); + NDArray xb = octave_value_extract<NDArray> (b); + if (! xb.all_integers () && xa.any_element_is_negative ()) + return octave_value (bsxfun_pow (ComplexNDArray (xa), xb)); + else + return octave_value (bsxfun_pow (xa, xb)); + } + else + { + gripe_nonconformant ("operator .^", a_dims, b_dims); + return octave_value (); + } } int len = a.length (); @@ -1318,8 +1333,15 @@ if (a_dims != b_dims) { - gripe_nonconformant ("operator .^", a_dims, b_dims); - return octave_value (); + if (is_valid_bsxfun (a_dims, b_dims)) + { + return bsxfun_pow (a, b); + } + else + { + gripe_nonconformant ("operator .^", a_dims, b_dims); + return octave_value (); + } } ComplexNDArray result (a_dims); @@ -1410,8 +1432,15 @@ if (a_dims != b_dims) { - gripe_nonconformant ("operator .^", a_dims, b_dims); - return octave_value (); + if (is_valid_bsxfun (a_dims, b_dims)) + { + return bsxfun_pow (a, b); + } + else + { + gripe_nonconformant ("operator .^", a_dims, b_dims); + return octave_value (); + } } ComplexNDArray result (a_dims); @@ -1453,8 +1482,15 @@ if (a_dims != b_dims) { - gripe_nonconformant ("operator .^", a_dims, b_dims); - return octave_value (); + if (is_valid_bsxfun (a_dims, b_dims)) + { + return bsxfun_pow (a, b); + } + else + { + gripe_nonconformant ("operator .^", a_dims, b_dims); + return octave_value (); + } } ComplexNDArray result (a_dims);