# HG changeset patch # User sh@sh-laptop # Date 1206648936 14400 # Node ID 5f6e11567f7065a763de2895f3cfac49ca4c9863 # Parent 4ab2488ab2b4c0be1f62c47f09279170e36de619 Allow convolving real data with complex data diff --git a/src/ChangeLog b/src/ChangeLog --- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,13 @@ +2008-03-27 John W. Eaton + + * DLD-FUNCTIONS/__convn__.cc (convn): Use traits class and + typedefs to allow all types to be deduced from argument types. + +2008-03-27 Soren Hauberg + + * DLD-FUNCTIONS/__convn__.cc (Fconvn): Allow convolving real data with + complex data. + 2008-03-26 John W. Eaton * ov-range.h (octave_range::subsref (const std::string&, diff --git a/src/DLD-FUNCTIONS/__convn__.cc b/src/DLD-FUNCTIONS/__convn__.cc --- a/src/DLD-FUNCTIONS/__convn__.cc +++ b/src/DLD-FUNCTIONS/__convn__.cc @@ -31,10 +31,32 @@ #include "defun-dld.h" +template +class +octave_convn_traits +{ +public: + // The return type for a T1 by T2 convn operation. + typedef T1 TR; +}; + +#define OCTAVE_CONVN_TRAIT(T1, T2, T3) \ + template<> \ + class octave_convn_traits \ + { \ + public: \ + typedef T3 TR; \ + } + +OCTAVE_CONVN_TRAIT (NDArray, NDArray, NDArray); +OCTAVE_CONVN_TRAIT (ComplexNDArray, NDArray, ComplexNDArray); +OCTAVE_CONVN_TRAIT (NDArray, ComplexNDArray, ComplexNDArray); +OCTAVE_CONVN_TRAIT (ComplexNDArray, ComplexNDArray, ComplexNDArray); + // FIXME -- this function should maybe be available in liboctave? -template +template octave_value -convn (const MT& a, const MT& b) +convn (const MTa& a, const MTb& b) { octave_value retval; @@ -56,7 +78,9 @@ for (octave_idx_type n = 0; n < ndims; n++) out_size(n) = std::max (a_size(n) - b_size(n) + 1, 0); - MT out = MT (out_size); + typedef typename octave_convn_traits::TR MTout; + + MTout out (out_size); const octave_idx_type out_numel = out.numel (); @@ -72,7 +96,7 @@ OCTAVE_QUIT; // For each neighbour - ST sum = 0; + typename MTout::element_type sum = 0; for (octave_idx_type n = 0; n < ndims; n++) b_idx(n) = 0; @@ -108,24 +132,49 @@ if (args.length () == 2) { - if (args(0).is_real_type () && args(1).is_real_type ()) - { - const NDArray a = args (0).array_value (); - const NDArray b = args (1).array_value (); + if (args(0).is_real_type ()) + { + if (args(1).is_real_type ()) + { + const NDArray a = args (0).array_value (); + const NDArray b = args (1).array_value (); + + if (! error_state) + retval = convn (a, b); + } + else if (args(1).is_complex_type ()) + { + const NDArray a = args (0).array_value (); + const ComplexNDArray b = args (1).complex_array_value (); - if (! error_state) - retval = convn (a, b); - } - else if (args(0).is_complex_type () && args(1).is_complex_type ()) - { - const ComplexNDArray a = args (0).complex_array_value (); - const ComplexNDArray b = args (1).complex_array_value (); + if (! error_state) + retval = convn (a, b); + } + else + error ("__convn__: invalid call"); + } + else if (args(0).is_complex_type ()) + { + if (args(1).is_complex_type ()) + { + const ComplexNDArray a = args (0).complex_array_value (); + const ComplexNDArray b = args (1).complex_array_value (); - if (! error_state) - retval = convn (a, b); - } - else - error ("__convn__: first and second input should be real, or complex arrays"); + if (! error_state) + retval = convn (a, b); + } + else if (args(1).is_real_type ()) + { + const ComplexNDArray a = args (0).complex_array_value (); + const NDArray b = args (1).array_value (); + + if (! error_state) + retval = convn (a, b); + } + else + error ("__convn__: invalid call"); + } + error ("__convn__: invalid call"); } else print_usage ();