Mercurial > hg > octave-lyh
changeset 15135:bd6bb87e2bea
Support sin, cos, and exp with matrix arguments in JIT
* src/interp-core/jit-typeinfo.cc (jit_operation::generate): Remove unused
parameter name.
(jit_typeinfo::jit_typeinfo): Create any_call function.
(jit_typeinfo::register_generic): Implement.
* src/interp-core/jit-typeinfo.h (jit_typeinfo): New field, any_call.
* src/interp-core/pt-jit.cc: New test.
author | Max Brister <max@2bass.com> |
---|---|
date | Thu, 09 Aug 2012 15:45:59 -0500 |
parents | edae65062740 |
children | eeaaac7c86b6 |
files | src/interp-core/jit-typeinfo.cc src/interp-core/jit-typeinfo.h src/interp-core/pt-jit.cc |
diffstat | 3 files changed, 62 insertions(+), 4 deletions(-) [+] |
line wrap: on
line diff
--- a/src/interp-core/jit-typeinfo.cc +++ b/src/interp-core/jit-typeinfo.cc @@ -837,7 +837,7 @@ } jit_function * -jit_operation::generate (const signature_vec& types) const +jit_operation::generate (const signature_vec&) const { return 0; } @@ -1041,6 +1041,7 @@ complex = new_type ("complex", any, complex_t); scalar = new_type ("scalar", complex, scalar_t); scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ()); + any_ptr = new_type ("any_ptr", 0, any_t->getPointerTo ()); range = new_type ("range", any, range_t); string = new_type ("string", any, string_t); boolean = new_type ("bool", any, bool_t); @@ -1080,6 +1081,14 @@ engine->addGlobalMapping (lerror_state, reinterpret_cast<void *> (&error_state)); + // generic call function + { + jit_type *int_t = intN (sizeof (octave_builtin::fcn) * 8); + any_call = create_function (jit_convention::external, "octave_jit_call", + any, int_t, int_t, any_ptr, int_t); + any_call.add_mapping (engine, &octave_jit_call); + } + // any with anything is an any op jit_function fn; jit_type *binary_op_type = intN (sizeof (octave_value::binary_op) * 8); @@ -1974,10 +1983,48 @@ } void -jit_typeinfo::register_generic (const std::string&, jit_type *, - const std::vector<jit_type *>&) +jit_typeinfo::register_generic (const std::string& name, jit_type *result, + const std::vector<jit_type *>& args) { - // FIXME: Implement + octave_builtin *builtin = find_builtin (name); + if (! builtin) + return; + + std::vector<jit_type *> fn_args (args.size () + 1); + fn_args[0] = builtins[name]; + std::copy (args.begin (), args.end (), fn_args.begin () + 1); + jit_function fn = create_function (jit_convention::internal, name, result, + fn_args); + llvm::BasicBlock *block = fn.new_block (); + builder.SetInsertPoint (block); + llvm::Type *any_t = any->to_llvm (); + llvm::ArrayType *array_t = llvm::ArrayType::get (any_t, args.size ()); + llvm::Value *array = llvm::UndefValue::get (array_t); + for (size_t i = 0; i < args.size (); ++i) + { + llvm::Value *arg = fn.argument (builder, i + 1); + jit_function agrab = get_grab (args[i]); + llvm::Value *garg = agrab.call (builder, arg); + jit_function acast = cast (any, args[i]); + array = builder.CreateInsertValue (array, acast.call (builder, garg), i); + } + + llvm::Value *array_mem = builder.CreateAlloca (array_t); + builder.CreateStore (array, array_mem); + array = builder.CreateBitCast (array_mem, any_t->getPointerTo ()); + + jit_type *jintTy = intN (sizeof (octave_builtin::fcn) * 8); + llvm::Type *intTy = jintTy->to_llvm (); + size_t fcn_int = reinterpret_cast<size_t> (builtin->function ()); + llvm::Value *fcn = llvm::ConstantInt::get (intTy, fcn_int); + llvm::Value *nargin = llvm::ConstantInt::get (intTy, args.size ()); + size_t result_int = reinterpret_cast<size_t> (result); + llvm::Value *res_llvm = llvm::ConstantInt::get (intTy, result_int); + llvm::Value *ret = any_call.call (builder, fcn, nargin, array, res_llvm); + + jit_function cast_result = cast (result, any); + fn.do_return (builder, cast_result.call (builder, ret)); + paren_subsref_fn.add_overload (fn); } jit_function
--- a/src/interp-core/jit-typeinfo.h +++ b/src/interp-core/jit-typeinfo.h @@ -724,6 +724,7 @@ jit_type *matrix; jit_type *scalar; jit_type *scalar_ptr; // a fake type for interfacing with C++ + jit_type *any_ptr; // a fake type for interfacing with C++ jit_type *range; jit_type *string; jit_type *boolean; @@ -749,6 +750,8 @@ jit_operation end1_fn; jit_operation end_fn; + jit_function any_call; + // type id -> cast function TO that type std::vector<jit_operation> casts;