changeset 5988:d9ce802628e6

[project @ 2006-09-12 21:31:41 by dbateman]
author dbateman
date Tue, 12 Sep 2006 21:31:42 +0000
parents f1375e3f3b97
children e049385342f6
files src/ChangeLog src/DLD-FUNCTIONS/cellfun.cc
diffstat 2 files changed, 331 insertions(+), 36 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog
+++ b/src/ChangeLog
@@ -1,3 +1,12 @@
+2006-08-22  David Bateman  <dbateman@free.fr>
+
+	* DLD-FUNCTIONS/cellfun.cc (Fcellfun): Allow matlab compatiable
+	'UniformOutput' and 'ErrorHandler' options. Change output when
+	called with function handler or inline function to by default have
+	'UniformOutput' set to true. Allow functions with multiple inputs
+	and outputs. Add test code. Replace some int with octave_idx_type.
+	(Fnum2cell): Replace some int with octave_idx_type. Add test code.
+	
 2006-09-11  Yozo Hida  <yozo@cs.berkeley.edu>
 
 	* DLD-FUNCTIONS/gcd.cc (Fgcd): Extend range by using std::floor
--- a/src/DLD-FUNCTIONS/cellfun.cc
+++ b/src/DLD-FUNCTIONS/cellfun.cc
@@ -26,21 +26,32 @@
 #endif
 
 #include <string>
+#include <vector>
+#include <list>
 
 #include "lo-mappers.h"
 
 #include "Cell.h"
+#include "oct-map.h"
 #include "defun-dld.h"
 #include "parse.h"
 #include "variables.h"
 #include "ov-colon.h"
+#include "unwind-prot.h"
 
-DEFUN_DLD (cellfun, args, ,
+extern octave_value_list 
+Flasterr (const octave_value_list& args_name, int nargout_name);
+
+DEFUN_DLD (cellfun, args, nargout,
   " -*- texinfo -*-\n\
 @deftypefn {Lodable Function} {} cellfun (@var{name}, @var{c})\n\
 @deftypefnx {Lodable Function} {} cellfun (\"size\", @var{c}, @var{k})\n\
 @deftypefnx {Lodable Function} {} cellfun (\"isclass\", @var{c}, @var{class})\n\
 @deftypefnx {Lodable Function} {} cellfun (@var{func}, @var{c})\n\
+@deftypefnx {Lodable Function} {} cellfun (@var{func}, @var{c}, @var{d})\n\
+@deftypefnx {Lodable Function} {[@var{a}, @var{b}]} = cellfun (@dots{})\n\
+@deftypefnx {Lodable Function} {} cellfun (@dots{}, 'ErrorHandler',@var{errfunc})\n\
+@deftypefnx {Lodable Function} {} cellfun (@dots{}, 'UniformOutput',@var{val})\n\
 \n\
 Evaluate the function named @var{name} on the elements of the cell array\n\
 @var{c}.  Elements in @var{c} are passed on to the named function\n\
@@ -67,26 +78,63 @@
 \n\
 Additionally, @code{cellfun} accepts an arbitrary function @var{func}\n\
 in the form of an inline function, function handle, or the name of a\n\
-function (in a character string).  The function should take a single\n\
-argument and return a single value, and in the case of a character string\n\
-argument, the argument must be named @var{x}.  For example\n\
+function (in a character string). In the case of a character string\n\
+argument, the function must accept a single argument named @var{x}, and\n\
+it must return a string value. The function can take one or more arguments,\n\
+with the inputs args given by @var{c}, @var{d}, etc. Equally the function\n\
+can return one or more output arguments. For example\n\
 \n\
 @example\n\
 @group\n\
-cellfun (\"tolower(x)\", @{\"Foo\", \"Bar\", \"FooBar\"@})\n\
+cellfun (@@atan2, @{1, 0@}, @{0, 1@})\n\
+@result{}ans = [1.57080   0.00000]\n\
+@end group\n\
+@end example\n\
+\n\
+Note that the default output argument is an array of the same size as the\n\
+input arguments.\n\
+\n\
+If the param 'UniformOutput' is set to true (the default), then the function\n\
+must return either a single element which will be concatenated into the\n\
+return value. If 'UniformOutput is false, the outputs are concatenated in\n\
+a cell array. For example\n\
+\n\
+@example\n\
+@group\n\
+cellfun (\"tolower(x)\", @{\"Foo\", \"Bar\", \"FooBar\"@},'UniformOutput',false)\n\
 @result{} ans = @{\"foo\", \"bar\", \"foobar\"@}\n\
 @end group\n\
 @end example\n\
+\n\
+Given the parameter 'ErrorHandler', then @var{errfunc} defines a function to\n\
+call in case @var{func} generates an error. The form of the function is\n\
+\n\
+@example\n\
+function [@dots{}] = errfunc (@var{s}, @dots{})\n\
+@end example\n\
+\n\
+where there is an additional input argument to @var{errfunc} relative to\n\
+@var{func}, given by @var{s}. This is a structure with the elements\n\
+'identifier', 'message' and 'index', giving respectively the error\n\
+identifier, the error message, and the index into the input arguments\n\
+of the element that caused the error. For example\n\
+\n\
+@example\n\
+@group\n\
+function y = foo (s, x), y = NaN; endfunction\n\
+cellfun (@@factorial, @{-1,2@},'ErrorHandler',@@foo)\n\
+@result{} ans = [NaN 2]\n\
+@end group\n\
+@end example\n\
+\n\
 @seealso{isempty, islogical, isreal, length, ndims, numel, size, isclass}\n\
 @end deftypefn")
 {
-  octave_value retval;
-
+  octave_value_list retval;
   std::string name = "function";
-
   octave_function *func = 0;
-
   int nargin = args.length ();
+  nargout = (nargout < 1 ? 1 : nargout);
 
   if (nargin < 2)
     {
@@ -119,49 +167,49 @@
   
   Cell f_args = args(1).cell_value ();
   
-  int k = f_args.numel ();
+  octave_idx_type k = f_args.numel ();
 
   if (name == "isempty")
     {      
       boolNDArray result (f_args.dims ());
-      for (int count = 0; count < k ; count++)
+      for (octave_idx_type count = 0; count < k ; count++)
         result(count) = f_args.elem(count).is_empty ();
-      retval = result;
+      retval(0) = result;
     }
   else if (name == "islogical")
     {
       boolNDArray result (f_args.dims ());
-      for (int  count= 0; count < k ; count++)
+      for (octave_idx_type  count= 0; count < k ; count++)
         result(count) = f_args.elem(count).is_bool_type ();
-      retval = result;
+      retval(0) = result;
     }
   else if (name == "isreal")
     {
       boolNDArray result (f_args.dims ());
-      for (int  count= 0; count < k ; count++)
+      for (octave_idx_type  count= 0; count < k ; count++)
         result(count) = f_args.elem(count).is_real_type ();
-      retval = result;
+      retval(0) = result;
     }
   else if (name == "length")
     {
       NDArray result (f_args.dims ());
-      for (int  count= 0; count < k ; count++)
+      for (octave_idx_type  count= 0; count < k ; count++)
         result(count) = static_cast<double> (f_args.elem(count).length ());
-      retval = result;
+      retval(0) = result;
     }
   else if (name == "ndims")
     {
       NDArray result (f_args.dims ());
-      for (int count = 0; count < k ; count++)
+      for (octave_idx_type count = 0; count < k ; count++)
         result(count) = static_cast<double> (f_args.elem(count).ndims ());
-      retval = result;
+      retval(0) = result;
     }
   else if (name == "prodofsize")
     {
       NDArray result (f_args.dims ());
-      for (int count = 0; count < k ; count++)
+      for (octave_idx_type count = 0; count < k ; count++)
         result(count) = static_cast<double> (f_args.elem(count).numel ());
-      retval = result;
+      retval(0) = result;
     }
   else if (name == "size")
     {
@@ -175,7 +223,7 @@
 	  if (!error_state)
             {
               NDArray result (f_args.dims ());
-              for (int count = 0; count < k ; count++)
+              for (octave_idx_type count = 0; count < k ; count++)
                 {
                   dim_vector dv = f_args.elem(count).dims ();
                   if (d < dv.length ())
@@ -183,7 +231,7 @@
                   else
 	            result(count) = 1.0;
                 }
-              retval = result;
+              retval(0) = result;
             }
         }
       else
@@ -195,16 +243,19 @@
         {
           std::string class_name = args(2).string_value();
           boolNDArray result (f_args.dims ());
-          for (int count = 0; count < k ; count++)
+          for (octave_idx_type count = 0; count < k ; count++)
             result(count) = (f_args.elem(count).class_name() == class_name);
           
-          retval = result;
+          retval(0) = result;
         }
       else
         error ("not enough arguments for `isclass'");
     }
   else 
     {
+      unwind_protect::begin_frame ("Fcellfun");
+      unwind_protect_int (buffer_error_messages);
+
       std::string fcn_name;
       
       if (! func)
@@ -221,29 +272,256 @@
 	error ("unknown function");
       else
 	{
-	  Cell result (f_args.dims ());
+	  octave_value_list idx;
+	  octave_value_list inputlist;
+	  octave_value_list errlist;
+	  bool UniformOutput = true;
+	  bool haveErrorHandler = false;
+	  std::string err_name;
+	  octave_function *ErrorHandler;
+	  int offset = 1;
+	  int i = 1;
+	  OCTAVE_LOCAL_BUFFER (Cell, inputs, nargin);
 
-          for (int count = 0; count < k ; count++)
+	  while (i < nargin)
 	    {
-	      octave_value_list tmp
-		= func->do_multi_index_op (1, f_args.elem (count));
-	      result(count) = tmp(0);
+	      if (args(i).is_string())
+		{
+		  std::string arg = args(i++).string_value();
+		  if (i == nargin)
+		    {
+		      error ("cellfun: parameter value is missing");
+		      goto cellfun_err;
+		    }
+
+		  std::transform (arg.begin (), arg.end (), 
+				  arg.begin (), tolower);
+
+		  if (arg == "uniformoutput")
+		    UniformOutput = args(i++).bool_value();
+		  else if (arg == "errorhandler")
+		    {
+		      if (args(i).is_function_handle () || 
+			  args(i).is_inline_function ())
+			{
+			  ErrorHandler = args(i).function_value ();
 
-	      if (error_state)
-		break;
+			  if (error_state)
+			    goto cellfun_err;
+			}
+		      else if (args(i).is_string ())
+			{
+			  err_name = unique_symbol_name ("__cellfun_fcn_");
+			  std::string fname = "function y = ";
+			  fname.append (fcn_name);
+			  fname.append ("(x) y = ");
+			  ErrorHandler = extract_function (args(i), "cellfun", 
+							   err_name, fname,
+							   "; endfunction");
+			}
+
+		      if (!ErrorHandler)
+			goto cellfun_err;
+
+		      haveErrorHandler = true;
+		      i++;
+		    }
+		  else
+		    {
+		      error ("cellfun: unrecognized parameter %s", 
+			     arg.c_str());
+		      goto cellfun_err;
+		    }
+		  offset += 2;
+		}
+	      else
+		{
+		  inputs[i-offset] = args(i).cell_value ();
+		  if (f_args.dims() != inputs[i-offset].dims())
+		    {
+		      error ("cellfun: Dimension mismatch");
+		      goto cellfun_err;
+
+		    }
+		  i++;
+		}
 	    }
 
-	  if (! error_state)
-	    retval = result;
+	  inputlist.resize(nargin-offset);
+
+	  if (haveErrorHandler)
+	    buffer_error_messages++;
+
+	  if (UniformOutput)
+	    {
+	      retval.resize(nargout);
+
+	      for (octave_idx_type count = 0; count < k ; count++)
+		{
+		  for (int i = 0; i < nargin-offset; i++)
+		    inputlist(i) = inputs[i](count);
+
+		  octave_value_list tmp = feval (func, inputlist, nargout);
+
+		  if (error_state && haveErrorHandler)
+		    {
+		      octave_value_list errtmp = 
+			Flasterr (octave_value_list (), 2);
+
+		      Octave_map msg;
+		      msg.assign ("identifier", errtmp(1));
+		      msg.assign ("message", errtmp(0));
+		      msg.assign ("index", octave_value(double (count)));
+		      octave_value_list errlist = inputlist;
+		      errlist.prepend (msg);
+		      buffer_error_messages--;
+		      error_state = 0;
+		      tmp = feval (ErrorHandler, errlist, nargout);
+		      buffer_error_messages++;
+
+		      if (error_state)
+			goto cellfun_err;
+		    }
+
+		  if (tmp.length() < nargout)
+		    {
+		      error ("cellfun: too many output arguments");
+		      goto cellfun_err;
+		    }
+
+		  if (error_state)
+		    break;
+
+		  if (count == 0)
+		    {
+		      for (int i = 0; i < nargout; i++)
+			{
+			  octave_value val;
+			  val = tmp(i);
+
+			  if (error_state)
+			    goto cellfun_err;
+
+			  val.resize(f_args.dims());
+			  retval(i) = val;
+			}
+		    }
+		  else
+		    {
+		      idx(0) = octave_value (static_cast<double>(count+1));
+		      for (int i = 0; i < nargout; i++)
+			retval(i) = 
+			  retval(i).subsasgn ("(", 
+					      std::list<octave_value_list> 
+					      (1, idx(0)), tmp(i));
+		    }
+
+		  if (error_state)
+		    break;
+		}
+	    }
+	  else
+	    {
+	      OCTAVE_LOCAL_BUFFER (Cell, results, nargout);
+	      for (int i = 0; i < nargout; i++)
+		results[i].resize(f_args.dims());
+
+	      for (octave_idx_type count = 0; count < k ; count++)
+		{
+		  for (int i = 0; i < nargin-offset; i++)
+		    inputlist(i) = inputs[i](count);
+
+		  octave_value_list tmp = feval (func, inputlist, nargout);
+
+		  if (error_state && haveErrorHandler)
+		    {
+		      octave_value_list errtmp = 
+			Flasterr (octave_value_list (), 2);
+
+		      Octave_map msg;
+		      msg.assign ("identifier", errtmp(1));
+		      msg.assign ("message", errtmp(0));
+		      msg.assign ("index", octave_value(double (count)));
+		      octave_value_list errlist = inputlist;
+		      errlist.prepend (msg);
+		      buffer_error_messages--;
+		      error_state = 0;
+		      tmp = feval (ErrorHandler, errlist, nargout);
+		      buffer_error_messages++;
+
+		      if (error_state)
+			goto cellfun_err;
+		    }
+
+		  if (tmp.length() < nargout)
+		    {
+		      error ("cellfun: too many output arguments");
+		      goto cellfun_err;
+		    }
+
+		  if (error_state)
+		    break;
+
+
+		  for (int i = 0; i < nargout; i++)
+		    results[i](count) = tmp(i);
+		}
+
+	      retval.resize(nargout);
+	      for (int i = 0; i < nargout; i++)
+		retval(i) = results[i];
+	    }
+
+	cellfun_err:
+	  if (error_state)
+	    retval = octave_value_list();
 
 	  if (! fcn_name.empty ())
 	    clear_function (fcn_name);
+
+	  if (! err_name.empty ())
+	    clear_function (err_name);
 	}
+
+      unwind_protect::run_frame ("Fcellfun");
     }
 
   return retval;
 }
 
+/*
+
+%!error(cellfun(1))
+%!error(cellfun('isclass',1))
+%!error(cellfun('size',1))
+%!error(cellfun(@sin,{[]},'BadParam',false))
+%!error(cellfun(@sin,{[]},'UniformOuput'))
+%!error(cellfun(@sin,{[]},'ErrorHandler'))
+%!assert(cellfun(@sin,{0,1}),sin([0,1]))
+%!assert(cellfun(inline('sin(x)'),{0,1}),sin([0,1]))
+%!assert(cellfun('sin',{0,1}),sin([0,1]))
+%!assert(cellfun('isempty',{1,[]}),[false,true])
+%!assert(cellfun('islogical',{false,pi}),[true,false])
+%!assert(cellfun('isreal',{1i,1}),[false,true])
+%!assert(cellfun('length',{zeros(2,2),1}),[2,1])
+%!assert(cellfun('prodofsize',{zeros(2,2),1}),[4,1])
+%!assert(cellfun('ndims',{zeros([2,2,2]),1}),[3,2])
+%!assert(cellfun('isclass',{zeros([2,2,2]),'test'},'double'),[true,false])
+%!assert(cellfun('size',{zeros([1,2,3]),1},1),[1,1])
+%!assert(cellfun('size',{zeros([1,2,3]),1},2),[2,1])
+%!assert(cellfun('size',{zeros([1,2,3]),1},3),[3,1])
+%!assert(cellfun(@atan2,{1,1},{1,2}),[atan2(1,1),atan2(1,2)])
+%!assert(cellfun(@atan2,{1,1},{1,2},'UniformOutput',false),{atan2(1,1),atan2(1,2)})
+%!error(cellfun(@factorial,{-1,3}))
+%!assert(cellfun(@factorial,{-1,3},'ErrorHandler',@(x,y) NaN),[NaN,6])
+%!test
+%! [a,b,c]=cellfun(@fileparts,{'/a/b/c.d','/e/f/g.h'},'UniformOutput',false);
+%! assert(a,{'/a/b','/e/f'})
+%! assert(b,{'c','g'})
+%! assert(c,{'.d','.h'})
+
+*/
+
 DEFUN_DLD (num2cell, args, ,
   "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {@var{c} =} num2cell (@var{m})\n\
@@ -271,7 +549,7 @@
 	  sings.resize (dsings.length());
 
 	  if (!error_state)
-	    for (int i = 0; i < dsings.length(); i++)
+	    for (octave_idx_type i = 0; i < dsings.length(); i++)
 	      if (dsings(i) > dv.length() || dsings(i) < 1 ||
 		  D_NINT(dsings(i)) != dsings(i))
 		{
@@ -332,6 +610,14 @@
   return retval;
 }
 
+/*
+
+%!assert(num2cell([1,2;3,4]),{1,2;3,4})
+%!assert(num2cell([1,2;3,4],1),{[1;3],[2;4]})
+%!assert(num2cell([1,2;3,4],2),{[1,2];[3,4]})
+
+*/
+
 DEFUN_DLD (mat2cell, args, ,
   "-*- texinfo -*-\n\
 @deftypefn {Loadable Function} {@var{b} =} mat2cell (@var{a}, @var{m}, @var{n})\n\