Mercurial > hg > octave-nkf
changeset 14903:54ea692b8ab5
Reworking JIT implementation
src/TEMPLATE-INST/Array-jit.cc: New file.
src/TEMPLATE-INST/module.mk: Add Array-jit.cc.
src/ov-base.h (octave_base_value::grab,
octave_base_value::release): New functions.
src/pt-jit.cc: Rewrite.
src/pt-jit.h: Rewrite.
author | Max Brister <max@2bass.com> |
---|---|
date | Sat, 12 May 2012 19:24:32 -0600 |
parents | a21bbb5f34d4 |
children | 3513df68d580 |
files | src/TEMPLATE-INST/Array-jit.cc src/TEMPLATE-INST/module.mk src/ov-base.h src/pt-jit.cc src/pt-jit.h |
diffstat | 5 files changed, 1659 insertions(+), 561 deletions(-) [+] |
line wrap: on
line diff
new file mode 100644 --- /dev/null +++ b/src/TEMPLATE-INST/Array-jit.cc @@ -0,0 +1,34 @@ +/* + +Copyright (C) 2012 Max Brister <max@2bass.com> + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +<http://www.gnu.org/licenses/>. + +*/ + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#include "Array.h" +#include "Array.cc" + +#include "pt-jit.h" + +NO_INSTANTIATE_ARRAY_SORT (jit_function::overload); + +INSTANTIATE_ARRAY (jit_function::overload, OCTINTERP_API);
--- a/src/TEMPLATE-INST/module.mk +++ b/src/TEMPLATE-INST/module.mk @@ -2,4 +2,5 @@ TEMPLATE_INST_SRC = \ TEMPLATE-INST/Array-os.cc \ - TEMPLATE-INST/Array-tc.cc + TEMPLATE-INST/Array-tc.cc \ + TEMPLATE-INST/Array-jit.cc
--- a/src/ov-base.h +++ b/src/ov-base.h @@ -755,6 +755,21 @@ virtual bool fast_elem_insert_self (void *where, builtin_type_t btyp) const; + // Grab the reference count. For use by jit. + void + grab (void) + { + ++count; + } + + // Release the reference count. For use by jit. + void + release (void) + { + if (--count == 0) + delete this; + } + protected: // This should only be called for derived types.
--- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -40,525 +40,1185 @@ #include <llvm/ExecutionEngine/JIT.h> #include <llvm/PassManager.h> #include <llvm/Analysis/Verifier.h> +#include <llvm/Analysis/CallGraph.h> #include <llvm/Analysis/Passes.h> #include <llvm/Target/TargetData.h> #include <llvm/Transforms/Scalar.h> +#include <llvm/Transforms/IPO.h> #include <llvm/Support/TargetSelect.h> #include <llvm/Support/raw_os_ostream.h> +#include <llvm/ExecutionEngine/GenericValue.h> +#include "octave.h" #include "ov-fcn-handle.h" #include "ov-usr-fcn.h" #include "pt-all.h" -using namespace llvm; +// FIXME: Remove eventually +// For now we leave this in so people tell when JIT actually happens +static const bool debug_print = false; //FIXME: Move into tree_jit -static IRBuilder<> builder (getGlobalContext ()); +static llvm::IRBuilder<> builder (llvm::getGlobalContext ()); + +// function that jit code calls +extern "C" void +octave_jit_print_any (const char *name, octave_base_value *obv) +{ + obv->print_with_name (octave_stdout, name, true); +} extern "C" void -octave_print_double (const char *name, double value) +octave_jit_print_double (const char *name, double value) { // FIXME: We should avoid allocating a new octave_scalar each time octave_value ov (value); ov.print_with_name (octave_stdout, name); } -tree_jit::tree_jit (void) : context (getGlobalContext ()), engine (0) +extern "C" octave_base_value* +octave_jit_binary_any_any (octave_value::binary_op op, octave_base_value *lhs, + octave_base_value *rhs) +{ + octave_value olhs (lhs, true); + octave_value orhs (rhs, true); + octave_value result = do_binary_op (op, olhs, orhs); + octave_base_value *rep = result.internal_rep (); + rep->grab (); + return rep; +} + +extern "C" void +octave_jit_assign_any_any_help (octave_base_value *lhs, octave_base_value *rhs) +{ + if (lhs != rhs) + { + rhs->grab (); + lhs->release (); + } +} + +// -------------------- jit_type -------------------- +llvm::Type * +jit_type::to_llvm_arg (void) const +{ + return llvm_type ? llvm_type->getPointerTo () : 0; +} + +// -------------------- jit_function -------------------- +void +jit_function::add_overload (const overload& func, + const std::vector<jit_type*>& args) +{ + if (args.size () >= overloads.size ()) + overloads.resize (args.size () + 1); + + Array<overload>& over = overloads[args.size ()]; + dim_vector dv (over.dims ()); + Array<octave_idx_type> idx = to_idx (args); + bool must_resize = false; + + if (dv.length () != idx.numel ()) + { + dv.resize (idx.numel ()); + must_resize = true; + } + + for (octave_idx_type i = 0; i < dv.length (); ++i) + if (dv(i) <= idx(i)) + { + must_resize = true; + dv(i) = idx(i) + 1; + } + + if (must_resize) + over.resize (dv); + + over(idx) = func; +} + +const jit_function::overload& +jit_function::get_overload (const std::vector<jit_type*>& types) const +{ + // FIXME: We should search for the next best overload on failure + static overload null_overload; + if (types.size () >= overloads.size ()) + return null_overload; + + const Array<overload>& over = overloads[types.size ()]; + dim_vector dv (over.dims ()); + Array<octave_idx_type> idx = to_idx (types); + for (octave_idx_type i = 0; i < dv.length (); ++i) + if (idx(i) >= dv(i)) + return null_overload; + + return over(idx); +} + +Array<octave_idx_type> +jit_function::to_idx (const std::vector<jit_type*>& types) const +{ + octave_idx_type numel = types.size (); + if (numel == 1) + numel = 2; + + Array<octave_idx_type> idx (dim_vector (1, numel)); + for (octave_idx_type i = 0; i < static_cast<octave_idx_type> (types.size ()); + ++i) + idx(i) = types[i]->type_id (); + + if (types.size () == 1) + { + idx(1) = idx(0); + idx(0) = 0; + } + + return idx; +} + +// -------------------- jit_typeinfo -------------------- +jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e, llvm::Type *ov) + : module (m), engine (e), next_id (0), ov_t (ov) { - InitializeNativeTarget (); - InitializeNativeTargetAsmPrinter (); - module = new Module ("octave", context); + // FIXME: We should be registering types like in octave_value_typeinfo + llvm::LLVMContext &ctx = m->getContext (); + + // create types + any = new_type ("any", true, 0, ov_t); + scalar = new_type ("scalar", false, any, llvm::Type::getDoubleTy (ctx)); + + // any with anything is an any op + llvm::IRBuilder<> fn_builder (ctx); + + llvm::Type *binary_op_type + = llvm::Type::getIntNTy (ctx, sizeof (octave_value::binary_op)); + std::vector<llvm::Type*> args (3); + args[0] = binary_op_type; + args[1] = args[2] = any->to_llvm (); + llvm::FunctionType *any_binary_t = llvm::FunctionType::get (ov_t, args, false); + llvm::Function *any_binary = llvm::Function::Create (any_binary_t, + llvm::Function::ExternalLinkage, + "octave_jit_binary_any_any", + module); + engine->addGlobalMapping (any_binary, + reinterpret_cast<void*>(&octave_jit_binary_any_any)); + + args.resize (2); + args[0] = any->to_llvm (); + args[1] = any->to_llvm (); + + binary_ops.resize (octave_value::num_binary_ops); + for (int op = 0; op < octave_value::num_binary_ops; ++op) + { + llvm::FunctionType *ftype = llvm::FunctionType::get (ov_t, args, false); + + llvm::Twine fn_name ("octave_jit_binary_any_any_"); + fn_name = fn_name + llvm::Twine (op); + llvm::Function *fn = llvm::Function::Create (ftype, + llvm::Function::ExternalLinkage, + fn_name, module); + llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + fn_builder.SetInsertPoint (block); + llvm::APInt op_int(sizeof (octave_value::binary_op), op, + std::numeric_limits<octave_value::binary_op>::is_signed); + llvm::Value *op_as_llvm = llvm::ConstantInt::get (binary_op_type, op_int); + llvm::Value *ret = fn_builder.CreateCall3 (any_binary, + op_as_llvm, + fn->arg_begin (), + ++fn->arg_begin ()); + fn_builder.CreateRet (ret); + + jit_function::overload overload (fn, true, any, any, any); + for (octave_idx_type i = 0; i < next_id; ++i) + binary_ops[op].add_overload (overload); + } + + // assign any = any + llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); + args.resize (2); + args[0] = any->to_llvm (); + args[1] = any->to_llvm (); + llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, args, false); + llvm::Function *fn_help = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "octave_jit_assign_any_any_help", + module); + engine->addGlobalMapping (fn_help, + reinterpret_cast<void*>(&octave_jit_assign_any_any_help)); + + args.resize (2); + args[0] = any->to_llvm_arg (); + args[1] = any->to_llvm (); + ft = llvm::FunctionType::get (tvoid, args, false); + llvm::Function *fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "octave_jit_assign_any_any", + module); + fn->addFnAttr (llvm::Attribute::AlwaysInline); + llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", fn); + fn_builder.SetInsertPoint (body); + llvm::Value *value = fn_builder.CreateLoad (fn->arg_begin ()); + fn_builder.CreateCall2 (fn_help, value, ++fn->arg_begin ()); + fn_builder.CreateStore (++fn->arg_begin (), fn->arg_begin ()); + fn_builder.CreateRetVoid (); + llvm::verifyFunction (*fn); + assign_fn.add_overload (fn, false, 0, any, any); + + // assign scalar = scalar + args.resize (2); + args[0] = scalar->to_llvm_arg (); + args[1] = scalar->to_llvm (); + ft = llvm::FunctionType::get (tvoid, args, false); + fn = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "octave_jit_assign_scalar_scalar", module); + fn->addFnAttr (llvm::Attribute::AlwaysInline); + body = llvm::BasicBlock::Create (ctx, "body", fn); + fn_builder.SetInsertPoint (body); + fn_builder.CreateStore (++fn->arg_begin (), fn->arg_begin ()); + fn_builder.CreateRetVoid (); + llvm::verifyFunction (*fn); + assign_fn.add_overload (fn, false, 0, scalar, scalar); + + // now for binary scalar operations + // FIXME: Finish all operations + add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd); + add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub); + add_binary_op (scalar, octave_value::op_mul, llvm::Instruction::FMul); + add_binary_op (scalar, octave_value::op_el_mul, llvm::Instruction::FMul); + + // FIXME: Warn if rhs is zero + add_binary_op (scalar, octave_value::op_div, llvm::Instruction::FDiv); + add_binary_op (scalar, octave_value::op_el_div, llvm::Instruction::FDiv); + + // now for printing functions + add_print (any, reinterpret_cast<void*> (&octave_jit_print_any)); + add_print (scalar, reinterpret_cast<void*> (&octave_jit_print_double)); +} + +void +jit_typeinfo::add_print (jit_type *ty, void *call) +{ + llvm::LLVMContext& ctx = llvm::getGlobalContext (); + llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); + std::vector<llvm::Type *> args (2); + args[0] = llvm::Type::getInt8PtrTy (ctx); + args[1] = ty->to_llvm (); + + std::stringstream name; + name << "octave_jit_print_" << ty->name (); + + llvm::FunctionType *print_ty = llvm::FunctionType::get (tvoid, args, false); + llvm::Function *fn = llvm::Function::Create (print_ty, + llvm::Function::ExternalLinkage, + name.str (), module); + engine->addGlobalMapping (fn, call); + + jit_function::overload ol (fn, false, 0, ty); + print_fn.add_overload (ol); +} + +void +jit_typeinfo::add_binary_op (jit_type *ty, int op, int llvm_op) +{ + llvm::LLVMContext& ctx = llvm::getGlobalContext (); + std::vector<llvm::Type *> args (2, ty->to_llvm ()); + llvm::FunctionType *ft = llvm::FunctionType::get (ty->to_llvm (), args, + false); + + std::stringstream fname; + octave_value::binary_op ov_op = static_cast<octave_value::binary_op>(op); + fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + llvm::Function *fn = llvm::Function::Create (ft, + llvm::Function::ExternalLinkage, + fname.str (), + module); + fn->addFnAttr (llvm::Attribute::AlwaysInline); + llvm::BasicBlock *block = llvm::BasicBlock::Create (ctx, "body", fn); + llvm::IRBuilder<> fn_builder (block); + llvm::Instruction::BinaryOps temp + = static_cast<llvm::Instruction::BinaryOps>(llvm_op); + llvm::Value *ret = fn_builder.CreateBinOp (temp, fn->arg_begin (), + ++fn->arg_begin ()); + fn_builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + jit_function::overload ol(fn, false, ty, ty, ty); + binary_ops[op].add_overload (ol); +} + +jit_type* +jit_typeinfo::type_of (const octave_value &ov) const +{ + if (ov.is_undefined () || ov.is_function ()) + return 0; + + if (ov.is_double_type () && ov.is_real_scalar ()) + return get_scalar (); + + return get_any (); +} + +const jit_function& +jit_typeinfo::binary_op (int op) const +{ + return binary_ops[op]; +} + +const jit_function::overload& +jit_typeinfo::assign_op (jit_type *lhs, jit_type *rhs) const +{ + assert (lhs == rhs); + return assign_fn.get_overload (lhs, rhs); +} + +const jit_function::overload& +jit_typeinfo::print_value (jit_type *to_print) const +{ + return print_fn.get_overload (to_print); +} + +void +jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv) +{ + // duplication here can probably be removed somehow + if (type == any) + to_generic (type, gv, octave_value ()); + else if (type == scalar) + to_generic (type, gv, octave_value (0)); + else + assert (false && "Type not supported yet"); +} + +void +jit_typeinfo::to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov) +{ + if (type == any) + { + octave_base_value *obv = ov.internal_rep (); + obv->grab (); + ov_out[ov_out_idx] = obv; + gv.PointerVal = &ov_out[ov_out_idx++]; + } + else + { + scalar_out[scalar_out_idx] = ov.double_value (); + gv.PointerVal = &scalar_out[scalar_out_idx++]; + } +} + +octave_value +jit_typeinfo::to_octave_value (jit_type *type, llvm::GenericValue& gv) +{ + if (type == any) + { + octave_base_value **ptr = reinterpret_cast<octave_base_value**>(gv.PointerVal); + return octave_value (*ptr); + } + else if (type == scalar) + { + double *ptr = reinterpret_cast<double*>(gv.PointerVal); + return octave_value (*ptr); + } + else + assert (false && "Type not supported yet"); +} + +void +jit_typeinfo::reset_generic (size_t nargs) +{ + scalar_out_idx = 0; + ov_out_idx = 0; + + if (scalar_out.size () < nargs) + scalar_out.resize (nargs); + + if (ov_out.size () < nargs) + ov_out.resize (nargs); +} + +jit_type* +jit_typeinfo::new_type (const std::string& name, bool force_init, + jit_type *parent, llvm::Type *llvm_type) +{ + jit_type *ret = new jit_type (name, force_init, parent, llvm_type, next_id++); + id_to_type.push_back (ret); + return ret; +} + +tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0) +{ + llvm::InitializeNativeTarget (); + module = new llvm::Module ("octave", context); + assert (module); } tree_jit::~tree_jit (void) { - delete module; + for (compiled_map::iterator iter = compiled.begin (); iter != compiled.end (); + ++iter) + { + function_list& flist = iter->second; + for (function_list::iterator fiter = flist.begin (); fiter != flist.end (); + ++fiter) + delete *fiter; + } + + delete tinfo; } bool tree_jit::execute (tree& tee) { - if (!engine) + // something funny happens during initialization with the engine + bool need_init = false; + if (! engine) { - engine = ExecutionEngine::createJIT (module); - - // initialize pass manager - pass_manager = new FunctionPassManager (module); - pass_manager->add (new TargetData(*engine->getTargetData ())); - pass_manager->add (createBasicAliasAnalysisPass ()); - pass_manager->add (createPromoteMemoryToRegisterPass ()); - pass_manager->add (createInstructionCombiningPass ()); - pass_manager->add (createReassociatePass ()); - pass_manager->add (createGVNPass ()); - pass_manager->add (createCFGSimplificationPass ()); - pass_manager->doInitialization (); - - // create external functions - Type *vtype = Type::getVoidTy (context); - std::vector<Type*> pd_args (2); - pd_args[0] = Type::getInt8PtrTy (context); - pd_args[1] = Type::getDoubleTy (context); - FunctionType *print_double_ty = FunctionType::get (vtype, pd_args, false); - print_double = Function::Create (print_double_ty, - Function::ExternalLinkage, - "octave_print_double", module); - engine->addGlobalMapping (print_double, - reinterpret_cast<void*>(&octave_print_double)); + need_init = true; + engine = llvm::ExecutionEngine::createJIT (module); } - if (!engine) - // sometimes this fails during early initialization + if (! engine) return false; - // find function - function_info *finfo; - finfo_map_iterator iter = compiled_functions.find (&tee); + if (need_init) + { + module_pass_manager = new llvm::PassManager (); + module_pass_manager->add (llvm::createAlwaysInlinerPass ()); + + pass_manager = new llvm::FunctionPassManager (module); + pass_manager->add (new llvm::TargetData(*engine->getTargetData ())); + pass_manager->add (llvm::createBasicAliasAnalysisPass ()); + pass_manager->add (llvm::createPromoteMemoryToRegisterPass ()); + pass_manager->add (llvm::createInstructionCombiningPass ()); + pass_manager->add (llvm::createReassociatePass ()); + pass_manager->add (llvm::createGVNPass ()); + pass_manager->add (llvm::createCFGSimplificationPass ()); + pass_manager->doInitialization (); + + llvm::Type *ov_t = llvm::StructType::create (context, "octave_base_value"); + ov_t = ov_t->getPointerTo (); + + tinfo = new jit_typeinfo (module, engine, ov_t); + } + + function_list& fnlist = compiled[&tee]; + for (function_list::iterator iter = fnlist.begin (); iter != fnlist.end (); + ++iter) + { + function_info& fi = **iter; + if (fi.match ()) + return fi.execute (); + } + + function_info *fi = new function_info (*this, tee); + fnlist.push_back (fi); + + return fi->execute (); +} + +void +tree_jit::type_infer::visit_anon_fcn_handle (tree_anon_fcn_handle&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_argument_list (tree_argument_list&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_binary_expression (tree_binary_expression& be) +{ + if (is_lvalue) + fail (); + + tree_expression *lhs = be.lhs (); + lhs->accept (*this); + jit_type *tlhs = type_stack.back (); + type_stack.pop_back (); + + tree_expression *rhs = be.rhs (); + rhs->accept (*this); + jit_type *trhs = type_stack.back (); + + jit_type *result = tinfo->binary_op_result (be.op_type (), tlhs, trhs); + if (! result) + fail (); + + type_stack.push_back (result); +} + +void +tree_jit::type_infer::visit_break_command (tree_break_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_colon_expression (tree_colon_expression&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_continue_command (tree_continue_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_global_command (tree_global_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_persistent_command (tree_persistent_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_decl_elt (tree_decl_elt&) +{ + fail (); +} - if (iter == compiled_functions.end ()) - finfo = compile (tee); - else - finfo = iter->second; +void +tree_jit::type_infer::visit_decl_init_list (tree_decl_init_list&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_simple_for_command (tree_simple_for_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_complex_for_command (tree_complex_for_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_octave_user_script (octave_user_script&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_octave_user_function (octave_user_function&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_octave_user_function_header (octave_user_function&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_octave_user_function_trailer (octave_user_function&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_function_def (tree_function_def&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_identifier (tree_identifier& ti) +{ + handle_identifier (ti.name (), ti.do_lookup ()); +} - return finfo->execute (); +void +tree_jit::type_infer::visit_if_clause (tree_if_clause&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_if_command (tree_if_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_if_command_list (tree_if_command_list&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_index_expression (tree_index_expression&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_matrix (tree_matrix&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_cell (tree_cell&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_multi_assignment (tree_multi_assignment&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_no_op_command (tree_no_op_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_constant (tree_constant& tc) +{ + if (is_lvalue) + fail (); + + octave_value v = tc.rvalue1 (); + jit_type *type = tinfo->type_of (v); + if (! type) + fail (); + + type_stack.push_back (type); } -tree_jit::function_info* -tree_jit::compile (tree& tee) +void +tree_jit::type_infer::visit_fcn_handle (tree_fcn_handle&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_parameter_list (tree_parameter_list&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_postfix_expression (tree_postfix_expression&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_prefix_expression (tree_prefix_expression&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_return_command (tree_return_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_return_list (tree_return_list&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_simple_assignment (tree_simple_assignment& tsa) +{ + if (is_lvalue) + fail (); + + // resolve rhs + is_lvalue = false; + tree_expression *rhs = tsa.right_hand_side (); + rhs->accept (*this); + + jit_type *trhs = type_stack.back (); + type_stack.pop_back (); + + // resolve lhs + is_lvalue = true; + rvalue_type = trhs; + tree_expression *lhs = tsa.left_hand_side (); + lhs->accept (*this); + + // we don't pop back here, as the resulting type should be the rhs type + // which is equal to the lhs type anways + jit_type *tlhs = type_stack.back (); + if (tlhs != trhs) + fail (); + + is_lvalue = false; + rvalue_type = 0; +} + +void +tree_jit::type_infer::visit_statement (tree_statement& stmt) { - value_stack.clear (); - variables.clear (); + if (is_lvalue) + fail (); + + tree_command *cmd = stmt.command (); + tree_expression *expr = stmt.expression (); + + if (cmd) + cmd->accept (*this); + else + { + // ok, this check for ans appears three times as cp + bool do_bind_ans = false; + + if (expr->is_identifier ()) + { + tree_identifier *id = dynamic_cast<tree_identifier *> (expr); + + do_bind_ans = (! id->is_variable ()); + } + else + do_bind_ans = (! expr->is_assignment_expression ()); + + expr->accept (*this); + + if (do_bind_ans) + { + is_lvalue = true; + rvalue_type = type_stack.back (); + type_stack.pop_back (); + handle_identifier ("ans", symbol_table::varval ("ans")); + + if (rvalue_type != type_stack.back ()) + fail (); + + is_lvalue = false; + rvalue_type = 0; + } + + type_stack.pop_back (); + } +} + +void +tree_jit::type_infer::visit_statement_list (tree_statement_list&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_switch_case (tree_switch_case&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_switch_case_list (tree_switch_case_list&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_switch_command (tree_switch_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_try_catch_command (tree_try_catch_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_unwind_protect_command (tree_unwind_protect_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_while_command (tree_while_command&) +{ + fail (); +} + +void +tree_jit::type_infer::visit_do_until_command (tree_do_until_command&) +{ + fail (); +} + +void +tree_jit::type_infer::handle_identifier (const std::string& name, octave_value v) +{ + type_map::iterator iter = types.find (name); + if (iter == types.end ()) + { + jit_type *ty = tinfo->type_of (v); + if (is_lvalue) + { + if (! ty) + ty = rvalue_type; + } + else + { + if (! ty) + fail (); + + argin.insert (name); + } + + types[name] = ty; + type_stack.push_back (ty); + } + else + type_stack.push_back (iter->second); +} + +tree_jit::code_generator::code_generator (jit_typeinfo *ti, llvm::Module *module, + tree &tee, + const std::set<std::string>& argin, + const type_map& infered_types) + : tinfo (ti), is_lvalue (false) - // setup function - std::vector<Type*> args (2); - args[0] = Type::getInt1PtrTy (context); - args[1] = Type::getDoublePtrTy (context); - FunctionType *ft = FunctionType::get (Type::getVoidTy (context), args, false); - Function *compiling = Function::Create (ft, Function::ExternalLinkage, - "test_fn", module); +{ + // determine the function type through the type of all variables + std::vector<llvm::Type *> arg_types (infered_types.size ()); + size_t idx = 0; + type_map::const_iterator iter; + for (iter = infered_types.begin (); iter != infered_types.end (); ++iter, ++idx) + arg_types[idx] = iter->second->to_llvm_arg (); + + // now create the LLVM function from our determined types + llvm::LLVMContext &ctx = llvm::getGlobalContext (); + llvm::Type *tvoid = llvm::Type::getVoidTy (ctx); + llvm::FunctionType *ft = llvm::FunctionType::get (tvoid, arg_types, false); + function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "foobar", module); - entry_block = BasicBlock::Create (context, "entry", compiling); - BasicBlock *body = BasicBlock::Create (context, "body", - compiling); + // declare each argument and copy its initial value + llvm::BasicBlock *body = llvm::BasicBlock::Create (ctx, "body", function); builder.SetInsertPoint (body); + llvm::Function::arg_iterator arg_iter = function->arg_begin(); + for (iter = infered_types.begin (); iter != infered_types.end (); + ++iter, ++arg_iter) + + { + llvm::Type *vartype = iter->second->to_llvm (); + llvm::Value *var = builder.CreateAlloca (vartype, 0, iter->first); + variables[iter->first] = value (iter->second, var); - // convert tree to LLVM IR + if (iter->second->force_init () || argin.count (iter->first)) + { + llvm::Value *loaded_arg = builder.CreateLoad (arg_iter); + builder.CreateStore (loaded_arg, var); + } + } + + // generate body try { tee.accept (*this); } catch (const jit_fail_exception&) { - //FIXME: cleanup - return compiled_functions[&tee] = new function_info (); - } - - // copy input arguments - builder.SetInsertPoint (entry_block); - - Function::arg_iterator arg_iter = compiling->arg_begin (); - Value *arg_defined = arg_iter; - Value *arg_value = ++arg_iter; - - arg_defined->setName ("arg_defined"); - arg_value->setName ("arg_value"); - - size_t idx = 0; - std::vector<std::string> arg_names; - std::vector<bool> arg_used; - for (var_map_iterator iter = variables.begin (); iter != variables.end (); - ++iter, ++idx) - { - arg_names.push_back (iter->first); - arg_used.push_back (iter->second.use); - - Value *gep_defined = builder.CreateConstInBoundsGEP1_32 (arg_defined, idx); - Value *defined = builder.CreateLoad (gep_defined); - builder.CreateStore (defined, iter->second.defined); - - Value *gep_value = builder.CreateConstInBoundsGEP1_32 (arg_value, idx); - Value *value = builder.CreateLoad (gep_value); - builder.CreateStore (value, iter->second.value); - } - builder.CreateBr (body); - - // copy output arguments - BasicBlock *cleanup = BasicBlock::Create (context, "cleanup", compiling); - builder.SetInsertPoint (body); - builder.CreateBr (cleanup); - builder.SetInsertPoint (cleanup); - - idx = 0; - for (var_map_iterator iter = variables.begin (); iter != variables.end (); - ++iter, ++idx) - { - Value *gep_defined = builder.CreateConstInBoundsGEP1_32 (arg_defined, idx); - Value *defined = builder.CreateLoad (iter->second.defined); - builder.CreateStore (defined, gep_defined); - - Value *gep_value = builder.CreateConstInBoundsGEP1_32 (arg_value, idx); - Value *value = builder.CreateLoad (iter->second.value, iter->first); - builder.CreateStore (value, gep_value); + function->eraseFromParent (); + function = 0; + return; } - builder.CreateRetVoid (); - - // print what we compiled (for debugging) - // we leave this in for now, as other people might want to view the ir created - // should be removed eventually though - const bool debug_print_ir = false; - if (debug_print_ir) + // copy computed values back into arguments + arg_iter = function->arg_begin (); + for (iter = infered_types.begin (); iter != infered_types.end (); + ++iter, ++arg_iter) { - raw_os_ostream os (std::cout); - std::cout << "Compiling --------------------\n"; - tree_print_code tpc (std::cout); - std::cout << typeid (tee).name () << std::endl; - tee.accept (tpc); - std::cout << "\n--------------------\n"; - - std::cout << "llvm_ir\n"; - compiling->print (os); - std::cout << "--------------------\n"; + llvm::Value *var = variables[iter->first].second; + llvm::Value *loaded_var = builder.CreateLoad (var); + builder.CreateStore (loaded_var, arg_iter); } - - // compile code - verifyFunction (*compiling); - pass_manager->run (*compiling); - - if (debug_print_ir) - { - raw_os_ostream os (std::cout); - std::cout << "optimized llvm_ir\n"; - compiling->print (os); - std::cout << "--------------------\n"; - } - - jit_function fun = - reinterpret_cast<jit_function> (engine->getPointerToFunction (compiling)); - - return compiled_functions[&tee] = new function_info (fun, arg_names, arg_used); + builder.CreateRetVoid (); } -tree_jit::variable_info -tree_jit::find (const std::string &name, bool use) +void +tree_jit::code_generator::visit_anon_fcn_handle (tree_anon_fcn_handle&) { - var_map_iterator iter = variables.find (name); - if (iter == variables.end ()) - { - // we currently just assume everything is a double - Type *dbl = Type::getDoubleTy (context); - Type *bol = Type::getInt1Ty (context); - IRBuilder<> tmpB (entry_block, entry_block->begin ()); + fail (); +} - variable_info vinfo; - vinfo.defined = tmpB.CreateAlloca (bol, 0); - vinfo.value = tmpB.CreateAlloca (dbl, 0, name); - vinfo.use = use; - variables[name] = vinfo; - return vinfo; - } - else - { - iter->second.use = iter->second.use || use; - return iter->second; - } +void +tree_jit::code_generator::visit_argument_list (tree_argument_list&) +{ + fail (); } void -tree_jit::do_assign (variable_info vinfo, llvm::Value *value) +tree_jit::code_generator::visit_binary_expression (tree_binary_expression& be) { - // create assign expression - Value *result = builder.CreateStore (value, vinfo.value); - value_stack.push_back (result); + tree_expression *lhs = be.lhs (); + lhs->accept (*this); + value lhsv = value_stack.back (); + value_stack.pop_back (); + + tree_expression *rhs = be.rhs (); + rhs->accept (*this); + value rhsv = value_stack.back (); + value_stack.pop_back (); + + const jit_function::overload& ol + = tinfo->binary_op_overload (be.op_type (), lhsv.first, rhsv.first); - // update defined for lhs - Type *btype = Type::getInt1Ty (context); - Value *btrue = ConstantInt::get (btype, APInt (1, 1)); - builder.CreateStore (btrue, vinfo.defined); + if (! ol.function) + fail (); + + llvm::Value *result = builder.CreateCall2 (ol.function, lhsv.second, + rhsv.second); + push_value (ol.result, result); +} + +void +tree_jit::code_generator::visit_break_command (tree_break_command&) +{ + fail (); } void -tree_jit::emit_print (const std::string& vname, llvm::Value *value) +tree_jit::code_generator::visit_colon_expression (tree_colon_expression&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_continue_command (tree_continue_command&) { - Value *pname = builder.CreateGlobalStringPtr (vname); - builder.CreateCall2 (print_double, pname, value); + fail (); +} + +void +tree_jit::code_generator::visit_global_command (tree_global_command&) +{ + fail (); } void -tree_jit::visit_anon_fcn_handle (tree_anon_fcn_handle&) +tree_jit::code_generator::visit_persistent_command (tree_persistent_command&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_decl_elt (tree_decl_elt&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_decl_init_list (tree_decl_init_list&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_simple_for_command (tree_simple_for_command&) { fail (); } void -tree_jit::visit_argument_list (tree_argument_list&) +tree_jit::code_generator::visit_complex_for_command (tree_complex_for_command&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_octave_user_script (octave_user_script&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_octave_user_function (octave_user_function&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_octave_user_function_header (octave_user_function&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_octave_user_function_trailer (octave_user_function&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_function_def (tree_function_def&) { fail (); } void -tree_jit::visit_binary_expression (tree_binary_expression& be) +tree_jit::code_generator::visit_identifier (tree_identifier& ti) { - tree_expression *lhs = be.lhs (); - tree_expression *rhs = be.rhs (); - if (lhs && rhs) + std::string name = ti.name (); + value variable = variables[name]; + if (is_lvalue) + value_stack.push_back (variable); + else { - lhs->accept (*this); - rhs->accept (*this); + llvm::Value *load = builder.CreateLoad (variable.second, name); + push_value (variable.first, load); + } +} + +void +tree_jit::code_generator::visit_if_clause (tree_if_clause&) +{ + fail (); +} - Value *lhsv = value_stack.back (); - value_stack.pop_back (); +void +tree_jit::code_generator::visit_if_command (tree_if_command&) +{ + fail (); +} - Value *rhsv = value_stack.back (); - value_stack.pop_back (); +void +tree_jit::code_generator::visit_if_command_list (tree_if_command_list&) +{ + fail (); +} - Value *result; - switch (be.op_type ()) - { - case octave_value::op_add: - result = builder.CreateFAdd (lhsv, rhsv); - break; - case octave_value::op_sub: - result = builder.CreateFSub (lhsv, rhsv); - break; - case octave_value::op_mul: - result = builder.CreateFMul (lhsv, rhsv); - break; - case octave_value::op_div: - result = builder.CreateFDiv (lhsv, rhsv); - break; - default: - fail (); - } +void +tree_jit::code_generator::visit_index_expression (tree_index_expression&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_matrix (tree_matrix&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_cell (tree_cell&) +{ + fail (); +} - value_stack.push_back (result); +void +tree_jit::code_generator::visit_multi_assignment (tree_multi_assignment&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_no_op_command (tree_no_op_command&) +{ + fail (); +} + +void +tree_jit::code_generator::visit_constant (tree_constant& tc) +{ + octave_value v = tc.rvalue1 (); + if (v.is_real_scalar () && v.is_double_type ()) + { + llvm::LLVMContext& ctx = llvm::getGlobalContext (); + double dv = v.double_value (); + llvm::Value *lv = llvm::ConstantFP::get (ctx, llvm::APFloat (dv)); + push_value (tinfo->get_scalar (), lv); } else fail (); } void -tree_jit::visit_break_command (tree_break_command&) -{ - fail (); -} - -void -tree_jit::visit_colon_expression (tree_colon_expression&) -{ - fail (); -} - -void -tree_jit::visit_continue_command (tree_continue_command&) +tree_jit::code_generator::visit_fcn_handle (tree_fcn_handle&) { fail (); } void -tree_jit::visit_global_command (tree_global_command&) -{ - fail (); -} - -void -tree_jit::visit_persistent_command (tree_persistent_command&) -{ - fail (); -} - -void -tree_jit::visit_decl_elt (tree_decl_elt&) -{ - fail (); -} - -void -tree_jit::visit_decl_init_list (tree_decl_init_list&) -{ - fail (); -} - -void -tree_jit::visit_simple_for_command (tree_simple_for_command&) +tree_jit::code_generator::visit_parameter_list (tree_parameter_list&) { fail (); } void -tree_jit::visit_complex_for_command (tree_complex_for_command&) -{ - fail (); -} - -void -tree_jit::visit_octave_user_script (octave_user_script&) -{ - fail (); -} - -void -tree_jit::visit_octave_user_function (octave_user_function&) -{ - fail (); -} - -void -tree_jit::visit_octave_user_function_header (octave_user_function&) +tree_jit::code_generator::visit_postfix_expression (tree_postfix_expression&) { fail (); } void -tree_jit::visit_octave_user_function_trailer (octave_user_function&) -{ - fail (); -} - -void -tree_jit::visit_function_def (tree_function_def&) +tree_jit::code_generator::visit_prefix_expression (tree_prefix_expression&) { fail (); } void -tree_jit::visit_identifier (tree_identifier& ti) -{ - octave_value ov = ti.do_lookup (); - if (ov.is_function ()) - fail (); - - std::string name = ti.name (); - variable_info vinfo = find (ti.name (), true); - - // TODO check defined - - Value *load_value = builder.CreateLoad (vinfo.value, name); - value_stack.push_back (load_value); -} - -void -tree_jit::visit_if_clause (tree_if_clause&) +tree_jit::code_generator::visit_return_command (tree_return_command&) { fail (); } void -tree_jit::visit_if_command (tree_if_command&) -{ - fail (); -} - -void -tree_jit::visit_if_command_list (tree_if_command_list&) -{ - fail (); -} - -void -tree_jit::visit_index_expression (tree_index_expression&) -{ - fail (); -} - -void -tree_jit::visit_matrix (tree_matrix&) -{ - fail (); -} - -void -tree_jit::visit_cell (tree_cell&) -{ - fail (); -} - -void -tree_jit::visit_multi_assignment (tree_multi_assignment&) -{ - fail (); -} - -void -tree_jit::visit_no_op_command (tree_no_op_command&) +tree_jit::code_generator::visit_return_list (tree_return_list&) { fail (); } void -tree_jit::visit_constant (tree_constant& tc) +tree_jit::code_generator::visit_simple_assignment (tree_simple_assignment& tsa) { - octave_value v = tc.rvalue1 (); - if (v.is_real_scalar () && v.is_double_type ()) - { - double dv = v.double_value (); - Value *lv = ConstantFP::get (context, APFloat (dv)); - value_stack.push_back (lv); - } - else + if (is_lvalue) fail (); -} + + // resolve lhs + is_lvalue = true; + tree_expression *lhs = tsa.left_hand_side (); + lhs->accept (*this); + + value lhsv = value_stack.back (); + value_stack.pop_back (); -void -tree_jit::visit_fcn_handle (tree_fcn_handle&) -{ - fail (); -} + // resolve rhs + is_lvalue = false; + tree_expression *rhs = tsa.right_hand_side (); + rhs->accept (*this); + + value rhsv = value_stack.back (); + value_stack.pop_back (); -void -tree_jit::visit_parameter_list (tree_parameter_list&) -{ - fail (); -} + // do assign, then store rhs as the result + jit_function::overload ol = tinfo->assign_op (lhsv.first, rhsv.first); + builder.CreateCall2 (ol.function, lhsv.second, rhsv.second); -void -tree_jit::visit_postfix_expression (tree_postfix_expression&) -{ - fail (); + if (tsa.print_result ()) + emit_print (lhs->name (), rhsv); + + value_stack.push_back (rhsv); } void -tree_jit::visit_prefix_expression (tree_prefix_expression&) -{ - fail (); -} - -void -tree_jit::visit_return_command (tree_return_command&) -{ - fail (); -} - -void -tree_jit::visit_return_list (tree_return_list&) -{ - fail (); -} - -void -tree_jit::visit_simple_assignment (tree_simple_assignment& tsa) -{ - // only support an identifier as lhs - tree_identifier *lhs = dynamic_cast<tree_identifier*> (tsa.left_hand_side ()); - if (!lhs) - fail (); - - variable_info lhsv = find (lhs->name (), false); - - // resolve rhs as normal - tree_expression *rhs = tsa.right_hand_side (); - rhs->accept (*this); - - Value *rhsv = value_stack.back (); - value_stack.pop_back (); - - do_assign (lhsv, rhsv); - - if (tsa.print_result ()) - emit_print (lhs->name (), rhsv); -} - -void -tree_jit::visit_statement (tree_statement& stmt) +tree_jit::code_generator::visit_statement (tree_statement& stmt) { tree_command *cmd = stmt.command (); tree_expression *expr = stmt.expression (); @@ -567,8 +1227,6 @@ cmd->accept (*this); else { - // TODO deal with printing - // stolen from tree_evaluator::visit_statement bool do_bind_ans = false; @@ -585,11 +1243,15 @@ if (do_bind_ans) { - Value *rhs = value_stack.back (); - value_stack.pop_back (); + value rhs = value_stack.back (); + value ans = variables["ans"]; + if (ans.first != rhs.first) + fail (); - variable_info ans = find ("ans", false); - do_assign (ans, rhs); + builder.CreateStore (rhs.second, ans.second); + + if (expr->print_result ()) + emit_print ("ans", rhs); } else if (expr->is_identifier () && expr->print_result ()) { @@ -604,100 +1266,151 @@ } void -tree_jit::visit_statement_list (tree_statement_list&) +tree_jit::code_generator::visit_statement_list (tree_statement_list&) { fail (); } void -tree_jit::visit_switch_case (tree_switch_case&) +tree_jit::code_generator::visit_switch_case (tree_switch_case&) { fail (); } void -tree_jit::visit_switch_case_list (tree_switch_case_list&) +tree_jit::code_generator::visit_switch_case_list (tree_switch_case_list&) { fail (); } void -tree_jit::visit_switch_command (tree_switch_command&) +tree_jit::code_generator::visit_switch_command (tree_switch_command&) { fail (); } void -tree_jit::visit_try_catch_command (tree_try_catch_command&) +tree_jit::code_generator::visit_try_catch_command (tree_try_catch_command&) { fail (); } void -tree_jit::visit_unwind_protect_command (tree_unwind_protect_command&) +tree_jit::code_generator::visit_unwind_protect_command (tree_unwind_protect_command&) { fail (); } void -tree_jit::visit_while_command (tree_while_command&) +tree_jit::code_generator::visit_while_command (tree_while_command&) { fail (); } void -tree_jit::visit_do_until_command (tree_do_until_command&) +tree_jit::code_generator::visit_do_until_command (tree_do_until_command&) { fail (); } void -tree_jit::fail (void) +tree_jit::code_generator::emit_print (const std::string& name, const value& v) { - throw jit_fail_exception (); + const jit_function::overload& ol = tinfo->print_value (v.first); + if (! ol.function) + fail (); + + llvm::Value *str = builder.CreateGlobalStringPtr (name); + builder.CreateCall2 (ol.function, str, v.second); } -tree_jit::function_info::function_info (void) : function (0) -{} +tree_jit::function_info::function_info (tree_jit& tjit, tree& tee) : + tinfo (tjit.tinfo), engine (tjit.engine) +{ + type_infer infer(tjit.tinfo); + + try + { + tee.accept (infer); + } + catch (const jit_fail_exception&) + { + function = 0; + return; + } + + argin = infer.get_argin (); + types = infer.get_types (); + + code_generator gen(tjit.tinfo, tjit.module, tee, argin, types); + function = gen.get_function (); -tree_jit::function_info::function_info (jit_function fun, - const std::vector<std::string>& args, - const std::vector<bool>& arg_used) : - function (fun), arguments (args), argument_used (arg_used) -{} + if (function) + { + llvm::verifyFunction (*function); + tjit.module_pass_manager->run (*tjit.module); + tjit.pass_manager->run (*function); + + if (debug_print) + { + std::cout << "Compiled:\n"; + std::cout << tee.str_print_code () << std::endl; -bool tree_jit::function_info::execute () + std::cout << "Code:\n"; + + llvm::raw_os_ostream os (std::cout); + function->print (os); + } + } +} + +bool +tree_jit::function_info::execute () const { if (! function) return false; - // FIXME: we are doing hash lookups every time, this has got to be slow - unwind_protect up; - bool *args_defined = new bool[arguments.size ()]; // vector<bool> sucks - up.add_delete (args_defined); + tinfo->reset_generic (types.size ()); - std::vector<double> args_values (arguments.size ()); - for (size_t i = 0; i < arguments.size (); ++i) + std::vector<llvm::GenericValue> args (types.size ()); + size_t idx; + type_map::const_iterator iter; + for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx) { - octave_value ov = symbol_table::varval (arguments[i]); - - if (argument_used[i]) + if (argin.count (iter->first)) { - if (! (ov.is_double_type () && ov.is_real_scalar ())) - return false; - - args_defined[i] = ov.is_defined (); - args_values[i] = ov.double_value (); + octave_value ov = symbol_table::varval (iter->first); + tinfo->to_generic (iter->second, args[idx], ov); } else - args_defined[i] = false; + tinfo->to_generic (iter->second, args[idx]); } - function (args_defined, &args_values[0]); + engine->runFunction (function, args); - for (size_t i = 0; i < arguments.size (); ++i) - if (args_defined[i]) - symbol_table::varref (arguments[i]) = octave_value (args_values[i]); + for (idx = 0, iter = types.begin (); iter != types.end (); ++iter, ++idx) + { + octave_value result = tinfo->to_octave_value (iter->second, args[idx]); + symbol_table::varref (iter->first) = result; + } return true; } + +bool +tree_jit::function_info::match () const +{ + for (std::set<std::string>::iterator iter = argin.begin (); + iter != argin.end (); ++iter) + { + jit_type *required_type = types.find (*iter)->second; + octave_value val = symbol_table::varref (*iter); + jit_type *current_type = tinfo->type_of (val); + + // FIXME: should be: ! required_type->is_parent (current_type) + if (required_type != current_type) + return false; + } + + return true; +}
--- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -23,31 +23,247 @@ #if !defined (octave_tree_jit_h) #define octave_tree_jit_h 1 +#include <list> #include <map> +#include <set> #include <stdexcept> #include <vector> +#include "Array.h" #include "pt-walk.h" -class jit_fail_exception : public std::exception {}; +// -------------------- Current status -------------------- +// Simple binary operations (+-*/) on octave_scalar's (doubles) are optimized. +// However, there is no warning emitted on divide by 0. For example, +// a = 5; +// b = a * 5 + a; +// +// For other types all binary operations are compiled but not optimized. For +// example, +// a = [1 2 3] +// b = a + a; +// will compile to do_binary_op (a, a). +// --------------------------------------------------------- -// LLVM forward declares + +// we don't want to include llvm headers here, as they require __STDC_LIMIT_MACROS +// and __STDC_CONSTANT_MACROS be defined in the entire compilation unit namespace llvm { class Value; class Module; class FunctionPassManager; + class PassManager; class ExecutionEngine; class Function; class BasicBlock; class LLVMContext; + class Type; + class GenericValue; } +class octave_base_value; +class octave_value; class tree; +// thrown when we should give up on JIT and interpret +class jit_fail_exception : public std::exception {}; + +// Used to keep track of estimated (infered) types during JIT. This is a +// hierarchical type system which includes both concrete and abstract types. +// +// Current, we only support any and scalar types. If we can't figure out what +// type a variable is, we assign it the any type. This allows us to generate +// code even for the case of poor type inference. +class +OCTINTERP_API +jit_type +{ +public: + jit_type (const std::string& n, bool fi, jit_type *mparent, llvm::Type *lt, + int tid) : + mname (n), finit (fi), p (mparent), llvm_type (lt), id (tid) + {} + + // a user readable type name + const std::string& name (void) const { return mname; } + + // do we need to initialize variables of this type, even if they are not + // input arguments? + bool force_init (void) const { return finit; } + + // a unique id for the type + int type_id (void) const { return id; } + + // An abstract base type, may be null + jit_type *parent (void) const { return p; } + + // convert to an llvm type + llvm::Type *to_llvm (void) const { return llvm_type; } + + // how this type gets passed as a function argument + llvm::Type *to_llvm_arg (void) const; +private: + std::string mname; + bool finit; + jit_type *p; + llvm::Type *llvm_type; + int id; + int depth; +}; + + +// Keeps track of overloads for a builtin function. Used for both type inference +// and code generation. +class +jit_function +{ +public: + struct overload + { + overload (void) : function (0), can_error (true), result (0) {} + + overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0) : + function (f), can_error (e), result (r), arguments (1) + { + arguments[0] = arg0; + } + + overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0, + jit_type *arg1) : function (f), can_error (e), result (r), + arguments (2) + { + arguments[0] = arg0; + arguments[1] = arg1; + } + + llvm::Function *function; + bool can_error; + jit_type *result; + std::vector<jit_type*> arguments; + }; + + void add_overload (const overload& func) + { + add_overload (func, func.arguments); + } + + void add_overload (llvm::Function *f, bool e, jit_type *r, jit_type *arg0, + jit_type *arg1) + { + overload ol (f, e, r, arg0, arg1); + add_overload (ol); + } + + void add_overload (const overload& func, + const std::vector<jit_type*>& args); + + const overload& get_overload (const std::vector<jit_type *>& types) const; + + const overload& get_overload (jit_type *arg0) const + { + std::vector<jit_type *> types (1); + types[0] = arg0; + return get_overload (types); + } + + const overload& get_overload (jit_type *arg0, jit_type *arg1) const + { + std::vector<jit_type *> types (2); + types[0] = arg0; + types[1] = arg1; + return get_overload (types); + } + + jit_type *get_result (const std::vector<jit_type *>& types) const + { + const overload& temp = get_overload (types); + return temp.result; + } + + jit_type *get_result (jit_type *arg0, jit_type *arg1) const + { + const overload& temp = get_overload (arg0, arg1); + return temp.result; + } +private: + Array<octave_idx_type> to_idx (const std::vector<jit_type*>& types) const; + + std::vector<Array<overload> > overloads; +}; + +// Get information and manipulate jit types. +class +OCTINTERP_API +jit_typeinfo +{ +public: + jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e, llvm::Type *ov); + + jit_type *get_any (void) const { return any; } + + jit_type *get_scalar (void) const { return scalar; } + + jit_type *type_of (const octave_value& ov) const; + + const jit_function& binary_op (int op) const; + + const jit_function::overload& binary_op_overload (int op, jit_type *lhs, + jit_type *rhs) const + { + const jit_function& jf = binary_op (op); + return jf.get_overload (lhs, rhs); + } + + jit_type *binary_op_result (int op, jit_type *lhs, jit_type *rhs) const + { + const jit_function::overload& ol = binary_op_overload (op, lhs, rhs); + return ol.result; + } + + const jit_function::overload& assign_op (jit_type *lhs, jit_type *rhs) const; + + const jit_function::overload& print_value (jit_type *to_print) const; + + // FIXME: generic creation should probably be handled seperatly + void to_generic (jit_type *type, llvm::GenericValue& gv); + void to_generic (jit_type *type, llvm::GenericValue& gv, octave_value ov); + + octave_value to_octave_value (jit_type *type, llvm::GenericValue& gv); + + void reset_generic (size_t nargs); +private: + jit_type *new_type (const std::string& name, bool force_init, + jit_type *parent, llvm::Type *llvm_type); + + void add_print (jit_type *ty, void *call); + + void add_binary_op (jit_type *ty, int op, int llvm_op); + + llvm::Module *module; + llvm::ExecutionEngine *engine; + int next_id; + + llvm::Type *ov_t; + + std::vector<jit_type*> id_to_type; + jit_type *any; + jit_type *scalar; + + std::vector<jit_function> binary_ops; + jit_function assign_fn; + jit_function print_fn; + + size_t scalar_out_idx; + std::vector<double> scalar_out; + + size_t ov_out_idx; + std::vector<octave_base_value*> ov_out; +}; + class OCTINTERP_API -tree_jit : private tree_walker +tree_jit { public: tree_jit (void); @@ -56,146 +272,265 @@ bool execute (tree& tee); private: - typedef void (*jit_function)(bool*, double*); + typedef std::map<std::string, jit_type *> type_map; + + class + type_infer : public tree_walker + { + public: + type_infer (jit_typeinfo *ti) : tinfo (ti), is_lvalue (false), + rvalue_type (0) + {} + + const std::set<std::string>& get_argin () const { return argin; } + + const type_map& get_types () const { return types; } + + void visit_anon_fcn_handle (tree_anon_fcn_handle&); + + void visit_argument_list (tree_argument_list&); + + void visit_binary_expression (tree_binary_expression&); + + void visit_break_command (tree_break_command&); + + void visit_colon_expression (tree_colon_expression&); + + void visit_continue_command (tree_continue_command&); + + void visit_global_command (tree_global_command&); + + void visit_persistent_command (tree_persistent_command&); + + void visit_decl_elt (tree_decl_elt&); + + void visit_decl_init_list (tree_decl_init_list&); + + void visit_simple_for_command (tree_simple_for_command&); + + void visit_complex_for_command (tree_complex_for_command&); + + void visit_octave_user_script (octave_user_script&); + + void visit_octave_user_function (octave_user_function&); + + void visit_octave_user_function_header (octave_user_function&); + + void visit_octave_user_function_trailer (octave_user_function&); + + void visit_function_def (tree_function_def&); + + void visit_identifier (tree_identifier&); + + void visit_if_clause (tree_if_clause&); + + void visit_if_command (tree_if_command&); + + void visit_if_command_list (tree_if_command_list&); + + void visit_index_expression (tree_index_expression&); + + void visit_matrix (tree_matrix&); + + void visit_cell (tree_cell&); + + void visit_multi_assignment (tree_multi_assignment&); + + void visit_no_op_command (tree_no_op_command&); + + void visit_constant (tree_constant&); + + void visit_fcn_handle (tree_fcn_handle&); + + void visit_parameter_list (tree_parameter_list&); + + void visit_postfix_expression (tree_postfix_expression&); + + void visit_prefix_expression (tree_prefix_expression&); + + void visit_return_command (tree_return_command&); + + void visit_return_list (tree_return_list&); + + void visit_simple_assignment (tree_simple_assignment&); + + void visit_statement (tree_statement&); + + void visit_statement_list (tree_statement_list&); + + void visit_switch_case (tree_switch_case&); + + void visit_switch_case_list (tree_switch_case_list&); + + void visit_switch_command (tree_switch_command&); + + void visit_try_catch_command (tree_try_catch_command&); + + void visit_unwind_protect_command (tree_unwind_protect_command&); + + void visit_while_command (tree_while_command&); + + void visit_do_until_command (tree_do_until_command&); + private: + void handle_identifier (const std::string& name, octave_value v); + + jit_typeinfo *tinfo; + + bool is_lvalue; + jit_type *rvalue_type; + + type_map types; + std::set<std::string> argin; + + std::vector<jit_type *> type_stack; + }; + + class + code_generator : public tree_walker + { + public: + code_generator (jit_typeinfo *ti, llvm::Module *module, tree &tee, + const std::set<std::string>& argin, + const type_map& infered_types); + + llvm::Function *get_function () const { return function; } + + void visit_anon_fcn_handle (tree_anon_fcn_handle&); + + void visit_argument_list (tree_argument_list&); + + void visit_binary_expression (tree_binary_expression&); + + void visit_break_command (tree_break_command&); + + void visit_colon_expression (tree_colon_expression&); + + void visit_continue_command (tree_continue_command&); + + void visit_global_command (tree_global_command&); + + void visit_persistent_command (tree_persistent_command&); + + void visit_decl_elt (tree_decl_elt&); + + void visit_decl_init_list (tree_decl_init_list&); + + void visit_simple_for_command (tree_simple_for_command&); + + void visit_complex_for_command (tree_complex_for_command&); + + void visit_octave_user_script (octave_user_script&); + + void visit_octave_user_function (octave_user_function&); + + void visit_octave_user_function_header (octave_user_function&); + + void visit_octave_user_function_trailer (octave_user_function&); + + void visit_function_def (tree_function_def&); + + void visit_identifier (tree_identifier&); + + void visit_if_clause (tree_if_clause&); + + void visit_if_command (tree_if_command&); + + void visit_if_command_list (tree_if_command_list&); + + void visit_index_expression (tree_index_expression&); + + void visit_matrix (tree_matrix&); + + void visit_cell (tree_cell&); + + void visit_multi_assignment (tree_multi_assignment&); + + void visit_no_op_command (tree_no_op_command&); + + void visit_constant (tree_constant&); + + void visit_fcn_handle (tree_fcn_handle&); + + void visit_parameter_list (tree_parameter_list&); + + void visit_postfix_expression (tree_postfix_expression&); + + void visit_prefix_expression (tree_prefix_expression&); + + void visit_return_command (tree_return_command&); + + void visit_return_list (tree_return_list&); + + void visit_simple_assignment (tree_simple_assignment&); + + void visit_statement (tree_statement&); + + void visit_statement_list (tree_statement_list&); + + void visit_switch_case (tree_switch_case&); + + void visit_switch_case_list (tree_switch_case_list&); + + void visit_switch_command (tree_switch_command&); + + void visit_try_catch_command (tree_try_catch_command&); + + void visit_unwind_protect_command (tree_unwind_protect_command&); + + void visit_while_command (tree_while_command&); + + void visit_do_until_command (tree_do_until_command&); + private: + typedef std::pair<jit_type *, llvm::Value *> value; + + void emit_print (const std::string& name, const value& v); + + void push_value (jit_type *type, llvm::Value *v) + { + value_stack.push_back (value (type, v)); + } + + jit_typeinfo *tinfo; + llvm::Function *function; + + bool is_lvalue; + std::map<std::string, value> variables; + std::vector<value> value_stack; + }; class function_info { public: - function_info (void); - function_info (jit_function fn, const std::vector<std::string>& args, - const std::vector<bool>& args_used); + function_info (tree_jit& tjit, tree& tee); - bool execute (); + bool execute () const; + + bool match () const; private: - jit_function function; - std::vector<std::string> arguments; - - // is the argument used? or is it just declared? - std::vector<bool> argument_used; - }; - - struct variable_info - { - llvm::Value *defined; - llvm::Value *value; - bool use; + jit_typeinfo *tinfo; + llvm::ExecutionEngine *engine; + std::set<std::string> argin; + type_map types; + llvm::Function *function; }; - function_info *compile (tree& tee); - - variable_info find (const std::string &name, bool use); - - void do_assign (variable_info vinfo, llvm::Value *value); - - void emit_print (const std::string& vname, llvm::Value *value); - - // tree_walker - void visit_anon_fcn_handle (tree_anon_fcn_handle&); - - void visit_argument_list (tree_argument_list&); - - void visit_binary_expression (tree_binary_expression&); - - void visit_break_command (tree_break_command&); - - void visit_colon_expression (tree_colon_expression&); - - void visit_continue_command (tree_continue_command&); - - void visit_global_command (tree_global_command&); - - void visit_persistent_command (tree_persistent_command&); - - void visit_decl_elt (tree_decl_elt&); - - void visit_decl_init_list (tree_decl_init_list&); - - void visit_simple_for_command (tree_simple_for_command&); - - void visit_complex_for_command (tree_complex_for_command&); - - void visit_octave_user_script (octave_user_script&); - - void visit_octave_user_function (octave_user_function&); - - void visit_octave_user_function_header (octave_user_function&); - - void visit_octave_user_function_trailer (octave_user_function&); - - void visit_function_def (tree_function_def&); - - void visit_identifier (tree_identifier&); - - void visit_if_clause (tree_if_clause&); - - void visit_if_command (tree_if_command&); - - void visit_if_command_list (tree_if_command_list&); - - void visit_index_expression (tree_index_expression&); + typedef std::list<function_info *> function_list; + typedef std::map<tree *, function_list> compiled_map; - void visit_matrix (tree_matrix&); - - void visit_cell (tree_cell&); - - void visit_multi_assignment (tree_multi_assignment&); - - void visit_no_op_command (tree_no_op_command&); - - void visit_constant (tree_constant&); - - void visit_fcn_handle (tree_fcn_handle&); - - void visit_parameter_list (tree_parameter_list&); - - void visit_postfix_expression (tree_postfix_expression&); - - void visit_prefix_expression (tree_prefix_expression&); - - void visit_return_command (tree_return_command&); - - void visit_return_list (tree_return_list&); - - void visit_simple_assignment (tree_simple_assignment&); - - void visit_statement (tree_statement&); - - void visit_statement_list (tree_statement_list&); - - void visit_switch_case (tree_switch_case&); - - void visit_switch_case_list (tree_switch_case_list&); - - void visit_switch_command (tree_switch_command&); - - void visit_try_catch_command (tree_try_catch_command&); - - void do_unwind_protect_cleanup_code (tree_statement_list *list); - - void visit_unwind_protect_command (tree_unwind_protect_command&); - - void visit_while_command (tree_while_command&); - - void visit_do_until_command (tree_do_until_command&); - - void fail (void); - - typedef std::map<std::string, variable_info> var_map; - typedef var_map::iterator var_map_iterator; - typedef std::map<tree*, function_info*> finfo_map; - typedef finfo_map::iterator finfo_map_iterator; - - std::vector<llvm::Value*> value_stack; - var_map variables; - finfo_map compiled_functions; + static void fail (void) + { + throw jit_fail_exception (); + } llvm::LLVMContext &context; llvm::Module *module; + llvm::PassManager *module_pass_manager; llvm::FunctionPassManager *pass_manager; llvm::ExecutionEngine *engine; - llvm::BasicBlock *entry_block; - llvm::Function *print_double; + jit_typeinfo *tinfo; + + compiled_map compiled; }; #endif