# HG changeset patch # User jwe # Date 1147898093 0 # Node ID e54c11df05245f51a9034f6bf365198a3eb32d31 # Parent 66a426e608cc360cde77888fdccb6626a2717ced [project @ 2006-05-17 20:34:52 by jwe] diff --git a/src/ChangeLog b/src/ChangeLog --- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,8 @@ +2006-05-04 David Bateman + + * DLD-FUNCTIONS/conv2.cc: New file. + * Makefile.in (DLD_XSRC): Add it to the list + 2006-05-17 Bill Denney * help.cc (keywords): Improve and Texinfoize. diff --git a/src/DLD-FUNCTIONS/conv2.cc b/src/DLD-FUNCTIONS/conv2.cc new file mode 100644 --- /dev/null +++ b/src/DLD-FUNCTIONS/conv2.cc @@ -0,0 +1,369 @@ +/* + +Copyright (C) 1999-2005 Andy Adler + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 2, or (at your option) any +later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, write to the Free +Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +02110-1301, USA. + +*/ + +#ifdef HAVE_CONFIG_H +#include +#endif + +#include "defun-dld.h" +#include "error.h" +#include "oct-obj.h" +#include "utils.h" + +#define MAX(a,b) ((a) > (b) ? (a) : (b)) + +enum Shape { SHAPE_FULL, SHAPE_SAME, SHAPE_VALID }; + +#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) +extern MArray2 +conv2 (MArray&, MArray&, MArray2&, Shape); + +extern MArray2 +conv2 (MArray&, MArray&, MArray2&, Shape); +#endif + +template +MArray2 +conv2 (MArray& R, MArray& C, MArray2& A, Shape ishape) +{ + octave_idx_type Rn = R.length (); + octave_idx_type Cm = C.length (); + octave_idx_type Am = A.rows (); + octave_idx_type An = A.columns (); + + // Calculate the size of the output matrix: + // in order to stay Matlab compatible, it is based + // on the third parameter if it's separable, and the + // first if it's not + + octave_idx_type outM = 0; + octave_idx_type outN = 0; + octave_idx_type edgM = 0; + octave_idx_type edgN = 0; + + switch (ishape) + { + case SHAPE_FULL: + outM = Am + Cm - 1; + outN = An + Rn - 1; + edgM = Cm - 1; + edgN = Rn - 1; + break; + + case SHAPE_SAME: + outM = Am; + outN = An; + // Follow the Matlab convention (ie + instead of -) + edgM = (Cm - 1) /2; + edgN = (Rn - 1) /2; + break; + + case SHAPE_VALID: + outM = Am - Cm + 1; + outN = An - Rn + 1; + if (outM < 0) + outM = 0; + if (outN < 0) + outN = 0; + edgM = edgN = 0; + break; + + default: + error ("conv2: invalid value of parameter ishape"); + } + + MArray2 O (outM, outN); + + // X accumulates the 1-D conv for each row, before calculating + // the convolution in the other direction + // There is no efficiency advantage to doing it in either direction + // first + + MArray X (An); + + for (octave_idx_type oi = 0; oi < outM; oi++) + { + for (octave_idx_type oj = 0; oj < An; oj++) + { + T sum = 0; + + octave_idx_type ci = Cm - 1 - MAX(0, edgM-oi); + octave_idx_type ai = MAX(0, oi-edgM); + const T* Ad = A.data() + ai + Am*oj; + const T* Cd = C.data() + ci; + for ( ; ci >= 0 && ai < Am; ci--, Cd--, ai++, Ad++) + sum += (*Ad) * (*Cd); + + X(oj) = sum; + } + + for (octave_idx_type oj = 0; oj < outN; oj++) + { + T sum = 0; + + octave_idx_type rj = Rn - 1 - MAX(0, edgN-oj); + octave_idx_type aj = MAX(0, oj-edgN) ; + const T* Xd = X.data() + aj; + const T* Rd = R.data() + rj; + + for ( ; rj >= 0 && aj < An; rj--, Rd--, aj++, Xd++) + sum += (*Xd) * (*Rd); + + O(oi,oj)= sum; + } + } + + return O; +} + +#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) +extern MArray2 +conv2 (MArray2&, MArray2&, Shape); + +extern MArray2 +conv2 (MArray2&, MArray2&, Shape); +#endif + +template +MArray2 +conv2 (MArray2&A, MArray2&B, Shape ishape) +{ + // Convolution works fastest if we choose the A matrix to be + // the largest. + + // Here we calculate the size of the output matrix, + // in order to stay Matlab compatible, it is based + // on the third parameter if it's separable, and the + // first if it's not + + // NOTE in order to be Matlab compatible, we give argueably + // wrong sizes for 'valid' if the smallest matrix is first + + octave_idx_type Am = A.rows (); + octave_idx_type An = A.columns (); + octave_idx_type Bm = B.rows (); + octave_idx_type Bn = B.columns (); + + octave_idx_type outM = 0; + octave_idx_type outN = 0; + octave_idx_type edgM = 0; + octave_idx_type edgN = 0; + + switch (ishape) + { + case SHAPE_FULL: + outM = Am + Bm - 1; + outN = An + Bn - 1; + edgM = Bm - 1; + edgN = Bn - 1; + break; + + case SHAPE_SAME: + outM = Am; + outN = An; + edgM = (Bm - 1) /2; + edgN = (Bn - 1) /2; + break; + + case SHAPE_VALID: + outM = Am - Bm + 1; + outN = An - Bn + 1; + if (outM < 0) + outM = 0; + if (outN < 0) + outN = 0; + edgM = edgN = 0; + break; + } + + MArray2 O (outM, outN); + + for (octave_idx_type oi = 0; oi < outM; oi++) + { + for (octave_idx_type oj = 0; oj < outN; oj++) + { + T sum = 0; + + for (octave_idx_type bj = Bn - 1 - MAX (0, edgN-oj), aj= MAX (0, oj-edgN); + bj >= 0 && aj < An; bj--, aj++) + { + octave_idx_type bi = Bm - 1 - MAX (0, edgM-oi); + octave_idx_type ai = MAX (0, oi-edgM); + const T* Ad = A.data () + ai + Am*aj; + const T* Bd = B.data () + bi + Bm*bj; + + for ( ; bi >= 0 && ai < Am; bi--, Bd--, ai++, Ad++) + { + sum += (*Ad) * (*Bd); + // Comment: it seems to be 2.5 x faster than this: + // sum+= A(ai,aj) * B(bi,bj); + } + } + + O(oi,oj) = sum; + } + } + + return O; +} + +/* +%!test +%! b = [0,1,2,3;1,8,12,12;4,20,24,21;7,22,25,18]; +%! assert(conv2([0,1;1,2],[1,2,3;4,5,6;7,8,9]),b); +*/ + +DEFUN_DLD (conv2, args, , + "-*- texinfo -*-\n\ +@deftypefn {Loadable Function} {y =} conv2 (@var{a}, @var{b}, @var{shape})\n\ +@deftypefnx {Loadable Function} {y =} conv2 (@var{v1}, @var{v2}, @var{M}, @var{shape})\n\ +\n\ +Returns 2D convolution of @var{a} and @var{b} where the size\n\ +of @var{c} is given by\n\ +\n\ +@table @asis\n\ +@item @var{shape}= 'full'\n\ +returns full 2-D convolution\n\ +@item @var{shape}= 'same'\n\ +same size as a. 'central' part of convolution\n\ +@item @var{shape}= 'valid'\n\ +only parts which do not include zero-padded edges\n\ +@end table\n\ +\n\ +By default @var{shape} is 'full'. When the third argument is a matrix\n\ +returns the convolution of the matrix @var{M} by the vector @var{v1}\n\ +in the column direction and by vector @var{v2} in the row direction\n\ +@end deftypefn") +{ + octave_value retval; + octave_value tmp; + int nargin = args.length (); + std::string shape= "full"; //default + bool separable= false; + Shape ishape; + + if (nargin < 2) + { + print_usage ("conv2"); + return retval; + } + else if (nargin == 3) + { + if args(2).is_string ()) + shape = args(2).string_value (); + else + separable = true; + } + else if (nargin >= 4) + { + separable = true; + shape = args(3).string_value (); + } + + if (shape == "full") + ishape = SHAPE_FULL; + else if (shape == "same") + ishape = SHAPE_SAME; + else if (shape == "valid") + ishape = SHAPE_VALID; + else + { + error("Shape type not valid"); + print_usage ("conv2"); + return retval; + } + + if (separable) + { + // If user requests separable, check first two params are vectors + + if (! (1 == args(0).rows () || 1 == args(0).columns ()) + || ! (1 == args(1).rows () || 1 == args(1).columns ())) + { + print_usage ("conv2"); + return retval; + } + + if (args(0).is_complex_type () + || args(1).is_complex_type () + || args(2).is_complex_type ()) + { + ComplexColumnVector v1 (args(0).complex_vector_value ()); + ComplexColumnVector v2 (args(1).complex_vector_value ()); + ComplexMatrix a (args(2).complex_matrix_value ()); + ComplexMatrix c (conv2 (v1, v2, a, ishape)); + if (! error_state) + retval = c; + } + else + { + ColumnVector v1 (args(0).vector_value ()); + ColumnVector v2 (args(1).vector_value ()); + Matrix a (args(2).matrix_value ()); + Matrix c (conv2 (v1, v2, a, ishape)); + if (! error_state) + retval = c; + } + } // if (separable) + else + { + if (args(0).is_complex_type () + || args(1).is_complex_type ()) + { + ComplexMatrix a (args(0).complex_matrix_value ()); + ComplexMatrix b (args(1).complex_matrix_value ()); + ComplexMatrix c (conv2 (a, b, ishape)); + if (! error_state) + retval = c; + } + else + { + Matrix a (args(0).matrix_value ()); + Matrix b (args(1).matrix_value ()); + Matrix c (conv2 (a, b, ishape)); + if (! error_state) + retval = c; + } + + } // if (separable) + + return retval; +} + +template MArray2 +conv2 (MArray&, MArray&, MArray2&, Shape); + +template MArray2 +conv2 (MArray2&, MArray2&, Shape); + +template MArray2 +conv2 (MArray&, MArray&, MArray2&, Shape); + +template MArray2 +conv2 (MArray2&, MArray2&, Shape); + +/* +;;; Local Variables: *** +;;; mode: C++ *** +;;; End: *** +*/ diff --git a/src/Makefile.in b/src/Makefile.in --- a/src/Makefile.in +++ b/src/Makefile.in @@ -43,15 +43,16 @@ OPT_HANDLERS := DASPK-opts.cc DASRT-opts.cc DASSL-opts.cc \ LSODE-opts.cc NLEqn-opts.cc Quad-opts.cc -DLD_XSRC := balance.cc besselj.cc betainc.cc cellfun.cc chol.cc ccolamd.cc \ - colamd.cc colloc.cc daspk.cc dasrt.cc dassl.cc det.cc dispatch.cc \ - eig.cc expm.cc fft.cc fft2.cc fftn.cc fftw_wisdom.cc \ - filter.cc find.cc fsolve.cc gammainc.cc gcd.cc getgrent.cc \ - getpwent.cc getrusage.cc givens.cc hess.cc inv.cc kron.cc \ - lpsolve.cc lsode.cc lu.cc luinc.cc matrix_type.cc minmax.cc \ - pinv.cc qr.cc quad.cc qz.cc rand.cc regexp.cc schur.cc sort.cc \ - sparse.cc spchol.cc spdet.cc spkron.cc splu.cc spparms.cc \ - spqr.cc sqrtm.cc svd.cc syl.cc time.cc \ +DLD_XSRC := balance.cc besselj.cc betainc.cc cellfun.cc chol.cc \ + ccolamd.cc colamd.cc colloc.cc conv2.cc daspk.cc dasrt.cc \ + dassl.cc det.cc dispatch.cc eig.cc expm.cc fft.cc fft2.cc \ + fftn.cc fftw_wisdom.cc filter.cc find.cc fsolve.cc \ + gammainc.cc gcd.cc getgrent.cc getpwent.cc getrusage.cc \ + givens.cc hess.cc inv.cc kron.cc lpsolve.cc lsode.cc \ + lu.cc luinc.cc matrix_type.cc minmax.cc pinv.cc qr.cc \ + quad.cc qz.cc rand.cc regexp.cc schur.cc sort.cc sparse.cc \ + spchol.cc spdet.cc spkron.cc splu.cc spparms.cc spqr.cc \ + sqrtm.cc svd.cc syl.cc time.cc \ __gnuplot_raw__.l __glpk__.cc __qp__.cc DLD_SRC := $(addprefix DLD-FUNCTIONS/, $(DLD_XSRC))