diff src/pt-jit.cc @ 14945:591aeec5c520

Remove uneeded error checks
author Max Brister <max@2bass.com>
date Fri, 08 Jun 2012 22:31:57 -0500
parents c0a5ab3b9278
children 3564bb141396
line wrap: on
line diff
--- a/src/pt-jit.cc
+++ b/src/pt-jit.cc
@@ -356,7 +356,7 @@
                                                  fn->arg_begin (),
                                                  ++fn->arg_begin ());
       builder.CreateRet (ret);
-      binary_ops[op].add_overload (fn, true, true, any, any, any);
+      binary_ops[op].add_overload (fn, true, any, any, any);
     }
 
   llvm::Type *void_t = llvm::Type::getVoidTy (context);
@@ -365,30 +365,30 @@
   fn = create_function ("octave_jit_grab_any", any, any);
                         
   engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_grab_any));
-  grab_fn.add_overload (fn, false, false, any, any);
+  grab_fn.add_overload (fn, false, any, any);
   grab_fn.stash_name ("grab");
 
   // grab scalar
   fn = create_identity (scalar);
-  grab_fn.add_overload (fn, false, false, scalar, scalar);
+  grab_fn.add_overload (fn, false, scalar, scalar);
 
   // grab index
   fn = create_identity (index);
-  grab_fn.add_overload (fn, false, false, index, index);
+  grab_fn.add_overload (fn, false, index, index);
 
   // release any
   fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ());
   engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_release_any));
-  release_fn.add_overload (fn, false, false, 0, any);
+  release_fn.add_overload (fn, false, 0, any);
   release_fn.stash_name ("release");
 
   // release scalar
   fn = create_identity (scalar);
-  release_fn.add_overload (fn, false, false, 0, scalar);
+  release_fn.add_overload (fn, false, 0, scalar);
 
   // release index
   fn = create_identity (index);
-  release_fn.add_overload (fn, false, false, 0, index);
+  release_fn.add_overload (fn, false, 0, index);
 
   // now for binary scalar operations
   // FIXME: Finish all operations
@@ -428,7 +428,7 @@
     llvm::Value *ret = builder.CreateFDiv (fn->arg_begin (), ++fn->arg_begin ());
     builder.CreateRet (ret);
 
-    jit_function::overload ol (fn, true, true, scalar, scalar, scalar);
+    jit_function::overload ol (fn, true, scalar, scalar, scalar);
     binary_ops[octave_value::op_div].add_overload (ol);
     binary_ops[octave_value::op_el_div].add_overload (ol);
   }
@@ -444,7 +444,7 @@
                                             fn->arg_begin ());
     builder.CreateRet (ret);
 
-    jit_function::overload ol (fn, true, true, scalar, scalar, scalar);
+    jit_function::overload ol (fn, true, scalar, scalar, scalar);
     binary_ops[octave_value::op_ldiv].add_overload (ol);
     binary_ops[octave_value::op_el_ldiv].add_overload (ol);
   }
@@ -469,7 +469,7 @@
     builder.CreateRet (zero);
   }
   llvm::verifyFunction (*fn);
-  for_init_fn.add_overload (fn, false, false, index, range);
+  for_init_fn.add_overload (fn, false, index, range);
 
   // bounds check for for loop
   for_check_fn.stash_name ("for_check");
@@ -485,7 +485,7 @@
     builder.CreateRet (ret);
   }
   llvm::verifyFunction (*fn);
-  for_check_fn.add_overload (fn, false, false, boolean, range, index);
+  for_check_fn.add_overload (fn, false, boolean, range, index);
 
   // index variabe for for loop
   for_index_fn.stash_name ("for_index");
@@ -505,7 +505,7 @@
     builder.CreateRet (ret);
   }
   llvm::verifyFunction (*fn);
-  for_index_fn.add_overload (fn, false, false, scalar, range, index);
+  for_index_fn.add_overload (fn, false, scalar, range, index);
 
   // logically true
   logically_true_fn.stash_name ("logically_true");
@@ -534,14 +534,14 @@
     builder.CreateRet (ret);
   }
   llvm::verifyFunction (*fn);
-  logically_true_fn.add_overload (fn, true, false, boolean, scalar);
+  logically_true_fn.add_overload (fn, true, boolean, scalar);
 
   fn = create_function ("octave_logically_true_bool", boolean, boolean);
   body = llvm::BasicBlock::Create (context, "body", fn);
   builder.SetInsertPoint (body);
   builder.CreateRet (fn->arg_begin ());
   llvm::verifyFunction (*fn);
-  logically_true_fn.add_overload (fn, false, false, boolean, boolean);
+  logically_true_fn.add_overload (fn, false, boolean, boolean);
 
   // make_range
   // FIXME: May be benificial to implement all in LLVM
@@ -572,7 +572,7 @@
     builder.CreateRet (rng);
   }
   llvm::verifyFunction (*fn);
-  make_range_fn.add_overload (fn, false, false, range, scalar, scalar, scalar);
+  make_range_fn.add_overload (fn, false, range, scalar, scalar, scalar);
 
   casts[any->type_id ()].stash_name ("(any)");
   casts[scalar->type_id ()].stash_name ("(scalar)");
@@ -580,20 +580,20 @@
   // cast any <- scalar
   fn = create_function ("octave_jit_cast_any_scalar", any, scalar);
   engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_any_scalar));
-  casts[any->type_id ()].add_overload (fn, false, false, any, scalar);
+  casts[any->type_id ()].add_overload (fn, false, any, scalar);
 
   // cast scalar <- any
   fn = create_function ("octave_jit_cast_scalar_any", scalar, any);
   engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_scalar_any));
-  casts[scalar->type_id ()].add_overload (fn, false, false, scalar, any);
+  casts[scalar->type_id ()].add_overload (fn, false, scalar, any);
 
   // cast any <- any
   fn = create_identity (any);
-  casts[any->type_id ()].add_overload (fn, false, false, any, any);
+  casts[any->type_id ()].add_overload (fn, false, any, any);
 
   // cast scalar <- scalar
   fn = create_identity (scalar);
-  casts[scalar->type_id ()].add_overload (fn, false, false, scalar, scalar);
+  casts[scalar->type_id ()].add_overload (fn, false, scalar, scalar);
 }
 
 void
@@ -608,7 +608,7 @@
                                         ty->to_llvm ());
   engine->addGlobalMapping (fn, call);
 
-  jit_function::overload ol (fn, false, true, 0, string, ty);
+  jit_function::overload ol (fn, false, 0, string, ty);
   print_fn.add_overload (ol);
 }
 
@@ -631,7 +631,7 @@
   builder.CreateRet (ret);
   llvm::verifyFunction (*fn);
 
-  jit_function::overload ol(fn, false, false, ty, ty, ty);
+  jit_function::overload ol(fn, false, ty, ty, ty);
   binary_ops[op].add_overload (ol);
 }
 
@@ -653,7 +653,7 @@
   builder.CreateRet (ret);
   llvm::verifyFunction (*fn);
 
-  jit_function::overload ol (fn, false, false, boolean, ty, ty);
+  jit_function::overload ol (fn, false, boolean, ty, ty);
   binary_ops[op].add_overload (ol);
 }
 
@@ -675,7 +675,7 @@
   builder.CreateRet (ret);
   llvm::verifyFunction (*fn);
 
-  jit_function::overload ol (fn, false, false, boolean, ty, ty);
+  jit_function::overload ol (fn, false, boolean, ty, ty);
   binary_ops[op].add_overload (ol);
 }
 
@@ -787,6 +787,7 @@
 {
   if (mparent)
     mparent->remove (mlocation);
+  resize_arguments (0);
 }
 
 llvm::BasicBlock *
@@ -878,12 +879,11 @@
   return append (instr);
 }
 
-jit_instruction *
-jit_block::append (jit_instruction *instr)
+void
+jit_block::internal_append (jit_instruction *instr)
 {
   instructions.push_back (instr);
   instr->stash_parent (this, --instructions.end ());
-  return instr;
 }
 
 jit_instruction *
@@ -917,19 +917,19 @@
 jit_block::pred (size_t idx) const
 {
   // FIXME: Make this O(1)
-  
-  // here we get the use in backwards order. This means we preserve phi
-  // information when new blocks are added
   assert (idx < use_count ());
   jit_use *use;
-  size_t real_idx = use_count () - idx - 1;
   size_t i;
-  for (use = first_use (), i = 0; use && i < real_idx; ++i,
-         use = use->next ());
-    
+  for (use = first_use (), i = 0; use && i < idx; ++i, use = use->next ());
   return use->user_parent ();
 }
 
+bool
+jit_block::branch_alive (jit_block *asucc) const
+{
+  return terminator ()->alive (asucc);
+}
+
 size_t
 jit_block::pred_index (jit_block *apred) const
 {
@@ -941,12 +941,12 @@
 }
 
 void
-jit_block::create_merge (llvm::Function *inside, size_t pred_idx)
+jit_block::create_merge (llvm::Function *inside, jit_block *apred)
 {
   mpred_llvm.resize (pred_count ());
 
-  jit_block *ipred = pred (pred_idx);
-  if (! mpred_llvm[pred_idx] && ipred->pred_count () > 1)
+  size_t pred_idx = pred_index (apred);
+  if (! mpred_llvm[pred_idx] && apred->pred_count () > 1)
     {
       llvm::BasicBlock *amerge;
       amerge = llvm::BasicBlock::Create (context, "phi_merge", inside,
@@ -1122,20 +1122,99 @@
   return i;
 }
 
-// -------------------- jit_call --------------------
+// -------------------- jit_phi --------------------
 bool
-jit_call::dead (void) const
+jit_phi::prune (void)
 {
-  return ! has_side_effects () && use_count () == 0;
+  jit_block *p = parent ();
+  size_t new_idx = 0;
+  for (size_t i = 0; i < argument_count (); ++i)
+    {
+      jit_block *inc = incomming (i);
+      if (inc->branch_alive (p))
+        {
+          if (new_idx != i)
+            {
+              stash_argument (new_idx, argument (i));
+              mincomming[new_idx] = mincomming[i];
+            }
+
+          ++new_idx;
+        }
+    }
+
+  if (new_idx != argument_count ())
+    {
+      resize_arguments (new_idx);
+      mincomming.resize (new_idx);
+    }
+
+  assert (argument_count () > 0);
+  if (argument_count () == 1)
+    {
+      replace_with (argument (0));
+      return true;
+    }
+
+  return false;
 }
 
 bool
-jit_call::almost_dead (void) const
+jit_phi::infer (void)
+{
+  jit_block *p = parent ();
+  if (! p->alive ())
+    return false;
+
+  jit_type *infered = 0;
+  for (size_t i = 0; i < argument_count (); ++i)
+    {
+      jit_block *inc = mincomming[i];
+      if (inc->branch_alive (p))
+        infered = jit_typeinfo::join (infered, argument_type (i));
+    }
+  
+  if (infered != type ())
+    {
+      stash_type (infered);
+      return true;
+    }
+
+  return false;
+}
+
+// -------------------- jit_terminator --------------------
+bool
+jit_terminator::alive (const jit_block *asucessor) const
 {
-  return ! has_side_effects () && use_count () <= 1;
+  size_t scount = sucessor_count ();
+  for (size_t i = 0; i < scount; ++i)
+    if (sucessor (i) == asucessor)
+      return malive[i];
+
+  panic_impossible ();
 }
 
 bool
+jit_terminator::infer (void)
+{
+  if (! parent ()->alive ())
+    return false;
+
+  bool changed = false;
+  for (size_t i = 0; i < malive.size (); ++i)
+    if (! malive[i] && check_alive (i))
+      {
+        changed = true;
+        malive[i] = true;
+        sucessor (i)->mark_alive ();
+      }
+
+  return changed;
+}
+
+// -------------------- jit_call --------------------
+bool
 jit_call::infer (void)
 {
   // FIXME: explain algorithm
@@ -1173,6 +1252,7 @@
   entry_block = create<jit_block> ("body");
   final_block = create<jit_block> ("final");
   blocks.push_back (entry_block);
+  entry_block->mark_alive ();
   block = entry_block;
   visit (tee);
 
@@ -1207,9 +1287,17 @@
       worklist.pop_front ();
 
       if (next->infer ())
-        append_users (next);
+        {
+          // terminators need to be handles specially
+          if (jit_terminator *term = dynamic_cast<jit_terminator *> (next))
+            append_users_term (term);
+          else
+            append_users (next);
+        }
     }
 
+  remove_dead ();
+
   place_releases ();
 
 #ifdef OCTAVE_JIT_DEBUG
@@ -1270,12 +1358,13 @@
   jit_value *rhsv = visit (rhs);
 
   const jit_function& fn = jit_typeinfo::binary_op (be.op_type ());
-  result = block->append (create<jit_call> (fn, lhsv, rhsv));
-
-  jit_block *normal = create<jit_block> (block->name () + "a");
-  block->append (create<jit_check_error> (normal, final_block));
+  jit_call *call = block->append (create<jit_call> (fn, lhsv, rhsv));
+
+  jit_block *normal = create<jit_block> (block->name ());
+  block->append (create<jit_check_error> (call, normal, final_block));
   blocks.push_back (normal);
   block = normal;
+  result = call;
 }
 
 void
@@ -1524,12 +1613,12 @@
         {
           tree_expression *expr = tic->condition ();
           jit_value *cond = visit (expr);
-          jit_instruction *check = create<jit_call> (&jit_typeinfo::logically_true, cond);
+          jit_call *check = create<jit_call> (&jit_typeinfo::logically_true, cond);
           block->append (check);
 
-          jit_block *next = create<jit_block> (block->name () + "a");
+          jit_block *next = create<jit_block> (block->name ());
           blocks.push_back (next);
-          block->append (create<jit_check_error> (next, final_block));
+          block->append (create<jit_check_error> (check, next, final_block));
           block = next;
 
           jit_block *body = create<jit_block> (i == 0 ? "if_body" : "ifelse_body");
@@ -1806,6 +1895,24 @@
 }
 
 void
+jit_convert::append_users_term (jit_terminator *term)
+{
+  for (size_t i = 0; i < term->sucessor_count (); ++i)
+    {
+      if (term->alive (i))
+        {
+          jit_block *succ = term->sucessor (i);
+          for (jit_block::iterator iter = succ->begin (); iter != succ->end ()
+                 && isa<jit_phi> (*iter); ++iter)
+            worklist.push_back (*iter);
+
+          if (succ->terminator ())
+            worklist.push_back (succ->terminator ());
+        }
+    }
+}
+
+void
 jit_convert::merge_blocks (void)
 {
   for (block_list::iterator iter = blocks.begin (); iter != blocks.end ();
@@ -1903,7 +2010,6 @@
   for (size_t i = 0; i < block.succ_count (); ++i)
     {
       jit_block *finish = block.succ (i);
-      size_t pred_idx = finish->pred_index (&block);
 
       for (jit_block::iterator iter = finish->begin (); iter != finish->end ()
              && isa<jit_phi> (*iter);)
@@ -1912,7 +2018,7 @@
           jit_variable *var = phi->dest ();
           if (var->has_top ())
             {
-              phi->stash_argument (pred_idx, var->top ());
+              phi->add_incomming (&block, var->top ());
               ++iter;
             }
           else
@@ -1927,6 +2033,50 @@
 }
 
 void
+jit_convert::remove_dead ()
+{
+  block_list::iterator biter;
+  for (biter = blocks.begin (); biter != blocks.end (); ++biter)
+    {
+      jit_block *b = *biter;
+      if (b->alive ())
+        {
+          for (jit_block::iterator iter = b->begin (); iter != b->end ()
+                 && isa<jit_phi> (*iter);)
+            {
+              jit_phi *phi = static_cast<jit_phi *> (*iter);
+              if (phi->prune ())
+                iter = b->remove (iter);
+              else
+                ++iter;
+            }
+        }
+    }
+
+  for (biter = blocks.begin (); biter != blocks.end ();)
+    {
+      jit_block *b = *biter;
+      if (b->alive ())
+        {
+          // FIXME: A special case for jit_check_error, if we generalize to
+          // we will need to change!
+          jit_terminator *term = b->terminator ();
+          if (term && term->sucessor_count () == 2 && ! term->alive (1))
+            {
+              jit_block *succ = term->sucessor (0);
+              term->remove ();
+              jit_break *abreak = b->append (create<jit_break> (succ));
+              abreak->infer ();
+            }
+
+          ++biter;
+        }
+      else
+        biter = blocks.erase (biter);
+    }
+}
+
+void
 jit_convert::place_releases (void)
 {
   release_placer placer (*this);
@@ -2056,9 +2206,9 @@
 }
 
 void
-jit_convert::convert_llvm::finish_phi (jit_instruction *phi)
+jit_convert::convert_llvm::finish_phi (jit_instruction *aphi)
 {
-  jit_block *pblock = phi->parent ();
+  jit_phi *phi = static_cast<jit_phi *> (aphi);
   llvm::PHINode *llvm_phi = llvm::cast<llvm::PHINode> (phi->to_llvm ());
 
   bool can_remove = ! phi->use_count ();
@@ -2086,7 +2236,7 @@
           jit_value *arg = phi->argument (i);
           if (arg->has_llvm () && phi->argument_type (i) != phi->type ())
             {
-              llvm::BasicBlock *pred = pblock->pred_llvm (i);
+              llvm::BasicBlock *pred = phi->incomming_llvm (i);
               builder.SetInsertPoint (--pred->end ());
               const jit_function::overload& ol
                 = jit_typeinfo::get_release (phi->argument_type (i));
@@ -2106,7 +2256,7 @@
     {
       for (size_t i = 0; i < phi->argument_count (); ++i)
         {
-          llvm::BasicBlock *pred = pblock->pred_llvm (i);
+          llvm::BasicBlock *pred = phi->incomming_llvm (i);
           if (phi->argument_type (i) == phi->type ())
             llvm_phi->addIncoming (phi->argument_llvm (i), pred);
           else
@@ -2249,7 +2399,7 @@
   jit_block *parent = phi.parent ();
   for (size_t i = 0; i < phi.argument_count (); ++i)
     if (phi.argument_type (i) != phi.type ())
-      parent->create_merge (function, i);
+      parent->create_merge (function, phi.incomming (i));
 }
 
 void