# HG changeset patch # User Max Brister # Date 1338082228 18000 # Node ID 51d4b1018efb23f6aa33ffc595e50d6e374dda87 # Parent f0499b0af64605ede75609f4cd4be264f4379cd2 For loops compile with new IR * src/pt-eval.cc (tree_evaluator::visit_simple_for_command): Compile loops. (tree_evaluator::visit_statement): No longer compile individual statements. * src/pt-loop.h (tree_simple_for_command::get_info): Remove type map. (tree_simple_for_command::stash_info): Remove type map. * src/pt-loop.cc (tree_simple_for_command::~tree_simple_for_command): Delete compiled code instead of map. diff --git a/build-aux/mkinstalldirs b/build-aux/mkinstalldirs --- a/build-aux/mkinstalldirs +++ b/build-aux/mkinstalldirs @@ -1,7 +1,7 @@ #! /bin/sh # mkinstalldirs --- make directory hierarchy -scriptversion=2009-04-28.21; # UTC +scriptversion=2012-05-25.20; # UTC # Original author: Noah Friedman # Created: 1993-05-16 diff --git a/src/pt-eval.cc b/src/pt-eval.cc --- a/src/pt-eval.cc +++ b/src/pt-eval.cc @@ -294,6 +294,9 @@ if (debug_mode) do_breakpoint (cmd.is_breakpoint ()); + if (jiter.execute (cmd)) + return; + // FIXME -- need to handle PARFOR loops here using cmd.in_parallel () // and cmd.maxproc_expr (); @@ -684,9 +687,6 @@ tree_command *cmd = stmt.command (); tree_expression *expr = stmt.expression (); - if (jiter.execute (stmt)) - return; - if (cmd || expr) { if (statement_context == function || statement_context == script) diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -65,7 +65,18 @@ jit_typeinfo *jit_typeinfo::instance; // thrown when we should give up on JIT and interpret -class jit_fail_exception : public std::exception {}; +class jit_fail_exception : public std::runtime_error +{ +public: + jit_fail_exception (void) : std::runtime_error ("unknown"), mknown (false) {} + jit_fail_exception (const std::string& reason) : std::runtime_error (reason), + mknown (true) + {} + + bool known (void) const { return mknown; } +private: + bool mknown; +}; static void fail (void) @@ -73,6 +84,12 @@ throw jit_fail_exception (); } +static void +fail (const std::string& reason) +{ + throw jit_fail_exception (reason); +} + // function that jit code calls extern "C" void octave_jit_print_any (const char *name, octave_base_value *obv) @@ -127,6 +144,14 @@ return new octave_scalar (value); } +// -------------------- jit_range -------------------- +std::ostream& +operator<< (std::ostream& os, const jit_range& rng) +{ + return os << "Range[" << rng.base << ", " << rng.limit << ", " << rng.inc + << ", " << rng.nelem << "]"; +} + // -------------------- jit_type -------------------- llvm::Type * jit_type::to_llvm_arg (void) const @@ -308,6 +333,10 @@ fn = create_identity (scalar); grab_fn.add_overload (fn, false, scalar, scalar); + // grab index + fn = create_identity (index); + grab_fn.add_overload (fn, false, index, index); + // release any fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ()); engine->addGlobalMapping (fn, reinterpret_cast(&octave_jit_release_any)); @@ -318,6 +347,10 @@ fn = create_identity (scalar); release_fn.add_overload (fn, false, 0, scalar); + // release index + fn = create_identity (index); + release_fn.add_overload (fn, false, 0, index); + // now for binary scalar operations // FIXME: Finish all operations add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd); @@ -336,41 +369,47 @@ add_binary_fcmp (scalar, octave_value::op_gt, llvm::CmpInst::FCMP_UGT); add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); + // now for binary index operators + add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); + // now for printing functions print_fn.stash_name ("print"); add_print (any, reinterpret_cast (&octave_jit_print_any)); add_print (scalar, reinterpret_cast (&octave_jit_print_double)); + // initialize for loop + for_init_fn.stash_name ("for_init"); + + fn = create_function ("octave_jit_for_range_init", index, range); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *zero = llvm::ConstantInt::get (index_t, 0); + builder.CreateRet (zero); + } + llvm::verifyFunction (*fn); + for_init_fn.add_overload (fn, false, index, range); + // bounds check for for loop - fn = create_function ("octave_jit_simple_for_range", boolean, range, index); - llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + for_check_fn.stash_name ("for_check"); + + fn = create_function ("octave_jit_for_range_check", boolean, range, index); + body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { llvm::Value *nelem = builder.CreateExtractValue (fn->arg_begin (), 3); - // llvm::Value *idx = builder.CreateLoad (++fn->arg_begin ()); llvm::Value *idx = ++fn->arg_begin (); llvm::Value *ret = builder.CreateICmpULT (idx, nelem); builder.CreateRet (ret); } llvm::verifyFunction (*fn); - simple_for_check.add_overload (fn, false, boolean, range, index); - - // increment for for loop - fn = create_function ("octave_jit_imple_for_range_incr", index, index); - body = llvm::BasicBlock::Create (context, "body", fn); - builder.SetInsertPoint (body); - { - llvm::Value *one = llvm::ConstantInt::get (index_t, 1); - llvm::Value *idx = fn->arg_begin (); - llvm::Value *ret = builder.CreateAdd (idx, one); - builder.CreateRet (ret); - } - llvm::verifyFunction (*fn); - simple_for_incr.add_overload (fn, false, index, index); + for_check_fn.add_overload (fn, false, boolean, range, index); // index variabe for for loop - fn = create_function ("octave_jit_simple_for_idx", scalar, range, index); + for_index_fn.stash_name ("for_index"); + + fn = create_function ("octave_jit_for_range_idx", scalar, range, index); body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { @@ -385,7 +424,7 @@ builder.CreateRet (ret); } llvm::verifyFunction (*fn); - simple_for_index.add_overload (fn, false, scalar, range, index); + for_index_fn.add_overload (fn, false, scalar, range, index); // logically true // FIXME: Check for NaN @@ -569,7 +608,102 @@ return ret; } +// -------------------- jit_use -------------------- +jit_block * +jit_use::user_parent (void) const +{ + return usr->parent (); +} + +// -------------------- jit_value -------------------- +#define JIT_METH(clname) \ + void \ + jit_ ## clname::accept (jit_ir_walker& walker) \ + { \ + walker.visit (*this); \ + } + +JIT_VISIT_IR_NOTEMPLATE +#undef JIT_METH + +// -------------------- jit_instruction -------------------- +llvm::BasicBlock * +jit_instruction::parent_llvm (void) const +{ + return mparent->to_llvm (); +} + // -------------------- jit_block -------------------- +jit_instruction * +jit_block::prepend (jit_instruction *instr) +{ + instructions.push_front (instr); + instr->stash_parent (this); + return instr; +} + +jit_instruction * +jit_block::append (jit_instruction *instr) +{ + instructions.push_back (instr); + instr->stash_parent (this); + return instr; +} + +jit_terminator * +jit_block::terminator (void) const +{ + if (instructions.empty ()) + return 0; + + jit_instruction *last = instructions.back (); + return dynamic_cast (last); +} + +llvm::Value * +jit_block::pred_terminator_llvm (size_t idx) const +{ + jit_terminator *term = pred_terminator (idx); + return term ? term->to_llvm () : 0; +} + +void +jit_block::create_merge (llvm::Function *inside, size_t pred_idx) +{ + mpred_llvm.resize (pred_count ()); + + jit_block *ipred = pred (pred_idx); + if (! mpred_llvm[pred_idx] && ipred->pred_count () > 1) + { + llvm::BasicBlock *merge; + merge = llvm::BasicBlock::Create (context, "phi_merge", inside, + to_llvm ()); + + // fix the predecessor jump if it has been created + llvm::Value *term = pred_terminator_llvm (pred_idx); + if (term) + { + llvm::TerminatorInst *branch = llvm::cast (term); + for (size_t i = 0; i < branch->getNumSuccessors (); ++i) + { + if (branch->getSuccessor (i) == to_llvm ()) + branch->setSuccessor (i, merge); + } + } + + llvm::IRBuilder<> temp (merge); + temp.CreateBr (to_llvm ()); + mpred_llvm[pred_idx] = merge; + } +} + +size_t +jit_block::succ_count (void) const +{ + jit_terminator *term = terminator (); + return term ? term->sucessor_count () : 0; +} + llvm::BasicBlock * jit_block::to_llvm (void) const { @@ -609,20 +743,20 @@ { jit_instruction::reset_ids (); - entry_block = new jit_block ("entry"); - blocks.push_back (entry_block); - block = new jit_block ("body"); + jit_block *entry_block = new jit_block ("body"); + block = entry_block; blocks.push_back (block); + toplevel_map tlevel (block); + variables = &tlevel; final_block = new jit_block ("final"); visit (tee); + blocks.push_back (final_block); - - entry_block->append (new jit_break (block)); block->append (new jit_break (final_block)); - for (variable_map::iterator iter = variables.begin (); - iter != variables.end (); ++iter) + for (variable_map::iterator iter = variables->begin (); + iter != variables->end (); ++iter) final_block->append (new jit_store_argument (iter->first, iter->second)); // FIXME: Maybe we should remove dead code here? @@ -632,6 +766,19 @@ iter != constants.end (); ++iter) append_users (*iter); + // also get anything from jit_extract_argument, as these have constant types + for (jit_block::iterator iter = entry_block->begin (); + iter != entry_block->end (); ++iter) + { + jit_instruction *instr = *iter; + if (jit_extract_argument *extract = dynamic_cast(instr)) + { + if (! extract->type ()) + fail (); // we depend on an unknown type + append_users (extract); + } + } + // FIXME: Describe algorithm here while (worklist.size ()) { @@ -653,6 +800,15 @@ std::cout << std::endl; } + // for now just init arguments from entry, later we will have to do something + // more interesting + for (jit_block::iterator iter = entry_block->begin (); + iter != entry_block->end (); ++iter) + { + if (jit_extract_argument *extract = dynamic_cast (*iter)) + arguments.push_back (std::make_pair (extract->tag (), true)); + } + convert_llvm to_llvm; function = to_llvm.convert (module, arguments, blocks, constants); @@ -662,6 +818,7 @@ llvm::raw_os_ostream llvm_cout (std::cout); function->print (llvm_cout); std::cout << std::endl; + llvm::verifyFunction (*function); } } @@ -737,9 +894,103 @@ } void -jit_convert::visit_simple_for_command (tree_simple_for_command&) +jit_convert::visit_simple_for_command (tree_simple_for_command& cmd) { - fail (); + // how a for statement is compiled. Note we do an initial check + // to see if the loop will run atleast once. This allows us to get + // better type inference bounds on variables defined and used only + // inside the for loop (e.g. the index variable) + + // prev_block: % pred = ? + // #control.0 = % compute_control (note this will just be a temp) + // #iter.0 = call for_init (#control.0) % Let type of control decide iter + // % initial value and type + // #temp.0 = call for_check (control.0, #iter.0) + // cond_break #temp.0, for_body, for_tail + // for_body: % pred = for_init, for_cond + // idxvar.2 = phi | for_init -> idxvar.1 + // | for_body -> idxvar.3 + // #iter.1 = phi | for_init -> #iter.0 + // | for_body -> #iter.2 + // idxvar.3 = call for_index (#control.0, #iter.1) + // % do loop body + // #iter.2 = #iter.1 + 1 % release is implicit in iter reuse + // #check = call for_check (#control.0, iter.2) + // cond_break #check for_body, for_tail + // for_tail: % pred = prev_block, for_body + // #iter.3 = phi | prev_block -> #iter.0 + // | for_body -> #iter.2 + // idxvar.4 = phi | prev_block -> idxvar.0 + // | for_body -> idxvar.3 + // call release (#iter.3) + // % rest of code + + // FIXME: one of these days we will introduce proper lvalues... + tree_identifier *lhs = dynamic_cast(cmd.left_hand_side ()); + if (! lhs) + fail (); + std::string lhs_name = lhs->name (); + + jit_block *body = new jit_block ("for_body"); + blocks.push_back (body); + + jit_block *tail = new jit_block ("for_tail"); + unwind_protect prot_tail; + prot_tail.add_delete (tail); // incase we fail before adding tail to blocks + + // do control expression, iter init, and condition check in prev_block (block) + jit_value *control = visit (cmd.control_expr ()); + jit_call *init_iter = new jit_call (jit_typeinfo::for_init, control); + init_iter->stash_tag ("#iter"); + block->append (init_iter); + jit_value *check = block->append (new jit_call (jit_typeinfo::for_check, + control, init_iter)); + block->append (new jit_cond_break (check, body, tail)); + + // we need to do iter phi manually, for_map handles the rest + jit_phi *iter_phi = new jit_phi (2); + iter_phi->stash_tag ("#iter"); + iter_phi->stash_argument (1, init_iter); + body->append (iter_phi); + + variable_map *merge_vars = variables; + for_map body_vars (variables, body); + variables = &body_vars; + block = body; + + // first thing we do in the for loop is bind our index from our itertor + jit_call *idx_rhs = new jit_call (jit_typeinfo::for_index, control, iter_phi); + block->append (idx_rhs); + idx_rhs->stash_tag (lhs_name); + do_assign (lhs_name, idx_rhs, false); + + tree_statement_list *pt_body = cmd.body (); + pt_body->accept (*this); + + // increment iterator, check conditional, and repeat + const jit_function& add_fn = jit_typeinfo::binary_op (octave_value::op_add); + jit_call *iter_inc = new jit_call (add_fn, iter_phi, + get_const (1)); + iter_inc->stash_tag ("#iter"); + block->append (iter_inc); + check = block->append (new jit_call (jit_typeinfo::for_check, control, + iter_inc)); + block->append (new jit_cond_break (check, body, tail)); + iter_phi->stash_argument (0, iter_inc); + body_vars.finish_phi (*variables); + + blocks.push_back (tail); + prot_tail.discard (); + block = tail; + + variables = merge_vars; + merge (body_vars); + iter_phi = new jit_phi (2); + iter_phi->stash_tag ("#iter"); + iter_phi->stash_argument (0, iter_inc); + iter_phi->stash_argument (1, init_iter); + block->append (iter_phi); + block->append (new jit_call (jit_typeinfo::release, iter_phi)); } void @@ -781,23 +1032,8 @@ void jit_convert::visit_identifier (tree_identifier& ti) { - std::string name = ti.name (); - variable_map::iterator iter = variables.find (name); - jit_value *var; - if (iter == variables.end ()) - { - octave_value var_value = ti.do_lookup (); - jit_type *var_type = jit_typeinfo::type_of (var_value); - var = entry_block->append (new jit_extract_argument (var_type, name)); - constants.push_back (var); - bounds.push_back (std::make_pair (var_type, name)); - variables[name] = var; - arguments.push_back (std::make_pair (name, true)); - } - else - var = iter->second; - const jit_function& fn = jit_typeinfo::grab (); + jit_value *var = variables->get (ti.name ()); result = block->append (new jit_call (fn, var)); } @@ -856,10 +1092,13 @@ if (v.is_real_scalar () && v.is_double_type ()) { double dv = v.double_value (); - result = get_scalar (dv); + result = get_const (dv); } else if (v.is_range ()) - fail (); + { + Range rv = v.range_value (); + result = get_const (rv); + } else fail (); } @@ -951,16 +1190,23 @@ // FIXME: ugly hack, we need to come up with a way to pass // nargout to visit_identifier const jit_function& fn = jit_typeinfo::print_value (); - jit_const_string *name = get_string (expr->name ()); + jit_const_string *name = get_const (expr->name ()); block->append (new jit_call (fn, name, expr_result)); } } } void -jit_convert::visit_statement_list (tree_statement_list&) +jit_convert::visit_statement_list (tree_statement_list& lst) { - fail (); + for (tree_statement_list::iterator iter = lst.begin (); iter != lst.end(); + ++iter) + { + tree_statement *elt = *iter; + // jwe: Can this ever be null? + assert (elt); + elt->accept (*this); + } } void @@ -1008,22 +1254,16 @@ void jit_convert::do_assign (const std::string& lhs, jit_value *rhs, bool print) { - variable_map::iterator iter = variables.find (lhs); - if (iter == variables.end ()) - arguments.push_back (std::make_pair (lhs, false)); - else - { - const jit_function& fn = jit_typeinfo::release (); - block->append (new jit_call (fn, iter->second)); - } - - variables[lhs] = rhs; + const jit_function& release = jit_typeinfo::release (); + jit_value *current = variables->get (lhs); + block->append (new jit_call (release, current)); + variables->set (lhs, rhs); if (print) { - const jit_function& fn = jit_typeinfo::print_value (); - jit_const_string *name = get_string (lhs); - block->append (new jit_call (fn, name, rhs)); + const jit_function& print_fn = jit_typeinfo::print_value (); + jit_const_string *name = get_const (lhs); + block->append (new jit_call (print_fn, name, rhs)); } } @@ -1038,6 +1278,40 @@ return ret; } +void +jit_convert::merge (const variable_map& ref) +{ + assert (variables->size () == ref.size ()); + variable_map::iterator viter = variables->begin (); + variable_map::const_iterator riter = ref.begin (); + for (; viter != variables->end (); ++viter, ++riter) + { + assert (viter->first == riter->first); + if (viter->second != riter->second) + { + jit_phi *phi = new jit_phi (2); + phi->stash_tag (viter->first); + block->prepend (phi); + phi->stash_argument (0, riter->second); + phi->stash_argument (1, viter->second); + viter->second = phi; + } + } +} + +// -------------------- jit_convert::toplevel_map -------------------- +jit_value * +jit_convert::toplevel_map::insert (const std::string& name, jit_value *pval) +{ + assert (pval == 0); // we have no parent + + jit_block *entry = block (); + octave_value val = symbol_table::find (name); + jit_type *type = jit_typeinfo::type_of (val); + jit_instruction *ret = new jit_extract_argument (type, name); + return vars[name] = entry->prepend (ret); +} + // -------------------- jit_convert::convert_llvm -------------------- llvm::Function * jit_convert::convert_llvm::convert (llvm::Module *module, @@ -1052,9 +1326,8 @@ arg_type = arg_type->getPointerTo (); llvm::FunctionType *ft = llvm::FunctionType::get (llvm::Type::getVoidTy (context), arg_type, false); - llvm::Function *function = llvm::Function::Create (ft, - llvm::Function::ExternalLinkage, - "foobar", module); + function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "foobar", module); try { @@ -1073,11 +1346,7 @@ // a block for (std::list::const_iterator iter = constants.begin (); iter != constants.end (); ++iter) - { - jit_value *constant = *iter; - if (! dynamic_cast (constant)) - visit (constant); - } + visit (*iter); std::list::const_iterator biter; for (biter = blocks.begin (); biter != blocks.end (); ++biter) @@ -1091,36 +1360,105 @@ jit_block *first = *blocks.begin (); builder.CreateBr (first->to_llvm ()); + // convert all instructions for (biter = blocks.begin (); biter != blocks.end (); ++biter) visit (*biter); + // now finish phi nodes + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block& block = **biter; + for (jit_block::iterator piter = block.begin (); + piter != block.end () && dynamic_cast (*piter); ++piter) + { + // our phi nodes don't have to have the same incomming type, + // so we do casts here + jit_instruction *phi = *piter; + jit_block *pblock = phi->parent (); + llvm::PHINode *llvm_phi = llvm::cast (phi->to_llvm ()); + for (size_t i = 0; i < phi->argument_count (); ++i) + { + llvm::BasicBlock *pred = pblock->pred_llvm (i); + if (phi->argument_type_llvm (i) == phi->type_llvm ()) + { + llvm_phi->addIncoming (phi->argument_llvm (i), pred); + } + else + { + // add cast right before pred terminator + builder.SetInsertPoint (--pred->end ()); + + const jit_function::overload& ol + = jit_typeinfo::cast (phi->type (), + phi->argument_type (i)); + if (! ol.function) + { + std::stringstream ss; + ss << "No cast for phi(" << i << "): "; + phi->print (ss); + fail (ss.str ()); + } + + llvm::Value *casted; + casted = builder.CreateCall (ol.function, + phi->argument_llvm (i)); + llvm_phi->addIncoming (casted, pred); + } + } + } + } + + jit_block *last = blocks.back (); + builder.SetInsertPoint (last->to_llvm ()); builder.CreateRetVoid (); - } catch (const jit_fail_exception&) + } catch (const jit_fail_exception& e) { function->eraseFromParent (); throw; } - llvm::verifyFunction (*function); - return function; } void -jit_convert::convert_llvm::visit_const_string (jit_const_string& cs) +jit_convert::convert_llvm::visit (jit_const_string& cs) { cs.stash_llvm (builder.CreateGlobalStringPtr (cs.value ())); } void -jit_convert::convert_llvm::visit_const_scalar (jit_const_scalar& cs) +jit_convert::convert_llvm::visit (jit_const_scalar& cs) { - llvm::Type *dbl = llvm::Type::getDoubleTy (context); - cs.stash_llvm (llvm::ConstantFP::get (dbl, cs.value ())); + cs.stash_llvm (llvm::ConstantFP::get (cs.type_llvm (), cs.value ())); +} + +void jit_convert::convert_llvm::visit (jit_const_index& ci) +{ + ci.stash_llvm (llvm::ConstantInt::get (ci.type_llvm (), ci.value ())); } void -jit_convert::convert_llvm::visit_block (jit_block& b) +jit_convert::convert_llvm::visit (jit_const_range& cr) +{ + llvm::StructType *stype = llvm::cast(cr.type_llvm ()); + llvm::Type *dbl = jit_typeinfo::get_scalar_llvm (); + llvm::Type *idx = jit_typeinfo::get_index_llvm (); + const jit_range& rng = cr.value (); + + llvm::Constant *constants[4]; + constants[0] = llvm::ConstantFP::get (dbl, rng.base); + constants[1] = llvm::ConstantFP::get (dbl, rng.limit); + constants[2] = llvm::ConstantFP::get (dbl, rng.inc); + constants[3] = llvm::ConstantInt::get (idx, rng.nelem); + + llvm::Value *as_llvm; + as_llvm = llvm::ConstantStruct::get (stype, + llvm::makeArrayRef (constants, 4)); + cr.stash_llvm (as_llvm); +} + +void +jit_convert::convert_llvm::visit (jit_block& b) { llvm::BasicBlock *block = b.to_llvm (); builder.SetInsertPoint (block); @@ -1129,46 +1467,49 @@ } void -jit_convert::convert_llvm::visit_break (jit_break& b) +jit_convert::convert_llvm::visit (jit_break& b) { - builder.CreateBr (b.sucessor_llvm ()); + b.stash_llvm (builder.CreateBr (b.sucessor_llvm ())); } void -jit_convert::convert_llvm::visit_cond_break (jit_cond_break& cb) +jit_convert::convert_llvm::visit (jit_cond_break& cb) { llvm::Value *cond = cb.cond_llvm (); - builder.CreateCondBr (cond, cb.sucessor_llvm (0), cb.sucessor_llvm (1)); + llvm::Value *br; + br = builder.CreateCondBr (cond, cb.sucessor_llvm (0), cb.sucessor_llvm (1)); + cb.stash_llvm (br); } void -jit_convert::convert_llvm::visit_call (jit_call& call) +jit_convert::convert_llvm::visit (jit_call& call) { const jit_function::overload& ol = call.overload (); if (! ol.function) - fail (); + fail ("No overload for: " + call.print_string ()); std::vector args (call.argument_count ()); for (size_t i = 0; i < call.argument_count (); ++i) args[i] = call.argument_llvm (i); - call.stash_llvm (builder.CreateCall (ol.function, args)); + call.stash_llvm (builder.CreateCall (ol.function, args, call.tag ())); } void -jit_convert::convert_llvm::visit_extract_argument (jit_extract_argument& extract) +jit_convert::convert_llvm::visit (jit_extract_argument& extract) { const jit_function::overload& ol = extract.overload (); if (! ol.function) fail (); llvm::Value *arg = arguments[extract.tag ()]; + assert (arg); arg = builder.CreateLoad (arg); - extract.stash_llvm (builder.CreateCall (ol.function, arg)); + extract.stash_llvm (builder.CreateCall (ol.function, arg, extract.tag ())); } void -jit_convert::convert_llvm::visit_store_argument (jit_store_argument& store) +jit_convert::convert_llvm::visit (jit_store_argument& store) { llvm::Value *arg_value = store.result_llvm (); const jit_function::overload& ol = store.overload (); @@ -1181,42 +1522,47 @@ store.stash_llvm (builder.CreateStore (arg_value, arg)); } +void +jit_convert::convert_llvm::visit (jit_phi& phi) +{ + // we might not have converted all incoming branches, so we don't + // set incomming branches now + llvm::PHINode *node = llvm::PHINode::Create (phi.type_llvm (), + phi.argument_count (), + phi.tag ()); + builder.Insert (node); + phi.stash_llvm (node); + + jit_block *parent = phi.parent (); + for (size_t i = 0; i < phi.argument_count (); ++i) + if (phi.argument_type (i) != phi.type ()) + parent->create_merge (function, i); +} + // -------------------- tree_jit -------------------- -tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0) +tree_jit::tree_jit (void) : module (0), engine (0) { - llvm::InitializeNativeTarget (); - module = new llvm::Module ("octave", context); } tree_jit::~tree_jit (void) {} bool -tree_jit::execute (tree& cmd) +tree_jit::execute (tree_simple_for_command& cmd) { if (! initialize ()) return false; - compiled_map::iterator iter = compiled.find (&cmd); - jit_info *jinfo = 0; - if (iter != compiled.end ()) + jit_info *info = cmd.get_info (); + if (! info || ! info->match ()) { - jinfo = iter->second; - if (! jinfo->match ()) - { - delete jinfo; - jinfo = 0; - } + delete info; + info = new jit_info (*this, cmd); + cmd.stash_info (info); } - if (! jinfo) - { - jinfo = new jit_info (*this, cmd); - compiled[&cmd] = jinfo; - } - - return jinfo->execute (); + return info->execute (); } bool @@ -1225,6 +1571,12 @@ if (engine) return true; + if (! module) + { + llvm::InitializeNativeTarget (); + module = new llvm::Module ("octave", context); + } + // sometimes this fails pre main engine = llvm::ExecutionEngine::createJIT (module); @@ -1269,8 +1621,11 @@ arguments = conv.get_arguments (); bounds = conv.get_bounds (); } - catch (const jit_fail_exception&) - {} + catch (const jit_fail_exception& e) + { + if (debug_print && e.known ()) + std::cout << "jit fail: " << e.what () << std::endl; + } if (! fun) { @@ -1326,7 +1681,7 @@ for (size_t i = 0; i < bounds.size (); ++i) { const std::string& arg_name = bounds[i].second; - octave_value value = symbol_table::varval (arg_name); + octave_value value = symbol_table::find (arg_name); jit_type *type = jit_typeinfo::type_of (value); // FIXME: Check for a parent relationship diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -46,10 +46,7 @@ // b = a + a; // will compile to do_binary_op (a, a). // -// for loops and if statements no longer compile! This is because work has been -// done to introduce a new lower level IR for octave. The low level IR looks -// a lot like llvm's IR, but it makes it much easier to infer types. You can set -// debug_print to true in pt-jit.cc to view the IRs that are created. +// For loops are compiled again! Additionally, make check passes using jit. // // The octave low level IR is a linear IR, it works by converting everything to // calls to jit_functions. This turns expressions like c = a + b into @@ -57,14 +54,13 @@ // The jit_functions contain information about overloads for differnt types. For // example, if we know a and b are scalars, then c must also be a scalar. // -// You will currently see a LARGE slowdown, as every statement is compiled -// seperatly! // // TODO: -// 1. Support for loops -// 2. Support if statements -// 3. Cleanup/documentation -// 4. ... +// 1. Support if statements +// 2. Support error cases +// 3. Fix memory leaks in JIT +// 4. Cleanup/documentation +// 5. ... // --------------------------------------------------------- @@ -109,6 +105,8 @@ octave_idx_type nelem; }; +std::ostream& operator<< (std::ostream& os, const jit_range& rng); + // Used to keep track of estimated (infered) types during JIT. This is a // hierarchical type system which includes both concrete and abstract types. // @@ -269,6 +267,8 @@ static jit_type *get_scalar (void) { return instance->scalar; } + static llvm::Type *get_scalar_llvm (void) { return instance->scalar->to_llvm (); } + static jit_type *get_range (void) { return instance->range; } static jit_type *get_string (void) { return instance->string; } @@ -277,6 +277,8 @@ static jit_type *get_index (void) { return instance->index; } + static llvm::Type *get_index_llvm (void) { return instance->index->to_llvm (); } + static jit_type *type_of (const octave_value& ov) { return instance->do_type_of (ov); @@ -299,6 +301,21 @@ return instance->print_fn; } + static const jit_function& for_init (void) + { + return instance->for_init_fn; + } + + static const jit_function& for_check (void) + { + return instance->for_check_fn; + } + + static const jit_function& for_index (void) + { + return instance->for_index_fn; + } + static const jit_function& cast (jit_type *result) { return instance->do_cast (result); @@ -461,9 +478,9 @@ jit_function grab_fn; jit_function release_fn; jit_function print_fn; - jit_function simple_for_check; - jit_function simple_for_incr; - jit_function simple_for_index; + jit_function for_init_fn; + jit_function for_check_fn; + jit_function for_index_fn; jit_function logically_true; // type id -> cast function TO that type @@ -477,35 +494,21 @@ // this ir is close to llvm, but contains information for doing type inference. // We convert the octave parse tree to this IR directly. -#define JIT_VISIT_IR_CLASSES \ - JIT_METH(const_string); \ - JIT_METH(const_scalar); \ +#define JIT_VISIT_IR_NOTEMPLATE \ JIT_METH(block); \ JIT_METH(break); \ JIT_METH(cond_break); \ JIT_METH(call); \ JIT_METH(extract_argument); \ - JIT_METH(store_argument) + JIT_METH(store_argument); \ + JIT_METH(phi) + +#define JIT_VISIT_IR_CLASSES \ + JIT_VISIT_IR_NOTEMPLATE; \ + JIT_VISIT_IR_CONST -#define JIT_METH(clname) class jit_ ## clname -JIT_VISIT_IR_CLASSES; -#undef JIT_METH - -class -jit_ir_walker -{ -public: - virtual ~jit_ir_walker () {} - -#define JIT_METH(clname) \ - virtual void visit_ ## clname (jit_ ## clname&) = 0 - - JIT_VISIT_IR_CLASSES; - -#undef JIT_METH -}; - +class jit_ir_walker; class jit_use; class @@ -513,11 +516,21 @@ { friend class jit_use; public: - jit_value (void) : llvm_value (0), ty (0), use_head (0) {} + jit_value (void) : llvm_value (0), ty (0), use_head (0), myuse_count (0) {} virtual ~jit_value (void) {} - jit_type *type () const { return ty; } + jit_type *type (void) const { return ty; } + + llvm::Type *type_llvm (void) const + { + return ty ? ty->to_llvm () : 0; + } + + const std::string& type_name (void) const + { + return ty->name (); + } void stash_type (jit_type *new_ty) { ty = new_ty; } @@ -525,6 +538,13 @@ size_t use_count (void) const { return myuse_count; } + std::string print_string (void) + { + std::stringstream ss; + print (ss); + return ss.str (); + } + virtual std::ostream& print (std::ostream& os, size_t indent = 0) = 0; virtual std::ostream& short_print (std::ostream& os) @@ -558,54 +578,55 @@ // defnie accept methods for subclasses #define JIT_VALUE_ACCEPT(clname) \ - virtual void accept (jit_ir_walker& walker) \ - { \ - walker.visit_ ## clname (*this); \ + virtual void accept (jit_ir_walker& walker); + +template +class +jit_const : public jit_value +{ +public: + typedef PASS_T pass_t; + + jit_const (PASS_T avalue) : mvalue (avalue) + { + stash_type (EXTRACT_T ()); } -class -jit_const_string : public jit_value -{ -public: - jit_const_string (const std::string& v) : val (v) - { - stash_type (jit_typeinfo::get_string ()); - } - - const std::string& value (void) const { return val; } + PASS_T value (void) const { return mvalue; } virtual std::ostream& print (std::ostream& os, size_t indent) { - return print_indent (os, indent) << "string: \"" << val << "\""; + print_indent (os, indent) << type_name () << ": "; + if (QUOTE) + os << "\""; + os << mvalue; + if (QUOTE) + os << "\""; + return os; } - JIT_VALUE_ACCEPT (const_string) + JIT_VALUE_ACCEPT (jit_const); private: - std::string val; + T mvalue; }; -class -jit_const_scalar : public jit_value -{ -public: - jit_const_scalar (double avalue) : mvalue (avalue) - { - stash_type (jit_typeinfo::get_scalar ()); - } +typedef jit_const jit_const_scalar; +typedef jit_const jit_const_index; - double value (void) const { return mvalue; } +typedef jit_const +jit_const_string; +typedef jit_const +jit_const_range; - virtual std::ostream& print (std::ostream& os, size_t indent) - { - return print_indent (os, indent) << "scalar: \"" << mvalue << "\""; - } - - JIT_VALUE_ACCEPT (const_scalar) -private: - double mvalue; -}; +#define JIT_VISIT_IR_CONST \ + JIT_METH(const_scalar); \ + JIT_METH(const_index); \ + JIT_METH(const_string); \ + JIT_METH(const_range) class jit_instruction; +class jit_block; class jit_use @@ -621,6 +642,8 @@ jit_instruction *user (void) const { return usr; } + jit_block *user_parent (void) const; + void stash_value (jit_value *new_value, jit_instruction *u = 0, size_t use_idx = -1) { @@ -678,38 +701,38 @@ { public: // FIXME: this code could be so much pretier with varadic templates... -#define JIT_EXTRACT_ARG(idx) arguments[idx].stash_value (arg ## idx, this, idx) + jit_instruction (void) : id (next_id ()), mparent (0) + {} - jit_instruction (void) : id (next_id ()) - { - } + jit_instruction (size_t nargs) + : already_infered (nargs, reinterpret_cast(0)), arguments (nargs), + id (next_id ()), mparent (0) + {} jit_instruction (jit_value *arg0) : already_infered (1, reinterpret_cast(0)), arguments (1), - id (next_id ()) + id (next_id ()), mparent (0) { - JIT_EXTRACT_ARG (0); + stash_argument (0, arg0); } jit_instruction (jit_value *arg0, jit_value *arg1) : already_infered (2, reinterpret_cast(0)), arguments (2), - id (next_id ()) + id (next_id ()), mparent (0) { - JIT_EXTRACT_ARG (0); - JIT_EXTRACT_ARG (1); + stash_argument (0, arg0); + stash_argument (1, arg1); } jit_instruction (jit_value *arg0, jit_value *arg1, jit_value *arg2) : already_infered (3, reinterpret_cast(0)), arguments (3), - id (next_id ()) + id (next_id ()), mparent (0) { - JIT_EXTRACT_ARG (0); - JIT_EXTRACT_ARG (1); - JIT_EXTRACT_ARG (2); + stash_argument (0, arg0); + stash_argument (1, arg1); + stash_argument (2, arg2); } -#undef JIT_EXTRACT_ARG - static void reset_ids (void) { next_id (true); @@ -722,12 +745,33 @@ llvm::Value *argument_llvm (size_t i) const { - return arguments[i].value ()->to_llvm (); + assert (argument (i)); + return argument (i)->to_llvm (); } jit_type *argument_type (size_t i) const { - return arguments[i].value ()->type (); + assert (argument (i)); + return argument (i)->type (); + } + + llvm::Type *argument_type_llvm (size_t i) const + { + assert (argument (i)); + return argument_type (i)->to_llvm (); + } + + std::ostream& print_argument (std::ostream& os, size_t i) const + { + if (argument (i)) + return argument (i)->short_print (os); + else + return os << "NULL"; + } + + void stash_argument (size_t i, jit_value *arg) + { + arguments[i].stash_value (arg, this, i); } size_t argument_count (void) const @@ -754,6 +798,16 @@ const std::string& tag (void) const { return mtag; } void stash_tag (const std::string& atag) { mtag = atag; } + + jit_block *parent (void) const { return mparent; } + + llvm::BasicBlock *parent_llvm (void) const; + + void stash_parent (jit_block *aparent) + { + assert (! mparent); + mparent = aparent; + } protected: std::vector already_infered; private: @@ -770,8 +824,11 @@ std::string mtag; size_t id; + jit_block *mparent; }; +class jit_terminator; + class jit_block : public jit_value { @@ -780,7 +837,8 @@ typedef instruction_list::iterator iterator; typedef instruction_list::const_iterator const_iterator; - jit_block (const std::string& n) : nm (n) {} + jit_block (const std::string& aname) : mname (aname) + {} virtual ~jit_block () { @@ -789,31 +847,93 @@ delete *iter; } - const std::string& name (void) const { return nm; } + const std::string& name (void) const { return mname; } + + jit_instruction *prepend (jit_instruction *instr); + + jit_instruction *append (jit_instruction *instr); + + jit_terminator *terminator (void) const; - jit_instruction *prepend (jit_instruction *instr) + jit_block *pred (size_t idx) const { - instructions.push_front (instr); - return instr; + // FIXME: We should probably make this O(1) + jit_use *puse = first_use (); + for (size_t i = 0; i < idx; ++i) + { + assert (puse); + puse = puse->next (); + } + + return puse->user_parent (); + } + + jit_terminator *pred_terminator (size_t idx) const + { + return pred (idx)->terminator (); + } + + llvm::Value *pred_terminator_llvm (size_t idx) const; + + std::ostream& print_pred (std::ostream& os, size_t idx) + { + return pred (idx)->short_print (os); } - jit_instruction *append (jit_instruction *instr) + // takes into account for the addition of phi merges + llvm::BasicBlock *pred_llvm (size_t idx) const { - instructions.push_back (instr); - return instr; + if (mpred_llvm.size () <= idx) + mpred_llvm.resize (pred_count ()); + + return mpred_llvm[idx] ? mpred_llvm[idx] : pred (idx)->to_llvm (); + } + + llvm::BasicBlock *pred_llvm (jit_block *apred) const + { + return pred_llvm (pred_index (apred)); } - iterator begin () { return instructions.begin (); } + size_t pred_index (jit_block *apred) const + { + jit_use *puse = first_use (); + size_t idx = 0; + while (puse->user_parent () != apred) + { + assert (puse); + puse = puse->next (); + ++idx; + } - const_iterator begin () const { return instructions.begin (); } + return idx; + } - iterator end () { return instructions.end (); } + // create llvm phi merge blocks for all predecessors (if required) + void create_merge (llvm::Function *inside, size_t pred_idx); + + size_t pred_count (void) const { return use_count (); } + + size_t succ_count (void) const; - const_iterator end () const { return instructions.begin (); } + iterator begin (void) { return instructions.begin (); } + + const_iterator begin (void) const { return instructions.begin (); } + + iterator end (void) { return instructions.end (); } + + const_iterator end (void) const { return instructions.begin (); } virtual std::ostream& print (std::ostream& os, size_t indent) { - print_indent (os, indent) << nm << ":" << std::endl; + print_indent (os, indent) << mname << ":\tpred = "; + for (size_t i = 0; i < pred_count (); ++i) + { + print_pred (os, i); + if (i + 1 < pred_count ()) + os << ", "; + } + os << std::endl; + for (iterator iter = begin (); iter != end (); ++iter) { jit_instruction *instr = *iter; @@ -822,15 +942,73 @@ return os; } + virtual std::ostream& short_print (std::ostream& os) + { + return os << mname; + } + llvm::BasicBlock *to_llvm (void) const; JIT_VALUE_ACCEPT (block) private: - std::string nm; + std::string mname; instruction_list instructions; + mutable std::vector mpred_llvm; }; -class jit_terminator : public jit_instruction +class +jit_phi : public jit_instruction +{ +public: + jit_phi (size_t npred) : jit_instruction (npred) + {} + + virtual bool infer (void) + { + jit_type *infered = 0; + for (size_t i = 0; i < argument_count (); ++i) + infered = jit_typeinfo::tunion (infered, argument_type (i)); + + if (infered != type ()) + { + stash_type (infered); + return true; + } + + return false; + } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + std::stringstream ss; + print_indent (ss, indent); + short_print (ss) << " phi "; + std::string ss_str = ss.str (); + std::string indent_str (ss_str.size () + 7, ' '); + os << ss_str; + + jit_block *pblock = parent (); + for (size_t i = 0; i < argument_count (); ++i) + { + if (i > 0) + os << indent_str; + os << "| "; + + pblock->print_pred (os, i) << " -> "; + print_argument (os, i); + + if (i + 1 < argument_count ()) + os << std::endl; + } + + return os; + } + + JIT_VALUE_ACCEPT (phi); +}; + +class +jit_terminator : public jit_instruction { public: jit_terminator (jit_value *arg0) : jit_instruction (arg0) {} @@ -840,9 +1018,20 @@ virtual jit_block *sucessor (size_t idx = 0) const = 0; + // return either our sucessors block directly, or the phi merge block + // between us and our sucessor llvm::BasicBlock *sucessor_llvm (size_t idx = 0) const { - return sucessor (idx)->to_llvm (); + jit_block *succ = sucessor (idx); + llvm::BasicBlock *pllvm = parent_llvm (); + llvm::BasicBlock *spred_llvm = succ->pred_llvm (parent ()); + llvm::BasicBlock *succ_llvm = succ->to_llvm (); + return pllvm == spred_llvm ? succ_llvm : spred_llvm; + } + + std::ostream& print_sucessor (std::ostream& os, size_t idx = 0) + { + return sucessor (idx)->short_print (os); } virtual size_t sucessor_count (void) const = 0; @@ -857,15 +1046,15 @@ jit_block *sucessor (size_t idx = 0) const { jit_value *arg = argument (idx); - return reinterpret_cast (arg); + return static_cast (arg); } size_t sucessor_count (void) const { return 1; } virtual std::ostream& print (std::ostream& os, size_t indent) { - jit_block *succ = sucessor (); - return print_indent (os, indent) << "break: " << succ->name (); + print_indent (os, indent) << "break: "; + return print_sucessor (os); } JIT_VALUE_ACCEPT (break) @@ -880,6 +1069,11 @@ jit_value *cond (void) const { return argument (0); } + std::ostream& print_cond (std::ostream& os) + { + return cond ()->short_print (os); + } + llvm::Value *cond_llvm (void) const { return cond ()->to_llvm (); @@ -888,11 +1082,19 @@ jit_block *sucessor (size_t idx) const { jit_value *arg = argument (idx + 1); - return reinterpret_cast (arg); + return static_cast (arg); } size_t sucessor_count (void) const { return 2; } + virtual std::ostream& print (std::ostream& os, size_t indent) + { + print_indent (os, indent) << "cond_break: "; + print_cond (os) << ", "; + print_sucessor (os, 0) << ", "; + return print_sucessor (os, 1); + } + JIT_VALUE_ACCEPT (cond_break) }; @@ -903,10 +1105,17 @@ jit_call (const jit_function& afunction, jit_value *arg0) : jit_instruction (arg0), mfunction (afunction) {} + jit_call (const jit_function& (*afunction) (void), + jit_value *arg0) : jit_instruction (arg0), mfunction (afunction ()) {} + jit_call (const jit_function& afunction, jit_value *arg0, jit_value *arg1) : jit_instruction (arg0, arg1), mfunction (afunction) {} + jit_call (const jit_function& (*afunction) (void), + jit_value *arg0, jit_value *arg1) : jit_instruction (arg0, arg1), + mfunction (afunction ()) {} + const jit_function& function (void) const { return mfunction; } const jit_function::overload& overload (void) const @@ -924,8 +1133,7 @@ for (size_t i = 0; i < argument_count (); ++i) { - jit_value *arg = argument (i); - arg->short_print (os); + print_argument (os, i); if (i + 1 < argument_count ()) os << ", "; } @@ -1004,6 +1212,27 @@ JIT_VALUE_ACCEPT (store_argument) }; +class +jit_ir_walker +{ +public: + virtual ~jit_ir_walker () {} + +#define JIT_METH(clname) \ + virtual void visit (jit_ ## clname&) = 0 + + JIT_VISIT_IR_CLASSES; + +#undef JIT_METH +}; + +template +void +jit_const::accept (jit_ir_walker& walker) +{ + walker.visit (*this); +} + // convert between IRs // FIXME: Class relationships are messy from here on down. They need to be // cleaned up. @@ -1112,14 +1341,126 @@ std::vector > arguments; type_bound_vector bounds; - typedef std::map variable_map; - variable_map variables; + class + variable_map + { + // internal variable map + typedef std::map ivar_map; + public: + typedef ivar_map::iterator iterator; + typedef ivar_map::const_iterator const_iterator; + + variable_map (variable_map *aparent, jit_block *ablock) : mparent (aparent), + mblock (ablock) + {} + + variable_map *parent (void) const { return mparent; } + + jit_block *block (void) const { return mblock; } + + jit_value *get (const std::string& name) + { + ivar_map::iterator iter = vars.find (name); + if (iter != vars.end ()) + return iter->second; + + if (mparent) + { + jit_value *pval = mparent->get (name); + return insert (name, pval); + } + + return insert (name, 0); + } + + jit_value *set (const std::string& name, jit_value *val) + { + get (name); // force insertion + return vars[name] = val; + } + + iterator begin (void) { return vars.begin (); } + const_iterator begin (void) const { return vars.begin (); } + + iterator end (void) { return vars.end (); } + const_iterator end (void) const { return vars.end (); } + + size_t size (void) const { return vars.size (); } + protected: + virtual jit_value *insert (const std::string& name, jit_value *pval) = 0; + + ivar_map vars; + private: + variable_map *mparent; + jit_block *mblock; + }; + + class + toplevel_map : public variable_map + { + public: + toplevel_map (jit_block *aentry) : variable_map (0, aentry) {} + protected: + virtual jit_value *insert (const std::string& name, jit_value *pval); + }; + + class + for_map : public variable_map + { + public: + typedef variable_map::iterator iterator; + typedef variable_map::const_iterator const_iterator; + + for_map (variable_map *aparent, jit_block *ablock) + : variable_map (aparent, ablock) + { + // force insertion of all phi nodes + for (iterator iter = aparent->begin (); iter != aparent->end (); ++iter) + get (iter->first); + } + + void finish_phi (variable_map& from) + { + jit_block *for_body = block (); + for (jit_block::iterator iter = for_body->begin (); + iter != for_body->end () && dynamic_cast (*iter); ++iter) + { + jit_instruction *node = *iter; + if (! node->argument (0)) + node->stash_argument (0, from.get (node->tag ())); + } + } + protected: + virtual jit_value *insert (const std::string& name, jit_value *pval) + { + jit_phi *ret = new jit_phi (2); + ret->stash_tag (name); + block ()->prepend (ret); + ret->stash_argument (1, pval); + return vars[name] = ret; + } + }; + + class + compound_map : public variable_map + { + public: + compound_map (variable_map *aparent) : variable_map (aparent, 0) + {} + protected: + virtual jit_value *insert (const std::string&, jit_value *pval) + { + return pval; + } + }; + + + variable_map *variables; // used instead of return values from visit_* functions jit_value *result; jit_block *block; - jit_block *entry_block; jit_block *final_block; llvm::Function *function; @@ -1142,19 +1483,17 @@ worklist.push_back (use->user ()); } - jit_const_scalar *get_scalar (double v) + template + CONST_T *get_const (typename CONST_T::pass_t v) { - jit_const_scalar *ret = new jit_const_scalar (v); + CONST_T *ret = new CONST_T (v); constants.push_back (ret); return ret; } - jit_const_string *get_string (const std::string& v) - { - jit_const_string *ret = new jit_const_string (v); - constants.push_back (ret); - return ret; - } + // place phi nodes in the current block to merge ref with variables + // we assume the same number of deffinitions + void merge (const variable_map& ref); // this case is much simpler, just convert from the jit ir to llvm class @@ -1167,7 +1506,7 @@ const std::list& constants); #define JIT_METH(clname) \ - virtual void visit_ ## clname (jit_ ## clname&); + virtual void visit (jit_ ## clname&); JIT_VISIT_IR_CLASSES; @@ -1186,6 +1525,8 @@ { jvalue.accept (*this); } + private: + llvm::Function *function; }; }; @@ -1199,7 +1540,7 @@ ~tree_jit (void); - bool execute (tree& cmd); + bool execute (tree_simple_for_command& cmd); llvm::ExecutionEngine *get_engine (void) const { return engine; } @@ -1211,9 +1552,6 @@ // FIXME: Temorary hack to test typedef std::map compiled_map; - compiled_map compiled; - - llvm::LLVMContext &context; llvm::Module *module; llvm::PassManager *module_pass_manager; llvm::FunctionPassManager *pass_manager; diff --git a/src/pt-loop.cc b/src/pt-loop.cc --- a/src/pt-loop.cc +++ b/src/pt-loop.cc @@ -98,10 +98,7 @@ delete list; delete lead_comm; delete trail_comm; - - for (compiled_map::iterator iter = compiled.begin (); iter != compiled.end (); - ++iter) - delete iter->second; + delete compiled; } tree_command * diff --git a/src/pt-loop.h b/src/pt-loop.h --- a/src/pt-loop.h +++ b/src/pt-loop.h @@ -37,7 +37,6 @@ #include "symtab.h" class jit_info; -class jit_type; // While. @@ -149,7 +148,7 @@ tree_simple_for_command (int l = -1, int c = -1) : tree_command (l, c), parallel (false), lhs (0), expr (0), - maxproc (0), list (0), lead_comm (0), trail_comm (0) { } + maxproc (0), list (0), lead_comm (0), trail_comm (0), compiled (0) { } tree_simple_for_command (bool parallel_arg, tree_expression *le, tree_expression *re, @@ -160,7 +159,7 @@ int l = -1, int c = -1) : tree_command (l, c), parallel (parallel_arg), lhs (le), expr (re), maxproc (maxproc_arg), list (lst), - lead_comm (lc), trail_comm (tc) { } + lead_comm (lc), trail_comm (tc), compiled (0) { } ~tree_simple_for_command (void); @@ -184,20 +183,17 @@ void accept (tree_walker& tw); // some functions use by tree_jit - jit_info *get_info (jit_type *type) const + jit_info *get_info (void) const { - compiled_map::const_iterator iter = compiled.find (type); - return iter != compiled.end () ? iter->second : 0; + return compiled; } - void stash_info (jit_type *type, jit_info *jinfo) + void stash_info (jit_info *jinfo) { - compiled[type] = jinfo; + compiled = jinfo; } private: - typedef std::map compiled_map; - // TRUE means operate in parallel (subject to the value of the // maxproc expression). bool parallel; @@ -221,8 +217,8 @@ // Comment preceding ENDFOR token. octave_comment_list *trail_comm; - // a map from iterator types -> compiled functions - compiled_map compiled; + // compiled version of the loop + jit_info *compiled; // No copying!