changeset 583:319cadd33084

Further speedups to bwlabeln
author jordigh
date Wed, 05 Sep 2012 19:14:14 +0000
parents 4355b74e0b17
children 01e312295406
files src/Makefile src/bwlabeln.cc src/union-find.h++
diffstat 3 files changed, 50 insertions(+), 84 deletions(-) [+]
line wrap: on
line diff
--- a/src/Makefile
+++ b/src/Makefile
@@ -5,7 +5,7 @@
 %.oct: %.cc
 	mkoctfile -Wall $<
 
-bwlabeln.oct: bwlabeln.cc
+bwlabeln.oct: bwlabeln.cc union-find.h++
 	CXXFLAGS='-g -O2 -std=c++0x -Wall' mkoctfile $<
 
 clean:
--- a/src/bwlabeln.cc
+++ b/src/bwlabeln.cc
@@ -18,6 +18,7 @@
 #include <oct.h>
 #include <set>
 #include "union-find.h++"
+#include <unordered_map>
 
 typedef Array<octave_idx_type> coord;
 
@@ -295,8 +296,6 @@
 {
   octave_value_list rval;
 
-  union_find<octave_idx_type> u_f;
-
   octave_idx_type nargin = args.length ();
 
   if (nargin < 1 || nargin > 2)
@@ -373,6 +372,7 @@
   L.insert(BW, coord (dim_vector (size_vec.length (), 1), 1));
 
   double* L_vec = L.fortran_vec ();
+  union_find u_f (L.nelem ());
 
   for (octave_idx_type BWidx = 0; BWidx < BW.nelem (); BWidx++)
     {
@@ -381,7 +381,7 @@
       if (L_vec[Lidx])
         {
           //Insert this one into its group
-          u_f.find_id(Lidx);
+          u_f.find (Lidx);
 
           //Replace this with C++0x range-based for loop later
           //(implemented in gcc 4.6)
@@ -395,17 +395,16 @@
     }
 
 
-  unordered_map<octave_idx_type, octave_idx_type> ids_to_label;
+  std::unordered_map<octave_idx_type, octave_idx_type> ids_to_label;
   octave_idx_type next_label = 1;
 
-  auto idxs  = u_f.get_objects ();
+  auto idxs  = u_f.get_ids ();
 
   //C++0x foreach later
-
   for (auto idx = idxs.begin (); idx != idxs.end (); idx++)
     {
       octave_idx_type label;
-      octave_idx_type id = u_f.find_id (idx->first);
+      octave_idx_type id = u_f.find (*idx);
       auto try_label = ids_to_label.find (id);
       if( try_label == ids_to_label.end ())
         {
@@ -417,7 +416,7 @@
           label = try_label -> second;
         }
 
-      L_vec[idx->first] = label;
+      L_vec[*idx] = label;
     }
 
   rval(0) = L;
--- a/src/union-find.h++
+++ b/src/union-find.h++
@@ -15,118 +15,85 @@
 
 // union-find.h++
 
-#include <unordered_map>
-#include <list>
+#include <vector>
 
-using std::unordered_map;
-using std::list;
+struct voxel{
+  octave_idx_type rank;
+  octave_idx_type parent;
+};
 
-// T - type of object we're union-finding for
-// H - hash for the map
-template <typename T, typename H = std::hash<T> >
 class union_find
 {
 
-//Dramatis personae
 private:
-
-  //Each root has rank.
-  unordered_map<octave_idx_type, octave_idx_type, H> num_ranks;
+  std::vector<voxel*> voxels;
 
-  //Each object points to its parent, possibly itself.
-  unordered_map<octave_idx_type, octave_idx_type, H> parent_pointers;
-
-  //Represent each object by a number and vice versa.
-  unordered_map<octave_idx_type, T, H>      num_to_objects;
-  unordered_map<T, octave_idx_type, H>      objects_to_num;
-
-// Act 1
 public:
 
-  //Insert a collection of objects
-  void insert_objects (const list<T>& objects)
+  union_find (octave_idx_type s) : voxels (s, NULL) {};
+
+  ~union_find ()
   {
-    for (auto i = objects.begin (); i != objects.end (); i++)
-      {
-        find (*i);
-      }
+    for (auto v = voxels.begin(); v != voxels.end(); v++)
+      delete *v;
   }
 
-
   //Give the root representative id for this object, or insert into a
   //new set if none is found
-  octave_idx_type find_id (const T& object)
+  octave_idx_type find (octave_idx_type idx)
   {
 
     //Insert new element if not found
-    if (objects_to_num.find (object) == objects_to_num.end () )
+    auto v = voxels[idx];
+    if (!v)
       {
-        //Assign number serially to objects
-        octave_idx_type obj_num = objects_to_num.size ()+1;
-
-        num_ranks[obj_num] = 0;
-        objects_to_num[object] = obj_num;
-        num_to_objects[obj_num] = object;
-        parent_pointers[obj_num] = obj_num;
-        return obj_num;
+        voxel* new_voxel = new voxel;
+        new_voxel->rank = 0;
+        new_voxel->parent = idx;
+        voxels[idx] = new_voxel;
+        return idx;
       }
 
-    //Path from this element to its root, we'll build it.
-    list<octave_idx_type> path (1, objects_to_num[object]);
-    octave_idx_type par = parent_pointers[path.back ()];
-    while ( par != path.back () )
-      {
-        path.push_back (par);
-        par = parent_pointers[par];
-      }
+    voxel* elt = v;
+    if (elt->parent != idx)
+      elt->parent = find (elt->parent);
 
-    //Update everything we've seen to point to the root.
-    for (auto i = path.begin (); i != path.end (); i++)
-      {
-        parent_pointers[*i] = par;
-      }
-
-    return par;
-  }
-
-  T find( const T& object)
-  {
-    return num_to_objects[find_id (object)];
+    return elt->parent;
   }
 
   //Given two objects, unite the sets to which they belong
-  void unite (const T& obj1, const T& obj2)
+  void unite (octave_idx_type idx1, octave_idx_type idx2)
   {
-    octave_idx_type on1 = find_id(obj1), on2 = find_id(obj2);
+    octave_idx_type root1 = find (idx1), root2 = find (idx2);
 
     //Check if any union needs to be done, maybe they already are
     //in the same set.
-    if (on1 != on2)
+    voxel *v1 = voxels[root1], *v2 = voxels[root2];
+    if (root1 != root2)
       {
-        octave_idx_type r1 = num_ranks[on1], r2 = num_ranks[on2];
+
 
-        if ( r1 < r2)
-          {
-            parent_pointers[on1] = on2;
-            num_ranks.erase (on1); //Only root nodes need a rank
-          }
-        else if (r2 > r1)
-          {
-            parent_pointers[on2] = on1;
-            num_ranks.erase (on2);
-          }
+        if ( v1->rank > v2->rank)
+          v1->parent = root2;
+        else if (v1->rank < v2->rank)
+          v2->parent = root1;
         else
           {
-            parent_pointers[on2] = on1;
-            num_ranks.erase (on2);
-            num_ranks[on1]++;
+            v2->parent = root1;
+            v1->rank++;
           }
       }
   }
 
-  const unordered_map<T, octave_idx_type, H>& get_objects()
+  std::vector<octave_idx_type> get_ids()
   {
-    return objects_to_num;
+    std::vector<octave_idx_type> ids;
+
+    for (size_t i = 0; i < voxels.size (); i++)
+      if (voxels[i])
+        ids.push_back (i);
+
+    return ids;
   };
 
 };