Mercurial > hg > octave-lyh
diff libinterp/interp-core/pt-jit.cc @ 15337:3f43e9d6d86e
JIT compile anonymous functions
* jit-ir.h (jit_block::front, jit_block::back): New function.
(jit_call::jit_call): New overloads.
(jit_return): New class.
* jit-typeinfo.cc (octave_jit_create_undef): New function.
(jit_operation::to_idx): Correctly handle empty type vector.
(jit_typeinfo::jit_typeinfo): Add destroy_fn and initialize create_undef.
* jit-typeinfo.h (jit_typeinfo::get_any_ptr, jit_typeinfo::destroy,
jit_typeinfo::create_undef): New function.
* pt-jit.cc (jit_convert::jit_convert): Add overload and refactor.
(jit_convert::initialize, jit_convert_llvm::convert_loop,
jit_convert_llvm::convert_function, tree_jit::do_execute,
jit_function_info::jit_function_info, jit_function_info::execute,
jit_function_info::match): New function.
(jit_convert::get_variable): Support function variable lookup.
(jit_convert_llvm::convert): Handle loop/function agnostic stuff.
(jit_convert_llvm::visit): Handle function creation as well.
(tree_jit::execute): Move implementation to tree_jit::do_execute.
(jit_info::compile): Call convert_loop instead of convert.
* pt-jit.h (jit_convert::jit_convert): New overload.
(jit_convert::initialize, jit_convert_llvm::convert_loop,
jit_convert_llvm::convert_function, tree_jit::do_execute): New function.
(jit_convert::create_variable, jit_convert_llvm::initialize): Update signature.
(tree_jit::execute): Made static.
(tree_jit::tree_jit): Made private.
(jit_function_info): New class.
* ov-usr-fcn.cc (octave_user_function::~octave_user_function): Delete jit_info.
(octave_user_function::octave_user_function): Maybe JIT and use is_special_expr
and special_expr.
(octave_user_function::special_expr): New function.
* ov-usr-fcn.h (octave_user_function::is_special_expr,
octave_user_function::special_expr, octave_user_function::get_info,
octave_user_function::stash_info): New function.
* pt-decl.h (tree_decl_elt::name): New function.
* pt-eval.cc (tree_evaluator::visit_simple_for_command,
tree_evaluator::visit_while_command): Use static tree_jit methods.
author | Max Brister <max@2bass.com> |
---|---|
date | Sun, 09 Sep 2012 00:29:00 -0600 |
parents | 8125773322d4 |
children | b49d707fe9d7 |
line wrap: on
line diff
--- a/libinterp/interp-core/pt-jit.cc +++ b/libinterp/interp-core/pt-jit.cc @@ -65,22 +65,16 @@ // -------------------- jit_convert -------------------- jit_convert::jit_convert (tree &tee, jit_type *for_bounds) - : iterator_count (0), for_bounds_count (0), short_count (0), breaking (false) + : converting_function (false) { - jit_instruction::reset_ids (); - - entry_block = factory.create<jit_block> ("body"); - final_block = factory.create<jit_block> ("final"); - blocks.push_back (entry_block); - entry_block->mark_alive (); - block = entry_block; + initialize (symbol_table::current_scope ()); if (for_bounds) create_variable (next_for_bounds (false), for_bounds); visit (tee); - // FIXME: Remove if we no longer only compile loops + // breaks must have been handled by the top level loop assert (! breaking); assert (breaks.empty ()); assert (continues.empty ()); @@ -95,6 +89,91 @@ if (name.size () && name[0] != '#') final_block->append (factory.create<jit_store_argument> (var)); } + + final_block->append (factory.create<jit_return> ()); +} + +jit_convert::jit_convert (octave_user_function& fcn, + const std::vector<jit_type *>& args) + : converting_function (true) +{ + initialize (fcn.scope ()); + + tree_parameter_list *plist = fcn.parameter_list (); + tree_parameter_list *rlist = fcn.return_list (); + if (plist && plist->takes_varargs ()) + throw jit_fail_exception ("varags not supported"); + + if (rlist && (rlist->size () > 1 || rlist->takes_varargs ())) + throw jit_fail_exception ("multiple returns not supported"); + + if (plist) + { + tree_parameter_list::iterator piter = plist->begin (); + for (size_t i = 0; i < args.size (); ++i, ++piter) + { + if (piter == plist->end ()) + throw jit_fail_exception ("Too many parameter to function"); + + tree_decl_elt *elt = *piter; + std::string name = elt->name (); + create_variable (name, args[i]); + } + } + + jit_value *return_value = 0; + if (fcn.is_special_expr ()) + { + tree_expression *expr = fcn.special_expr (); + if (expr) + { + jit_variable *retvar = get_variable ("#return"); + jit_value *retval = visit (expr); + block->append (factory.create<jit_assign> (retvar, retval)); + return_value = retvar; + } + } + else + visit_statement_list (*fcn.body ()); + + // the user may use break or continue to exit the function. Because the + // function does not start as a loop, we can have one continue, one break, or + // a regular fallthrough to exit the function + if (continues.size ()) + { + assert (! continues.size ()); + finish_breaks (final_block, continues); + } + else if (breaks.size ()) + finish_breaks (final_block, breaks); + else + block->append (factory.create<jit_branch> (final_block)); + blocks.push_back (final_block); + block = final_block; + + if (! return_value && rlist && rlist->size () == 1) + { + tree_decl_elt *elt = rlist->front (); + return_value = get_variable (elt->name ()); + } + + // FIXME: We should use live range analysis to delete variables where needed. + // For now we just delete everything at the end of the function. + for (variable_map::iterator iter = vmap.begin (); iter != vmap.end (); ++iter) + { + if (iter->second != return_value) + { + jit_call *call; + call = factory.create<jit_call> (&jit_typeinfo::destroy, + iter->second); + final_block->append (call); + } + } + + if (return_value) + final_block->append (factory.create<jit_return> (return_value)); + else + final_block->append (factory.create<jit_return> ()); } void @@ -719,6 +798,23 @@ throw jit_fail_exception (); } +void +jit_convert::initialize (symbol_table::scope_id s) +{ + scope = s; + iterator_count = 0; + for_bounds_count = 0; + short_count = 0; + breaking = false; + jit_instruction::reset_ids (); + + entry_block = factory.create<jit_block> ("body"); + final_block = factory.create<jit_block> ("final"); + blocks.push_back (entry_block); + entry_block->mark_alive (); + block = entry_block; +} + jit_call * jit_convert::create_checked_impl (jit_call *ret) { @@ -749,20 +845,42 @@ if (ret) return ret; - octave_value val = symbol_table::find (vname); - jit_type *type = jit_typeinfo::type_of (val); - bounds.push_back (type_bound (type, vname)); + symbol_table::symbol_record record = symbol_table::find_symbol (vname, scope); + if (record.is_persistent () || record.is_global ()) + throw jit_fail_exception ("Persistent and global not yet supported"); - return create_variable (vname, type); + if (converting_function) + return create_variable (vname, jit_typeinfo::get_any (), false); + else + { + octave_value val = record.varval (); + jit_type *type = jit_typeinfo::type_of (val); + bounds.push_back (type_bound (type, vname)); + + return create_variable (vname, type); + } } jit_variable * -jit_convert::create_variable (const std::string& vname, jit_type *type) +jit_convert::create_variable (const std::string& vname, jit_type *type, + bool isarg) { jit_variable *var = factory.create<jit_variable> (vname); - jit_extract_argument *extract; - extract = factory.create<jit_extract_argument> (type, var); - entry_block->prepend (extract); + + if (isarg) + { + jit_extract_argument *extract; + extract = factory.create<jit_extract_argument> (type, var); + entry_block->prepend (extract); + } + else + { + jit_call *init = factory.create<jit_call> (&jit_typeinfo::create_undef); + jit_assign *assign = factory.create<jit_assign> (var, init); + entry_block->prepend (assign); + entry_block->prepend (init); + } + return vmap[vname] = var; } @@ -898,10 +1016,12 @@ // -------------------- jit_convert_llvm -------------------- llvm::Function * -jit_convert_llvm::convert (llvm::Module *module, - const jit_block_list& blocks, - const std::list<jit_value *>& constants) +jit_convert_llvm::convert_loop (llvm::Module *module, + const jit_block_list& blocks, + const std::list<jit_value *>& constants) { + converting_function = false; + // for now just init arguments from entry, later we will have to do something // more interesting jit_block *entry_block = blocks.front (); @@ -934,44 +1054,7 @@ arguments[argument_vec[i].first] = loaded_arg; } - std::list<jit_block *>::const_iterator biter; - for (biter = blocks.begin (); biter != blocks.end (); ++biter) - { - jit_block *jblock = *biter; - llvm::BasicBlock *block = llvm::BasicBlock::Create (context, - jblock->name (), - function); - jblock->stash_llvm (block); - } - - jit_block *first = *blocks.begin (); - builder.CreateBr (first->to_llvm ()); - - // constants aren't in the IR, we visit those first - for (std::list<jit_value *>::const_iterator iter = constants.begin (); - iter != constants.end (); ++iter) - if (! isa<jit_instruction> (*iter)) - visit (*iter); - - // convert all instructions - for (biter = blocks.begin (); biter != blocks.end (); ++biter) - visit (*biter); - - // now finish phi nodes - for (biter = blocks.begin (); biter != blocks.end (); ++biter) - { - jit_block& block = **biter; - for (jit_block::iterator piter = block.begin (); - piter != block.end () && isa<jit_phi> (*piter); ++piter) - { - jit_instruction *phi = *piter; - finish_phi (static_cast<jit_phi *> (phi)); - } - } - - jit_block *last = blocks.back (); - builder.SetInsertPoint (last->to_llvm ()); - builder.CreateRetVoid (); + convert (blocks, constants); } catch (const jit_fail_exception& e) { function->eraseFromParent (); @@ -981,6 +1064,92 @@ return function; } + +jit_function +jit_convert_llvm::convert_function (llvm::Module *module, + const jit_block_list& blocks, + const std::list<jit_value *>& constants, + octave_user_function& fcn, + const std::vector<jit_type *>& args) +{ + converting_function = true; + + jit_block *final_block = blocks.back (); + jit_return *ret = dynamic_cast<jit_return *> (final_block->back ()); + assert (ret); + + jit_function creating = jit_function (module, jit_convention::internal, + "foobar", ret->result_type (), args); + function = creating.to_llvm (); + + try + { + prelude = creating.new_block ("prelude"); + builder.SetInsertPoint (prelude); + + tree_parameter_list *plist = fcn.parameter_list (); + if (plist) + { + tree_parameter_list::iterator piter = plist->begin (); + tree_parameter_list::iterator pend = plist->end (); + for (size_t i = 0; i < args.size () && piter != pend; ++i, ++piter) + { + tree_decl_elt *elt = *piter; + std::string arg_name = elt->name (); + arguments[arg_name] = creating.argument (builder, i); + } + } + + convert (blocks, constants); + } catch (const jit_fail_exception& e) + { + function->eraseFromParent (); + throw; + } + + return creating; +} + +void +jit_convert_llvm::convert (const jit_block_list& blocks, + const std::list<jit_value *>& constants) +{ + std::list<jit_block *>::const_iterator biter; + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block *jblock = *biter; + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, + jblock->name (), + function); + jblock->stash_llvm (block); + } + + jit_block *first = *blocks.begin (); + builder.CreateBr (first->to_llvm ()); + + // constants aren't in the IR, we visit those first + for (std::list<jit_value *>::const_iterator iter = constants.begin (); + iter != constants.end (); ++iter) + if (! isa<jit_instruction> (*iter)) + visit (*iter); + + // convert all instructions + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + visit (*biter); + + // now finish phi nodes + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block& block = **biter; + for (jit_block::iterator piter = block.begin (); + piter != block.end () && isa<jit_phi> (*piter); ++piter) + { + jit_instruction *phi = *piter; + finish_phi (static_cast<jit_phi *> (phi)); + } + } +} + void jit_convert_llvm::finish_phi (jit_phi *phi) { @@ -1089,10 +1258,16 @@ { llvm::Value *arg = arguments[extract.name ()]; assert (arg); - arg = builder.CreateLoad (arg); - const jit_function& ol = extract.overload (); - extract.stash_llvm (ol.call (builder, arg)); + if (converting_function) + extract.stash_llvm (arg); + else + { + arg = builder.CreateLoad (arg); + + const jit_function& ol = extract.overload (); + extract.stash_llvm (ol.call (builder, arg)); + } } void @@ -1105,6 +1280,16 @@ } void +jit_convert_llvm::visit (jit_return& ret) +{ + jit_value *res = ret.result (); + if (res) + builder.CreateRet (res->to_llvm ()); + else + builder.CreateRetVoid (); +} + +void jit_convert_llvm::visit (jit_phi& phi) { // we might not have converted all incoming branches, so we don't @@ -1539,44 +1724,27 @@ bool tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds) { - const size_t MIN_TRIP_COUNT = 1000; - - size_t tc = trip_count (bounds); - if (! tc || ! initialize ()) - return false; - - jit_info::vmap extra_vars; - extra_vars["#for_bounds0"] = &bounds; - - jit_info *info = cmd.get_info (); - if (! info || ! info->match (extra_vars)) - { - if (tc < MIN_TRIP_COUNT) - return false; - - delete info; - info = new jit_info (*this, cmd, bounds); - cmd.stash_info (info); - } - - return info->execute (extra_vars); + return instance ().do_execute (cmd, bounds); } bool tree_jit::execute (tree_while_command& cmd) { - if (! initialize ()) - return false; + return instance ().do_execute (cmd); +} - jit_info *info = cmd.get_info (); - if (! info || ! info->match ()) - { - delete info; - info = new jit_info (*this, cmd); - cmd.stash_info (info); - } +bool +tree_jit::execute (octave_user_function& fcn, const octave_value_list& args, + octave_value_list& retval) +{ + return instance ().do_execute (fcn, args, retval); +} - return info->execute (); +tree_jit& +tree_jit::instance (void) +{ + static tree_jit ret; + return ret; } bool @@ -1616,6 +1784,67 @@ return true; } +bool +tree_jit::do_execute (tree_simple_for_command& cmd, const octave_value& bounds) +{ + const size_t MIN_TRIP_COUNT = 1000; + + size_t tc = trip_count (bounds); + if (! tc || ! initialize ()) + return false; + + jit_info::vmap extra_vars; + extra_vars["#for_bounds0"] = &bounds; + + jit_info *info = cmd.get_info (); + if (! info || ! info->match (extra_vars)) + { + if (tc < MIN_TRIP_COUNT) + return false; + + delete info; + info = new jit_info (*this, cmd, bounds); + cmd.stash_info (info); + } + + return info->execute (extra_vars); +} + +bool +tree_jit::do_execute (tree_while_command& cmd) +{ + if (! initialize ()) + return false; + + jit_info *info = cmd.get_info (); + if (! info || ! info->match ()) + { + delete info; + info = new jit_info (*this, cmd); + cmd.stash_info (info); + } + + return info->execute (); +} + +bool +tree_jit::do_execute (octave_user_function& fcn, const octave_value_list& args, + octave_value_list& retval) +{ + if (! initialize ()) + return false; + + jit_function_info *info = fcn.get_info (); + if (! info || ! info->match (args)) + { + delete info; + info = new jit_function_info (*this, fcn, args); + fcn.stash_info (info); + } + + return info->execute (args, retval); +} + size_t tree_jit::trip_count (const octave_value& bounds) const { @@ -1644,6 +1873,163 @@ #endif } +// -------------------- jit_function_info -------------------- +jit_function_info::jit_function_info (tree_jit& tjit, + octave_user_function& fcn, + const octave_value_list& ov_args) + : argument_types (ov_args.length ()), function (0) +{ + size_t nargs = ov_args.length (); + for (size_t i = 0; i < nargs; ++i) + argument_types[i] = jit_typeinfo::type_of (ov_args(i)); + + try + { + jit_convert conv (fcn, argument_types); + jit_infer infer (conv.get_factory (), conv.get_blocks (), + conv.get_variable_map ()); + infer.infer (); + +#if OCTAVE_JIT_DEBUG + if (Venable_jit_debug) + { + jit_block_list& blocks = infer.get_blocks (); + jit_block *entry_block = blocks.front (); + entry_block->label (); + std::cout << "-------------------- Compiling function "; + std::cout << "--------------------\n"; + + tree_print_code tpc (std::cout); + tpc.visit_octave_user_function_header (fcn); + tpc.visit_statement_list (*fcn.body ()); + tpc.visit_octave_user_function_trailer (fcn); + blocks.print (std::cout, "octave jit ir"); + } +#endif + + jit_factory& factory = conv.get_factory (); + llvm::Module *module = tjit.get_module (); + jit_convert_llvm to_llvm; + jit_function raw_fn = to_llvm.convert_function (module, + infer.get_blocks (), + factory.constants (), + fcn, argument_types); + +#ifdef OCTAVE_JIT_DEBUG + if (Venable_jit_debug) + { + std::cout << "-------------------- raw function "; + std::cout << "--------------------\n"; + std::cout << *raw_fn.to_llvm () << std::endl; + } +#endif + + std::string wrapper_name = fcn.name () + "_wrapper"; + jit_type *any_t = jit_typeinfo::get_any (); + std::vector<jit_type *> wrapper_args (1, jit_typeinfo::get_any_ptr ()); + jit_function wrapper (module, jit_convention::internal, wrapper_name, + any_t, wrapper_args); + llvm::BasicBlock *wrapper_body = wrapper.new_block (); + builder.SetInsertPoint (wrapper_body); + + llvm::Value *wrapper_arg = wrapper.argument (builder, 0); + std::vector<llvm::Value *> raw_args (nargs); + for (size_t i = 0; i < nargs; ++i) + { + llvm::Value *arg; + arg = builder.CreateConstInBoundsGEP1_32 (wrapper_arg, i); + arg = builder.CreateLoad (arg); + + jit_type *arg_type = argument_types[i]; + const jit_function& cast = jit_typeinfo::cast (arg_type, any_t); + raw_args[i] = cast.call (builder, arg); + } + + llvm::Value *result = raw_fn.call (builder, raw_args); + if (raw_fn.result ()) + { + jit_type *raw_result_t = raw_fn.result (); + const jit_function& cast = jit_typeinfo::cast (any_t, raw_result_t); + result = cast.call (builder, result); + } + else + { + llvm::Value *zero = builder.getInt32 (0); + result = builder.CreateBitCast (zero, any_t->to_llvm ()); + } + + wrapper.do_return (builder, result); + + llvm::Function *llvm_function = wrapper.to_llvm (); + tjit.optimize (llvm_function); + +#ifdef OCTAVE_JIT_DEBUG + if (Venable_jit_debug) + { + std::cout << "-------------------- optimized and wrapped "; + std::cout << "--------------------\n"; + std::cout << *llvm_function << std::endl; + } +#endif + + llvm::ExecutionEngine* engine = tjit.get_engine (); + void *void_fn = engine->getPointerToFunction (llvm_function); + function = reinterpret_cast<jited_function> (void_fn); + } + catch (const jit_fail_exception& e) + { + argument_types.clear (); +#ifdef OCTAVE_JIT_DEBUG + if (Venable_jit_debug) + { + if (e.known ()) + std::cout << "jit fail: " << e.what () << std::endl; + } +#endif + } +} + +bool +jit_function_info::execute (const octave_value_list& ov_args, + octave_value_list& retval) const +{ + if (! function) + return false; + + // TODO figure out a way to delete ov_args so we avoid duplicating refcount + size_t nargs = ov_args.length (); + std::vector<octave_base_value *> args (nargs); + for (size_t i = 0; i < nargs; ++i) + { + octave_base_value *obv = ov_args(i).internal_rep (); + obv->grab (); + args[i] = obv; + } + + octave_base_value *ret = function (&args[0]); + if (ret) + retval(0) = octave_value (ret); + + return true; +} + +bool +jit_function_info::match (const octave_value_list& ov_args) const +{ + if (! function) + return true; + + size_t nargs = ov_args.length (); + if (nargs != argument_types.size ()) + return false; + + for (size_t i = 0; i < nargs; ++i) + if (jit_typeinfo::type_of (ov_args(i)) != argument_types[i]) + return false; + + return true; +} + // -------------------- jit_info -------------------- jit_info::jit_info (tree_jit& tjit, tree& tee) : engine (tjit.get_engine ()), function (0), llvm_function (0) @@ -1739,8 +2125,9 @@ jit_factory& factory = conv.get_factory (); jit_convert_llvm to_llvm; - llvm_function = to_llvm.convert (tjit.get_module (), infer.get_blocks (), - factory.constants ()); + llvm_function = to_llvm.convert_loop (tjit.get_module (), + infer.get_blocks (), + factory.constants ()); arguments = to_llvm.get_arguments (); bounds = conv.get_bounds (); } @@ -2126,4 +2513,13 @@ %!error <undefined near> (test_undef); +%!shared id +%! id = @(x) x; + +%!assert (id (1), 1); +%!assert (id (1+1i), 1+1i) +%!assert (id (1, 2), 1) +%!error <undefined> (id ()) + + */