diff src/data.cc @ 10436:00219bdd2d17

implement built-in rem and mod
author Jaroslav Hajek <highegg@gmail.com>
date Tue, 23 Mar 2010 13:01:34 +0100
parents 6a271334750c
children 8615b55b5caf
line wrap: on
line diff
--- a/src/data.cc
+++ b/src/data.cc
@@ -473,13 +473,20 @@
 %! assert (e(1:2,:), [0,1; 2,3]);
 */
 
-DEFUN (fmod, args, ,
+DEFUN (rem, args, ,
   "-*- texinfo -*-\n\
-@deftypefn {Mapping Function} {} fmod (@var{x}, @var{y})\n\
-Compute the floating point remainder of dividing @var{x} by @var{y}\n\
-using the C library function @code{fmod}.  The result has the same\n\
-sign as @var{x}.  If @var{y} is zero, the result is implementation-dependent.\n\
-@seealso{mod, rem}\n\
+@deftypefn {Mapping Function} {} rem (@var{x}, @var{y})\n\
+@deftypefnx {Mapping Function} {} fmod (@var{x}, @var{y})\n\
+Return the remainder of the division @code{@var{x} / @var{y}}, computed \n\
+using the expression\n\
+\n\
+@example\n\
+x - y .* fix (x ./ y)\n\
+@end example\n\
+\n\
+An error message is printed if the dimensions of the arguments do not\n\
+agree, or if either of the arguments is complex.\n\
+@seealso{mod}\n\
 @end deftypefn")
 {
   octave_value retval;
@@ -489,11 +496,49 @@
   if (nargin == 2)
     {
       if (! args(0).is_numeric_type ())
-        gripe_wrong_type_arg ("fmod", args(0));
+        gripe_wrong_type_arg ("rem", args(0));
       else if (! args(1).is_numeric_type ())
-        gripe_wrong_type_arg ("fmod", args(1));
+        gripe_wrong_type_arg ("rem", args(1));
       else if (args(0).is_complex_type () || args(1).is_complex_type ())
-        error ("fmod: not defined for complex numbers");
+        error ("rem: not defined for complex numbers");
+      else if (args(0).is_integer_type () || args(1).is_integer_type ())
+        {
+          builtin_type_t btyp0 = args(0).builtin_type ();
+          builtin_type_t btyp1 = args(1).builtin_type ();
+          if (btyp0 == btyp_double || btyp0 == btyp_float)
+            btyp0 = btyp1;
+          if (btyp1 == btyp_double || btyp1 == btyp_float)
+            btyp1 = btyp0;
+
+          if (btyp0 == btyp1)
+            {
+              switch (btyp0)
+                {
+#define MAKE_INT_BRANCH(X) \
+                case btyp_ ## X: \
+                    { \
+                    X##NDArray a0 = args(0).X##_array_value (); \
+                    X##NDArray a1 = args(1).X##_array_value (); \
+                    retval = binmap<octave_##X> (a0, a1, rem, "rem"); \
+                    } \
+                  break
+                MAKE_INT_BRANCH (int8);
+                MAKE_INT_BRANCH (int16);
+                MAKE_INT_BRANCH (int32);
+                MAKE_INT_BRANCH (int64);
+                MAKE_INT_BRANCH (uint8);
+                MAKE_INT_BRANCH (uint16);
+                MAKE_INT_BRANCH (uint32);
+                MAKE_INT_BRANCH (uint64);
+#undef MAKE_INT_BRANCH
+                default:
+                  panic_impossible ();
+                }
+            }
+          else
+            error ("rem: cannot combine %s and %d", 
+                   args(0).class_name ().c_str (), args(1).class_name ().c_str ());
+        }
       else if (args(0).is_single_type () || args(1).is_single_type ())
         {
           if (args(0).is_scalar_type () && args(1).is_scalar_type ())
@@ -502,7 +547,7 @@
             {
               FloatNDArray a0 = args(0).float_array_value ();
               FloatNDArray a1 = args(1).float_array_value ();
-              retval = binmap<float> (a0, a1, ::fmodf, "fmod");
+              retval = binmap<float> (a0, a1, fmodf, "rem");
             }
         }
       else
@@ -516,13 +561,13 @@
             {
               SparseMatrix m0 = args(0).sparse_matrix_value ();
               SparseMatrix m1 = args(1).sparse_matrix_value ();
-              retval = binmap<double> (m0, m1, ::fmod, "fmod");
+              retval = binmap<double> (m0, m1, fmod, "rem");
             }
           else
             {
               NDArray a0 = args(0).array_value ();
               NDArray a1 = args(1).array_value ();
-              retval = binmap<double> (a0, a1, ::fmod, "fmod");
+              retval = binmap<double> (a0, a1, fmod, "rem");
             }
         }
     }
@@ -533,6 +578,21 @@
 }
 
 /*
+
+%!assert(rem ([1, 2, 3; -1, -2, -3], 2), [1, 0, 1; -1, 0, -1]);
+%!assert(rem ([1, 2, 3; -1, -2, -3], 2 * ones (2, 3)),[1, 0, 1; -1, 0, -1]);
+%!error rem ();
+%!error rem (1, 2, 3);
+%!error rem ([1, 2], [3, 4, 5]);
+%!error rem (i, 1);
+%!assert(rem (uint8([1, 2, 3; -1, -2, -3]), uint8 (2)), uint8([1, 0, 1; -1, 0, -1]));
+%!assert(uint8(rem ([1, 2, 3; -1, -2, -3], 2 * ones (2, 3))),uint8([1, 0, 1; -1, 0, -1]));
+%!error rem (uint(8),int8(5));
+%!error rem (uint8([1, 2]), uint8([3, 4, 5]));
+
+*/
+
+/*
 %!assert (size (fmod (zeros (0, 2), zeros (0, 2))), [0, 2])
 %!assert (size (fmod (rand (2, 3, 4), zeros (2, 3, 4))), [2, 3, 4])
 %!assert (size (fmod (rand (2, 3, 4), 1)), [2, 3, 4])
@@ -540,6 +600,154 @@
 %!assert (size (fmod (1, 2)), [1, 1])
 */
 
+DEFALIAS (fmod, rem)
+
+DEFUN (mod, args, ,
+  "-*- texinfo -*-\n\
+@deftypefn {Mapping Function} {} mod (@var{x}, @var{y})\n\
+Compute the modulo of @var{x} and @var{y}.  Conceptually this is given by\n\
+\n\
+@example\n\
+x - y .* floor (x ./ y)\n\
+@end example\n\
+\n\
+and is written such that the correct modulus is returned for\n\
+integer types.  This function handles negative values correctly.  That\n\
+is, @code{mod (-1, 3)} is 2, not -1, as @code{rem (-1, 3)} returns.\n\
+@code{mod (@var{x}, 0)} returns @var{x}.\n\
+\n\
+An error results if the dimensions of the arguments do not agree, or if\n\
+either of the arguments is complex.\n\
+@seealso{rem}\n\
+@end deftypefn")
+{
+  octave_value retval;
+
+  int nargin = args.length ();
+
+  if (nargin == 2)
+    {
+      if (! args(0).is_numeric_type ())
+        gripe_wrong_type_arg ("mod", args(0));
+      else if (! args(1).is_numeric_type ())
+        gripe_wrong_type_arg ("mod", args(1));
+      else if (args(0).is_complex_type () || args(1).is_complex_type ())
+        error ("mod: not defined for complex numbers");
+      else if (args(0).is_integer_type () || args(1).is_integer_type ())
+        {
+          builtin_type_t btyp0 = args(0).builtin_type ();
+          builtin_type_t btyp1 = args(1).builtin_type ();
+          if (btyp0 == btyp_double || btyp0 == btyp_float)
+            btyp0 = btyp1;
+          if (btyp1 == btyp_double || btyp1 == btyp_float)
+            btyp1 = btyp0;
+
+          if (btyp0 == btyp1)
+            {
+              switch (btyp0)
+                {
+#define MAKE_INT_BRANCH(X) \
+                case btyp_ ## X: \
+                    { \
+                    X##NDArray a0 = args(0).X##_array_value (); \
+                    X##NDArray a1 = args(1).X##_array_value (); \
+                    retval = binmap<octave_##X> (a0, a1, mod, "mod"); \
+                    } \
+                  break
+                MAKE_INT_BRANCH (int8);
+                MAKE_INT_BRANCH (int16);
+                MAKE_INT_BRANCH (int32);
+                MAKE_INT_BRANCH (int64);
+                MAKE_INT_BRANCH (uint8);
+                MAKE_INT_BRANCH (uint16);
+                MAKE_INT_BRANCH (uint32);
+                MAKE_INT_BRANCH (uint64);
+#undef MAKE_INT_BRANCH
+                default:
+                  panic_impossible ();
+                }
+            }
+          else
+            error ("mod: cannot combine %s and %d", 
+                   args(0).class_name ().c_str (), args(1).class_name ().c_str ());
+        }
+      else if (args(0).is_single_type () || args(1).is_single_type ())
+        {
+          if (args(0).is_scalar_type () && args(1).is_scalar_type ())
+            retval = mod (args(0).float_value (), args(1).float_value ());
+          else
+            {
+              FloatNDArray a0 = args(0).float_array_value ();
+              FloatNDArray a1 = args(1).float_array_value ();
+              retval = binmap<float> (a0, a1, mod, "mod");
+            }
+        }
+      else
+        {
+          bool a0_scalar = args(0).is_scalar_type ();
+          bool a1_scalar = args(1).is_scalar_type ();
+          if (a0_scalar && a1_scalar)
+            retval = mod (args(0).scalar_value (), args(1).scalar_value ());
+          else if ((a0_scalar || args(0).is_sparse_type ()) 
+                   && (a1_scalar || args(1).is_sparse_type ()))
+            {
+              SparseMatrix m0 = args(0).sparse_matrix_value ();
+              SparseMatrix m1 = args(1).sparse_matrix_value ();
+              retval = binmap<double> (m0, m1, mod, "mod");
+            }
+          else
+            {
+              NDArray a0 = args(0).array_value ();
+              NDArray a1 = args(1).array_value ();
+              retval = binmap<double> (a0, a1, mod, "mod");
+            }
+        }
+    }
+  else
+    print_usage ();
+
+  return retval;
+}
+
+/*
+## empty input test
+%!assert (isempty(mod([], [])));
+
+## x mod y, y != 0 tests
+%!assert (mod(5, 3), 2);
+%!assert (mod(-5, 3), 1);
+%!assert (mod(0, 3), 0);
+%!assert (mod([-5, 5, 0], [3, 3, 3]), [1, 2, 0]);
+%!assert (mod([-5; 5; 0], [3; 3; 3]), [1; 2; 0]);
+%!assert (mod([-5, 5; 0, 3], [3, 3 ; 3, 1]), [1, 2 ; 0, 0]);
+
+## x mod 0 tests
+%!assert (mod(5, 0), 5);
+%!assert (mod(-5, 0), -5);
+%!assert (mod([-5, 5, 0], [3, 0, 3]), [1, 5, 0]);
+%!assert (mod([-5; 5; 0], [3; 0; 3]), [1; 5; 0]);
+%!assert (mod([-5, 5; 0, 3], [3, 0 ; 3, 1]), [1, 5 ; 0, 0]);
+%!assert (mod([-5, 5; 0, 3], [0, 0 ; 0, 0]), [-5, 5; 0, 3]);
+
+## mixed scalar/matrix tests
+%!assert (mod([-5, 5; 0, 3], 0), [-5, 5; 0, 3]); 
+%!assert (mod([-5, 5; 0, 3], 3), [1, 2; 0, 0]);
+%!assert (mod(-5,[0,0; 0,0]), [-5, -5; -5, -5]);
+%!assert (mod(-5,[3,0; 3,1]), [1, -5; 1, 0]);
+%!assert (mod(-5,[3,2; 3,1]), [1, 1; 1, 0]);
+
+## integer types
+%!assert (mod(uint8(5),uint8(4)),uint8(1))
+%!assert (mod(uint8([1:5]),uint8(4)),uint8([1,2,3,0,1]))
+%!assert (mod(uint8([1:5]),uint8(0)),uint8([1:5]))
+%!error (mod(uint8(5),int8(4)))
+
+## mixed integer/real types
+%!assert (mod(uint8(5),4),uint8(1))
+%!assert (mod(5,uint8(4)),uint8(1))
+%!assert (mod(uint8([1:5]),4),uint8([1,2,3,0,1]))
+*/
+
 // FIXME Need to convert the reduction functions of this file for single precision
 
 #define NATIVE_REDUCTION_1(FCN, TYPE, DIM) \