Mercurial > hg > octave-lyh
diff src/pt-jit.cc @ 14913:c7071907a641
Use symbol_record_ref instead of names in JIT
* src/pt-id.h (tree_identifier::symbol): New function.
* src/symtab.h (tree_identifier::symbol_record_ref::operator->):
Added const variant.
* src/pt-jit.h: Use symbol_record_ref
* src/pt-jit.cc: Use symbol_record_ref
author | Max Brister <max@2bass.com> |
---|---|
date | Fri, 18 May 2012 10:22:34 -0600 |
parents | 1e2196d0bea4 |
children | cba58541954c |
line wrap: on
line diff
--- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -478,7 +478,6 @@ void jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv) { - // duplication here can probably be removed somehow if (type == any) to_generic (type, gv, octave_value ()); else if (type == scalar) @@ -557,9 +556,6 @@ void jit_infer::infer (tree_simple_for_command& cmd, jit_type *bounds) { - argin.insert ("#bounds"); - types["#bounds"] = bounds; - infer_simple_for (cmd, bounds); } @@ -690,7 +686,8 @@ void jit_infer::visit_identifier (tree_identifier& ti) { - handle_identifier (ti.name (), ti.do_lookup ()); + symbol_table::symbol_record_ref record = ti.symbol (); + handle_identifier (record); } void @@ -853,7 +850,9 @@ is_lvalue = true; rvalue_type = type_stack.back (); type_stack.pop_back (); - handle_identifier ("ans", symbol_table::varval ("ans")); + + symbol_table::symbol_record_ref record (symbol_table::insert ("ans")); + handle_identifier (record); if (rvalue_type != type_stack.back ()) fail (); @@ -946,12 +945,13 @@ } void -jit_infer::handle_identifier (const std::string& name, octave_value v) +jit_infer::handle_identifier (const symbol_table::symbol_record_ref& record) { - type_map::iterator iter = types.find (name); + type_map::iterator iter = types.find (record); if (iter == types.end ()) { - jit_type *ty = tinfo->type_of (v); + jit_type *ty = tinfo->type_of (record->find ()); + bool argin = false; if (is_lvalue) { if (! ty) @@ -961,68 +961,46 @@ { if (! ty) fail (); - - argin.insert (name); + argin = true; } - types[name] = ty; + types[record] = type_entry (argin, ty); type_stack.push_back (ty); } else - type_stack.push_back (iter->second); + type_stack.push_back (iter->second.second); } // -------------------- jit_generator -------------------- -jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee, - const std::set<std::string>& argin, - const type_map& infered_types, bool have_bounds) - : tinfo (ti), is_lvalue (false) +jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *mod, + tree_simple_for_command& cmd, jit_type *bounds, + const type_map& infered_types) + : tinfo (ti), module (mod), is_lvalue (false) { - // determine the function type through the type of all variables - std::vector<llvm::Type *> arg_types (infered_types.size ()); - size_t idx = 0; + // create new vectors that include bounds + std::vector<std::string> names (infered_types.size () + 1); + std::vector<bool> argin (infered_types.size () + 1); + std::vector<jit_type *> types (infered_types.size () + 1); + names[0] = "#bounds"; + argin[0] = true; + types[0] = bounds; + size_t i; type_map::const_iterator iter; - for (iter = infered_types.begin (); iter != infered_types.end (); ++iter, ++idx) - arg_types[idx] = iter->second->to_llvm_arg (); - - // now create the LLVM function from our determined types - llvm::LLVMContext &ctx = llvm::getGlobalContext (); - llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); - llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false); - function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, - "foobar", module); - - // declare each argument and copy its initial value - llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function); - builder.SetInsertPoint (body); - llvm::Function::arg_iterator arg_iter = function->arg_begin(); - for (iter = infered_types.begin (); iter != infered_types.end (); - ++iter, ++arg_iter) - + for (i = 1, iter = infered_types.begin (); iter != infered_types.end (); + ++i, ++iter) { - llvm::Type *vartype = iter->second->to_llvm (); - llvm::Value *var = builder.CreateAlloca (vartype, 0, iter->first); - variables[iter->first] = value (iter->second, var); - - if (iter->second->force_init () || argin.count (iter->first)) - { - llvm::Value *loaded_arg = builder.CreateLoad (arg_iter); - builder.CreateStore (loaded_arg, var); - } + names[i] = iter->first.name (); + argin[i] = iter->second.first; + types[i] = iter->second.second; } - // generate body + initialize (names, argin, types); + try { - tree_simple_for_command *cmd = dynamic_cast<tree_simple_for_command*>(&tee); - if (have_bounds && cmd) - { - value bounds = variables["#bounds"]; - bounds.second = builder.CreateLoad (bounds.second); - emit_simple_for (*cmd, bounds, true); - } - else - tee.accept (*this); + value var_bounds = variables["#bounds"]; + var_bounds.second = builder.CreateLoad (var_bounds.second); + emit_simple_for (cmd, var_bounds, true); } catch (const jit_fail_exception&) { @@ -1031,16 +1009,7 @@ return; } - // copy computed values back into arguments - arg_iter = function->arg_begin (); - for (iter = infered_types.begin (); iter != infered_types.end (); - ++iter, ++arg_iter) - { - llvm::Value *var = variables[iter->first].second; - llvm::Value *loaded_var = builder.CreateLoad (var); - builder.CreateStore (loaded_var, arg_iter); - } - builder.CreateRetVoid (); + finalize (names); } void @@ -1513,6 +1482,56 @@ builder.CreateCall2 (ol.function, str, v.second); } +void +jit_generator::initialize (const std::vector<std::string>& names, + const std::vector<bool>& argin, + const std::vector<jit_type *> types) +{ + std::vector<llvm::Type *> arg_types (names.size ()); + for (size_t i = 0; i < types.size (); ++i) + arg_types[i] = types[i]->to_llvm_arg (); + + llvm::LLVMContext &ctx = llvm::getGlobalContext (); + llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); + llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false); + function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "foobar", module); + + // create variables and copy initial values + llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function); + builder.SetInsertPoint (body); + llvm::Function::arg_iterator arg_iter = function->arg_begin(); + for (size_t i = 0; i < names.size (); ++i, ++arg_iter) + { + llvm::Type *vartype = types[i]->to_llvm (); + const std::string& name = names[i]; + llvm::Value *var = builder.CreateAlloca (vartype, 0, name); + variables[name] = value (types[i], var); + + if (argin[i] || types[i]->force_init ()) + { + llvm::Value *loaded_arg = builder.CreateLoad (arg_iter); + builder.CreateStore (loaded_arg, var); + } + } +} + +void +jit_generator::finalize (const std::vector<std::string>& names) +{ + // copy computed values back into arguments + // we use names instead of looping through variables because order is + // important + llvm::Function::arg_iterator arg_iter = function->arg_begin(); + for (size_t i = 0; i < names.size (); ++i, ++arg_iter) + { + llvm::Value *var = variables[names[i]].second; + llvm::Value *loaded_var = builder.CreateLoad (var); + builder.CreateStore (loaded_var, arg_iter); + } + builder.CreateRetVoid (); +} + // -------------------- tree_jit -------------------- tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0) @@ -1584,7 +1603,8 @@ // -------------------- jit_info -------------------- jit_info::jit_info (tree_jit& tjit, tree_simple_for_command& cmd, jit_type *bounds) : tinfo (tjit.get_typeinfo ()), - engine (tjit.get_engine ()) + engine (tjit.get_engine ()), + bounds_t (bounds) { jit_infer infer(tinfo); @@ -1598,10 +1618,9 @@ return; } - argin = infer.get_argin (); types = infer.get_types (); - jit_generator gen(tinfo, tjit.get_module (), cmd, argin, types); + jit_generator gen(tinfo, tjit.get_module (), cmd, bounds, types); function = gen.get_function (); if (function) @@ -1635,31 +1654,29 @@ if (! function) return false; - std::vector<llvm::GenericValue> args (types.size ()); + std::vector<llvm::GenericValue> args (types.size () + 1); + tinfo->to_generic (bounds_t, args[0], bounds); + size_t idx; type_map::const_iterator iter; - for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx) + for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx) { - if (argin.count (iter->first)) + if (iter->second.first) // argin? { - octave_value ov; - if (iter->first == "#bounds") - ov = bounds; - else - ov = symbol_table::varval (iter->first); - - tinfo->to_generic (iter->second, args[idx], ov); + octave_value ov = iter->first->varval (); + tinfo->to_generic (iter->second.second, args[idx], ov); } else - tinfo->to_generic (iter->second, args[idx]); + tinfo->to_generic (iter->second.second, args[idx]); } engine->runFunction (function, args); - for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx) + for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx) { - octave_value result = tinfo->to_octave_value (iter->second, args[idx]); - symbol_table::varref (iter->first) = result; + octave_value result = tinfo->to_octave_value (iter->second.second, args[idx]); + octave_value &ref = iter->first->varref (); + ref = result; } tinfo->reset_generic (); @@ -1670,19 +1687,20 @@ bool jit_info::match () const { - for (std::set<std::string>::iterator iter = argin.begin (); - iter != argin.end (); ++iter) + for (type_map::const_iterator iter = types.begin (); iter != types.end (); + ++iter) + { - if (*iter == "#bounds") - continue; + if (iter->second.first) // argin? + { + jit_type *required_type = iter->second.second; + octave_value val = iter->first->varval (); + jit_type *current_type = tinfo->type_of (val); - jit_type *required_type = types.find (*iter)->second; - octave_value val = symbol_table::varref (*iter); - jit_type *current_type = tinfo->type_of (val); - - // FIXME: should be: ! required_type->is_parent (current_type) - if (required_type != current_type) - return false; + // FIXME: should be: ! required_type->is_parent (current_type) + if (required_type != current_type) + return false; + } } return true;