changeset 13005:4061106b1c4b

Enable automatic bsxfun for power operators * bsxfun.h: Put #include guards * int8NDArray.cc: Define bsxfun power operator for integral types. * int16NDArray.cc: Ditto. * int32NDArray.cc: Ditto. * int64NDArray.cc: fDitto. * uint8ADArray.cc: Ditto. * uint16NDArray.cc: Ditto. * uint32NDArray.cc: Ditto. * uint64NDArray.cc: Ditto. * mx-inlines.cc: Let the compiler decide to use Octave's own integral pow. * op-int.h: Call bsxfun for integral operators. * xpow.cc: Call bsxfun for float operators.
author Jordi Gutiérrez Hermoso <jordigh@gmail.com>
date Wed, 24 Aug 2011 23:12:28 -0500
parents d9d65c3017c3
children 61be447052c3
files liboctave/bsxfun.h liboctave/dNDArray.cc liboctave/dNDArray.h liboctave/int16NDArray.cc liboctave/int32NDArray.cc liboctave/int64NDArray.cc liboctave/int8NDArray.cc liboctave/mx-inlines.cc liboctave/uint16NDArray.cc liboctave/uint32NDArray.cc liboctave/uint64NDArray.cc liboctave/uint8NDArray.cc src/OPERATORS/op-int.h src/xpow.cc
diffstat 14 files changed, 118 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/bsxfun.h
+++ b/liboctave/bsxfun.h
@@ -19,6 +19,8 @@
 <http://www.gnu.org/licenses/>.
 
 */
+#if !defined (bsxfun_h)
+#define bsxfun_h 1
 
 #include <algorithm>
 
@@ -38,3 +40,5 @@
     }
   return true;
 }
+
+#endif
--- a/liboctave/dNDArray.cc
+++ b/liboctave/dNDArray.cc
@@ -925,3 +925,5 @@
 BSXFUN_OP_DEF_MXLOOP (pow, NDArray, mx_inline_pow)
 BSXFUN_OP2_DEF_MXLOOP (pow, ComplexNDArray, ComplexNDArray,
                        NDArray, mx_inline_pow)
+BSXFUN_OP2_DEF_MXLOOP (pow, ComplexNDArray, NDArray,
+                       ComplexNDArray, mx_inline_pow)
--- a/liboctave/dNDArray.h
+++ b/liboctave/dNDArray.h
@@ -185,5 +185,7 @@
 BSXFUN_OP_DECL (pow, NDArray, OCTAVE_API)
 BSXFUN_OP2_DECL (pow, ComplexNDArray, ComplexNDArray,
                  NDArray, OCTAVE_API)
+BSXFUN_OP2_DECL (pow, ComplexNDArray, NDArray,
+                 ComplexNDArray, OCTAVE_API)
 
 #endif
--- a/liboctave/int16NDArray.cc
+++ b/liboctave/int16NDArray.cc
@@ -54,3 +54,5 @@
 
 BSXFUN_STDOP_DEFS_MXLOOP (int16NDArray)
 BSXFUN_STDREL_DEFS_MXLOOP (int16NDArray)
+
+BSXFUN_OP_DEF_MXLOOP (pow, int16NDArray, mx_inline_pow)
--- a/liboctave/int32NDArray.cc
+++ b/liboctave/int32NDArray.cc
@@ -54,3 +54,5 @@
 
 BSXFUN_STDOP_DEFS_MXLOOP (int32NDArray)
 BSXFUN_STDREL_DEFS_MXLOOP (int32NDArray)
+
+BSXFUN_OP_DEF_MXLOOP (pow, int32NDArray, mx_inline_pow)
--- a/liboctave/int64NDArray.cc
+++ b/liboctave/int64NDArray.cc
@@ -54,3 +54,5 @@
 
 BSXFUN_STDOP_DEFS_MXLOOP (int64NDArray)
 BSXFUN_STDREL_DEFS_MXLOOP (int64NDArray)
+
+BSXFUN_OP_DEF_MXLOOP (pow, int64NDArray, mx_inline_pow)
--- a/liboctave/int8NDArray.cc
+++ b/liboctave/int8NDArray.cc
@@ -54,3 +54,5 @@
 
 BSXFUN_STDOP_DEFS_MXLOOP (int8NDArray)
 BSXFUN_STDREL_DEFS_MXLOOP (int8NDArray)
+
+BSXFUN_OP_DEF_MXLOOP (pow, int8NDArray, mx_inline_pow)
--- a/liboctave/mx-inlines.cc
+++ b/liboctave/mx-inlines.cc
@@ -288,7 +288,10 @@
 inline void F (size_t n, R *r, X x, const Y *y) throw () \
 { for (size_t i = 0; i < n; i++) r[i] = FUN (x, y[i]); }
 
-DEFMXMAPPER2X (mx_inline_pow, std::pow)
+// Let the compiler decide which pow to use, whichever best matches the
+// arguments provided.
+using std::pow;
+DEFMXMAPPER2X (mx_inline_pow, pow)
 
 // Arbitrary function appliers. The function is a template parameter to enable
 // inlining.
--- a/liboctave/uint16NDArray.cc
+++ b/liboctave/uint16NDArray.cc
@@ -54,3 +54,5 @@
 
 BSXFUN_STDOP_DEFS_MXLOOP (uint16NDArray)
 BSXFUN_STDREL_DEFS_MXLOOP (uint16NDArray)
+
+BSXFUN_OP_DEF_MXLOOP (pow, uint16NDArray, mx_inline_pow)
--- a/liboctave/uint32NDArray.cc
+++ b/liboctave/uint32NDArray.cc
@@ -54,3 +54,5 @@
 
 BSXFUN_STDOP_DEFS_MXLOOP (uint32NDArray)
 BSXFUN_STDREL_DEFS_MXLOOP (uint32NDArray)
+
+BSXFUN_OP_DEF_MXLOOP (pow, uint32NDArray, mx_inline_pow)
--- a/liboctave/uint64NDArray.cc
+++ b/liboctave/uint64NDArray.cc
@@ -54,3 +54,5 @@
 
 BSXFUN_STDOP_DEFS_MXLOOP (uint64NDArray)
 BSXFUN_STDREL_DEFS_MXLOOP (uint64NDArray)
+
+BSXFUN_OP_DEF_MXLOOP (pow, uint64NDArray, mx_inline_pow)
--- a/liboctave/uint8NDArray.cc
+++ b/liboctave/uint8NDArray.cc
@@ -54,3 +54,5 @@
 
 BSXFUN_STDOP_DEFS_MXLOOP (uint8NDArray)
 BSXFUN_STDREL_DEFS_MXLOOP (uint8NDArray)
+
+BSXFUN_OP_DEF_MXLOOP (pow, uint8NDArray, mx_inline_pow)
--- a/src/OPERATORS/op-int.h
+++ b/src/OPERATORS/op-int.h
@@ -21,6 +21,7 @@
 */
 
 #include "quit.h"
+#include "bsxfun.h"
 
 #define DEFINTBINOP_OP(name, t1, t2, op, t3) \
   BINOPDECL (name, a1, a2) \
@@ -703,8 +704,15 @@
     dim_vector b_dims = b.dims (); \
     if (a_dims != b_dims) \
       { \
-        gripe_nonconformant ("operator .^", a_dims, b_dims); \
-        return octave_value (); \
+        if (is_valid_bsxfun (a_dims, b_dims)) \
+          { \
+            return bsxfun_pow (a, b); \
+          } \
+        else \
+          { \
+            gripe_nonconformant ("operator .^", a_dims, b_dims);  \
+            return octave_value (); \
+          } \
       } \
     T1 ## NDArray result (a_dims); \
     for (int i = 0; i < a.length (); i++) \
@@ -722,8 +730,15 @@
     dim_vector b_dims = b.dims (); \
     if (a_dims != b_dims) \
       { \
-        gripe_nonconformant ("operator .^", a_dims, b_dims); \
-        return octave_value (); \
+        if (is_valid_bsxfun (a_dims, b_dims)) \
+          { \
+            return bsxfun_pow (a, static_cast<T1 ## NDArray> (b)); \
+          } \
+        else \
+          { \
+            gripe_nonconformant ("operator .^", a_dims, b_dims);  \
+            return octave_value (); \
+          } \
       } \
     T1 ## NDArray result (a_dims); \
     for (int i = 0; i < a.length (); i++) \
@@ -741,8 +756,15 @@
     dim_vector b_dims = b.dims (); \
     if (a_dims != b_dims) \
       { \
-        gripe_nonconformant ("operator .^", a_dims, b_dims); \
-        return octave_value (); \
+        if (is_valid_bsxfun (a_dims, b_dims)) \
+          { \
+            return bsxfun_pow (static_cast<T2 ## NDArray> (a), b); \
+          } \
+        else \
+          { \
+            gripe_nonconformant ("operator .^", a_dims, b_dims);  \
+            return octave_value (); \
+          } \
       } \
     T2 ## NDArray result (a_dims); \
     for (int i = 0; i < a.length (); i++) \
@@ -760,8 +782,15 @@
     dim_vector b_dims = b.dims (); \
     if (a_dims != b_dims) \
       { \
-        gripe_nonconformant ("operator .^", a_dims, b_dims); \
-        return octave_value (); \
+        if (is_valid_bsxfun (a_dims, b_dims)) \
+          { \
+            return bsxfun_pow (a, static_cast<T1 ## NDArray> (b)); \
+          } \
+        else \
+          { \
+            gripe_nonconformant ("operator .^", a_dims, b_dims);  \
+            return octave_value (); \
+          } \
       } \
     T1 ## NDArray result (a_dims); \
     for (int i = 0; i < a.length (); i++) \
@@ -779,8 +808,15 @@
     dim_vector b_dims = b.dims (); \
     if (a_dims != b_dims) \
       { \
-        gripe_nonconformant ("operator .^", a_dims, b_dims); \
-        return octave_value (); \
+        if (is_valid_bsxfun (a_dims, b_dims)) \
+          { \
+            return bsxfun_pow (static_cast<T1 ## NDArray> (a), b); \
+          } \
+        else \
+          { \
+            gripe_nonconformant ("operator .^", a_dims, b_dims);  \
+            return octave_value (); \
+          } \
       } \
     T2 ## NDArray result (a_dims); \
     for (int i = 0; i < a.length (); i++) \
--- a/src/xpow.cc
+++ b/src/xpow.cc
@@ -49,6 +49,8 @@
 #include "utils.h"
 #include "xpow.h"
 
+#include "bsxfun.h"
+
 #ifdef _OPENMP
 #include <omp.h>
 #endif
@@ -1243,8 +1245,21 @@
 
   if (a_dims != b_dims)
     {
-      gripe_nonconformant ("operator .^", a_dims, b_dims);
-      return octave_value ();
+      if (is_valid_bsxfun (a_dims, b_dims))
+        {
+          //Potentially complex results
+          NDArray xa = octave_value_extract<NDArray> (a);
+          NDArray xb = octave_value_extract<NDArray> (b);
+          if (! xb.all_integers () && xa.any_element_is_negative ())
+            return octave_value (bsxfun_pow (ComplexNDArray (xa), xb));
+          else
+            return octave_value (bsxfun_pow (xa, xb));
+        }
+      else
+        {
+          gripe_nonconformant ("operator .^", a_dims, b_dims);
+          return octave_value ();
+        }
     }
 
   int len = a.length ();
@@ -1318,8 +1333,15 @@
 
   if (a_dims != b_dims)
     {
-      gripe_nonconformant ("operator .^", a_dims, b_dims);
-      return octave_value ();
+      if (is_valid_bsxfun (a_dims, b_dims))
+        {
+          return bsxfun_pow (a, b);
+        }
+      else
+        {
+          gripe_nonconformant ("operator .^", a_dims, b_dims);
+          return octave_value ();
+        }
     }
 
   ComplexNDArray result (a_dims);
@@ -1410,8 +1432,15 @@
 
   if (a_dims != b_dims)
     {
-      gripe_nonconformant ("operator .^", a_dims, b_dims);
-      return octave_value ();
+      if (is_valid_bsxfun (a_dims, b_dims))
+        {
+          return bsxfun_pow (a, b);
+        }
+      else
+        {
+          gripe_nonconformant ("operator .^", a_dims, b_dims);
+          return octave_value ();
+        }
     }
 
   ComplexNDArray result (a_dims);
@@ -1453,8 +1482,15 @@
 
   if (a_dims != b_dims)
     {
-      gripe_nonconformant ("operator .^", a_dims, b_dims);
-      return octave_value ();
+      if (is_valid_bsxfun (a_dims, b_dims))
+        {
+          return bsxfun_pow (a, b);
+        }
+      else
+        {
+          gripe_nonconformant ("operator .^", a_dims, b_dims);
+          return octave_value ();
+        }
     }
 
   ComplexNDArray result (a_dims);