changeset 893:f34897bc944f

connectivity: move the validate() static methods into the constructor. * conndef.cc (Fiptconncheck): this function performs a bunch of logic that is useful for a constructor from octave_value. But if we move it there, then all the validate static methods are no longer needed, and we can just check by calling the constructor. (invalid_connectivity): an exception class so we can give meaningful error messages on why the constructor failed (and the check failed). (Fconncheck): make the position arg optional. * conndef.h: declare new exception class for incorrect connectivity arguments, remove the validate methods and add them to the constructor
author Carnë Draug <carandraug@octave.org>
date Thu, 02 Oct 2014 16:39:05 +0100
parents a2140b980079
children 23e93653e1e5
files src/conndef.cc src/conndef.h
diffstat 2 files changed, 178 insertions(+), 77 deletions(-) [+]
line wrap: on
line diff
--- a/src/conndef.cc
+++ b/src/conndef.cc
@@ -23,8 +23,81 @@
 {
 }
 
+connectivity::connectivity (const octave_value& val)
+{
+  try
+    {
+      const double conn = double_value (val);
+      if (error_state)
+          throw invalid_connectivity ("must be in [4 6 8 18 26]");
+      ctor (conn);
+    }
+  catch (invalid_connectivity& e)
+    {
+      const boolNDArray mask = bool_array_value (val);
+      if (error_state)
+        throw invalid_connectivity ("must be logical or in [4 6 8 18 26]");
+      ctor (mask);
+    }
+  return;
+}
+
+
+connectivity::connectivity (const boolNDArray& mask)
+{
+  ctor (mask);
+  return;
+}
+
+void
+connectivity::ctor (const boolNDArray& mask)
+{
+  // Must be 1x1, 3x1, or 3x3x3x...x3
+  const octave_idx_type numel = mask.numel ();
+  const octave_idx_type ndims = mask.ndims ();
+  const dim_vector      dims  = mask.dims ();
+
+  if (ndims == 2)
+    {
+      // Don't forget 1x1, and 3x1 which are valid but arrays always
+      // have at least 2d
+      if (   (dims(1) != 3 && dims(2) != 3)
+          && (dims(1) != 3 && dims(2) != 1)
+          && (dims(1) != 1 && dims(2) != 1))
+        throw invalid_connectivity ("is not 1x1, 3x1, 3x3, or 3x3x...x3");
+    }
+  else
+    {
+      for (octave_idx_type i = 0; i < ndims; i++)
+        if (dims(i) != 3)
+          throw invalid_connectivity ("is not 3x3x...x3");
+    }
+
+  // Center must be true
+  const octave_idx_type center = floor (numel /2);
+  if (! mask(center))
+    throw invalid_connectivity ("center is not true");
+
+  // Must be symmetric relative to its center
+  const bool* start = mask.fortran_vec ();
+  const bool* end   = mask.fortran_vec () + (numel -1);
+  for (octave_idx_type i = 0; i < center; i++)
+    if (start[i] != end[-i])
+      throw invalid_connectivity ("is not symmetric relative to its center");
+
+  this->mask = mask;
+  return;
+}
+
 connectivity::connectivity (const octave_idx_type& conn)
 {
+  ctor (conn);
+  return;
+}
+
+void
+connectivity::ctor (const octave_idx_type& conn)
+{
   if (conn == 4)
     {
       mask = boolNDArray (dim_vector (3, 3), true);
@@ -64,7 +137,7 @@
   else if (conn == 26)
     mask = boolNDArray (dim_vector (3, 3, 3), true);
   else
-    error ("conndef: invalid CONN `%i'", conn);
+    throw invalid_connectivity ("must be in the set [4 6 8 18 26]");
 
   return;
 }
@@ -77,8 +150,10 @@
   if (ndims == 1)
     size = dim_vector (3, 1);
   else
-    size = dim_vector (3, 3);
-    size.resize (ndims, 3);
+    {
+      size = dim_vector (3, 3);
+      size.resize (ndims, 3);
+    }
 
   if (type == "maximal")
     {
@@ -99,11 +174,12 @@
         }
     }
   else
-    error ("conndef: invalid TYPE of connectivity '%s'", type.c_str ());
+    throw invalid_connectivity ("must be \"maximal\" or \"minimal\"");
 
   return;
 }
 
+
 Array<octave_idx_type>
 connectivity::offsets (const dim_vector& size) const
 {
@@ -114,7 +190,6 @@
   Array<octave_idx_type> offsets (dim_vector (nnz, 1)); // retval
   const dim_vector cum_size = size.cumulative ();
 
-
   Array<octave_idx_type> diff (dim_vector (ndims, 1));
 
   Array<octave_idx_type> sub (dim_vector (ndims, 1), 0);
@@ -129,6 +204,7 @@
           octave_idx_type off = diff(0);
           for (octave_idx_type dim = 1; dim < ndims; dim++)
             off += (diff(dim) * cum_size(dim-1));
+
           offsets(found) = off;
           found++;
         }
@@ -138,41 +214,28 @@
 }
 
 
-bool
-connectivity::validate (const double& conn)
+double
+connectivity::double_value (const octave_value& val)
 {
-  if (conn == 4 || conn == 8 || conn == 6 || conn == 18 || conn == 26
-      || conn == 1)
-    return true;
-
-  return false;
+  error_state = 0;
+  const double conn = val.double_value ();
+  // Check is_scalar_type because the warning Octave:array-to-scalar
+  // is off by default and we will get the first element only.
+  if (error_state || ! val.is_scalar_type ())
+    error_state = 1;
+  return conn;
 }
 
-
-bool
-connectivity::validate (const boolNDArray& mask)
+boolNDArray
+connectivity::bool_array_value (const octave_value& val)
 {
-  // Must be 3x3x3x...x3
-  const dim_vector dims = mask.dims ();
-  const octave_idx_type ndims = mask.ndims ();
-  for (octave_idx_type i = 0; i < ndims; i++)
-    if (dims(i) != 3)
-      return false;
-
-  // Center must be true
-  const octave_idx_type numel = mask.numel ();
-  const octave_idx_type center = floor (numel /2);
-  if (! mask(center))
-    return false;
-
-  // Must be symmetric relative to its center
-  const bool* start = mask.fortran_vec ();
-  const bool* end   = mask.fortran_vec () + (numel -1);
-  for (octave_idx_type i = 0; i < center; i++)
-    if (start[i] != end[-i])
-      return false;
-
-  return true;
+  error_state = 0;
+  const boolNDArray mask = val.bool_array_value ();
+  // bool_array_value converts anything other than 0 to true, which will
+  // then validate as conn array, hence any_element_not_one_or_zero()
+  if (val.array_value ().any_element_not_one_or_zero ())
+    error_state = 1;
+  return mask;
 }
 
 
@@ -246,7 +309,15 @@
 
   connectivity conn;
   if (nargin == 1)
-    conn = connectivity (arg0);
+    {
+      try
+        {conn = connectivity (arg0);}
+      catch (invalid_connectivity& e)
+        {
+          error ("conndef: CONN %s", e.what ());
+          return octave_value ();
+        }
+    }
   else
     {
       const std::string type = args(1).string_value ();
@@ -255,7 +326,13 @@
           error ("conndef: TYPE must be a string");
           return octave_value ();
         }
-      conn = connectivity (arg0, type);
+      try
+        {conn = connectivity (arg0, type);}
+      catch (invalid_connectivity& e)
+        {
+          error ("conndef: TYPE %s", e.what ());
+          return octave_value ();
+        }
     }
 
   // we must return an array of class double
@@ -264,6 +341,7 @@
 
 
 /*
+
 %!assert (conndef (1, "minimal"), [1; 1; 1]);
 %!assert (conndef (2, "minimal"), [0 1 0; 1 1 1; 0 1 0]);
 
@@ -311,11 +389,11 @@
 %!        [  122  284  338  356  362  364  365  366  368  374  392  446  608](:))
 
 %!error conndef ()
-%!error conndef (-2, "minimal")
+%!error <must be a positive integer> conndef (-2, "minimal")
 %!error conndef (char (2), "minimal")
-%!error <TYPE of connectivity> conndef (3, "invalid")
 %!error conndef ("minimal", 3)
-%!error <invalid CONN> conndef (10)
+%!error <TYPE must be "maximal" or "minimal"> conndef (3, "invalid")
+%!error <CONN must be in the set> conndef (10)
 
 %!assert (conndef (2, "minimal"), conndef (4))
 %!assert (conndef (2, "maximal"), conndef (8))
@@ -325,14 +403,14 @@
 %!assert (conndef (18), reshape ([0 1 0 1 1 1 0 1 0
 %!                                1 1 1 1 1 1 1 1 1
 %!                                0 1 0 1 1 1 0 1 0], [3 3 3]))
-
 */
 
 // PKG_ADD: autoload ("iptcheckconn", which ("conndef"));
 // PKG_DEL: autoload ("iptcheckconn", which ("conndef"), "remove");
 DEFUN_DLD(iptcheckconn, args, , "\
 -*- texinfo -*-\n\
-@deftypefn {Loadable Function} {} iptcheckconn (@var{con}, @var{func}, @var{var}, @var{pos})\n\
+@deftypefn  {Loadable Function} {} iptcheckconn (@var{conn}, @var{func}, @var{var})\n\
+@deftypefnx {Loadable Function} {} iptcheckconn (@var{conn}, @var{func}, @var{var}, @var{pos})\n\
 Check if argument is valid connectivity.\n\
 \n\
 If @var{conn} is not a valid connectivity argument, gives a properly\n\
@@ -342,7 +420,7 @@
 argument in the input.\n\
 \n\
 A valid connectivity argument must be either double or logical.  It must\n\
-also be either a scalar from set [1 4 6 8 18 26], or a symmetric matrix\n\
+also be either a scalar from set [4 6 8 18 26], or a symmetric matrix\n\
 with all dimensions of size 3, with only 0 or 1 as values, and 1 at its\n\
 center.\n\
 \n\
@@ -352,7 +430,7 @@
   const octave_idx_type nargin = args.length ();
 //  const octave_value rv = octave_value ();
 
-  if (nargin != 4)
+  if (nargin < 3 || nargin > 4)
     {
       print_usage ();
       return octave_value ();
@@ -370,34 +448,27 @@
       error ("iptcheckconn: VAR must be a string");
       return octave_value ();
     }
-  const octave_idx_type pos = args(3).idx_type_value (true);
-  if (error_state || pos < 1)
+  octave_idx_type pos (0);
+  if (nargin > 3)
     {
-      error ("iptcheckconn: POS must be a positive integer");
-      return octave_value ();
+      pos = args(3).idx_type_value (true);
+      if (error_state || pos < 1)
+        {
+          error ("iptcheckconn: POS must be a positive integer");
+          return octave_value ();
+        }
     }
 
-  bool bad = true;
-
-  const double conn = args(0).double_value ();
-  // check is_scalar_type because of the warning Octave:array-to-scalar
-  if (! error_state && args(0).is_scalar_type () //
-      && connectivity::validate (conn))
-    bad = false;
-  else
+  try
+    {const connectivity conn (args(0));}
+  catch (invalid_connectivity& e)
     {
-      const boolNDArray mask = args(0).bool_array_value ();
-      // bool_array_value converts anything other than 0 to true, which will
-      // then validate asconn array, hence any_element_not_one_or_zero
-      if (! error_state && connectivity::validate (mask)
-          && ! args(0).array_value ().any_element_not_one_or_zero ())
-        bad = false;
+      if (pos == 0)
+        error ("%s: %s %s", func.c_str (), var.c_str (), e.what ());
+      else
+        error ("%s: %s, at pos %i, %s",
+               func.c_str (), var.c_str (), pos, e.what ());
     }
-
-  if (bad)
-    error ("%s: %s at pos %i is not a valid connectivity array",
-           func.c_str (), var.c_str (), pos);
-
   return octave_value ();
 }
 
@@ -405,12 +476,21 @@
 // the complete error message should be "expected error <.> but got none",
 // but how to escape <> within the error message?
 
-%!error <expected error> fail ("iptcheckconn (4, 'func', 'var', 2)");
-%!error <expected error> fail ("iptcheckconn (ones (3, 3, 3, 3), 'func', 'var', 2)");
+%!error <expected error> fail ("iptcheckconn ( 4, 'func', 'var')");
+%!error <expected error> fail ("iptcheckconn ( 6, 'func', 'var')");
+%!error <expected error> fail ("iptcheckconn ( 8, 'func', 'var')");
+%!error <expected error> fail ("iptcheckconn (18, 'func', 'var')");
+%!error <expected error> fail ("iptcheckconn (26, 'func', 'var')");
 
-%!error <not a valid connectivity array> iptcheckconn (3, "func", "var", 2);
-%!error <not a valid connectivity array> iptcheckconn ([1 1 1; 1 0 1; 1 1 1], "func", "var", 2);
-%!error <not a valid connectivity array> iptcheckconn ([1 2 1; 1 1 1; 1 1 1], "func", "var", 2);
-%!error <not a valid connectivity array> iptcheckconn ([0 1 1; 1 1 1; 1 1 1], "func", "var", 2);
-%!error <not a valid connectivity array> iptcheckconn (ones (3, 3, 3, 4), "func", "var", 2);
+%!error <expected error> fail ("iptcheckconn (1, 'func', 'var')");
+%!error <expected error> fail ("iptcheckconn (ones (3, 1), 'func', 'var')");
+%!error <expected error> fail ("iptcheckconn (ones (3, 3), 'func', 'var')");
+%!error <expected error> fail ("iptcheckconn (ones (3, 3, 3), 'func', 'var')");
+%!error <expected error> fail ("iptcheckconn (ones (3, 3, 3, 3), 'func', 'var')");
+
+%!error <VAR must be logical or in> iptcheckconn (3, "func", "VAR");
+%!error <VAR center is not true> iptcheckconn ([1 1 1; 1 0 1; 1 1 1], "func", "VAR");
+%!error <VAR must be logical or in> iptcheckconn ([1 2 1; 1 1 1; 1 1 1], "func", "VAR");
+%!error <VAR is not symmetric relative to its center> iptcheckconn ([0 1 1; 1 1 1; 1 1 1], "func", "VAR");
+%!error <VAR is not 3x3x...x3> iptcheckconn (ones (3, 3, 3, 4), "func", "VAR");
 */
--- a/src/conndef.h
+++ b/src/conndef.h
@@ -17,6 +17,8 @@
 #ifndef OCTAVE_IMAGE_CONNDEF
 #define OCTAVE_IMAGE_CONNDEF
 
+#include <stdexcept>
+
 #include <octave/oct.h>
 
 namespace octave
@@ -27,6 +29,10 @@
     {
       public:
         connectivity ();
+
+        //! Will throw if val is bad
+        connectivity (const octave_value& val);
+        connectivity (const boolNDArray& mask);
         connectivity (const octave_idx_type& conn);
         connectivity (const octave_idx_type& ndims, const std::string& type);
 
@@ -36,8 +42,23 @@
         // connected elements (will have negative and positive values).
         Array<octave_idx_type> offsets (const dim_vector& size) const;
 
-        static bool validate (const boolNDArray& mask);
-        static bool validate (const double& conn);
+      private:
+        void ctor (const boolNDArray& mask);
+        void ctor (const octave_idx_type& conn);
+
+        //! Like octave_value::double_value() but actually checks if scalar.
+        static double double_value (const octave_value& val);
+
+        //! Like octave_value::bool_array_value() but actually checks if
+        //! all values are zeros and one.
+        static boolNDArray bool_array_value (const octave_value& val);
+    };
+
+    class invalid_connectivity : public std::invalid_argument
+    {
+      public:
+        invalid_connectivity (const std::string& what_arg)
+          : std::invalid_argument (what_arg) { }
     };
   }
 }