Mercurial > hg > octave-lyh
diff src/pt-jit.cc @ 14920:51d4b1018efb
For loops compile with new IR
* src/pt-eval.cc (tree_evaluator::visit_simple_for_command): Compile loops.
(tree_evaluator::visit_statement): No longer compile individual statements.
* src/pt-loop.h (tree_simple_for_command::get_info): Remove type map.
(tree_simple_for_command::stash_info): Remove type map.
* src/pt-loop.cc (tree_simple_for_command::~tree_simple_for_command):
Delete compiled code instead of map.
author | Max Brister <max@2bass.com> |
---|---|
date | Sat, 26 May 2012 20:30:28 -0500 |
parents | 13465aab507f |
children | 2e6f83b2f2b9 |
line wrap: on
line diff
--- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -65,7 +65,18 @@ jit_typeinfo *jit_typeinfo::instance; // thrown when we should give up on JIT and interpret -class jit_fail_exception : public std::exception {}; +class jit_fail_exception : public std::runtime_error +{ +public: + jit_fail_exception (void) : std::runtime_error ("unknown"), mknown (false) {} + jit_fail_exception (const std::string& reason) : std::runtime_error (reason), + mknown (true) + {} + + bool known (void) const { return mknown; } +private: + bool mknown; +}; static void fail (void) @@ -73,6 +84,12 @@ throw jit_fail_exception (); } +static void +fail (const std::string& reason) +{ + throw jit_fail_exception (reason); +} + // function that jit code calls extern "C" void octave_jit_print_any (const char *name, octave_base_value *obv) @@ -127,6 +144,14 @@ return new octave_scalar (value); } +// -------------------- jit_range -------------------- +std::ostream& +operator<< (std::ostream& os, const jit_range& rng) +{ + return os << "Range[" << rng.base << ", " << rng.limit << ", " << rng.inc + << ", " << rng.nelem << "]"; +} + // -------------------- jit_type -------------------- llvm::Type * jit_type::to_llvm_arg (void) const @@ -308,6 +333,10 @@ fn = create_identity (scalar); grab_fn.add_overload (fn, false, scalar, scalar); + // grab index + fn = create_identity (index); + grab_fn.add_overload (fn, false, index, index); + // release any fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ()); engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_release_any)); @@ -318,6 +347,10 @@ fn = create_identity (scalar); release_fn.add_overload (fn, false, 0, scalar); + // release index + fn = create_identity (index); + release_fn.add_overload (fn, false, 0, index); + // now for binary scalar operations // FIXME: Finish all operations add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd); @@ -336,41 +369,47 @@ add_binary_fcmp (scalar, octave_value::op_gt, llvm::CmpInst::FCMP_UGT); add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); + // now for binary index operators + add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); + // now for printing functions print_fn.stash_name ("print"); add_print (any, reinterpret_cast<void*> (&octave_jit_print_any)); add_print (scalar, reinterpret_cast<void*> (&octave_jit_print_double)); + // initialize for loop + for_init_fn.stash_name ("for_init"); + + fn = create_function ("octave_jit_for_range_init", index, range); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *zero = llvm::ConstantInt::get (index_t, 0); + builder.CreateRet (zero); + } + llvm::verifyFunction (*fn); + for_init_fn.add_overload (fn, false, index, range); + // bounds check for for loop - fn = create_function ("octave_jit_simple_for_range", boolean, range, index); - llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + for_check_fn.stash_name ("for_check"); + + fn = create_function ("octave_jit_for_range_check", boolean, range, index); + body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { llvm::Value *nelem = builder.CreateExtractValue (fn->arg_begin (), 3); - // llvm::Value *idx = builder.CreateLoad (++fn->arg_begin ()); llvm::Value *idx = ++fn->arg_begin (); llvm::Value *ret = builder.CreateICmpULT (idx, nelem); builder.CreateRet (ret); } llvm::verifyFunction (*fn); - simple_for_check.add_overload (fn, false, boolean, range, index); - - // increment for for loop - fn = create_function ("octave_jit_imple_for_range_incr", index, index); - body = llvm::BasicBlock::Create (context, "body", fn); - builder.SetInsertPoint (body); - { - llvm::Value *one = llvm::ConstantInt::get (index_t, 1); - llvm::Value *idx = fn->arg_begin (); - llvm::Value *ret = builder.CreateAdd (idx, one); - builder.CreateRet (ret); - } - llvm::verifyFunction (*fn); - simple_for_incr.add_overload (fn, false, index, index); + for_check_fn.add_overload (fn, false, boolean, range, index); // index variabe for for loop - fn = create_function ("octave_jit_simple_for_idx", scalar, range, index); + for_index_fn.stash_name ("for_index"); + + fn = create_function ("octave_jit_for_range_idx", scalar, range, index); body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { @@ -385,7 +424,7 @@ builder.CreateRet (ret); } llvm::verifyFunction (*fn); - simple_for_index.add_overload (fn, false, scalar, range, index); + for_index_fn.add_overload (fn, false, scalar, range, index); // logically true // FIXME: Check for NaN @@ -569,7 +608,102 @@ return ret; } +// -------------------- jit_use -------------------- +jit_block * +jit_use::user_parent (void) const +{ + return usr->parent (); +} + +// -------------------- jit_value -------------------- +#define JIT_METH(clname) \ + void \ + jit_ ## clname::accept (jit_ir_walker& walker) \ + { \ + walker.visit (*this); \ + } + +JIT_VISIT_IR_NOTEMPLATE +#undef JIT_METH + +// -------------------- jit_instruction -------------------- +llvm::BasicBlock * +jit_instruction::parent_llvm (void) const +{ + return mparent->to_llvm (); +} + // -------------------- jit_block -------------------- +jit_instruction * +jit_block::prepend (jit_instruction *instr) +{ + instructions.push_front (instr); + instr->stash_parent (this); + return instr; +} + +jit_instruction * +jit_block::append (jit_instruction *instr) +{ + instructions.push_back (instr); + instr->stash_parent (this); + return instr; +} + +jit_terminator * +jit_block::terminator (void) const +{ + if (instructions.empty ()) + return 0; + + jit_instruction *last = instructions.back (); + return dynamic_cast<jit_terminator *> (last); +} + +llvm::Value * +jit_block::pred_terminator_llvm (size_t idx) const +{ + jit_terminator *term = pred_terminator (idx); + return term ? term->to_llvm () : 0; +} + +void +jit_block::create_merge (llvm::Function *inside, size_t pred_idx) +{ + mpred_llvm.resize (pred_count ()); + + jit_block *ipred = pred (pred_idx); + if (! mpred_llvm[pred_idx] && ipred->pred_count () > 1) + { + llvm::BasicBlock *merge; + merge = llvm::BasicBlock::Create (context, "phi_merge", inside, + to_llvm ()); + + // fix the predecessor jump if it has been created + llvm::Value *term = pred_terminator_llvm (pred_idx); + if (term) + { + llvm::TerminatorInst *branch = llvm::cast<llvm::TerminatorInst> (term); + for (size_t i = 0; i < branch->getNumSuccessors (); ++i) + { + if (branch->getSuccessor (i) == to_llvm ()) + branch->setSuccessor (i, merge); + } + } + + llvm::IRBuilder<> temp (merge); + temp.CreateBr (to_llvm ()); + mpred_llvm[pred_idx] = merge; + } +} + +size_t +jit_block::succ_count (void) const +{ + jit_terminator *term = terminator (); + return term ? term->sucessor_count () : 0; +} + llvm::BasicBlock * jit_block::to_llvm (void) const { @@ -609,20 +743,20 @@ { jit_instruction::reset_ids (); - entry_block = new jit_block ("entry"); - blocks.push_back (entry_block); - block = new jit_block ("body"); + jit_block *entry_block = new jit_block ("body"); + block = entry_block; blocks.push_back (block); + toplevel_map tlevel (block); + variables = &tlevel; final_block = new jit_block ("final"); visit (tee); + blocks.push_back (final_block); - - entry_block->append (new jit_break (block)); block->append (new jit_break (final_block)); - for (variable_map::iterator iter = variables.begin (); - iter != variables.end (); ++iter) + for (variable_map::iterator iter = variables->begin (); + iter != variables->end (); ++iter) final_block->append (new jit_store_argument (iter->first, iter->second)); // FIXME: Maybe we should remove dead code here? @@ -632,6 +766,19 @@ iter != constants.end (); ++iter) append_users (*iter); + // also get anything from jit_extract_argument, as these have constant types + for (jit_block::iterator iter = entry_block->begin (); + iter != entry_block->end (); ++iter) + { + jit_instruction *instr = *iter; + if (jit_extract_argument *extract = dynamic_cast<jit_extract_argument *>(instr)) + { + if (! extract->type ()) + fail (); // we depend on an unknown type + append_users (extract); + } + } + // FIXME: Describe algorithm here while (worklist.size ()) { @@ -653,6 +800,15 @@ std::cout << std::endl; } + // for now just init arguments from entry, later we will have to do something + // more interesting + for (jit_block::iterator iter = entry_block->begin (); + iter != entry_block->end (); ++iter) + { + if (jit_extract_argument *extract = dynamic_cast<jit_extract_argument *> (*iter)) + arguments.push_back (std::make_pair (extract->tag (), true)); + } + convert_llvm to_llvm; function = to_llvm.convert (module, arguments, blocks, constants); @@ -662,6 +818,7 @@ llvm::raw_os_ostream llvm_cout (std::cout); function->print (llvm_cout); std::cout << std::endl; + llvm::verifyFunction (*function); } } @@ -737,9 +894,103 @@ } void -jit_convert::visit_simple_for_command (tree_simple_for_command&) +jit_convert::visit_simple_for_command (tree_simple_for_command& cmd) { - fail (); + // how a for statement is compiled. Note we do an initial check + // to see if the loop will run atleast once. This allows us to get + // better type inference bounds on variables defined and used only + // inside the for loop (e.g. the index variable) + + // prev_block: % pred = ? + // #control.0 = % compute_control (note this will just be a temp) + // #iter.0 = call for_init (#control.0) % Let type of control decide iter + // % initial value and type + // #temp.0 = call for_check (control.0, #iter.0) + // cond_break #temp.0, for_body, for_tail + // for_body: % pred = for_init, for_cond + // idxvar.2 = phi | for_init -> idxvar.1 + // | for_body -> idxvar.3 + // #iter.1 = phi | for_init -> #iter.0 + // | for_body -> #iter.2 + // idxvar.3 = call for_index (#control.0, #iter.1) + // % do loop body + // #iter.2 = #iter.1 + 1 % release is implicit in iter reuse + // #check = call for_check (#control.0, iter.2) + // cond_break #check for_body, for_tail + // for_tail: % pred = prev_block, for_body + // #iter.3 = phi | prev_block -> #iter.0 + // | for_body -> #iter.2 + // idxvar.4 = phi | prev_block -> idxvar.0 + // | for_body -> idxvar.3 + // call release (#iter.3) + // % rest of code + + // FIXME: one of these days we will introduce proper lvalues... + tree_identifier *lhs = dynamic_cast<tree_identifier *>(cmd.left_hand_side ()); + if (! lhs) + fail (); + std::string lhs_name = lhs->name (); + + jit_block *body = new jit_block ("for_body"); + blocks.push_back (body); + + jit_block *tail = new jit_block ("for_tail"); + unwind_protect prot_tail; + prot_tail.add_delete (tail); // incase we fail before adding tail to blocks + + // do control expression, iter init, and condition check in prev_block (block) + jit_value *control = visit (cmd.control_expr ()); + jit_call *init_iter = new jit_call (jit_typeinfo::for_init, control); + init_iter->stash_tag ("#iter"); + block->append (init_iter); + jit_value *check = block->append (new jit_call (jit_typeinfo::for_check, + control, init_iter)); + block->append (new jit_cond_break (check, body, tail)); + + // we need to do iter phi manually, for_map handles the rest + jit_phi *iter_phi = new jit_phi (2); + iter_phi->stash_tag ("#iter"); + iter_phi->stash_argument (1, init_iter); + body->append (iter_phi); + + variable_map *merge_vars = variables; + for_map body_vars (variables, body); + variables = &body_vars; + block = body; + + // first thing we do in the for loop is bind our index from our itertor + jit_call *idx_rhs = new jit_call (jit_typeinfo::for_index, control, iter_phi); + block->append (idx_rhs); + idx_rhs->stash_tag (lhs_name); + do_assign (lhs_name, idx_rhs, false); + + tree_statement_list *pt_body = cmd.body (); + pt_body->accept (*this); + + // increment iterator, check conditional, and repeat + const jit_function& add_fn = jit_typeinfo::binary_op (octave_value::op_add); + jit_call *iter_inc = new jit_call (add_fn, iter_phi, + get_const<jit_const_index> (1)); + iter_inc->stash_tag ("#iter"); + block->append (iter_inc); + check = block->append (new jit_call (jit_typeinfo::for_check, control, + iter_inc)); + block->append (new jit_cond_break (check, body, tail)); + iter_phi->stash_argument (0, iter_inc); + body_vars.finish_phi (*variables); + + blocks.push_back (tail); + prot_tail.discard (); + block = tail; + + variables = merge_vars; + merge (body_vars); + iter_phi = new jit_phi (2); + iter_phi->stash_tag ("#iter"); + iter_phi->stash_argument (0, iter_inc); + iter_phi->stash_argument (1, init_iter); + block->append (iter_phi); + block->append (new jit_call (jit_typeinfo::release, iter_phi)); } void @@ -781,23 +1032,8 @@ void jit_convert::visit_identifier (tree_identifier& ti) { - std::string name = ti.name (); - variable_map::iterator iter = variables.find (name); - jit_value *var; - if (iter == variables.end ()) - { - octave_value var_value = ti.do_lookup (); - jit_type *var_type = jit_typeinfo::type_of (var_value); - var = entry_block->append (new jit_extract_argument (var_type, name)); - constants.push_back (var); - bounds.push_back (std::make_pair (var_type, name)); - variables[name] = var; - arguments.push_back (std::make_pair (name, true)); - } - else - var = iter->second; - const jit_function& fn = jit_typeinfo::grab (); + jit_value *var = variables->get (ti.name ()); result = block->append (new jit_call (fn, var)); } @@ -856,10 +1092,13 @@ if (v.is_real_scalar () && v.is_double_type ()) { double dv = v.double_value (); - result = get_scalar (dv); + result = get_const<jit_const_scalar> (dv); } else if (v.is_range ()) - fail (); + { + Range rv = v.range_value (); + result = get_const<jit_const_range> (rv); + } else fail (); } @@ -951,16 +1190,23 @@ // FIXME: ugly hack, we need to come up with a way to pass // nargout to visit_identifier const jit_function& fn = jit_typeinfo::print_value (); - jit_const_string *name = get_string (expr->name ()); + jit_const_string *name = get_const<jit_const_string> (expr->name ()); block->append (new jit_call (fn, name, expr_result)); } } } void -jit_convert::visit_statement_list (tree_statement_list&) +jit_convert::visit_statement_list (tree_statement_list& lst) { - fail (); + for (tree_statement_list::iterator iter = lst.begin (); iter != lst.end(); + ++iter) + { + tree_statement *elt = *iter; + // jwe: Can this ever be null? + assert (elt); + elt->accept (*this); + } } void @@ -1008,22 +1254,16 @@ void jit_convert::do_assign (const std::string& lhs, jit_value *rhs, bool print) { - variable_map::iterator iter = variables.find (lhs); - if (iter == variables.end ()) - arguments.push_back (std::make_pair (lhs, false)); - else - { - const jit_function& fn = jit_typeinfo::release (); - block->append (new jit_call (fn, iter->second)); - } - - variables[lhs] = rhs; + const jit_function& release = jit_typeinfo::release (); + jit_value *current = variables->get (lhs); + block->append (new jit_call (release, current)); + variables->set (lhs, rhs); if (print) { - const jit_function& fn = jit_typeinfo::print_value (); - jit_const_string *name = get_string (lhs); - block->append (new jit_call (fn, name, rhs)); + const jit_function& print_fn = jit_typeinfo::print_value (); + jit_const_string *name = get_const<jit_const_string> (lhs); + block->append (new jit_call (print_fn, name, rhs)); } } @@ -1038,6 +1278,40 @@ return ret; } +void +jit_convert::merge (const variable_map& ref) +{ + assert (variables->size () == ref.size ()); + variable_map::iterator viter = variables->begin (); + variable_map::const_iterator riter = ref.begin (); + for (; viter != variables->end (); ++viter, ++riter) + { + assert (viter->first == riter->first); + if (viter->second != riter->second) + { + jit_phi *phi = new jit_phi (2); + phi->stash_tag (viter->first); + block->prepend (phi); + phi->stash_argument (0, riter->second); + phi->stash_argument (1, viter->second); + viter->second = phi; + } + } +} + +// -------------------- jit_convert::toplevel_map -------------------- +jit_value * +jit_convert::toplevel_map::insert (const std::string& name, jit_value *pval) +{ + assert (pval == 0); // we have no parent + + jit_block *entry = block (); + octave_value val = symbol_table::find (name); + jit_type *type = jit_typeinfo::type_of (val); + jit_instruction *ret = new jit_extract_argument (type, name); + return vars[name] = entry->prepend (ret); +} + // -------------------- jit_convert::convert_llvm -------------------- llvm::Function * jit_convert::convert_llvm::convert (llvm::Module *module, @@ -1052,9 +1326,8 @@ arg_type = arg_type->getPointerTo (); llvm::FunctionType *ft = llvm::FunctionType::get (llvm::Type::getVoidTy (context), arg_type, false); - llvm::Function *function = llvm::Function::Create (ft, - llvm::Function::ExternalLinkage, - "foobar", module); + function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "foobar", module); try { @@ -1073,11 +1346,7 @@ // a block for (std::list<jit_value *>::const_iterator iter = constants.begin (); iter != constants.end (); ++iter) - { - jit_value *constant = *iter; - if (! dynamic_cast<jit_instruction *> (constant)) - visit (constant); - } + visit (*iter); std::list<jit_block *>::const_iterator biter; for (biter = blocks.begin (); biter != blocks.end (); ++biter) @@ -1091,36 +1360,105 @@ jit_block *first = *blocks.begin (); builder.CreateBr (first->to_llvm ()); + // convert all instructions for (biter = blocks.begin (); biter != blocks.end (); ++biter) visit (*biter); + // now finish phi nodes + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block& block = **biter; + for (jit_block::iterator piter = block.begin (); + piter != block.end () && dynamic_cast<jit_phi *> (*piter); ++piter) + { + // our phi nodes don't have to have the same incomming type, + // so we do casts here + jit_instruction *phi = *piter; + jit_block *pblock = phi->parent (); + llvm::PHINode *llvm_phi = llvm::cast<llvm::PHINode> (phi->to_llvm ()); + for (size_t i = 0; i < phi->argument_count (); ++i) + { + llvm::BasicBlock *pred = pblock->pred_llvm (i); + if (phi->argument_type_llvm (i) == phi->type_llvm ()) + { + llvm_phi->addIncoming (phi->argument_llvm (i), pred); + } + else + { + // add cast right before pred terminator + builder.SetInsertPoint (--pred->end ()); + + const jit_function::overload& ol + = jit_typeinfo::cast (phi->type (), + phi->argument_type (i)); + if (! ol.function) + { + std::stringstream ss; + ss << "No cast for phi(" << i << "): "; + phi->print (ss); + fail (ss.str ()); + } + + llvm::Value *casted; + casted = builder.CreateCall (ol.function, + phi->argument_llvm (i)); + llvm_phi->addIncoming (casted, pred); + } + } + } + } + + jit_block *last = blocks.back (); + builder.SetInsertPoint (last->to_llvm ()); builder.CreateRetVoid (); - } catch (const jit_fail_exception&) + } catch (const jit_fail_exception& e) { function->eraseFromParent (); throw; } - llvm::verifyFunction (*function); - return function; } void -jit_convert::convert_llvm::visit_const_string (jit_const_string& cs) +jit_convert::convert_llvm::visit (jit_const_string& cs) { cs.stash_llvm (builder.CreateGlobalStringPtr (cs.value ())); } void -jit_convert::convert_llvm::visit_const_scalar (jit_const_scalar& cs) +jit_convert::convert_llvm::visit (jit_const_scalar& cs) { - llvm::Type *dbl = llvm::Type::getDoubleTy (context); - cs.stash_llvm (llvm::ConstantFP::get (dbl, cs.value ())); + cs.stash_llvm (llvm::ConstantFP::get (cs.type_llvm (), cs.value ())); +} + +void jit_convert::convert_llvm::visit (jit_const_index& ci) +{ + ci.stash_llvm (llvm::ConstantInt::get (ci.type_llvm (), ci.value ())); } void -jit_convert::convert_llvm::visit_block (jit_block& b) +jit_convert::convert_llvm::visit (jit_const_range& cr) +{ + llvm::StructType *stype = llvm::cast<llvm::StructType>(cr.type_llvm ()); + llvm::Type *dbl = jit_typeinfo::get_scalar_llvm (); + llvm::Type *idx = jit_typeinfo::get_index_llvm (); + const jit_range& rng = cr.value (); + + llvm::Constant *constants[4]; + constants[0] = llvm::ConstantFP::get (dbl, rng.base); + constants[1] = llvm::ConstantFP::get (dbl, rng.limit); + constants[2] = llvm::ConstantFP::get (dbl, rng.inc); + constants[3] = llvm::ConstantInt::get (idx, rng.nelem); + + llvm::Value *as_llvm; + as_llvm = llvm::ConstantStruct::get (stype, + llvm::makeArrayRef (constants, 4)); + cr.stash_llvm (as_llvm); +} + +void +jit_convert::convert_llvm::visit (jit_block& b) { llvm::BasicBlock *block = b.to_llvm (); builder.SetInsertPoint (block); @@ -1129,46 +1467,49 @@ } void -jit_convert::convert_llvm::visit_break (jit_break& b) +jit_convert::convert_llvm::visit (jit_break& b) { - builder.CreateBr (b.sucessor_llvm ()); + b.stash_llvm (builder.CreateBr (b.sucessor_llvm ())); } void -jit_convert::convert_llvm::visit_cond_break (jit_cond_break& cb) +jit_convert::convert_llvm::visit (jit_cond_break& cb) { llvm::Value *cond = cb.cond_llvm (); - builder.CreateCondBr (cond, cb.sucessor_llvm (0), cb.sucessor_llvm (1)); + llvm::Value *br; + br = builder.CreateCondBr (cond, cb.sucessor_llvm (0), cb.sucessor_llvm (1)); + cb.stash_llvm (br); } void -jit_convert::convert_llvm::visit_call (jit_call& call) +jit_convert::convert_llvm::visit (jit_call& call) { const jit_function::overload& ol = call.overload (); if (! ol.function) - fail (); + fail ("No overload for: " + call.print_string ()); std::vector<llvm::Value *> args (call.argument_count ()); for (size_t i = 0; i < call.argument_count (); ++i) args[i] = call.argument_llvm (i); - call.stash_llvm (builder.CreateCall (ol.function, args)); + call.stash_llvm (builder.CreateCall (ol.function, args, call.tag ())); } void -jit_convert::convert_llvm::visit_extract_argument (jit_extract_argument& extract) +jit_convert::convert_llvm::visit (jit_extract_argument& extract) { const jit_function::overload& ol = extract.overload (); if (! ol.function) fail (); llvm::Value *arg = arguments[extract.tag ()]; + assert (arg); arg = builder.CreateLoad (arg); - extract.stash_llvm (builder.CreateCall (ol.function, arg)); + extract.stash_llvm (builder.CreateCall (ol.function, arg, extract.tag ())); } void -jit_convert::convert_llvm::visit_store_argument (jit_store_argument& store) +jit_convert::convert_llvm::visit (jit_store_argument& store) { llvm::Value *arg_value = store.result_llvm (); const jit_function::overload& ol = store.overload (); @@ -1181,42 +1522,47 @@ store.stash_llvm (builder.CreateStore (arg_value, arg)); } +void +jit_convert::convert_llvm::visit (jit_phi& phi) +{ + // we might not have converted all incoming branches, so we don't + // set incomming branches now + llvm::PHINode *node = llvm::PHINode::Create (phi.type_llvm (), + phi.argument_count (), + phi.tag ()); + builder.Insert (node); + phi.stash_llvm (node); + + jit_block *parent = phi.parent (); + for (size_t i = 0; i < phi.argument_count (); ++i) + if (phi.argument_type (i) != phi.type ()) + parent->create_merge (function, i); +} + // -------------------- tree_jit -------------------- -tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0) +tree_jit::tree_jit (void) : module (0), engine (0) { - llvm::InitializeNativeTarget (); - module = new llvm::Module ("octave", context); } tree_jit::~tree_jit (void) {} bool -tree_jit::execute (tree& cmd) +tree_jit::execute (tree_simple_for_command& cmd) { if (! initialize ()) return false; - compiled_map::iterator iter = compiled.find (&cmd); - jit_info *jinfo = 0; - if (iter != compiled.end ()) + jit_info *info = cmd.get_info (); + if (! info || ! info->match ()) { - jinfo = iter->second; - if (! jinfo->match ()) - { - delete jinfo; - jinfo = 0; - } + delete info; + info = new jit_info (*this, cmd); + cmd.stash_info (info); } - if (! jinfo) - { - jinfo = new jit_info (*this, cmd); - compiled[&cmd] = jinfo; - } - - return jinfo->execute (); + return info->execute (); } bool @@ -1225,6 +1571,12 @@ if (engine) return true; + if (! module) + { + llvm::InitializeNativeTarget (); + module = new llvm::Module ("octave", context); + } + // sometimes this fails pre main engine = llvm::ExecutionEngine::createJIT (module); @@ -1269,8 +1621,11 @@ arguments = conv.get_arguments (); bounds = conv.get_bounds (); } - catch (const jit_fail_exception&) - {} + catch (const jit_fail_exception& e) + { + if (debug_print && e.known ()) + std::cout << "jit fail: " << e.what () << std::endl; + } if (! fun) { @@ -1326,7 +1681,7 @@ for (size_t i = 0; i < bounds.size (); ++i) { const std::string& arg_name = bounds[i].second; - octave_value value = symbol_table::varval (arg_name); + octave_value value = symbol_table::find (arg_name); jit_type *type = jit_typeinfo::type_of (value); // FIXME: Check for a parent relationship