# HG changeset patch # User Max Brister # Date 1343753461 18000 # Node ID df4538e3b50bd0cd737c57c0ae82f46920dbd9b3 # Parent 6451a584305e11ab1e0728a074a404083fff199d ND scalar indexing in JIT. * src/jit-ir.cc (jit_magic_end::jit_magic_end): Use jit_magic_end::context. * src/jit-ir.h (jit_call::jit_call): New overload. (jit_magic_end::context): New class. (jit_magic_end::jit_magic_end): moved to src/jit-ir.cc. * src/jit-typeinfo.cc (octave_jit_paren_scalar): New function. (jit_typeinfo::jit_typeinfo): Generate ND scalar indexing. (jit_typeinfo::gen_subsref): New function. * src/jit-typeinfo.h (jit_typeinfo::gen_subsref): New declaration. * src/pt-jit.cc (jit_convert::visit_index_expression, jit_convert::do_assign): Update resolve call. (jit_convert::resolve): Resolve ND indices. * src/pt-jit.h (jit_convert::resolve): Change function signature. diff --git a/src/jit-ir.cc b/src/jit-ir.cc --- a/src/jit-ir.cc +++ b/src/jit-ir.cc @@ -599,6 +599,22 @@ } // -------------------- jit_magic_end -------------------- +jit_magic_end::jit_magic_end (const std::vector& full_context) +{ + // for now we only support end in 1 dimensional indexing + resize_arguments (full_context.size ()); + + size_t i; + std::vector::const_iterator iter; + for (iter = full_context.begin (), i = 0; iter != full_context.end (); ++iter, + ++i) + { + if (iter->count != 1) + throw jit_fail_exception ("end is only supported in linear contexts"); + stash_argument (i, iter->value); + } +} + const jit_function& jit_magic_end::overload () const { diff --git a/src/jit-ir.h b/src/jit-ir.h --- a/src/jit-ir.h +++ b/src/jit-ir.h @@ -1074,6 +1074,10 @@ #undef JIT_CALL_CONST + jit_call (const jit_operation& aoperation, + const std::vector& args) + : jit_instruction (args), moperation (aoperation) + {} const jit_operation& operation (void) const { return moperation; } @@ -1151,9 +1155,23 @@ jit_magic_end : public jit_instruction { public: - jit_magic_end (const std::vector& context) - : jit_instruction (context) - {} + class + context + { + public: + context (void) : value (0), index (0), count (0) + {} + + context (jit_value *avalue, size_t aindex, size_t acount) + : value (avalue), index (aindex), count (acount) + {} + + jit_value *value; + size_t index; + size_t count; + }; + + jit_magic_end (const std::vector& full_context); const jit_function& overload () const; diff --git a/src/jit-typeinfo.cc b/src/jit-typeinfo.cc --- a/src/jit-typeinfo.cc +++ b/src/jit-typeinfo.cc @@ -243,6 +243,27 @@ *ret = *mat; } +extern "C" double +octave_jit_paren_scalar (jit_matrix *mat, double *indicies, + octave_idx_type idx_count) +{ + // FIXME: Replace this with a more optimal version + try + { + Array idx (dim_vector (1, idx_count)); + for (octave_idx_type i = 0; i < idx_count; ++i) + idx(i) = idx_vector (indicies[i]); + + Array ret = mat->array->index (idx); + return ret.xelem (0); + } + catch (const octave_execution_exception&) + { + gripe_library_execution_error (); + return 0; + } +} + extern "C" void octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat, jit_range *index, double value) @@ -789,6 +810,9 @@ boolean = new_type ("bool", any, bool_t); index = new_type ("index", any, index_t); + // a fake type for interfacing with C++ + jit_type *scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ()); + create_int (8); create_int (16); create_int (32); @@ -1310,6 +1334,18 @@ } paren_subsref_fn.add_overload (fn); + // generate () subsref for ND indexing of matricies with scalars + jit_function paren_scalar = create_function (jit_convention::external, + "octave_jit_paren_scalar", + scalar, matrix, scalar_ptr, + index); + paren_scalar.add_mapping (engine, &octave_jit_paren_scalar); + paren_scalar.mark_can_error (); + + // FIXME: Generate this on the fly + for (size_t i = 2; i < 10; ++i) + gen_subsref (paren_scalar, i); + // paren subsasgn paren_subsasgn_fn.stash_name ("()subsasgn"); @@ -1831,4 +1867,37 @@ return ret; } +void +jit_typeinfo::gen_subsref (const jit_function& paren_scalar, size_t n) +{ + std::stringstream name; + name << "jit_paren_subsref_matrix_scalar" << n; + std::vector args (n + 1, scalar); + args[0] = matrix; + jit_function fn = create_function (jit_convention::internal, name.str (), + scalar, args); + fn.mark_can_error (); + llvm::BasicBlock *body = fn.new_block (); + builder.SetInsertPoint (body); + + llvm::Type *scalar_t = scalar->to_llvm (); + llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n); + llvm::Value *array = llvm::UndefValue::get (array_t); + for (size_t i = 0; i < n; ++i) + { + llvm::Value *idx = fn.argument (builder, i + 1); + array = builder.CreateInsertValue (array, idx, i); + } + + llvm::Value *array_mem = builder.CreateAlloca (array_t); + builder.CreateStore (array, array_mem); + array = builder.CreateBitCast (array_mem, scalar_t->getPointerTo ()); + + llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), n); + llvm::Value *mat = fn.argument (builder, 0); + llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem); + fn.do_return (builder, ret); + paren_subsref_fn.add_overload (fn); +} + #endif diff --git a/src/jit-typeinfo.h b/src/jit-typeinfo.h --- a/src/jit-typeinfo.h +++ b/src/jit-typeinfo.h @@ -631,6 +631,8 @@ jit_type *intN (size_t nbits) const; + void gen_subsref (const jit_function& paren_scalar, size_t n); + static jit_typeinfo *instance; llvm::Module *module; diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -518,11 +518,7 @@ void jit_convert::visit_index_expression (tree_index_expression& exp) { - std::pair res = resolve (exp); - jit_value *object = res.first; - jit_value *index = res.second; - - result = create_checked (jit_typeinfo::paren_subsref, object, index); + result = resolve (jit_typeinfo::paren_subsref (), exp); } void @@ -813,8 +809,8 @@ return ss.str (); } -std::pair -jit_convert::resolve (tree_index_expression& exp) +jit_instruction * +jit_convert::resolve (const jit_operation& fres, tree_index_expression& exp) { std::string type = exp.type_tags (); if (! (type.size () == 1 && type[0] == '(')) @@ -828,21 +824,27 @@ if (! arg_list) throw jit_fail_exception ("null argument list"); - if (arg_list->size () != 1) - throw jit_fail_exception ("Bad number of arguments in arg_list"); + if (arg_list->size () < 1) + throw jit_fail_exception ("Empty arg_list"); tree_expression *tree_object = exp.expression (); jit_value *object = visit (tree_object); - end_context.push_back (object); - - unwind_protect prot; - prot.add_method (&end_context, &std::vector::pop_back); + size_t narg = arg_list->size (); + tree_argument_list::iterator iter = arg_list->begin (); + std::vector call_args (narg + 1); + call_args[0] = object; - tree_expression *arg0 = arg_list->front (); - jit_value *index = visit (arg0); + for (size_t idx = 0; iter != arg_list->end (); ++idx, ++iter) + { + unwind_protect prot; + prot.add_method (&end_context, + &std::vector::pop_back); + end_context.push_back (jit_magic_end::context (object, idx, narg)); + call_args[idx + 1] = visit (*iter); + } - return std::make_pair (object, index); + return create_checked (fres, call_args); } jit_value * @@ -856,14 +858,8 @@ else if (tree_index_expression *idx = dynamic_cast (exp)) { - std::pair res = resolve (*idx); - jit_value *object = res.first; - jit_value *index = res.second; - jit_call *new_object = create (&jit_typeinfo::paren_subsasgn, - object, index, rhs); - block->append (new_object); + jit_value *new_object = resolve (jit_typeinfo::paren_subsasgn (), *idx); do_assign (idx->expression (), new_object, true); - create_check (new_object); // FIXME: Will not work for values that must be release/grabed return rhs; @@ -1853,4 +1849,17 @@ %! endfor %! assert (result == m(end) * niter); +%!test +%! ndim = 100; +%! result = 0; +%! m = zeros (ndim); +%! m(:) = 1:ndim^2; +%! i = 1; +%! while (i <= ndim) +%! for j = 1:ndim +%! result = result + m(i, j); +%! endfor +%! i = i + 1; +%! endwhile +%! assert (result == sum (sum (m))); */ diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -244,7 +244,7 @@ std::list all_values; - std::vector end_context; + std::vector end_context; size_t iterator_count; size_t for_bounds_count; @@ -296,7 +296,8 @@ std::string next_name (const char *prefix, size_t& count, bool inc); - std::pair resolve (tree_index_expression& exp); + jit_instruction *resolve (const jit_operation& fres, + tree_index_expression& exp); jit_value *do_assign (tree_expression *exp, jit_value *rhs, bool artificial = false);