# HG changeset patch # User Max Brister # Date 1343216415 18000 # Node ID 094bc0a145a180637b5b78a377c2f90f947acb1b # Parent c753ee2287081248d9898018c67533d30770c5d1 Take into account different calling conventions * src/TEMPLATE-INST/Array-jit.cc: Instantiate jit_function instead of jit_operation::overload. * src/pt-jit.cc: New test cases. (operator<<): New overload. (jit_type::jit_type): Initialize new fields. (jit_function::jit_function, jit_function::name, jit_function::new_block, jit_function::call, jit_function::argument, jit_function::do_return, jit_typeinfo::create_int, jit_typeinfo::intN): New function. (jit_operation::add_overload, jit_typeinfo::add_print, jit_typeinfo::add_binary_op, jit_typeinfo::add_binary_icmp, jit_typeinfo::add_binary_fcmp, jit_typeinfo::create_function, jit_typeinfo::create_identity, jit_typeinfo::register_intrinsic, jit_typeinfo::mirror_binary): Use jit_function. (jit_operation::overload): Renamed from get_overload and use jit_function. (jit_typeinfo::initialize): Do not assign to instance. (jit_typeinfo::jit_typeinfo): Assign to instance and deal with calling conventions using jit_function. (jit_typeinfo::wrap_complex, jit_convert::convert_llvm::create_call): Removed function. (jit_call::infer): Call result instead of get_result. (jit_convert::convert_llvm::visit): Use jit_function and jit_function::call. * src/pt-jit.h (operator<<): New declaration. (jit_convention::type): New enumeration. (jit_type::jit_type): Move implementation to src/pt-jit.cc. (jit_type::sret, jit_type::mark_sret, jit_type::pointer_arg, jit_type::mark_pointer_arg, jit_type::pack, jit_type::set_pack, jit_type::unpack, jit_type::set_unpack, jit_type::packed_type, jit_type::set_packed_type): New functions. (ASSIGN_ARG, JIT_EXPAND): New convenience macros. (jit_function): New class. (jit_operation::overload): Removed class. (jit_operation::add_overload): Accept a jit_function instead of a jit_operation::overload. (jit_operation::result): Rename from jit_operation::get_result. (jit_operation::overload): Rename from jit_operation::get_overload. (jit_operation::add_overload): Remove several overloads. (jit_typeinfo::get_grab, jit_typeinfo::get_release, jit_typeinfo::cast, jit_typeinfo::do_cast, jit_typeinfo::mirror_binary, jit_call::overload, jit_call::needs_release, jit_extract_argument::overload, jit_store_argument::overload): Use jit_function (jit_typeinfo::add_print): Remove unneeded parameter. (jit_typeinfo::create_function): Use new parameters. (jit_typeinfo::pack_complex, jit_typeinfo::unpack_complex): Make static. (jit_typeinfo::intN, jit_typeinfo::create_int): New declarations. (jit_call::jit_call): Use rename parameter. (jit_call::operation): Rename from jit_call::function. (jit_call::can_error): overload ().can_error is now a function. (jit_call::print): Use moperation instead of mfunction. (jit_convert::create_call): Removed function and declarations. diff --git a/src/TEMPLATE-INST/Array-jit.cc b/src/TEMPLATE-INST/Array-jit.cc --- a/src/TEMPLATE-INST/Array-jit.cc +++ b/src/TEMPLATE-INST/Array-jit.cc @@ -33,8 +33,8 @@ #include "pt-jit.h" -NO_INSTANTIATE_ARRAY_SORT (jit_operation::overload); +NO_INSTANTIATE_ARRAY_SORT (jit_function); -INSTANTIATE_ARRAY (jit_operation::overload, OCTINTERP_API); +INSTANTIATE_ARRAY (jit_function, OCTINTERP_API); #endif diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -431,6 +431,14 @@ return 0; } +std::ostream& +operator<< (std::ostream& os, const llvm::Value& v) +{ + llvm::raw_os_ostream llvm_out (os); + v.print (llvm_out); + return os; +} + // -------------------- jit_range -------------------- bool jit_range::all_elements_are_ints () const @@ -457,21 +465,233 @@ } // -------------------- jit_type -------------------- +jit_type::jit_type (const std::string& aname, jit_type *aparent, + llvm::Type *allvm_type, int aid) : + mname (aname), mparent (aparent), llvm_type (allvm_type), mid (aid), + mdepth (aparent ? aparent->mdepth + 1 : 0) +{ + std::memset (msret, 0, sizeof (msret)); + std::memset (mpointer_arg, 0, sizeof (mpointer_arg)); + std::memset (mpack, 0, sizeof (mpack)); + std::memset (munpack, 0, sizeof (munpack)); + + for (size_t i = 0; i < jit_convention::length; ++i) + mpacked_type[i] = llvm_type; +} + llvm::Type * jit_type::to_llvm_arg (void) const { return llvm_type ? llvm_type->getPointerTo () : 0; } +// -------------------- jit_function -------------------- +jit_function::jit_function () : module (0), llvm_function (0), mresult (0), + call_conv (jit_convention::length), + mcan_error (false) +{} + +jit_function::jit_function (llvm::Module *amodule, + jit_convention::type acall_conv, + const llvm::Twine& aname, jit_type *aresult, + const std::vector& aargs) + : module (amodule), mresult (aresult), args (aargs), call_conv (acall_conv), + mcan_error (false) +{ + llvm::SmallVector llvm_args; + + llvm::Type *rtype = builder.getVoidTy (); + if (mresult) + { + rtype = mresult->packed_type (call_conv); + if (sret ()) + { + llvm_args.push_back (rtype->getPointerTo ()); + rtype = builder.getVoidTy (); + } + } + + for (std::vector::const_iterator iter = args.begin (); + iter != args.end (); ++iter) + { + jit_type *ty = *iter; + assert (ty); + llvm::Type *argty = ty->packed_type (call_conv); + if (ty->pointer_arg (call_conv)) + argty = argty->getPointerTo (); + + llvm_args.push_back (argty); + } + + // we mark all functinos as external linkage because this prevents llvm + // from getting rid of always inline functions + llvm::FunctionType *ft = llvm::FunctionType::get (rtype, llvm_args, false); + llvm_function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + aname, module); + if (call_conv == jit_convention::internal) + llvm_function->addFnAttr (llvm::Attribute::AlwaysInline); +} + +jit_function::jit_function (const jit_function& fn, jit_type *aresult, + const std::vector& aargs) + : module (fn.module), llvm_function (fn.llvm_function), mresult (aresult), + args (aargs), call_conv (fn.call_conv), mcan_error (fn.mcan_error) +{ +} + +jit_function::jit_function (const jit_function& fn) + : module (fn.module), llvm_function (fn.llvm_function), mresult (fn.mresult), + args (fn.args), call_conv (fn.call_conv), mcan_error (fn.mcan_error) +{} + +std::string +jit_function::name (void) const +{ + return llvm_function->getName (); +} + +llvm::BasicBlock * +jit_function::new_block (const std::string& aname, + llvm::BasicBlock *insert_before) +{ + return llvm::BasicBlock::Create (context, aname, llvm_function, + insert_before); +} + +llvm::Value * +jit_function::call (const std::vector& in_args) const +{ + assert (in_args.size () == args.size ()); + + std::vector llvm_args (args.size ()); + for (size_t i = 0; i < in_args.size (); ++i) + llvm_args[i] = in_args[i]->to_llvm (); + + return call (llvm_args); +} + +llvm::Value * +jit_function::call (const std::vector& in_args) const +{ + assert (valid ()); + assert (in_args.size () == args.size ()); + llvm::Function *stacksave + = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave); + llvm::SmallVector llvm_args; + llvm_args.reserve (in_args.size () + sret ()); + + llvm::Value *sret_mem = 0; + llvm::Value *saved_stack = 0; + if (sret ()) + { + saved_stack = builder.CreateCall (stacksave); + sret_mem = builder.CreateAlloca (mresult->packed_type (call_conv)); + llvm_args.push_back (sret_mem); + } + + for (size_t i = 0; i < in_args.size (); ++i) + { + llvm::Value *arg = in_args[i]; + jit_type::convert_fn convert = args[i]->pack (call_conv); + if (convert) + arg = convert (arg); + + if (args[i]->pointer_arg (call_conv)) + { + if (! saved_stack) + saved_stack = builder.CreateCall (stacksave); + + arg = builder.CreateAlloca (args[i]->to_llvm ()); + builder.CreateStore (in_args[i], arg); + } + + llvm_args.push_back (arg); + } + + llvm::Value *ret = builder.CreateCall (llvm_function, llvm_args); + if (sret_mem) + ret = builder.CreateLoad (sret_mem); + + if (mresult) + { + jit_type::convert_fn unpack = mresult->unpack (call_conv); + if (unpack) + ret = unpack (ret); + } + + if (saved_stack) + { + llvm::Function *stackrestore + = llvm::Intrinsic::getDeclaration (module, + llvm::Intrinsic::stackrestore); + builder.CreateCall (stackrestore, saved_stack); + } + + return ret; +} + +llvm::Value * +jit_function::argument (size_t idx) const +{ + assert (idx < args.size ()); + + // FIXME: We should be treating arguments like a list, not a vector. Shouldn't + // matter much for now, as the number of arguments shouldn't be much bigger + // than 4 + llvm::Function::arg_iterator iter = llvm_function->arg_begin (); + if (sret ()) + ++iter; + + for (size_t i = 0; i < idx; ++i, ++iter); + + if (args[idx]->pointer_arg (call_conv)) + return builder.CreateLoad (iter); + + return iter; +} + +void +jit_function::do_return (llvm::Value *rval) +{ + assert (! rval == ! mresult); + + if (rval) + { + jit_type::convert_fn convert = mresult->pack (call_conv); + if (convert) + rval = convert (rval); + + if (sret ()) + builder.CreateStore (rval, llvm_function->arg_begin ()); + else + builder.CreateRet (rval); + } + else + builder.CreateRetVoid (); + + llvm::verifyFunction (*llvm_function); +} + +std::ostream& +operator<< (std::ostream& os, const jit_function& fn) +{ + llvm::Function *lfn = fn.to_llvm (); + os << "jit_function: cc=" << fn.call_conv; + llvm::raw_os_ostream llvm_out (os); + lfn->print (llvm_out); + llvm_out.flush (); + return os; +} + // -------------------- jit_operation -------------------- void -jit_operation::add_overload (const overload& func, +jit_operation::add_overload (const jit_function& func, const std::vector& args) { if (args.size () >= overloads.size ()) overloads.resize (args.size () + 1); - Array& over = overloads[args.size ()]; + Array& over = overloads[args.size ()]; dim_vector dv (over.dims ()); Array idx = to_idx (args); bool must_resize = false; @@ -495,11 +715,11 @@ over(idx) = func; } -const jit_operation::overload& -jit_operation::get_overload (const std::vector& types) const +const jit_function& +jit_operation::overload (const std::vector& types) const { // FIXME: We should search for the next best overload on failure - static overload null_overload; + static jit_function null_overload; if (types.size () >= overloads.size ()) return null_overload; @@ -507,7 +727,7 @@ if (! types[i]) return null_overload; - const Array& over = overloads[types.size ()]; + const Array& over = overloads[types.size ()]; dim_vector dv (over.dims ()); Array idx = to_idx (types); for (octave_idx_type i = 0; i < dv.length (); ++i) @@ -542,12 +762,14 @@ void jit_typeinfo::initialize (llvm::Module *m, llvm::ExecutionEngine *e) { - instance = new jit_typeinfo (m, e); + new jit_typeinfo (m, e); } jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e) : module (m), engine (e), next_id (0) { + instance = this; + // FIXME: We should be registering types like in octave_value_typeinfo llvm::Type *any_t = llvm::StructType::create (context, "octave_base_value"); any_t = any_t->getPointerTo (); @@ -565,7 +787,6 @@ range_t->setBody (range_contents); llvm::Type *refcount_t = llvm::Type::getIntNTy (context, sizeof(int) * 8); - llvm::Type *int_t = refcount_t; llvm::StructType *matrix_t = llvm::StructType::create (context, "matrix"); llvm::Type *matrix_contents[5]; @@ -594,8 +815,29 @@ boolean = new_type ("bool", any, bool_t); index = new_type ("index", any, index_t); + create_int (8); + create_int (16); + create_int (32); + create_int (64); + casts.resize (next_id + 1); - identities.resize (next_id + 1, 0); + identities.resize (next_id + 1); + + // specify calling conventions + // FIXME: We should detect architecture and do something sane based on that + // here we assume x86 or x86_64 + matrix->mark_sret (); + matrix->mark_pointer_arg (); + + range->mark_sret (); + range->mark_pointer_arg (); + + complex->set_pack (jit_convention::external, &jit_typeinfo::pack_complex); + complex->set_unpack (jit_convention::external, &jit_typeinfo::unpack_complex); + complex->set_packed_type (jit_convention::external, complex_ret); + + if (sizeof (void *) == 4) + complex->mark_sret (); // bind global variables lerror_state = new llvm::GlobalVariable (*module, bool_t, false, @@ -605,15 +847,13 @@ reinterpret_cast (&error_state)); // any with anything is an any op - llvm::Function *fn; - llvm::Type *binary_op_type - = llvm::Type::getIntNTy (context, sizeof (octave_value::binary_op)); - llvm::Function *any_binary = create_function ("octave_jit_binary_any_any", - any_t, binary_op_type, - any_t, any_t); - engine->addGlobalMapping (any_binary, - reinterpret_cast(&octave_jit_binary_any_any)); - + jit_function fn; + jit_type *binary_op_type = intN (sizeof (octave_value::binary_op) * 8); + llvm::Type *llvm_bo_type = binary_op_type->to_llvm (); + jit_function any_binary = create_function (jit_convention::external, + "octave_jit_binary_any_any", + any, binary_op_type, any, any); + any_binary.mark_can_error (); binary_ops.resize (octave_value::num_binary_ops); for (size_t i = 0; i < octave_value::num_binary_ops; ++i) { @@ -626,66 +866,53 @@ { llvm::Twine fn_name ("octave_jit_binary_any_any_"); fn_name = fn_name + llvm::Twine (op); - fn = create_function (fn_name, any, any, any); - llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + + fn = create_function (jit_convention::internal, fn_name, any, any, any); + fn.mark_can_error (); + llvm::BasicBlock *block = fn.new_block (); builder.SetInsertPoint (block); - llvm::APInt op_int(sizeof (octave_value::binary_op), op, + llvm::APInt op_int(sizeof (octave_value::binary_op) * 8, op, std::numeric_limits::is_signed); - llvm::Value *op_as_llvm = llvm::ConstantInt::get (binary_op_type, op_int); - llvm::Value *ret = builder.CreateCall3 (any_binary, - op_as_llvm, - fn->arg_begin (), - ++fn->arg_begin ()); - builder.CreateRet (ret); - binary_ops[op].add_overload (fn, true, any, any, any); + llvm::Value *op_as_llvm = llvm::ConstantInt::get (llvm_bo_type, op_int); + llvm::Value *ret = any_binary.call (op_as_llvm, fn.argument (0), + fn.argument (1)); + fn.do_return (ret); + binary_ops[op].add_overload (fn); } - llvm::Type *void_t = llvm::Type::getVoidTy (context); - // grab any - fn = create_function ("octave_jit_grab_any", any, any); - engine->addGlobalMapping (fn, reinterpret_cast(&octave_jit_grab_any)); - grab_fn.add_overload (fn, false, any, any); + fn = create_function (jit_convention::external, "octave_jit_grab_any", any, + any); + grab_fn.add_overload (fn); grab_fn.stash_name ("grab"); // grab matrix - llvm::Function *print_matrix = create_function ("octave_jit_print_matrix", - void_t, - matrix_t->getPointerTo ()); - engine->addGlobalMapping (print_matrix, - reinterpret_cast(&octave_jit_print_matrix)); - fn = create_function ("octave_jit_grab_matrix", void_t, - matrix_t->getPointerTo (), matrix_t->getPointerTo ()); - engine->addGlobalMapping (fn, - reinterpret_cast (&octave_jit_grab_matrix)); - grab_fn.add_overload (fn, false, matrix, matrix); + fn = create_function (jit_convention::external, "octave_jit_grab_matrix", + matrix, matrix); + grab_fn.add_overload (fn); // release any - fn = create_function ("octave_jit_release_any", void_t, any_t); - llvm::Function *release_any = fn; - engine->addGlobalMapping (fn, - reinterpret_cast(&octave_jit_release_any)); - release_fn.add_overload (fn, false, 0, any); + fn = create_function (jit_convention::external, "octave_jit_release_any", 0, + any); + release_fn.add_overload (fn); release_fn.stash_name ("release"); // release matrix - fn = create_function ("octave_jit_release_matrix", void_t, - matrix_t->getPointerTo ()); - engine->addGlobalMapping (fn, - reinterpret_cast (&octave_jit_release_matrix)); - release_fn.add_overload (fn, false, 0, matrix); + fn = create_function (jit_convention::external, "octave_jit_release_matrix", + 0, matrix); + release_fn.add_overload (fn); // release scalar fn = create_identity (scalar); - release_fn.add_overload (fn, false, 0, scalar); + release_fn.add_overload (fn); // release complex fn = create_identity (complex); - release_fn.add_overload (fn, false, 0, complex); + release_fn.add_overload (fn); // release index fn = create_identity (index); - release_fn.add_overload (fn, false, 0, index); + release_fn.add_overload (fn); // now for binary scalar operations // FIXME: Finish all operations @@ -701,74 +928,66 @@ add_binary_fcmp (scalar, octave_value::op_gt, llvm::CmpInst::FCMP_UGT); add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); - llvm::Function *gripe_div0 = create_function ("gripe_divide_by_zero", void_t); - engine->addGlobalMapping (gripe_div0, - reinterpret_cast (&gripe_divide_by_zero)); + jit_function gripe_div0 = create_function (jit_convention::external, + "gripe_divide_by_zero", 0); + gripe_div0.mark_can_error (); // divide is annoying because it might error - fn = create_function ("octave_jit_div_scalar_scalar", scalar, scalar, scalar); - llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + fn = create_function (jit_convention::internal, + "octave_jit_div_scalar_scalar", scalar, scalar, scalar); + fn.mark_can_error (); + + llvm::BasicBlock *body = fn.new_block (); builder.SetInsertPoint (body); { - llvm::BasicBlock *warn_block = llvm::BasicBlock::Create (context, "warn", - fn); - llvm::BasicBlock *normal_block = llvm::BasicBlock::Create (context, - "normal", fn); + llvm::BasicBlock *warn_block = fn.new_block ("warn"); + llvm::BasicBlock *normal_block = fn.new_block ("normal"); llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); - llvm::Value *check = builder.CreateFCmpUEQ (zero, ++fn->arg_begin ()); + llvm::Value *check = builder.CreateFCmpUEQ (zero, fn.argument (0)); builder.CreateCondBr (check, warn_block, normal_block); builder.SetInsertPoint (warn_block); - builder.CreateCall (gripe_div0); + gripe_div0.call (); builder.CreateBr (normal_block); builder.SetInsertPoint (normal_block); - llvm::Value *ret = builder.CreateFDiv (fn->arg_begin (), - ++fn->arg_begin ()); - builder.CreateRet (ret); - - jit_operation::overload ol (fn, true, scalar, scalar, scalar); - binary_ops[octave_value::op_div].add_overload (ol); - binary_ops[octave_value::op_el_div].add_overload (ol); + llvm::Value *ret = builder.CreateFDiv (fn.argument (0), + fn.argument (1)); + fn.do_return (ret); } - llvm::verifyFunction (*fn); + binary_ops[octave_value::op_div].add_overload (fn); + binary_ops[octave_value::op_el_div].add_overload (fn); // ldiv is the same as div with the operators reversed fn = mirror_binary (fn); - { - jit_operation::overload ol (fn, true, scalar, scalar, scalar); - binary_ops[octave_value::op_ldiv].add_overload (ol); - binary_ops[octave_value::op_el_ldiv].add_overload (ol); - } + binary_ops[octave_value::op_ldiv].add_overload (fn); + binary_ops[octave_value::op_el_ldiv].add_overload (fn); // In general, the result of scalar ^ scalar is a complex number. We might be // able to improve on this if we keep track of the range of values varaibles // can take on. - fn = create_function ("octave_jit_pow_scalar_scalar", complex_ret, scalar_t, - scalar_t); - engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_scalar_scalar)); - { - jit_operation::overload ol (wrap_complex (fn), false, complex, scalar, - scalar); - binary_ops[octave_value::op_pow].add_overload (ol); - binary_ops[octave_value::op_el_pow].add_overload (ol); - } + fn = create_function (jit_convention::external, + "octave_jit_pow_scalar_scalar", complex, scalar, + scalar); + binary_ops[octave_value::op_pow].add_overload (fn); + binary_ops[octave_value::op_el_pow].add_overload (fn); // now for binary complex operations add_binary_op (complex, octave_value::op_add, llvm::Instruction::FAdd); add_binary_op (complex, octave_value::op_sub, llvm::Instruction::FSub); - fn = create_function ("octave_jit_*_complex_complex", complex, complex, + fn = create_function (jit_convention::internal, + "octave_jit_*_complex_complex", complex, complex, complex); - body = llvm::BasicBlock::Create (context, "body", fn); + body = fn.new_block (); builder.SetInsertPoint (body); { // (x0*x1 - y0*y1, x0*y1 + y0*x1) = (x0,y0) * (x1,y1) // We compute this in one vectorized multiplication, a subtraction, and an // addition. - llvm::Value *lhs = fn->arg_begin (); - llvm::Value *rhs = ++fn->arg_begin (); + llvm::Value *lhs = fn.argument (0); + llvm::Value *rhs = fn.argument (1); // FIXME: We need a better way of doing this, working with llvm's IR // directly is sort of a pain. @@ -803,134 +1022,99 @@ tlhs = builder.CreateExtractElement (mres, two); trhs = builder.CreateExtractElement (mres, three); llvm::Value *ret_imag = builder.CreateFAdd (tlhs, trhs); - builder.CreateRet (complex_new (ret_real, ret_imag)); - - jit_operation::overload ol (fn, false, complex, complex, complex); - binary_ops[octave_value::op_mul].add_overload (ol); - binary_ops[octave_value::op_el_mul].add_overload (ol); + fn.do_return (complex_new (ret_real, ret_imag)); } - llvm::verifyFunction (*fn); - - llvm::Function *complex_div = create_function ("octave_jit_complex_div", - complex_ret, complex_ret, - complex_ret); - engine->addGlobalMapping (complex_div, - reinterpret_cast (&octave_jit_complex_div)); - complex_div = wrap_complex (complex_div); - { - jit_operation::overload ol (complex_div, true, complex, complex, complex); - binary_ops[octave_value::op_div].add_overload (ol); - binary_ops[octave_value::op_ldiv].add_overload (ol); - } + + binary_ops[octave_value::op_mul].add_overload (fn); + binary_ops[octave_value::op_el_mul].add_overload (fn); + + jit_function complex_div = create_function (jit_convention::external, + "octave_jit_complex_div", + complex, complex, complex); + complex_div.mark_can_error (); + binary_ops[octave_value::op_div].add_overload (fn); + binary_ops[octave_value::op_ldiv].add_overload (fn); fn = mirror_binary (complex_div); - { - jit_operation::overload ol (fn, true, complex, complex, complex); - binary_ops[octave_value::op_ldiv].add_overload (ol); - binary_ops[octave_value::op_el_ldiv].add_overload (ol); - } - - fn = create_function ("octave_jit_pow_complex_complex", complex_ret, - complex_ret, complex_ret); - engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_complex_complex)); - { - jit_operation::overload ol (wrap_complex (fn), false, complex, complex, - complex); - binary_ops[octave_value::op_pow].add_overload (ol); - binary_ops[octave_value::op_el_pow].add_overload (ol); - } - - fn = create_function ("octave_jit_*_scalar_complex", complex, scalar, + binary_ops[octave_value::op_ldiv].add_overload (fn); + binary_ops[octave_value::op_el_ldiv].add_overload (fn); + + fn = create_function (jit_convention::external, + "octave_jit_pow_complex_complex", complex, complex, complex); - llvm::Function *mul_scalar_complex = fn; - body = llvm::BasicBlock::Create (context, "body", fn); + binary_ops[octave_value::op_pow].add_overload (fn); + binary_ops[octave_value::op_el_pow].add_overload (fn); + + fn = create_function (jit_convention::internal, + "octave_jit_*_scalar_complex", complex, scalar, + complex); + jit_function mul_scalar_complex = fn; + body = fn.new_block (); builder.SetInsertPoint (body); { - llvm::Value *lhs = fn->arg_begin (); + llvm::Value *lhs = fn.argument (0); llvm::Value *tlhs = complex_new (lhs, lhs); - llvm::Value *rhs = ++fn->arg_begin (); - builder.CreateRet (builder.CreateFMul (tlhs, rhs)); - - jit_operation::overload ol (fn, false, complex, scalar, complex); - binary_ops[octave_value::op_mul].add_overload (ol); - binary_ops[octave_value::op_el_mul].add_overload (ol); + llvm::Value *rhs = fn.argument (1); + fn.do_return (builder.CreateFMul (tlhs, rhs)); } - llvm::verifyFunction (*fn); + binary_ops[octave_value::op_mul].add_overload (fn); + binary_ops[octave_value::op_el_mul].add_overload (fn); + fn = mirror_binary (mul_scalar_complex); - { - jit_operation::overload ol (fn, false, complex, complex, scalar); - binary_ops[octave_value::op_mul].add_overload (ol); - binary_ops[octave_value::op_el_mul].add_overload (ol); - } - - fn = create_function ("octave_jit_+_scalar_complex", complex, scalar, - complex); - body = llvm::BasicBlock::Create (context, "body", fn); + binary_ops[octave_value::op_mul].add_overload (fn); + binary_ops[octave_value::op_el_mul].add_overload (fn); + + fn = create_function (jit_convention::internal, "octave_jit_+_scalar_complex", + complex, scalar, complex); + body = fn.new_block (); builder.SetInsertPoint (body); { - llvm::Value *lhs = fn->arg_begin (); - llvm::Value *rhs = ++fn->arg_begin (); + llvm::Value *lhs = fn.argument (0); + llvm::Value *rhs = fn.argument (1); llvm::Value *real = builder.CreateFAdd (lhs, complex_real (rhs)); - builder.CreateRet (complex_real (rhs, real)); - llvm::verifyFunction (*fn); - - binary_ops[octave_value::op_add].add_overload (fn, false, complex, scalar, - complex); - fn = mirror_binary (fn); - binary_ops[octave_value::op_add].add_overload (fn, false, complex, complex, - scalar); + fn.do_return (complex_real (rhs, real)); } - - fn = create_function ("octave_jit_-_complex_scalar", complex, complex, - scalar); - body = llvm::BasicBlock::Create (context, "body", fn); + binary_ops[octave_value::op_add].add_overload (fn); + + fn = mirror_binary (fn); + binary_ops[octave_value::op_add].add_overload (fn); + + fn = create_function (jit_convention::internal, "octave_jit_-_complex_scalar", + complex, complex, scalar); + body = fn.new_block (); builder.SetInsertPoint (body); { - llvm::Value *lhs = fn->arg_begin (); - llvm::Value *rhs = ++fn->arg_begin (); + llvm::Value *lhs = fn.argument (0); + llvm::Value *rhs = fn.argument (1); llvm::Value *real = builder.CreateFSub (complex_real (lhs), rhs); - builder.CreateRet (complex_real (lhs, real)); - llvm::verifyFunction (*fn); - - binary_ops[octave_value::op_sub].add_overload (fn, false, complex, complex, - scalar); + fn.do_return (complex_real (lhs, real)); } - - fn = create_function ("octave_jit_-_scalar_complex", complex, scalar, - complex); - body = llvm::BasicBlock::Create (context, "body", fn); + binary_ops[octave_value::op_sub].add_overload (fn); + + fn = create_function (jit_convention::internal, "octave_jit_-_scalar_complex", + complex, scalar, complex); + body = fn.new_block (); builder.SetInsertPoint (body); { - llvm::Value *lhs = fn->arg_begin (); - llvm::Value *rhs = ++fn->arg_begin (); + llvm::Value *lhs = fn.argument (0); + llvm::Value *rhs = fn.argument (1); llvm::Value *real = builder.CreateFSub (lhs, complex_real (rhs)); - builder.CreateRet (complex_real (rhs, real)); - llvm::verifyFunction (*fn); - - binary_ops[octave_value::op_sub].add_overload (fn, false, complex, scalar, - complex); + fn.do_return (complex_real (rhs, real)); } - - fn = create_function ("octave_jit_pow_scalar_complex", complex_ret, - scalar_t, complex_ret); - engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_scalar_complex)); - { - jit_operation::overload ol (wrap_complex (fn), false, complex, scalar, - complex); - binary_ops[octave_value::op_pow].add_overload (ol); - binary_ops[octave_value::op_el_pow].add_overload (ol); - } - - fn = create_function ("octave_jit_pow_complex_scalar", complex_ret, - complex_ret, scalar_t); - engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_complex_complex)); - { - jit_operation::overload ol (wrap_complex (fn), false, complex, complex, - scalar); - binary_ops[octave_value::op_pow].add_overload (ol); - binary_ops[octave_value::op_el_pow].add_overload (ol); - } + binary_ops[octave_value::op_sub].add_overload (fn); + + fn = create_function (jit_convention::external, + "octave_jit_pow_scalar_complex", complex, scalar, + complex); + binary_ops[octave_value::op_pow].add_overload (fn); + binary_ops[octave_value::op_el_pow].add_overload (fn); + + fn = create_function (jit_convention::external, + "octave_jit_pow_complex_scalar", complex, complex, + scalar); + binary_ops[octave_value::op_pow].add_overload (fn); + binary_ops[octave_value::op_el_pow].add_overload (fn); // now for binary index operators add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); @@ -941,115 +1125,111 @@ // now for printing functions print_fn.stash_name ("print"); - add_print (any, reinterpret_cast (&octave_jit_print_any)); - add_print (scalar, reinterpret_cast (&octave_jit_print_double)); + add_print (any); + add_print (scalar); // initialize for loop for_init_fn.stash_name ("for_init"); - fn = create_function ("octave_jit_for_range_init", index, range); - body = llvm::BasicBlock::Create (context, "body", fn); + fn = create_function (jit_convention::internal, "octave_jit_for_range_init", + index, range); + body = fn.new_block (); builder.SetInsertPoint (body); { llvm::Value *zero = llvm::ConstantInt::get (index_t, 0); - builder.CreateRet (zero); + fn.do_return (zero); } - llvm::verifyFunction (*fn); - for_init_fn.add_overload (fn, false, index, range); + for_init_fn.add_overload (fn); // bounds check for for loop for_check_fn.stash_name ("for_check"); - fn = create_function ("octave_jit_for_range_check", boolean, range, index); - body = llvm::BasicBlock::Create (context, "body", fn); + fn = create_function (jit_convention::internal, "octave_jit_for_range_check", + boolean, range, index); + body = fn.new_block (); builder.SetInsertPoint (body); { llvm::Value *nelem - = builder.CreateExtractValue (fn->arg_begin (), 3); - llvm::Value *idx = ++fn->arg_begin (); + = builder.CreateExtractValue (fn.argument (0), 3); + llvm::Value *idx = fn.argument (1); llvm::Value *ret = builder.CreateICmpULT (idx, nelem); - builder.CreateRet (ret); + fn.do_return (ret); } - llvm::verifyFunction (*fn); - for_check_fn.add_overload (fn, false, boolean, range, index); + for_check_fn.add_overload (fn); // index variabe for for loop for_index_fn.stash_name ("for_index"); - fn = create_function ("octave_jit_for_range_idx", scalar, range, index); - body = llvm::BasicBlock::Create (context, "body", fn); + fn = create_function (jit_convention::internal, "octave_jit_for_range_idx", + scalar, range, index); + body = fn.new_block (); builder.SetInsertPoint (body); { - llvm::Value *idx = ++fn->arg_begin (); + llvm::Value *idx = fn.argument (1); llvm::Value *didx = builder.CreateSIToFP (idx, scalar_t); - llvm::Value *rng = fn->arg_begin (); + llvm::Value *rng = fn.argument (0); llvm::Value *base = builder.CreateExtractValue (rng, 0); llvm::Value *inc = builder.CreateExtractValue (rng, 2); llvm::Value *ret = builder.CreateFMul (didx, inc); ret = builder.CreateFAdd (base, ret); - builder.CreateRet (ret); + fn.do_return (ret); } - llvm::verifyFunction (*fn); - for_index_fn.add_overload (fn, false, scalar, range, index); + for_index_fn.add_overload (fn); // logically true logically_true_fn.stash_name ("logically_true"); - llvm::Function *gripe_nantl - = create_function ("octave_jit_gripe_nan_to_logical_conversion", void_t); - engine->addGlobalMapping (gripe_nantl, reinterpret_cast (&octave_jit_gripe_nan_to_logical_conversion)); - - fn = create_function ("octave_jit_logically_true_scalar", boolean, scalar); - body = llvm::BasicBlock::Create (context, "body", fn); + jit_function gripe_nantl + = create_function (jit_convention::external, + "octave_jit_gripe_nan_to_logical_conversion", 0); + gripe_nantl.mark_can_error (); + + fn = create_function (jit_convention::internal, + "octave_jit_logically_true_scalar", boolean, scalar); + fn.mark_can_error (); + + body = fn.new_block (); builder.SetInsertPoint (body); { - llvm::BasicBlock *error_block = llvm::BasicBlock::Create (context, "error", - fn); - llvm::BasicBlock *normal_block = llvm::BasicBlock::Create (context, - "normal", fn); - - llvm::Value *check = builder.CreateFCmpUNE (fn->arg_begin (), - fn->arg_begin ()); + llvm::BasicBlock *error_block = fn.new_block ("error"); + llvm::BasicBlock *normal_block = fn.new_block ("normal"); + + llvm::Value *check = builder.CreateFCmpUNE (fn.argument (0), + fn.argument (0)); builder.CreateCondBr (check, error_block, normal_block); builder.SetInsertPoint (error_block); - builder.CreateCall (gripe_nantl); + gripe_nantl.call (); builder.CreateBr (normal_block); builder.SetInsertPoint (normal_block); llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); - llvm::Value *ret = builder.CreateFCmpONE (fn->arg_begin (), zero); - builder.CreateRet (ret); + llvm::Value *ret = builder.CreateFCmpONE (fn.argument (0), zero); + fn.do_return (ret); } - llvm::verifyFunction (*fn); - logically_true_fn.add_overload (fn, true, boolean, scalar); - - fn = create_function ("octave_logically_true_bool", boolean, boolean); - body = llvm::BasicBlock::Create (context, "body", fn); - builder.SetInsertPoint (body); - builder.CreateRet (fn->arg_begin ()); - llvm::verifyFunction (*fn); - logically_true_fn.add_overload (fn, false, boolean, boolean); + logically_true_fn.add_overload (fn); + + // logically_true boolean + fn = create_identity (boolean); + logically_true_fn.add_overload (fn); // make_range // FIXME: May be benificial to implement all in LLVM make_range_fn.stash_name ("make_range"); - llvm::Function *compute_nelem - = create_function ("octave_jit_compute_nelem", index, scalar, scalar, - scalar); - engine->addGlobalMapping (compute_nelem, - reinterpret_cast (&octave_jit_compute_nelem)); - - fn = create_function ("octave_jit_make_range", range, scalar, scalar, scalar); - body = llvm::BasicBlock::Create (context, "body", fn); + jit_function compute_nelem + = create_function (jit_convention::external, "octave_jit_compute_nelem", + index, scalar, scalar, scalar); + + fn = create_function (jit_convention::internal, "octave_jit_make_range", + range, scalar, scalar, scalar); + body = fn.new_block (); builder.SetInsertPoint (body); { - llvm::Function::arg_iterator args = fn->arg_begin (); - llvm::Value *base = args; - llvm::Value *limit = ++args; - llvm::Value *inc = ++args; - llvm::Value *nelem = builder.CreateCall3 (compute_nelem, base, limit, inc); + llvm::Value *base = fn.argument (0); + llvm::Value *limit = fn.argument (1); + llvm::Value *inc = fn.argument (2); + llvm::Value *nelem = compute_nelem.call (base, limit, inc); llvm::Value *dzero = llvm::ConstantFP::get (scalar_t, 0); llvm::Value *izero = llvm::ConstantInt::get (index_t, 0); @@ -1059,25 +1239,26 @@ rng = builder.CreateInsertValue (rng, limit, 1); rng = builder.CreateInsertValue (rng, inc, 2); rng = builder.CreateInsertValue (rng, nelem, 3); - builder.CreateRet (rng); + fn.do_return (rng); } - llvm::verifyFunction (*fn); - make_range_fn.add_overload (fn, false, range, scalar, scalar, scalar); + make_range_fn.add_overload (fn); // paren_subsref - llvm::Function *ginvalid_index = create_function ("gipe_invalid_index", - void_t); - engine->addGlobalMapping (ginvalid_index, - reinterpret_cast (&octave_jit_ginvalid_index)); - - llvm::Function *gindex_range = create_function ("gripe_index_out_of_range", - void_t, int_t, int_t, index_t, - index_t); - engine->addGlobalMapping (gindex_range, - reinterpret_cast (&octave_jit_gindex_range)); - - fn = create_function ("()subsref", scalar, matrix, scalar); - body = llvm::BasicBlock::Create (context, "body", fn); + jit_type *jit_int = intN (sizeof (int) * 8); + llvm::Type *int_t = jit_int->to_llvm (); + jit_function ginvalid_index + = create_function (jit_convention::external, "octave_jit_ginvalid_index", + 0); + jit_function gindex_range = create_function (jit_convention::external, + "octave_jit_gindex_range", + 0, jit_int, jit_int, index, + index); + + fn = create_function (jit_convention::internal, "()subsref", scalar, matrix, + scalar); + fn.mark_can_error (); + + body = fn.new_block (); builder.SetInsertPoint (body); { llvm::Value *one = llvm::ConstantInt::get (index_t, 1); @@ -1088,10 +1269,8 @@ ione = llvm::ConstantInt::get (int_t, 1); llvm::Value *undef = llvm::UndefValue::get (scalar_t); - - llvm::Function::arg_iterator args = fn->arg_begin (); - llvm::Value *mat = args++; - llvm::Value *idx = args; + llvm::Value *mat = fn.argument (0); + llvm::Value *idx = fn.argument (1); // convert index to scalar to integer, and check index >= 1 llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t); @@ -1100,17 +1279,13 @@ llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one); llvm::Value *cond = builder.CreateOr (cond0, cond1); - llvm::BasicBlock *done = llvm::BasicBlock::Create (context, "done", fn); - - llvm::BasicBlock *conv_error = llvm::BasicBlock::Create (context, - "conv_error", fn, - done); - llvm::BasicBlock *normal = llvm::BasicBlock::Create (context, "normal", fn, - done); + llvm::BasicBlock *done = fn.new_block ("done"); + llvm::BasicBlock *conv_error = fn.new_block ("conv_error", done); + llvm::BasicBlock *normal = fn.new_block ("normal", done); builder.CreateCondBr (cond, conv_error, normal); builder.SetInsertPoint (conv_error); - builder.CreateCall (ginvalid_index); + ginvalid_index.call (); builder.CreateBr (done); builder.SetInsertPoint (normal); @@ -1119,16 +1294,12 @@ cond = builder.CreateICmpSGT (int_idx, len); - llvm::BasicBlock *bounds_error = llvm::BasicBlock::Create (context, - "bounds_error", - fn, done); - - llvm::BasicBlock *success = llvm::BasicBlock::Create (context, "success", - fn, done); + llvm::BasicBlock *bounds_error = fn.new_block ("bounds_error", done); + llvm::BasicBlock *success = fn.new_block ("success", done); builder.CreateCondBr (cond, bounds_error, success); builder.SetInsertPoint (bounds_error); - builder.CreateCall4 (gindex_range, ione, ione, int_idx, len); + gindex_range.call (ione, ione, int_idx, len); builder.CreateBr (done); builder.SetInsertPoint (success); @@ -1145,31 +1316,27 @@ merge->addIncoming (undef, conv_error); merge->addIncoming (undef, bounds_error); merge->addIncoming (ret, success); - builder.CreateRet (merge); + fn.do_return (merge); } - llvm::verifyFunction (*fn); - paren_subsref_fn.add_overload (fn, true, scalar, matrix, scalar); + paren_subsref_fn.add_overload (fn); // paren subsasgn paren_subsasgn_fn.stash_name ("()subsasgn"); - llvm::Function *resize_paren_subsasgn - = create_function ("octave_jit_paren_subsasgn_impl", void_t, - matrix_t->getPointerTo (), index_t, scalar_t); - engine->addGlobalMapping (resize_paren_subsasgn, - reinterpret_cast (&octave_jit_paren_subsasgn_impl)); - - fn = create_function ("octave_jit_paren_subsasgn", matrix, matrix, scalar, - scalar); - body = llvm::BasicBlock::Create (context, "body", fn); + jit_function resize_paren_subsasgn + = create_function (jit_convention::external, + "octave_jit_paren_subsasgn_impl", matrix, index, scalar); + fn = create_function (jit_convention::internal, "octave_jit_paren_subsasgn", + matrix, matrix, scalar, scalar); + fn.mark_can_error (); + body = fn.new_block (); builder.SetInsertPoint (body); { llvm::Value *one = llvm::ConstantInt::get (index_t, 1); - llvm::Function::arg_iterator args = fn->arg_begin (); - llvm::Value *mat = args++; - llvm::Value *idx = args++; - llvm::Value *value = args; + llvm::Value *mat = fn.argument (0); + llvm::Value *idx = fn.argument (1); + llvm::Value *value = fn.argument (2); llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t); llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t); @@ -1177,16 +1344,13 @@ llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one); llvm::Value *cond = builder.CreateOr (cond0, cond1); - llvm::BasicBlock *done = llvm::BasicBlock::Create (context, "done", fn); - - llvm::BasicBlock *conv_error = llvm::BasicBlock::Create (context, - "conv_error", fn, - done); - llvm::BasicBlock *normal = llvm::BasicBlock::Create (context, "normal", fn, - done); + llvm::BasicBlock *done = fn.new_block ("done"); + + llvm::BasicBlock *conv_error = fn.new_block ("conv_error", done); + llvm::BasicBlock *normal = fn.new_block ("normal", done); builder.CreateCondBr (cond, conv_error, normal); builder.SetInsertPoint (conv_error); - builder.CreateCall (ginvalid_index); + ginvalid_index.call (); builder.CreateBr (done); builder.SetInsertPoint (normal); @@ -1199,20 +1363,13 @@ cond1 = builder.CreateICmpSGT (rcount, one); cond = builder.CreateOr (cond0, cond1); - llvm::BasicBlock *bounds_error = llvm::BasicBlock::Create (context, - "bounds_error", - fn, done); - - llvm::BasicBlock *success = llvm::BasicBlock::Create (context, "success", - fn, done); + llvm::BasicBlock *bounds_error = fn.new_block ("bounds_error", done); + llvm::BasicBlock *success = fn.new_block ("success", done); builder.CreateCondBr (cond, bounds_error, success); // resize on out of bounds access builder.SetInsertPoint (bounds_error); - llvm::Value *resize_result = builder.CreateAlloca (matrix_t); - builder.CreateStore (mat, resize_result); - builder.CreateCall3 (resize_paren_subsasgn, resize_result, int_idx, value); - resize_result = builder.CreateLoad (resize_result); + llvm::Value *resize_result = resize_paren_subsasgn.call (int_idx, value); builder.CreateBr (done); builder.SetInsertPoint (success); @@ -1229,17 +1386,15 @@ merge->addIncoming (mat, conv_error); merge->addIncoming (resize_result, bounds_error); merge->addIncoming (mat, success); - builder.CreateRet (merge); + fn.do_return (merge); } - llvm::verifyFunction (*fn); - paren_subsasgn_fn.add_overload (fn, true, matrix, matrix, scalar, scalar); - - fn = create_function ("octave_jit_paren_subsasgn_matrix_range", void_t, - matrix_t->getPointerTo (), matrix_t->getPointerTo (), - range_t->getPointerTo (), scalar_t); - engine->addGlobalMapping (fn, - reinterpret_cast (&octave_jit_paren_subsasgn_matrix_range)); - paren_subsasgn_fn.add_overload (fn, true, matrix, matrix, range, scalar); + paren_subsasgn_fn.add_overload (fn); + + fn = create_function (jit_convention::external, + "octave_jit_paren_subsasgn_matrix_range", matrix, + matrix, range, scalar); + fn.mark_can_error (); + paren_subsasgn_fn.add_overload (fn); casts[any->type_id ()].stash_name ("(any)"); casts[scalar->type_id ()].stash_name ("(scalar)"); @@ -1247,72 +1402,65 @@ casts[matrix->type_id ()].stash_name ("(matrix)"); // cast any <- matrix - fn = create_function ("octave_jit_cast_any_matrix", any_t, - matrix_t->getPointerTo ()); - engine->addGlobalMapping (fn, - reinterpret_cast (&octave_jit_cast_any_matrix)); - casts[any->type_id ()].add_overload (fn, false, any, matrix); + fn = create_function (jit_convention::external, "octave_jit_cast_any_matrix", + any, matrix); + casts[any->type_id ()].add_overload (fn); // cast matrix <- any - fn = create_function ("octave_jit_cast_matrix_any", void_t, - matrix_t->getPointerTo (), any_t); - engine->addGlobalMapping (fn, - reinterpret_cast (&octave_jit_cast_matrix_any)); - casts[matrix->type_id ()].add_overload (fn, false, matrix, any); + fn = create_function (jit_convention::external, "octave_jit_cast_matrix_any", + matrix, any); + casts[matrix->type_id ()].add_overload (fn); // cast any <- scalar - fn = create_function ("octave_jit_cast_any_scalar", any, scalar); - engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_any_scalar)); - casts[any->type_id ()].add_overload (fn, false, any, scalar); + fn = create_function (jit_convention::external, "octave_jit_cast_any_scalar", + any, scalar); + casts[any->type_id ()].add_overload (fn); // cast scalar <- any - fn = create_function ("octave_jit_cast_scalar_any", scalar, any); - engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_scalar_any)); - casts[scalar->type_id ()].add_overload (fn, false, scalar, any); + fn = create_function (jit_convention::external, "octave_jit_cast_scalar_any", + scalar, any); + casts[scalar->type_id ()].add_overload (fn); // cast any <- complex - fn = create_function ("octave_jit_cast_any_complex", any_t, complex_ret); - engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_any_complex)); - casts[any->type_id ()].add_overload (wrap_complex (fn), false, any, complex); + fn = create_function (jit_convention::external, "octave_jit_cast_any_complex", + any, complex); + casts[any->type_id ()].add_overload (fn); // cast complex <- any - fn = create_function ("octave_jit_cast_complex_any", complex_ret, any_t); - engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_complex_any)); - casts[complex->type_id ()].add_overload (wrap_complex (fn), false, complex, - any); + fn = create_function (jit_convention::external, "octave_jit_cast_complex_any", + complex, any); + casts[complex->type_id ()].add_overload (fn); // cast complex <- scalar - fn = create_function ("octave_jit_cast_complex_scalar", complex, scalar); - body = llvm::BasicBlock::Create (context, "body", fn); + fn = create_function (jit_convention::internal, + "octave_jit_cast_complex_scalar", complex, scalar); + body = fn.new_block (); builder.SetInsertPoint (body); { llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); - builder.CreateRet (complex_new (fn->arg_begin (), zero)); - llvm::verifyFunction (*fn); + fn.do_return (complex_new (fn.argument (0), zero)); } - casts[complex->type_id ()].add_overload (fn, false, complex, scalar); + casts[complex->type_id ()].add_overload (fn); // cast scalar <- complex - fn = create_function ("octave_jit_cast_scalar_complex", scalar, complex); - body = llvm::BasicBlock::Create (context, "body", fn); + fn = create_function (jit_convention::internal, + "octave_jit_cast_scalar_complex", scalar, complex); + body = fn.new_block (); builder.SetInsertPoint (body); - { - builder.CreateRet (complex_real (fn->arg_begin ())); - llvm::verifyFunction (*fn); - } - casts[scalar->type_id ()].add_overload (fn, false, scalar, complex); + fn.do_return (complex_real (fn.argument (0))); + casts[scalar->type_id ()].add_overload (fn); // cast any <- any fn = create_identity (any); - casts[any->type_id ()].add_overload (fn, false, any, any); + casts[any->type_id ()].add_overload (fn); // cast scalar <- scalar fn = create_identity (scalar); - casts[scalar->type_id ()].add_overload (fn, false, scalar, scalar); + casts[scalar->type_id ()].add_overload (fn); // cast complex <- complex fn = create_identity (complex); - casts[complex->type_id ()].add_overload (fn, false, complex, complex); + casts[complex->type_id ()].add_overload (fn); // -------------------- builtin functions -------------------- add_builtin ("#unknown_function"); @@ -1331,31 +1479,34 @@ register_generic ("exp", matrix, matrix); casts.resize (next_id + 1); - fn = create_identity (any); + jit_function any_id = create_identity (any); + jit_function release_any = get_release (any); + std::vector args; + args.resize (1); + for (std::map::iterator iter = builtins.begin (); iter != builtins.end (); ++iter) { jit_type *btype = iter->second; - release_fn.add_overload (release_any, false, 0, btype); - casts[any->type_id ()].add_overload (fn, false, any, btype); - casts[btype->type_id ()].add_overload (fn, false, btype, any); + args[0] = btype; + + release_fn.add_overload (jit_function (release_any, 0, args)); + casts[any->type_id ()].add_overload (jit_function (any_id, any, args)); + + args[0] = any; + casts[btype->type_id ()].add_overload (jit_function (any_id, btype, + args)); } } void -jit_typeinfo::add_print (jit_type *ty, void *call) +jit_typeinfo::add_print (jit_type *ty) { std::stringstream name; name << "octave_jit_print_" << ty->name (); - - llvm::Type *void_t = llvm::Type::getVoidTy (context); - llvm::Function *fn = create_function (name.str (), void_t, - llvm::Type::getInt8PtrTy (context), - ty->to_llvm ()); - engine->addGlobalMapping (fn, call); - - jit_operation::overload ol (fn, false, 0, string, ty); - print_fn.add_overload (ol); + jit_function fn = create_function (jit_convention::external, name.str (), 0, + intN (8), ty); + print_fn.add_overload (fn); } // FIXME: cp between add_binary_op, add_binary_icmp, and add_binary_fcmp @@ -1367,18 +1518,17 @@ fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op) << "_" << ty->name (); - llvm::Function *fn = create_function (fname.str (), ty, ty, ty); - llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + jit_function fn = create_function (jit_convention::internal, fname.str (), + ty, ty, ty); + llvm::BasicBlock *block = fn.new_block (); builder.SetInsertPoint (block); llvm::Instruction::BinaryOps temp = static_cast(llvm_op); - llvm::Value *ret = builder.CreateBinOp (temp, fn->arg_begin (), - ++fn->arg_begin ()); - builder.CreateRet (ret); - llvm::verifyFunction (*fn); - - jit_operation::overload ol(fn, false, ty, ty, ty); - binary_ops[op].add_overload (ol); + + llvm::Value *ret = builder.CreateBinOp (temp, fn.argument (0), + fn.argument (1)); + fn.do_return (ret); + binary_ops[op].add_overload (fn); } void @@ -1389,18 +1539,16 @@ fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) << "_" << ty->name (); - llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); - llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + jit_function fn = create_function (jit_convention::internal, fname.str (), + boolean, ty, ty); + llvm::BasicBlock *block = fn.new_block (); builder.SetInsertPoint (block); llvm::CmpInst::Predicate temp = static_cast(llvm_op); - llvm::Value *ret = builder.CreateICmp (temp, fn->arg_begin (), - ++fn->arg_begin ()); - builder.CreateRet (ret); - llvm::verifyFunction (*fn); - - jit_operation::overload ol (fn, false, boolean, ty, ty); - binary_ops[op].add_overload (ol); + llvm::Value *ret = builder.CreateICmp (temp, fn.argument (0), + fn.argument (1)); + fn.do_return (ret); + binary_ops[op].add_overload (fn); } void @@ -1411,60 +1559,42 @@ fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) << "_" << ty->name (); - llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); - llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + jit_function fn = create_function (jit_convention::internal, fname.str (), + boolean, ty, ty); + llvm::BasicBlock *block = fn.new_block (); builder.SetInsertPoint (block); llvm::CmpInst::Predicate temp = static_cast(llvm_op); - llvm::Value *ret = builder.CreateFCmp (temp, fn->arg_begin (), - ++fn->arg_begin ()); - builder.CreateRet (ret); - llvm::verifyFunction (*fn); - - jit_operation::overload ol (fn, false, boolean, ty, ty); - binary_ops[op].add_overload (ol); + llvm::Value *ret = builder.CreateFCmp (temp, fn.argument (0), + fn.argument (1)); + fn.do_return (ret); + binary_ops[op].add_overload (fn); } -llvm::Function * -jit_typeinfo::create_function (const llvm::Twine& name, jit_type *ret, +jit_function +jit_typeinfo::create_function (jit_convention::type cc, const llvm::Twine& name, + jit_type *ret, const std::vector& args) { - llvm::Type *void_t = llvm::Type::getVoidTy (context); - std::vector llvm_args (args.size (), void_t); - for (size_t i = 0; i < args.size (); ++i) - if (args[i]) - llvm_args[i] = args[i]->to_llvm (); - - return create_function (name, ret ? ret->to_llvm () : void_t, llvm_args); + jit_function result (module, cc, name, ret, args); + return result; } -llvm::Function * -jit_typeinfo::create_function (const llvm::Twine& name, llvm::Type *ret, - const std::vector& args) -{ - llvm::FunctionType *ft = llvm::FunctionType::get (ret, args, false); - llvm::Function *fn = llvm::Function::Create (ft, - llvm::Function::ExternalLinkage, - name, module); - fn->addFnAttr (llvm::Attribute::AlwaysInline); - return fn; -} - -llvm::Function * +jit_function jit_typeinfo::create_identity (jit_type *type) { size_t id = type->type_id (); if (id >= identities.size ()) - identities.resize (id + 1, 0); - - if (! identities[id]) + identities.resize (id + 1); + + if (! identities[id].valid ()) { - llvm::Function *fn = create_function ("id", type, type); - llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + jit_function fn = create_function (jit_convention::internal, "id", type, + type); + llvm::BasicBlock *body = fn.new_block (); builder.SetInsertPoint (body); - builder.CreateRet (fn->arg_begin ()); - llvm::verifyFunction (*fn); - identities[id] = fn; + fn.do_return (fn.argument (0)); + return identities[id] = fn; } return identities[id]; @@ -1511,21 +1641,18 @@ // The first argument will be the Octave function, but we already know that // the function call is the equivalent of the intrinsic, so we ignore it and // call the intrinsic with the remaining arguments. - llvm::Function *fn = create_function (fn_name.str (), result, args1); - llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + jit_function fn = create_function (jit_convention::internal, fn_name.str (), + result, args1); + llvm::BasicBlock *body = fn.new_block (); builder.SetInsertPoint (body); llvm::SmallVector fargs (nargs); - llvm::Function::arg_iterator iter = fn->arg_begin (); - ++iter; - for (size_t i = 0; i < nargs; ++i, ++iter) - fargs[i] = iter; + for (size_t i = 0; i < nargs; ++i) + fargs[i] = fn.argument (i + 1); llvm::Value *ret = builder.CreateCall (ifun, fargs); - builder.CreateRet (ret); - llvm::verifyFunction (*fn); - - paren_subsref_fn.add_overload (fn, false, result, args1); + fn.do_return (ret); + paren_subsref_fn.add_overload (fn); } octave_builtin * @@ -1544,80 +1671,31 @@ // FIXME: Implement } -llvm::Function * -jit_typeinfo::mirror_binary (llvm::Function *fn) -{ - llvm::FunctionType *fn_type = fn->getFunctionType (); - llvm::Function *ret = create_function (fn->getName () + "_reverse", - fn_type->getReturnType (), - fn_type->getParamType (1), - fn_type->getParamType (0)); - llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", ret); - builder.SetInsertPoint (body); - llvm::Value *result = builder.CreateCall2 (fn, ++ret->arg_begin (), - ret->arg_begin ()); - if (ret->getReturnType () == builder.getVoidTy ()) - builder.CreateRetVoid (); - else - builder.CreateRet (result); - - llvm::verifyFunction (*ret); - return ret; -} - -llvm::Function * -jit_typeinfo::wrap_complex (llvm::Function *wrap) +jit_function +jit_typeinfo::mirror_binary (const jit_function& fn) { - llvm::SmallVector new_args; - new_args.reserve (wrap->arg_size ()); - llvm::Type *complex_t = complex->to_llvm (); - for (llvm::Function::arg_iterator iter = wrap->arg_begin (); - iter != wrap->arg_end (); ++iter) - { - llvm::Value *value = iter; - llvm::Type *type = value->getType (); - new_args.push_back (type == complex_ret ? complex_t : type); - } - - llvm::FunctionType *wrap_type = wrap->getFunctionType (); - bool convert_ret = wrap_type->getReturnType () == complex_ret; - llvm::Type *rtype = convert_ret ? complex_t : wrap->getReturnType (); - llvm::FunctionType *ft = llvm::FunctionType::get (rtype, new_args, false); - llvm::Function *fn = llvm::Function::Create (ft, - llvm::Function::ExternalLinkage, - wrap->getName () + "_wrap", - module); - llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + jit_function ret = create_function (jit_convention::internal, + fn.name () + "_reverse", + fn.result (), fn.argument_type (1), + fn.argument_type (0)); + if (fn.can_error ()) + ret.mark_can_error (); + + llvm::BasicBlock *body = ret.new_block (); builder.SetInsertPoint (body); - - llvm::SmallVector converted (new_args.size ()); - llvm::Function::arg_iterator witer = wrap->arg_begin (); - llvm::Function::arg_iterator fiter = fn->arg_begin (); - for (size_t i = 0; i < new_args.size (); ++i, ++witer, ++fiter) - { - llvm::Value *warg = witer; - llvm::Value *arg = fiter; - converted[i] = warg->getType () == arg->getType () ? arg - : pack_complex (arg); - } - - llvm::Value *ret = builder.CreateCall (wrap, converted); - if (wrap_type->getReturnType () != builder.getVoidTy ()) - { - if (convert_ret) - ret = unpack_complex (ret); - builder.CreateRet (ret); - } + llvm::Value *result = fn.call (ret.argument (1), ret.argument (0)); + if (ret.result ()) + ret.do_return (result); else - builder.CreateRetVoid (); - - llvm::verifyFunction (*fn); - return fn; + ret.do_return (); + + return ret; } llvm::Value * jit_typeinfo::pack_complex (llvm::Value *cplx) { + llvm::Type *complex_ret = instance->complex_ret; llvm::Value *real = builder.CreateExtractElement (cplx, builder.getInt32 (0)); llvm::Value *imag = builder.CreateExtractElement (cplx, builder.getInt32 (1)); llvm::Value *ret = llvm::UndefValue::get (complex_ret); @@ -1628,7 +1706,7 @@ llvm::Value * jit_typeinfo::unpack_complex (llvm::Value *result) { - llvm::Type *complex_t = complex->to_llvm (); + llvm::Type *complex_t = get_complex ()->to_llvm (); llvm::Value *real = builder.CreateExtractValue (result, 0); llvm::Value *imag = builder.CreateExtractValue (result, 1); llvm::Value *ret = llvm::UndefValue::get (complex_t); @@ -1668,6 +1746,25 @@ return complex_imag (ret, imag); } +void +jit_typeinfo::create_int (size_t nbits) +{ + std::stringstream tname; + tname << "int" << nbits; + ints[nbits] = new_type (tname.str (), any, llvm::Type::getIntNTy (context, + nbits)); +} + +jit_type * +jit_typeinfo::intN (size_t nbits) const +{ + std::map::const_iterator iter = ints.find (nbits); + if (iter != ints.end ()) + return iter->second; + + fail ("No such integer type"); +} + jit_type * jit_typeinfo::do_type_of (const octave_value &ov) const { @@ -2250,7 +2347,7 @@ return false; } - jit_type *infered = mfunction.get_result (already_infered); + jit_type *infered = moperation.result (already_infered); if (! infered && use_count ()) { std::stringstream ss; @@ -3576,7 +3673,13 @@ void jit_convert::convert_llvm::visit (jit_call& call) { - llvm::Value *ret = create_call (call.overload (), call.arguments ()); + const jit_function& ol = call.overload (); + + std::vector args (call.arguments ().size ()); + for (size_t i = 0; i < args.size (); ++i) + args[i] = call.argument (i); + + llvm::Value *ret = ol.call (args); call.stash_llvm (ret); } @@ -3587,15 +3690,15 @@ assert (arg); arg = builder.CreateLoad (arg); - jit_value *jarg = jthis.create (jit_typeinfo::get_any (), arg); - extract.stash_llvm (create_call (extract.overload (), jarg)); + const jit_function& ol = extract.overload (); + extract.stash_llvm (ol.call (arg)); } void jit_convert::convert_llvm::visit (jit_store_argument& store) { - llvm::Value *arg_value = create_call (store.overload (), store.result ()); - + const jit_function& ol = store.overload (); + llvm::Value *arg_value = ol.call (store.result ()); llvm::Value *arg = arguments[store.name ()]; store.stash_llvm (builder.CreateStore (arg_value, arg)); } @@ -3629,27 +3732,24 @@ void jit_convert::convert_llvm::visit (jit_assign& assign) { - assign.stash_llvm (assign.src ()->to_llvm ()); + jit_value *new_value = assign.src (); + assign.stash_llvm (new_value->to_llvm ()); if (assign.artificial ()) return; - jit_value *new_value = assign.src (); if (isa (new_value)) { - const jit_operation::overload& ol - = jit_typeinfo::get_grab (new_value->type ()); - if (ol.function) - assign.stash_llvm (create_call (ol, new_value)); + const jit_function& ol = jit_typeinfo::get_grab (new_value->type ()); + if (ol.valid ()) + assign.stash_llvm (ol.call (new_value)); } jit_value *overwrite = assign.overwrite (); if (isa (overwrite)) { - const jit_operation::overload& ol - = jit_typeinfo::get_release (overwrite->type ()); - if (ol.function) - create_call (ol, overwrite); + const jit_function& ol = jit_typeinfo::get_release (overwrite->type ()); + ol.call (overwrite); } } @@ -3657,67 +3757,6 @@ jit_convert::convert_llvm::visit (jit_argument&) {} -llvm::Value * -jit_convert::convert_llvm::create_call (const jit_operation::overload& ol, - const std::vector& jargs) -{ - llvm::IRBuilder<> alloca_inserter (prelude, prelude->begin ()); - - llvm::Function *fun = ol.function; - if (! fun) - fail ("Missing overload"); - - const llvm::Function::ArgumentListType& alist = fun->getArgumentList (); - size_t nargs = alist.size (); - bool sret = false; - if (nargs != jargs.size ()) - { - // first argument is the structure return value - assert (nargs == jargs.size () + 1); - sret = true; - } - - std::vector args (nargs); - llvm::Function::arg_iterator llvm_arg = fun->arg_begin (); - if (sret) - { - args[0] = alloca_inserter.CreateAlloca (ol.result->to_llvm ()); - ++llvm_arg; - } - - for (size_t i = 0; i < jargs.size (); ++i, ++llvm_arg) - { - llvm::Value *arg = jargs[i]->to_llvm (); - llvm::Type *arg_type = arg->getType (); - llvm::Type *llvm_arg_type = llvm_arg->getType (); - - if (arg_type == llvm_arg_type) - args[i + sret] = arg; - else - { - // pass structure by pointer - assert (arg_type->getPointerTo () == llvm_arg_type); - llvm::Value *new_arg = alloca_inserter.CreateAlloca (arg_type); - builder.CreateStore (arg, new_arg); - args[i + sret] = new_arg; - } - } - - llvm::Value *llvm_call = builder.CreateCall (fun, args); - return sret ? builder.CreateLoad (args[0]) : llvm_call; -} - -llvm::Value * -jit_convert::convert_llvm::create_call (const jit_operation::overload& ol, - const std::vector& uses) -{ - std::vector values (uses.size ()); - for (size_t i = 0; i < uses.size (); ++i) - values[i] = uses[i].value (); - - return create_call (ol, values); -} - // -------------------- tree_jit -------------------- tree_jit::tree_jit (void) : module (0), engine (0) @@ -3886,3 +3925,96 @@ return true; } #endif + + +/* +Test some simple cases that compile. + +%!test +%! inc = 1e-5; +%! result = 0; +%! for ii = 0:inc:1 +%! result = result + inc * (1/3 * ii * ii); +%! endfor +%! assert (abs (result - 1/9) < 1e-5); + +%!test +%! inc = 1e-5; +%! result = 0; +%! for ii = 0:inc:1 +%! # the ^ operator's result is complex +%! result = result + inc * (1/3 * ii ^ 2); +%! endfor +%! assert (abs (result - 1/9) < 1e-5); + +%!test +%! nr = 1001; +%! mat = zeros (1, nr); +%! for i = 1:nr +%! mat(i) = i; +%! endfor +%! assert (mat == 1:nr); + +%!test +%! nr = 1001; +%! mat = 1:nr; +%! mat(end) = 0; # force mat to a matrix +%! total = 0; +%! for i = 1:nr +%! total = mat(i) + total; +%! endfor +%! assert (sum (mat) == total); + +%!test +%! nr = 1001; +%! mat = [3 1 5]; +%! try +%! for i = 1:nr +%! if i > 500 +%! result = mat(100); +%! else +%! result = i; +%! endif +%! endfor +%! catch +%! end +%! assert (result == 500); + +%!function result = gen_test (n) +%! result = double (rand (1, n) > .01); +%!endfunction + +%!function z = vectorized (A, K) +%! temp = ones (1, K); +%! z = conv (A, temp); +%! z = z > K-1; +%! z = conv (z, temp); +%! z = z(K:end-K+1); +%! z = z >= 1; +%!endfunction + +%!function z = loopy (A, K) +%! z = A; +%! n = numel (A); +%! counter = 0; +%! for ii=1:n +%! if z(ii) +%! counter = counter + 1; +%! else +%! if counter > 0 && counter < K +%! z(ii-counter:ii-1) = 0; +%! endif +%! counter = 0; +%! endif +%! endfor +%! +%! if counter > 0 && counter < K +%! z(end-counter+1:end) = 0; +%! endif +%!endfunction + +%!test +%! test_set = gen_test (10000); +%! assert (all (vectorized (test_set, 3) == loopy (test_set, 3))); + +*/ diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -83,6 +83,9 @@ class PHINode; } +// llvm doesn't provide this, and it's really useful for debugging +std::ostream& operator<< (std::ostream& os, const llvm::Value& v); + class octave_base_value; class octave_builtin; class octave_value; @@ -260,6 +263,26 @@ std::ostream& operator<< (std::ostream& os, const jit_matrix& mat); +class jit_type; +class jit_value; + +// calling convention +namespace +jit_convention +{ + enum + type + { + // internal to jit + internal, + + // an external C call + external, + + length + }; +} + // Used to keep track of estimated (infered) types during JIT. This is a // hierarchical type system which includes both concrete and abstract types. // @@ -270,11 +293,10 @@ jit_type { public: + typedef llvm::Value *(*convert_fn) (llvm::Value *); + jit_type (const std::string& aname, jit_type *aparent, llvm::Type *allvm_type, - int aid) : - mname (aname), mparent (aparent), llvm_type (allvm_type), mid (aid), - mdepth (aparent ? aparent->mdepth + 1 : 0) - {} + int aid); // a user readable type name const std::string& name (void) const { return mname; } @@ -292,112 +314,177 @@ llvm::Type *to_llvm_arg (void) const; size_t depth (void) const { return mdepth; } + + bool sret (jit_convention::type cc) const { return msret[cc]; } + + void mark_sret (jit_convention::type cc = jit_convention::external) + { msret[cc] = true; } + + bool pointer_arg (jit_convention::type cc) const { return mpointer_arg[cc]; } + + void mark_pointer_arg (jit_convention::type cc = jit_convention::external) + { mpointer_arg[cc] = true; } + + convert_fn pack (jit_convention::type cc) { return mpack[cc]; } + + void set_pack (jit_convention::type cc, convert_fn fn) { mpack[cc] = fn; } + + convert_fn unpack (jit_convention::type cc) { return munpack[cc]; } + + void set_unpack (jit_convention::type cc, convert_fn fn) + { munpack[cc] = fn; } + + llvm::Type *packed_type (jit_convention::type cc) + { return mpacked_type[cc]; } + + void set_packed_type (jit_convention::type cc, llvm::Type *ty) + { mpacked_type[cc] = ty; } private: std::string mname; jit_type *mparent; llvm::Type *llvm_type; int mid; size_t mdepth; + + bool msret[jit_convention::length]; + bool mpointer_arg[jit_convention::length]; + + convert_fn mpack[jit_convention::length]; + convert_fn munpack[jit_convention::length]; + + llvm::Type *mpacked_type[jit_convention::length]; }; // seperate print function to allow easy printing if type is null std::ostream& jit_print (std::ostream& os, jit_type *atype); +#define ASSIGN_ARG(i) the_args[i] = arg ## i; +#define JIT_EXPAND(ret, fname, type, isconst, N) \ + ret fname (JIT_PARAM_ARGS OCT_MAKE_DECL_LIST (type, arg, N)) isconst \ + { \ + std::vector the_args (N); \ + OCT_ITERATE_MACRO (ASSIGN_ARG, N); \ + return fname (JIT_PARAMS the_args); \ + } + +// provides a mechanism for calling +class +jit_function +{ + friend std::ostream& operator<< (std::ostream& os, const jit_function& fn); +public: + jit_function (); + + jit_function (llvm::Module *amodule, jit_convention::type acall_conv, + const llvm::Twine& aname, jit_type *aresult, + const std::vector& aargs); + + jit_function (const jit_function& fn, jit_type *aresult, + const std::vector& aargs); + + jit_function (const jit_function& fn); + + bool valid (void) const { return llvm_function; } + + std::string name (void) const; + + llvm::BasicBlock *new_block (const std::string& aname = "body", + llvm::BasicBlock *insert_before = 0); + + llvm::Value *call (const std::vector& in_args) const; + + llvm::Value *call (const std::vector& in_args) const; + +#define JIT_PARAM_ARGS +#define JIT_PARAMS +#define JIT_CALL(N) JIT_EXPAND (llvm::Value *, call, llvm::Value *, const, N) + + JIT_CALL (0); + JIT_CALL (1); + JIT_CALL (2); + JIT_CALL (3); + JIT_CALL (4); + JIT_CALL (5); + +#undef JIT_CALL + +#define JIT_CALL(N) JIT_EXPAND (llvm::Value *, call, jit_value *, const, N) + + JIT_CALL (1); + JIT_CALL (2); + +#undef JIT_CALL +#undef JIT_PARAMS +#undef JIT_PARAM_ARGS + + llvm::Value *argument (size_t idx) const; + + void do_return (llvm::Value *rval = 0); + + llvm::Function *to_llvm (void) const { return llvm_function; } + + // If true, then the return value is passed as a pointer in the first argument + bool sret (void) const { return mresult && mresult->sret (call_conv); } + + bool can_error (void) const { return mcan_error; } + + void mark_can_error (void) { mcan_error = true; } + + jit_type *result (void) const { return mresult; } + + jit_type *argument_type (size_t idx) const + { + assert (idx < args.size ()); + return args[idx]; + } + + const std::vector& arguments (void) const { return args; } +private: + llvm::Module *module; + llvm::Function *llvm_function; + jit_type *mresult; + std::vector args; + jit_convention::type call_conv; + bool mcan_error; +}; + +std::ostream& operator<< (std::ostream& os, const jit_function& fn); + + // Keeps track of overloads for a builtin function. Used for both type inference // and code generation. class jit_operation { public: - struct - overload + void add_overload (const jit_function& func) { - overload (void) : function (0), can_error (false), result (0) {} - -#define ASSIGN_ARG(i) arguments[i] = arg ## i; -#define OVERLOAD_CTOR(N) \ - overload (llvm::Function *f, bool e, jit_type *ret, \ - OCT_MAKE_DECL_LIST (jit_type *, arg, N)) \ - : function (f), can_error (e), result (ret), arguments (N) \ - { \ - OCT_ITERATE_MACRO (ASSIGN_ARG, N); \ - } - - OVERLOAD_CTOR (1) - OVERLOAD_CTOR (2) - OVERLOAD_CTOR (3) - -#undef ASSIGN_ARG -#undef OVERLOAD_CTOR - - overload (llvm::Function *f, bool e, jit_type *r, - const std::vector& aarguments) - : function (f), can_error (e), result (r), arguments (aarguments) - {} - - llvm::Function *function; - bool can_error; - jit_type *result; - std::vector arguments; - }; - - void add_overload (const overload& func) - { - add_overload (func, func.arguments); - } - -#define ADD_OVERLOAD(N) \ - void add_overload (llvm::Function *f, bool e, jit_type *ret, \ - OCT_MAKE_DECL_LIST (jit_type *, arg, N)) \ - { \ - overload ol (f, e, ret, OCT_MAKE_ARG_LIST (arg, N)); \ - add_overload (ol); \ + add_overload (func, func.arguments ()); } - ADD_OVERLOAD (1); - ADD_OVERLOAD (2); - ADD_OVERLOAD (3); - -#undef ADD_OVERLOAD - - void add_overload (llvm::Function *f, bool e, jit_type *r, - const std::vector& args) - { - overload ol (f, e, r, args); - add_overload (ol); - } - - void add_overload (const overload& func, + void add_overload (const jit_function& func, const std::vector& args); - const overload& get_overload (const std::vector& types) const; - - const overload& get_overload (jit_type *arg0) const + const jit_function& overload (const std::vector& types) const; + + jit_type *result (const std::vector& types) const { - std::vector types (1); - types[0] = arg0; - return get_overload (types); - } - - const overload& get_overload (jit_type *arg0, jit_type *arg1) const - { - std::vector types (2); - types[0] = arg0; - types[1] = arg1; - return get_overload (types); + const jit_function& temp = overload (types); + return temp.result (); } - jit_type *get_result (const std::vector& types) const - { - const overload& temp = get_overload (types); - return temp.result; - } - - jit_type *get_result (jit_type *arg0, jit_type *arg1) const - { - const overload& temp = get_overload (arg0, arg1); - return temp.result; - } +#define JIT_PARAMS +#define JIT_PARAM_ARGS +#define JIT_OVERLOAD(N) \ + JIT_EXPAND (const jit_function&, overload, jit_type *, const, N) \ + JIT_EXPAND (jit_type *, result, jit_type *, const, N) + + JIT_OVERLOAD (1); + JIT_OVERLOAD (2); + JIT_OVERLOAD (3); + +#undef JIT_PARAMS +#undef JIT_PARAM_ARGS const std::string& name (void) const { return mname; } @@ -405,7 +492,7 @@ private: Array to_idx (const std::vector& types) const; - std::vector > overloads; + std::vector > overloads; std::string mname; }; @@ -456,9 +543,9 @@ static const jit_operation& grab (void) { return instance->grab_fn; } - static const jit_operation::overload& get_grab (jit_type *type) + static const jit_function& get_grab (jit_type *type) { - return instance->grab_fn.get_overload (type); + return instance->grab_fn.overload (type); } static const jit_operation& release (void) @@ -466,9 +553,9 @@ return instance->release_fn; } - static const jit_operation::overload& get_release (jit_type *type) + static const jit_function& get_release (jit_type *type) { - return instance->release_fn.get_overload (type); + return instance->release_fn.overload (type); } static const jit_operation& print_value (void) @@ -516,7 +603,7 @@ return instance->do_cast (result); } - static const jit_operation::overload& cast (jit_type *to, jit_type *from) + static const jit_function& cast (jit_type *to, jit_type *from) { return instance->do_cast (to, from); } @@ -586,16 +673,16 @@ return casts[id]; } - const jit_operation::overload& do_cast (jit_type *to, jit_type *from) + const jit_function& do_cast (jit_type *to, jit_type *from) { - return do_cast (to).get_overload (from); + return do_cast (to).overload (from); } jit_type *new_type (const std::string& name, jit_type *parent, llvm::Type *llvm_type); - void add_print (jit_type *ty, void *call); + void add_print (jit_type *ty); void add_binary_op (jit_type *ty, int op, int llvm_op); @@ -603,43 +690,27 @@ void add_binary_fcmp (jit_type *ty, int op, int llvm_op); - - llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret) - { - std::vector args; - return create_function (name, ret, args); - } - -#define ASSIGN_ARG(i) args[i] = arg ## i; -#define CREATE_FUNCTIONT(TYPE, N) \ - llvm::Function *create_function (const llvm::Twine& name, TYPE *ret, \ - OCT_MAKE_DECL_LIST (TYPE *, arg, N)) \ - { \ - std::vector args (N); \ - OCT_ITERATE_MACRO (ASSIGN_ARG, N); \ - return create_function (name, ret, args); \ - } - -#define CREATE_FUNCTION(N) \ - CREATE_FUNCTIONT(llvm::Type, N) \ - CREATE_FUNCTIONT(jit_type, N) - - CREATE_FUNCTION(1) - CREATE_FUNCTION(2) - CREATE_FUNCTION(3) - CREATE_FUNCTION(4) - -#undef ASSIGN_ARG -#undef CREATE_FUNCTIONT + jit_function create_function (jit_convention::type cc, + const llvm::Twine& name, jit_type *ret, + const std::vector& args + = std::vector ()); + +#define JIT_PARAM_ARGS jit_convention::type cc, const llvm::Twine& name, \ + jit_type *ret, +#define JIT_PARAMS cc, name, ret, +#define CREATE_FUNCTION(N) JIT_EXPAND(jit_function, create_function, \ + jit_type *, /* empty */, N) + + CREATE_FUNCTION(1); + CREATE_FUNCTION(2); + CREATE_FUNCTION(3); + CREATE_FUNCTION(4); + +#undef JIT_PARAM_ARGS +#undef JIT_PARAMS #undef CREATE_FUNCTION - llvm::Function *create_function (const llvm::Twine& name, jit_type *ret, - const std::vector& args); - - llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, - const std::vector& args); - - llvm::Function *create_identity (jit_type *type); + jit_function create_identity (jit_type *type); llvm::Value *do_insert_error_check (void); @@ -667,13 +738,13 @@ octave_builtin *find_builtin (const std::string& name); - llvm::Function *mirror_binary (llvm::Function *fn); + jit_function mirror_binary (const jit_function& fn); llvm::Function *wrap_complex (llvm::Function *wrap); - llvm::Value *pack_complex (llvm::Value *cplx); - - llvm::Value *unpack_complex (llvm::Value *result); + static llvm::Value *pack_complex (llvm::Value *cplx); + + static llvm::Value *unpack_complex (llvm::Value *result); llvm::Value *complex_real (llvm::Value *cx); @@ -685,6 +756,10 @@ llvm::Value *complex_new (llvm::Value *real, llvm::Value *imag); + void create_int (size_t nbits); + + jit_type *intN (size_t nbits) const; + static jit_typeinfo *instance; llvm::Module *module; @@ -703,6 +778,7 @@ jit_type *index; jit_type *complex; jit_type *unknown_function; + std::map ints; std::map builtins; llvm::StructType *complex_ret; @@ -723,7 +799,7 @@ std::vector casts; // type id -> identity function - std::vector identities; + std::vector identities; }; // The low level octave jit ir @@ -1744,13 +1820,14 @@ { public: #define JIT_CALL_CONST(N) \ - jit_call (const jit_operation& afunction, \ + jit_call (const jit_operation& aoperation, \ OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ - : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), mfunction (afunction) {} \ + : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), moperation (aoperation) {} \ \ - jit_call (const jit_operation& (*afunction) (void), \ + jit_call (const jit_operation& (*aoperation) (void), \ OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ - : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), mfunction (afunction ()) {} + : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), moperation (aoperation ()) \ + {} JIT_CALL_CONST (1) JIT_CALL_CONST (2) @@ -1760,21 +1837,21 @@ #undef JIT_CALL_CONST - const jit_operation& function (void) const { return mfunction; } + const jit_operation& operation (void) const { return moperation; } bool can_error (void) const { - return overload ().can_error; + return overload ().can_error (); } - const jit_operation::overload& overload (void) const + const jit_function& overload (void) const { - return mfunction.get_overload (argument_types ()); + return moperation.overload (argument_types ()); } virtual bool needs_release (void) const { - return type () && jit_typeinfo::get_release (type ()).function; + return type () && jit_typeinfo::get_release (type ()).valid (); } virtual std::ostream& print (std::ostream& os, size_t indent = 0) const @@ -1783,7 +1860,7 @@ if (use_count ()) short_print (os) << " = "; - os << "call " << mfunction.name () << " ("; + os << "call " << moperation.name () << " ("; for (size_t i = 0; i < argument_count (); ++i) { @@ -1798,7 +1875,7 @@ JIT_VALUE_ACCEPT; private: - const jit_operation& mfunction; + const jit_operation& moperation; }; // FIXME: This is just ugly... @@ -1846,7 +1923,7 @@ return dest ()->name (); } - const jit_operation::overload& overload (void) const + const jit_function& overload (void) const { return jit_typeinfo::cast (type (), jit_typeinfo::get_any ()); } @@ -1874,7 +1951,7 @@ return dest->name (); } - const jit_operation::overload& overload (void) const + const jit_function& overload (void) const { return jit_typeinfo::cast (jit_typeinfo::get_any (), result_type ()); } @@ -2268,28 +2345,6 @@ { jvalue.accept (*this); } - - llvm::Value *create_call (const jit_operation::overload& ol, jit_value *arg0) - { - std::vector args (1, arg0); - return create_call (ol, args); - } - - llvm::Value *create_call (const jit_operation::overload& ol, jit_value *arg0, - jit_value *arg1) - { - std::vector args (2); - args[0] = arg0; - args[1] = arg1; - - return create_call (ol, args); - } - - llvm::Value *create_call (const jit_operation::overload& ol, - const std::vector& jargs); - - llvm::Value *create_call (const jit_operation::overload& ol, - const std::vector& uses); private: jit_convert &jthis; llvm::Function *function; @@ -2353,6 +2408,8 @@ #undef JIT_VISIT_IR_CLASSES #undef JIT_VISIT_IR_CONST #undef JIT_VALUE_ACCEPT +#undef ASSIGN_ARG +#undef JIT_EXPAND #endif #endif