# HG changeset patch # User Max Brister # Date 1337893689 21600 # Node ID 232d8ab07932c0ac00a65a4440a0f3f95156bfe2 # Parent 0b0569667939451c8d7332fb5563c3a4df2666b6 Rewrite pt-jit.* adding new low level octave IR * src/pt-eval.cc (tree_evaluator::visit_simple_for_command): Remove jit (tree_evaluator::visit_statement): Add jit * src/pt-jit.h: Rewrite * src/pt-jit.cc: Rewrite diff --git a/src/pt-eval.cc b/src/pt-eval.cc --- a/src/pt-eval.cc +++ b/src/pt-eval.cc @@ -310,9 +310,6 @@ if (error_state || rhs.is_undefined ()) return; - if (jiter.execute (cmd, rhs)) - return; - { tree_expression *lhs = cmd.left_hand_side (); @@ -687,6 +684,9 @@ tree_command *cmd = stmt.command (); tree_expression *expr = stmt.expression (); + if (jiter.execute (stmt)) + return; + if (cmd || expr) { if (statement_context == function || statement_context == script) diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -52,6 +52,7 @@ #include "octave.h" #include "ov-fcn-handle.h" #include "ov-usr-fcn.h" +#include "ov-scalar.h" #include "pt-all.h" // FIXME: Remove eventually @@ -60,6 +61,10 @@ static llvm::IRBuilder<> builder (llvm::getGlobalContext ()); +static llvm::LLVMContext& context = llvm::getGlobalContext (); + +jit_typeinfo *jit_typeinfo::instance; + // thrown when we should give up on JIT and interpret class jit_fail_exception : public std::exception {}; @@ -102,10 +107,25 @@ obv->release (); } -extern "C" void +extern "C" octave_base_value * octave_jit_grab_any (octave_base_value *obv) { obv->grab (); + return obv; +} + +extern "C" double +octave_jit_cast_scalar_any (octave_base_value *obv) +{ + double ret = obv->double_value (); + obv->release (); + return ret; +} + +extern "C" octave_base_value * +octave_jit_cast_any_scalar (double value) +{ + return new octave_scalar (value); } // -------------------- jit_type -------------------- @@ -155,6 +175,10 @@ if (types.size () >= overloads.size ()) return null_overload; + for (size_t i =0; i < types.size (); ++i) + if (! types[i]) + return null_overload; + const Array& over = overloads[types.size ()]; dim_vector dv (over.dims ()); Array idx = to_idx (types); @@ -187,46 +211,56 @@ } // -------------------- jit_typeinfo -------------------- +void +jit_typeinfo::initialize (llvm::Module *m, llvm::ExecutionEngine *e) +{ + instance = new jit_typeinfo (m, e); +} + jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e) : module (m), engine (e), next_id (0) { // FIXME: We should be registering types like in octave_value_typeinfo - llvm::LLVMContext &ctx = m->getContext (); - - ov_t = llvm::StructType::create (ctx, "octave_base_value"); + ov_t = llvm::StructType::create (context, "octave_base_value"); ov_t = ov_t->getPointerTo (); - llvm::Type *dbl = llvm::Type::getDoubleTy (ctx); - llvm::Type *bool_t = llvm::Type::getInt1Ty (ctx); + llvm::Type *dbl = llvm::Type::getDoubleTy (context); + llvm::Type *bool_t = llvm::Type::getInt1Ty (context); + llvm::Type *string_t = llvm::Type::getInt8Ty (context); + string_t = string_t->getPointerTo (); llvm::Type *index_t = 0; switch (sizeof(octave_idx_type)) { case 4: - index_t = llvm::Type::getInt32Ty (ctx); + index_t = llvm::Type::getInt32Ty (context); break; case 8: - index_t = llvm::Type::getInt64Ty (ctx); + index_t = llvm::Type::getInt64Ty (context); break; default: assert (false && "Unrecognized index type size"); } - llvm::StructType *range_t = llvm::StructType::create (ctx, "range"); + llvm::StructType *range_t = llvm::StructType::create (context, "range"); std::vector range_contents (4, dbl); range_contents[3] = index_t; range_t->setBody (range_contents); // create types - any = new_type ("any", true, 0, ov_t); - scalar = new_type ("scalar", false, any, dbl); - range = new_type ("range", false, any, range_t); - boolean = new_type ("bool", false, any, bool_t); - index = new_type ("index", false, any, index_t); + any = new_type ("any", 0, ov_t); + scalar = new_type ("scalar", any, dbl); + range = new_type ("range", any, range_t); + string = new_type ("string", any, string_t); + boolean = new_type ("bool", any, bool_t); + index = new_type ("index", any, index_t); + + casts.resize (next_id + 1); + identities.resize (next_id + 1, 0); // any with anything is an any op llvm::Function *fn; llvm::Type *binary_op_type - = llvm::Type::getIntNTy (ctx, sizeof (octave_value::binary_op)); + = llvm::Type::getIntNTy (context, sizeof (octave_value::binary_op)); llvm::Function *any_binary = create_function ("octave_jit_binary_any_any", any->to_llvm (), binary_op_type, any->to_llvm (), any->to_llvm ()); @@ -234,12 +268,19 @@ reinterpret_cast(&octave_jit_binary_any_any)); binary_ops.resize (octave_value::num_binary_ops); + for (size_t i = 0; i < octave_value::num_binary_ops; ++i) + { + octave_value::binary_op op = static_cast (i); + std::string op_name = octave_value::binary_op_as_string (op); + binary_ops[i].stash_name ("binary" + op_name); + } + for (int op = 0; op < octave_value::num_binary_ops; ++op) { llvm::Twine fn_name ("octave_jit_binary_any_any_"); fn_name = fn_name + llvm::Twine (op); fn = create_function (fn_name, any, any, any); - llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (block); llvm::APInt op_int(sizeof (octave_value::binary_op), op, std::numeric_limits::is_signed); @@ -255,18 +296,28 @@ binary_ops[op].add_overload (overload); } - llvm::Type *void_t = llvm::Type::getVoidTy (ctx); + llvm::Type *void_t = llvm::Type::getVoidTy (context); // grab any - fn = create_function ("octave_jit_grab_any", void_t, any->to_llvm ()); + fn = create_function ("octave_jit_grab_any", any, any); engine->addGlobalMapping (fn, reinterpret_cast(&octave_jit_grab_any)); - grab_fn.add_overload (fn, false, 0, any); + grab_fn.add_overload (fn, false, any, any); + grab_fn.stash_name ("grab"); + + // grab scalar + fn = create_identity (scalar); + grab_fn.add_overload (fn, false, scalar, scalar); // release any fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ()); engine->addGlobalMapping (fn, reinterpret_cast(&octave_jit_release_any)); release_fn.add_overload (fn, false, 0, any); + release_fn.stash_name ("release"); + + // release scalar + fn = create_identity (scalar); + release_fn.add_overload (fn, false, 0, scalar); // now for binary scalar operations // FIXME: Finish all operations @@ -287,12 +338,13 @@ add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); // now for printing functions + print_fn.stash_name ("print"); add_print (any, reinterpret_cast (&octave_jit_print_any)); add_print (scalar, reinterpret_cast (&octave_jit_print_double)); // bounds check for for loop fn = create_function ("octave_jit_simple_for_range", boolean, range, index); - llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", fn); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { llvm::Value *nelem @@ -307,7 +359,7 @@ // increment for for loop fn = create_function ("octave_jit_imple_for_range_incr", index, index); - body = llvm::BasicBlock::Create (ctx, "body", fn); + body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { llvm::Value *one = llvm::ConstantInt::get (index_t, 1); @@ -320,7 +372,7 @@ // index variabe for for loop fn = create_function ("octave_jit_simple_for_idx", scalar, range, index); - body = llvm::BasicBlock::Create (ctx, "body", fn); + body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { llvm::Value *idx = ++fn->arg_begin (); @@ -339,7 +391,7 @@ // logically true // FIXME: Check for NaN fn = create_function ("octave_logically_true_scalar", boolean, scalar); - body = llvm::BasicBlock::Create (ctx, "body", fn); + body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { llvm::Value *zero = llvm::ConstantFP::get (scalar->to_llvm (), 0); @@ -350,11 +402,33 @@ logically_true.add_overload (fn, true, boolean, scalar); fn = create_function ("octave_logically_true_bool", boolean, boolean); - body = llvm::BasicBlock::Create (ctx, "body", fn); + body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); builder.CreateRet (fn->arg_begin ()); llvm::verifyFunction (*fn); logically_true.add_overload (fn, false, boolean, boolean); + logically_true.stash_name ("logically_true"); + + casts[any->type_id ()].stash_name ("(any)"); + casts[scalar->type_id ()].stash_name ("(scalar)"); + + // cast any <- scalar + fn = create_function ("octave_jit_cast_any_scalar", any, scalar); + engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_any_scalar)); + casts[any->type_id ()].add_overload (fn, false, any, scalar); + + // cast scalar <- any + fn = create_function ("octave_jit_cast_scalar_any", scalar, any); + engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_scalar_any)); + casts[scalar->type_id ()].add_overload (fn, false, scalar, any); + + // cast any <- any + fn = create_identity (any); + casts[any->type_id ()].add_overload (fn, false, any, any); + + // cast scalar <- scalar + fn = create_identity (scalar); + casts[scalar->type_id ()].add_overload (fn, false, scalar, scalar); } void @@ -363,14 +437,13 @@ std::stringstream name; name << "octave_jit_print_" << ty->name (); - llvm::LLVMContext &ctx = llvm::getGlobalContext (); - llvm::Type *void_t = llvm::Type::getVoidTy (ctx); + llvm::Type *void_t = llvm::Type::getVoidTy (context); llvm::Function *fn = create_function (name.str (), void_t, - llvm::Type::getInt8PtrTy (ctx), + llvm::Type::getInt8PtrTy (context), ty->to_llvm ()); engine->addGlobalMapping (fn, call); - jit_function::overload ol (fn, false, 0, ty); + jit_function::overload ol (fn, false, 0, string, ty); print_fn.add_overload (ol); } @@ -383,9 +456,8 @@ fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op) << "_" << ty->name (); - llvm::LLVMContext &ctx = llvm::getGlobalContext (); llvm::Function *fn = create_function (fname.str (), ty, ty, ty); - llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (block); llvm::Instruction::BinaryOps temp = static_cast(llvm_op); @@ -406,9 +478,8 @@ fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) << "_" << ty->name (); - llvm::LLVMContext &ctx = llvm::getGlobalContext (); llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); - llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (block); llvm::CmpInst::Predicate temp = static_cast(llvm_op); @@ -429,9 +500,8 @@ fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) << "_" << ty->name (); - llvm::LLVMContext &ctx = llvm::getGlobalContext (); llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); - llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (block); llvm::CmpInst::Predicate temp = static_cast(llvm_op); @@ -454,10 +524,30 @@ name, module); fn->addFnAttr (llvm::Attribute::AlwaysInline); return fn; -} +} + +llvm::Function * +jit_typeinfo::create_identity (jit_type *type) +{ + size_t id = type->type_id (); + if (id >= identities.size ()) + identities.resize (id + 1, 0); -jit_type* -jit_typeinfo::type_of (const octave_value &ov) const + if (! identities[id]) + { + llvm::Function *fn = create_function ("id", type, type); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + builder.CreateRet (fn->arg_begin ()); + llvm::verifyFunction (*fn); + identities[id] = fn; + } + + return identities[id]; +} + +jit_type * +jit_typeinfo::do_type_of (const octave_value &ov) const { if (ov.is_undefined () || ov.is_function ()) return 0; @@ -471,34 +561,21 @@ return get_any (); } -const jit_function& -jit_typeinfo::binary_op (int op) const -{ - assert (static_cast(op) < binary_ops.size ()); - return binary_ops[op]; -} - -const jit_function::overload& -jit_typeinfo::print_value (jit_type *to_print) const -{ - return print_fn.get_overload (to_print); -} - void -jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv) +jit_typeinfo::do_to_generic (jit_type *type, llvm::GenericValue& gv) { if (type == any) - to_generic (type, gv, octave_value ()); + do_to_generic (type, gv, octave_value ()); else if (type == scalar) - to_generic (type, gv, octave_value (0)); + do_to_generic (type, gv, octave_value (0)); else if (type == range) - to_generic (type, gv, octave_value (Range ())); + do_to_generic (type, gv, octave_value (Range ())); else assert (false && "Type not supported yet"); } void -jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov) +jit_typeinfo::do_to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov) { if (type == any) { @@ -522,7 +599,7 @@ } octave_value -jit_typeinfo::to_octave_value (jit_type *type, llvm::GenericValue& gv) +jit_typeinfo::do_to_octave_value (jit_type *type, llvm::GenericValue& gv) { if (type == any) { @@ -545,7 +622,7 @@ } void -jit_typeinfo::reset_generic (void) +jit_typeinfo::do_reset_generic (void) { scalar_out.clear (); ov_out.clear (); @@ -553,926 +630,373 @@ } jit_type* -jit_typeinfo::new_type (const std::string& name, bool force_init, - jit_type *parent, llvm::Type *llvm_type) +jit_typeinfo::new_type (const std::string& name, jit_type *parent, + llvm::Type *llvm_type) { - jit_type *ret = new jit_type (name, force_init, parent, llvm_type, next_id++); + jit_type *ret = new jit_type (name, parent, llvm_type, next_id++); id_to_type.push_back (ret); return ret; } -// -------------------- jit_infer -------------------- -void -jit_infer::infer (tree_simple_for_command& cmd, jit_type *bounds) -{ - infer_simple_for (cmd, bounds); -} - -void -jit_infer::visit_anon_fcn_handle (tree_anon_fcn_handle&) -{ - fail (); -} - -void -jit_infer::visit_argument_list (tree_argument_list&) -{ - fail (); -} - -void -jit_infer::visit_binary_expression (tree_binary_expression& be) +// -------------------- jit_block -------------------- +llvm::BasicBlock * +jit_block::to_llvm (void) const { - if (is_lvalue) - fail (); - - if (be.op_type () >= octave_value::num_binary_ops) - fail (); - - tree_expression *lhs = be.lhs (); - lhs->accept (*this); - jit_type *tlhs = type_stack.back (); - type_stack.pop_back (); - - tree_expression *rhs = be.rhs (); - rhs->accept (*this); - jit_type *trhs = type_stack.back (); - - jit_type *result = tinfo->binary_op_result (be.op_type (), tlhs, trhs); - if (! result) - fail (); - - type_stack.push_back (result); -} - -void -jit_infer::visit_break_command (tree_break_command&) -{ - fail (); + return llvm::cast (llvm_value); } -void -jit_infer::visit_colon_expression (tree_colon_expression&) -{ - fail (); -} - -void -jit_infer::visit_continue_command (tree_continue_command&) -{ - fail (); -} - -void -jit_infer::visit_global_command (tree_global_command&) -{ - fail (); -} - -void -jit_infer::visit_persistent_command (tree_persistent_command&) -{ - fail (); -} - -void -jit_infer::visit_decl_elt (tree_decl_elt&) +// -------------------- jit_call -------------------- +bool +jit_call::infer (void) { - fail (); -} - -void -jit_infer::visit_decl_init_list (tree_decl_init_list&) -{ - fail (); -} - -void -jit_infer::visit_simple_for_command (tree_simple_for_command& cmd) -{ - tree_expression *control = cmd.control_expr (); - control->accept (*this); - - jit_type *control_t = type_stack.back (); - type_stack.pop_back (); - - // FIXME: We should improve type inference so we don't have to do this - // to generate nested for loop code - - // quick hack, check if the for loop bounds are const. If we - // run at least one, we don't have to merge types - bool atleast_once = false; - if (control->is_constant ()) + // FIXME explain algorithm + jit_type *current = type (); + for (size_t i = 0; i < argument_count (); ++i) { - octave_value over = control->rvalue1 (); - if (over.is_range ()) + jit_type *arg_type = argument_type (i); + jit_type *todo = jit_typeinfo::difference (arg_type, already_infered[i]); + if (todo) { - Range rng = over.range_value (); - atleast_once = rng.nelem () > 0; + already_infered[i] = todo; + jit_type *fresult = mfunction.get_result (already_infered); + current = jit_typeinfo::tunion (current, fresult); + already_infered[i] = arg_type; } } - if (atleast_once) - infer_simple_for (cmd, control_t); - else + if (current != type ()) { - type_map fallthrough = types; - infer_simple_for (cmd, control_t); - merge (types, fallthrough); + stash_type (current); + return true; + } + + return false; +} + +// -------------------- jit_convert -------------------- +jit_convert::jit_convert (llvm::Module *module, tree &tee) +{ + jit_instruction::reset_ids (); + + entry_block = new jit_block ("entry"); + blocks.push_back (entry_block); + block = new jit_block ("body"); + blocks.push_back (block); + + final_block = new jit_block ("final"); + visit (tee); + blocks.push_back (final_block); + + entry_block->append (new jit_break (block)); + block->append (new jit_break (final_block)); + + for (variable_map::iterator iter = variables.begin (); + iter != variables.end (); ++iter) + final_block->append (new jit_store_argument (iter->first, iter->second)); + + // FIXME: Maybe we should remove dead code here? + + // initialize the worklist to instructions derived from constants + for (std::list::iterator iter = constants.begin (); + iter != constants.end (); ++iter) + append_users (*iter); + + // FIXME: Describe algorithm here + while (worklist.size ()) + { + jit_instruction *next = worklist.front (); + worklist.pop_front (); + + if (next->infer ()) + append_users (next); + } + + if (debug_print) + { + std::cout << "-------------------- Compiling tree --------------------\n"; + std::cout << tee.str_print_code () << std::endl; + std::cout << "-------------------- octave jit ir --------------------\n"; + for (std::list::iterator iter = blocks.begin (); + iter != blocks.end (); ++iter) + (*iter)->print (std::cout, 0); + std::cout << std::endl; + } + + convert_llvm to_llvm; + function = to_llvm.convert (module, arguments, blocks, constants); + + if (debug_print) + { + std::cout << "-------------------- llvm ir --------------------"; + llvm::raw_os_ostream llvm_cout (std::cout); + function->print (llvm_cout); + std::cout << std::endl; } } void -jit_infer::visit_complex_for_command (tree_complex_for_command&) -{ - fail (); -} - -void -jit_infer::visit_octave_user_script (octave_user_script&) -{ - fail (); -} - -void -jit_infer::visit_octave_user_function (octave_user_function&) -{ - fail (); -} - -void -jit_infer::visit_octave_user_function_header (octave_user_function&) -{ - fail (); -} - -void -jit_infer::visit_octave_user_function_trailer (octave_user_function&) -{ - fail (); -} - -void -jit_infer::visit_function_def (tree_function_def&) -{ - fail (); -} - -void -jit_infer::visit_identifier (tree_identifier& ti) -{ - symbol_table::symbol_record_ref record = ti.symbol (); - handle_identifier (record); -} - -void -jit_infer::visit_if_clause (tree_if_clause&) +jit_convert::visit_anon_fcn_handle (tree_anon_fcn_handle&) { fail (); } void -jit_infer::visit_if_command (tree_if_command& cmd) -{ - if (is_lvalue) - fail (); - - tree_if_command_list *lst = cmd.cmd_list (); - assert (lst); - lst->accept (*this); -} - -void -jit_infer::visit_if_command_list (tree_if_command_list& lst) -{ - // determine the types on each branch of the if seperatly, then merge - type_map fallthrough = types, last; - bool first_time = true; - for (tree_if_command_list::iterator p = lst.begin (); p != lst.end(); ++p) - { - tree_if_clause *tic = *p; - - if (! first_time) - types = fallthrough; - - if (! tic->is_else_clause ()) - { - tree_expression *expr = tic->condition (); - expr->accept (*this); - } - - fallthrough = types; - - tree_statement_list *stmt_lst = tic->commands (); - assert (stmt_lst); - stmt_lst->accept (*this); - - if (first_time) - last = types; - else - merge (last, types); - } - - types = last; - - tree_if_clause *last_clause = lst.back (); - if (! last_clause->is_else_clause ()) - merge (types, fallthrough); -} - -void -jit_infer::visit_index_expression (tree_index_expression&) +jit_convert::visit_argument_list (tree_argument_list&) { fail (); } void -jit_infer::visit_matrix (tree_matrix&) +jit_convert::visit_binary_expression (tree_binary_expression& be) { - fail (); -} - -void -jit_infer::visit_cell (tree_cell&) -{ - fail (); -} + if (be.op_type () >= octave_value::num_binary_ops) + // this is the case for bool_or and bool_and + fail (); -void -jit_infer::visit_multi_assignment (tree_multi_assignment&) -{ - fail (); -} + tree_expression *lhs = be.lhs (); + jit_value *lhsv = visit (lhs); -void -jit_infer::visit_no_op_command (tree_no_op_command&) -{ - fail (); + tree_expression *rhs = be.rhs (); + jit_value *rhsv = visit (rhs); + + const jit_function& fn = jit_typeinfo::binary_op (be.op_type ()); + result = block->append (new jit_call (fn, lhsv, rhsv)); } void -jit_infer::visit_constant (tree_constant& tc) -{ - if (is_lvalue) - fail (); - - octave_value v = tc.rvalue1 (); - jit_type *type = tinfo->type_of (v); - if (! type) - fail (); - - type_stack.push_back (type); -} - -void -jit_infer::visit_fcn_handle (tree_fcn_handle&) -{ - fail (); -} - -void -jit_infer::visit_parameter_list (tree_parameter_list&) -{ - fail (); -} - -void -jit_infer::visit_postfix_expression (tree_postfix_expression&) -{ - fail (); -} - -void -jit_infer::visit_prefix_expression (tree_prefix_expression&) -{ - fail (); -} - -void -jit_infer::visit_return_command (tree_return_command&) +jit_convert::visit_break_command (tree_break_command&) { fail (); } void -jit_infer::visit_return_list (tree_return_list&) +jit_convert::visit_colon_expression (tree_colon_expression&) +{ + fail (); +} + +void +jit_convert::visit_continue_command (tree_continue_command&) +{ + fail (); +} + +void +jit_convert::visit_global_command (tree_global_command&) { fail (); } void -jit_infer::visit_simple_assignment (tree_simple_assignment& tsa) -{ - if (is_lvalue) - fail (); - - // resolve rhs - is_lvalue = false; - tree_expression *rhs = tsa.right_hand_side (); - rhs->accept (*this); - - jit_type *trhs = type_stack.back (); - type_stack.pop_back (); - - // resolve lhs - is_lvalue = true; - rvalue_type = trhs; - tree_expression *lhs = tsa.left_hand_side (); - lhs->accept (*this); - - // we don't pop back here, as the resulting type should be the rhs type - // which is equal to the lhs type anways - jit_type *tlhs = type_stack.back (); - if (tlhs != trhs) - fail (); - - is_lvalue = false; - rvalue_type = 0; -} - -void -jit_infer::visit_statement (tree_statement& stmt) -{ - if (is_lvalue) - fail (); - - tree_command *cmd = stmt.command (); - tree_expression *expr = stmt.expression (); - - if (cmd) - cmd->accept (*this); - else - { - // ok, this check for ans appears three times as cp - bool do_bind_ans = false; - - if (expr->is_identifier ()) - { - tree_identifier *id = dynamic_cast (expr); - - do_bind_ans = (! id->is_variable ()); - } - else - do_bind_ans = (! expr->is_assignment_expression ()); - - expr->accept (*this); - - if (do_bind_ans) - { - is_lvalue = true; - rvalue_type = type_stack.back (); - type_stack.pop_back (); - - symbol_table::symbol_record_ref record (symbol_table::insert ("ans")); - handle_identifier (record); - - if (rvalue_type != type_stack.back ()) - fail (); - - is_lvalue = false; - rvalue_type = 0; - } - - type_stack.pop_back (); - } -} - -void -jit_infer::visit_statement_list (tree_statement_list& lst) -{ - tree_statement_list::iterator iter; - for (iter = lst.begin (); iter != lst.end (); ++iter) - { - tree_statement *stmt = *iter; - assert (stmt); // FIXME: jwe can this be null? - stmt->accept (*this); - } -} - -void -jit_infer::visit_switch_case (tree_switch_case&) +jit_convert::visit_persistent_command (tree_persistent_command&) { fail (); } void -jit_infer::visit_switch_case_list (tree_switch_case_list&) -{ - fail (); -} - -void -jit_infer::visit_switch_command (tree_switch_command&) +jit_convert::visit_decl_elt (tree_decl_elt&) { fail (); } void -jit_infer::visit_try_catch_command (tree_try_catch_command&) +jit_convert::visit_decl_init_list (tree_decl_init_list&) { fail (); } void -jit_infer::visit_unwind_protect_command (tree_unwind_protect_command&) -{ - fail (); -} - -void -jit_infer::visit_while_command (tree_while_command&) -{ - fail (); -} - -void -jit_infer::visit_do_until_command (tree_do_until_command&) +jit_convert::visit_simple_for_command (tree_simple_for_command&) { fail (); } void -jit_infer::infer_simple_for (tree_simple_for_command& cmd, - jit_type *bounds) +jit_convert::visit_complex_for_command (tree_complex_for_command&) { - if (is_lvalue) - fail (); - - jit_type *iter = tinfo->get_simple_for_index_result (bounds); - if (! iter) - fail (); - - is_lvalue = true; - rvalue_type = iter; - tree_expression *lhs = cmd.left_hand_side (); - lhs->accept (*this); - if (type_stack.back () != iter) - fail (); - type_stack.pop_back (); - is_lvalue = false; - rvalue_type = 0; - - tree_statement_list *body = cmd.body (); - body->accept (*this); + fail (); } void -jit_infer::handle_identifier (const symbol_table::symbol_record_ref& record) +jit_convert::visit_octave_user_script (octave_user_script&) { - type_map::iterator iter = types.find (record); - if (iter == types.end ()) - { - jit_type *ty = tinfo->type_of (record->find ()); - bool argin = false; - if (is_lvalue) - { - if (! ty) - ty = rvalue_type; - } - else - { - if (! ty) - fail (); - argin = true; - } - - types[record] = type_entry (argin, ty); - type_stack.push_back (ty); - } - else - type_stack.push_back (iter->second.second); + fail (); } void -jit_infer::merge (type_map& dest, const type_map& src) -{ - if (dest.size () != src.size ()) - fail (); - - type_map::iterator dest_iter; - type_map::const_iterator src_iter; - for (dest_iter = dest.begin (), src_iter = src.begin (); - dest_iter != dest.end (); ++dest_iter, ++src_iter) - { - if (dest_iter->first.name () != src_iter->first.name () - || dest_iter->second.second != src_iter->second.second) - fail (); - - // require argin if one path requires argin - dest_iter->second.first = dest_iter->second.first - || src_iter->second.first; - } -} - -// -------------------- jit_generator -------------------- -jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *mod, - tree_simple_for_command& cmd, jit_type *bounds, - const type_map& infered_types) - : tinfo (ti), module (mod), is_lvalue (false) +jit_convert::visit_octave_user_function (octave_user_function&) { - // create new vectors that include bounds - std::vector names (infered_types.size () + 1); - std::vector argin (infered_types.size () + 1); - std::vector types (infered_types.size () + 1); - names[0] = "#bounds"; - argin[0] = true; - types[0] = bounds; - size_t i; - type_map::const_iterator iter; - for (i = 1, iter = infered_types.begin (); iter != infered_types.end (); - ++i, ++iter) - { - names[i] = iter->first.name (); - argin[i] = iter->second.first; - types[i] = iter->second.second; - } - - initialize (names, argin, types); - - try - { - value var_bounds = variables["#bounds"]; - var_bounds.second = builder.CreateLoad (var_bounds.second); - emit_simple_for (cmd, var_bounds, true); - } - catch (const jit_fail_exception&) - { - function->eraseFromParent (); - function = 0; - return; - } - - finalize (names); + fail (); } void -jit_generator::visit_anon_fcn_handle (tree_anon_fcn_handle&) +jit_convert::visit_octave_user_function_header (octave_user_function&) +{ + fail (); +} + +void +jit_convert::visit_octave_user_function_trailer (octave_user_function&) +{ + fail (); +} + +void +jit_convert::visit_function_def (tree_function_def&) { fail (); } void -jit_generator::visit_argument_list (tree_argument_list&) +jit_convert::visit_identifier (tree_identifier& ti) { - fail (); + std::string name = ti.name (); + variable_map::iterator iter = variables.find (name); + jit_value *var; + if (iter == variables.end ()) + { + octave_value var_value = ti.do_lookup (); + jit_type *var_type = jit_typeinfo::type_of (var_value); + var = entry_block->append (new jit_extract_argument (var_type, name)); + constants.push_back (var); + bounds.push_back (std::make_pair (var_type, name)); + variables[name] = var; + arguments.push_back (std::make_pair (name, true)); + } + else + var = iter->second; + + const jit_function& fn = jit_typeinfo::grab (); + result = block->append (new jit_call (fn, var)); } void -jit_generator::visit_binary_expression (tree_binary_expression& be) -{ - tree_expression *lhs = be.lhs (); - lhs->accept (*this); - value lhsv = value_stack.back (); - value_stack.pop_back (); - - tree_expression *rhs = be.rhs (); - rhs->accept (*this); - value rhsv = value_stack.back (); - value_stack.pop_back (); - - const jit_function::overload& ol - = tinfo->binary_op_overload (be.op_type (), lhsv.first, rhsv.first); - - if (! ol.function) - fail (); - - llvm::Value *result = builder.CreateCall2 (ol.function, lhsv.second, - rhsv.second); - push_value (ol.result, result); -} - -void -jit_generator::visit_break_command (tree_break_command&) -{ - fail (); -} - -void -jit_generator::visit_colon_expression (tree_colon_expression&) -{ - fail (); -} - -void -jit_generator::visit_continue_command (tree_continue_command&) -{ - fail (); -} - -void -jit_generator::visit_global_command (tree_global_command&) +jit_convert::visit_if_clause (tree_if_clause&) { fail (); } void -jit_generator::visit_persistent_command (tree_persistent_command&) -{ - fail (); -} - -void -jit_generator::visit_decl_elt (tree_decl_elt&) -{ - fail (); -} - -void -jit_generator::visit_decl_init_list (tree_decl_init_list&) +jit_convert::visit_if_command (tree_if_command&) { fail (); } void -jit_generator::visit_simple_for_command (tree_simple_for_command& cmd) -{ - if (is_lvalue) - fail (); - - tree_expression *control = cmd.control_expr (); - assert (control); // FIXME: jwe, can this be null? - - control->accept (*this); - value over = value_stack.back (); - value_stack.pop_back (); - - emit_simple_for (cmd, over, false); -} - -void -jit_generator::visit_complex_for_command (tree_complex_for_command&) +jit_convert::visit_if_command_list (tree_if_command_list&) { fail (); } void -jit_generator::visit_octave_user_script (octave_user_script&) +jit_convert::visit_index_expression (tree_index_expression&) { fail (); } void -jit_generator::visit_octave_user_function (octave_user_function&) -{ - fail (); -} - -void -jit_generator::visit_octave_user_function_header (octave_user_function&) -{ - fail (); -} - -void -jit_generator::visit_octave_user_function_trailer (octave_user_function&) +jit_convert::visit_matrix (tree_matrix&) { fail (); } void -jit_generator::visit_function_def (tree_function_def&) +jit_convert::visit_cell (tree_cell&) { fail (); } void -jit_generator::visit_identifier (tree_identifier& ti) +jit_convert::visit_multi_assignment (tree_multi_assignment&) { - std::string name = ti.name (); - value variable = variables[name]; - if (is_lvalue) - { - value_stack.push_back (variable); - - const jit_function::overload& ol = tinfo->release (variable.first); - if (ol.function) - { - llvm::Value *load = builder.CreateLoad (variable.second, name); - builder.CreateCall (ol.function, load); - } - } - else - { - llvm::Value *load = builder.CreateLoad (variable.second, name); - push_value (variable.first, load); - - const jit_function::overload& ol = tinfo->grab (variable.first); - if (ol.function) - builder.CreateCall (ol.function, load); - } + fail (); } void -jit_generator::visit_if_clause (tree_if_clause&) +jit_convert::visit_no_op_command (tree_no_op_command&) { fail (); } void -jit_generator::visit_if_command (tree_if_command& cmd) +jit_convert::visit_constant (tree_constant& tc) { - tree_if_command_list *lst = cmd.cmd_list (); - assert (lst); - lst->accept (*this); + octave_value v = tc.rvalue1 (); + if (v.is_real_scalar () && v.is_double_type ()) + { + double dv = v.double_value (); + result = get_scalar (dv); + } + else if (v.is_range ()) + fail (); + else + fail (); } void -jit_generator::visit_if_command_list (tree_if_command_list& lst) -{ - llvm::LLVMContext &ctx = llvm::getGlobalContext (); - llvm::BasicBlock *tail = llvm::BasicBlock::Create (ctx, "if_tail", function); - std::vector clause_entry (lst.size ()); - tree_if_command_list::iterator p; - size_t i; - for (p = lst.begin (), i = 0; p != lst.end (); ++p, ++i) - { - tree_if_clause *tic = *p; - if (tic->is_else_clause ()) - clause_entry[i] = llvm::BasicBlock::Create (ctx, "else_body", function, - tail); - else - clause_entry[i] = llvm::BasicBlock::Create (ctx, "if_cond", function, - tail); - } - - builder.CreateBr (clause_entry[0]); - - for (p = lst.begin (), i = 0; p != lst.end (); ++p, ++i) - { - tree_if_clause *tic = *p; - llvm::BasicBlock *body; - if (tic->is_else_clause ()) - body = clause_entry[i]; - else - { - llvm::BasicBlock *cond = clause_entry[i]; - builder.SetInsertPoint (cond); - - tree_expression *expr = tic->condition (); - expr->accept (*this); - - // FIXME: Handle undefined case - value condv = value_stack.back (); - value_stack.pop_back (); - - const jit_function::overload& ol = tinfo->get_logically_true (condv.first); - if (! ol.function) - fail (); - - bool last = i + 1 == clause_entry.size (); - llvm::BasicBlock *next = last ? tail : clause_entry[i + 1]; - body = llvm::BasicBlock::Create (ctx, "if_body", function, tail); - - llvm::Value *is_true = builder.CreateCall (ol.function, condv.second); - builder.CreateCondBr (is_true, body, next); - } - - tree_statement_list *stmt_lst = tic->commands (); - builder.SetInsertPoint (body); - stmt_lst->accept (*this); - builder.CreateBr (tail); - } - - builder.SetInsertPoint (tail); -} - -void -jit_generator::visit_index_expression (tree_index_expression&) +jit_convert::visit_fcn_handle (tree_fcn_handle&) { fail (); } void -jit_generator::visit_matrix (tree_matrix&) +jit_convert::visit_parameter_list (tree_parameter_list&) { fail (); } void -jit_generator::visit_cell (tree_cell&) -{ - fail (); -} - -void -jit_generator::visit_multi_assignment (tree_multi_assignment&) -{ - fail (); -} - -void -jit_generator::visit_no_op_command (tree_no_op_command&) +jit_convert::visit_postfix_expression (tree_postfix_expression&) { fail (); } void -jit_generator::visit_constant (tree_constant& tc) -{ - octave_value v = tc.rvalue1 (); - llvm::LLVMContext& ctx = llvm::getGlobalContext (); - if (v.is_real_scalar () && v.is_double_type ()) - { - double dv = v.double_value (); - llvm::Value *lv = llvm::ConstantFP::get (ctx, llvm::APFloat (dv)); - push_value (tinfo->get_scalar (), lv); - } - else if (v.is_range ()) - { - Range rng = v.range_value (); - llvm::Type *range = tinfo->get_range_llvm (); - llvm::Type *scalar = tinfo->get_scalar_llvm (); - llvm::Type *index = tinfo->get_index_llvm (); - - std::vector values (4); - values[0] = llvm::ConstantFP::get (scalar, rng.base ()); - values[1] = llvm::ConstantFP::get (scalar, rng.limit ()); - values[2] = llvm::ConstantFP::get (scalar, rng.inc ()); - values[3] = llvm::ConstantInt::get (index, rng.nelem ()); - - llvm::StructType *llvm_range = llvm::cast(range); - llvm::Value *lv = llvm::ConstantStruct::get (llvm_range, values); - push_value (tinfo->get_range (), lv); - } - else - fail (); -} - -void -jit_generator::visit_fcn_handle (tree_fcn_handle&) +jit_convert::visit_prefix_expression (tree_prefix_expression&) { fail (); } void -jit_generator::visit_parameter_list (tree_parameter_list&) +jit_convert::visit_return_command (tree_return_command&) +{ + fail (); +} + +void +jit_convert::visit_return_list (tree_return_list&) { fail (); } void -jit_generator::visit_postfix_expression (tree_postfix_expression&) +jit_convert::visit_simple_assignment (tree_simple_assignment& tsa) { - fail (); -} - -void -jit_generator::visit_prefix_expression (tree_prefix_expression&) -{ - fail (); -} + // resolve rhs + tree_expression *rhs = tsa.right_hand_side (); + jit_value *rhsv = visit (rhs); -void -jit_generator::visit_return_command (tree_return_command&) -{ - fail (); -} + // resolve lhs + tree_expression *lhs = tsa.left_hand_side (); + if (! lhs->is_identifier ()) + fail (); -void -jit_generator::visit_return_list (tree_return_list&) -{ - fail (); + std::string lhs_name = lhs->name (); + do_assign (lhs_name, rhsv, tsa.print_result ()); + result = rhsv; + + if (jit_instruction *instr = dynamic_cast(rhsv)) + instr->stash_tag (lhs_name); } void -jit_generator::visit_simple_assignment (tree_simple_assignment& tsa) -{ - if (is_lvalue) - fail (); - - // resolve rhs - tree_expression *rhs = tsa.right_hand_side (); - rhs->accept (*this); - - value rhsv = value_stack.back (); - value_stack.pop_back (); - - // resolve lhs - is_lvalue = true; - tree_expression *lhs = tsa.left_hand_side (); - lhs->accept (*this); - is_lvalue = false; - - value lhsv = value_stack.back (); - value_stack.pop_back (); - - // do assign, then keep rhs as the result - builder.CreateStore (rhsv.second, lhsv.second); - - if (tsa.print_result ()) - emit_print (lhs->name (), rhsv); - - value_stack.push_back (rhsv); -} - -void -jit_generator::visit_statement (tree_statement& stmt) +jit_convert::visit_statement (tree_statement& stmt) { tree_command *cmd = stmt.command (); tree_expression *expr = stmt.expression (); if (cmd) - cmd->accept (*this); + visit (cmd); else { // stolen from tree_evaluator::visit_statement @@ -1487,208 +1011,243 @@ else do_bind_ans = (! expr->is_assignment_expression ()); - expr->accept (*this); + jit_value *expr_result = visit (expr); if (do_bind_ans) - { - value rhs = value_stack.back (); - value ans = variables["ans"]; - if (ans.first != rhs.first) - fail (); - - builder.CreateStore (rhs.second, ans.second); - - if (expr->print_result ()) - emit_print ("ans", rhs); - } + do_assign ("ans", expr_result, expr->print_result ()); else if (expr->is_identifier () && expr->print_result ()) { // FIXME: ugly hack, we need to come up with a way to pass // nargout to visit_identifier - emit_print (expr->name (), value_stack.back ()); + const jit_function& fn = jit_typeinfo::print_value (); + jit_const_string *name = get_string (expr->name ()); + block->append (new jit_call (fn, name, expr_result)); } - - - value_stack.pop_back (); } } void -jit_generator::visit_statement_list (tree_statement_list& lst) +jit_convert::visit_statement_list (tree_statement_list&) { - tree_statement_list::iterator iter; - for (iter = lst.begin (); iter != lst.end (); ++iter) - { - tree_statement *stmt = *iter; - assert (stmt); // FIXME: jwe can this be null? - stmt->accept (*this); - } + fail (); } void -jit_generator::visit_switch_case (tree_switch_case&) +jit_convert::visit_switch_case (tree_switch_case&) { fail (); } void -jit_generator::visit_switch_case_list (tree_switch_case_list&) +jit_convert::visit_switch_case_list (tree_switch_case_list&) { fail (); } void -jit_generator::visit_switch_command (tree_switch_command&) +jit_convert::visit_switch_command (tree_switch_command&) { fail (); } void -jit_generator::visit_try_catch_command (tree_try_catch_command&) +jit_convert::visit_try_catch_command (tree_try_catch_command&) { fail (); } void -jit_generator::visit_unwind_protect_command (tree_unwind_protect_command&) +jit_convert::visit_unwind_protect_command (tree_unwind_protect_command&) { fail (); } void -jit_generator::visit_while_command (tree_while_command&) +jit_convert::visit_while_command (tree_while_command&) { fail (); } void -jit_generator::visit_do_until_command (tree_do_until_command&) +jit_convert::visit_do_until_command (tree_do_until_command&) { fail (); } void -jit_generator::emit_simple_for (tree_simple_for_command& cmd, value over, - bool atleast_once) +jit_convert::do_assign (const std::string& lhs, jit_value *rhs, bool print) { - if (is_lvalue) - fail (); + variable_map::iterator iter = variables.find (lhs); + if (iter == variables.end ()) + arguments.push_back (std::make_pair (lhs, false)); + else + { + const jit_function& fn = jit_typeinfo::release (); + block->append (new jit_call (fn, iter->second)); + } - jit_type *index = tinfo->get_index (); - llvm::Value *init_index = 0; - if (over.first == tinfo->get_range ()) - init_index = llvm::ConstantInt::get (index->to_llvm (), 0); - else - fail (); + variables[lhs] = rhs; - llvm::Value *llvm_index = builder.CreateAlloca (index->to_llvm (), 0, "index"); - builder.CreateStore (init_index, llvm_index); + if (print) + { + const jit_function& fn = jit_typeinfo::print_value (); + jit_const_string *name = get_string (lhs); + block->append (new jit_call (fn, name, rhs)); + } +} + +jit_value * +jit_convert::visit (tree& tee) +{ + result = 0; + tee.accept (*this); - // FIXME: Support break - llvm::LLVMContext &ctx = llvm::getGlobalContext (); - llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "for_body", function); - llvm::BasicBlock *cond_check = llvm::BasicBlock::Create (ctx, "for_check", function); - llvm::BasicBlock *tail = llvm::BasicBlock::Create (ctx, "for_tail", function); + jit_value *ret = result; + result = 0; + return ret; +} - // initialize the iter from the index - if (atleast_once) - builder.CreateBr (body); - else - builder.CreateBr (cond_check); - - builder.SetInsertPoint (body); +// -------------------- jit_convert::convert_llvm -------------------- +llvm::Function * +jit_convert::convert_llvm::convert (llvm::Module *module, + const std::vector >& args, + const std::list& blocks, + const std::list& constants) +{ + jit_type *any = jit_typeinfo::get_any (); - is_lvalue = true; - tree_expression *lhs = cmd.left_hand_side (); - lhs->accept (*this); - is_lvalue = false; + // argument is an array of octave_base_value*, or octave_base_value** + llvm::Type *arg_type = any->to_llvm (); // this is octave_base_value* + arg_type = arg_type->getPointerTo (); + llvm::FunctionType *ft = llvm::FunctionType::get (llvm::Type::getVoidTy (context), + arg_type, false); + llvm::Function *function = llvm::Function::Create (ft, + llvm::Function::ExternalLinkage, + "foobar", module); - value lhsv = value_stack.back (); - value_stack.pop_back (); + try + { + llvm::BasicBlock *prelude = llvm::BasicBlock::Create (context, "prelude", + function); + builder.SetInsertPoint (prelude); - const jit_function::overload& index_ol = tinfo->get_simple_for_index (over.first); - llvm::Value *lindex = builder.CreateLoad (llvm_index); - llvm::Value *llvm_iter = builder.CreateCall2 (index_ol.function, over.second, lindex); - value iter(index_ol.result, llvm_iter); - builder.CreateStore (iter.second, lhsv.second); + llvm::Value *arg = function->arg_begin (); + for (size_t i = 0; i < args.size (); ++i) + { + llvm::Value *loaded_arg = builder.CreateConstInBoundsGEP1_32 (arg, i); + arguments[args[i].first] = loaded_arg; + } - tree_statement_list *lst = cmd.body (); - lst->accept (*this); + // we need to generate llvm values for constants, as these don't appear in + // a block + for (std::list::const_iterator iter = constants.begin (); + iter != constants.end (); ++iter) + { + jit_value *constant = *iter; + if (! dynamic_cast (constant)) + visit (constant); + } - llvm::Value *one = llvm::ConstantInt::get (index->to_llvm (), 1); - lindex = builder.CreateLoad (llvm_index); - lindex = builder.CreateAdd (lindex, one); - builder.CreateStore (lindex, llvm_index); - builder.CreateBr (cond_check); + std::list::const_iterator biter; + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block *jblock = *biter; + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, jblock->name (), + function); + jblock->stash_llvm (block); + } + + jit_block *first = *blocks.begin (); + builder.CreateBr (first->to_llvm ()); - builder.SetInsertPoint (cond_check); - lindex = builder.CreateLoad (llvm_index); - const jit_function::overload& check_ol = tinfo->get_simple_for_check (over.first); - llvm::Value *cond = builder.CreateCall2 (check_ol.function, over.second, lindex); - builder.CreateCondBr (cond, body, tail); + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + visit (*biter); - builder.SetInsertPoint (tail); + builder.CreateRetVoid (); + } catch (const jit_fail_exception&) + { + function->eraseFromParent (); + throw; + } + + llvm::verifyFunction (*function); + + return function; } void -jit_generator::emit_print (const std::string& name, const value& v) +jit_convert::convert_llvm::visit_const_string (jit_const_string& cs) +{ + cs.stash_llvm (builder.CreateGlobalStringPtr (cs.value ())); +} + +void +jit_convert::convert_llvm::visit_const_scalar (jit_const_scalar& cs) +{ + llvm::Type *dbl = llvm::Type::getDoubleTy (context); + cs.stash_llvm (llvm::ConstantFP::get (dbl, cs.value ())); +} + +void +jit_convert::convert_llvm::visit_block (jit_block& b) { - const jit_function::overload& ol = tinfo->print_value (v.first); - if (! ol.function) - fail (); + llvm::BasicBlock *block = b.to_llvm (); + builder.SetInsertPoint (block); + for (jit_block::iterator iter = b.begin (); iter != b.end (); ++iter) + visit (*iter); +} - llvm::Value *str = builder.CreateGlobalStringPtr (name); - builder.CreateCall2 (ol.function, str, v.second); +void +jit_convert::convert_llvm::visit_break (jit_break& b) +{ + builder.CreateBr (b.sucessor_llvm ()); +} + +void +jit_convert::convert_llvm::visit_cond_break (jit_cond_break& cb) +{ + llvm::Value *cond = cb.cond_llvm (); + builder.CreateCondBr (cond, cb.sucessor_llvm (0), cb.sucessor_llvm (1)); } void -jit_generator::initialize (const std::vector& names, - const std::vector& argin, - const std::vector types) +jit_convert::convert_llvm::visit_call (jit_call& call) { - std::vector arg_types (names.size ()); - for (size_t i = 0; i < types.size (); ++i) - arg_types[i] = types[i]->to_llvm_arg (); - - llvm::LLVMContext &ctx = llvm::getGlobalContext (); - llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); - llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false); - function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, - "foobar", module); + const jit_function::overload& ol = call.overload (); + if (! ol.function) + fail (); + + std::vector args (call.argument_count ()); + for (size_t i = 0; i < call.argument_count (); ++i) + args[i] = call.argument_llvm (i); - // create variables and copy initial values - llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function); - builder.SetInsertPoint (body); - llvm::Function::arg_iterator arg_iter = function->arg_begin(); - for (size_t i = 0; i < names.size (); ++i, ++arg_iter) - { - llvm::Type *vartype = types[i]->to_llvm (); - const std::string& name = names[i]; - llvm::Value *var = builder.CreateAlloca (vartype, 0, name); - variables[name] = value (types[i], var); - - if (argin[i] || types[i]->force_init ()) - { - llvm::Value *loaded_arg = builder.CreateLoad (arg_iter); - builder.CreateStore (loaded_arg, var); - } - } + call.stash_llvm (builder.CreateCall (ol.function, args)); } void -jit_generator::finalize (const std::vector& names) +jit_convert::convert_llvm::visit_extract_argument (jit_extract_argument& extract) { - // copy computed values back into arguments - // we use names instead of looping through variables because order is - // important - llvm::Function::arg_iterator arg_iter = function->arg_begin(); - for (size_t i = 0; i < names.size (); ++i, ++arg_iter) - { - llvm::Value *var = variables[names[i]].second; - llvm::Value *loaded_var = builder.CreateLoad (var); - builder.CreateStore (loaded_var, arg_iter); - } - builder.CreateRetVoid (); + const jit_function::overload& ol = extract.overload (); + if (! ol.function) + fail (); + + llvm::Value *arg = arguments[extract.tag ()]; + arg = builder.CreateLoad (arg); + extract.stash_llvm (builder.CreateCall (ol.function, arg)); +} + +void +jit_convert::convert_llvm::visit_store_argument (jit_store_argument& store) +{ + llvm::Value *arg_value = store.result_llvm (); + const jit_function::overload& ol = store.overload (); + if (! ol.function) + fail (); + + arg_value = builder.CreateCall (ol.function, arg_value); + + llvm::Value *arg = arguments[store.tag ()]; + store.stash_llvm (builder.CreateStore (arg_value, arg)); } // -------------------- tree_jit -------------------- @@ -1700,25 +1259,33 @@ } tree_jit::~tree_jit (void) -{ - delete tinfo; -} +{} bool -tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds) +tree_jit::execute (tree& cmd) { if (! initialize ()) return false; - jit_type *bounds_t = tinfo->type_of (bounds); - jit_info *jinfo = cmd.get_info (bounds_t); + compiled_map::iterator iter = compiled.find (&cmd); + jit_info *jinfo = 0; + if (iter != compiled.end ()) + { + jinfo = iter->second; + if (! jinfo->match ()) + { + delete jinfo; + jinfo = 0; + } + } + if (! jinfo) { - jinfo = new jit_info (*this, cmd, bounds_t); - cmd.stash_info (bounds_t, jinfo); + jinfo = new jit_info (*this, cmd); + compiled[&cmd] = jinfo; } - return jinfo->execute (bounds); + return jinfo->execute (); } bool @@ -1746,7 +1313,7 @@ pass_manager->add (llvm::createCFGSimplificationPass ()); pass_manager->doInitialization (); - tinfo = new jit_typeinfo (module, engine); + jit_typeinfo::initialize (module, engine); return true; } @@ -1760,106 +1327,80 @@ } // -------------------- jit_info -------------------- -jit_info::jit_info (tree_jit& tjit, tree_simple_for_command& cmd, - jit_type *bounds) : tinfo (tjit.get_typeinfo ()), - engine (tjit.get_engine ()), - bounds_t (bounds) +jit_info::jit_info (tree_jit& tjit, tree& tee) + : engine (tjit.get_engine ()) { - jit_infer infer(tinfo); - + llvm::Function *fun = 0; try { - infer.infer (cmd, bounds); + jit_convert conv (tjit.get_module (), tee); + fun = conv.get_function (); + arguments = conv.get_arguments (); + bounds = conv.get_bounds (); } catch (const jit_fail_exception&) + {} + + if (! fun) { function = 0; return; } - types = infer.get_types (); - - jit_generator gen(tinfo, tjit.get_module (), cmd, bounds, types); - function = gen.get_function (); - - if (function) - { - if (debug_print) - { - std::cout << "Compiled code:\n"; - std::cout << cmd.str_print_code () << std::endl; - - std::cout << "Before optimization:\n"; + tjit.optimize (fun); - llvm::raw_os_ostream os (std::cout); - function->print (os); - } - llvm::verifyFunction (*function); - tjit.optimize (function); + if (debug_print) + { + std::cout << "-------------------- optimized llvm ir --------------------\n"; + llvm::raw_os_ostream llvm_cout (std::cout); + fun->print (llvm_cout); + std::cout << std::endl; + } - if (debug_print) - { - std::cout << "After optimization:\n"; - - llvm::raw_os_ostream os (std::cout); - function->print (os); - } - } + function = reinterpret_cast(engine->getPointerToFunction (fun)); } bool -jit_info::execute (const octave_value& bounds) const +jit_info::execute (void) const { if (! function) return false; - std::vector args (types.size () + 1); - tinfo->to_generic (bounds_t, args[0], bounds); - - size_t idx; - type_map::const_iterator iter; - for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx) + std::vector real_arguments (arguments.size ()); + for (size_t i = 0; i < arguments.size (); ++i) { - if (iter->second.first) // argin? + if (arguments[i].second) { - octave_value ov = iter->first->varval (); - tinfo->to_generic (iter->second.second, args[idx], ov); + octave_value current = symbol_table::varval (arguments[i].first); + octave_base_value *obv = current.internal_rep (); + obv->grab (); + real_arguments[i] = obv; } - else - tinfo->to_generic (iter->second.second, args[idx]); } - engine->runFunction (function, args); + function (&real_arguments[0]); - for (idx = 1, iter = types.begin (); iter != types.end (); ++iter, ++idx) - { - octave_value result = tinfo->to_octave_value (iter->second.second, args[idx]); - octave_value &ref = iter->first->varref (); - ref = result; - } - - tinfo->reset_generic (); + for (size_t i = 0; i < arguments.size (); ++i) + symbol_table::varref (arguments[i].first) = real_arguments[i]; return true; } bool -jit_info::match () const +jit_info::match (void) const { - for (type_map::const_iterator iter = types.begin (); iter != types.end (); - ++iter) - + if (! function) + return true; + + for (size_t i = 0; i < bounds.size (); ++i) { - if (iter->second.first) // argin? - { - jit_type *required_type = iter->second.second; - octave_value val = iter->first->varval (); - jit_type *current_type = tinfo->type_of (val); + const std::string& arg_name = bounds[i].second; + octave_value value = symbol_table::varval (arg_name); + jit_type *type = jit_typeinfo::type_of (value); - // FIXME: should be: ! required_type->is_parent (current_type) - if (required_type != current_type) - return false; - } + // FIXME: Check for a parent relationship + if (type != bounds[i].first) + return false; } return true; diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -46,18 +46,25 @@ // b = a + a; // will compile to do_binary_op (a, a). // -// for loops with ranges compile. For example, -// for i=1:1000 -// result = i + 1; -// endfor -// Will compile. Nested for loops with constant bounds are also supported. +// for loops and if statements no longer compile! This is because work has been +// done to introduce a new lower level IR for octave. The low level IR looks +// a lot like llvm's IR, but it makes it much easier to infer types. You can set +// debug_print to true in pt-jit.cc to view the IRs that are created. // -// If statements/comparisons compile, but && and || do not. +// The octave low level IR is a linear IR, it works by converting everything to +// calls to jit_functions. This turns expressions like c = a + b into +// c = call binary+ (a, b) +// The jit_functions contain information about overloads for differnt types. For +// example, if we know a and b are scalars, then c must also be a scalar. +// +// You will currently see a LARGE slowdown, as every statement is compiled +// seperatly! // // TODO: -// 1. Support iteration over matricies -// 2. Check error state -// 3. ... +// 1. Support for loops +// 2. Support if statements +// 3. Cleanup/documentation +// 4. ... // --------------------------------------------------------- @@ -113,37 +120,44 @@ jit_type { public: - jit_type (const std::string& n, bool fi, jit_type *mparent, llvm::Type *lt, - int tid) : - mname (n), finit (fi), p (mparent), llvm_type (lt), id (tid) + jit_type (const std::string& aname, jit_type *aparent, llvm::Type *allvm_type, + int aid) : + mname (aname), mparent (aparent), llvm_type (allvm_type), mid (aid), + mdepth (aparent ? aparent->mdepth + 1 : 0) {} // a user readable type name const std::string& name (void) const { return mname; } - // do we need to initialize variables of this type, even if they are not - // input arguments? - bool force_init (void) const { return finit; } - // a unique id for the type - int type_id (void) const { return id; } + int type_id (void) const { return mid; } // An abstract base type, may be null - jit_type *parent (void) const { return p; } + jit_type *parent (void) const { return mparent; } // convert to an llvm type llvm::Type *to_llvm (void) const { return llvm_type; } // how this type gets passed as a function argument llvm::Type *to_llvm_arg (void) const; + + size_t depth (void) const { return mdepth; } private: std::string mname; - bool finit; - jit_type *p; + jit_type *mparent; llvm::Type *llvm_type; - int id; + int mid; + size_t mdepth; }; +// seperate print function to allow easy printing if type is null +static std::ostream& jit_print (std::ostream& os, jit_type *atype) +{ + if (! atype) + return os << "null"; + return os << atype->name (); +} + // Keeps track of overloads for a builtin function. Used for both type inference // and code generation. class @@ -223,10 +237,16 @@ const overload& temp = get_overload (arg0, arg1); return temp.result; } + + const std::string& name (void) const { return mname; } + + void stash_name (const std::string& aname) { mname = aname; } private: Array to_idx (const std::vector& types) const; std::vector > overloads; + + std::string mname; }; // Get information and manipulate jit types. @@ -234,84 +254,160 @@ jit_typeinfo { public: + static void initialize (llvm::Module *m, llvm::ExecutionEngine *e); + + static jit_type *tunion (jit_type *lhs, jit_type *rhs) + { + return instance->do_union (lhs, rhs); + } + + static jit_type *difference (jit_type *lhs, jit_type *rhs) + { + return instance->do_difference (lhs, rhs); + } + + static jit_type *get_any (void) { return instance->any; } + + static jit_type *get_scalar (void) { return instance->scalar; } + + static jit_type *get_range (void) { return instance->range; } + + static jit_type *get_string (void) { return instance->string; } + + static jit_type *get_bool (void) { return instance->boolean; } + + static jit_type *get_index (void) { return instance->index; } + + static jit_type *type_of (const octave_value& ov) + { + return instance->do_type_of (ov); + } + + static const jit_function& binary_op (int op) + { + return instance->do_binary_op (op); + } + + static const jit_function& grab (void) { return instance->grab_fn; } + + static const jit_function& release (void) + { + return instance->release_fn; + } + + static const jit_function& print_value (void) + { + return instance->print_fn; + } + + static const jit_function& cast (jit_type *result) + { + return instance->do_cast (result); + } + + static const jit_function::overload& cast (jit_type *to, jit_type *from) + { + return instance->do_cast (to, from); + } + + static void to_generic (jit_type *type, llvm::GenericValue& gv) + { + return instance->do_to_generic (type, gv); + } + + static void to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov) + { + return instance->do_to_generic (type, gv, ov); + } + + static octave_value to_octave_value (jit_type *type, llvm::GenericValue& gv) + { + return instance->do_to_octave_value (type, gv); + } + + static void reset_generic (void) + { + instance->do_reset_generic (); + } +private: jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e); - jit_type *get_any (void) const { return any; } - - jit_type *get_scalar (void) const { return scalar; } - - llvm::Type *get_scalar_llvm (void) const { return scalar->to_llvm (); } - - jit_type *get_range (void) const { return range; } + // FIXME: Do these methods really need to be in jit_typeinfo? + jit_type *do_union (jit_type *lhs, jit_type *rhs) + { + // FIXME: Actually introduce a union type - llvm::Type *get_range_llvm (void) const { return range->to_llvm (); } - - jit_type *get_bool (void) const { return boolean; } + // empty case + if (! lhs) + return rhs; - jit_type *get_index (void) const { return index; } - - llvm::Type *get_index_llvm (void) const { return index->to_llvm (); } - - jit_type *type_of (const octave_value& ov) const; + if (! rhs) + return lhs; - const jit_function& binary_op (int op) const; - - const jit_function::overload& binary_op_overload (int op, jit_type *lhs, - jit_type *rhs) const - { - const jit_function& jf = binary_op (op); - return jf.get_overload (lhs, rhs); - } + // check for a shared parent + while (lhs != rhs) + { + if (lhs->depth () > rhs->depth ()) + lhs = lhs->parent (); + else if (lhs->depth () < rhs->depth ()) + rhs = rhs->parent (); + else + { + // we MUST have depth > 0 as any is the base type of everything + do + { + lhs = lhs->parent (); + rhs = rhs->parent (); + } + while (lhs != rhs); + } + } - jit_type *binary_op_result (int op, jit_type *lhs, jit_type *rhs) const - { - const jit_function::overload& ol = binary_op_overload (op, lhs, rhs); - return ol.result; - } - - const jit_function::overload& grab (jit_type *ty) const - { - return grab_fn.get_overload (ty); + return lhs; } - const jit_function::overload& release (jit_type *ty) const + jit_type *do_difference (jit_type *lhs, jit_type *) { - return release_fn.get_overload (ty); + // FIXME: Maybe we can do something smarter? + return lhs; } - const jit_function::overload& print_value (jit_type *to_print) const; - - const jit_function::overload& get_simple_for_check (jit_type *bounds) const + jit_type *do_type_of (const octave_value &ov) const; + + const jit_function& do_binary_op (int op) const { - return simple_for_check.get_overload (bounds, index); - } - - const jit_function::overload& get_simple_for_index (jit_type *bounds) const - { - return simple_for_index.get_overload (bounds, index); + assert (static_cast(op) < binary_ops.size ()); + return binary_ops[op]; } - jit_type *get_simple_for_index_result (jit_type *bounds) const + void do_to_generic (jit_type *type, llvm::GenericValue& gv); + + void do_to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov); + + octave_value do_to_octave_value (jit_type *type, llvm::GenericValue& gv); + + void do_reset_generic (void); + + const jit_function& do_cast (jit_type *to) { - const jit_function::overload& ol = get_simple_for_index (bounds); - return ol.result; - } + static jit_function null_function; + if (! to) + return null_function; - const jit_function::overload& get_logically_true (jit_type *conv) const - { - return logically_true.get_overload (conv); + size_t id = to->type_id (); + if (id >= casts.size ()) + return null_function; + return casts[id]; } - void to_generic (jit_type *type, llvm::GenericValue& gv); - - void to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov); - - octave_value to_octave_value (jit_type *type, llvm::GenericValue& gv); - - void reset_generic (void); -private: - jit_type *new_type (const std::string& name, bool force_init, - jit_type *parent, llvm::Type *llvm_type); + const jit_function::overload& do_cast (jit_type *to, jit_type *from) + { + return do_cast (to).get_overload (from); + } + + jit_type *new_type (const std::string& name, jit_type *parent, + llvm::Type *llvm_type); + void add_print (jit_type *ty, void *call); @@ -372,6 +468,10 @@ llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, const std::vector& args); + llvm::Function *create_identity (jit_type *type); + + static jit_typeinfo *instance; + llvm::Module *module; llvm::ExecutionEngine *engine; int next_id; @@ -382,6 +482,7 @@ jit_type *any; jit_type *scalar; jit_type *range; + jit_type *string; jit_type *boolean; jit_type *index; @@ -394,27 +495,566 @@ jit_function simple_for_index; jit_function logically_true; + // type id -> cast function TO that type + std::vector casts; + + // type id -> identity function + std::vector identities; + std::list scalar_out; std::list ov_out; std::list range_out; }; +// The low level octave jit ir +// this ir is close to llvm, but contains information for doing type inference. +// We convert the octave parse tree to this IR directly. + +#define JIT_VISIT_IR_CLASSES \ + JIT_METH(const_string); \ + JIT_METH(const_scalar); \ + JIT_METH(block); \ + JIT_METH(break); \ + JIT_METH(cond_break); \ + JIT_METH(call); \ + JIT_METH(extract_argument); \ + JIT_METH(store_argument) + + +#define JIT_METH(clname) class jit_ ## clname +JIT_VISIT_IR_CLASSES; +#undef JIT_METH + class -jit_infer : public tree_walker +jit_ir_walker +{ +public: + virtual ~jit_ir_walker () {} + +#define JIT_METH(clname) \ + virtual void visit_ ## clname (jit_ ## clname&) = 0 + + JIT_VISIT_IR_CLASSES; + +#undef JIT_METH +}; + +class jit_use; + +class +jit_value +{ + friend class jit_use; +public: + jit_value (void) : llvm_value (0), ty (0), use_head (0) {} + + virtual ~jit_value (void) {} + + jit_type *type () const { return ty; } + + void stash_type (jit_type *new_ty) { ty = new_ty; } + + jit_use *first_use (void) const { return use_head; } + + size_t use_count (void) const { return myuse_count; } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) = 0; + + virtual std::ostream& short_print (std::ostream& os) + { return print (os); } + + virtual void accept (jit_ir_walker& walker) = 0; + + llvm::Value *to_llvm (void) const + { + return llvm_value; + } + + void stash_llvm (llvm::Value *compiled) + { + llvm_value = compiled; + } +protected: + std::ostream& print_indent (std::ostream& os, size_t indent) + { + for (size_t i = 0; i < indent; ++i) + os << "\t"; + return os; + } + + llvm::Value *llvm_value; +private: + jit_type *ty; + jit_use *use_head; + size_t myuse_count; +}; + +// defnie accept methods for subclasses +#define JIT_VALUE_ACCEPT(clname) \ + virtual void accept (jit_ir_walker& walker) \ + { \ + walker.visit_ ## clname (*this); \ + } + +class +jit_const_string : public jit_value +{ +public: + jit_const_string (const std::string& v) : val (v) + { + stash_type (jit_typeinfo::get_string ()); + } + + const std::string& value (void) const { return val; } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + return print_indent (os, indent) << "string: \"" << val << "\""; + } + + JIT_VALUE_ACCEPT (const_string) +private: + std::string val; +}; + +class +jit_const_scalar : public jit_value +{ +public: + jit_const_scalar (double avalue) : mvalue (avalue) + { + stash_type (jit_typeinfo::get_scalar ()); + } + + double value (void) const { return mvalue; } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + return print_indent (os, indent) << "scalar: \"" << mvalue << "\""; + } + + JIT_VALUE_ACCEPT (const_scalar) +private: + double mvalue; +}; + +class jit_instruction; + +class +jit_use +{ +public: + jit_use (void) : used (0), next_use (0), prev_use (0) {} + + ~jit_use (void) { remove (); } + + jit_value *value (void) const { return used; } + + size_t index (void) const { return idx; } + + jit_instruction *user (void) const { return usr; } + + void stash_value (jit_value *new_value, jit_instruction *u = 0, + size_t use_idx = -1) + { + remove (); + + used = new_value; + + if (used) + { + if (used->use_head) + { + used->use_head->prev_use = this; + next_use = used->use_head; + } + + used->use_head = this; + ++used->myuse_count; + } + + idx = use_idx; + usr = u; + } + + jit_use *next (void) const { return next_use; } + + jit_use *prev (void) const { return prev_use; } +private: + void remove (void) + { + if (used) + { + if (this == used->use_head) + used->use_head = next_use; + + if (prev_use) + prev_use->next_use = next_use; + + if (next_use) + next_use->prev_use = prev_use; + + next_use = prev_use = 0; + --used->myuse_count; + } + } + + jit_value *used; + jit_use *next_use; + jit_use *prev_use; + jit_instruction *usr; + size_t idx; +}; + +class +jit_instruction : public jit_value { public: - // pair - typedef std::pair type_entry; - typedef std::map type_map; + // FIXME: this code could be so much pretier with varadic templates... +#define JIT_EXTRACT_ARG(idx) arguments[idx].stash_value (arg ## idx, this, idx) + + jit_instruction (void) : id (next_id ()) + { + } + + jit_instruction (jit_value *arg0) + : already_infered (1, reinterpret_cast(0)), arguments (1), + id (next_id ()) + { + JIT_EXTRACT_ARG (0); + } + + jit_instruction (jit_value *arg0, jit_value *arg1) + : already_infered (2, reinterpret_cast(0)), arguments (2), + id (next_id ()) + { + JIT_EXTRACT_ARG (0); + JIT_EXTRACT_ARG (1); + } + + jit_instruction (jit_value *arg0, jit_value *arg1, jit_value *arg2) + : already_infered (3, reinterpret_cast(0)), arguments (3), + id (next_id ()) + { + JIT_EXTRACT_ARG (0); + JIT_EXTRACT_ARG (1); + JIT_EXTRACT_ARG (2); + } + +#undef JIT_EXTRACT_ARG + + static void reset_ids (void) + { + next_id (true); + } + + jit_value *argument (size_t i) const + { + return arguments[i].value (); + } + + llvm::Value *argument_llvm (size_t i) const + { + return arguments[i].value ()->to_llvm (); + } + + jit_type *argument_type (size_t i) const + { + return arguments[i].value ()->type (); + } + + size_t argument_count (void) const + { + return arguments.size (); + } + + // argument types which have been infered already + const std::vector& argument_types (void) const + { return already_infered; } + + virtual bool infer (void) { return false; } + + virtual std::ostream& short_print (std::ostream& os) + { + if (mtag.empty ()) + jit_print (os, type ()) << ": #" << id; + else + jit_print (os, type ()) << ": " << mtag << "." << id; + + return os; + } + + const std::string& tag (void) const { return mtag; } + + void stash_tag (const std::string& atag) { mtag = atag; } +protected: + std::vector already_infered; +private: + static size_t next_id (bool reset = false) + { + static size_t ret = 0; + if (reset) + return ret = 0; + + return ret++; + } + + std::vector arguments; // DO NOT resize + + std::string mtag; + size_t id; +}; + +class +jit_block : public jit_value +{ +public: + typedef std::list instruction_list; + typedef instruction_list::iterator iterator; + typedef instruction_list::const_iterator const_iterator; + + jit_block (const std::string& n) : nm (n) {} + + virtual ~jit_block () + { + for (instruction_list::iterator iter = instructions.begin (); + iter != instructions.end (); ++iter) + delete *iter; + } + + const std::string& name (void) const { return nm; } + + jit_instruction *prepend (jit_instruction *instr) + { + instructions.push_front (instr); + return instr; + } + + jit_instruction *append (jit_instruction *instr) + { + instructions.push_back (instr); + return instr; + } + + iterator begin () { return instructions.begin (); } + + const_iterator begin () const { return instructions.begin (); } + + iterator end () { return instructions.end (); } + + const_iterator end () const { return instructions.begin (); } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + print_indent (os, indent) << nm << ":" << std::endl; + for (iterator iter = begin (); iter != end (); ++iter) + { + jit_instruction *instr = *iter; + instr->print (os, indent + 1) << std::endl; + } + return os; + } + + llvm::BasicBlock *to_llvm (void) const; + + JIT_VALUE_ACCEPT (block) +private: + std::string nm; + instruction_list instructions; +}; + +class jit_terminator : public jit_instruction +{ +public: + jit_terminator (jit_value *arg0) : jit_instruction (arg0) {} + + jit_terminator (jit_value *arg0, jit_value *arg1, jit_value *arg2) + : jit_instruction (arg0, arg1, arg2) {} + + virtual jit_block *sucessor (size_t idx = 0) const = 0; + + llvm::BasicBlock *sucessor_llvm (size_t idx = 0) const + { + return sucessor (idx)->to_llvm (); + } + + virtual size_t sucessor_count (void) const = 0; +}; - jit_infer (jit_typeinfo *ti) : tinfo (ti), is_lvalue (false), - rvalue_type (0) - {} +class +jit_break : public jit_terminator +{ +public: + jit_break (jit_block *succ) : jit_terminator (succ) {} + + jit_block *sucessor (size_t idx = 0) const + { + jit_value *arg = argument (idx); + return reinterpret_cast (arg); + } + + size_t sucessor_count (void) const { return 1; } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + jit_block *succ = sucessor (); + return print_indent (os, indent) << "break: " << succ->name (); + } + + JIT_VALUE_ACCEPT (break) +}; + +class +jit_cond_break : public jit_terminator +{ +public: + jit_cond_break (jit_value *c, jit_block *ctrue, jit_block *cfalse) + : jit_terminator (c, ctrue, cfalse) {} + + jit_value *cond (void) const { return argument (0); } + + llvm::Value *cond_llvm (void) const + { + return cond ()->to_llvm (); + } + + jit_block *sucessor (size_t idx) const + { + jit_value *arg = argument (idx + 1); + return reinterpret_cast (arg); + } + + size_t sucessor_count (void) const { return 2; } + + JIT_VALUE_ACCEPT (cond_break) +}; + +class +jit_call : public jit_instruction +{ +public: + jit_call (const jit_function& afunction, + jit_value *arg0) : jit_instruction (arg0), mfunction (afunction) {} + + jit_call (const jit_function& afunction, + jit_value *arg0, jit_value *arg1) : jit_instruction (arg0, arg1), + mfunction (afunction) {} + + const jit_function& function (void) const { return mfunction; } + + const jit_function::overload& overload (void) const + { + return mfunction.get_overload (argument_types ()); + } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + print_indent (os, indent); + + if (use_count ()) + short_print (os) << " = "; + os << "call " << mfunction.name () << " ("; + + for (size_t i = 0; i < argument_count (); ++i) + { + jit_value *arg = argument (i); + arg->short_print (os); + if (i + 1 < argument_count ()) + os << ", "; + } + return os << ")"; + } + + virtual bool infer (void); - const type_map& get_types () const { return types; } + JIT_VALUE_ACCEPT (call) +private: + const jit_function& mfunction; +}; + +class +jit_extract_argument : public jit_instruction +{ +public: + jit_extract_argument (jit_type *atype, const std::string& aname) + : jit_instruction () + { + stash_type (atype); + stash_tag (aname); + } + + const jit_function::overload& overload (void) const + { + return jit_typeinfo::cast (type (), jit_typeinfo::get_any ()); + } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + print_indent (os, indent); + return short_print (os) << " = extract: " << tag (); + } + + JIT_VALUE_ACCEPT (extract_argument) +}; + +class +jit_store_argument : public jit_instruction +{ +public: + jit_store_argument (const std::string& aname, jit_value *aresult) + : jit_instruction (aresult) + { + stash_tag (aname); + } - void infer (tree_simple_for_command& cmd, jit_type *bounds); + const jit_function::overload& overload (void) const + { + return jit_typeinfo::cast (jit_typeinfo::get_any (), result_type ()); + } + + jit_value *result (void) const + { + return argument (0); + } + + jit_type *result_type (void) const + { + return result ()->type (); + } + + llvm::Value *result_llvm (void) const + { + return result ()->to_llvm (); + } + + virtual std::ostream& print (std::ostream& os, size_t indent) + { + jit_value *res = result (); + print_indent (os, indent) << tag () << " <- "; + return res->short_print (os); + } + + JIT_VALUE_ACCEPT (store_argument) +}; + +// convert between IRs +// FIXME: Class relationships are messy from here on down. They need to be +// cleaned up. +class +jit_convert : public tree_walker +{ +public: + typedef std::pair type_bound; + typedef std::vector type_bound_vector; + + jit_convert (llvm::Module *module, tree &tee); + + llvm::Function *get_function (void) const { return function; } + + const std::vector >& get_arguments(void) const + { return arguments; } + + const type_bound_vector& get_bounds (void) const { return bounds; } void visit_anon_fcn_handle (tree_anon_fcn_handle&); @@ -502,146 +1142,87 @@ void visit_do_until_command (tree_do_until_command&); private: - void infer_simple_for (tree_simple_for_command& cmd, - jit_type *bounds); + std::vector > arguments; + type_bound_vector bounds; + + typedef std::map variable_map; + variable_map variables; + + // used instead of return values from visit_* functions + jit_value *result; + + jit_block *block; + jit_block *entry_block; + jit_block *final_block; + + llvm::Function *function; - void handle_identifier (const symbol_table::symbol_record_ref& record); + std::list blocks; + + std::list worklist; + + std::list constants; + + void do_assign (const std::string& lhs, jit_value *rhs, bool print); - void merge (type_map& dest, const type_map& src); + jit_value *visit (tree *tee) { return visit (*tee); } + + jit_value *visit (tree& tee); + + void append_users (jit_value *v) + { + for (jit_use *use = v->first_use (); use; use = use->next ()) + worklist.push_back (use->user ()); + } - jit_typeinfo *tinfo; + jit_const_scalar *get_scalar (double v) + { + jit_const_scalar *ret = new jit_const_scalar (v); + constants.push_back (ret); + return ret; + } + + jit_const_string *get_string (const std::string& v) + { + jit_const_string *ret = new jit_const_string (v); + constants.push_back (ret); + return ret; + } - bool is_lvalue; - jit_type *rvalue_type; + // this case is much simpler, just convert from the jit ir to llvm + class + convert_llvm : public jit_ir_walker + { + public: + llvm::Function *convert (llvm::Module *module, + const std::vector >& args, + const std::list& blocks, + const std::list& constants); + +#define JIT_METH(clname) \ + virtual void visit_ ## clname (jit_ ## clname&); + + JIT_VISIT_IR_CLASSES; - type_map types; +#undef JIT_METH + private: + // name -> llvm argument + std::map arguments; + - std::vector type_stack; + void visit (jit_value *jvalue) + { + return visit (*jvalue); + } + + void visit (jit_value &jvalue) + { + jvalue.accept (*this); + } + }; }; -class -jit_generator : public tree_walker -{ -public: - typedef jit_infer::type_map type_map; - - jit_generator (jit_typeinfo *ti, llvm::Module *mod, tree_simple_for_command &cmd, - jit_type *bounds, const type_map& infered_types); - - llvm::Function *get_function () const { return function; } - - void visit_anon_fcn_handle (tree_anon_fcn_handle&); - - void visit_argument_list (tree_argument_list&); - - void visit_binary_expression (tree_binary_expression&); - - void visit_break_command (tree_break_command&); - - void visit_colon_expression (tree_colon_expression&); - - void visit_continue_command (tree_continue_command&); - - void visit_global_command (tree_global_command&); - - void visit_persistent_command (tree_persistent_command&); - - void visit_decl_elt (tree_decl_elt&); - - void visit_decl_init_list (tree_decl_init_list&); - - void visit_simple_for_command (tree_simple_for_command&); - - void visit_complex_for_command (tree_complex_for_command&); - - void visit_octave_user_script (octave_user_script&); - - void visit_octave_user_function (octave_user_function&); - - void visit_octave_user_function_header (octave_user_function&); - - void visit_octave_user_function_trailer (octave_user_function&); - - void visit_function_def (tree_function_def&); - - void visit_identifier (tree_identifier&); - - void visit_if_clause (tree_if_clause&); - - void visit_if_command (tree_if_command&); - - void visit_if_command_list (tree_if_command_list&); - - void visit_index_expression (tree_index_expression&); - - void visit_matrix (tree_matrix&); - - void visit_cell (tree_cell&); - - void visit_multi_assignment (tree_multi_assignment&); - - void visit_no_op_command (tree_no_op_command&); - - void visit_constant (tree_constant&); - - void visit_fcn_handle (tree_fcn_handle&); - - void visit_parameter_list (tree_parameter_list&); - - void visit_postfix_expression (tree_postfix_expression&); - - void visit_prefix_expression (tree_prefix_expression&); - - void visit_return_command (tree_return_command&); - - void visit_return_list (tree_return_list&); - - void visit_simple_assignment (tree_simple_assignment&); - - void visit_statement (tree_statement&); - - void visit_statement_list (tree_statement_list&); - - void visit_switch_case (tree_switch_case&); - - void visit_switch_case_list (tree_switch_case_list&); - - void visit_switch_command (tree_switch_command&); - - void visit_try_catch_command (tree_try_catch_command&); - - void visit_unwind_protect_command (tree_unwind_protect_command&); - - void visit_while_command (tree_while_command&); - - void visit_do_until_command (tree_do_until_command&); -private: - typedef std::pair value; - - void emit_simple_for (tree_simple_for_command& cmd, value over, - bool atleast_once); - - void emit_print (const std::string& name, const value& v); - - void push_value (jit_type *type, llvm::Value *v) - { - value_stack.push_back (value (type, v)); - } - - void initialize (const std::vector& names, - const std::vector& argin, - const std::vector types); - - void finalize (const std::vector& names); - - jit_typeinfo *tinfo; - llvm::Module *module; - llvm::Function *function; - - bool is_lvalue; - std::map variables; - std::vector value_stack; -}; +class jit_info; class tree_jit @@ -651,9 +1232,7 @@ ~tree_jit (void); - bool execute (tree_simple_for_command& cmd, const octave_value& bounds); - - jit_typeinfo *get_typeinfo (void) const { return tinfo; } + bool execute (tree& cmd); llvm::ExecutionEngine *get_engine (void) const { return engine; } @@ -663,32 +1242,36 @@ private: bool initialize (void); + // FIXME: Temorary hack to test + typedef std::map compiled_map; + compiled_map compiled; + llvm::LLVMContext &context; llvm::Module *module; llvm::PassManager *module_pass_manager; llvm::FunctionPassManager *pass_manager; llvm::ExecutionEngine *engine; - - jit_typeinfo *tinfo; }; class jit_info { public: - typedef jit_infer::type_map type_map; + jit_info (tree_jit& tjit, tree& tee); - jit_info (tree_jit& tjit, tree_simple_for_command& cmd, jit_type *bounds); - - bool execute (const octave_value& bounds) const; + bool execute (void) const; bool match (void) const; private: - jit_typeinfo *tinfo; + typedef jit_convert::type_bound type_bound; + typedef jit_convert::type_bound_vector type_bound_vector; + typedef void (*jited_function)(octave_base_value**); + llvm::ExecutionEngine *engine; - type_map types; - llvm::Function *function; - jit_type *bounds_t; + jited_function function; + + std::vector > arguments; + type_bound_vector bounds; }; #endif