Mercurial > hg > octave-nkf
diff src/DLD-FUNCTIONS/bsxfun.cc @ 6869:f9c893831e68
[project @ 2007-09-06 16:38:44 by dbateman]
author | dbateman |
---|---|
date | Thu, 06 Sep 2007 16:38:44 +0000 |
parents | |
children | cd2c6a69a70d |
line wrap: on
line diff
new file mode 100644 --- /dev/null +++ b/src/DLD-FUNCTIONS/bsxfun.cc @@ -0,0 +1,483 @@ +/* + +Copyright (C) 2007 David Bateman + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 2, or (at your option) any +later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, write to the Free +Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +02110-1301, USA. + +*/ + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#include <string> +#include <vector> +#include <list> + +#include "lo-mappers.h" + +#include "oct-map.h" +#include "defun-dld.h" +#include "parse.h" +#include "variables.h" +#include "ov-colon.h" +#include "unwind-prot.h" + +static bool +maybe_update_column (octave_value& Ac, const octave_value& A, + const dim_vector& dva, const dim_vector& dvc, + octave_idx_type i, octave_value_list &idx) +{ + octave_idx_type nd = dva.length (); + + if (i == 0) + { + idx(0) = octave_value (':'); + for (octave_idx_type j = 1; j < nd; j++) + { + if (dva (j) == 1) + idx (j) = octave_value (1); + else + idx (j) = octave_value ((i % dvc(j)) + 1); + + i = i / dvc (j); + } + + Ac = A; + Ac = Ac.single_subsref ("(", idx); + return true; + } + else + { + bool is_changed = false; + octave_idx_type k = i; + octave_idx_type k1 = i - 1; + for (octave_idx_type j = 1; j < nd; j++) + { + if (dva(j) != 1 && k % dvc (j) != k1 % dvc (j)) + { + idx (j) = octave_value ((k % dvc(j)) + 1); + is_changed = true; + } + + k = k / dvc (j); + k1 = k1 / dvc (j); + } + + if (is_changed) + { + Ac = A; + Ac = Ac.single_subsref ("(", idx); + return true; + } + else + return false; + } +} + +static void +update_index (octave_value_list& idx, const dim_vector& dv, octave_idx_type i) +{ + octave_idx_type nd = dv.length (); + + if (i == 0) + { + for (octave_idx_type j = nd - 1; j > 0; j--) + idx(j) = octave_value (static_cast<double>(1)); + idx(0) = octave_value (':'); + } + else + { + for (octave_idx_type j = 1; j < nd; j++) + { + idx (j) = octave_value (i % dv (j) + 1); + i = i / dv (j); + } + } +} + +static void +update_index (Array<int>& idx, const dim_vector& dv, octave_idx_type i) +{ + octave_idx_type nd = dv.length (); + + idx(0) = 0; + for (octave_idx_type j = 1; j < nd; j++) + { + idx (j) = i % dv (j); + i = i / dv (j); + } +} + +DEFUN_DLD (bsxfun, args, nargout, + " -*- texinfo -*-\n\ +@deftypefn {Lodable Function} {} bsxfun (@var{f}, @var{a}, @var{b})\n\ +Applies a binary function @var{f} element-wise to two matrix arguments\n\ +@var{a} and @var{b}. The function @var{f} must be capable of accepting\n\ +two column vector arguments of equal length, or one column vector\n\ +argument and a scalar.\n\ +\n\ +The dimensions of @var{a} and @var{b} must be equal or singleton. The\n\ +singleton dimensions a the matirces will be expanded to the same\n\ +dimensioanlity as the other matrix.\n\ +\n\ +@seealso{arrayfun, cellfun}\n\ +@end deftypefn") +{ + int nargin = args.length (); + octave_value_list retval; + + if (nargin != 3) + print_usage (); + else + { + octave_function *func = 0; + std::string name; + std::string fcn_name; + + if (args(0).is_function_handle () || args(0).is_inline_function ()) + func = args(0).function_value (); + else if (args(0).is_string ()) + { + name = args(0).string_value (); + fcn_name = unique_symbol_name ("__bsxfun_fcn_"); + std::string fname = "function y = "; + fname.append (fcn_name); + fname.append ("(x) y = "); + func = extract_function (args(0), "bsxfun", fcn_name, fname, + "; endfunction"); + } + else + error ("bsxfun: first argument must be a string or function handle"); + + if (! error_state) + { + const octave_value A = args (1); + dim_vector dva = A.dims (); + octave_idx_type nda = dva.length (); + const octave_value B = args (2); + dim_vector dvb = B.dims (); + octave_idx_type ndb = dvb.length (); + octave_idx_type nd = nda; + + if (nda > ndb) + dvb.resize (nda, 1); + else if (nda < ndb) + { + dva.resize (ndb, 1); + nd = ndb; + } + + for (octave_idx_type i = 0; i < nd; i++) + if (dva (i) != dvb (i) && dva (i) != 1 && dvb (i) != 1) + { + error ("bsxfun: dimensions don't match"); + break; + } + + if (!error_state) + { + // Find the size of the output + dim_vector dvc; + dvc.resize (nd); + + for (octave_idx_type i = 0; i < nd; i++) + dvc (i) = (dva (i) < 1 ? dva (i) : (dvb (i) < 1 ? dvb (i) : + (dva (i) > dvb (i) ? dva (i) : dvb (i)))); + + if (dva == dvb || dva.numel () == 1 || dvb.numel () == 1) + { + octave_value_list inputs; + inputs (0) = A; + inputs (1) = B; + retval = feval (func, inputs, 1); + } + else if (dvc.numel () < 1) + { + octave_value_list inputs; + inputs (0) = A.resize (dvc); + inputs (1) = B.resize (dvc); + retval = feval (func, inputs, 1); + } + else + { + octave_idx_type ncount = 1; + for (octave_idx_type i = 1; i < nd; i++) + ncount *= dvc (i); + +#define BSXDEF(T) \ + T result_ ## T; \ + bool have_ ## T = false; + + BSXDEF(NDArray); + BSXDEF(ComplexNDArray); + BSXDEF(boolNDArray); + BSXDEF(int8NDArray); + BSXDEF(int16NDArray); + BSXDEF(int32NDArray); + BSXDEF(int64NDArray); + BSXDEF(uint8NDArray); + BSXDEF(uint16NDArray); + BSXDEF(uint32NDArray); + BSXDEF(uint64NDArray); + + octave_value Ac ; + octave_value_list idxA; + octave_value Bc; + octave_value_list idxB; + octave_value C; + octave_value_list inputs; + Array<int> ra_idx (dvc.length(), 0); + + + for (octave_idx_type i = 0; i < ncount; i++) + { + if (maybe_update_column (Ac, A, dva, dvc, i, idxA)) + inputs (0) = Ac; + + if (maybe_update_column (Bc, B, dvb, dvc, i, idxB)) + inputs (1) = Bc; + + octave_value_list tmp = feval (func, inputs, 1); + + if (error_state) + break; + +#define BSXINIT(T, CLS, EXTRACTOR) \ + (result_type == CLS) \ + { \ + have_ ## T = true; \ + result_ ## T = \ + tmp (0). EXTRACTOR ## _array_value (); \ + result_ ## T .resize (dvc); \ + } + + if (i == 0) + { + if (! tmp(0).is_sparse_type ()) + { + std::string result_type = tmp(0).class_name (); + if (result_type == "double") + { + if (tmp(0).is_real_type ()) + { + have_NDArray = true; + result_NDArray = tmp(0).array_value (); + result_NDArray.resize (dvc); + } + else + { + have_ComplexNDArray = true; + result_ComplexNDArray = + tmp(0).complex_array_value (); + result_ComplexNDArray.resize (dvc); + } + } + else if BSXINIT(boolNDArray, "logical", bool) + else if BSXINIT(int8NDArray, "int8", int8) + else if BSXINIT(int16NDArray, "int16", int16) + else if BSXINIT(int32NDArray, "int32", int32) + else if BSXINIT(int64NDArray, "int64", int64) + else if BSXINIT(uint8NDArray, "uint8", uint8) + else if BSXINIT(uint16NDArray, "uint16", uint16) + else if BSXINIT(uint32NDArray, "uint32", uint32) + else if BSXINIT(uint64NDArray, "uint64", uint64) + else + { + C = tmp (0); + C = C.resize (dvc); + } + } + } + else + { + update_index (ra_idx, dvc, i); + + if (have_NDArray) + { + if (tmp(0).class_name () != "double") + { + have_NDArray = false; + C = result_NDArray; + C = do_cat_op (C, tmp(0), ra_idx); + } + else if (tmp(0).is_real_type ()) + result_NDArray.insert (tmp(0).array_value(), + ra_idx); + else + { + result_ComplexNDArray = + ComplexNDArray (result_NDArray); + result_ComplexNDArray.insert + (tmp(0).complex_array_value(), ra_idx); + have_NDArray = false; + have_ComplexNDArray = true; + } + } + +#define BSXLOOP(T, CLS, EXTRACTOR) \ + (have_ ## T) \ + { \ + if (tmp (0).class_name () != CLS) \ + { \ + have_ ## T = false; \ + C = result_ ## T; \ + C = do_cat_op (C, tmp (0), ra_idx); \ + } \ + else \ + result_ ## T .insert \ + (tmp(0). EXTRACTOR ## _array_value (), \ + ra_idx); \ + } + + else if BSXLOOP(ComplexNDArray, "double", complex) + else if BSXLOOP(boolNDArray, "logical", bool) + else if BSXLOOP(int8NDArray, "int8", int8) + else if BSXLOOP(int16NDArray, "int16", int16) + else if BSXLOOP(int32NDArray, "int32", int32) + else if BSXLOOP(int64NDArray, "int64", int64) + else if BSXLOOP(uint8NDArray, "uint8", uint8) + else if BSXLOOP(uint16NDArray, "uint16", uint16) + else if BSXLOOP(uint32NDArray, "uint32", uint32) + else if BSXLOOP(uint64NDArray, "uint64", uint64) + else + C = do_cat_op (C, tmp(0), ra_idx); + } + } + +#define BSXEND(T) \ + (have_ ## T) \ + retval (0) = result_ ## T; + + if BSXEND(NDArray) + else if BSXEND(ComplexNDArray) + else if BSXEND(boolNDArray) + else if BSXEND(int8NDArray) + else if BSXEND(int16NDArray) + else if BSXEND(int32NDArray) + else if BSXEND(int64NDArray) + else if BSXEND(uint8NDArray) + else if BSXEND(uint16NDArray) + else if BSXEND(uint32NDArray) + else if BSXEND(uint64NDArray) + else + retval(0) = C; + } + } + } + + if (! fcn_name.empty ()) + clear_function (fcn_name); + } + + return retval; +} + +/* + +%!shared a, b, c, f +%! a = randn (4, 4); +%! b = mean (a, 1); +%! c = mean (a, 2); +%! f = @minus; +%!error(bsxfun (f)); +%!error(bsxfun (f, a)); +%!error(bsxfun (a, b)); +%!error(bsxfun (a, b, c)); +%!error(bsxfun (f, a, b, c)); +%!error(bsxfun (f, ones(4, 0), ones(4, 4))) +%!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0)); +%!assert(bsxfun (f, ones(1, 4), ones(4, 1)), zeros(4, 4)); +%!assert(bsxfun (f, a, b), a - repmat(b, 4, 1)); +%!assert(bsxfun (f, a, c), a - repmat(c, 1, 4)); +%!assert(bsxfun ("minus", ones(1, 4), ones(4, 1)), zeros(4, 4)); + +%!shared a, b, c, f +%! a = randn (4, 4); +%! a(1) *= 1i; +%! b = mean (a, 1); +%! c = mean (a, 2); +%! f = @minus; +%!error(bsxfun (f)); +%!error(bsxfun (f, a)); +%!error(bsxfun (a, b)); +%!error(bsxfun (a, b, c)); +%!error(bsxfun (f, a, b, c)); +%!error(bsxfun (f, ones(4, 0), ones(4, 4))) +%!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0)); +%!assert(bsxfun (f, ones(1, 4), ones(4, 1)), zeros(4, 4)); +%!assert(bsxfun (f, a, b), a - repmat(b, 4, 1)); +%!assert(bsxfun (f, a, c), a - repmat(c, 1, 4)); +%!assert(bsxfun ("minus", ones(1, 4), ones(4, 1)), zeros(4, 4)); + +%!shared a, b, c, f +%! a = randn (4, 4); +%! a(end) *= 1i; +%! b = mean (a, 1); +%! c = mean (a, 2); +%! f = @minus; +%!error(bsxfun (f)); +%!error(bsxfun (f, a)); +%!error(bsxfun (a, b)); +%!error(bsxfun (a, b, c)); +%!error(bsxfun (f, a, b, c)); +%!error(bsxfun (f, ones(4, 0), ones(4, 4))) +%!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0)); +%!assert(bsxfun (f, ones(1, 4), ones(4, 1)), zeros(4, 4)); +%!assert(bsxfun (f, a, b), a - repmat(b, 4, 1)); +%!assert(bsxfun (f, a, c), a - repmat(c, 1, 4)); +%!assert(bsxfun ("minus", ones(1, 4), ones(4, 1)), zeros(4, 4)); + +%!shared a, b, c, f +%! a = randn (4, 4); +%! b = a (1, :); +%! c = a (:, 1); +%! f = @(x, y) x == y; +%!error(bsxfun (f)); +%!error(bsxfun (f, a)); +%!error(bsxfun (a, b)); +%!error(bsxfun (a, b, c)); +%!error(bsxfun (f, a, b, c)); +%!error(bsxfun (f, ones(4, 0), ones(4, 4))) +%!assert(bsxfun (f, ones(4, 0), ones(4, 1)), zeros(4, 0, "logical")); +%!assert(bsxfun (f, ones(1, 4), ones(4, 1)), ones(4, 4, "logical")); +%!assert(bsxfun (f, a, b), a == repmat(b, 4, 1)); +%!assert(bsxfun (f, a, c), a == repmat(c, 1, 4)); + +%!shared a, b, c, d, f +%! a = randn (4, 4, 4); +%! b = mean (a, 1); +%! c = mean (a, 2); +%! d = mean (a, 3); +%! f = @minus; +%!error(bsxfun (f, ones([4, 0, 4]), ones([4, 4, 4]))); +%!assert(bsxfun (f, ones([4, 0, 4]), ones([4, 1, 4])), zeros([4, 0, 4])); +%!assert(bsxfun (f, ones([4, 4, 0]), ones([4, 1, 1])), zeros([4, 4, 0])); +%!assert(bsxfun (f, ones([1, 4, 4]), ones([4, 1, 4])), zeros([4, 4, 4])); +%!assert(bsxfun (f, ones([4, 4, 1]), ones([4, 1, 4])), zeros([4, 4, 4])); +%!assert(bsxfun (f, ones([4, 1, 4]), ones([1, 4, 4])), zeros([4, 4, 4])); +%!assert(bsxfun (f, ones([4, 1, 4]), ones([1, 4, 1])), zeros([4, 4, 4])); +%!assert(bsxfun (f, a, b), a - repmat(b, [4, 1, 1])); +%!assert(bsxfun (f, a, c), a - repmat(c, [1, 4, 1])); +%!assert(bsxfun (f, a, d), a - repmat(d, [1, 1, 4])); +%!assert(bsxfun ("minus", ones([4, 0, 4]), ones([4, 1, 4])), zeros([4, 0, 4])); + +%% The below is a very hard case to treat +%!assert(bsxfun (f, ones([4, 1, 4, 1]), ones([1, 4, 1, 4])), zeros([4, 4, 4, 4])); + +*/