diff liboctave/mx-inlines.cc @ 8736:53b4fdeacc2e

improve reduction functions
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 13 Feb 2009 21:04:50 +0100
parents a1ae2aae903e
children 1bd918cfb6e2
line wrap: on
line diff
--- a/liboctave/mx-inlines.cc
+++ b/liboctave/mx-inlines.cc
@@ -282,6 +282,220 @@
 OP_DUP_FCN (imag, mx_inline_imag_dup, float,  FloatComplex)
 OP_DUP_FCN (conj, mx_inline_conj_dup, FloatComplex, FloatComplex)
 
+// NOTE: std::norm is NOT equivalent
+template <class T>
+T cabsq (const std::complex<T>& c) 
+{ return c.real () * c.real () + c.imag () * c.imag (); }
+
+#define OP_RED_SUM(ac, el) ac += el
+#define OP_RED_PROD(ac, el) ac *= el
+#define OP_RED_SUMSQ(ac, el) ac += el*el
+#define OP_RED_SUMSQC(ac, el) ac += cabsq (el)
+
+#define OP_RED_FCN(F, TSRC, OP, ZERO) \
+template <class T> \
+inline T \
+F (const TSRC* v, octave_idx_type n) \
+{ \
+  T ac = ZERO; \
+  for (octave_idx_type i = 0; i < n; i++) \
+    OP(ac, v[i]); \
+  return ac; \
+}
+
+OP_RED_FCN (mx_inline_sum, T, OP_RED_SUM, 0)
+OP_RED_FCN (mx_inline_prod, T, OP_RED_PROD, 1)
+OP_RED_FCN (mx_inline_sumsq, T, OP_RED_SUMSQ, 0)
+OP_RED_FCN (mx_inline_sumsq, std::complex<T>, OP_RED_SUMSQC, 0)
+
+#define OP_RED_FCN2(F, TSRC, OP, ZERO) \
+template <class T> \
+inline void \
+F (const TSRC* v, T *r, octave_idx_type m, octave_idx_type n) \
+{ \
+  for (octave_idx_type i = 0; i < m; i++) \
+    r[i] = ZERO; \
+  for (octave_idx_type j = 0; j < n; j++) \
+    { \
+      for (octave_idx_type i = 0; i < m; i++) \
+        OP(r[i], v[i]); \
+      v += m; \
+    } \
+}
+
+OP_RED_FCN2 (mx_inline_sum, T, OP_RED_SUM, 0)
+OP_RED_FCN2 (mx_inline_prod, T, OP_RED_PROD, 1)
+OP_RED_FCN2 (mx_inline_sumsq, T, OP_RED_SUMSQ, 0)
+OP_RED_FCN2 (mx_inline_sumsq, std::complex<T>, OP_RED_SUMSQC, 0)
+
+#define OP_RED_FCNN(F, TSRC) \
+template <class T> \
+inline void \
+F (const TSRC *v, T *r, octave_idx_type l, \
+   octave_idx_type n, octave_idx_type u) \
+{ \
+  if (l == 1) \
+    { \
+      for (octave_idx_type i = 0; i < u; i++) \
+        { \
+          r[i] = F (v, n); \
+          v += n; \
+        } \
+    } \
+  else \
+    { \
+      for (octave_idx_type i = 0; i < u; i++) \
+        { \
+          F (v, r, l, n); \
+          v += l*n; \
+          r += l; \
+        } \
+    } \
+}
+
+OP_RED_FCNN (mx_inline_sum, T)
+OP_RED_FCNN (mx_inline_prod, T)
+OP_RED_FCNN (mx_inline_sumsq, T)
+OP_RED_FCNN (mx_inline_sumsq, std::complex<T>)
+
+#define OP_CUM_FCN(F, OP) \
+template <class T> \
+inline void \
+F (const T *v, T *r, octave_idx_type n) \
+{ \
+  if (n) \
+    { \
+      T t = r[0] = v[0]; \
+      for (octave_idx_type i = 1; i < n; i++) \
+        r[i] = t = t OP v[i]; \
+    } \
+}
+
+OP_CUM_FCN (mx_inline_cumsum, +)
+OP_CUM_FCN (mx_inline_cumprod, *)
+
+#define OP_CUM_FCN2(F, OP) \
+template <class T> \
+inline void \
+F (const T *v, T *r, octave_idx_type m, octave_idx_type n) \
+{ \
+  if (n) \
+    { \
+      for (octave_idx_type i = 0; i < m; i++) \
+        r[i] = v[i]; \
+      const T *r0 = r; \
+      for (octave_idx_type j = 1; j < n; j++) \
+        { \
+          r += m; v += m; \
+          for (octave_idx_type i = 0; i < m; i++) \
+            r[i] = v[i] OP r0[i]; \
+          r0 += m; \
+        } \
+    } \
+}
+
+OP_CUM_FCN2 (mx_inline_cumsum, +)
+OP_CUM_FCN2 (mx_inline_cumprod, *)
+
+#define OP_CUM_FCNN(F) \
+template <class T> \
+inline void \
+F (const T *v, T *r, octave_idx_type l, \
+   octave_idx_type n, octave_idx_type u) \
+{ \
+  if (l == 1) \
+    { \
+      for (octave_idx_type i = 0; i < u; i++) \
+        { \
+          F (v, r, n); \
+          v += n; r += n; \
+        } \
+    } \
+  else \
+    { \
+      for (octave_idx_type i = 0; i < u; i++) \
+        { \
+          F (v, r, l, n); \
+          v += l*n; \
+          r += l*n; \
+        } \
+    } \
+}
+
+OP_CUM_FCNN (mx_inline_cumsum)
+OP_CUM_FCNN (mx_inline_cumprod)
+
+// Assistant function
+
+inline void
+get_extent_triplet (const dim_vector& dims, int& dim,
+                    octave_idx_type& l, octave_idx_type& n,
+                    octave_idx_type& u)
+{
+  octave_idx_type ndims = dims.length ();
+  if (dim >= ndims)
+    {
+      l = dims.numel ();
+      n = 1;
+      u = 1;
+    }
+  else
+    {
+      if (dim < 0)
+        {
+          // find first non-singleton dim
+          for (dim = 0; dims(dim) == 1 && dim < ndims - 1; dim++) ;
+        }
+      // calculate extent triplet.
+      l = 1, n = dims(dim), u = 1;
+      for (octave_idx_type i = 0; i < dim; i++) 
+        l *= dims (i);
+      for (octave_idx_type i = dim + 1; i < ndims; i++)
+        u *= dims (i);
+    }
+}
+
+// Appliers.
+// FIXME: is this the best design? C++ gives a lot of options here...
+// maybe it can be done without an explicit parameter?
+
+template <class ArrayType, class T>
+inline ArrayType
+do_mx_red_op (const Array<T>& src, int dim,
+              void (*mx_red_op) (const T *, typename ArrayType::element_type *,
+                                 octave_idx_type, octave_idx_type, octave_idx_type))
+{
+  octave_idx_type l, n, u;
+  dim_vector dims = src.dims ();
+  get_extent_triplet (dims, dim, l, n, u);
+
+  // Reduction operation reduces the array size.
+  if (dim < dims.length ()) dims(dim) = 1;
+  dims.chop_trailing_singletons ();
+
+  ArrayType ret (dims);
+  mx_red_op (src.data (), ret.fortran_vec (), l, n, u);
+
+  return ret;
+}
+
+template <class ArrayType, class T>
+inline ArrayType
+do_mx_cum_op (const Array<T>& src, int dim,
+              void (*mx_cum_op) (const T *, typename ArrayType::element_type *,
+                                 octave_idx_type, octave_idx_type, octave_idx_type))
+{
+  octave_idx_type l, n, u;
+  dim_vector dims = src.dims ();
+  get_extent_triplet (dims, dim, l, n, u);
+
+  // Cumulative operation doesn't reduce the array size.
+  ArrayType ret (dims);
+  mx_cum_op (src.data (), ret.fortran_vec (), l, n, u);
+
+  return ret;
+}
+
 // Avoid some code duplication.  Maybe we should use templates.
 
 #define MX_CUMULATIVE_OP(RET_TYPE, ELT_TYPE, OP) \