# HG changeset patch # User Max Brister # Date 1337292441 21600 # Node ID 3f81e8b42955e416d1a476a5f77ae111b9ecbe5c # Parent 05bf75eaea2abde089def7cc696f601a40ea6da4 JIT for loops over ranges * src/pt-eval.cc (tree_evaluator::visit_statment): Removed jit. (tree_evaluator::visit_simple_for_command): Added jit. * src/pt-jit.cc: Implement JIT of range based for loops. * src/pt-jit.h: Implement JI of range based for loops. * src/pt-loop.h (tree_simple_for_command::get_info, tree_simple_for_command::stash_info): New functions. * src/pt-loop.cc (tree_simple_for_command::~tree_simple_for_command): Delete stashed info. diff --git a/src/pt-eval.cc b/src/pt-eval.cc --- a/src/pt-eval.cc +++ b/src/pt-eval.cc @@ -310,6 +310,9 @@ if (error_state || rhs.is_undefined ()) return; + if (jiter.execute (cmd, rhs)) + return; + { tree_expression *lhs = cmd.left_hand_side (); @@ -684,9 +687,6 @@ tree_command *cmd = stmt.command (); tree_expression *expr = stmt.expression (); - if (! Vdebugging && ! Vecho_executing_commands && jiter.execute (stmt)) - return; - if (cmd || expr) { if (statement_context == function || statement_context == script) diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -61,6 +61,15 @@ //FIXME: Move into tree_jit static llvm::IRBuilder<> builder (llvm::getGlobalContext ()); +// thrown when we should give up on JIT and interpret +class jit_fail_exception : public std::exception {}; + +static void +fail (void) +{ + throw jit_fail_exception (); +} + // function that jit code calls extern "C" void octave_jit_print_any (const char *name, octave_base_value *obv) @@ -177,19 +186,43 @@ } // -------------------- jit_typeinfo -------------------- -jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e, llvm::Type *ov) - : module (m), engine (e), next_id (0), ov_t (ov) +jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e) + : module (m), engine (e), next_id (0) { // FIXME: We should be registering types like in octave_value_typeinfo llvm::LLVMContext &ctx = m->getContext (); + ov_t = llvm::StructType::create (ctx, "octave_base_value"); + ov_t = ov_t->getPointerTo (); + + llvm::Type *dbl = llvm::Type::getDoubleTy (ctx); + llvm::Type *bool_t = llvm::Type::getInt1Ty (ctx); + llvm::Type *index_t = 0; + switch (sizeof(octave_idx_type)) + { + case 4: + index_t = llvm::Type::getInt32Ty (ctx); + break; + case 8: + index_t = llvm::Type::getInt64Ty (ctx); + break; + default: + assert (false && "Unrecognized index type size"); + } + + llvm::StructType *range_t = llvm::StructType::create (ctx, "range"); + std::vector range_contents (4, dbl); + range_contents[3] = index_t; + range_t->setBody (range_contents); + // create types any = new_type ("any", true, 0, ov_t); - scalar = new_type ("scalar", false, any, llvm::Type::getDoubleTy (ctx)); + scalar = new_type ("scalar", false, any, dbl); + range = new_type ("range", false, any, range_t); + boolean = new_type ("bool", false, any, bool_t); + index = new_type ("index", false, any, index_t); // any with anything is an any op - llvm::IRBuilder<> fn_builder (ctx); - llvm::Type *binary_op_type = llvm::Type::getIntNTy (ctx, sizeof (octave_value::binary_op)); std::vector args (3); @@ -211,22 +244,22 @@ for (int op = 0; op < octave_value::num_binary_ops; ++op) { llvm::FunctionType *ftype = llvm::FunctionType::get (ov_t, args, false); - + llvm::Twine fn_name ("octave_jit_binary_any_any_"); fn_name = fn_name + llvm::Twine (op); llvm::Function *fn = llvm::Function::Create (ftype, llvm::Function::ExternalLinkage, fn_name, module); llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); - fn_builder.SetInsertPoint (block); + builder.SetInsertPoint (block); llvm::APInt op_int(sizeof (octave_value::binary_op), op, std::numeric_limits::is_signed); llvm::Value *op_as_llvm = llvm::ConstantInt::get (binary_op_type, op_int); - llvm::Value *ret = fn_builder.CreateCall3 (any_binary, + llvm::Value *ret = builder.CreateCall3 (any_binary, op_as_llvm, fn->arg_begin (), ++fn->arg_begin ()); - fn_builder.CreateRet (ret); + builder.CreateRet (ret); jit_function::overload overload (fn, true, any, any, any); for (octave_idx_type i = 0; i < next_id; ++i) @@ -234,11 +267,11 @@ } // assign any = any - llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); + llvm::Type *void_t = llvm::Type::getVoidTy (ctx); args.resize (2); args[0] = any->to_llvm (); args[1] = any->to_llvm (); - llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, args, false); + llvm::FunctionType *ft = llvm::FunctionType::get (void_t, args, false); llvm::Function *fn_help = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, "octave_jit_assign_any_any_help", module); @@ -248,17 +281,17 @@ args.resize (2); args[0] = any->to_llvm_arg (); args[1] = any->to_llvm (); - ft = llvm::FunctionType::get (tvoid, args, false); + ft = llvm::FunctionType::get (void_t, args, false); llvm::Function *fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, "octave_jit_assign_any_any", module); fn->addFnAttr (llvm::Attribute::AlwaysInline); llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", fn); - fn_builder.SetInsertPoint (body); - llvm::Value *value = fn_builder.CreateLoad (fn->arg_begin ()); - fn_builder.CreateCall2 (fn_help, value, ++fn->arg_begin ()); - fn_builder.CreateStore (++fn->arg_begin (), fn->arg_begin ()); - fn_builder.CreateRetVoid (); + builder.SetInsertPoint (body); + llvm::Value *value = builder.CreateLoad (fn->arg_begin ()); + builder.CreateCall2 (fn_help, value, ++fn->arg_begin ()); + builder.CreateStore (++fn->arg_begin (), fn->arg_begin ()); + builder.CreateRetVoid (); llvm::verifyFunction (*fn); assign_fn.add_overload (fn, false, 0, any, any); @@ -266,14 +299,14 @@ args.resize (2); args[0] = scalar->to_llvm_arg (); args[1] = scalar->to_llvm (); - ft = llvm::FunctionType::get (tvoid, args, false); + ft = llvm::FunctionType::get (void_t, args, false); fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, "octave_jit_assign_scalar_scalar", module); fn->addFnAttr (llvm::Attribute::AlwaysInline); body = llvm::BasicBlock::Create (ctx, "body", fn); - fn_builder.SetInsertPoint (body); - fn_builder.CreateStore (++fn->arg_begin (), fn->arg_begin ()); - fn_builder.CreateRetVoid (); + builder.SetInsertPoint (body); + builder.CreateStore (++fn->arg_begin (), fn->arg_begin ()); + builder.CreateRetVoid (); llvm::verifyFunction (*fn); assign_fn.add_overload (fn, false, 0, scalar, scalar); @@ -291,13 +324,76 @@ // now for printing functions add_print (any, reinterpret_cast (&octave_jit_print_any)); add_print (scalar, reinterpret_cast (&octave_jit_print_double)); + + // bounds check for for loop + args.resize (2); + args[0] = range->to_llvm (); + args[1] = index->to_llvm (); + ft = llvm::FunctionType::get (bool_t, args, false); + fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "octave_jit_simple_for_range", module); + fn->addFnAttr (llvm::Attribute::AlwaysInline); + body = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *nelem + = builder.CreateExtractValue (fn->arg_begin (), 3); + // llvm::Value *idx = builder.CreateLoad (++fn->arg_begin ()); + llvm::Value *idx = ++fn->arg_begin (); + llvm::Value *ret = builder.CreateICmpULT (idx, nelem); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + simple_for_check.add_overload (fn, false, boolean, range, index); + + // increment for for loop + args.resize (1); + args[0] = index->to_llvm (); + ft = llvm::FunctionType::get (index->to_llvm (), args, false); + fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "octave_jit_imple_for_range_incr", module); + fn->addFnAttr (llvm::Attribute::AlwaysInline); + body = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (index_t, 1); + llvm::Value *idx = fn->arg_begin (); + llvm::Value *ret = builder.CreateAdd (idx, one); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + simple_for_incr.add_overload (fn, false, index, index); + + // index variabe for for loop + args.resize (2); + args[0] = range->to_llvm (); + args[1] = index->to_llvm (); + ft = llvm::FunctionType::get (dbl, args, false); + fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "octave_jit_simple_for_idx", module); + fn->addFnAttr (llvm::Attribute::AlwaysInline); + body = llvm::BasicBlock::Create (ctx, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *idx = ++fn->arg_begin (); + llvm::Value *didx = builder.CreateUIToFP (idx, dbl); + llvm::Value *rng = fn->arg_begin (); + llvm::Value *base = builder.CreateExtractValue (rng, 0); + llvm::Value *inc = builder.CreateExtractValue (rng, 2); + + llvm::Value *ret = builder.CreateFMul (didx, inc); + ret = builder.CreateFAdd (base, ret); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + simple_for_index.add_overload (fn, false, scalar, range, index); } void jit_typeinfo::add_print (jit_type *ty, void *call) { llvm::LLVMContext& ctx = llvm::getGlobalContext (); - llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); + llvm::Type *void_t = llvm::Type::getVoidTy (ctx); std::vector args (2); args[0] = llvm::Type::getInt8PtrTy (ctx); args[1] = ty->to_llvm (); @@ -305,7 +401,7 @@ std::stringstream name; name << "octave_jit_print_" << ty->name (); - llvm::FunctionType *print_ty = llvm::FunctionType::get (tvoid, args, false); + llvm::FunctionType *print_ty = llvm::FunctionType::get (void_t, args, false); llvm::Function *fn = llvm::Function::Create (print_ty, llvm::Function::ExternalLinkage, name.str (), module); @@ -327,7 +423,7 @@ octave_value::binary_op ov_op = static_cast(op); fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op) << "_" << ty->name (); - + llvm::Function *fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, fname.str (), @@ -355,6 +451,9 @@ if (ov.is_double_type () && ov.is_real_scalar ()) return get_scalar (); + if (ov.is_range ()) + return get_range (); + return get_any (); } @@ -385,6 +484,8 @@ to_generic (type, gv, octave_value ()); else if (type == scalar) to_generic (type, gv, octave_value (0)); + else if (type == range) + to_generic (type, gv, octave_value (Range ())); else assert (false && "Type not supported yet"); } @@ -396,14 +497,21 @@ { octave_base_value *obv = ov.internal_rep (); obv->grab (); - ov_out[ov_out_idx] = obv; - gv.PointerVal = &ov_out[ov_out_idx++]; + ov_out[ov_idx] = obv; + gv.PointerVal = &ov_out[ov_idx++]; + } + else if (type == scalar) + { + scalar_out[scalar_idx] = ov.double_value (); + gv.PointerVal = &scalar_out[scalar_idx++]; + } + else if (type == range) + { + range_out[range_idx] = ov.range_value (); + gv.PointerVal = &range_out[range_idx++]; } else - { - scalar_out[scalar_out_idx] = ov.double_value (); - gv.PointerVal = &scalar_out[scalar_out_idx++]; - } + assert (false && "Type not supported yet"); } octave_value @@ -411,14 +519,20 @@ { if (type == any) { - octave_base_value **ptr = reinterpret_cast(gv.PointerVal); + octave_base_value **ptr = reinterpret_cast(gv.PointerVal); return octave_value (*ptr); } else if (type == scalar) { - double *ptr = reinterpret_cast(gv.PointerVal); + double *ptr = reinterpret_cast(gv.PointerVal); return octave_value (*ptr); } + else if (type == range) + { + jit_range *ptr = reinterpret_cast(gv.PointerVal); + Range rng = *ptr; + return octave_value (rng); + } else assert (false && "Type not supported yet"); } @@ -426,14 +540,18 @@ void jit_typeinfo::reset_generic (size_t nargs) { - scalar_out_idx = 0; - ov_out_idx = 0; + scalar_idx = 0; + ov_idx = 0; + range_idx = 0; if (scalar_out.size () < nargs) scalar_out.resize (nargs); if (ov_out.size () < nargs) ov_out.resize (nargs); + + if (range_out.size () < nargs) + range_out.resize (nargs); } jit_type* @@ -445,91 +563,30 @@ return ret; } -tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0) -{ - llvm::InitializeNativeTarget (); - module = new llvm::Module ("octave", context); - assert (module); -} - -tree_jit::~tree_jit (void) +// -------------------- jit_infer -------------------- +void +jit_infer::infer (tree_simple_for_command& cmd, jit_type *bounds) { - for (compiled_map::iterator iter = compiled.begin (); iter != compiled.end (); - ++iter) - { - function_list& flist = iter->second; - for (function_list::iterator fiter = flist.begin (); fiter != flist.end (); - ++fiter) - delete *fiter; - } - - delete tinfo; -} - -bool -tree_jit::execute (tree& tee) -{ - // something funny happens during initialization with the engine - bool need_init = false; - if (! engine) - { - need_init = true; - engine = llvm::ExecutionEngine::createJIT (module); - } - - if (! engine) - return false; + argin.insert ("#bounds"); + types["#bounds"] = bounds; - if (need_init) - { - module_pass_manager = new llvm::PassManager (); - module_pass_manager->add (llvm::createAlwaysInlinerPass ()); - - pass_manager = new llvm::FunctionPassManager (module); - pass_manager->add (new llvm::TargetData(*engine->getTargetData ())); - pass_manager->add (llvm::createBasicAliasAnalysisPass ()); - pass_manager->add (llvm::createPromoteMemoryToRegisterPass ()); - pass_manager->add (llvm::createInstructionCombiningPass ()); - pass_manager->add (llvm::createReassociatePass ()); - pass_manager->add (llvm::createGVNPass ()); - pass_manager->add (llvm::createCFGSimplificationPass ()); - pass_manager->doInitialization (); - - llvm::Type *ov_t = llvm::StructType::create (context, "octave_base_value"); - ov_t = ov_t->getPointerTo (); - - tinfo = new jit_typeinfo (module, engine, ov_t); - } - - function_list& fnlist = compiled[&tee]; - for (function_list::iterator iter = fnlist.begin (); iter != fnlist.end (); - ++iter) - { - function_info& fi = **iter; - if (fi.match ()) - return fi.execute (); - } - - function_info *fi = new function_info (*this, tee); - fnlist.push_back (fi); - - return fi->execute (); + infer_simple_for (cmd, bounds); } void -tree_jit::type_infer::visit_anon_fcn_handle (tree_anon_fcn_handle&) +jit_infer::visit_anon_fcn_handle (tree_anon_fcn_handle&) { fail (); } void -tree_jit::type_infer::visit_argument_list (tree_argument_list&) +jit_infer::visit_argument_list (tree_argument_list&) { fail (); } void -tree_jit::type_infer::visit_binary_expression (tree_binary_expression& be) +jit_infer::visit_binary_expression (tree_binary_expression& be) { if (is_lvalue) fail (); @@ -551,145 +608,151 @@ } void -tree_jit::type_infer::visit_break_command (tree_break_command&) +jit_infer::visit_break_command (tree_break_command&) { fail (); } void -tree_jit::type_infer::visit_colon_expression (tree_colon_expression&) +jit_infer::visit_colon_expression (tree_colon_expression&) { fail (); } void -tree_jit::type_infer::visit_continue_command (tree_continue_command&) +jit_infer::visit_continue_command (tree_continue_command&) { fail (); } void -tree_jit::type_infer::visit_global_command (tree_global_command&) +jit_infer::visit_global_command (tree_global_command&) { fail (); } void -tree_jit::type_infer::visit_persistent_command (tree_persistent_command&) +jit_infer::visit_persistent_command (tree_persistent_command&) { fail (); } void -tree_jit::type_infer::visit_decl_elt (tree_decl_elt&) +jit_infer::visit_decl_elt (tree_decl_elt&) { fail (); } void -tree_jit::type_infer::visit_decl_init_list (tree_decl_init_list&) +jit_infer::visit_decl_init_list (tree_decl_init_list&) { fail (); } void -tree_jit::type_infer::visit_simple_for_command (tree_simple_for_command&) +jit_infer::visit_simple_for_command (tree_simple_for_command& cmd) +{ + tree_expression *control = cmd.control_expr (); + control->accept (*this); + + jit_type *control_t = type_stack.back (); + type_stack.pop_back (); + + infer_simple_for (cmd, control_t); +} + +void +jit_infer::visit_complex_for_command (tree_complex_for_command&) { fail (); } void -tree_jit::type_infer::visit_complex_for_command (tree_complex_for_command&) +jit_infer::visit_octave_user_script (octave_user_script&) { fail (); } void -tree_jit::type_infer::visit_octave_user_script (octave_user_script&) -{ - fail (); -} - -void -tree_jit::type_infer::visit_octave_user_function (octave_user_function&) +jit_infer::visit_octave_user_function (octave_user_function&) { fail (); } void -tree_jit::type_infer::visit_octave_user_function_header (octave_user_function&) +jit_infer::visit_octave_user_function_header (octave_user_function&) { fail (); } void -tree_jit::type_infer::visit_octave_user_function_trailer (octave_user_function&) +jit_infer::visit_octave_user_function_trailer (octave_user_function&) { fail (); } void -tree_jit::type_infer::visit_function_def (tree_function_def&) +jit_infer::visit_function_def (tree_function_def&) { fail (); } void -tree_jit::type_infer::visit_identifier (tree_identifier& ti) +jit_infer::visit_identifier (tree_identifier& ti) { handle_identifier (ti.name (), ti.do_lookup ()); } void -tree_jit::type_infer::visit_if_clause (tree_if_clause&) +jit_infer::visit_if_clause (tree_if_clause&) { fail (); } void -tree_jit::type_infer::visit_if_command (tree_if_command&) +jit_infer::visit_if_command (tree_if_command&) { fail (); } void -tree_jit::type_infer::visit_if_command_list (tree_if_command_list&) +jit_infer::visit_if_command_list (tree_if_command_list&) { fail (); } void -tree_jit::type_infer::visit_index_expression (tree_index_expression&) +jit_infer::visit_index_expression (tree_index_expression&) { fail (); } void -tree_jit::type_infer::visit_matrix (tree_matrix&) +jit_infer::visit_matrix (tree_matrix&) { fail (); } void -tree_jit::type_infer::visit_cell (tree_cell&) +jit_infer::visit_cell (tree_cell&) { fail (); } void -tree_jit::type_infer::visit_multi_assignment (tree_multi_assignment&) +jit_infer::visit_multi_assignment (tree_multi_assignment&) { fail (); } void -tree_jit::type_infer::visit_no_op_command (tree_no_op_command&) +jit_infer::visit_no_op_command (tree_no_op_command&) { fail (); } void -tree_jit::type_infer::visit_constant (tree_constant& tc) +jit_infer::visit_constant (tree_constant& tc) { if (is_lvalue) fail (); @@ -703,43 +766,43 @@ } void -tree_jit::type_infer::visit_fcn_handle (tree_fcn_handle&) +jit_infer::visit_fcn_handle (tree_fcn_handle&) { fail (); } void -tree_jit::type_infer::visit_parameter_list (tree_parameter_list&) +jit_infer::visit_parameter_list (tree_parameter_list&) { fail (); } void -tree_jit::type_infer::visit_postfix_expression (tree_postfix_expression&) +jit_infer::visit_postfix_expression (tree_postfix_expression&) { fail (); } void -tree_jit::type_infer::visit_prefix_expression (tree_prefix_expression&) +jit_infer::visit_prefix_expression (tree_prefix_expression&) { fail (); } void -tree_jit::type_infer::visit_return_command (tree_return_command&) +jit_infer::visit_return_command (tree_return_command&) { fail (); } void -tree_jit::type_infer::visit_return_list (tree_return_list&) +jit_infer::visit_return_list (tree_return_list&) { fail (); } void -tree_jit::type_infer::visit_simple_assignment (tree_simple_assignment& tsa) +jit_infer::visit_simple_assignment (tree_simple_assignment& tsa) { if (is_lvalue) fail (); @@ -769,7 +832,7 @@ } void -tree_jit::type_infer::visit_statement (tree_statement& stmt) +jit_infer::visit_statement (tree_statement& stmt) { if (is_lvalue) fail (); @@ -814,55 +877,86 @@ } void -tree_jit::type_infer::visit_statement_list (tree_statement_list&) +jit_infer::visit_statement_list (tree_statement_list& lst) +{ + tree_statement_list::iterator iter; + for (iter = lst.begin (); iter != lst.end (); ++iter) + { + tree_statement *stmt = *iter; + assert (stmt); // FIXME: jwe can this be null? + stmt->accept (*this); + } +} + +void +jit_infer::visit_switch_case (tree_switch_case&) { fail (); } void -tree_jit::type_infer::visit_switch_case (tree_switch_case&) +jit_infer::visit_switch_case_list (tree_switch_case_list&) { fail (); } void -tree_jit::type_infer::visit_switch_case_list (tree_switch_case_list&) +jit_infer::visit_switch_command (tree_switch_command&) { fail (); } void -tree_jit::type_infer::visit_switch_command (tree_switch_command&) +jit_infer::visit_try_catch_command (tree_try_catch_command&) { fail (); } void -tree_jit::type_infer::visit_try_catch_command (tree_try_catch_command&) +jit_infer::visit_unwind_protect_command (tree_unwind_protect_command&) { fail (); } void -tree_jit::type_infer::visit_unwind_protect_command (tree_unwind_protect_command&) +jit_infer::visit_while_command (tree_while_command&) +{ + fail (); +} + +void +jit_infer::visit_do_until_command (tree_do_until_command&) { fail (); } void -tree_jit::type_infer::visit_while_command (tree_while_command&) +jit_infer::infer_simple_for (tree_simple_for_command& cmd, + jit_type *bounds) { - fail (); + if (is_lvalue) + fail (); + + jit_type *iter = tinfo->get_simple_for_index_result (bounds); + if (! iter) + fail (); + + is_lvalue = true; + rvalue_type = iter; + tree_expression *lhs = cmd.left_hand_side (); + lhs->accept (*this); + if (type_stack.back () != iter) + fail (); + type_stack.pop_back (); + is_lvalue = false; + rvalue_type = 0; + + tree_statement_list *body = cmd.body (); + body->accept (*this); } void -tree_jit::type_infer::visit_do_until_command (tree_do_until_command&) -{ - fail (); -} - -void -tree_jit::type_infer::handle_identifier (const std::string& name, octave_value v) +jit_infer::handle_identifier (const std::string& name, octave_value v) { type_map::iterator iter = types.find (name); if (iter == types.end ()) @@ -888,12 +982,11 @@ type_stack.push_back (iter->second); } -tree_jit::code_generator::code_generator (jit_typeinfo *ti, llvm::Module *module, - tree &tee, - const std::set& argin, - const type_map& infered_types) +// -------------------- jit_generator -------------------- +jit_generator::jit_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee, + const std::set& argin, + const type_map& infered_types, bool have_bounds) : tinfo (ti), is_lvalue (false) - { // determine the function type through the type of all variables std::vector arg_types (infered_types.size ()); @@ -915,7 +1008,7 @@ llvm::Function::arg_iterator arg_iter = function->arg_begin(); for (iter = infered_types.begin (); iter != infered_types.end (); ++iter, ++arg_iter) - + { llvm::Type *vartype = iter->second->to_llvm (); llvm::Value *var = builder.CreateAlloca (vartype, 0, iter->first); @@ -931,7 +1024,15 @@ // generate body try { - tee.accept (*this); + tree_simple_for_command *cmd = dynamic_cast(&tee); + if (have_bounds && cmd) + { + value bounds = variables["#bounds"]; + bounds.second = builder.CreateLoad (bounds.second); + emit_simple_for (*cmd, bounds, true); + } + else + tee.accept (*this); } catch (const jit_fail_exception&) { @@ -953,19 +1054,19 @@ } void -tree_jit::code_generator::visit_anon_fcn_handle (tree_anon_fcn_handle&) +jit_generator::visit_anon_fcn_handle (tree_anon_fcn_handle&) { fail (); } void -tree_jit::code_generator::visit_argument_list (tree_argument_list&) +jit_generator::visit_argument_list (tree_argument_list&) { fail (); } void -tree_jit::code_generator::visit_binary_expression (tree_binary_expression& be) +jit_generator::visit_binary_expression (tree_binary_expression& be) { tree_expression *lhs = be.lhs (); lhs->accept (*this); @@ -989,91 +1090,101 @@ } void -tree_jit::code_generator::visit_break_command (tree_break_command&) +jit_generator::visit_break_command (tree_break_command&) { fail (); } void -tree_jit::code_generator::visit_colon_expression (tree_colon_expression&) +jit_generator::visit_colon_expression (tree_colon_expression&) { fail (); } void -tree_jit::code_generator::visit_continue_command (tree_continue_command&) +jit_generator::visit_continue_command (tree_continue_command&) { fail (); } void -tree_jit::code_generator::visit_global_command (tree_global_command&) +jit_generator::visit_global_command (tree_global_command&) { fail (); } void -tree_jit::code_generator::visit_persistent_command (tree_persistent_command&) +jit_generator::visit_persistent_command (tree_persistent_command&) { fail (); } void -tree_jit::code_generator::visit_decl_elt (tree_decl_elt&) +jit_generator::visit_decl_elt (tree_decl_elt&) { fail (); } void -tree_jit::code_generator::visit_decl_init_list (tree_decl_init_list&) +jit_generator::visit_decl_init_list (tree_decl_init_list&) { fail (); } void -tree_jit::code_generator::visit_simple_for_command (tree_simple_for_command&) +jit_generator::visit_simple_for_command (tree_simple_for_command& cmd) { - fail (); + if (is_lvalue) + fail (); + + tree_expression *control = cmd.control_expr (); + assert (control); // FIXME: jwe, can this be null? + + control->accept (*this); + value over = value_stack.back (); + value_stack.pop_back (); + + emit_simple_for (cmd, over, false); } void -tree_jit::code_generator::visit_complex_for_command (tree_complex_for_command&) -{ - fail (); -} - -void -tree_jit::code_generator::visit_octave_user_script (octave_user_script&) +jit_generator::visit_complex_for_command (tree_complex_for_command&) { fail (); } void -tree_jit::code_generator::visit_octave_user_function (octave_user_function&) +jit_generator::visit_octave_user_script (octave_user_script&) { fail (); } void -tree_jit::code_generator::visit_octave_user_function_header (octave_user_function&) +jit_generator::visit_octave_user_function (octave_user_function&) { fail (); } void -tree_jit::code_generator::visit_octave_user_function_trailer (octave_user_function&) +jit_generator::visit_octave_user_function_header (octave_user_function&) { fail (); } void -tree_jit::code_generator::visit_function_def (tree_function_def&) +jit_generator::visit_octave_user_function_trailer (octave_user_function&) { fail (); } void -tree_jit::code_generator::visit_identifier (tree_identifier& ti) +jit_generator::visit_function_def (tree_function_def&) +{ + fail (); +} + +void +jit_generator::visit_identifier (tree_identifier& ti) { std::string name = ti.name (); value variable = variables[name]; @@ -1087,106 +1198,123 @@ } void -tree_jit::code_generator::visit_if_clause (tree_if_clause&) +jit_generator::visit_if_clause (tree_if_clause&) +{ + fail (); +} + +void +jit_generator::visit_if_command (tree_if_command&) { fail (); } void -tree_jit::code_generator::visit_if_command (tree_if_command&) +jit_generator::visit_if_command_list (tree_if_command_list&) { fail (); } void -tree_jit::code_generator::visit_if_command_list (tree_if_command_list&) +jit_generator::visit_index_expression (tree_index_expression&) { fail (); } void -tree_jit::code_generator::visit_index_expression (tree_index_expression&) +jit_generator::visit_matrix (tree_matrix&) +{ + fail (); +} + +void +jit_generator::visit_cell (tree_cell&) { fail (); } void -tree_jit::code_generator::visit_matrix (tree_matrix&) +jit_generator::visit_multi_assignment (tree_multi_assignment&) { fail (); } void -tree_jit::code_generator::visit_cell (tree_cell&) +jit_generator::visit_no_op_command (tree_no_op_command&) { fail (); } void -tree_jit::code_generator::visit_multi_assignment (tree_multi_assignment&) -{ - fail (); -} - -void -tree_jit::code_generator::visit_no_op_command (tree_no_op_command&) -{ - fail (); -} - -void -tree_jit::code_generator::visit_constant (tree_constant& tc) +jit_generator::visit_constant (tree_constant& tc) { octave_value v = tc.rvalue1 (); + llvm::LLVMContext& ctx = llvm::getGlobalContext (); if (v.is_real_scalar () && v.is_double_type ()) { - llvm::LLVMContext& ctx = llvm::getGlobalContext (); double dv = v.double_value (); llvm::Value *lv = llvm::ConstantFP::get (ctx, llvm::APFloat (dv)); push_value (tinfo->get_scalar (), lv); } + else if (v.is_range ()) + { + Range rng = v.range_value (); + llvm::Type *range = tinfo->get_range_llvm (); + llvm::Type *scalar = tinfo->get_scalar_llvm (); + llvm::Type *index = tinfo->get_index_llvm (); + + std::vector values (4); + values[0] = llvm::ConstantFP::get (scalar, rng.base ()); + values[1] = llvm::ConstantFP::get (scalar, rng.limit ()); + values[2] = llvm::ConstantFP::get (scalar, rng.inc ()); + values[3] = llvm::ConstantInt::get (index, rng.nelem ()); + + llvm::StructType *llvm_range = llvm::cast(range); + llvm::Value *lv = llvm::ConstantStruct::get (llvm_range, values); + push_value (tinfo->get_range (), lv); + } else fail (); } void -tree_jit::code_generator::visit_fcn_handle (tree_fcn_handle&) +jit_generator::visit_fcn_handle (tree_fcn_handle&) { fail (); } void -tree_jit::code_generator::visit_parameter_list (tree_parameter_list&) +jit_generator::visit_parameter_list (tree_parameter_list&) { fail (); } void -tree_jit::code_generator::visit_postfix_expression (tree_postfix_expression&) +jit_generator::visit_postfix_expression (tree_postfix_expression&) { fail (); } void -tree_jit::code_generator::visit_prefix_expression (tree_prefix_expression&) +jit_generator::visit_prefix_expression (tree_prefix_expression&) { fail (); } void -tree_jit::code_generator::visit_return_command (tree_return_command&) +jit_generator::visit_return_command (tree_return_command&) { fail (); } void -tree_jit::code_generator::visit_return_list (tree_return_list&) +jit_generator::visit_return_list (tree_return_list&) { fail (); } void -tree_jit::code_generator::visit_simple_assignment (tree_simple_assignment& tsa) +jit_generator::visit_simple_assignment (tree_simple_assignment& tsa) { if (is_lvalue) fail (); @@ -1218,7 +1346,7 @@ } void -tree_jit::code_generator::visit_statement (tree_statement& stmt) +jit_generator::visit_statement (tree_statement& stmt) { tree_command *cmd = stmt.command (); tree_expression *expr = stmt.expression (); @@ -1266,55 +1394,126 @@ } void -tree_jit::code_generator::visit_statement_list (tree_statement_list&) +jit_generator::visit_statement_list (tree_statement_list& lst) +{ + tree_statement_list::iterator iter; + for (iter = lst.begin (); iter != lst.end (); ++iter) + { + tree_statement *stmt = *iter; + assert (stmt); // FIXME: jwe can this be null? + stmt->accept (*this); + } +} + +void +jit_generator::visit_switch_case (tree_switch_case&) +{ + fail (); +} + +void +jit_generator::visit_switch_case_list (tree_switch_case_list&) { fail (); } void -tree_jit::code_generator::visit_switch_case (tree_switch_case&) +jit_generator::visit_switch_command (tree_switch_command&) +{ + fail (); +} + +void +jit_generator::visit_try_catch_command (tree_try_catch_command&) { fail (); } void -tree_jit::code_generator::visit_switch_case_list (tree_switch_case_list&) +jit_generator::visit_unwind_protect_command (tree_unwind_protect_command&) { fail (); } void -tree_jit::code_generator::visit_switch_command (tree_switch_command&) +jit_generator::visit_while_command (tree_while_command&) +{ + fail (); +} + +void +jit_generator::visit_do_until_command (tree_do_until_command&) { fail (); } void -tree_jit::code_generator::visit_try_catch_command (tree_try_catch_command&) +jit_generator::emit_simple_for (tree_simple_for_command& cmd, value over, + bool atleast_once) { - fail (); -} + if (is_lvalue) + fail (); + + jit_type *index = tinfo->get_index (); + llvm::Value *init_index = 0; + if (over.first == tinfo->get_range ()) + init_index = llvm::ConstantInt::get (index->to_llvm (), 0); + else + fail (); + + llvm::Value *llvm_index = builder.CreateAlloca (index->to_llvm (), 0, "index"); + builder.CreateStore (init_index, llvm_index); + + // FIXME: Support break + llvm::LLVMContext &ctx = llvm::getGlobalContext (); + llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "for_body", function); + llvm::BasicBlock *cond_check = llvm::BasicBlock::Create (ctx, "for_check", function); + llvm::BasicBlock *tail = llvm::BasicBlock::Create (ctx, "for_tail", function); + + // initialize the iter from the index + if (atleast_once) + builder.CreateBr (body); + else + builder.CreateBr (cond_check); + + builder.SetInsertPoint (body); -void -tree_jit::code_generator::visit_unwind_protect_command (tree_unwind_protect_command&) -{ - fail (); + is_lvalue = true; + tree_expression *lhs = cmd.left_hand_side (); + lhs->accept (*this); + is_lvalue = false; + + value lhsv = value_stack.back (); + value_stack.pop_back (); + + const jit_function::overload& index_ol = tinfo->get_simple_for_index (over.first); + llvm::Value *lindex = builder.CreateLoad (llvm_index); + llvm::Value *llvm_iter = builder.CreateCall2 (index_ol.function, over.second, lindex); + value iter(index_ol.result, llvm_iter); + + jit_function::overload assign = tinfo->assign_op (lhsv.first, iter.first); + builder.CreateCall2 (assign.function, lhsv.second, iter.second); + + tree_statement_list *lst = cmd.body (); + lst->accept (*this); + + llvm::Value *one = llvm::ConstantInt::get (index->to_llvm (), 1); + lindex = builder.CreateLoad (llvm_index); + lindex = builder.CreateAdd (lindex, one); + builder.CreateStore (lindex, llvm_index); + builder.CreateBr (cond_check); + + builder.SetInsertPoint (cond_check); + lindex = builder.CreateLoad (llvm_index); + const jit_function::overload& check_ol = tinfo->get_simple_for_check (over.first); + llvm::Value *cond = builder.CreateCall2 (check_ol.function, over.second, lindex); + builder.CreateCondBr (cond, body, tail); + + builder.SetInsertPoint (tail); } void -tree_jit::code_generator::visit_while_command (tree_while_command&) -{ - fail (); -} - -void -tree_jit::code_generator::visit_do_until_command (tree_do_until_command&) -{ - fail (); -} - -void -tree_jit::code_generator::emit_print (const std::string& name, const value& v) +jit_generator::emit_print (const std::string& name, const value& v) { const jit_function::overload& ol = tinfo->print_value (v.first); if (! ol.function) @@ -1324,14 +1523,84 @@ builder.CreateCall2 (ol.function, str, v.second); } -tree_jit::function_info::function_info (tree_jit& tjit, tree& tee) : - tinfo (tjit.tinfo), engine (tjit.engine) +// -------------------- tree_jit -------------------- + +tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0) +{ + llvm::InitializeNativeTarget (); + module = new llvm::Module ("octave", context); +} + +tree_jit::~tree_jit (void) +{ + delete tinfo; +} + +bool +tree_jit::execute (tree_simple_for_command& cmd, const octave_value& bounds) +{ + if (! initialize ()) + return false; + + jit_type *bounds_t = tinfo->type_of (bounds); + jit_info *jinfo = cmd.get_info (bounds_t); + if (! jinfo) + { + jinfo = new jit_info (*this, cmd, bounds_t); + cmd.stash_info (bounds_t, jinfo); + } + + return jinfo->execute (bounds); +} + +bool +tree_jit::initialize (void) { - type_infer infer(tjit.tinfo); + if (engine) + return true; + + // sometimes this fails pre main + engine = llvm::ExecutionEngine::createJIT (module); + + if (! engine) + return false; + + module_pass_manager = new llvm::PassManager (); + module_pass_manager->add (llvm::createAlwaysInlinerPass ()); + + pass_manager = new llvm::FunctionPassManager (module); + pass_manager->add (new llvm::TargetData(*engine->getTargetData ())); + pass_manager->add (llvm::createBasicAliasAnalysisPass ()); + pass_manager->add (llvm::createPromoteMemoryToRegisterPass ()); + pass_manager->add (llvm::createInstructionCombiningPass ()); + pass_manager->add (llvm::createReassociatePass ()); + pass_manager->add (llvm::createGVNPass ()); + pass_manager->add (llvm::createCFGSimplificationPass ()); + pass_manager->doInitialization (); + + tinfo = new jit_typeinfo (module, engine); + + return true; +} + + +void +tree_jit::optimize (llvm::Function *fn) +{ + module_pass_manager->run (*module); + pass_manager->run (*fn); +} + +// -------------------- jit_info -------------------- +jit_info::jit_info (tree_jit& tjit, tree_simple_for_command& cmd, + jit_type *bounds) : tinfo (tjit.get_typeinfo ()), + engine (tjit.get_engine ()) +{ + jit_infer infer(tinfo); try { - tee.accept (infer); + infer.infer (cmd, bounds); } catch (const jit_fail_exception&) { @@ -1342,21 +1611,27 @@ argin = infer.get_argin (); types = infer.get_types (); - code_generator gen(tjit.tinfo, tjit.module, tee, argin, types); + jit_generator gen(tinfo, tjit.get_module (), cmd, argin, types); function = gen.get_function (); if (function) { + if (debug_print) + { + std::cout << "Compiled code:\n"; + std::cout << cmd.str_print_code () << std::endl; + + std::cout << "Before optimization:\n"; + + llvm::raw_os_ostream os (std::cout); + function->print (os); + } llvm::verifyFunction (*function); - tjit.module_pass_manager->run (*tjit.module); - tjit.pass_manager->run (*function); + tjit.optimize (function); if (debug_print) { - std::cout << "Compiled:\n"; - std::cout << tee.str_print_code () << std::endl; - - std::cout << "Code:\n"; + std::cout << "After optimization:\n"; llvm::raw_os_ostream os (std::cout); function->print (os); @@ -1365,7 +1640,7 @@ } bool -tree_jit::function_info::execute () const +jit_info::execute (const octave_value& bounds) const { if (! function) return false; @@ -1379,7 +1654,12 @@ { if (argin.count (iter->first)) { - octave_value ov = symbol_table::varval (iter->first); + octave_value ov; + if (iter->first == "#bounds") + ov = bounds; + else + ov = symbol_table::varval (iter->first); + tinfo->to_generic (iter->second, args[idx], ov); } else @@ -1398,11 +1678,14 @@ } bool -tree_jit::function_info::match () const +jit_info::match () const { for (std::set::iterator iter = argin.begin (); iter != argin.end (); ++iter) { + if (*iter == "#bounds") + continue; + jit_type *required_type = types.find (*iter)->second; octave_value val = symbol_table::varref (*iter); jit_type *current_type = tinfo->type_of (val); diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -30,6 +30,7 @@ #include #include "Array.h" +#include "Range.h" #include "pt-walk.h" // -------------------- Current status -------------------- @@ -43,6 +44,19 @@ // a = [1 2 3] // b = a + a; // will compile to do_binary_op (a, a). +// +// for loops with ranges compile. For example, +// for i=1:1000 +// result = i + 1; +// endfor +// Will compile. Nested for loops with constant bounds are also supported. +// +// TODO: +// 1. Cleanup +// 2. Support if statements +// 3. Support iteration over matricies +// 4. Check error state +// 5. ... // --------------------------------------------------------- @@ -66,8 +80,27 @@ class octave_value; class tree; -// thrown when we should give up on JIT and interpret -class jit_fail_exception : public std::exception {}; +// jit_range is compatable with the llvm range structure +struct +OCTINTERP_API +jit_range +{ + jit_range (void) {} + + jit_range (const Range& from) : base (from.base ()), limit (from.limit ()), + inc (from.inc ()), nelem (from.nelem ()) + {} + + operator Range () const + { + return Range (base, limit, inc); + } + + double base; + double limit; + double inc; + octave_idx_type nelem; +}; // Used to keep track of estimated (infered) types during JIT. This is a // hierarchical type system which includes both concrete and abstract types. @@ -109,10 +142,8 @@ jit_type *p; llvm::Type *llvm_type; int id; - int depth; }; - // Keeps track of overloads for a builtin function. Used for both type inference // and code generation. class @@ -148,6 +179,12 @@ add_overload (func, func.arguments); } + void add_overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0) + { + overload ol (f, e, r, arg0); + add_overload (ol); + } + void add_overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0, jit_type *arg1) { @@ -198,12 +235,24 @@ jit_typeinfo { public: - jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e, llvm::Type *ov); + jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e); jit_type *get_any (void) const { return any; } jit_type *get_scalar (void) const { return scalar; } + llvm::Type *get_scalar_llvm (void) const { return scalar->to_llvm (); } + + jit_type *get_range (void) const { return range; } + + llvm::Type *get_range_llvm (void) const { return range->to_llvm (); } + + jit_type *get_bool (void) const { return boolean; } + + jit_type *get_index (void) const { return index; } + + llvm::Type *get_index_llvm (void) const { return index->to_llvm (); } + jit_type *type_of (const octave_value& ov) const; const jit_function& binary_op (int op) const; @@ -225,6 +274,22 @@ const jit_function::overload& print_value (jit_type *to_print) const; + const jit_function::overload& get_simple_for_check (jit_type *bounds) const + { + return simple_for_check.get_overload (bounds, index); + } + + const jit_function::overload& get_simple_for_index (jit_type *bounds) const + { + return simple_for_index.get_overload (bounds, index); + } + + jit_type *get_simple_for_index_result (jit_type *bounds) const + { + const jit_function::overload& ol = get_simple_for_index (bounds); + return ol.result; + } + // FIXME: generic creation should probably be handled seperatly void to_generic (jit_type *type, llvm::GenericValue& gv); void to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov); @@ -233,6 +298,8 @@ void reset_generic (size_t nargs); private: + typedef std::map type_map; + jit_type *new_type (const std::string& name, bool force_init, jit_type *parent, llvm::Type *llvm_type); @@ -249,16 +316,261 @@ std::vector id_to_type; jit_type *any; jit_type *scalar; + jit_type *range; + jit_type *boolean; + jit_type *index; std::vector binary_ops; jit_function assign_fn; jit_function print_fn; + jit_function simple_for_check; + jit_function simple_for_incr; + jit_function simple_for_index; - size_t scalar_out_idx; + size_t scalar_idx; std::vector scalar_out; - size_t ov_out_idx; + size_t ov_idx; std::vector ov_out; + + size_t range_idx; + std::vector range_out; +}; + +class +OCTINTERP_API +jit_infer : public tree_walker +{ + typedef std::map type_map; +public: + jit_infer (jit_typeinfo *ti) : tinfo (ti), is_lvalue (false), + rvalue_type (0) + {} + + const std::set& get_argin () const { return argin; } + + const type_map& get_types () const { return types; } + + void infer (tree_simple_for_command& cmd, jit_type *bounds); + + void visit_anon_fcn_handle (tree_anon_fcn_handle&); + + void visit_argument_list (tree_argument_list&); + + void visit_binary_expression (tree_binary_expression&); + + void visit_break_command (tree_break_command&); + + void visit_colon_expression (tree_colon_expression&); + + void visit_continue_command (tree_continue_command&); + + void visit_global_command (tree_global_command&); + + void visit_persistent_command (tree_persistent_command&); + + void visit_decl_elt (tree_decl_elt&); + + void visit_decl_init_list (tree_decl_init_list&); + + void visit_simple_for_command (tree_simple_for_command&); + + void visit_complex_for_command (tree_complex_for_command&); + + void visit_octave_user_script (octave_user_script&); + + void visit_octave_user_function (octave_user_function&); + + void visit_octave_user_function_header (octave_user_function&); + + void visit_octave_user_function_trailer (octave_user_function&); + + void visit_function_def (tree_function_def&); + + void visit_identifier (tree_identifier&); + + void visit_if_clause (tree_if_clause&); + + void visit_if_command (tree_if_command&); + + void visit_if_command_list (tree_if_command_list&); + + void visit_index_expression (tree_index_expression&); + + void visit_matrix (tree_matrix&); + + void visit_cell (tree_cell&); + + void visit_multi_assignment (tree_multi_assignment&); + + void visit_no_op_command (tree_no_op_command&); + + void visit_constant (tree_constant&); + + void visit_fcn_handle (tree_fcn_handle&); + + void visit_parameter_list (tree_parameter_list&); + + void visit_postfix_expression (tree_postfix_expression&); + + void visit_prefix_expression (tree_prefix_expression&); + + void visit_return_command (tree_return_command&); + + void visit_return_list (tree_return_list&); + + void visit_simple_assignment (tree_simple_assignment&); + + void visit_statement (tree_statement&); + + void visit_statement_list (tree_statement_list&); + + void visit_switch_case (tree_switch_case&); + + void visit_switch_case_list (tree_switch_case_list&); + + void visit_switch_command (tree_switch_command&); + + void visit_try_catch_command (tree_try_catch_command&); + + void visit_unwind_protect_command (tree_unwind_protect_command&); + + void visit_while_command (tree_while_command&); + + void visit_do_until_command (tree_do_until_command&); +private: + void infer_simple_for (tree_simple_for_command& cmd, + jit_type *bounds); + + void handle_identifier (const std::string& name, octave_value v); + + jit_typeinfo *tinfo; + + bool is_lvalue; + jit_type *rvalue_type; + + type_map types; + std::set argin; + + std::vector type_stack; +}; + +class +OCTINTERP_API +jit_generator : public tree_walker +{ + typedef std::map type_map; +public: + jit_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee, + const std::set& argin, + const type_map& infered_types, bool have_bounds = true); + + llvm::Function *get_function () const { return function; } + + void visit_anon_fcn_handle (tree_anon_fcn_handle&); + + void visit_argument_list (tree_argument_list&); + + void visit_binary_expression (tree_binary_expression&); + + void visit_break_command (tree_break_command&); + + void visit_colon_expression (tree_colon_expression&); + + void visit_continue_command (tree_continue_command&); + + void visit_global_command (tree_global_command&); + + void visit_persistent_command (tree_persistent_command&); + + void visit_decl_elt (tree_decl_elt&); + + void visit_decl_init_list (tree_decl_init_list&); + + void visit_simple_for_command (tree_simple_for_command&); + + void visit_complex_for_command (tree_complex_for_command&); + + void visit_octave_user_script (octave_user_script&); + + void visit_octave_user_function (octave_user_function&); + + void visit_octave_user_function_header (octave_user_function&); + + void visit_octave_user_function_trailer (octave_user_function&); + + void visit_function_def (tree_function_def&); + + void visit_identifier (tree_identifier&); + + void visit_if_clause (tree_if_clause&); + + void visit_if_command (tree_if_command&); + + void visit_if_command_list (tree_if_command_list&); + + void visit_index_expression (tree_index_expression&); + + void visit_matrix (tree_matrix&); + + void visit_cell (tree_cell&); + + void visit_multi_assignment (tree_multi_assignment&); + + void visit_no_op_command (tree_no_op_command&); + + void visit_constant (tree_constant&); + + void visit_fcn_handle (tree_fcn_handle&); + + void visit_parameter_list (tree_parameter_list&); + + void visit_postfix_expression (tree_postfix_expression&); + + void visit_prefix_expression (tree_prefix_expression&); + + void visit_return_command (tree_return_command&); + + void visit_return_list (tree_return_list&); + + void visit_simple_assignment (tree_simple_assignment&); + + void visit_statement (tree_statement&); + + void visit_statement_list (tree_statement_list&); + + void visit_switch_case (tree_switch_case&); + + void visit_switch_case_list (tree_switch_case_list&); + + void visit_switch_command (tree_switch_command&); + + void visit_try_catch_command (tree_try_catch_command&); + + void visit_unwind_protect_command (tree_unwind_protect_command&); + + void visit_while_command (tree_while_command&); + + void visit_do_until_command (tree_do_until_command&); +private: + typedef std::pair value; + + void emit_simple_for (tree_simple_for_command& cmd, value over, + bool atleast_once); + + void emit_print (const std::string& name, const value& v); + + void push_value (jit_type *type, llvm::Value *v) + { + value_stack.push_back (value (type, v)); + } + + jit_typeinfo *tinfo; + llvm::Function *function; + + bool is_lvalue; + std::map variables; + std::vector value_stack; }; class @@ -270,257 +582,17 @@ ~tree_jit (void); - bool execute (tree& tee); - private: - typedef std::map type_map; - - class - type_infer : public tree_walker - { - public: - type_infer (jit_typeinfo *ti) : tinfo (ti), is_lvalue (false), - rvalue_type (0) - {} - - const std::set& get_argin () const { return argin; } - - const type_map& get_types () const { return types; } - - void visit_anon_fcn_handle (tree_anon_fcn_handle&); - - void visit_argument_list (tree_argument_list&); - - void visit_binary_expression (tree_binary_expression&); - - void visit_break_command (tree_break_command&); - - void visit_colon_expression (tree_colon_expression&); - - void visit_continue_command (tree_continue_command&); - - void visit_global_command (tree_global_command&); - - void visit_persistent_command (tree_persistent_command&); - - void visit_decl_elt (tree_decl_elt&); - - void visit_decl_init_list (tree_decl_init_list&); - - void visit_simple_for_command (tree_simple_for_command&); - - void visit_complex_for_command (tree_complex_for_command&); - - void visit_octave_user_script (octave_user_script&); - - void visit_octave_user_function (octave_user_function&); - - void visit_octave_user_function_header (octave_user_function&); - - void visit_octave_user_function_trailer (octave_user_function&); - - void visit_function_def (tree_function_def&); - - void visit_identifier (tree_identifier&); - - void visit_if_clause (tree_if_clause&); - - void visit_if_command (tree_if_command&); - - void visit_if_command_list (tree_if_command_list&); - - void visit_index_expression (tree_index_expression&); - - void visit_matrix (tree_matrix&); + bool execute (tree_simple_for_command& cmd, const octave_value& bounds); - void visit_cell (tree_cell&); - - void visit_multi_assignment (tree_multi_assignment&); - - void visit_no_op_command (tree_no_op_command&); - - void visit_constant (tree_constant&); - - void visit_fcn_handle (tree_fcn_handle&); - - void visit_parameter_list (tree_parameter_list&); - - void visit_postfix_expression (tree_postfix_expression&); - - void visit_prefix_expression (tree_prefix_expression&); - - void visit_return_command (tree_return_command&); - - void visit_return_list (tree_return_list&); - - void visit_simple_assignment (tree_simple_assignment&); - - void visit_statement (tree_statement&); - - void visit_statement_list (tree_statement_list&); - - void visit_switch_case (tree_switch_case&); - - void visit_switch_case_list (tree_switch_case_list&); + jit_typeinfo *get_typeinfo (void) const { return tinfo; } - void visit_switch_command (tree_switch_command&); - - void visit_try_catch_command (tree_try_catch_command&); - - void visit_unwind_protect_command (tree_unwind_protect_command&); - - void visit_while_command (tree_while_command&); - - void visit_do_until_command (tree_do_until_command&); - private: - void handle_identifier (const std::string& name, octave_value v); - - jit_typeinfo *tinfo; - - bool is_lvalue; - jit_type *rvalue_type; - - type_map types; - std::set argin; - - std::vector type_stack; - }; - - class - code_generator : public tree_walker - { - public: - code_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee, - const std::set& argin, - const type_map& infered_types); - - llvm::Function *get_function () const { return function; } + llvm::ExecutionEngine *get_engine (void) const { return engine; } - void visit_anon_fcn_handle (tree_anon_fcn_handle&); - - void visit_argument_list (tree_argument_list&); - - void visit_binary_expression (tree_binary_expression&); - - void visit_break_command (tree_break_command&); - - void visit_colon_expression (tree_colon_expression&); - - void visit_continue_command (tree_continue_command&); - - void visit_global_command (tree_global_command&); - - void visit_persistent_command (tree_persistent_command&); - - void visit_decl_elt (tree_decl_elt&); - - void visit_decl_init_list (tree_decl_init_list&); - - void visit_simple_for_command (tree_simple_for_command&); - - void visit_complex_for_command (tree_complex_for_command&); - - void visit_octave_user_script (octave_user_script&); - - void visit_octave_user_function (octave_user_function&); - - void visit_octave_user_function_header (octave_user_function&); - - void visit_octave_user_function_trailer (octave_user_function&); - - void visit_function_def (tree_function_def&); - - void visit_identifier (tree_identifier&); - - void visit_if_clause (tree_if_clause&); - - void visit_if_command (tree_if_command&); - - void visit_if_command_list (tree_if_command_list&); - - void visit_index_expression (tree_index_expression&); - - void visit_matrix (tree_matrix&); - - void visit_cell (tree_cell&); - - void visit_multi_assignment (tree_multi_assignment&); - - void visit_no_op_command (tree_no_op_command&); - - void visit_constant (tree_constant&); - - void visit_fcn_handle (tree_fcn_handle&); - - void visit_parameter_list (tree_parameter_list&); - - void visit_postfix_expression (tree_postfix_expression&); - - void visit_prefix_expression (tree_prefix_expression&); + llvm::Module *get_module (void) const { return module; } - void visit_return_command (tree_return_command&); - - void visit_return_list (tree_return_list&); - - void visit_simple_assignment (tree_simple_assignment&); - - void visit_statement (tree_statement&); - - void visit_statement_list (tree_statement_list&); - - void visit_switch_case (tree_switch_case&); - - void visit_switch_case_list (tree_switch_case_list&); - - void visit_switch_command (tree_switch_command&); - - void visit_try_catch_command (tree_try_catch_command&); - - void visit_unwind_protect_command (tree_unwind_protect_command&); - - void visit_while_command (tree_while_command&); - - void visit_do_until_command (tree_do_until_command&); - private: - typedef std::pair value; - - void emit_print (const std::string& name, const value& v); - - void push_value (jit_type *type, llvm::Value *v) - { - value_stack.push_back (value (type, v)); - } - - jit_typeinfo *tinfo; - llvm::Function *function; - - bool is_lvalue; - std::map variables; - std::vector value_stack; - }; - - class function_info - { - public: - function_info (tree_jit& tjit, tree& tee); - - bool execute () const; - - bool match () const; - private: - jit_typeinfo *tinfo; - llvm::ExecutionEngine *engine; - std::set argin; - type_map types; - llvm::Function *function; - }; - - typedef std::list function_list; - typedef std::map compiled_map; - - static void fail (void) - { - throw jit_fail_exception (); - } + void optimize (llvm::Function *fn); + private: + bool initialize (void); llvm::LLVMContext &context; llvm::Module *module; @@ -529,8 +601,26 @@ llvm::ExecutionEngine *engine; jit_typeinfo *tinfo; +}; - compiled_map compiled; +class +OCTINTERP_API +jit_info +{ +public: + jit_info (tree_jit& tjit, tree_simple_for_command& cmd, jit_type *bounds); + + bool execute (const octave_value& bounds) const; + + bool match (void) const; +private: + typedef std::map type_map; + + jit_typeinfo *tinfo; + llvm::ExecutionEngine *engine; + std::set argin; + type_map types; + llvm::Function *function; }; #endif diff --git a/src/pt-loop.cc b/src/pt-loop.cc --- a/src/pt-loop.cc +++ b/src/pt-loop.cc @@ -35,6 +35,7 @@ #include "pt-bp.h" #include "pt-cmd.h" #include "pt-exp.h" +#include "pt-jit.h" #include "pt-jump.h" #include "pt-loop.h" #include "pt-stmt.h" @@ -97,6 +98,10 @@ delete list; delete lead_comm; delete trail_comm; + + for (compiled_map::iterator iter = compiled.begin (); iter != compiled.end (); + ++iter) + delete iter->second; } tree_command * diff --git a/src/pt-loop.h b/src/pt-loop.h --- a/src/pt-loop.h +++ b/src/pt-loop.h @@ -36,6 +36,9 @@ #include "pt-cmd.h" #include "symtab.h" +class jit_info; +class jit_type; + // While. class @@ -180,7 +183,20 @@ void accept (tree_walker& tw); + // some functions use by tree_jit + jit_info *get_info (jit_type *type) const + { + compiled_map::const_iterator iter = compiled.find (type); + return iter != compiled.end () ? iter->second : 0; + } + + void stash_info (jit_type *type, jit_info *jinfo) + { + compiled[type] = jinfo; + } + private: + typedef std::map compiled_map; // TRUE means operate in parallel (subject to the value of the // maxproc expression). @@ -205,6 +221,9 @@ // Comment preceding ENDFOR token. octave_comment_list *trail_comm; + // a map from iterator types -> compiled functions + compiled_map compiled; + // No copying! tree_simple_for_command (const tree_simple_for_command&);