Mercurial > hg > octave-terminal
changeset 15056:bc32288f4a42
Support the end keyword for one dimentional indexing in JIT.
* src/jit-ir.cc (jit_magic_end): New class.
* src/jit-ir.h (jit_magic_end): New class.
(jit_instruction::jit_instruction): New overload.
* src/jit-typeinfo.cc (jit_function::call): Throw jit_fail_exception if invalid.
(jit_typeinfo::jit_typeinfo): Initialize end_fn.
* src/jit-typeinfo.h (jit_typeinfo::end): New function.
* src/pt-jit.cc (jit_convert::visit_identifier): Handle magic_end.
(jit_convert::resolve): Keep track of end context.
(jit_convert::convert_llvm::visit): New overload.
* src/pt-jit.h (jit_convert): Add end_context.
author | Max Brister <max@2bass.com> |
---|---|
date | Mon, 30 Jul 2012 13:05:29 -0500 |
parents | a6d4965ef04b |
children | 46b19589b593 6130d87495b8 |
files | src/jit-ir.cc src/jit-ir.h src/jit-typeinfo.cc src/jit-typeinfo.h src/pt-jit.cc src/pt-jit.h |
diffstat | 6 files changed, 133 insertions(+), 6 deletions(-) [+] |
line wrap: on
line diff
--- a/src/jit-ir.cc +++ b/src/jit-ir.cc @@ -598,4 +598,36 @@ return false; } +// -------------------- jit_magic_end -------------------- +const jit_function& +jit_magic_end::overload () const +{ + jit_value *ctx = resolve_context (); + if (ctx) + return jit_typeinfo::end (ctx->type ()); + + static jit_function null_ret; + return null_ret; +} + +jit_value * +jit_magic_end::resolve_context (void) const +{ + // FIXME: We need to have a way of marking functions so we can skip them here + return argument_count () ? argument (0) : 0; +} + +bool +jit_magic_end::infer (void) +{ + jit_type *new_type = overload ().result (); + if (new_type != type ()) + { + stash_type (new_type); + return true; + } + + return false; +} + #endif
--- a/src/jit-ir.h +++ b/src/jit-ir.h @@ -46,7 +46,8 @@ JIT_METH(variable); \ JIT_METH(error_check); \ JIT_METH(assign) \ - JIT_METH(argument) + JIT_METH(argument) \ + JIT_METH(magic_end) #define JIT_VISIT_IR_CONST \ JIT_METH(const_bool); \ @@ -256,6 +257,14 @@ #undef STASH_ARG #undef JIT_INSTRUCTION_CTOR + jit_instruction (const std::vector<jit_value *>& aarguments) + : already_infered (aarguments.size ()), marguments (aarguments.size ()), + mid (next_id ()), mparent (0) + { + for (size_t i = 0; i < aarguments.size (); ++i) + stash_argument (i, aarguments[i]); + } + static void reset_ids (void) { next_id (true); @@ -1137,6 +1146,34 @@ } }; +// for now only handles the 1D case +class +jit_magic_end : public jit_instruction +{ +public: + jit_magic_end (const std::vector<jit_value *>& context) + : jit_instruction (context) + {} + + const jit_function& overload () const; + + jit_value *resolve_context (void) const; + + virtual bool infer (void); + + virtual std::ostream& short_print (std::ostream& os) const + { + return os << "magic_end"; + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + return short_print (print_indent (os, indent)); + } + + JIT_VALUE_ACCEPT; +}; + class jit_extract_argument : public jit_assign_base {
--- a/src/jit-typeinfo.cc +++ b/src/jit-typeinfo.cc @@ -522,8 +522,10 @@ jit_function::call (llvm::IRBuilderD& builder, const std::vector<jit_value *>& in_args) const { + if (! valid ()) + throw jit_fail_exception ("Call not implemented"); + assert (in_args.size () == args.size ()); - std::vector<llvm::Value *> llvm_args (args.size ()); for (size_t i = 0; i < in_args.size (); ++i) llvm_args[i] = in_args[i]->to_llvm (); @@ -535,7 +537,9 @@ jit_function::call (llvm::IRBuilderD& builder, const std::vector<llvm::Value *>& in_args) const { - assert (valid ()); + if (! valid ()) + throw jit_fail_exception ("Call not implemented"); + assert (in_args.size () == args.size ()); llvm::Function *stacksave = llvm::Intrinsic::getDeclaration (module, llvm::Intrinsic::stacksave); @@ -1342,8 +1346,7 @@ builder.CreateBr (done); builder.SetInsertPoint (normal); - llvm::Value *len = builder.CreateExtractValue (mat, - llvm::ArrayRef<unsigned> (2)); + llvm::Value *len = builder.CreateExtractValue (mat, 2); cond0 = builder.CreateICmpSGT (int_idx, len); llvm::Value *rcount = builder.CreateExtractValue (mat, 0); @@ -1386,6 +1389,18 @@ fn.mark_can_error (); paren_subsasgn_fn.add_overload (fn); + end_fn.stash_name ("end"); + fn = create_function (jit_convention::internal, "octave_jit_end_matrix", + scalar, matrix); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *mat = fn.argument (builder, 0); + llvm::Value *ret = builder.CreateExtractValue (mat, 2); + fn.do_return (builder, builder.CreateSIToFP (ret, scalar_t)); + } + end_fn.add_overload (fn); + casts[any->type_id ()].stash_name ("(any)"); casts[scalar->type_id ()].stash_name ("(scalar)"); casts[complex->type_id ()].stash_name ("(complex)");
--- a/src/jit-typeinfo.h +++ b/src/jit-typeinfo.h @@ -471,6 +471,16 @@ { return instance->do_insert_error_check (bld); } + + static const jit_operation& end (void) + { + return instance->end_fn; + } + + static const jit_function& end (jit_type *ty) + { + return instance->end_fn.overload (ty); + } private: jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e); @@ -655,6 +665,7 @@ jit_operation make_range_fn; jit_operation paren_subsref_fn; jit_operation paren_subsasgn_fn; + jit_operation end_fn; // type id -> cast function TO that type std::vector<jit_operation> casts;
--- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -412,7 +412,14 @@ void jit_convert::visit_identifier (tree_identifier& ti) { - result = get_variable (ti.name ()); + if (ti.has_magic_end ()) + { + if (!end_context.size ()) + throw jit_fail_exception ("Illegal end"); + result = block->append (create<jit_magic_end> (end_context)); + } + else + result = get_variable (ti.name ()); } void @@ -826,6 +833,12 @@ tree_expression *tree_object = exp.expression (); jit_value *object = visit (tree_object); + + end_context.push_back (object); + + unwind_protect prot; + prot.add_method (&end_context, &std::vector<jit_value *>::pop_back); + tree_expression *arg0 = arg_list->front (); jit_value *index = visit (arg0); @@ -1479,6 +1492,14 @@ jit_convert::convert_llvm::visit (jit_argument&) {} +void +jit_convert::convert_llvm::visit (jit_magic_end& me) +{ + const jit_function& ol = me.overload (); + llvm::Value *ret = ol.call (builder, me.resolve_context ()); + me.stash_llvm (ret); +} + // -------------------- tree_jit -------------------- tree_jit::tree_jit (void) : module (0), engine (0) @@ -1823,4 +1844,13 @@ %! endwhile %! assert (i == niter); +%!test +%! niter = 1001; +%! result = 0; +%! m = [5 10]; +%! for i=1:niter +%! result = result + m(end); +%! endfor +%! assert (result == m(end) * niter); + */