diff src/pt-jit.cc @ 14913:c7071907a641

Use symbol_record_ref instead of names in JIT * src/pt-id.h (tree_identifier::symbol): New function. * src/symtab.h (tree_identifier::symbol_record_ref::operator->): Added const variant. * src/pt-jit.h: Use symbol_record_ref * src/pt-jit.cc: Use symbol_record_ref
author Max Brister <max@2bass.com>
date Fri, 18 May 2012 10:22:34 -0600
parents 1e2196d0bea4
children cba58541954c
line wrap: on
line diff
--- a/src/pt-jit.cc
+++ b/src/pt-jit.cc
@@ -478,7 +478,6 @@
 void
 jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv)
 {
-  // duplication here can probably be removed somehow
   if (type == any)
     to_generic (type, gv, octave_value ());
   else if (type == scalar)
@@ -557,9 +556,6 @@
 void
 jit_infer::infer (tree_simple_for_command& cmd, jit_type *bounds)
 {
-  argin.insert ("#bounds");
-  types["#bounds"] = bounds;
-
   infer_simple_for (cmd, bounds);
 }
 
@@ -690,7 +686,8 @@
 void
 jit_infer::visit_identifier (tree_identifier& ti)
 {
-  handle_identifier (ti.name (), ti.do_lookup ());
+  symbol_table::symbol_record_ref record = ti.symbol ();
+  handle_identifier (record);
 }
 
 void
@@ -853,7 +850,9 @@
           is_lvalue = true;
           rvalue_type = type_stack.back ();
           type_stack.pop_back ();
-          handle_identifier ("ans", symbol_table::varval ("ans"));
+
+          symbol_table::symbol_record_ref record (symbol_table::insert ("ans"));
+          handle_identifier (record);
 
           if (rvalue_type != type_stack.back ())
             fail ();
@@ -946,12 +945,13 @@
 }
 
 void
-jit_infer::handle_identifier (const std::string& name, octave_value v)
+jit_infer::handle_identifier (const symbol_table::symbol_record_ref& record)
 {
-  type_map::iterator iter = types.find (name);
+  type_map::iterator iter = types.find (record);
   if (iter == types.end ())
     {
-      jit_type *ty = tinfo->type_of (v);
+      jit_type *ty = tinfo->type_of (record->find ());
+      bool argin = false;
       if (is_lvalue)
         {
           if (! ty)
@@ -961,68 +961,46 @@
         {
           if (! ty)
             fail ();
-
-          argin.insert (name);
+          argin = true;
         }
 
-      types[name] = ty;
+      types[record] = type_entry (argin, ty);
       type_stack.push_back (ty);
     }
   else
-    type_stack.push_back (iter->second);
+    type_stack.push_back (iter->second.second);
 }
 
 // -------------------- jit_generator --------------------
-jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee,
-                              const std::set<std::string>& argin,
-                              const type_map& infered_types, bool have_bounds)
-  : tinfo (ti), is_lvalue (false)
+jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *mod,
+                              tree_simple_for_command& cmd, jit_type *bounds,
+                              const type_map& infered_types)
+  : tinfo (ti), module (mod), is_lvalue (false)
 {
-  // determine the function type through the type of all variables
-  std::vector<llvm::Type *> arg_types (infered_types.size ());
-  size_t idx = 0;
+  // create new vectors that include bounds
+  std::vector<std::string> names (infered_types.size () + 1);
+  std::vector<bool> argin (infered_types.size () + 1);
+  std::vector<jit_type *> types (infered_types.size () + 1);
+  names[0] = "#bounds";
+  argin[0] = true;
+  types[0] = bounds;
+  size_t i;
   type_map::const_iterator iter;
-  for (iter = infered_types.begin (); iter != infered_types.end (); ++iter, ++idx)
-    arg_types[idx] = iter->second->to_llvm_arg ();
-
-  // now create the LLVM function from our determined types
-  llvm::LLVMContext &ctx = llvm::getGlobalContext ();
-  llvm::Type *tvoid = llvm::Type::getVoidTy (ctx);
-  llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false);
-  function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
-                                     "foobar", module);
-
-  // declare each argument and copy its initial value
-  llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function);
-  builder.SetInsertPoint (body);
-  llvm::Function::arg_iterator arg_iter = function->arg_begin();
-  for (iter = infered_types.begin (); iter != infered_types.end ();
-       ++iter, ++arg_iter)
-
+  for (i = 1, iter = infered_types.begin (); iter != infered_types.end ();
+       ++i, ++iter)
     {
-      llvm::Type *vartype = iter->second->to_llvm ();
-      llvm::Value *var = builder.CreateAlloca (vartype, 0, iter->first);
-      variables[iter->first] = value (iter->second, var);
-
-      if (iter->second->force_init () || argin.count (iter->first))
-        {
-          llvm::Value *loaded_arg = builder.CreateLoad (arg_iter);
-          builder.CreateStore (loaded_arg, var);
-        }
+      names[i] = iter->first.name ();
+      argin[i] = iter->second.first;
+      types[i] = iter->second.second;
     }
 
-  // generate body
+  initialize (names, argin, types);
+
   try
     {
-      tree_simple_for_command *cmd = dynamic_cast<tree_simple_for_command*>(&tee);
-      if (have_bounds && cmd)
-        {
-          value bounds = variables["#bounds"];
-          bounds.second = builder.CreateLoad (bounds.second);
-          emit_simple_for (*cmd, bounds, true);
-        }
-      else
-        tee.accept (*this);
+      value var_bounds = variables["#bounds"];
+      var_bounds.second = builder.CreateLoad (var_bounds.second);
+      emit_simple_for (cmd, var_bounds, true);
     }
   catch (const jit_fail_exception&)
     {
@@ -1031,16 +1009,7 @@
       return;
     }
 
-  // copy computed values back into arguments
-  arg_iter = function->arg_begin ();
-  for (iter = infered_types.begin (); iter != infered_types.end ();
-       ++iter, ++arg_iter)
-    {
-      llvm::Value *var = variables[iter->first].second;
-      llvm::Value *loaded_var = builder.CreateLoad (var);
-      builder.CreateStore (loaded_var, arg_iter);
-    }
-  builder.CreateRetVoid ();
+  finalize (names);
 }
 
 void
@@ -1513,6 +1482,56 @@
   builder.CreateCall2 (ol.function, str, v.second);
 }
 
+void
+jit_generator::initialize (const std::vector<std::string>& names,
+                           const std::vector<bool>& argin,
+                           const std::vector<jit_type *> types)
+{
+  std::vector<llvm::Type *> arg_types (names.size ());
+  for (size_t i = 0; i < types.size (); ++i)
+    arg_types[i] = types[i]->to_llvm_arg ();
+
+  llvm::LLVMContext &ctx = llvm::getGlobalContext ();
+  llvm::Type *tvoid = llvm::Type::getVoidTy (ctx);
+  llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false);
+  function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
+                                     "foobar", module);
+
+  // create variables and copy initial values
+  llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function);
+  builder.SetInsertPoint (body);
+  llvm::Function::arg_iterator arg_iter = function->arg_begin();
+  for (size_t i = 0; i < names.size (); ++i, ++arg_iter)
+    {
+      llvm::Type *vartype = types[i]->to_llvm ();
+      const std::string& name = names[i];
+      llvm::Value *var = builder.CreateAlloca (vartype, 0, name);
+      variables[name] = value (types[i], var);
+
+      if (argin[i] || types[i]->force_init ())
+        {
+          llvm::Value *loaded_arg = builder.CreateLoad (arg_iter);
+          builder.CreateStore (loaded_arg, var);
+        }
+    }
+}
+
+void
+jit_generator::finalize (const std::vector<std::string>& names)
+{
+  // copy computed values back into arguments
+  // we use names instead of looping through variables because order is
+  // important
+  llvm::Function::arg_iterator arg_iter = function->arg_begin();
+  for (size_t i = 0; i < names.size (); ++i, ++arg_iter)
+    {
+      llvm::Value *var = variables[names[i]].second;
+      llvm::Value *loaded_var = builder.CreateLoad (var);
+      builder.CreateStore (loaded_var, arg_iter);
+    }
+  builder.CreateRetVoid ();
+}
+
 // -------------------- tree_jit --------------------
 
 tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0)
@@ -1584,7 +1603,8 @@
 // -------------------- jit_info --------------------
 jit_info::jit_info (tree_jit& tjit, tree_simple_for_command& cmd,
                     jit_type *bounds) : tinfo (tjit.get_typeinfo ()),
-                                        engine (tjit.get_engine ())
+                                        engine (tjit.get_engine ()),
+                                        bounds_t (bounds)
 {
   jit_infer infer(tinfo);
 
@@ -1598,10 +1618,9 @@
       return;
     }
 
-  argin = infer.get_argin ();
   types = infer.get_types ();
 
-  jit_generator gen(tinfo, tjit.get_module (), cmd, argin, types);
+  jit_generator gen(tinfo, tjit.get_module (), cmd, bounds, types);
   function = gen.get_function ();
 
   if (function)
@@ -1635,31 +1654,29 @@
   if (! function)
     return false;
 
-  std::vector<llvm::GenericValue> args (types.size ());
+  std::vector<llvm::GenericValue> args (types.size () + 1);
+  tinfo->to_generic (bounds_t, args[0], bounds);
+
   size_t idx;
   type_map::const_iterator iter;
-  for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx)
+  for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx)
     {
-      if (argin.count (iter->first))
+      if (iter->second.first) // argin?
         {
-          octave_value ov;
-          if (iter->first == "#bounds")
-            ov = bounds;
-          else
-            ov = symbol_table::varval (iter->first);
-
-          tinfo->to_generic (iter->second, args[idx], ov);
+          octave_value ov = iter->first->varval ();
+          tinfo->to_generic (iter->second.second, args[idx], ov);
         }
       else
-        tinfo->to_generic (iter->second, args[idx]);
+        tinfo->to_generic (iter->second.second, args[idx]);
     }
 
   engine->runFunction (function, args);
 
-  for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx)
+  for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx)
     {
-      octave_value result = tinfo->to_octave_value (iter->second, args[idx]);
-      symbol_table::varref (iter->first) = result;
+      octave_value result = tinfo->to_octave_value (iter->second.second, args[idx]);
+      octave_value &ref = iter->first->varref ();
+      ref = result;
     }
 
   tinfo->reset_generic ();
@@ -1670,19 +1687,20 @@
 bool
 jit_info::match () const
 {
-  for (std::set<std::string>::iterator iter = argin.begin ();
-       iter != argin.end (); ++iter)
+  for (type_map::const_iterator iter = types.begin (); iter != types.end ();
+       ++iter)
+       
     {
-      if (*iter == "#bounds")
-        continue;
+      if (iter->second.first) // argin?
+        {
+          jit_type *required_type = iter->second.second;
+          octave_value val = iter->first->varval ();
+          jit_type *current_type = tinfo->type_of (val);
 
-      jit_type *required_type = types.find (*iter)->second;
-      octave_value val = symbol_table::varref (*iter);
-      jit_type *current_type = tinfo->type_of (val);
-
-      // FIXME: should be: ! required_type->is_parent (current_type)
-      if (required_type != current_type)
-        return false;
+          // FIXME: should be: ! required_type->is_parent (current_type)
+          if (required_type != current_type)
+            return false;
+        }
     }
 
   return true;