Mercurial > hg > octave-lyh
diff src/pt-jit.cc @ 15027:741d2dbcc117
Check trip count before compiling for loops.
* src/jit-typeinfo.cc (octave_jit_cast_any_range, octave_jit_cast_range_any):
New function.
(octave_jit_paren_subsasgn_impl): Add return argument.
(jit_typeinfo::jit_typeinfo): Update octave_jit_paren_subsasgn_impl call and add
any <-> range casts.
* src/pt-eval.cc (tree_evaluator::visit_simple_for_command): Try jit after
computing loop bounds.
* src/pt-jit.cc (jit_convert::jit_convert): Add and handle for bounds argument.
(jit_convert::visit_binary_expression): Use next_shortcircut_result.
(jit_convert::visit_simple_for_command): Use next_iterator and check for
precomputed bounds.
(jit_convert::find_variable, jit_convert::create_variable,
jit_convert::next_name, tree_jit::trip_count, jit_info::initialize,
jit_info::find): New function.
(jit_convert::get_variable): Use find_variable and create_variable.
(tree_jit::execute): Allow for precomputed loop bounds and check trip count.
(jit_info::jit_info): Added new overload and defer work to initialize.
(jit_info::execute): Support precomputed bounds.
(jit_info::match): Support precomputed bounds.
* src/pt-jit.h (jit_convert::jit_convert, jit_convert::execute,
jit_info::execute, jit_info::match): New parameter.
(jit_convert::find_variable, jit_convert::create_variable,
tree_jit::trip_count, jit_info::initialize, jit_info::find): New declaration.
(jit_convert::next_iterator, jit_convert::next_for_bounds,
jit_convert::next_shortcircut_result, jit_convert::next_name): New function.
author | Max Brister <max@2bass.com> |
---|---|
date | Thu, 26 Jul 2012 17:03:15 -0500 |
parents | 75d1bc2fd6d2 |
children | 86a95d6ada0d |
line wrap: on
line diff
--- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -57,8 +57,9 @@ static llvm::LLVMContext& context = llvm::getGlobalContext (); // -------------------- jit_convert -------------------- -jit_convert::jit_convert (llvm::Module *module, tree &tee) - : iterator_count (0), short_count (0), breaking (false) +jit_convert::jit_convert (llvm::Module *module, tree &tee, + jit_type *for_bounds) + : iterator_count (0), for_bounds_count (0), short_count (0), breaking (false) { jit_instruction::reset_ids (); @@ -67,6 +68,10 @@ append (entry_block); entry_block->mark_alive (); block = entry_block; + + if (for_bounds) + create_variable (next_for_bounds (false), for_bounds); + visit (tee); // FIXME: Remove if we no longer only compile loops @@ -175,10 +180,7 @@ assert (boole); bool is_and = boole->op_type () == tree_boolean_expression::bool_and; - std::stringstream ss; - ss << "#short_result" << short_count++; - - std::string short_name = ss.str (); + std::string short_name = next_shortcircut_result (); jit_variable *short_result = create<jit_variable> (short_name); vmap[short_name] = short_result; @@ -302,10 +304,9 @@ continues.clear (); // we need a variable for our iterator, because it is used in multiple blocks - std::stringstream ss; - ss << "#iter" << iterator_count++; - std::string iter_name = ss.str (); + std::string iter_name = next_iterator (); jit_variable *iterator = create<jit_variable> (iter_name); + create<jit_variable> (iter_name); vmap[iter_name] = iterator; jit_block *body = create<jit_block> ("for_body"); @@ -314,7 +315,10 @@ jit_block *tail = create<jit_block> ("for_tail"); // do control expression, iter init, and condition check in prev_block (block) - jit_value *control = visit (cmd.control_expr ()); + // if we are the top level for loop, the bounds is an input argument. + jit_value *control = find_variable (next_for_bounds ()); + if (! control) + control = visit (cmd.control_expr ()); jit_call *init_iter = create<jit_call> (jit_typeinfo::for_init, control); block->append (init_iter); block->append (create<jit_assign> (iterator, init_iter)); @@ -762,21 +766,43 @@ } jit_variable * +jit_convert::find_variable (const std::string& vname) const +{ + vmap_t::const_iterator iter; + iter = vmap.find (vname); + return iter != vmap.end () ? iter->second : 0; +} + +jit_variable * jit_convert::get_variable (const std::string& vname) { - vmap_t::iterator iter; - iter = vmap.find (vname); - if (iter != vmap.end ()) - return iter->second; + jit_variable *ret = find_variable (vname); + if (ret) + return ret; - jit_variable *var = create<jit_variable> (vname); octave_value val = symbol_table::find (vname); jit_type *type = jit_typeinfo::type_of (val); + return create_variable (vname, type); +} + +jit_variable * +jit_convert::create_variable (const std::string& vname, jit_type *type) +{ + jit_variable *var = create<jit_variable> (vname); jit_extract_argument *extract; extract = create<jit_extract_argument> (type, var); entry_block->prepend (extract); + return vmap[vname] = var; +} - return vmap[vname] = var; +std::string +jit_convert::next_name (const char *prefix, size_t& count, bool inc) +{ + std::stringstream ss; + ss << prefix << count; + if (inc) + ++count; + return ss.str (); } std::pair<jit_value *, jit_value *> @@ -1462,20 +1488,29 @@ {} bool -tree_jit::execute (tree_simple_for_command& cmd) +tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds) { - if (! initialize ()) + const size_t MIN_TRIP_COUNT = 1000; + + size_t tc = trip_count (bounds); + if (! tc || ! initialize ()) return false; + jit_info::vmap extra_vars; + extra_vars["#for_bounds0"] = &bounds; + jit_info *info = cmd.get_info (); - if (! info || ! info->match ()) + if (! info || ! info->match (extra_vars)) { + if (tc < MIN_TRIP_COUNT) + return false; + delete info; - info = new jit_info (*this, cmd); + info = new jit_info (*this, cmd, bounds); cmd.stash_info (info); } - return info->execute (); + return info->execute (extra_vars); } bool @@ -1531,6 +1566,19 @@ return true; } +size_t +tree_jit::trip_count (const octave_value& bounds) const +{ + if (bounds.is_range ()) + { + Range rng = bounds.range_value (); + return rng.nelem (); + } + + // unsupported type + return 0; +} + void tree_jit::optimize (llvm::Function *fn) @@ -1548,14 +1596,12 @@ // -------------------- jit_info -------------------- jit_info::jit_info (tree_jit& tjit, tree& tee) - : engine (tjit.get_engine ()), llvm_function (0) + : engine (tjit.get_engine ()), function (0), llvm_function (0) { try { jit_convert conv (tjit.get_module (), tee); - llvm_function = conv.get_function (); - arguments = conv.get_arguments (); - bounds = conv.get_bounds (); + initialize (tjit, conv); } catch (const jit_fail_exception& e) { @@ -1564,24 +1610,24 @@ std::cout << "jit fail: " << e.what () << std::endl; #endif } - - if (! llvm_function) - { - function = 0; - return; - } - - tjit.optimize (llvm_function); +} +jit_info::jit_info (tree_jit& tjit, tree& tee, const octave_value& for_bounds) + : engine (tjit.get_engine ()), function (0), llvm_function (0) +{ + try + { + jit_convert conv (tjit.get_module (), tee, + jit_typeinfo::type_of (for_bounds)); + initialize (tjit, conv); + } + catch (const jit_fail_exception& e) + { #ifdef OCTAVE_JIT_DEBUG - std::cout << "-------------------- optimized llvm ir --------------------\n"; - llvm::raw_os_ostream llvm_cout (std::cout); - llvm_function->print (llvm_cout); - std::cout << std::endl; + if (e.known ()) + std::cout << "jit fail: " << e.what () << std::endl; #endif - - void *void_fn = engine->getPointerToFunction (llvm_function); - function = reinterpret_cast<jited_function> (void_fn); + } } jit_info::~jit_info (void) @@ -1591,7 +1637,7 @@ } bool -jit_info::execute (void) const +jit_info::execute (const vmap& extra_vars) const { if (! function) return false; @@ -1601,24 +1647,29 @@ { if (arguments[i].second) { - octave_value ¤t = symbol_table::varref (arguments[i].first); + octave_value current = find (extra_vars, arguments[i].first); octave_base_value *obv = current.internal_rep (); obv->grab (); real_arguments[i] = obv; - current = octave_value (); } } function (&real_arguments[0]); for (size_t i = 0; i < arguments.size (); ++i) - symbol_table::varref (arguments[i].first) = real_arguments[i]; + { + const std::string& name = arguments[i].first; + + // do not store for loop bounds temporary + if (name.size () && name[0] != '#') + symbol_table::varref (arguments[i].first) = real_arguments[i]; + } return true; } bool -jit_info::match (void) const +jit_info::match (const vmap& extra_vars) const { if (! function) return true; @@ -1626,7 +1677,7 @@ for (size_t i = 0; i < bounds.size (); ++i) { const std::string& arg_name = bounds[i].second; - octave_value value = symbol_table::find (arg_name); + octave_value value = find (extra_vars, arg_name); jit_type *type = jit_typeinfo::type_of (value); // FIXME: Check for a parent relationship @@ -1636,6 +1687,40 @@ return true; } + +void +jit_info::initialize (tree_jit& tjit, jit_convert& conv) +{ + llvm_function = conv.get_function (); + arguments = conv.get_arguments (); + bounds = conv.get_bounds (); + + if (llvm_function) + { + tjit.optimize (llvm_function); + +#ifdef OCTAVE_JIT_DEBUG + std::cout << "-------------------- optimized llvm ir " + << "--------------------\n"; + llvm::raw_os_ostream llvm_cout (std::cout); + llvm_function->print (llvm_cout); + llvm_cout.flush (); + std::cout << std::endl; +#endif + + void *void_fn = engine->getPointerToFunction (llvm_function); + function = reinterpret_cast<jited_function> (void_fn); + } +} + +octave_value +jit_info::find (const vmap& extra_vars, const std::string& vname) const +{ + vmap::const_iterator iter = extra_vars.find (vname); + return iter == extra_vars.end () ? symbol_table::varval (vname) + : *iter->second; +} + #endif