# HG changeset patch # User Max Brister # Date 1337358154 21600 # Node ID c7071907a641ea2926ec3b875d06aa23449a2ebe # Parent 3d3c002ccc60f8c2edfd6a29c5435b2684caa84d Use symbol_record_ref instead of names in JIT * src/pt-id.h (tree_identifier::symbol): New function. * src/symtab.h (tree_identifier::symbol_record_ref::operator->): Added const variant. * src/pt-jit.h: Use symbol_record_ref * src/pt-jit.cc: Use symbol_record_ref diff --git a/src/pt-id.h b/src/pt-id.h --- a/src/pt-id.h +++ b/src/pt-id.h @@ -114,6 +114,10 @@ void accept (tree_walker& tw); + symbol_table::symbol_record_ref symbol (void) const + { + return sym; + } private: // The symbol record that this identifier references. diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -478,7 +478,6 @@ void jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv) { - // duplication here can probably be removed somehow if (type == any) to_generic (type, gv, octave_value ()); else if (type == scalar) @@ -557,9 +556,6 @@ void jit_infer::infer (tree_simple_for_command& cmd, jit_type *bounds) { - argin.insert ("#bounds"); - types["#bounds"] = bounds; - infer_simple_for (cmd, bounds); } @@ -690,7 +686,8 @@ void jit_infer::visit_identifier (tree_identifier& ti) { - handle_identifier (ti.name (), ti.do_lookup ()); + symbol_table::symbol_record_ref record = ti.symbol (); + handle_identifier (record); } void @@ -853,7 +850,9 @@ is_lvalue = true; rvalue_type = type_stack.back (); type_stack.pop_back (); - handle_identifier ("ans", symbol_table::varval ("ans")); + + symbol_table::symbol_record_ref record (symbol_table::insert ("ans")); + handle_identifier (record); if (rvalue_type != type_stack.back ()) fail (); @@ -946,12 +945,13 @@ } void -jit_infer::handle_identifier (const std::string& name, octave_value v) +jit_infer::handle_identifier (const symbol_table::symbol_record_ref& record) { - type_map::iterator iter = types.find (name); + type_map::iterator iter = types.find (record); if (iter == types.end ()) { - jit_type *ty = tinfo->type_of (v); + jit_type *ty = tinfo->type_of (record->find ()); + bool argin = false; if (is_lvalue) { if (! ty) @@ -961,68 +961,46 @@ { if (! ty) fail (); - - argin.insert (name); + argin = true; } - types[name] = ty; + types[record] = type_entry (argin, ty); type_stack.push_back (ty); } else - type_stack.push_back (iter->second); + type_stack.push_back (iter->second.second); } // -------------------- jit_generator -------------------- -jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee, - const std::set& argin, - const type_map& infered_types, bool have_bounds) - : tinfo (ti), is_lvalue (false) +jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *mod, + tree_simple_for_command& cmd, jit_type *bounds, + const type_map& infered_types) + : tinfo (ti), module (mod), is_lvalue (false) { - // determine the function type through the type of all variables - std::vector arg_types (infered_types.size ()); - size_t idx = 0; + // create new vectors that include bounds + std::vector names (infered_types.size () + 1); + std::vector argin (infered_types.size () + 1); + std::vector types (infered_types.size () + 1); + names[0] = "#bounds"; + argin[0] = true; + types[0] = bounds; + size_t i; type_map::const_iterator iter; - for (iter = infered_types.begin (); iter != infered_types.end (); ++iter, ++idx) - arg_types[idx] = iter->second->to_llvm_arg (); - - // now create the LLVM function from our determined types - llvm::LLVMContext &ctx = llvm::getGlobalContext (); - llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); - llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false); - function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, - "foobar", module); - - // declare each argument and copy its initial value - llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function); - builder.SetInsertPoint (body); - llvm::Function::arg_iterator arg_iter = function->arg_begin(); - for (iter = infered_types.begin (); iter != infered_types.end (); - ++iter, ++arg_iter) - + for (i = 1, iter = infered_types.begin (); iter != infered_types.end (); + ++i, ++iter) { - llvm::Type *vartype = iter->second->to_llvm (); - llvm::Value *var = builder.CreateAlloca (vartype, 0, iter->first); - variables[iter->first] = value (iter->second, var); - - if (iter->second->force_init () || argin.count (iter->first)) - { - llvm::Value *loaded_arg = builder.CreateLoad (arg_iter); - builder.CreateStore (loaded_arg, var); - } + names[i] = iter->first.name (); + argin[i] = iter->second.first; + types[i] = iter->second.second; } - // generate body + initialize (names, argin, types); + try { - tree_simple_for_command *cmd = dynamic_cast(&tee); - if (have_bounds && cmd) - { - value bounds = variables["#bounds"]; - bounds.second = builder.CreateLoad (bounds.second); - emit_simple_for (*cmd, bounds, true); - } - else - tee.accept (*this); + value var_bounds = variables["#bounds"]; + var_bounds.second = builder.CreateLoad (var_bounds.second); + emit_simple_for (cmd, var_bounds, true); } catch (const jit_fail_exception&) { @@ -1031,16 +1009,7 @@ return; } - // copy computed values back into arguments - arg_iter = function->arg_begin (); - for (iter = infered_types.begin (); iter != infered_types.end (); - ++iter, ++arg_iter) - { - llvm::Value *var = variables[iter->first].second; - llvm::Value *loaded_var = builder.CreateLoad (var); - builder.CreateStore (loaded_var, arg_iter); - } - builder.CreateRetVoid (); + finalize (names); } void @@ -1513,6 +1482,56 @@ builder.CreateCall2 (ol.function, str, v.second); } +void +jit_generator::initialize (const std::vector& names, + const std::vector& argin, + const std::vector types) +{ + std::vector arg_types (names.size ()); + for (size_t i = 0; i < types.size (); ++i) + arg_types[i] = types[i]->to_llvm_arg (); + + llvm::LLVMContext &ctx = llvm::getGlobalContext (); + llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); + llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false); + function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "foobar", module); + + // create variables and copy initial values + llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function); + builder.SetInsertPoint (body); + llvm::Function::arg_iterator arg_iter = function->arg_begin(); + for (size_t i = 0; i < names.size (); ++i, ++arg_iter) + { + llvm::Type *vartype = types[i]->to_llvm (); + const std::string& name = names[i]; + llvm::Value *var = builder.CreateAlloca (vartype, 0, name); + variables[name] = value (types[i], var); + + if (argin[i] || types[i]->force_init ()) + { + llvm::Value *loaded_arg = builder.CreateLoad (arg_iter); + builder.CreateStore (loaded_arg, var); + } + } +} + +void +jit_generator::finalize (const std::vector& names) +{ + // copy computed values back into arguments + // we use names instead of looping through variables because order is + // important + llvm::Function::arg_iterator arg_iter = function->arg_begin(); + for (size_t i = 0; i < names.size (); ++i, ++arg_iter) + { + llvm::Value *var = variables[names[i]].second; + llvm::Value *loaded_var = builder.CreateLoad (var); + builder.CreateStore (loaded_var, arg_iter); + } + builder.CreateRetVoid (); +} + // -------------------- tree_jit -------------------- tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0) @@ -1584,7 +1603,8 @@ // -------------------- jit_info -------------------- jit_info::jit_info (tree_jit& tjit, tree_simple_for_command& cmd, jit_type *bounds) : tinfo (tjit.get_typeinfo ()), - engine (tjit.get_engine ()) + engine (tjit.get_engine ()), + bounds_t (bounds) { jit_infer infer(tinfo); @@ -1598,10 +1618,9 @@ return; } - argin = infer.get_argin (); types = infer.get_types (); - jit_generator gen(tinfo, tjit.get_module (), cmd, argin, types); + jit_generator gen(tinfo, tjit.get_module (), cmd, bounds, types); function = gen.get_function (); if (function) @@ -1635,31 +1654,29 @@ if (! function) return false; - std::vector args (types.size ()); + std::vector args (types.size () + 1); + tinfo->to_generic (bounds_t, args[0], bounds); + size_t idx; type_map::const_iterator iter; - for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx) + for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx) { - if (argin.count (iter->first)) + if (iter->second.first) // argin? { - octave_value ov; - if (iter->first == "#bounds") - ov = bounds; - else - ov = symbol_table::varval (iter->first); - - tinfo->to_generic (iter->second, args[idx], ov); + octave_value ov = iter->first->varval (); + tinfo->to_generic (iter->second.second, args[idx], ov); } else - tinfo->to_generic (iter->second, args[idx]); + tinfo->to_generic (iter->second.second, args[idx]); } engine->runFunction (function, args); - for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx) + for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx) { - octave_value result = tinfo->to_octave_value (iter->second, args[idx]); - symbol_table::varref (iter->first) = result; + octave_value result = tinfo->to_octave_value (iter->second.second, args[idx]); + octave_value &ref = iter->first->varref (); + ref = result; } tinfo->reset_generic (); @@ -1670,19 +1687,20 @@ bool jit_info::match () const { - for (std::set::iterator iter = argin.begin (); - iter != argin.end (); ++iter) + for (type_map::const_iterator iter = types.begin (); iter != types.end (); + ++iter) + { - if (*iter == "#bounds") - continue; + if (iter->second.first) // argin? + { + jit_type *required_type = iter->second.second; + octave_value val = iter->first->varval (); + jit_type *current_type = tinfo->type_of (val); - jit_type *required_type = types.find (*iter)->second; - octave_value val = symbol_table::varref (*iter); - jit_type *current_type = tinfo->type_of (val); - - // FIXME: should be: ! required_type->is_parent (current_type) - if (required_type != current_type) - return false; + // FIXME: should be: ! required_type->is_parent (current_type) + if (required_type != current_type) + return false; + } } return true; diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -32,6 +32,7 @@ #include "Array.h" #include "Range.h" #include "pt-walk.h" +#include "symtab.h" // -------------------- Current status -------------------- // Simple binary operations (+-*/) on octave_scalar's (doubles) are optimized. @@ -295,8 +296,6 @@ void reset_generic (void); private: - typedef std::map type_map; - jit_type *new_type (const std::string& name, bool force_init, jit_type *parent, llvm::Type *llvm_type); @@ -332,14 +331,16 @@ class jit_infer : public tree_walker { - typedef std::map type_map; public: + // pair + typedef std::pair type_entry; + typedef std::map type_map; + jit_infer (jit_typeinfo *ti) : tinfo (ti), is_lvalue (false), rvalue_type (0) {} - const std::set& get_argin () const { return argin; } - const type_map& get_types () const { return types; } void infer (tree_simple_for_command& cmd, jit_type *bounds); @@ -433,7 +434,7 @@ void infer_simple_for (tree_simple_for_command& cmd, jit_type *bounds); - void handle_identifier (const std::string& name, octave_value v); + void handle_identifier (const symbol_table::symbol_record_ref& record); jit_typeinfo *tinfo; @@ -441,7 +442,6 @@ jit_type *rvalue_type; type_map types; - std::set argin; std::vector type_stack; }; @@ -449,11 +449,11 @@ class jit_generator : public tree_walker { - typedef std::map type_map; public: - jit_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee, - const std::set& argin, - const type_map& infered_types, bool have_bounds = true); + typedef jit_infer::type_map type_map; + + jit_generator (jit_typeinfo *ti, llvm::Module *mod, tree_simple_for_command &cmd, + jit_type *bounds, const type_map& infered_types); llvm::Function *get_function () const { return function; } @@ -555,7 +555,14 @@ value_stack.push_back (value (type, v)); } + void initialize (const std::vector& names, + const std::vector& argin, + const std::vector types); + + void finalize (const std::vector& names); + jit_typeinfo *tinfo; + llvm::Module *module; llvm::Function *function; bool is_lvalue; @@ -596,19 +603,19 @@ jit_info { public: + typedef jit_infer::type_map type_map; + jit_info (tree_jit& tjit, tree_simple_for_command& cmd, jit_type *bounds); bool execute (const octave_value& bounds) const; bool match (void) const; private: - typedef std::map type_map; - jit_typeinfo *tinfo; llvm::ExecutionEngine *engine; - std::set argin; type_map types; llvm::Function *function; + jit_type *bounds_t; }; #endif diff --git a/src/symtab.h b/src/symtab.h --- a/src/symtab.h +++ b/src/symtab.h @@ -610,6 +610,12 @@ return &sym; } + symbol_record *operator-> (void) const + { + update (); + return &sym; + } + // can be used to place symbol_record_ref in maps, we don't overload < as // it doesn't make any sense for symbol_record_ref struct comparator @@ -621,7 +627,7 @@ } }; private: - void update (void) + void update (void) const { scope_id curr_scope = symbol_table::current_scope (); if (scope != curr_scope || ! sym.is_valid ()) @@ -631,8 +637,8 @@ } } - scope_id scope; - symbol_record sym; + mutable scope_id scope; + mutable symbol_record sym; }; class