Mercurial > hg > octave-lyh
diff src/pt-jit.cc @ 14899:f25d2224fa02
Initial JIT support
build-aux/common.mk: Add llvm flags.
configure.ac: Link with llvm.
src/Makefile: Add pt-jit.
src/link-deps.mk: Link with llvm.
src/oct-conf.in.h: Add llvm flags.
src/toplev.cc: Add llvm flags.
src/pt-eval.cc: Try to jit statements.
src/pt-jit.cc: New file.
src/pt-jit.h: New file
author | Max Brister <max@2bass.com> |
---|---|
date | Sun, 06 May 2012 20:17:30 -0600 |
parents | |
children | 516b4a15b775 |
line wrap: on
line diff
new file mode 100644 --- /dev/null +++ b/src/pt-jit.cc @@ -0,0 +1,703 @@ +/* + +Copyright (C) 2009-2012 John W. Eaton + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +<http://www.gnu.org/licenses/>. + +*/ + +#define __STDC_LIMIT_MACROS +#define __STDC_CONSTANT_MACROS + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#include "pt-jit.h" + +#include <typeinfo> + +#include <llvm/LLVMContext.h> +#include <llvm/Module.h> +#include <llvm/Function.h> +#include <llvm/BasicBlock.h> +#include <llvm/Support/IRBuilder.h> +#include <llvm/ExecutionEngine/ExecutionEngine.h> +#include <llvm/ExecutionEngine/JIT.h> +#include <llvm/PassManager.h> +#include <llvm/Analysis/Verifier.h> +#include <llvm/Analysis/Passes.h> +#include <llvm/Target/TargetData.h> +#include <llvm/Transforms/Scalar.h> +#include <llvm/Support/TargetSelect.h> +#include <llvm/Support/raw_os_ostream.h> + +#include "ov-fcn-handle.h" +#include "ov-usr-fcn.h" +#include "pt-all.h" + +using namespace llvm; + +//FIXME: Move into tree_jit +static IRBuilder<> builder (getGlobalContext ()); + +extern "C" void +octave_print_double (const char *name, double value) +{ + // FIXME: We should avoid allocating a new octave_scalar each time + octave_value ov (value); + ov.print_with_name (octave_stdout, name); +} + +tree_jit::tree_jit (void) : context (getGlobalContext ()), engine (0) +{ + InitializeNativeTarget (); + InitializeNativeTargetAsmPrinter (); + module = new Module ("octave", context); +} + +tree_jit::~tree_jit (void) +{ + delete module; +} + +bool +tree_jit::execute (tree& tee) +{ + if (!engine) + { + engine = ExecutionEngine::createJIT (module); + + // initialize pass manager + pass_manager = new FunctionPassManager (module); + pass_manager->add (new TargetData(*engine->getTargetData ())); + pass_manager->add (createBasicAliasAnalysisPass ()); + pass_manager->add (createPromoteMemoryToRegisterPass ()); + pass_manager->add (createInstructionCombiningPass ()); + pass_manager->add (createReassociatePass ()); + pass_manager->add (createGVNPass ()); + pass_manager->add (createCFGSimplificationPass ()); + pass_manager->doInitialization (); + + // create external functions + Type *vtype = Type::getVoidTy (context); + std::vector<Type*> pd_args (2); + pd_args[0] = Type::getInt8PtrTy (context); + pd_args[1] = Type::getDoubleTy (context); + FunctionType *print_double_ty = FunctionType::get (vtype, pd_args, false); + print_double = Function::Create (print_double_ty, + Function::ExternalLinkage, + "octave_print_double", module); + engine->addGlobalMapping (print_double, + reinterpret_cast<void*>(&octave_print_double)); + } + + if (!engine) + // sometimes this fails during early initialization + return false; + + // find function + function_info *finfo; + finfo_map_iterator iter = compiled_functions.find (&tee); + + if (iter == compiled_functions.end ()) + finfo = compile (tee); + else + finfo = iter->second; + + return finfo->execute (); +} + +tree_jit::function_info* +tree_jit::compile (tree& tee) +{ + value_stack.clear (); + variables.clear (); + + // setup function + std::vector<Type*> args (2); + args[0] = Type::getInt1PtrTy (context); + args[1] = Type::getDoublePtrTy (context); + FunctionType *ft = FunctionType::get (Type::getVoidTy (context), args, false); + Function *compiling = Function::Create (ft, Function::ExternalLinkage, + "test_fn", module); + + entry_block = BasicBlock::Create (context, "entry", compiling); + BasicBlock *body = BasicBlock::Create (context, "body", + compiling); + builder.SetInsertPoint (body); + + // convert tree to LLVM IR + try + { + tee.accept (*this); + } + catch (const jit_fail_exception&) + { + //FIXME: cleanup + return compiled_functions[&tee] = new function_info (); + } + + // copy input arguments + builder.SetInsertPoint (entry_block); + + Function::arg_iterator arg_iter = compiling->arg_begin (); + Value *arg_defined = arg_iter; + Value *arg_value = ++arg_iter; + + arg_defined->setName ("arg_defined"); + arg_value->setName ("arg_value"); + + size_t idx = 0; + std::vector<std::string> arg_names; + std::vector<bool> arg_used; + for (var_map_iterator iter = variables.begin (); iter != variables.end (); + ++iter, ++idx) + { + arg_names.push_back (iter->first); + arg_used.push_back (iter->second.use); + + Value *gep_defined = builder.CreateConstInBoundsGEP1_32 (arg_defined, idx); + Value *defined = builder.CreateLoad (gep_defined); + builder.CreateStore (defined, iter->second.defined); + + Value *gep_value = builder.CreateConstInBoundsGEP1_32 (arg_value, idx); + Value *value = builder.CreateLoad (gep_value); + builder.CreateStore (value, iter->second.value); + } + builder.CreateBr (body); + + // copy output arguments + BasicBlock *cleanup = BasicBlock::Create (context, "cleanup", compiling); + builder.SetInsertPoint (body); + builder.CreateBr (cleanup); + builder.SetInsertPoint (cleanup); + + idx = 0; + for (var_map_iterator iter = variables.begin (); iter != variables.end (); + ++iter, ++idx) + { + Value *gep_defined = builder.CreateConstInBoundsGEP1_32 (arg_defined, idx); + Value *defined = builder.CreateLoad (iter->second.defined); + builder.CreateStore (defined, gep_defined); + + Value *gep_value = builder.CreateConstInBoundsGEP1_32 (arg_value, idx); + Value *value = builder.CreateLoad (iter->second.value, iter->first); + builder.CreateStore (value, gep_value); + } + + builder.CreateRetVoid (); + + // print what we compiled (for debugging) + // we leave this in for now, as other people might want to view the ir created + // should be removed eventually though + const bool debug_print_ir = false; + if (debug_print_ir) + { + raw_os_ostream os (std::cout); + std::cout << "Compiling --------------------\n"; + tree_print_code tpc (std::cout); + std::cout << typeid (tee).name () << std::endl; + tee.accept (tpc); + std::cout << "\n--------------------\n"; + + std::cout << "llvm_ir\n"; + compiling->print (os); + std::cout << "--------------------\n"; + } + + // compile code + verifyFunction (*compiling); + pass_manager->run (*compiling); + + if (debug_print_ir) + { + raw_os_ostream os (std::cout); + std::cout << "optimized llvm_ir\n"; + compiling->print (os); + std::cout << "--------------------\n"; + } + + jit_function fun = + reinterpret_cast<jit_function> (engine->getPointerToFunction (compiling)); + + return compiled_functions[&tee] = new function_info (fun, arg_names, arg_used); +} + +tree_jit::variable_info +tree_jit::find (const std::string &name, bool use) +{ + var_map_iterator iter = variables.find (name); + if (iter == variables.end ()) + { + // we currently just assume everything is a double + Type *dbl = Type::getDoubleTy (context); + Type *bol = Type::getInt1Ty (context); + IRBuilder<> tmpB (entry_block, entry_block->begin ()); + + variable_info vinfo; + vinfo.defined = tmpB.CreateAlloca (bol, 0); + vinfo.value = tmpB.CreateAlloca (dbl, 0, name); + vinfo.use = use; + variables[name] = vinfo; + return vinfo; + } + else + { + iter->second.use = iter->second.use || use; + return iter->second; + } +} + +void +tree_jit::do_assign (variable_info vinfo, llvm::Value *value) +{ + // create assign expression + Value *result = builder.CreateStore (value, vinfo.value); + value_stack.push_back (result); + + // update defined for lhs + Type *btype = Type::getInt1Ty (context); + Value *btrue = ConstantInt::get (btype, APInt (1, 1)); + builder.CreateStore (btrue, vinfo.defined); +} + +void +tree_jit::emit_print (const std::string& vname, llvm::Value *value) +{ + Value *pname = builder.CreateGlobalStringPtr (vname); + builder.CreateCall2 (print_double, pname, value); +} + +void +tree_jit::visit_anon_fcn_handle (tree_anon_fcn_handle&) +{ + fail (); +} + +void +tree_jit::visit_argument_list (tree_argument_list&) +{ + fail (); +} + +void +tree_jit::visit_binary_expression (tree_binary_expression& be) +{ + tree_expression *lhs = be.lhs (); + tree_expression *rhs = be.rhs (); + if (lhs && rhs) + { + lhs->accept (*this); + rhs->accept (*this); + + Value *lhsv = value_stack.back (); + value_stack.pop_back (); + + Value *rhsv = value_stack.back (); + value_stack.pop_back (); + + Value *result; + switch (be.op_type ()) + { + case octave_value::op_add: + result = builder.CreateFAdd (lhsv, rhsv); + break; + case octave_value::op_sub: + result = builder.CreateFSub (lhsv, rhsv); + break; + case octave_value::op_mul: + result = builder.CreateFMul (lhsv, rhsv); + break; + case octave_value::op_div: + result = builder.CreateFDiv (lhsv, rhsv); + break; + default: + fail (); + } + + value_stack.push_back (result); + } + else + fail (); +} + +void +tree_jit::visit_break_command (tree_break_command&) +{ + fail (); +} + +void +tree_jit::visit_colon_expression (tree_colon_expression&) +{ + fail (); +} + +void +tree_jit::visit_continue_command (tree_continue_command&) +{ + fail (); +} + +void +tree_jit::visit_global_command (tree_global_command&) +{ + fail (); +} + +void +tree_jit::visit_persistent_command (tree_persistent_command&) +{ + fail (); +} + +void +tree_jit::visit_decl_elt (tree_decl_elt&) +{ + fail (); +} + +void +tree_jit::visit_decl_init_list (tree_decl_init_list&) +{ + fail (); +} + +void +tree_jit::visit_simple_for_command (tree_simple_for_command&) +{ + fail (); +} + +void +tree_jit::visit_complex_for_command (tree_complex_for_command&) +{ + fail (); +} + +void +tree_jit::visit_octave_user_script (octave_user_script&) +{ + fail (); +} + +void +tree_jit::visit_octave_user_function (octave_user_function&) +{ + fail (); +} + +void +tree_jit::visit_octave_user_function_header (octave_user_function&) +{ + fail (); +} + +void +tree_jit::visit_octave_user_function_trailer (octave_user_function&) +{ + fail (); +} + +void +tree_jit::visit_function_def (tree_function_def&) +{ + fail (); +} + +void +tree_jit::visit_identifier (tree_identifier& ti) +{ + octave_value ov = ti.do_lookup (); + if (ov.is_function ()) + fail (); + + std::string name = ti.name (); + variable_info vinfo = find (ti.name (), true); + + // TODO check defined + + Value *load_value = builder.CreateLoad (vinfo.value, name); + value_stack.push_back (load_value); +} + +void +tree_jit::visit_if_clause (tree_if_clause&) +{ + fail (); +} + +void +tree_jit::visit_if_command (tree_if_command&) +{ + fail (); +} + +void +tree_jit::visit_if_command_list (tree_if_command_list&) +{ + fail (); +} + +void +tree_jit::visit_index_expression (tree_index_expression&) +{ + fail (); +} + +void +tree_jit::visit_matrix (tree_matrix&) +{ + fail (); +} + +void +tree_jit::visit_cell (tree_cell&) +{ + fail (); +} + +void +tree_jit::visit_multi_assignment (tree_multi_assignment&) +{ + fail (); +} + +void +tree_jit::visit_no_op_command (tree_no_op_command&) +{ + fail (); +} + +void +tree_jit::visit_constant (tree_constant& tc) +{ + octave_value v = tc.rvalue1 (); + if (v.is_real_scalar () && v.is_double_type ()) + { + double dv = v.double_value (); + Value *lv = ConstantFP::get (context, APFloat (dv)); + value_stack.push_back (lv); + } + else + fail (); +} + +void +tree_jit::visit_fcn_handle (tree_fcn_handle&) +{ + fail (); +} + +void +tree_jit::visit_parameter_list (tree_parameter_list&) +{ + fail (); +} + +void +tree_jit::visit_postfix_expression (tree_postfix_expression&) +{ + fail (); +} + +void +tree_jit::visit_prefix_expression (tree_prefix_expression&) +{ + fail (); +} + +void +tree_jit::visit_return_command (tree_return_command&) +{ + fail (); +} + +void +tree_jit::visit_return_list (tree_return_list&) +{ + fail (); +} + +void +tree_jit::visit_simple_assignment (tree_simple_assignment& tsa) +{ + // only support an identifier as lhs + tree_identifier *lhs = dynamic_cast<tree_identifier*> (tsa.left_hand_side ()); + if (!lhs) + fail (); + + variable_info lhsv = find (lhs->name (), false); + + // resolve rhs as normal + tree_expression *rhs = tsa.right_hand_side (); + rhs->accept (*this); + + Value *rhsv = value_stack.back (); + value_stack.pop_back (); + + do_assign (lhsv, rhsv); + + if (tsa.print_result ()) + emit_print (lhs->name (), rhsv); +} + +void +tree_jit::visit_statement (tree_statement& stmt) +{ + tree_command *cmd = stmt.command (); + tree_expression *expr = stmt.expression (); + + if (cmd) + cmd->accept (*this); + else + { + // TODO deal with printing + + // stolen from tree_evaluator::visit_statement + bool do_bind_ans = false; + + if (expr->is_identifier ()) + { + tree_identifier *id = dynamic_cast<tree_identifier *> (expr); + + do_bind_ans = (! id->is_variable ()); + } + else + do_bind_ans = (! expr->is_assignment_expression ()); + + expr->accept (*this); + + if (do_bind_ans) + { + Value *rhs = value_stack.back (); + value_stack.pop_back (); + + variable_info ans = find ("ans", false); + do_assign (ans, rhs); + } + else if (expr->is_identifier () && expr->print_result ()) + { + // FIXME: ugly hack, we need to come up with a way to pass + // nargout to visit_identifier + emit_print (expr->name (), value_stack.back ()); + } + + + value_stack.pop_back (); + } +} + +void +tree_jit::visit_statement_list (tree_statement_list&) +{ + fail (); +} + +void +tree_jit::visit_switch_case (tree_switch_case&) +{ + fail (); +} + +void +tree_jit::visit_switch_case_list (tree_switch_case_list&) +{ + fail (); +} + +void +tree_jit::visit_switch_command (tree_switch_command&) +{ + fail (); +} + +void +tree_jit::visit_try_catch_command (tree_try_catch_command&) +{ + fail (); +} + +void +tree_jit::visit_unwind_protect_command (tree_unwind_protect_command&) +{ + fail (); +} + +void +tree_jit::visit_while_command (tree_while_command&) +{ + fail (); +} + +void +tree_jit::visit_do_until_command (tree_do_until_command&) +{ + fail (); +} + +void +tree_jit::fail (void) +{ + throw jit_fail_exception (); +} + +tree_jit::function_info::function_info (void) : function (0) +{} + +tree_jit::function_info::function_info (jit_function fun, + const std::vector<std::string>& args, + const std::vector<bool>& arg_used) : + function (fun), arguments (args), argument_used (arg_used) +{} + +bool tree_jit::function_info::execute () +{ + if (! function) + return false; + + // FIXME: we are doing hash lookups every time, this has got to be slow + unwind_protect up; + bool *args_defined = new bool[arguments.size ()]; // vector<bool> sucks + up.add_delete (args_defined); + + std::vector<double> args_values (arguments.size ()); + for (size_t i = 0; i < arguments.size (); ++i) + { + octave_value ov = symbol_table::varval (arguments[i]); + + if (argument_used[i]) + { + if (! (ov.is_double_type () && ov.is_real_scalar ())) + return false; + + args_defined[i] = ov.is_defined (); + args_values[i] = ov.double_value (); + } + else + args_defined[i] = false; + } + + function (args_defined, &args_values[0]); + + for (size_t i = 0; i < arguments.size (); ++i) + if (args_defined[i]) + symbol_table::varref (arguments[i]) = octave_value (args_values[i]); + + return true; +}