Mercurial > hg > octave-terminal
changeset 15068:f57d7578c1a6
Support ND matrix indexing with scalar assignment in JIT.
* src/jit-typeinfo.cc (make_indices, octave_jit_paren_scalar_subsasgn,
jit_typeinfo::gen_subsasgn): New function.
(octave_jit_paren_scalar): Use make_indices.
(jit_typeinfo::jit_typeinfo): Call gen_subsasgn.
* src/pt-jit.h (jit_typeinfo::gen_subsasgn): New declaration.
* src/pt-jit.cc (jit_convert::resolve): Add extra_arg argument.
(jit_convert::do_assign): Pass rhs to resolve.
* src/pt-jit.h (jit_convert::resolve): Change function signature.
author | Max Brister <max@2bass.com> |
---|---|
date | Tue, 31 Jul 2012 15:40:52 -0500 |
parents | df4538e3b50b |
children | 7a3957ca99c3 |
files | src/jit-typeinfo.cc src/jit-typeinfo.h src/pt-jit.cc src/pt-jit.h |
diffstat | 4 files changed, 106 insertions(+), 8 deletions(-) [+] |
line wrap: on
line diff
--- a/src/jit-typeinfo.cc +++ b/src/jit-typeinfo.cc @@ -243,6 +243,15 @@ *ret = *mat; } +static void +make_indices (double *indices, octave_idx_type idx_count, + Array<idx_vector>& result) +{ + result.resize (dim_vector (1, idx_count)); + for (octave_idx_type i = 0; i < idx_count; ++i) + result(i) = idx_vector (indices[i]); +} + extern "C" double octave_jit_paren_scalar (jit_matrix *mat, double *indicies, octave_idx_type idx_count) @@ -250,9 +259,8 @@ // FIXME: Replace this with a more optimal version try { - Array<idx_vector> idx (dim_vector (1, idx_count)); - for (octave_idx_type i = 0; i < idx_count; ++i) - idx(i) = idx_vector (indicies[i]); + Array<idx_vector> idx; + make_indices (indicies, idx_count, idx); Array<double> ret = mat->array->index (idx); return ret.xelem (0); @@ -265,6 +273,28 @@ } extern "C" void +octave_jit_paren_scalar_subsasgn (jit_matrix *ret, jit_matrix *mat, + double *indices, octave_idx_type idx_count, + double value) +{ + // FIXME: Replace this with a more optimal version + try + { + Array<idx_vector> idx; + make_indices (indices, idx_count, idx); + + Matrix temp (1, 1); + temp.xelem(0) = value; + mat->array->assign (idx, temp); + ret->update (mat->array); + } + catch (const octave_execution_exception&) + { + gripe_library_execution_error (); + } +} + +extern "C" void octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat, jit_range *index, double value) { @@ -1342,9 +1372,19 @@ paren_scalar.add_mapping (engine, &octave_jit_paren_scalar); paren_scalar.mark_can_error (); + jit_function paren_scalar_subsasgn + = create_function (jit_convention::external, + "octave_jit_paren_scalar_subsasgn", matrix, matrix, + scalar_ptr, index, scalar); + paren_scalar_subsasgn.add_mapping (engine, &octave_jit_paren_scalar_subsasgn); + paren_scalar_subsasgn.mark_can_error (); + // FIXME: Generate this on the fly for (size_t i = 2; i < 10; ++i) - gen_subsref (paren_scalar, i); + { + gen_subsref (paren_scalar, i); + gen_subsasgn (paren_scalar_subsasgn, i); + } // paren subsasgn paren_subsasgn_fn.stash_name ("()subsasgn"); @@ -1900,4 +1940,38 @@ paren_subsref_fn.add_overload (fn); } +void +jit_typeinfo::gen_subsasgn (const jit_function& paren_scalar, size_t n) +{ + std::stringstream name; + name << "jit_paren_subsasgn_matrix_scalar" << n; + std::vector<jit_type *> args (n + 2, scalar); + args[0] = matrix; + jit_function fn = create_function (jit_convention::internal, name.str (), + matrix, args); + fn.mark_can_error (); + llvm::BasicBlock *body = fn.new_block (); + builder.SetInsertPoint (body); + + llvm::Type *scalar_t = scalar->to_llvm (); + llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n); + llvm::Value *array = llvm::UndefValue::get (array_t); + for (size_t i = 0; i < n; ++i) + { + llvm::Value *idx = fn.argument (builder, i + 1); + array = builder.CreateInsertValue (array, idx, i); + } + + llvm::Value *array_mem = builder.CreateAlloca (array_t); + builder.CreateStore (array, array_mem); + array = builder.CreateBitCast (array_mem, scalar_t->getPointerTo ()); + + llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), n); + llvm::Value *mat = fn.argument (builder, 0); + llvm::Value *value = fn.argument (builder, n + 1); + llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem, value); + fn.do_return (builder, ret); + paren_subsasgn_fn.add_overload (fn); +} + #endif
--- a/src/jit-typeinfo.h +++ b/src/jit-typeinfo.h @@ -633,6 +633,8 @@ void gen_subsref (const jit_function& paren_scalar, size_t n); + void gen_subsasgn (const jit_function& paren_scalar, size_t n); + static jit_typeinfo *instance; llvm::Module *module;
--- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -810,7 +810,8 @@ } jit_instruction * -jit_convert::resolve (const jit_operation& fres, tree_index_expression& exp) +jit_convert::resolve (const jit_operation& fres, tree_index_expression& exp, + jit_value *extra_arg) { std::string type = exp.type_tags (); if (! (type.size () == 1 && type[0] == '(')) @@ -832,7 +833,8 @@ size_t narg = arg_list->size (); tree_argument_list::iterator iter = arg_list->begin (); - std::vector<jit_value *> call_args (narg + 1); + bool have_extra = extra_arg; + std::vector<jit_value *> call_args (narg + 1 + have_extra); call_args[0] = object; for (size_t idx = 0; iter != arg_list->end (); ++idx, ++iter) @@ -844,6 +846,9 @@ call_args[idx + 1] = visit (*iter); } + if (extra_arg) + call_args[call_args.size () - 1] = extra_arg; + return create_checked (fres, call_args); } @@ -858,7 +863,8 @@ else if (tree_index_expression *idx = dynamic_cast<tree_index_expression *> (exp)) { - jit_value *new_object = resolve (jit_typeinfo::paren_subsasgn (), *idx); + jit_value *new_object = resolve (jit_typeinfo::paren_subsasgn (), *idx, + rhs); do_assign (idx->expression (), new_object, true); // FIXME: Will not work for values that must be release/grabed @@ -1862,4 +1868,19 @@ %! i = i + 1; %! endwhile %! assert (result == sum (sum (m))); + +%!test +%! ndim = 100; +%! m = zeros (ndim); +%! i = 1; +%! while (i <= ndim) +%! for j = 1:ndim +%! m(i, j) = (j - 1) * ndim + i; +%! endfor +%! i = i + 1; +%! endwhile +%! m2 = zeros (ndim); +%! m2(:) = 1:(ndim^2); +%! assert (all (m == m2)); + */
--- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -297,7 +297,8 @@ std::string next_name (const char *prefix, size_t& count, bool inc); jit_instruction *resolve (const jit_operation& fres, - tree_index_expression& exp); + tree_index_expression& exp, + jit_value *extra_arg = 0); jit_value *do_assign (tree_expression *exp, jit_value *rhs, bool artificial = false);