# HG changeset patch # User Max Brister # Date 1340824460 18000 # Node ID e3cd4c9d7ccce44fb47502c56fbcfbde4f81ec0b # Parent 2960f1b2d6eac851032c0c2256fa75e955327ebd Generalize builtin specification in JIT and add support for cos and exp * src/ov-builtin.cc (octave_builtin::function): New function. * src/ov-builtin.h (octave_builtin::function): New declaration. * src/pt-jit.cc (gripe_bad_result, octave_jit_call, jit_typeinfo::add_builtin, jit_typeinfo::register_intrinsic, jit_typeinfo::find_builtin, jit_typeinfo::register_generic): New function. (jit_typeinfo::jit_typeinfo): Generalize builtin specification and add support for cos and exp. (jit_typeinfo::create_function): New overload. * src/pt-jit.h (overload::overload, jit_function::add_overload, jit_typeinfo::create_function): New overload. (jit_typeinfo::add_builtin, jit_typeinfo::register_intrinsic, jit_typeinfo::register_generic, jit_typeinfo::find_builtin): New declaration. diff --git a/src/ov-builtin.cc b/src/ov-builtin.cc --- a/src/ov-builtin.cc +++ b/src/ov-builtin.cc @@ -164,4 +164,10 @@ jtype = &type; } +octave_builtin::fcn +octave_builtin::function (void) const +{ + return f; +} + const std::list *octave_builtin::curr_lvalue_list = 0; diff --git a/src/ov-builtin.h b/src/ov-builtin.h --- a/src/ov-builtin.h +++ b/src/ov-builtin.h @@ -80,6 +80,8 @@ void stash_jit (jit_type& type); + fcn function (void) const; + static const std::list *curr_lvalue_list; protected: diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -261,6 +261,55 @@ std::cout << *m << std::endl; } +static void +gripe_bad_result (void) +{ + error ("incorrect type information given to the JIT compiler"); +} + +// FIXME: Add support for multiple outputs +extern "C" octave_base_value * +octave_jit_call (octave_builtin::fcn fn, size_t nargin, + octave_base_value **argin, jit_type *result_type) +{ + octave_value_list ovl (nargin); + for (size_t i = 0; i < nargin; ++i) + ovl.xelem (i) = octave_value (argin[i]); + + ovl = fn (ovl, 1); + + // These type checks are not strictly required, but I'm guessing that + // incorrect types will be entered on occasion. This will be very difficult to + // debug unless we do the sanity check here. + if (result_type) + { + if (ovl.length () != 1) + { + gripe_bad_result (); + return 0; + } + + octave_value& result = ovl.xelem (0); + jit_type *jtype = jit_typeinfo::join (jit_typeinfo::type_of (result), + result_type); + if (jtype != result_type) + { + gripe_bad_result (); + return 0; + } + + octave_base_value *ret = result.internal_rep (); + ret->grab (); + return ret; + } + + if (! (ovl.length () == 0 + || (ovl.length () == 1 && ovl.xelem (0).is_undefined ()))) + gripe_bad_result (); + + return 0; +} + // -------------------- jit_range -------------------- std::ostream& operator<< (std::ostream& os, const jit_range& rng) @@ -408,8 +457,6 @@ boolean = new_type ("bool", any, bool_t); index = new_type ("index", any, index_t); - sin_type = new_type ("sin", any, any_t); - casts.resize (next_id + 1); identities.resize (next_id + 1, 0); @@ -900,33 +947,27 @@ // -------------------- builtin functions -------------------- - - // FIXME: Handling this here is messy, but it's the easiest way for now - // FIXME: Come up with a nicer way of defining functions - octave_value ov_sin = symbol_table::builtin_find ("sin"); - octave_builtin *bsin - = dynamic_cast (ov_sin.internal_rep ()); - if (bsin) + add_builtin ("sin"); + register_intrinsic ("sin", llvm::Intrinsic::sin, scalar, scalar); + register_generic ("sin", matrix, matrix); + + add_builtin ("cos"); + register_intrinsic ("cos", llvm::Intrinsic::cos, scalar, scalar); + register_generic ("cos", matrix, matrix); + + add_builtin ("exp"); + register_intrinsic ("exp", llvm::Intrinsic::cos, scalar, scalar); + register_generic ("exp", matrix, matrix); + + casts.resize (next_id + 1); + fn = create_identity (any); + for (std::map::iterator iter = builtins.begin (); + iter != builtins.end (); ++iter) { - bsin->stash_jit (*sin_type); - - llvm::Function *isin - = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::sin, - llvm::makeArrayRef (scalar_t)); - fn = create_function ("octave_jit_sin", scalar, any, scalar); - body = llvm::BasicBlock::Create (context, "body", fn); - builder.SetInsertPoint (body); - { - llvm::Value *ret = builder.CreateCall (isin, ++fn->arg_begin ()); - builder.CreateRet (ret); - } - llvm::verifyFunction (*fn); - paren_subsref_fn.add_overload (fn, false, scalar, sin_type, scalar); - release_fn.add_overload (release_any, false, 0, sin_type); - - fn = create_identity (any); - casts[any->type_id ()].add_overload (fn, false, any, sin_type); - casts[sin_type->type_id ()].add_overload (fn, false, sin_type, any); + jit_type *btype = iter->second; + release_fn.add_overload (release_any, false, 0, btype); + casts[any->type_id ()].add_overload (fn, false, any, btype); + casts[btype->type_id ()].add_overload (fn, false, btype, any); } } @@ -1014,6 +1055,19 @@ } llvm::Function * +jit_typeinfo::create_function (const llvm::Twine& name, jit_type *ret, + const std::vector& args) +{ + llvm::Type *void_t = llvm::Type::getVoidTy (context); + std::vector llvm_args (args.size (), void_t); + for (size_t i = 0; i < args.size (); ++i) + if (args[i]) + llvm_args[i] = args[i]->to_llvm (); + + return create_function (name, ret ? ret->to_llvm () : void_t, llvm_args); +} + +llvm::Function * jit_typeinfo::create_function (const llvm::Twine& name, llvm::Type *ret, const std::vector& args) { @@ -1051,6 +1105,74 @@ return builder.CreateLoad (lerror_state); } +void +jit_typeinfo::add_builtin (const std::string& name) +{ + jit_type *btype = new_type (name, any, any->to_llvm ()); + builtins[name] = btype; + + octave_builtin *ov_builtin = find_builtin (name); + if (ov_builtin) + ov_builtin->stash_jit (*btype); +} + +void +jit_typeinfo::register_intrinsic (const std::string& name, size_t iid, + jit_type *result, + const std::vector& args) +{ + jit_type *builtin_type = builtins[name]; + size_t nargs = args.size (); + llvm::SmallVector llvm_args (nargs); + for (size_t i = 0; i < nargs; ++i) + llvm_args[i] = args[i]->to_llvm (); + + llvm::Intrinsic::ID id = static_cast (iid); + llvm::Function *ifun = llvm::Intrinsic::getDeclaration (module, id, + llvm_args); + std::stringstream fn_name; + fn_name << "octave_jit_" << name; + + std::vector args1 (nargs + 1); + args1[0] = builtin_type; + std::copy (args.begin (), args.end (), args1.begin () + 1); + + // The first argument will be the Octave function, but we already know that + // the function call is the equivalent of the intrinsic, so we ignore it and + // call the intrinsic with the remaining arguments. + llvm::Function *fn = create_function (fn_name.str (), result, args1); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + + llvm::SmallVector fargs (nargs); + llvm::Function::arg_iterator iter = fn->arg_begin (); + ++iter; + for (size_t i = 0; i < nargs; ++i, ++iter) + fargs[i] = iter; + + llvm::Value *ret = builder.CreateCall (ifun, fargs); + builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + paren_subsref_fn.add_overload (fn, false, result, args1); +} + +octave_builtin * +jit_typeinfo::find_builtin (const std::string& name) +{ + // FIXME: Finalize what we want to store in octave_builtin, then add functions + // to access these values in octave_value + octave_value ov_builtin = symbol_table::find (name); + return dynamic_cast (ov_builtin.internal_rep ()); +} + +void +jit_typeinfo::register_generic (const std::string&, jit_type *, + const std::vector&) +{ + // FIXME: Implement +} + jit_type * jit_typeinfo::do_type_of (const octave_value &ov) const { diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -84,6 +84,7 @@ } class octave_base_value; +class octave_builtin; class octave_value; class tree; class tree_expression; @@ -329,6 +330,11 @@ arguments[2] = arg2; } + overload (llvm::Function *f, bool e, jit_type *r, + const std::vector& aarguments) + : function (f), can_error (e), result (r), arguments (aarguments) + {} + llvm::Function *function; bool can_error; jit_type *result; @@ -360,6 +366,13 @@ add_overload (ol); } + void add_overload (llvm::Function *f, bool e, jit_type *r, + const std::vector& args) + { + overload ol (f, e, r, args); + add_overload (ol); + } + void add_overload (const overload& func, const std::vector& args); @@ -660,6 +673,9 @@ return create_function (name, ret, args); } + llvm::Function *create_function (const llvm::Twine& name, jit_type *ret, + const std::vector& args); + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, const std::vector& args); @@ -667,6 +683,30 @@ llvm::Value *do_insert_error_check (void); + void add_builtin (const std::string& name); + + void register_intrinsic (const std::string& name, size_t id, + jit_type *result, jit_type *arg0) + { + std::vector args (1, arg0); + register_intrinsic (name, id, result, args); + } + + void register_intrinsic (const std::string& name, size_t id, jit_type *result, + const std::vector& args); + + void register_generic (const std::string& name, jit_type *result, + jit_type *arg0) + { + std::vector args (1, arg0); + register_generic (name, result, args); + } + + void register_generic (const std::string& name, jit_type *result, + const std::vector& args); + + octave_builtin *find_builtin (const std::string& name); + static jit_typeinfo *instance; llvm::Module *module; @@ -683,7 +723,7 @@ jit_type *string; jit_type *boolean; jit_type *index; - jit_type *sin_type; + std::map builtins; std::vector binary_ops; jit_function grab_fn;