Mercurial > hg > octave-lyh
changeset 14923:168cb10bb9c5
If, ifelse, and else statements JIT compile now
author | Max Brister <max@2bass.com> |
---|---|
date | Mon, 28 May 2012 23:19:41 -0500 |
parents | 2e6f83b2f2b9 |
children | d4d9a64db6aa |
files | src/pt-jit.cc src/pt-jit.h |
diffstat | 2 files changed, 272 insertions(+), 92 deletions(-) [+] |
line wrap: on
line diff
--- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -612,7 +612,7 @@ jit_block * jit_use::user_parent (void) const { - return usr->parent (); + return muser->parent (); } // -------------------- jit_value -------------------- @@ -660,6 +660,23 @@ return dynamic_cast<jit_terminator *> (last); } +jit_block * +jit_block::pred (size_t idx) const +{ + // FIXME: Make this O(1) + + // here we get the use in backwards order. This means we preserve phi + // information when new blocks are added + assert (idx < use_count ()); + jit_use *use; + size_t real_idx = use_count () - idx - 1; + size_t i; + for (use = first_use (), i = 0; use && i < real_idx; ++i, + use = use->next ()); + + return use->user_parent (); +} + llvm::Value * jit_block::pred_terminator_llvm (size_t idx) const { @@ -667,6 +684,17 @@ return term ? term->to_llvm () : 0; } +size_t +jit_block::pred_index (jit_block *apred) const +{ + for (size_t i = 0; i < pred_count (); ++i) + if (pred (i) == apred) + return i; + + fail ("No such predecessor"); + return 0; // silly compiler, why you warn? +} + void jit_block::create_merge (llvm::Function *inside, size_t pred_idx) { @@ -704,6 +732,21 @@ return term ? term->sucessor_count () : 0; } +jit_phi * +jit_block::search_phi (const std::string& tag_name, jit_value *adefault) +{ + jit_phi *ret; + for (iterator iter = begin (); iter != end () + && (ret = dynamic_cast<jit_phi *> (*iter)); ++iter) + if (ret->tag () == tag_name) + return ret; + + ret = new jit_phi (pred_count (), adefault); + ret->stash_tag (tag_name); + prepend (ret); + return ret; +} + llvm::BasicBlock * jit_block::to_llvm (void) const { @@ -952,7 +995,7 @@ // we need to do iter phi manually, for_map handles the rest jit_phi *iter_phi = new jit_phi (2); iter_phi->stash_tag ("#iter"); - iter_phi->stash_argument (1, init_iter); + iter_phi->stash_argument (0, init_iter); body->append (iter_phi); variable_map *merge_vars = variables; @@ -978,19 +1021,19 @@ check = block->append (new jit_call (jit_typeinfo::for_check, control, iter_inc)); block->append (new jit_cond_break (check, body, tail)); - iter_phi->stash_argument (0, iter_inc); + iter_phi->stash_argument (1, iter_inc); body_vars.finish_phi (*variables); + merge (tail, *merge_vars, block, body_vars); blocks.push_back (tail); prot_tail.discard (); block = tail; + variables = merge_vars; - variables = merge_vars; - merge (body_vars); iter_phi = new jit_phi (2); iter_phi->stash_tag ("#iter"); - iter_phi->stash_argument (0, iter_inc); - iter_phi->stash_argument (1, init_iter); + iter_phi->stash_argument (0, init_iter); + iter_phi->stash_argument (1, iter_inc); block->append (iter_phi); block->append (new jit_call (jit_typeinfo::release, iter_phi)); } @@ -1046,15 +1089,126 @@ } void -jit_convert::visit_if_command (tree_if_command&) +jit_convert::visit_if_command (tree_if_command& cmd) { - fail (); + tree_if_command_list *lst = cmd.cmd_list (); + assert (lst); // jwe: Can this be null? + lst->accept (*this); } void -jit_convert::visit_if_command_list (tree_if_command_list&) +jit_convert::visit_if_command_list (tree_if_command_list& lst) { - fail (); + // Example code: + // if a == 1 + // c = c + 1; + // elseif b == 1 + // c = c + 2; + // else + // c = c + 3; + // endif + + // Generates: + // prev_block0: % pred - ? + // #temp.0 = call binary== (a.0, 1) + // cond_break #temp.0, if_body1, ifelse_cond2 + // if_body1: + // c.1 = call binary+ (c.0, 1) + // break if_tail5 + // ifelse_cond2: + // #temp.1 = call binary== (b.0, 1) + // cond_break #temp.1, ifelse_body3, else4 + // ifelse_body3: + // c.2 = call binary+ (c.0, 2) + // break if_tail5 + // else4: + // c.3 = call binary+ (c.0, 3) + // break if_tail5 + // if_tail5: + // c.4 = phi | if_body1 -> c.1 + // | ifelse_body3 -> c.2 + // | else4 -> c.3 + + + tree_if_clause *last = lst.back (); + size_t last_else = static_cast<size_t> (last->is_else_clause ()); + + // entry_blocks represents the block you need to enter in order to execute + // the condition check for the ith clause. For the else, it is simple the + // else body. If there is no else body, then it is padded with the tail + std::vector<jit_block *> entry_blocks (lst.size () + 1 - last_else); + std::vector<variable_map *> branch_variables (lst.size (), 0); + std::vector<jit_block *> branch_blocks (lst.size (), 0); // final blocks + entry_blocks[0] = block; + + // we need to construct blocks first, because they have jumps to eachother + tree_if_command_list::iterator iter = lst.begin (); + ++iter; + for (size_t i = 1; iter != lst.end (); ++iter, ++i) + { + tree_if_clause *tic = *iter; + if (tic->is_else_clause ()) + entry_blocks[i] = new jit_block ("else"); + else + entry_blocks[i] = new jit_block ("ifelse_cond"); + cleanup_blocks.push_back (entry_blocks[i]); + } + + jit_block *tail = new jit_block ("if_tail"); + if (! last_else) + entry_blocks[entry_blocks.size () - 1] = tail; + + // actually fill out the contents of our blocks. We store the variable maps + // at the end of each branch, this allows us to merge them in the tail + variable_map *prev_map = variables; + iter = lst.begin (); + for (size_t i = 0; iter != lst.end (); ++iter, ++i) + { + tree_if_clause *tic = *iter; + block = entry_blocks[i]; + assert (block); + variables = prev_map; + + if (i) // the first block is prev_block, so it has already been added + blocks.push_back (entry_blocks[i]); + + if (! tic->is_else_clause ()) + { + tree_expression *expr = tic->condition (); + jit_value *cond = visit (expr); + + jit_block *body = new jit_block (i == 0 ? "if_body" : "ifelse_body"); + blocks.push_back (body); + + jit_instruction *br = new jit_cond_break (cond, body, + entry_blocks[i + 1]); + block->append (br); + block = body; + + variables = new compound_map (variables); + branch_variables[i] = variables; + } + + tree_statement_list *stmt_lst = tic->commands (); + assert (stmt_lst); // jwe: Can this be null? + stmt_lst->accept (*this); + + branch_variables[i] = variables; + branch_blocks[i] = block; + block->append (new jit_break (tail)); + } + + blocks.push_back (tail); + + // We create phi nodes in the tail to merge blocks + for (size_t i = 0; i < branch_variables.size () - last_else; ++i) + { + merge (tail, *prev_map, branch_blocks[i], *branch_variables[i]); + delete branch_variables[i]; + } + + variables = prev_map; + block = tail; } void @@ -1281,22 +1435,28 @@ } void -jit_convert::merge (const variable_map& ref) +jit_convert::merge (jit_block *merge_block, variable_map& merge_vars, + jit_block *incomming_block, + const variable_map& incomming_vars) { - assert (variables->size () == ref.size ()); - variable_map::iterator viter = variables->begin (); - variable_map::const_iterator riter = ref.begin (); - for (; viter != variables->end (); ++viter, ++riter) + size_t merge_idx = merge_block->pred_index (incomming_block); + for (variable_map::const_iterator iter = incomming_vars.begin (); + iter != incomming_vars.end (); ++iter) { - assert (viter->first == riter->first); - if (viter->second != riter->second) + const std::string& vname = iter->first; + jit_value *merge_val = merge_vars.get (vname); + jit_value *inc_val = iter->second; + + if (merge_val != inc_val) { - jit_phi *phi = new jit_phi (2); - phi->stash_tag (viter->first); - block->prepend (phi); - phi->stash_argument (0, riter->second); - phi->stash_argument (1, viter->second); - viter->second = phi; + jit_phi *phi = dynamic_cast<jit_phi *> (merge_val); + if (! (phi && phi->parent () == merge_block)) + { + phi = merge_block->search_phi (vname, merge_val); + merge_vars.set (vname, phi); + } + + phi->stash_argument (merge_idx, inc_val); } } }
--- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -46,7 +46,9 @@ // b = a + a; // will compile to do_binary_op (a, a). // -// For loops are compiled again! Additionally, make check passes using jit. +// For loops are compiled again! +// if, elseif, and else statements compile again! +// Additionally, make check passes using jit. // // The octave low level IR is a linear IR, it works by converting everything to // calls to jit_functions. This turns expressions like c = a + b into @@ -56,8 +58,8 @@ // // // TODO: -// 1. Support if statements -// 2. Support error cases +// 1. Support error cases +// 2. Support break/continue // 3. Fix memory leaks in JIT // 4. Cleanup/documentation // 5. ... @@ -566,6 +568,7 @@ private: jit_type *ty; jit_use *use_head; + jit_use *use_tail; size_t myuse_count; }; @@ -625,68 +628,82 @@ jit_use { public: - jit_use (void) : used (0), next_use (0), prev_use (0) {} + jit_use (void) : mvalue (0), mnext (0), mprev (0), muser (0), mindex (0) {} + + // we should really have a move operator, but not until c++11 :( + jit_use (const jit_use& use) : mvalue (0), mnext (0), mprev (0), muser (0), + mindex (0) + { + *this = use; + } ~jit_use (void) { remove (); } - jit_value *value (void) const { return used; } + jit_use& operator= (const jit_use& use) + { + stash_value (use.value (), use.user (), use.index ()); + return *this; + } - size_t index (void) const { return idx; } + jit_value *value (void) const { return mvalue; } - jit_instruction *user (void) const { return usr; } + size_t index (void) const { return mindex; } + + jit_instruction *user (void) const { return muser; } jit_block *user_parent (void) const; - void stash_value (jit_value *new_value, jit_instruction *u = 0, - size_t use_idx = -1) + void stash_value (jit_value *avalue, jit_instruction *auser = 0, + size_t aindex = -1) { remove (); - used = new_value; + mvalue = avalue; - if (used) + if (mvalue) { - if (used->use_head) + if (mvalue->use_head) { - used->use_head->prev_use = this; - next_use = used->use_head; + mvalue->use_head->mprev = this; + mnext = mvalue->use_head; } - used->use_head = this; - ++used->myuse_count; + mvalue->use_head = this; + ++mvalue->myuse_count; } - idx = use_idx; - usr = u; + mindex = aindex; + muser = auser; } - jit_use *next (void) const { return next_use; } + jit_use *next (void) const { return mnext; } - jit_use *prev (void) const { return prev_use; } + jit_use *prev (void) const { return mprev; } private: void remove (void) { - if (used) + if (mvalue) { - if (this == used->use_head) - used->use_head = next_use; + if (this == mvalue->use_head) + mvalue->use_head = mnext; - if (prev_use) - prev_use->next_use = next_use; + if (mprev) + mprev->mnext = mnext; - if (next_use) - next_use->prev_use = prev_use; + if (mnext) + mnext->mprev = mprev; - next_use = prev_use = 0; - --used->myuse_count; + mnext = mprev = 0; + --mvalue->myuse_count; + mvalue = 0; } } - jit_value *used; - jit_use *next_use; - jit_use *prev_use; - jit_instruction *usr; - size_t idx; + jit_value *mvalue; + jit_use *mnext; + jit_use *mprev; + jit_instruction *muser; + size_t mindex; }; class @@ -697,10 +714,14 @@ jit_instruction (void) : id (next_id ()), mparent (0) {} - jit_instruction (size_t nargs) + jit_instruction (size_t nargs, jit_value *adefault = 0) : already_infered (nargs, reinterpret_cast<jit_type *>(0)), arguments (nargs), id (next_id ()), mparent (0) - {} + { + if (adefault) + for (size_t i = 0; i < nargs; ++i) + stash_argument (i, adefault); + } jit_instruction (jit_value *arg0) : already_infered (1, reinterpret_cast<jit_type *>(0)), arguments (1), @@ -772,6 +793,16 @@ return arguments.size (); } + void resize_arguments (size_t acount, jit_value *adefault = 0) + { + size_t old = arguments.size (); + arguments.resize (acount); + + if (adefault) + for (size_t i = old; i < acount; ++i) + stash_argument (i, adefault); + } + // argument types which have been infered already const std::vector<jit_type *>& argument_types (void) const { return already_infered; } @@ -813,7 +844,7 @@ return ret++; } - std::vector<jit_use> arguments; // DO NOT resize + std::vector<jit_use> arguments; std::string mtag; size_t id; @@ -821,6 +852,7 @@ }; class jit_terminator; +class jit_phi; class jit_block : public jit_value @@ -848,18 +880,7 @@ jit_terminator *terminator (void) const; - jit_block *pred (size_t idx) const - { - // FIXME: We should probably make this O(1) - jit_use *puse = first_use (); - for (size_t i = 0; i < idx; ++i) - { - assert (puse); - puse = puse->next (); - } - - return puse->user_parent (); - } + jit_block *pred (size_t idx) const; jit_terminator *pred_terminator (size_t idx) const { @@ -876,7 +897,7 @@ // takes into account for the addition of phi merges llvm::BasicBlock *pred_llvm (size_t idx) const { - if (mpred_llvm.size () <= idx) + if (mpred_llvm.size () < pred_count ()) mpred_llvm.resize (pred_count ()); return mpred_llvm[idx] ? mpred_llvm[idx] : pred (idx)->to_llvm (); @@ -887,19 +908,7 @@ return pred_llvm (pred_index (apred)); } - size_t pred_index (jit_block *apred) const - { - jit_use *puse = first_use (); - size_t idx = 0; - while (puse->user_parent () != apred) - { - assert (puse); - puse = puse->next (); - ++idx; - } - - return idx; - } + size_t pred_index (jit_block *apred) const; // create llvm phi merge blocks for all predecessors (if required) void create_merge (llvm::Function *inside, size_t pred_idx); @@ -916,6 +925,10 @@ const_iterator end (void) const { return instructions.begin (); } + // search for the phi function with the given tag_name, if no function + // exists then a new phi node is created + jit_phi *search_phi (const std::string& tag_name, jit_value *adefault); + virtual std::ostream& print (std::ostream& os, size_t indent) { print_indent (os, indent) << mname << ":\tpred = "; @@ -953,7 +966,8 @@ jit_phi : public jit_instruction { public: - jit_phi (size_t npred) : jit_instruction (npred) + jit_phi (size_t npred, jit_value *adefault = 0) + : jit_instruction (npred, adefault) {} virtual bool infer (void) @@ -1347,6 +1361,8 @@ mblock (ablock) {} + virtual ~variable_map () {} + variable_map *parent (void) const { return mparent; } jit_block *block (void) const { return mblock; } @@ -1419,8 +1435,8 @@ iter != for_body->end () && dynamic_cast<jit_phi *> (*iter); ++iter) { jit_instruction *node = *iter; - if (! node->argument (0)) - node->stash_argument (0, from.get (node->tag ())); + if (! node->argument (1)) + node->stash_argument (1, from.get (node->tag ())); } } protected: @@ -1429,7 +1445,7 @@ jit_phi *ret = new jit_phi (2); ret->stash_tag (name); block ()->prepend (ret); - ret->stash_argument (1, pval); + ret->stash_argument (0, pval); return vars[name] = ret; } }; @@ -1460,6 +1476,8 @@ std::list<jit_block *> blocks; + std::list<jit_block *> cleanup_blocks; + std::list<jit_instruction *> worklist; std::list<jit_value *> constants; @@ -1486,7 +1504,9 @@ // place phi nodes in the current block to merge ref with variables // we assume the same number of deffinitions - void merge (const variable_map& ref); + void merge (jit_block *merge_block, variable_map& merge_vars, + jit_block *incomming_block, + const variable_map& incomming_vars); // this case is much simpler, just convert from the jit ir to llvm class