Mercurial > hg > octave-nkf
changeset 9756:b134960cea23
implement built-in tril/triu
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Fri, 23 Oct 2009 10:10:37 +0200 |
parents | 4f4873f6f873 |
children | 95ad9c2a27e2 |
files | scripts/ChangeLog scripts/general/Makefile.in scripts/general/tril.m scripts/general/triu.m src/ChangeLog src/DLD-FUNCTIONS/tril.cc src/Makefile.in |
diffstat | 7 files changed, 436 insertions(+), 185 deletions(-) [+] |
line wrap: on
line diff
--- a/scripts/ChangeLog +++ b/scripts/ChangeLog @@ -1,3 +1,8 @@ +2009-10-23 Jaroslav Hajek <highegg@gmail.com> + + * general/tril.m, general/triu.m: Remove sources. + * general/Makefile.in: Update. + 2009-10-20 Soren Hauberg <hauberg@gmail.com> * general/interp2.m: improved error checking and support for bicubic
--- a/scripts/general/Makefile.in +++ b/scripts/general/Makefile.in @@ -46,7 +46,7 @@ polyarea.m postpad.m prepad.m quadgk.m quadl.m quadv.m rat.m \ rem.m repmat.m rot90.m rotdim.m runlength.m saveobj.m shift.m shiftdim.m \ sortrows.m sph2cart.m strerror.m structfun.m subsindex.m \ - triplequad.m trapz.m tril.m triu.m + triplequad.m trapz.m DISTFILES = $(addprefix $(srcdir)/, Makefile.in $(SOURCES))
deleted file mode 100644 --- a/scripts/general/tril.m +++ /dev/null @@ -1,112 +0,0 @@ -## Copyright (C) 1993, 1994, 1995, 1996, 1997, 1999, 2000, 2004, 2005, -## 2006, 2007, 2008 John W. Eaton -## -## This file is part of Octave. -## -## Octave is free software; you can redistribute it and/or modify it -## under the terms of the GNU General Public License as published by -## the Free Software Foundation; either version 3 of the License, or (at -## your option) any later version. -## -## Octave is distributed in the hope that it will be useful, but -## WITHOUT ANY WARRANTY; without even the implied warranty of -## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -## General Public License for more details. -## -## You should have received a copy of the GNU General Public License -## along with Octave; see the file COPYING. If not, see -## <http://www.gnu.org/licenses/>. - -## -*- texinfo -*- -## @deftypefn {Function File} {} tril (@var{a}, @var{k}) -## @deftypefnx {Function File} {} triu (@var{a}, @var{k}) -## Return a new matrix formed by extracting the lower (@code{tril}) -## or upper (@code{triu}) triangular part of the matrix @var{a}, and -## setting all other elements to zero. The second argument is optional, -## and specifies how many diagonals above or below the main diagonal should -## also be set to zero. -## -## The default value of @var{k} is zero, so that @code{triu} and -## @code{tril} normally include the main diagonal as part of the result -## matrix. -## -## If the value of @var{k} is negative, additional elements above (for -## @code{tril}) or below (for @code{triu}) the main diagonal are also -## selected. -## -## The absolute value of @var{k} must not be greater than the number of -## sub- or super-diagonals. -## -## For example, -## -## @example -## @group -## tril (ones (3), -1) -## @result{} 0 0 0 -## 1 0 0 -## 1 1 0 -## @end group -## @end example -## -## @noindent -## and -## -## @example -## @group -## tril (ones (3), 1) -## @result{} 1 1 0 -## 1 1 1 -## 1 1 1 -## @end group -## @end example -## @seealso{triu, diag} -## @end deftypefn - -## Author: jwe - -function retval = tril (x, k) - - if (nargin > 0) - if (isstruct (x)) - error ("tril: structure arrays not supported"); - endif - [nr, nc] = size (x); - endif - - if (nargin == 1) - k = 0; - elseif (nargin == 2) - if ((k > 0 && k > nc) || (k < 0 && k < -nr)) - error ("tril: requested diagonal out of range"); - endif - else - print_usage (); - endif - - retval = resize (resize (x, 0), nr, nc); - for j = 1 : min (nc, nr+k) - nr_limit = max (1, j-k); - retval (nr_limit:nr, j) = x (nr_limit:nr, j); - endfor - -endfunction - -%!test -%! a = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12]; -%! -%! l0 = [1, 0, 0; 4, 5, 0; 7, 8, 9; 10, 11, 12]; -%! l1 = [1, 2, 0; 4, 5, 6; 7, 8, 9; 10, 11, 12]; -%! l2 = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12]; -%! lm1 = [0, 0, 0; 4, 0, 0; 7, 8, 0; 10, 11, 12]; -%! lm2 = [0, 0, 0; 0, 0, 0; 7, 0, 0; 10, 11, 0]; -%! lm3 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 10, 0, 0]; -%! lm4 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 0, 0, 0]; -%! -%! assert((tril (a, -4) == lm4 && tril (a, -3) == lm3 -%! && tril (a, -2) == lm2 && tril (a, -1) == lm1 -%! && tril (a) == l0 && tril (a, 1) == l1 && tril (a, 2) == l2)); - -%!error tril (); - -%!error tril (1, 2, 3); -
deleted file mode 100644 --- a/scripts/general/triu.m +++ /dev/null @@ -1,71 +0,0 @@ -## Copyright (C) 1993, 1994, 1995, 1996, 1997, 1999, 2000, 2005, 2006, -## 2007, 2008 John W. Eaton -## -## This file is part of Octave. -## -## Octave is free software; you can redistribute it and/or modify it -## under the terms of the GNU General Public License as published by -## the Free Software Foundation; either version 3 of the License, or (at -## your option) any later version. -## -## Octave is distributed in the hope that it will be useful, but -## WITHOUT ANY WARRANTY; without even the implied warranty of -## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -## General Public License for more details. -## -## You should have received a copy of the GNU General Public License -## along with Octave; see the file COPYING. If not, see -## <http://www.gnu.org/licenses/>. - -## -*- texinfo -*- -## @deftypefn {Function File} {} triu (@var{a}, @var{k}) -## See tril. -## @end deftypefn - -## Author: jwe - -function retval = triu (x, k) - - if (nargin > 0) - if (isstruct (x)) - error ("tril: structure arrays not supported"); - endif - [nr, nc] = size (x); - endif - if (nargin == 1) - k = 0; - elseif (nargin == 2) - if ((k > 0 && k > nc) || (k < 0 && k < -nr)) - error ("triu: requested diagonal out of range"); - endif - else - print_usage (); - endif - - retval = resize (resize (x, 0), nr, nc); - for j = max (1, k+1) : nc - nr_limit = min (nr, j-k); - retval (1:nr_limit, j) = x (1:nr_limit, j); - endfor - -endfunction - -%!test -%! a = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12]; -%! -%! u0 = [1, 2, 3; 0, 5, 6; 0, 0, 9; 0, 0, 0]; -%! u1 = [0, 2, 3; 0, 0, 6; 0, 0, 0; 0, 0, 0]; -%! u2 = [0, 0, 3; 0, 0, 0; 0, 0, 0; 0, 0, 0]; -%! u3 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 0, 0, 0]; -%! um1 = [1, 2, 3; 4, 5, 6; 0, 8, 9; 0, 0, 12]; -%! um2 = [1, 2, 3; 4, 5, 6; 7, 8, 9; 0, 11, 12]; -%! um3 = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12]; -%! -%! assert((triu (a, -3) == um3 && triu (a, -2) == um2 -%! && triu (a, -1) == um1 && triu (a) == u0 && triu (a, 1) == u1 -%! && triu (a, 2) == u2 && triu (a, 3) == u3)); - -%!error triu (); - -%!error triu (1, 2, 3); -
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,8 @@ +2009-10-23 Jaroslav Hajek <highegg@gmail.com> + + * DLD-FUNCTIONS/tril.cc: New source. + * Makefile.in: Include it. + 2009-10-22 Jaroslav Hajek <highegg@gmail.com> * error.cc (verror (bool, std::ostream&, ..., bool)): Add optional
new file mode 100644 --- /dev/null +++ b/src/DLD-FUNCTIONS/tril.cc @@ -0,0 +1,424 @@ +/* + +Copyright (C) 2004, 2007 David Bateman +Copyright (C) 2009 VZLU Prague + +This program is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 3, or (at your option) +any later version. + +This program is distributed in the hope that it will be useful, but +WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +General Public License for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +<http://www.gnu.org/licenses/>. + +*/ + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#include <algorithm> +#include "Array.h" +#include "Sparse.h" +#include "mx-base.h" + +#include "ov.h" +#include "Cell.h" + +#include "defun-dld.h" +#include "error.h" +#include "oct-obj.h" + +// The bulk of the work. +template <class T> +static Array<T> +do_tril (const Array<T>& a, octave_idx_type k, bool pack) +{ + octave_idx_type nr = a.rows (), nc = a.columns (); + const T *avec = a.fortran_vec (); + + if (pack) + { + octave_idx_type j1 = std::min (std::max (0, k), nc); + octave_idx_type j2 = std::min (std::max (0, nr + k), nc); + octave_idx_type n = j1 * nr + ((j2 - j1) * (nr-(j1-k) + nr-(j2-1-k))) / 2; + Array<T> r (n); + T *rvec = r.fortran_vec (); + for (octave_idx_type j = 0; j < nc; j++) + { + octave_idx_type ii = std::min (std::max (0, j - k), nr); + rvec = std::copy (avec + ii, avec + nr, rvec); + avec += nr; + } + + return r; + } + else + { + Array<T> r (a.dims ()); + T *rvec = r.fortran_vec (); + for (octave_idx_type j = 0; j < nc; j++) + { + octave_idx_type ii = std::min (std::max (0, j - k), nr); + std::fill (rvec, rvec + ii, T()); + std::copy (avec + ii, avec + nr, rvec + ii); + avec += nr; + rvec += nr; + } + + return r; + } +} + +template <class T> +static Array<T> +do_triu (const Array<T>& a, octave_idx_type k, bool pack) +{ + octave_idx_type nr = a.rows (), nc = a.columns (); + const T *avec = a.fortran_vec (); + + if (pack) + { + octave_idx_type j1 = std::min (std::max (0, k), nc); + octave_idx_type j2 = std::min (std::max (0, nr + k), nc); + octave_idx_type n = ((j2 - j1) * ((j1+1-k) + (j2-k))) / 2 + (nc - j2) * nr; + Array<T> r (n); + T *rvec = r.fortran_vec (); + for (octave_idx_type j = 0; j < nc; j++) + { + octave_idx_type ii = std::min (std::max (0, j + 1 - k), nr); + rvec = std::copy (avec, avec + ii, rvec); + avec += nr; + } + + return r; + } + else + { + NoAlias<Array<T> > r (a.dims ()); + T *rvec = r.fortran_vec (); + for (octave_idx_type j = 0; j < nc; j++) + { + octave_idx_type ii = std::min (std::max (0, j + 1 - k), nr); + std::copy (avec, avec + ii, rvec); + std::fill (rvec + ii, rvec + nr, T()); + avec += nr; + rvec += nr; + } + + return r; + } +} + +// These two are by David Bateman. +// FIXME: optimizations possible. "pack" support missing. + +template <class T> +static Sparse<T> +do_tril (const Sparse<T>& a, octave_idx_type k, bool pack) +{ + if (pack) // FIXME + { + error ("tril: \"pack\" not implemented for sparse matrices"); + return Sparse<T> (); + } + + Sparse<T> m = a; + octave_idx_type nc = m.cols(); + + for (octave_idx_type j = 0; j < nc; j++) + for (octave_idx_type i = m.cidx(j); i < m.cidx(j+1); i++) + if (m.ridx(i) < j-k) + m.data(i) = 0.; + + m.maybe_compress (true); + return m; +} + +template <class T> +static Sparse<T> +do_triu (const Sparse<T>& a, octave_idx_type k, bool pack) +{ + if (pack) // FIXME + { + error ("triu: \"pack\" not implemented for sparse matrices"); + return Sparse<T> (); + } + + Sparse<T> m = a; + octave_idx_type nc = m.cols(); + + for (octave_idx_type j = 0; j < nc; j++) + for (octave_idx_type i = m.cidx(j); i < m.cidx(j+1); i++) + if (m.ridx(i) > j-k) + m.data(i) = 0.; + + m.maybe_compress (true); + return m; +} + +// Convenience dispatchers. +template <class T> +static Array<T> +do_trilu (const Array<T>& a, octave_idx_type k, bool lower, bool pack) +{ + return lower ? do_tril (a, k, pack) : do_triu (a, k, pack); +} + +template <class T> +static Sparse<T> +do_trilu (const Sparse<T>& a, octave_idx_type k, bool lower, bool pack) +{ + return lower ? do_tril (a, k, pack) : do_triu (a, k, pack); +} + +static octave_value +do_trilu (const std::string& name, + const octave_value_list& args) +{ + bool lower = name == "tril"; + + octave_value retval; + int nargin = args.length (); + octave_idx_type k = 0; + bool pack = false; + if (nargin >= 2 && args(nargin-1).is_string ()) + { + pack = args(nargin-1).string_value () == "pack"; + nargin--; + } + + if (nargin == 2) + { + k = args(1).int_value (true); + + if (error_state) + return retval; + } + + if (nargin < 1 || nargin > 2) + print_usage (); + else + { + octave_value arg = args (0); + + dim_vector dims = arg.dims (); + if (dims.length () != 2) + error ("%s: needs a 2D matrix", name.c_str ()); + else if (k < -dims (0) || k > dims(1)) + error ("%s: requested diagonal out of range", name.c_str ()); + else + { + switch (arg.builtin_type ()) + { + case btyp_double: + if (arg.is_sparse_type ()) + retval = do_trilu (arg.sparse_matrix_value (), k, lower, pack); + else + retval = do_trilu (arg.array_value (), k, lower, pack); + break; + case btyp_complex: + if (arg.is_sparse_type ()) + retval = do_trilu (arg.sparse_complex_matrix_value (), k, lower, pack); + else + retval = do_trilu (arg.complex_array_value (), k, lower, pack); + break; + case btyp_bool: + if (arg.is_sparse_type ()) + retval = do_trilu (arg.sparse_bool_matrix_value (), k, lower, pack); + else + retval = do_trilu (arg.bool_array_value (), k, lower, pack); + break; +#define ARRAYCASE(TYP) \ + case btyp_ ## TYP: \ + retval = do_trilu (arg.TYP ## _array_value (), k, lower, pack); \ + break + ARRAYCASE (float); + ARRAYCASE (float_complex); + ARRAYCASE (int8); + ARRAYCASE (int16); + ARRAYCASE (int32); + ARRAYCASE (int64); + ARRAYCASE (uint8); + ARRAYCASE (uint16); + ARRAYCASE (uint32); + ARRAYCASE (uint64); + ARRAYCASE (char); +#undef ARRAYCASE + default: + { + // Generic code that works on octave-values, that is slow + // but will also work on arbitrary user types + + if (pack) // FIXME + { + error ("%s: \"pack\" not implemented for class %s", + name.c_str (), arg.class_name ().c_str ()); + return octave_value (); + } + + octave_value tmp = arg; + if (arg.numel () == 0) + return arg; + + octave_idx_type nr = dims(0), nc = dims (1); + + // The sole purpose of the below is to force the correct + // matrix size. This would not be necessary if the + // octave_value resize function allowed a fill_value. + // It also allows odd attributes in some user types + // to be handled. With a fill_value ot should be replaced + // with + // + // octave_value_list ov_idx; + // tmp = tmp.resize(dim_vector (0,0)).resize (dims, fill_value); + + octave_value_list ov_idx; + std::list<octave_value_list> idx_tmp; + ov_idx(1) = static_cast<double> (nc+1); + ov_idx(0) = Range (1, nr); + idx_tmp.push_back (ov_idx); + ov_idx(1) = static_cast<double> (nc); + tmp = tmp.resize (dim_vector (0,0)); + tmp = tmp.subsasgn("(",idx_tmp, arg.do_index_op (ov_idx)); + tmp = tmp.resize(dims); + + if (lower) + { + octave_idx_type st = nc < nr + k ? nc : nr + k; + + for (octave_idx_type j = 1; j <= st; j++) + { + octave_idx_type nr_limit = 1 > j - k ? 1 : j - k; + ov_idx(1) = static_cast<double> (j); + ov_idx(0) = Range (nr_limit, nr); + std::list<octave_value_list> idx; + idx.push_back (ov_idx); + + tmp = tmp.subsasgn ("(", idx, arg.do_index_op(ov_idx)); + + if (error_state) + return retval; + } + } + else + { + octave_idx_type st = k + 1 > 1 ? k + 1 : 1; + + for (octave_idx_type j = st; j <= nc; j++) + { + octave_idx_type nr_limit = nr < j - k ? nr : j - k; + ov_idx(1) = static_cast<double> (j); + ov_idx(0) = Range (1, nr_limit); + std::list<octave_value_list> idx; + idx.push_back (ov_idx); + + tmp = tmp.subsasgn ("(", idx, arg.do_index_op(ov_idx)); + + if (error_state) + return retval; + } + } + + retval = tmp; + } + } + } + } + + return retval; +} + +DEFUN_DLD (tril, args, , + "-*- texinfo -*-\n\ +@deftypefn {Function File} {} tril (@var{a}, @var{k}, @var{pack})\n\ +@deftypefnx {Function File} {} triu (@var{a}, @var{k}, @var{pack})\n\ +Return a new matrix formed by extracting extract the lower (@code{tril})\n\ +or upper (@code{triu}) triangular part of the matrix @var{a}, and\n\ +setting all other elements to zero. The second argument is optional,\n\ +and specifies how many diagonals above or below the main diagonal should\n\ +also be set to zero.\n\ +\n\ +The default value of @var{k} is zero, so that @code{triu} and\n\ +@code{tril} normally include the main diagonal as part of the result\n\ +matrix.\n\ +\n\ +If the value of @var{k} is negative, additional elements above (for\n\ +@code{tril}) or below (for @code{triu}) the main diagonal are also\n\ +selected.\n\ +\n\ +The absolute value of @var{k} must not be greater than the number of\n\ +sub- or super-diagonals.\n\ +\n\ +For example,\n\ +\n\ +@example\n\ +@group\n\ +tril (ones (3), -1)\n\ + @result{} 0 0 0\n\ + 1 0 0\n\ + 1 1 0\n\ +@end group\n\ +@end example\n\ +\n\ +@noindent\n\ +and\n\ +\n\ +@example\n\ +@group\n\ +tril (ones (3), 1)\n\ + @result{} 1 1 0\n\ + 1 1 1\n\ + 1 1 1\n\ +@end group\n\ +@end example\n\ +\n\ +If the option \"pack\" is given as third argument, the extracted elements\n\ +are not inserted into a matrix, but rather stacked column-wise one above other.\n\ +@seealso{triu, diag}\n\ +@end deftypefn") +{ + return do_trilu ("tril", args); +} + +DEFUN_DLD (triu, args, , + "-*- texinfo -*-\n\ +@deftypefn {Function File} {} triu (@var{a}, @var{k})\n\ +See tril.\n\ +@end deftypefn") +{ + return do_trilu ("triu", args); +} + +/* + +%!test +%! a = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12]; +%! +%! l0 = [1, 0, 0; 4, 5, 0; 7, 8, 9; 10, 11, 12]; +%! l1 = [1, 2, 0; 4, 5, 6; 7, 8, 9; 10, 11, 12]; +%! l2 = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12]; +%! lm1 = [0, 0, 0; 4, 0, 0; 7, 8, 0; 10, 11, 12]; +%! lm2 = [0, 0, 0; 0, 0, 0; 7, 0, 0; 10, 11, 0]; +%! lm3 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 10, 0, 0]; +%! lm4 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 0, 0, 0]; +%! +%! assert((tril (a, -4) == lm4 && tril (a, -3) == lm3 +%! && tril (a, -2) == lm2 && tril (a, -1) == lm1 +%! && tril (a) == l0 && tril (a, 1) == l1 && tril (a, 2) == l2)); + +%!error tril (); + +*/ + +/* +;;; Local Variables: *** +;;; mode: C++ *** +;;; End: *** +*/
--- a/src/Makefile.in +++ b/src/Makefile.in @@ -73,7 +73,7 @@ lsode.cc lu.cc luinc.cc matrix_type.cc max.cc md5sum.cc \ pinv.cc qr.cc quad.cc qz.cc rand.cc rcond.cc regexp.cc \ schur.cc sparse.cc spparms.cc sqrtm.cc sub2ind.cc svd.cc syl.cc \ - symrcm.cc symbfact.cc time.cc tsearch.cc typecast.cc \ + symrcm.cc symbfact.cc time.cc tril.cc tsearch.cc typecast.cc \ urlwrite.cc __contourc__.cc __delaunayn__.cc \ __dsearchn__.cc __glpk__.cc __lin_interpn__.cc \ __magick_read__.cc __pchip_deriv__.cc __qp__.cc \