Mercurial > hg > octave-shane
changeset 7660:5f6e11567f70
Allow convolving real data with complex data
author | sh@sh-laptop |
---|---|
date | Thu, 27 Mar 2008 16:15:36 -0400 |
parents | 4ab2488ab2b4 |
children | f3493c40a0bd |
files | src/ChangeLog src/DLD-FUNCTIONS/__convn__.cc |
diffstat | 2 files changed, 79 insertions(+), 20 deletions(-) [+] |
line wrap: on
line diff
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,13 @@ +2008-03-27 John W. Eaton <jwe@octave.org> + + * DLD-FUNCTIONS/__convn__.cc (convn): Use traits class and + typedefs to allow all types to be deduced from argument types. + +2008-03-27 Soren Hauberg <hauberg@gmail.com> + + * DLD-FUNCTIONS/__convn__.cc (Fconvn): Allow convolving real data with + complex data. + 2008-03-26 John W. Eaton <jwe@octave.org> * ov-range.h (octave_range::subsref (const std::string&,
--- a/src/DLD-FUNCTIONS/__convn__.cc +++ b/src/DLD-FUNCTIONS/__convn__.cc @@ -31,10 +31,32 @@ #include "defun-dld.h" +template <class T1, class T2> +class +octave_convn_traits +{ +public: + // The return type for a T1 by T2 convn operation. + typedef T1 TR; +}; + +#define OCTAVE_CONVN_TRAIT(T1, T2, T3) \ + template<> \ + class octave_convn_traits <T1, T2> \ + { \ + public: \ + typedef T3 TR; \ + } + +OCTAVE_CONVN_TRAIT (NDArray, NDArray, NDArray); +OCTAVE_CONVN_TRAIT (ComplexNDArray, NDArray, ComplexNDArray); +OCTAVE_CONVN_TRAIT (NDArray, ComplexNDArray, ComplexNDArray); +OCTAVE_CONVN_TRAIT (ComplexNDArray, ComplexNDArray, ComplexNDArray); + // FIXME -- this function should maybe be available in liboctave? -template <class MT, class ST> +template <class MTa, class MTb> octave_value -convn (const MT& a, const MT& b) +convn (const MTa& a, const MTb& b) { octave_value retval; @@ -56,7 +78,9 @@ for (octave_idx_type n = 0; n < ndims; n++) out_size(n) = std::max (a_size(n) - b_size(n) + 1, 0); - MT out = MT (out_size); + typedef typename octave_convn_traits<MTa, MTb>::TR MTout; + + MTout out (out_size); const octave_idx_type out_numel = out.numel (); @@ -72,7 +96,7 @@ OCTAVE_QUIT; // For each neighbour - ST sum = 0; + typename MTout::element_type sum = 0; for (octave_idx_type n = 0; n < ndims; n++) b_idx(n) = 0; @@ -108,24 +132,49 @@ if (args.length () == 2) { - if (args(0).is_real_type () && args(1).is_real_type ()) - { - const NDArray a = args (0).array_value (); - const NDArray b = args (1).array_value (); + if (args(0).is_real_type ()) + { + if (args(1).is_real_type ()) + { + const NDArray a = args (0).array_value (); + const NDArray b = args (1).array_value (); + + if (! error_state) + retval = convn (a, b); + } + else if (args(1).is_complex_type ()) + { + const NDArray a = args (0).array_value (); + const ComplexNDArray b = args (1).complex_array_value (); - if (! error_state) - retval = convn<NDArray, double> (a, b); - } - else if (args(0).is_complex_type () && args(1).is_complex_type ()) - { - const ComplexNDArray a = args (0).complex_array_value (); - const ComplexNDArray b = args (1).complex_array_value (); + if (! error_state) + retval = convn (a, b); + } + else + error ("__convn__: invalid call"); + } + else if (args(0).is_complex_type ()) + { + if (args(1).is_complex_type ()) + { + const ComplexNDArray a = args (0).complex_array_value (); + const ComplexNDArray b = args (1).complex_array_value (); - if (! error_state) - retval = convn<ComplexNDArray, Complex> (a, b); - } - else - error ("__convn__: first and second input should be real, or complex arrays"); + if (! error_state) + retval = convn (a, b); + } + else if (args(1).is_real_type ()) + { + const ComplexNDArray a = args (0).complex_array_value (); + const NDArray b = args (1).array_value (); + + if (! error_state) + retval = convn (a, b); + } + else + error ("__convn__: invalid call"); + } + error ("__convn__: invalid call"); } else print_usage ();