changeset 11021:0ee74d581c00

optimize & rewrite gcd
author Jaroslav Hajek <highegg@gmail.com>
date Sat, 25 Sep 2010 14:18:29 +0200
parents a2773763e3ff
children a5bee81bb69f
files src/ChangeLog src/DLD-FUNCTIONS/gcd.cc
diffstat 2 files changed, 288 insertions(+), 420 deletions(-) [+]
line wrap: on
line diff
--- a/src/ChangeLog
+++ b/src/ChangeLog
@@ -1,3 +1,8 @@
+2010-09-25  Jaroslav Hajek  <highegg@gmail.com>
+
+	* DLD-FUNCTIONS/gcd.cc (Fgcd): Rewrite. Use more efficient code.
+	Specialize for simple/extended Euclid. Handle integer arguments.
+
 2010-09-24  Rik <octave@nomad.inbox5.com>
 
 	* DLD-FUNCTIONS/regexp.cc (regexp, regexprep): Update docstring to 
--- a/src/DLD-FUNCTIONS/gcd.cc
+++ b/src/DLD-FUNCTIONS/gcd.cc
@@ -1,6 +1,7 @@
 /*
 
 Copyright (C) 2004, 2005, 2006, 2007, 2008, 2009 David Bateman
+Copyirght (C) 2010 Jaroslav Hajek
 
 This file is part of Octave.
 
@@ -29,44 +30,287 @@
 #include "fNDArray.h"
 #include "fCNDArray.h"
 #include "lo-mappers.h"
-#include "oct-locbuf.h"
+#include "oct-binmap.h"
 
 #include "defun-dld.h"
 #include "error.h"
 #include "oct-obj.h"
 
-// FIXME -- should probably handle Inf, NaN.
+static double 
+simple_gcd (double a, double b)
+{
+  if (! xisinteger (a) || ! xisinteger (b))
+    (*current_liboctave_error_handler)
+      ("gcd: all values must be integers");
+
+  double aa = fabs (a), bb = fabs (b);
+  while (bb != 0)
+    {
+      double tt = fmod (aa, bb);
+      aa = bb; bb = tt;
+    }
+
+  return aa;
+}
+
+template <class T>
+static octave_int<T>
+simple_gcd (const octave_int<T>& a, const octave_int<T>& b)
+{
+  T aa = a.abs ().value (), bb = b.abs ().value ();
+  while (bb != 0)
+    {
+      T tt = aa % bb;
+      aa = bb; bb = tt;
+    }
+
+  return aa;
+}
+
+static double 
+extended_gcd (double a, double b, double& x, double& y)
+{
+  if (! xisinteger (a) || ! xisinteger (b))
+    (*current_liboctave_error_handler)
+      ("gcd: all values must be integers");
+
+  double aa = fabs (a), bb = fabs (b);
+  double xx = 0, lx = 1, yy = 0, ly = 1;
+  while (bb != 0)
+    {
+      double qq = floor (aa / bb);
+      double tt = fmod (aa, bb);
+      aa = bb; bb = tt;
+
+      double tx = lx - qq*xx;
+      lx = xx; xx = tx;
+
+      double ty = ly - qq*yy;
+      ly = yy; yy = ty;
+    }
 
-static inline bool
-is_integer_value (double x)
+  x = a >= 0 ? lx : -lx; 
+  y = b >= 0 ? ly : -ly;
+
+  return aa;
+}
+
+template <class T>
+static octave_int<T>
+extended_gcd (const octave_int<T>& a, const octave_int<T>& b,
+              octave_int<T>& x, octave_int<T>& y)
 {
-  return x == std::floor (x);
+  T aa = a.abs ().value (), bb = b.abs ().value ();
+  T xx = 0, lx = 1, yy = 0, ly = 1;
+  while (bb != 0)
+    {
+      T qq = aa / bb;
+      T tt = aa % bb;
+      aa = bb; bb = tt;
+
+      T tx = lx - qq*xx;
+      lx = xx; xx = tx;
+
+      T ty = ly - qq*yy;
+      ly = yy; yy = ty;
+    }
+
+  x = octave_int<T> (lx) * a.signum ();
+  y = octave_int<T> (ly) * b.signum ();
+
+  return aa;
+}
+
+template<class NDA>
+static octave_value
+do_simple_gcd (const octave_value& a, const octave_value& b)
+{
+  typedef typename NDA::element_type T;
+  octave_value retval;
+
+  if (a.is_scalar_type () && b.is_scalar_type ())
+    {
+      // Optimize scalar case.
+      T aa = octave_value_extract<T> (a);
+      T bb = octave_value_extract<T> (b);
+      retval = simple_gcd (aa, bb);
+    }
+  else
+    {
+      NDA aa = octave_value_extract<NDA> (a);
+      NDA bb = octave_value_extract<NDA> (b);
+      retval = binmap<T> (aa, bb, simple_gcd, "gcd");
+    }
+
+  return retval;
 }
 
-static inline bool
-is_integer_value (float x)
+// Dispatcher
+static octave_value
+do_simple_gcd (const octave_value& a, const octave_value& b)
+{
+  octave_value retval;
+  builtin_type_t btyp = btyp_mixed_numeric (a.builtin_type (),
+                                            b.builtin_type ());
+  switch (btyp)
+    {
+    case btyp_double:
+      if (a.is_sparse_type () && b.is_sparse_type ())
+        {
+          retval = do_simple_gcd<SparseMatrix> (a, b);
+          break;
+        }
+      else { } // fall through!
+    case btyp_float:
+      retval = do_simple_gcd<NDArray> (a, b);
+      break;
+#define MAKE_INT_BRANCH(X) \
+    case btyp_ ## X: \
+      retval = do_simple_gcd<X ## NDArray> (a, b); \
+      break
+    MAKE_INT_BRANCH (int8);
+    MAKE_INT_BRANCH (int16);
+    MAKE_INT_BRANCH (int32);
+    MAKE_INT_BRANCH (int64);
+    MAKE_INT_BRANCH (uint8);
+    MAKE_INT_BRANCH (uint16);
+    MAKE_INT_BRANCH (uint32);
+    MAKE_INT_BRANCH (uint64);
+#undef MAKE_INT_BRANCH
+    case btyp_complex:
+    case btyp_float_complex:
+      error ("gcd: complex numbers not allowed");
+    default:
+      error ("gcd: invalid class combination for gcd: %s and %s\n",
+             a.class_name ().c_str (), b.class_name ().c_str ());
+    }
+
+  if (btyp == btyp_float)
+    retval = retval.float_array_value ();
+
+  return retval;
+}
+
+template<class NDA>
+static octave_value
+do_extended_gcd (const octave_value& a, const octave_value& b,
+                 octave_value& x, octave_value& y)
 {
-  return x == std::floor (x);
+  typedef typename NDA::element_type T;
+  octave_value retval;
+
+  if (a.is_scalar_type () && b.is_scalar_type ())
+    {
+      // Optimize scalar case.
+      T aa = octave_value_extract<T> (a);
+      T bb = octave_value_extract<T> (b);
+      T xx, yy;
+      retval = extended_gcd (aa, bb, xx, yy);
+      x = xx;
+      y = yy;
+    }
+  else
+    {
+      NDA aa = octave_value_extract<NDA> (a);
+      NDA bb = octave_value_extract<NDA> (b);
+
+      dim_vector dv = aa.dims ();
+      if (aa.numel () == 1)
+        dv = bb.dims ();
+      else if (bb.numel () != 1 && bb.dims () != dv)
+        gripe_nonconformant ("gcd", a.dims (), b.dims ());
+
+      NDA gg (dv), xx (dv), yy (dv);
+
+      const T *aptr = aa.fortran_vec (), *bptr = bb.fortran_vec ();
+
+      bool inca = aa.numel () != 1, incb = bb.numel () != 1;
+
+      T *gptr = gg.fortran_vec ();
+      T *xptr = xx.fortran_vec (), *yptr = yy.fortran_vec ();
+
+      octave_idx_type n = gg.numel ();
+      for (octave_idx_type i = 0; i < n; i++)
+        {
+          octave_quit ();
+
+          *gptr++ = extended_gcd (*aptr, *bptr, *xptr++, *yptr++);
+          aptr += inca;
+          bptr += incb;
+        }
+
+      x = xx;
+      y = yy;
+      retval = gg;
+    }
+
+  return retval;
 }
 
+// Dispatcher
+static octave_value
+do_extended_gcd (const octave_value& a, const octave_value& b,
+                 octave_value& x, octave_value& y)
+{
+  octave_value retval;
+  builtin_type_t btyp = btyp_mixed_numeric (a.builtin_type (),
+                                            b.builtin_type ());
+  switch (btyp)
+    {
+    case btyp_double:
+    case btyp_float:
+      retval = do_extended_gcd<NDArray> (a, b, x, y);
+      break;
+#define MAKE_INT_BRANCH(X) \
+    case btyp_ ## X: \
+      retval = do_extended_gcd<X ## NDArray> (a, b, x, y); \
+      break
+    MAKE_INT_BRANCH (int8);
+    MAKE_INT_BRANCH (int16);
+    MAKE_INT_BRANCH (int32);
+    MAKE_INT_BRANCH (int64);
+    MAKE_INT_BRANCH (uint8);
+    MAKE_INT_BRANCH (uint16);
+    MAKE_INT_BRANCH (uint32);
+    MAKE_INT_BRANCH (uint64);
+#undef MAKE_INT_BRANCH
+    case btyp_complex:
+    case btyp_float_complex:
+      error ("gcd: complex numbers not allowed");
+    default:
+      error ("gcd: invalid class combination for gcd: %s and %s\n",
+             a.class_name ().c_str (), b.class_name ().c_str ());
+    }
+
+  // For consistency.
+  if (! error_state && a.is_sparse_type () && b.is_sparse_type ())
+    {
+      retval = retval.sparse_matrix_value ();
+      x = x.sparse_matrix_value ();
+      y = y.sparse_matrix_value ();
+    }
+
+  if (btyp == btyp_float)
+    {
+      retval = retval.float_array_value ();
+      x = x.float_array_value ();
+      y = y.float_array_value ();
+    }
+
+  return retval;
+}
+
+
 DEFUN_DLD (gcd, args, nargout,
   "-*- texinfo -*-\n\
-@deftypefn  {Loadable Function} {@var{g} =} gcd (@var{a})\n\
-@deftypefnx {Loadable Function} {@var{g} =} gcd (@var{a1}, @var{a2}, @dots{})\n\
+@deftypefn {Loadable Function} {@var{g} =} gcd (@var{a1}, @var{a2}, @dots{})\n\
 @deftypefnx {Loadable Function} {[@var{g}, @var{v1}, @dots{}] =} gcd (@var{a1}, @var{a2}, @dots{})\n\
 \n\
-Compute the greatest common divisor of the elements of @var{a}.  If more\n\
+Compute the greatest common divisor of @var{a1}, @var{a2}, @dots{}. If more\n\
 than one argument is given all arguments must be the same size or scalar.\n\
   In this case the greatest common divisor is calculated for each element\n\
 individually.  All elements must be integers.  For example,\n\
 \n\
-@example\n\
-@group\n\
-gcd ([15, 20])\n\
-    @result{}  5\n\
-@end group\n\
-@end example\n\
-\n\
 @noindent\n\
 and\n\
 \n\
@@ -91,432 +335,51 @@
 \n\
 @end ifnottex\n\
 \n\
-For backward compatibility with previous versions of this function, when\n\
-all arguments are scalar, a single return argument @var{v1} containing\n\
-all of the values of @var{v1}, @dots{} is acceptable.\n\
 @seealso{lcm, factor}\n\
 @end deftypefn")
 {
   octave_value_list retval;
-
   int nargin = args.length ();
 
-  if (nargin == 0)
+  if (nargin > 1)
     {
-      print_usage ();
-      return retval;
-    }
-
-  bool all_args_scalar = true;
-  bool any_single = false;
-
-  dim_vector dv (1, 1);
-
-  for (int i = 0; i < nargin; i++)
-    {
-      if (! args(i).is_scalar_type ())
+      if (nargout > 1)
         {
-          if (! args(i).is_matrix_type ())
-            {
-              error ("gcd: invalid argument type");
-              return retval;
-            }
-
-          if (all_args_scalar)
-            {
-              all_args_scalar = false;
-              dv = args(i).dims ();
-            }
-          else
-            {
-              if (dv != args(i).dims ())
-                {
-                  error ("gcd: all arguments must be the same size or scalar");
-                  return retval;
-                }
-            }
-        }
-      if (!any_single && args(i).is_single_type ())
-        any_single = true;
-    }
-
-  if (any_single)
-    {
-      if (nargin == 1)
-        {
-          FloatNDArray gg = args(0).float_array_value ();
-
-          int nel = dv.numel ();
-
-          FloatNDArray v (dv);
-
-          FloatRowVector x (3);
-          FloatRowVector y (3);
-
-          float g = std::abs (gg(0));
-
-          if (! is_integer_value (g))
-            {
-              error ("gcd: all arguments must be integer");
-              return retval;
-            }
-
-          v(0) = signum (gg(0));
-      
-          for (int k = 1; k < nel; k++)
+          retval.resize (nargin + 1);
+          retval(0) = do_extended_gcd (args(0), args(1), retval(1), retval(2));
+          for (int j = 2; j < nargin; j++)
             {
-              x(0) = g;
-              x(1) = 1;
-              x(2) = 0;
-
-              y(0) = std::abs (gg(k));
-              y(1) = 0;
-              y(2) = 1;
-
-              if (! is_integer_value (y(0)))
-                {
-                  error ("gcd: all arguments must be integer");
-                  return retval;
-                }
-
-              while (y(0) > 0)
-                {
-                  FloatRowVector r = x - y * std::floor (x(0) / y(0));
-                  x = y;
-                  y = r;
-                }
-
-              g = x(0);
-
-              for (int i = 0; i < k; i++) 
-                v(i) *= x(1);
-
-              v(k) = x(2) * signum (gg(k));
-            }
-
-          retval (1) = v;
-          retval (0) = g;
-        }
-      else if (all_args_scalar && nargout < 3)
-        {
-          float g = args(0).float_value ();
-
-          if (error_state || ! is_integer_value (g))
-            {
-              error ("gcd: all arguments must be integer");
-              return retval;
+              octave_value x;
+              retval(0) = do_extended_gcd (retval(0), args(1), 
+                                           x, retval(j+1));
+              for (int i = 0; i < j; i++)
+                retval(i).assign (octave_value::op_el_mul_eq, x);
+              if (error_state)
+                break;
             }
-
-          FloatRowVector v (nargin, 0);
-          FloatRowVector x (3);
-          FloatRowVector y (3);
-
-          v(0) = signum (g);
-
-          g = std::abs(g);
-      
-          for (int k = 1; k < nargin; k++)
-            {
-              x(0) = g;
-              x(1) = 1;
-              x(2) = 0;
-
-              y(0) = args(k).float_value ();
-              y(1) = 0;
-              y(2) = 1;
-
-              float sgn = signum (y(0));
-
-              y(0) = std::abs (y(0));
-
-              if (error_state || ! is_integer_value (g))
-                {
-                  error ("gcd: all arguments must be integer");
-                  return retval;
-                }
-
-              while (y(0) > 0)
-                {
-                  FloatRowVector r = x - y * std::floor (x(0) / y(0));
-                  x = y;
-                  y = r;
-                }
-
-              g = x(0);
-
-              for (int i = 0; i < k; i++) 
-                v(i) *= x(1);
-
-              v(k) = x(2) * sgn;
-            }
-
-          retval (1) = v;
-          retval (0) = g;
         }
       else
         {
-          // FIXME -- we need to handle a possible mixture of scalar and
-          // array values here.
-
-          FloatNDArray g = args(0).float_array_value ();
-
-          OCTAVE_LOCAL_BUFFER (FloatNDArray, v, nargin);
-
-          int nel = dv.numel ();
-
-          v[0].resize(dv);
-
-          for (int i = 0; i < nel; i++)
-            {
-              v[0](i) = signum (g(i));
-              g(i) = std::abs (g(i));
-
-              if (! is_integer_value (g(i)))
-                {
-                  error ("gcd: all arguments must be integer");
-                  return retval;
-                }
-            }
-
-          FloatRowVector x (3);
-          FloatRowVector y (3);
-
-          for (int k = 1; k < nargin; k++)
+          retval(0) = do_simple_gcd (args(0), args(1));
+          for (int j = 2; j < nargin; j++)
             {
-              FloatNDArray gnew = args(k).float_array_value ();
-
-              v[k].resize(dv);
-
-              for (int n = 0; n < nel; n++)
-                {
-                  x(0) = g(n);
-                  x(1) = 1;
-                  x(2) = 0;
-
-                  y(0) = std::abs (gnew(n));
-                  y(1) = 0;
-                  y(2) = 1; 
-
-                  if (! is_integer_value (y(0)))
-                    {
-                      error ("gcd: all arguments must be integer");
-                      return retval;
-                    }
-
-                  while (y(0) > 0)
-                    {
-                      FloatRowVector r = x - y * std::floor (x(0) / y(0));
-                      x = y;
-                      y = r;
-                    }
-
-                  g(n) = x(0);
-
-                  for (int i = 0; i < k; i++) 
-                    v[i](n) *= x(1);
-
-                  v[k](n) = x(2) * signum (gnew(n));
-                }
+              retval(0) = do_simple_gcd (retval(0), args(j));
+              if (error_state)
+                break;
             }
-
-          for (int k = 0; k < nargin; k++)
-            retval(1+k) = v[k];
-
-          retval (0) = g;
         }
     }
-  else if (nargin == 1)
-    {
-      NDArray gg = args(0).array_value ();
-
-      int nel = dv.numel ();
-
-      NDArray v (dv);
-
-      RowVector x (3);
-      RowVector y (3);
-
-      double g = std::abs (gg(0));
-
-      if (! is_integer_value (g))
-        {
-          error ("gcd: all arguments must be integer");
-          return retval;
-        }
-
-      v(0) = signum (gg(0));
-      
-      for (int k = 1; k < nel; k++)
-        {
-          x(0) = g;
-          x(1) = 1;
-          x(2) = 0;
-
-          y(0) = std::abs (gg(k));
-          y(1) = 0;
-          y(2) = 1;
-
-          if (! is_integer_value (y(0)))
-            {
-              error ("gcd: all arguments must be integer");
-              return retval;
-            }
-
-          while (y(0) > 0)
-            {
-              RowVector r = x - y * std::floor (x(0) / y(0));
-              x = y;
-              y = r;
-            }
-
-          g = x(0);
-
-          for (int i = 0; i < k; i++) 
-            v(i) *= x(1);
-
-          v(k) = x(2) * signum (gg(k));
-        }
-
-      retval (1) = v;
-      retval (0) = g;
-    }
-  else if (all_args_scalar && nargout < 3)
-    {
-      double g = args(0).double_value ();
-
-      if (error_state || ! is_integer_value (g))
-        {
-          error ("gcd: all arguments must be integer");
-          return retval;
-        }
-
-      RowVector v (nargin, 0);
-      RowVector x (3);
-      RowVector y (3);
-
-      v(0) = signum (g);
-
-      g = std::abs(g);
-      
-      for (int k = 1; k < nargin; k++)
-        {
-          x(0) = g;
-          x(1) = 1;
-          x(2) = 0;
-
-          y(0) = args(k).double_value ();
-          y(1) = 0;
-          y(2) = 1;
-
-          double sgn = signum (y(0));
-
-          y(0) = std::abs (y(0));
-
-          if (error_state || ! is_integer_value (g))
-            {
-              error ("gcd: all arguments must be integer");
-              return retval;
-            }
-
-          while (y(0) > 0)
-            {
-              RowVector r = x - y * std::floor (x(0) / y(0));
-              x = y;
-              y = r;
-            }
-
-          g = x(0);
-
-          for (int i = 0; i < k; i++) 
-            v(i) *= x(1);
-
-          v(k) = x(2) * sgn;
-        }
-
-      retval (1) = v;
-      retval (0) = g;
-    }
   else
-    {
-      // FIXME -- we need to handle a possible mixture of scalar and
-      // array values here.
-
-      NDArray g = args(0).array_value ();
-
-      OCTAVE_LOCAL_BUFFER (NDArray, v, nargin);
-
-      int nel = dv.numel ();
-
-      v[0].resize(dv);
-
-      for (int i = 0; i < nel; i++)
-        {
-          v[0](i) = signum (g(i));
-          g(i) = std::abs (g(i));
-
-          if (! is_integer_value (g(i)))
-            {
-              error ("gcd: all arguments must be integer");
-              return retval;
-            }
-        }
-
-      RowVector x (3);
-      RowVector y (3);
-
-      for (int k = 1; k < nargin; k++)
-        {
-          NDArray gnew = args(k).array_value ();
-
-          v[k].resize(dv);
-
-          for (int n = 0; n < nel; n++)
-            {
-              x(0) = g(n);
-              x(1) = 1;
-              x(2) = 0;
-
-              y(0) = std::abs (gnew(n));
-              y(1) = 0;
-              y(2) = 1; 
-
-              if (! is_integer_value (y(0)))
-                {
-                  error ("gcd: all arguments must be integer");
-                  return retval;
-                }
-
-              while (y(0) > 0)
-                {
-                  RowVector r = x - y * std::floor (x(0) / y(0));
-                  x = y;
-                  y = r;
-                }
-
-              g(n) = x(0);
-
-              for (int i = 0; i < k; i++) 
-                v[i](n) *= x(1);
-
-              v[k](n) = x(2) * signum (gnew(n));
-            }
-        }
-
-      for (int k = 0; k < nargin; k++)
-        retval(1+k) = v[k];
-
-      retval (0) = g;
-    }
+    print_usage ();
 
   return retval;
 }
 
 /*
 
-%!assert(gcd (200, 300, 50, 35), gcd ([200, 300, 50, 35]))
-%!assert(gcd ([200, 300, 50, 35]), 5);
-%!assert(gcd (single(200), single(300), single(50), single(35)), gcd (single([200, 300, 50, 35])))
-%!assert(gcd (single([200, 300, 50, 35])), single(5));
+%!assert(gcd (200, 300, 50, 35), 5)
+%!assert(gcd (int16(200), int16(300), int16(50), int16(35)), int16(5))
+%!assert(gcd (uint64(200), uint64(300), uint64(50), uint64(35)), uint64(5))
 
 %!error <Invalid call to gcd.*> gcd ();