changeset 15029:7aa103a1c8ae

Merge in Doug's changes
author Jordi Gutiérrez Hermoso <jordigh@octave.org>
date Thu, 26 Jul 2012 18:45:52 -0400
parents 397f0d80bd47 (current diff) 741d2dbcc117 (diff)
children 86a95d6ada0d
files
diffstat 4 files changed, 214 insertions(+), 60 deletions(-) [+]
line wrap: on
line diff
--- a/src/jit-typeinfo.cc
+++ b/src/jit-typeinfo.cc
@@ -138,6 +138,25 @@
   obv->release ();
 }
 
+extern "C" octave_base_value *
+octave_jit_cast_any_range (jit_range *rng)
+{
+  Range temp (*rng);
+  octave_value ret (temp);
+  octave_base_value *rep = ret.internal_rep ();
+  rep->grab ();
+
+  return rep;
+}
+extern "C" void
+octave_jit_cast_range_any (jit_range *ret, octave_base_value *obv)
+{
+
+  jit_range r (obv->range_value ());
+  *ret = r;
+  obv->release ();
+}
+
 extern "C" double
 octave_jit_cast_scalar_any (octave_base_value *obv)
 {
@@ -210,8 +229,8 @@
 }
 
 extern "C" void
-octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index,
-                                double value)
+octave_jit_paren_subsasgn_impl (jit_matrix *ret, jit_matrix *mat,
+                                octave_idx_type index, double value)
 {
   NDArray *array = mat->array;
   if (array->nelem () < index)
@@ -221,6 +240,7 @@
   data[index - 1] = value;
 
   mat->update ();
+  *ret = *mat;
 }
 
 extern "C" void
@@ -1291,7 +1311,8 @@
 
   jit_function resize_paren_subsasgn
     = create_function (jit_convention::external,
-                       "octave_jit_paren_subsasgn_impl", matrix, index, scalar);
+                       "octave_jit_paren_subsasgn_impl", matrix, matrix, index,
+                       scalar);
   resize_paren_subsasgn.add_mapping (engine, &octave_jit_paren_subsasgn_impl);
   fn = create_function (jit_convention::internal, "octave_jit_paren_subsasgn",
                         matrix, matrix, scalar, scalar);
@@ -1336,8 +1357,8 @@
 
     // resize on out of bounds access
     builder.SetInsertPoint (bounds_error);
-    llvm::Value *resize_result = resize_paren_subsasgn.call (builder, int_idx,
-                                                             value);
+    llvm::Value *resize_result = resize_paren_subsasgn.call (builder, mat,
+                                                             int_idx, value);
     builder.CreateBr (done);
 
     builder.SetInsertPoint (success);
@@ -1369,6 +1390,7 @@
   casts[scalar->type_id ()].stash_name ("(scalar)");
   casts[complex->type_id ()].stash_name ("(complex)");
   casts[matrix->type_id ()].stash_name ("(matrix)");
+  casts[any->type_id ()].stash_name ("(range)");
 
   // cast any <- matrix
   fn = create_function (jit_convention::external, "octave_jit_cast_any_matrix",
@@ -1382,6 +1404,18 @@
   fn.add_mapping (engine, &octave_jit_cast_matrix_any);
   casts[matrix->type_id ()].add_overload (fn);
 
+  // cast any <- range
+  fn = create_function (jit_convention::external, "octave_jit_cast_any_range",
+                        any, range);
+  fn.add_mapping (engine, &octave_jit_cast_any_range);
+  casts[any->type_id ()].add_overload (fn);
+
+  // cast range <- any
+  fn = create_function (jit_convention::external, "octave_jit_cast_range_any",
+                        range, any);
+  fn.add_mapping (engine, &octave_jit_cast_range_any);
+  casts[range->type_id ()].add_overload (fn);
+
   // cast any <- scalar
   fn = create_function (jit_convention::external, "octave_jit_cast_any_scalar",
                         any, scalar);
--- a/src/pt-eval.cc
+++ b/src/pt-eval.cc
@@ -296,11 +296,6 @@
   if (debug_mode)
     do_breakpoint (cmd.is_breakpoint ());
 
-#if HAVE_LLVM
-  if (jiter.execute (cmd))
-    return;
-#endif
-
   // FIXME -- need to handle PARFOR loops here using cmd.in_parallel ()
   // and cmd.maxproc_expr ();
 
@@ -314,6 +309,11 @@
 
   octave_value rhs = expr->rvalue1 ();
 
+#if HAVE_LLVM
+  if (jiter.execute (cmd, rhs))
+    return;
+#endif
+
   if (error_state || rhs.is_undefined ())
     return;
 
--- a/src/pt-jit.cc
+++ b/src/pt-jit.cc
@@ -57,8 +57,9 @@
 static llvm::LLVMContext& context = llvm::getGlobalContext ();
 
 // -------------------- jit_convert --------------------
-jit_convert::jit_convert (llvm::Module *module, tree &tee)
-  : iterator_count (0), short_count (0), breaking (false)
+jit_convert::jit_convert (llvm::Module *module, tree &tee,
+                          jit_type *for_bounds)
+  : iterator_count (0), for_bounds_count (0), short_count (0), breaking (false)
 {
   jit_instruction::reset_ids ();
 
@@ -67,6 +68,10 @@
   append (entry_block);
   entry_block->mark_alive ();
   block = entry_block;
+
+  if (for_bounds)
+    create_variable (next_for_bounds (false), for_bounds);
+
   visit (tee);
 
   // FIXME: Remove if we no longer only compile loops
@@ -175,10 +180,7 @@
       assert (boole);
       bool is_and = boole->op_type () == tree_boolean_expression::bool_and;
 
-      std::stringstream ss;
-      ss << "#short_result" << short_count++;
-
-      std::string short_name = ss.str ();
+      std::string short_name = next_shortcircut_result ();
       jit_variable *short_result = create<jit_variable> (short_name);
       vmap[short_name] = short_result;
 
@@ -302,10 +304,9 @@
   continues.clear ();
 
   // we need a variable for our iterator, because it is used in multiple blocks
-  std::stringstream ss;
-  ss << "#iter" << iterator_count++;
-  std::string iter_name = ss.str ();
+  std::string iter_name = next_iterator ();
   jit_variable *iterator = create<jit_variable> (iter_name);
+  create<jit_variable> (iter_name);
   vmap[iter_name] = iterator;
 
   jit_block *body = create<jit_block> ("for_body");
@@ -314,7 +315,10 @@
   jit_block *tail = create<jit_block> ("for_tail");
 
   // do control expression, iter init, and condition check in prev_block (block)
-  jit_value *control = visit (cmd.control_expr ());
+  // if we are the top level for loop, the bounds is an input argument.
+  jit_value *control = find_variable (next_for_bounds ());
+  if (! control)
+    control = visit (cmd.control_expr ());
   jit_call *init_iter = create<jit_call> (jit_typeinfo::for_init, control);
   block->append (init_iter);
   block->append (create<jit_assign> (iterator, init_iter));
@@ -762,21 +766,43 @@
 }
 
 jit_variable *
+jit_convert::find_variable (const std::string& vname) const
+{
+  vmap_t::const_iterator iter;
+  iter = vmap.find (vname);
+  return iter != vmap.end () ? iter->second : 0;
+}
+
+jit_variable *
 jit_convert::get_variable (const std::string& vname)
 {
-  vmap_t::iterator iter;
-  iter = vmap.find (vname);
-  if (iter != vmap.end ())
-    return iter->second;
+  jit_variable *ret = find_variable (vname);
+  if (ret)
+    return ret;
 
-  jit_variable *var = create<jit_variable> (vname);
   octave_value val = symbol_table::find (vname);
   jit_type *type = jit_typeinfo::type_of (val);
+  return create_variable (vname, type);
+}
+
+jit_variable *
+jit_convert::create_variable (const std::string& vname, jit_type *type)
+{
+  jit_variable *var = create<jit_variable> (vname);
   jit_extract_argument *extract;
   extract = create<jit_extract_argument> (type, var);
   entry_block->prepend (extract);
+  return vmap[vname] = var;
+}
 
-  return vmap[vname] = var;
+std::string
+jit_convert::next_name (const char *prefix, size_t& count, bool inc)
+{
+  std::stringstream ss;
+  ss << prefix << count;
+  if (inc)
+    ++count;
+  return ss.str ();
 }
 
 std::pair<jit_value *, jit_value *>
@@ -1462,20 +1488,29 @@
 {}
 
 bool
-tree_jit::execute (tree_simple_for_command& cmd)
+tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds)
 {
-  if (! initialize ())
+  const size_t MIN_TRIP_COUNT = 1000;
+
+  size_t tc = trip_count (bounds);
+  if (! tc || ! initialize ())
     return false;
 
+  jit_info::vmap extra_vars;
+  extra_vars["#for_bounds0"] = &bounds;
+
   jit_info *info = cmd.get_info ();
-  if (! info || ! info->match ())
+  if (! info || ! info->match (extra_vars))
     {
+      if (tc < MIN_TRIP_COUNT)
+        return false;
+
       delete info;
-      info = new jit_info (*this, cmd);
+      info = new jit_info (*this, cmd, bounds);
       cmd.stash_info (info);
     }
 
-  return info->execute ();
+  return info->execute (extra_vars);
 }
 
 bool
@@ -1531,6 +1566,19 @@
   return true;
 }
 
+size_t
+tree_jit::trip_count (const octave_value& bounds) const
+{
+  if (bounds.is_range ())
+    {
+      Range rng = bounds.range_value ();
+      return rng.nelem ();
+    }
+
+  // unsupported type
+  return 0;
+}
+
 
 void
 tree_jit::optimize (llvm::Function *fn)
@@ -1548,14 +1596,12 @@
 
 // -------------------- jit_info --------------------
 jit_info::jit_info (tree_jit& tjit, tree& tee)
-  : engine (tjit.get_engine ()), llvm_function (0)
+  : engine (tjit.get_engine ()), function (0), llvm_function (0)
 {
   try
     {
       jit_convert conv (tjit.get_module (), tee);
-      llvm_function = conv.get_function ();
-      arguments = conv.get_arguments ();
-      bounds = conv.get_bounds ();
+      initialize (tjit, conv);
     }
   catch (const jit_fail_exception& e)
     {
@@ -1564,24 +1610,24 @@
         std::cout << "jit fail: " << e.what () << std::endl;
 #endif
     }
-
-  if (! llvm_function)
-    {
-      function = 0;
-      return;
-    }
-
-  tjit.optimize (llvm_function);
+}
 
+jit_info::jit_info (tree_jit& tjit, tree& tee, const octave_value& for_bounds)
+  : engine (tjit.get_engine ()), function (0), llvm_function (0)
+{
+  try
+    {
+      jit_convert conv (tjit.get_module (), tee,
+                        jit_typeinfo::type_of (for_bounds));
+      initialize (tjit, conv);
+    }
+  catch (const jit_fail_exception& e)
+    {
 #ifdef OCTAVE_JIT_DEBUG
-  std::cout << "-------------------- optimized llvm ir --------------------\n";
-  llvm::raw_os_ostream llvm_cout (std::cout);
-  llvm_function->print (llvm_cout);
-  std::cout << std::endl;
+      if (e.known ())
+        std::cout << "jit fail: " << e.what () << std::endl;
 #endif
-
-  void *void_fn = engine->getPointerToFunction (llvm_function);
-  function = reinterpret_cast<jited_function> (void_fn);
+    }
 }
 
 jit_info::~jit_info (void)
@@ -1591,7 +1637,7 @@
 }
 
 bool
-jit_info::execute (void) const
+jit_info::execute (const vmap& extra_vars) const
 {
   if (! function)
     return false;
@@ -1601,24 +1647,29 @@
     {
       if (arguments[i].second)
         {
-          octave_value &current = symbol_table::varref (arguments[i].first);
+          octave_value current = find (extra_vars, arguments[i].first);
           octave_base_value *obv = current.internal_rep ();
           obv->grab ();
           real_arguments[i] = obv;
-          current = octave_value ();
         }
     }
 
   function (&real_arguments[0]);
 
   for (size_t i = 0; i < arguments.size (); ++i)
-    symbol_table::varref (arguments[i].first) = real_arguments[i];
+    {
+      const std::string& name = arguments[i].first;
+
+      // do not store for loop bounds temporary
+      if (name.size () && name[0] != '#')
+        symbol_table::varref (arguments[i].first) = real_arguments[i];
+    }
 
   return true;
 }
 
 bool
-jit_info::match (void) const
+jit_info::match (const vmap& extra_vars) const
 {
   if (! function)
     return true;
@@ -1626,7 +1677,7 @@
   for (size_t i = 0; i < bounds.size (); ++i)
     {
       const std::string& arg_name = bounds[i].second;
-      octave_value value = symbol_table::find (arg_name);
+      octave_value value = find (extra_vars, arg_name);
       jit_type *type = jit_typeinfo::type_of (value);
 
       // FIXME: Check for a parent relationship
@@ -1636,6 +1687,40 @@
 
   return true;
 }
+
+void
+jit_info::initialize (tree_jit& tjit, jit_convert& conv)
+{
+  llvm_function = conv.get_function ();
+  arguments = conv.get_arguments ();
+  bounds = conv.get_bounds ();
+
+  if (llvm_function)
+    {
+      tjit.optimize (llvm_function);
+
+#ifdef OCTAVE_JIT_DEBUG
+      std::cout << "-------------------- optimized llvm ir "
+                << "--------------------\n";
+      llvm::raw_os_ostream llvm_cout (std::cout);
+      llvm_function->print (llvm_cout);
+      llvm_cout.flush ();
+      std::cout << std::endl;
+#endif
+
+      void *void_fn = engine->getPointerToFunction (llvm_function);
+      function = reinterpret_cast<jited_function> (void_fn);
+    }
+}
+
+octave_value
+jit_info::find (const vmap& extra_vars, const std::string& vname) const
+{
+  vmap::const_iterator iter = extra_vars.find (vname);
+  return iter == extra_vars.end () ? symbol_table::varval (vname)
+    : *iter->second;
+}
+
 #endif
 
 
--- a/src/pt-jit.h
+++ b/src/pt-jit.h
@@ -64,7 +64,7 @@
   typedef std::pair<jit_type *, std::string> type_bound;
   typedef std::vector<type_bound> type_bound_vector;
 
-  jit_convert (llvm::Module *module, tree &tee);
+  jit_convert (llvm::Module *module, tree &tee, jit_type *for_bounds = 0);
 
   ~jit_convert (void);
 
@@ -245,6 +245,7 @@
   std::list<jit_value *> all_values;
 
   size_t iterator_count;
+  size_t for_bounds_count;
   size_t short_count;
 
   typedef std::map<std::string, jit_variable *> vmap_t;
@@ -268,8 +269,31 @@
     return ret;
   }
 
+  // get an existing vairable. If the variable does not exist, it will not be
+  // created
+  jit_variable *find_variable (const std::string& vname) const;
+
+  // get a variable, create it if it does not exist. The type will default to
+  // the variable's current type in the symbol table.
   jit_variable *get_variable (const std::string& vname);
 
+  // create a variable of the given name and given type. Will also insert an
+  // extract statement
+  jit_variable *create_variable (const std::string& vname, jit_type *type);
+
+  // The name of the next for loop iterator. If inc is false, then the iterator
+  // counter will not be incremented.
+  std::string next_iterator (bool inc = true)
+  { return next_name ("#iter", iterator_count, inc); }
+
+  std::string next_for_bounds (bool inc = true)
+  { return next_name ("#for_bounds", for_bounds_count, inc); }
+
+  std::string next_shortcircut_result (bool inc = true)
+  { return next_name ("#shortcircut_result", short_count, inc); }
+
+  std::string next_name (const char *prefix, size_t& count, bool inc);
+
   std::pair<jit_value *, jit_value *> resolve (tree_index_expression& exp);
 
   jit_value *do_assign (tree_expression *exp, jit_value *rhs,
@@ -404,7 +428,7 @@
 
   ~tree_jit (void);
 
-  bool execute (tree_simple_for_command& cmd);
+  bool execute (tree_simple_for_command& cmd, const octave_value& bounds);
 
   bool execute (tree_while_command& cmd);
 
@@ -416,6 +440,8 @@
  private:
   bool initialize (void);
 
+  size_t trip_count (const octave_value& bounds) const;
+
   // FIXME: Temorary hack to test
   typedef std::map<tree *, jit_info *> compiled_map;
   llvm::Module *module;
@@ -428,18 +454,27 @@
 jit_info
 {
 public:
+  // we use a pointer here so we don't have to include ov.h
+  typedef std::map<std::string, const octave_value *> vmap;
+
   jit_info (tree_jit& tjit, tree& tee);
 
+  jit_info (tree_jit& tjit, tree& tee, const octave_value& for_bounds);
+
   ~jit_info (void);
 
-  bool execute (void) const;
+  bool execute (const vmap& extra_vars = vmap ()) const;
 
-  bool match (void) const;
+  bool match (const vmap& extra_vars = vmap ()) const;
 private:
   typedef jit_convert::type_bound type_bound;
   typedef jit_convert::type_bound_vector type_bound_vector;
   typedef void (*jited_function)(octave_base_value**);
 
+  void initialize (tree_jit& tjit, jit_convert& conv);
+
+  octave_value find (const vmap& extra_vars, const std::string& vname) const;
+
   llvm::ExecutionEngine *engine;
   jited_function function;
   llvm::Function *llvm_function;