changeset 15068:df4538e3b50b

ND scalar indexing in JIT. * src/jit-ir.cc (jit_magic_end::jit_magic_end): Use jit_magic_end::context. * src/jit-ir.h (jit_call::jit_call): New overload. (jit_magic_end::context): New class. (jit_magic_end::jit_magic_end): moved to src/jit-ir.cc. * src/jit-typeinfo.cc (octave_jit_paren_scalar): New function. (jit_typeinfo::jit_typeinfo): Generate ND scalar indexing. (jit_typeinfo::gen_subsref): New function. * src/jit-typeinfo.h (jit_typeinfo::gen_subsref): New declaration. * src/pt-jit.cc (jit_convert::visit_index_expression, jit_convert::do_assign): Update resolve call. (jit_convert::resolve): Resolve ND indices. * src/pt-jit.h (jit_convert::resolve): Change function signature.
author Max Brister <max@2bass.com>
date Tue, 31 Jul 2012 11:51:01 -0500
parents 6451a584305e
children f57d7578c1a6
files src/jit-ir.cc src/jit-ir.h src/jit-typeinfo.cc src/jit-typeinfo.h src/pt-jit.cc src/pt-jit.h
diffstat 6 files changed, 143 insertions(+), 28 deletions(-) [+]
line wrap: on
line diff
--- a/src/jit-ir.cc
+++ b/src/jit-ir.cc
@@ -599,6 +599,22 @@
 }
 
 // -------------------- jit_magic_end --------------------
+jit_magic_end::jit_magic_end (const std::vector<context>& full_context)
+{
+  // for now we only support end in 1 dimensional indexing
+  resize_arguments (full_context.size ());
+
+  size_t i;
+  std::vector<context>::const_iterator iter;
+  for (iter = full_context.begin (), i = 0; iter != full_context.end (); ++iter,
+         ++i)
+    {
+      if (iter->count != 1)
+        throw jit_fail_exception ("end is only supported in linear contexts");
+      stash_argument (i, iter->value);
+    }
+}
+
 const jit_function&
 jit_magic_end::overload () const
 {
--- a/src/jit-ir.h
+++ b/src/jit-ir.h
@@ -1074,6 +1074,10 @@
 
 #undef JIT_CALL_CONST
 
+  jit_call (const jit_operation& aoperation,
+            const std::vector<jit_value *>& args)
+  : jit_instruction (args), moperation (aoperation)
+  {}
 
   const jit_operation& operation (void) const { return moperation; }
 
@@ -1151,9 +1155,23 @@
 jit_magic_end : public jit_instruction
 {
 public:
-  jit_magic_end (const std::vector<jit_value *>& context)
-    : jit_instruction (context)
-  {}
+  class
+  context
+  {
+  public:
+    context (void) : value (0), index (0), count (0)
+    {}
+
+    context (jit_value *avalue, size_t aindex, size_t acount)
+      : value (avalue), index (aindex), count (acount)
+    {}
+
+    jit_value *value;
+    size_t index;
+    size_t count;
+  };
+
+  jit_magic_end (const std::vector<context>& full_context);
 
   const jit_function& overload () const;
 
--- a/src/jit-typeinfo.cc
+++ b/src/jit-typeinfo.cc
@@ -243,6 +243,27 @@
   *ret = *mat;
 }
 
+extern "C" double
+octave_jit_paren_scalar (jit_matrix *mat, double *indicies,
+                         octave_idx_type idx_count)
+{
+  // FIXME: Replace this with a more optimal version
+  try
+    {
+      Array<idx_vector> idx (dim_vector (1, idx_count));
+      for (octave_idx_type i = 0; i < idx_count; ++i)
+        idx(i) = idx_vector (indicies[i]);
+
+      Array<double> ret = mat->array->index (idx);
+      return ret.xelem (0);
+    }
+  catch (const octave_execution_exception&)
+    {
+      gripe_library_execution_error ();
+      return 0;
+    }
+}
+
 extern "C" void
 octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat,
                                         jit_range *index, double value)
@@ -789,6 +810,9 @@
   boolean = new_type ("bool", any, bool_t);
   index = new_type ("index", any, index_t);
 
+  // a fake type for interfacing with C++
+  jit_type *scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ());
+
   create_int (8);
   create_int (16);
   create_int (32);
@@ -1310,6 +1334,18 @@
   }
   paren_subsref_fn.add_overload (fn);
 
+  // generate () subsref for ND indexing of matricies with scalars
+  jit_function paren_scalar = create_function (jit_convention::external,
+                                               "octave_jit_paren_scalar",
+                                               scalar, matrix, scalar_ptr,
+                                               index);
+  paren_scalar.add_mapping (engine, &octave_jit_paren_scalar);
+  paren_scalar.mark_can_error ();
+
+  // FIXME: Generate this on the fly
+  for (size_t i = 2; i < 10; ++i)
+    gen_subsref (paren_scalar, i);
+
   // paren subsasgn
   paren_subsasgn_fn.stash_name ("()subsasgn");
 
@@ -1831,4 +1867,37 @@
   return ret;
 }
 
+void
+jit_typeinfo::gen_subsref (const jit_function& paren_scalar, size_t n)
+{
+  std::stringstream name;
+  name << "jit_paren_subsref_matrix_scalar" << n;
+  std::vector<jit_type *> args (n + 1, scalar);
+  args[0] = matrix;
+  jit_function fn = create_function (jit_convention::internal, name.str (),
+                                     scalar, args);
+  fn.mark_can_error ();
+  llvm::BasicBlock *body = fn.new_block ();
+  builder.SetInsertPoint (body);
+
+  llvm::Type *scalar_t = scalar->to_llvm ();
+  llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n);
+  llvm::Value *array = llvm::UndefValue::get (array_t);
+  for (size_t i = 0; i < n; ++i)
+    {
+      llvm::Value *idx = fn.argument (builder, i + 1);
+      array = builder.CreateInsertValue (array, idx, i);
+    }
+
+  llvm::Value *array_mem = builder.CreateAlloca (array_t);
+  builder.CreateStore (array, array_mem);
+  array = builder.CreateBitCast (array_mem, scalar_t->getPointerTo ());
+
+  llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), n);
+  llvm::Value *mat = fn.argument (builder, 0);
+  llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem);
+  fn.do_return (builder, ret);
+  paren_subsref_fn.add_overload (fn);
+}
+
 #endif
--- a/src/jit-typeinfo.h
+++ b/src/jit-typeinfo.h
@@ -631,6 +631,8 @@
 
   jit_type *intN (size_t nbits) const;
 
+  void gen_subsref (const jit_function& paren_scalar, size_t n);
+
   static jit_typeinfo *instance;
 
   llvm::Module *module;
--- a/src/pt-jit.cc
+++ b/src/pt-jit.cc
@@ -518,11 +518,7 @@
 void
 jit_convert::visit_index_expression (tree_index_expression& exp)
 {
-  std::pair<jit_value *, jit_value *> res = resolve (exp);
-  jit_value *object = res.first;
-  jit_value *index = res.second;
-
-  result = create_checked (jit_typeinfo::paren_subsref, object, index);
+  result = resolve (jit_typeinfo::paren_subsref (), exp);
 }
 
 void
@@ -813,8 +809,8 @@
   return ss.str ();
 }
 
-std::pair<jit_value *, jit_value *>
-jit_convert::resolve (tree_index_expression& exp)
+jit_instruction *
+jit_convert::resolve (const jit_operation& fres, tree_index_expression& exp)
 {
   std::string type = exp.type_tags ();
   if (! (type.size () == 1 && type[0] == '('))
@@ -828,21 +824,27 @@
   if (! arg_list)
     throw jit_fail_exception ("null argument list");
 
-  if (arg_list->size () != 1)
-    throw jit_fail_exception ("Bad number of arguments in arg_list");
+  if (arg_list->size () < 1)
+    throw jit_fail_exception ("Empty arg_list");
 
   tree_expression *tree_object = exp.expression ();
   jit_value *object = visit (tree_object);
 
-  end_context.push_back (object);
-
-  unwind_protect prot;
-  prot.add_method (&end_context, &std::vector<jit_value *>::pop_back);
+  size_t narg = arg_list->size ();
+  tree_argument_list::iterator iter = arg_list->begin ();
+  std::vector<jit_value *> call_args (narg + 1);
+  call_args[0] = object;
 
-  tree_expression *arg0 = arg_list->front ();
-  jit_value *index = visit (arg0);
+  for (size_t idx = 0; iter != arg_list->end (); ++idx, ++iter)
+    {
+      unwind_protect prot;
+      prot.add_method (&end_context,
+                       &std::vector<jit_magic_end::context>::pop_back);
+      end_context.push_back (jit_magic_end::context (object, idx, narg));
+      call_args[idx + 1] = visit (*iter);
+    }
 
-  return std::make_pair (object, index);
+  return create_checked (fres, call_args);
 }
 
 jit_value *
@@ -856,14 +858,8 @@
   else if (tree_index_expression *idx
            = dynamic_cast<tree_index_expression *> (exp))
     {
-      std::pair<jit_value *, jit_value *> res = resolve (*idx);
-      jit_value *object = res.first;
-      jit_value *index = res.second;
-      jit_call *new_object = create<jit_call> (&jit_typeinfo::paren_subsasgn,
-                                               object, index, rhs);
-      block->append (new_object);
+      jit_value *new_object = resolve (jit_typeinfo::paren_subsasgn (), *idx);
       do_assign (idx->expression (), new_object, true);
-      create_check (new_object);
 
       // FIXME: Will not work for values that must be release/grabed
       return rhs;
@@ -1853,4 +1849,17 @@
 %! endfor
 %! assert (result == m(end) * niter);
 
+%!test
+%! ndim = 100;
+%! result = 0;
+%! m = zeros (ndim);
+%! m(:) = 1:ndim^2;
+%! i = 1;
+%! while (i <= ndim)
+%!   for j = 1:ndim
+%!     result = result + m(i, j);
+%!    endfor
+%!   i = i + 1;
+%! endwhile
+%! assert (result == sum (sum (m)));
 */
--- a/src/pt-jit.h
+++ b/src/pt-jit.h
@@ -244,7 +244,7 @@
 
   std::list<jit_value *> all_values;
 
-  std::vector<jit_value *> end_context;
+  std::vector<jit_magic_end::context> end_context;
 
   size_t iterator_count;
   size_t for_bounds_count;
@@ -296,7 +296,8 @@
 
   std::string next_name (const char *prefix, size_t& count, bool inc);
 
-  std::pair<jit_value *, jit_value *> resolve (tree_index_expression& exp);
+  jit_instruction *resolve (const jit_operation& fres,
+                            tree_index_expression& exp);
 
   jit_value *do_assign (tree_expression *exp, jit_value *rhs,
                         bool artificial = false);