# HG changeset patch # User Max Brister # Date 1339091000 18000 # Node ID bab44e3ee2912d9d083f69c8832f5ceb1184b57f # Parent 78e1457c5bf55393f9b80f8d335f8e1a0f7fc33c Adding basic error support to JIT diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -301,6 +301,12 @@ casts.resize (next_id + 1); identities.resize (next_id + 1, 0); + // bind global variables + lerror_state = new llvm::GlobalVariable (*module, bool_t, false, + llvm::GlobalValue::ExternalLinkage, + 0, "error_state"); + engine->addGlobalMapping (lerror_state, reinterpret_cast (&error_state)); + // any with anything is an any op llvm::Function *fn; llvm::Type *binary_op_type @@ -334,10 +340,7 @@ fn->arg_begin (), ++fn->arg_begin ()); builder.CreateRet (ret); - - jit_function::overload overload (fn, true, any, any, any); - for (octave_idx_type i = 0; i < next_id; ++i) - binary_ops[op].add_overload (overload); + binary_ops[op].add_overload (fn, true, true, any, any, any); } llvm::Type *void_t = llvm::Type::getVoidTy (context); @@ -635,6 +638,12 @@ return identities[id]; } +llvm::Value * +jit_typeinfo::do_insert_error_check (void) +{ + return builder.CreateLoad (lerror_state); +} + jit_type * jit_typeinfo::do_type_of (const octave_value &ov) const { @@ -679,10 +688,7 @@ { jit_instruction *user = use_head->user (); size_t idx = use_head->index (); - if (idx < user->argument_count ()) - user->stash_argument (idx, value); - else - user->stash_tag (0); + user->stash_argument (idx, value); } } @@ -710,20 +716,6 @@ mparent->remove (mlocation); } -void -jit_instruction::push_variable (void) -{ - if (tag ()) - tag ()->push (this); -} - -void -jit_instruction::pop_variable (void) -{ - if (tag ()) - tag ()->pop (); -} - llvm::BasicBlock * jit_instruction::parent_llvm (void) const { @@ -735,24 +727,7 @@ { if (type ()) jit_print (os, type ()) << ": "; - - if (tag ()) - os << tag ()->name () << "." << id; - else - os << "#" << id; - return os; -} - -jit_variable * -jit_instruction::tag (void) const -{ - return reinterpret_cast (mtag.value ()); -} - -void -jit_instruction::stash_tag (jit_variable *atag) -{ - mtag.stash_value (atag, this); + return os << "#" << mid; } // -------------------- jit_block -------------------- @@ -967,13 +942,25 @@ for (size_t i = 0; i < pred_count (); ++i) changed = pred (i)->update_idom (visit_count) || changed; + if (! idom) + { + // one of our predecessors may have an idom of us, so if idom_intersect + // is called we need to have an idom. Assign idom to the pred with the + // lowest rpo id, as this prevents an infinite loop in idom_intersect + // FIXME: Textbook algorithm doesn't do this, ensure this is correct + size_t lowest_rpo = 0; + for (size_t i = 1; i < pred_count (); ++i) + if (pred (i)->id () < pred (lowest_rpo)->id ()) + lowest_rpo = i; + idom = pred (lowest_rpo); + changed = true; + } + jit_block *new_idom = pred (0); for (size_t i = 1; i < pred_count (); ++i) { jit_block *pidom = pred (i)->idom; - if (! new_idom) - new_idom = pidom; - else if (pidom) + if (pidom) new_idom = pidom->idom_intersect (new_idom); } @@ -1077,6 +1064,7 @@ jit_instruction::reset_ids (); entry_block = create ("body"); + final_block = create ("final"); blocks.push_back (entry_block); block = entry_block; visit (tee); @@ -1086,7 +1074,9 @@ assert (breaks.empty ()); assert (continues.empty ()); - jit_block *final_block = block; + block->append (create (final_block)); + blocks.push_back (final_block); + for (vmap_t::iterator iter = vmap.begin (); iter != vmap.end (); ++iter) { @@ -1096,9 +1086,7 @@ final_block->append (create (var)); } - print_blocks ("octave jit ir"); - - construct_ssa (final_block); + construct_ssa (); // initialize the worklist to instructions derived from constants for (std::list::iterator iter = constants.begin (); @@ -1176,6 +1164,11 @@ const jit_function& fn = jit_typeinfo::binary_op (be.op_type ()); result = block->append (create (fn, lhsv, rhsv)); + + jit_block *normal = create (block->name () + "a"); + block->append (create (normal, final_block)); + blocks.push_back (normal); + block = normal; } void @@ -1241,10 +1234,9 @@ void jit_convert::visit_simple_for_command (tree_simple_for_command& cmd) { - // how a for statement is compiled. Note we do an initial check - // to see if the loop will run atleast once. This allows us to get - // better type inference bounds on variables defined and used only - // inside the for loop (e.g. the index variable) + // Note we do an initial check to see if the loop will run atleast once. + // This allows us to get better type inference bounds on variables defined + // and used only inside the for loop (e.g. the index variable) // If we are a nested for loop we need to store the previous breaks assert (! breaking); @@ -1275,8 +1267,8 @@ // do control expression, iter init, and condition check in prev_block (block) jit_value *control = visit (cmd.control_expr ()); jit_call *init_iter = create (jit_typeinfo::for_init, control); - init_iter->stash_tag (iterator); block->append (init_iter); + block->append (create (iterator, init_iter)); jit_value *check = block->append (create (jit_typeinfo::for_check, control, iterator)); @@ -1316,8 +1308,8 @@ block->append (one); jit_call *iter_inc = create (add_fn, iterator, one); - iter_inc->stash_tag (iterator); block->append (iter_inc); + block->append (create (iterator, iter_inc)); check = block->append (create (jit_typeinfo::for_check, control, iterator)); block->append (create (check, body, tail)); @@ -1389,41 +1381,6 @@ void jit_convert::visit_if_command_list (tree_if_command_list& lst) { - // Example code: - // if a == 1 - // c = c + 1; - // elseif b == 1 - // c = c + 2; - // else - // c = c + 3; - // endif - - // ******************** - // FIXME: Documentation no longer reflects current version - // ******************** - - // Generates: - // prev_block0: % pred - ? - // #temp.0 = call binary== (a.0, 1) - // cond_break #temp.0, if_body1, ifelse_cond2 - // if_body1: - // c.1 = call binary+ (c.0, 1) - // break if_tail5 - // ifelse_cond2: - // #temp.1 = call binary== (b.0, 1) - // cond_break #temp.1, ifelse_body3, else4 - // ifelse_body3: - // c.2 = call binary+ (c.0, 2) - // break if_tail5 - // else4: - // c.3 = call binary+ (c.0, 3) - // break if_tail5 - // if_tail5: - // c.4 = phi | if_body1 -> c.1 - // | ifelse_body3 -> c.2 - // | else4 -> c.3 - - tree_if_clause *last = lst.back (); size_t last_else = static_cast (last->is_else_clause ()); @@ -1718,7 +1675,7 @@ bool print) { jit_variable *var = get_variable (lhs); - rhs->stash_tag (var); + block->append (create (var, rhs)); if (print) { @@ -1742,7 +1699,7 @@ } void -jit_convert::construct_ssa (jit_block *final_block) +jit_convert::construct_ssa (void) { final_block->label (); entry_block->compute_idom (final_block); @@ -1784,6 +1741,7 @@ } entry_block->visit_dom (&jit_convert::do_construct_ssa, &jit_block::pop_all); + print_dom (); } void @@ -1795,7 +1753,8 @@ jit_instruction *instr = *iter; if (! isa (instr)) { - for (size_t i = 0; i < instr->argument_count (); ++i) + for (size_t i = isa (instr); i < instr->argument_count (); + ++i) { jit_value *arg = instr->argument (i); jit_variable *var = dynamic_cast (arg); @@ -1814,11 +1773,22 @@ size_t pred_idx = finish->pred_index (&block); for (jit_block::iterator iter = finish->begin (); iter != finish->end () - && isa (*iter); ++iter) + && isa (*iter);) { - jit_instruction *phi = *iter; - jit_variable *var = phi->tag (); - phi->stash_argument (pred_idx, var->top ()); + jit_phi *phi = dynamic_cast (*iter); + jit_variable *var = phi->dest (); + if (var->has_top ()) + { + phi->stash_argument (pred_idx, var->top ()); + ++iter; + } + else + { + // temporaries may have extranious phi nodes which can be removed + assert (! phi->use_count ()); + assert (var->name ().size () && var->name ()[0] == '#'); + iter = finish->remove (iter); + } } } } @@ -1826,7 +1796,7 @@ void jit_convert::place_releases (void) { - jit_convert::release_placer placer (*this); + release_placer placer (*this); entry_block->visit_dom (placer, &jit_block::pop_all); } @@ -1848,24 +1818,22 @@ for (jit_block::iterator iter = block.begin (); iter != block.end (); ++iter) { jit_instruction *instr = *iter; + instr->stash_last_use (instr); + for (size_t i = 0; i < instr->argument_count (); ++i) { - jit_value *varg = instr->argument (i); - jit_instruction *arg = dynamic_cast (varg); - if (arg && arg->tag ()) - { - jit_variable *tag = arg->tag (); - tag->stash_last_use (instr); - } + jit_value *arg = instr->argument (i); + assert (arg); + arg->stash_last_use (instr); } - jit_variable *tag = instr->tag (); - if (tag && ! (isa (instr) || isa (instr)) - && tag->has_top ()) + jit_assign *assign = dynamic_cast (instr); + if (assign && assign->dest ()->has_top ()) { - jit_instruction *last_use = tag->last_use (); + jit_variable *var = assign->dest (); + jit_instruction *last_use = var->last_use (); jit_call *release = convert.create (jit_typeinfo::release, - tag->top ()); + var->top ()); release->infer (); if (last_use && last_use->parent () == &block && ! isa (last_use)) @@ -1953,7 +1921,7 @@ jit_block *pblock = phi->parent (); llvm::PHINode *llvm_phi = llvm::cast (phi->to_llvm ()); - bool can_remove = llvm_phi->use_empty (); + bool can_remove = ! phi->use_count (); if (! can_remove && llvm_phi->hasOneUse () && phi->use_count () == 1) { jit_instruction *user = phi->first_use ()->user (); @@ -2000,9 +1968,7 @@ { llvm::BasicBlock *pred = pblock->pred_llvm (i); if (phi->argument_type (i) == phi->type ()) - { - llvm_phi->addIncoming (phi->argument_llvm (i), pred); - } + llvm_phi->addIncoming (phi->argument_llvm (i), pred); else { // add cast right before pred terminator @@ -2152,6 +2118,19 @@ fail ("ERROR: SSA construction should remove all variables"); } +void +jit_convert::convert_llvm::visit (jit_check_error& check) +{ + llvm::Value *cond = jit_typeinfo::insert_error_check (); + llvm::Value *br = builder.CreateCondBr (cond, check.sucessor_llvm (1), + check.sucessor_llvm (0)); + check.stash_llvm (br); +} + +void +jit_convert::convert_llvm::visit (jit_assign&) +{} + // -------------------- tree_jit -------------------- tree_jit::tree_jit (void) : module (0), engine (0) diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -84,6 +84,7 @@ class LLVMContext; class Type; class Twine; + class GlobalVariable; } class octave_base_value; @@ -356,6 +357,11 @@ { return instance->do_cast (to, from); } + + static llvm::Value *insert_error_check (void) + { + return instance->do_insert_error_check (); + } private: jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e); @@ -487,6 +493,8 @@ llvm::Function *create_identity (jit_type *type); + llvm::Value *do_insert_error_check (void); + static jit_typeinfo *instance; llvm::Module *module; @@ -494,6 +502,7 @@ int next_id; llvm::Type *ov_t; + llvm::GlobalVariable *lerror_state; std::vector id_to_type; jit_type *any; @@ -525,19 +534,21 @@ // We convert the octave parse tree to this IR directly. #define JIT_VISIT_IR_NOTEMPLATE \ - JIT_METH(block) \ - JIT_METH(break) \ - JIT_METH(cond_break) \ - JIT_METH(call) \ - JIT_METH(extract_argument) \ - JIT_METH(store_argument) \ - JIT_METH(phi) \ - JIT_METH(variable) + JIT_METH(block); \ + JIT_METH(break); \ + JIT_METH(cond_break); \ + JIT_METH(call); \ + JIT_METH(extract_argument); \ + JIT_METH(store_argument); \ + JIT_METH(phi); \ + JIT_METH(variable); \ + JIT_METH(check_error); \ + JIT_METH(assign) #define JIT_VISIT_IR_CONST \ - JIT_METH(const_scalar) \ - JIT_METH(const_index) \ - JIT_METH(const_string) \ + JIT_METH(const_scalar); \ + JIT_METH(const_index); \ + JIT_METH(const_string); \ JIT_METH(const_range) #define JIT_VISIT_IR_CLASSES \ @@ -576,7 +587,8 @@ { friend class jit_use; public: - jit_value (void) : llvm_value (0), ty (0), use_head (0), myuse_count (0) {} + jit_value (void) : llvm_value (0), ty (0), use_head (0), myuse_count (0), + mlast_use (0) {} virtual ~jit_value (void); @@ -608,6 +620,13 @@ return ss.str (); } + jit_instruction *last_use (void) const { return mlast_use; } + + void stash_last_use (jit_instruction *alast_use) + { + mlast_use = alast_use; + } + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const = 0; virtual std::ostream& short_print (std::ostream& os) const @@ -644,6 +663,7 @@ jit_type *ty; jit_use *use_head; size_t myuse_count; + jit_instruction *mlast_use; }; std::ostream& operator<< (std::ostream& os, const jit_value& value); @@ -737,12 +757,12 @@ { public: // FIXME: this code could be so much pretier with varadic templates... - jit_instruction (void) : id (next_id ()), mparent (0) + jit_instruction (void) : mid (next_id ()), mparent (0) {} jit_instruction (size_t nargs, jit_value *adefault = 0) : already_infered (nargs, reinterpret_cast(0)), arguments (nargs), - id (next_id ()), mparent (0) + mid (next_id ()), mparent (0) { if (adefault) for (size_t i = 0; i < nargs; ++i) @@ -751,14 +771,14 @@ jit_instruction (jit_value *arg0) : already_infered (1, reinterpret_cast(0)), arguments (1), - id (next_id ()), mparent (0) + mid (next_id ()), mparent (0) { stash_argument (0, arg0); } jit_instruction (jit_value *arg0, jit_value *arg1) : already_infered (2, reinterpret_cast(0)), arguments (2), - id (next_id ()), mparent (0) + mid (next_id ()), mparent (0) { stash_argument (0, arg0); stash_argument (1, arg1); @@ -766,7 +786,7 @@ jit_instruction (jit_value *arg0, jit_value *arg1, jit_value *arg2) : already_infered (3, reinterpret_cast(0)), arguments (3), - id (next_id ()), mparent (0) + mid (next_id ()), mparent (0) { stash_argument (0, arg0); stash_argument (1, arg1); @@ -776,7 +796,7 @@ jit_instruction (jit_value *arg0, jit_value *arg1, jit_value *arg2, jit_value *arg3) : already_infered (3, reinterpret_cast(0)), arguments (4), - id (next_id ()), mparent (0) + mid (next_id ()), mparent (0) { stash_argument (0, arg0); stash_argument (1, arg1); @@ -847,14 +867,14 @@ virtual bool almost_dead (void) const { return false; } + virtual void push_variable (void) {} + + virtual void pop_variable (void) {} + virtual bool infer (void) { return false; } void remove (void); - void push_variable (void); - - void pop_variable (void); - virtual std::ostream& short_print (std::ostream& os) const; jit_block *parent (void) const { return mparent; } @@ -873,9 +893,7 @@ mlocation = alocation; } - jit_variable *tag (void) const; - - void stash_tag (jit_variable *atag); + size_t id (void) const { return mid; } protected: std::vector already_infered; private: @@ -890,9 +908,7 @@ std::vector arguments; - jit_use mtag; - - size_t id; + size_t mid; jit_block *mparent; std::list::iterator mlocation; }; @@ -960,11 +976,12 @@ jit_instruction *insert_after (iterator loc, jit_instruction *instr); - void remove (jit_block::iterator iter) + iterator remove (iterator iter) { jit_instruction *instr = *iter; - instructions.erase (iter); + iter = instructions.erase (iter); instr->stash_parent (0, instructions.end ()); + return iter; } jit_terminator *terminator (void) const; @@ -1046,8 +1063,7 @@ for (size_t i = 0; i < pred_count (); ++i) pred (i)->label (visit_count, number); - mid = number; - ++number; + mid = number++; } // See for idom computation algorithm @@ -1254,15 +1270,66 @@ }; class -jit_phi : public jit_instruction +jit_assign_base : public jit_instruction +{ +public: + jit_assign_base (jit_variable *adest) : jit_instruction (), mdest (adest) {} + + jit_assign_base (jit_variable *adest, size_t npred) : jit_instruction (npred), + mdest (adest) {} + + jit_assign_base (jit_variable *adest, jit_value *arg0, jit_value *arg1) + : jit_instruction (arg0, arg1), mdest (adest) {} + + jit_variable *dest (void) const { return mdest; } + + virtual void push_variable (void) + { + mdest->push (this); + } + + virtual void pop_variable (void) + { + mdest->pop (); + } +private: + jit_variable *mdest; +}; + +class +jit_assign : public jit_assign_base { public: - jit_phi (jit_variable *avariable, size_t npred) - : jit_instruction (npred) + jit_assign (jit_variable *adest, jit_instruction *asrc) + : jit_assign_base (adest, adest, asrc) {} + + jit_instruction *src (void) const + { + return static_cast (argument (1)); + } + + virtual void push_variable (void) + { + dest ()->push (src ()); + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const { - stash_tag (avariable); + return print_indent (os, indent) << *dest () << " = " << *src (); } + JIT_VALUE_ACCEPT (assign); +private: + jit_variable *mdest; +}; + +class +jit_phi : public jit_assign_base +{ +public: + jit_phi (jit_variable *adest, size_t npred) : jit_assign_base (adest, npred) + {} + virtual bool dead (void) const { return use_count () == 0; @@ -1314,6 +1381,15 @@ return os; } + virtual std::ostream& short_print (std::ostream& os) const + { + if (type ()) + jit_print (os, type ()) << ": "; + + dest ()->short_print (os); + return os << "#" << id (); + } + JIT_VALUE_ACCEPT (phi); }; @@ -1323,6 +1399,9 @@ public: jit_terminator (jit_value *arg0) : jit_instruction (arg0) {} + jit_terminator (jit_value *arg0, jit_value *arg1) + : jit_instruction (arg0, arg1) {} + jit_terminator (jit_value *arg0, jit_value *arg1, jit_value *arg2) : jit_instruction (arg0, arg1, arg2) {} @@ -1451,7 +1530,7 @@ { print_indent (os, indent); - if (use_count () || tag ()) + if (use_count ()) short_print (os) << " = "; os << "call " << mfunction.name () << " ("; @@ -1475,20 +1554,46 @@ const jit_function& mfunction; }; +// FIXME: This is just ugly... +// checks error_state, if error_state is false then goto the normal branche, +// otherwise goto the error branch class -jit_extract_argument : public jit_instruction +jit_check_error : public jit_terminator { public: - jit_extract_argument (jit_type *atype, jit_variable *var) - : jit_instruction () + jit_check_error (jit_block *normal, jit_block *error) + : jit_terminator (normal, error) {} + + jit_block *sucessor (size_t idx) const + { + return static_cast (argument (idx)); + } + + size_t sucessor_count (void) const { return 2; } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent) << "check_error: normal: "; + print_sucessor (os, 0) << " error: "; + return print_sucessor (os, 1); + } + + JIT_VALUE_ACCEPT (jit_check_error) +}; + +class +jit_extract_argument : public jit_assign_base +{ +public: + jit_extract_argument (jit_type *atype, jit_variable *adest) + : jit_assign_base (adest) { stash_type (atype); - stash_tag (var); } const std::string& name (void) const { - return tag ()->name (); + return dest ()->name (); } const jit_function::overload& overload (void) const @@ -1512,14 +1617,12 @@ { public: jit_store_argument (jit_variable *var) - : jit_instruction (var) - { - stash_tag (var); - } + : jit_instruction (var), dest (var) + {} const std::string& name (void) const { - return tag ()->name (); + return dest->name (); } const jit_function::overload& overload (void) const @@ -1546,11 +1649,20 @@ { jit_value *res = result (); print_indent (os, indent) << "store "; - short_print (os) << " = "; - return res->short_print (os); + dest->short_print (os); + + if (! isa (res)) + { + os << " = "; + res->short_print (os); + } + + return os; } JIT_VALUE_ACCEPT (store_argument) +private: + jit_variable *dest; }; class @@ -1735,6 +1847,8 @@ jit_block *entry_block; + jit_block *final_block; + jit_block *block; llvm::Function *function; @@ -1774,7 +1888,7 @@ all_values.push_back (value); } - void construct_ssa (jit_block *final_block); + void construct_ssa (void); static void do_construct_ssa (jit_block& block);