changeset 13004:d9d65c3017c3

Make bsxfun automatic for most binary operators. * MArray.cc: Give do_mm_binary_op two extra loop arguments so it can pass them to bsxfun. * MDiagArray2.cc: Ditto. * mx-op-defs.h: Ditto. * bsxfun.h: New file. * Makefile.am: Add bsxfun.h to includes. * mx-inlines.cc: Call do_bsxfun_op when appropriate.
author Jordi Gutiérrez Hermoso <jordigh@gmail.com>
date Wed, 24 Aug 2011 23:06:59 -0500
parents 74c5fa5cef47
children 4061106b1c4b
files liboctave/MArray.cc liboctave/MDiagArray2.cc liboctave/Makefile.am liboctave/bsxfun.h liboctave/mx-inlines.cc liboctave/mx-op-defs.h
diffstat 6 files changed, 59 insertions(+), 11 deletions(-) [+]
line wrap: on
line diff
--- a/liboctave/MArray.cc
+++ b/liboctave/MArray.cc
@@ -339,7 +339,7 @@
   MArray<T> \
   FCN (const MArray<T>& a, const MArray<T>& b) \
   { \
-    return do_mm_binary_op<T, T, T> (a, b, FN, #FCN); \
+    return do_mm_binary_op<T, T, T> (a, b, FN, FN, FN, #FCN); \
   }
 
 MARRAY_NDND_OP (operator +, +, mx_inline_add)
--- a/liboctave/MDiagArray2.cc
+++ b/liboctave/MDiagArray2.cc
@@ -82,7 +82,7 @@
   { \
     if (a.d1 != b.d1 || a.d2 != b.d2) \
       gripe_nonconformant (#FCN, a.d1, a.d2, b.d1, b.d2); \
-    return MDiagArray2<T> (do_mm_binary_op<T, T, T> (a, b, FN, #FCN), a.d1, a.d2); \
+    return MDiagArray2<T> (do_mm_binary_op<T, T, T> (a, b, FN, FN, FN, #FCN), a.d1, a.d2); \
   }
 
 MARRAY_DADA_OP (operator +, +, mx_inline_add)
--- a/liboctave/Makefile.am
+++ b/liboctave/Makefile.am
@@ -188,6 +188,7 @@
   base-dae.h \
   base-de.h \
   base-min.h \
+  bsxfun.h \
   byte-swap.h \
   caseless-str.h \
   cmd-edit.h \
new file mode 100644
--- /dev/null
+++ b/liboctave/bsxfun.h
@@ -0,0 +1,40 @@
+/*
+
+Copyright (C) 2011 Jordi GutiƩrrez Hermoso <jordigh@octave.org>
+
+This file is part of Octave.
+
+Octave is free software; you can redistribute it and/or modify it
+under the terms of the GNU General Public License as published by the
+Free Software Foundation; either version 3 of the License, or (at your
+option) any later version.
+
+Octave is distributed in the hope that it will be useful, but WITHOUT
+ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+for more details.
+
+You should have received a copy of the GNU General Public License
+along with Octave; see the file COPYING.  If not, see
+<http://www.gnu.org/licenses/>.
+
+*/
+
+#include <algorithm>
+
+#include "Array.h"
+#include "dim-vector.h"
+
+#include "bsxfun-defs.cc"
+
+inline
+bool
+is_valid_bsxfun (const dim_vector& dx, const dim_vector& dy)
+{
+  for (int i = 0; i < std::min (dx.length (), dy.length ()); i++)
+    {
+      if ( dx(i) > 1 && dy(i) > 1 && dx(i) != dy(i))
+        return false;
+    }
+  return true;
+}
--- a/liboctave/mx-inlines.cc
+++ b/liboctave/mx-inlines.cc
@@ -37,6 +37,8 @@
 #include "Array.h"
 #include "Array-util.h"
 
+#include "bsxfun.h"
+
 // Provides some commonly repeated, basic loop templates.
 
 template <class R, class S>
@@ -336,11 +338,12 @@
   return r;
 }
 
-
 template <class R, class X, class Y>
 inline Array<R>
 do_mm_binary_op (const Array<X>& x, const Array<Y>& y,
                  void (*op) (size_t, R *, const X *, const Y *) throw (),
+                 void (*op1) (size_t, R *, X, const Y *) throw (),
+                 void (*op2) (size_t, R *, const X *, Y) throw (),
                  const char *opname)
 {
   dim_vector dx = x.dims (), dy = y.dims ();
@@ -350,6 +353,10 @@
       op (r.length (), r.fortran_vec (), x.data (), y.data ());
       return r;
     }
+  else if (is_valid_bsxfun (dx, dy))
+    {
+      return do_bsxfun_op (x, y, op, op1, op2);
+    }
   else
     {
       gripe_nonconformant (opname, dx, dy);
--- a/liboctave/mx-op-defs.h
+++ b/liboctave/mx-op-defs.h
@@ -72,7 +72,7 @@
   R \
   F (const V1& v1, const V2& v2) \
   { \
-    return do_mm_binary_op<R::element_type, V1::element_type, V2::element_type> (v1, v2, OP, #F); \
+    return do_mm_binary_op<R::element_type, V1::element_type, V2::element_type> (v1, v2, OP, OP, OP, #F); \
   }
 
 #define VV_BIN_OPS(R, V1, V2) \
@@ -173,7 +173,7 @@
   R \
   OP (const M1& m1, const M2& m2) \
   { \
-    return do_mm_binary_op<R::element_type, M1::element_type, M2::element_type> (m1, m2, F, #OP); \
+    return do_mm_binary_op<R::element_type, M1::element_type, M2::element_type> (m1, m2, F, F, F, #OP); \
   }
 
 #define MM_BIN_OPS(R, M1, M2) \
@@ -186,7 +186,7 @@
   boolMatrix \
   F (const M1& m1, const M2& m2) \
   { \
-    return do_mm_binary_op<bool, M1::element_type, M2::element_type> (m1, m2, OP, #F); \
+    return do_mm_binary_op<bool, M1::element_type, M2::element_type> (m1, m2, OP, OP, OP, #F); \
   }
 
 #define MM_CMP_OPS(M1, M2) \
@@ -203,7 +203,7 @@
   { \
     MNANCHK (m1, M1::element_type); \
     MNANCHK (m2, M2::element_type); \
-    return do_mm_binary_op<bool, M1::element_type, M2::element_type> (m1, m2, OP, #F); \
+    return do_mm_binary_op<bool, M1::element_type, M2::element_type> (m1, m2, OP, OP, OP, #F); \
   }
 
 #define MM_BOOL_OPS(M1, M2) \
@@ -310,7 +310,7 @@
   R \
   OP (const ND1& m1, const ND2& m2) \
   { \
-    return do_mm_binary_op<R::element_type, ND1::element_type, ND2::element_type> (m1, m2, F, #OP); \
+    return do_mm_binary_op<R::element_type, ND1::element_type, ND2::element_type> (m1, m2, F, F, F, #OP); \
   }
 
 #define NDND_BIN_OPS(R, ND1, ND2) \
@@ -323,7 +323,7 @@
   boolNDArray \
   F (const ND1& m1, const ND2& m2) \
   { \
-    return do_mm_binary_op<bool, ND1::element_type, ND2::element_type> (m1, m2, OP, #F); \
+    return do_mm_binary_op<bool, ND1::element_type, ND2::element_type> (m1, m2, OP, OP, OP, #F); \
   }
 
 #define NDND_CMP_OPS(ND1, ND2) \
@@ -340,7 +340,7 @@
   { \
     MNANCHK (m1, ND1::element_type); \
     MNANCHK (m2, ND2::element_type); \
-    return do_mm_binary_op<bool, ND1::element_type, ND2::element_type> (m1, m2, OP, #F); \
+    return do_mm_binary_op<bool, ND1::element_type, ND2::element_type> (m1, m2, OP, OP, OP, #F); \
   }
 
 #define NDND_BOOL_OPS(ND1, ND2) \
@@ -583,7 +583,7 @@
 T \
 FCN (const T& a, const T& b) \
 { \
-  return do_mm_binary_op<T::element_type, T::element_type, T::element_type> (a, b, mx_inline_x##FCN, #FCN); \
+  return do_mm_binary_op<T::element_type, T::element_type, T::element_type> (a, b, mx_inline_x##FCN, mx_inline_x##FCN, mx_inline_x##FCN, #FCN); \
 }
 
 #define MINMAX_FCNS(T, S) \