changeset 15136:bd6bb87e2bea

Support sin, cos, and exp with matrix arguments in JIT * src/interp-core/jit-typeinfo.cc (jit_operation::generate): Remove unused parameter name. (jit_typeinfo::jit_typeinfo): Create any_call function. (jit_typeinfo::register_generic): Implement. * src/interp-core/jit-typeinfo.h (jit_typeinfo): New field, any_call. * src/interp-core/pt-jit.cc: New test.
author Max Brister <max@2bass.com>
date Thu, 09 Aug 2012 15:45:59 -0500
parents edae65062740
children eeaaac7c86b6
files src/interp-core/jit-typeinfo.cc src/interp-core/jit-typeinfo.h src/interp-core/pt-jit.cc
diffstat 3 files changed, 62 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/src/interp-core/jit-typeinfo.cc
+++ b/src/interp-core/jit-typeinfo.cc
@@ -837,7 +837,7 @@
 }
 
 jit_function *
-jit_operation::generate (const signature_vec& types) const
+jit_operation::generate (const signature_vec&) const
 {
   return 0;
 }
@@ -1041,6 +1041,7 @@
   complex = new_type ("complex", any, complex_t);
   scalar = new_type ("scalar", complex, scalar_t);
   scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ());
+  any_ptr = new_type ("any_ptr", 0, any_t->getPointerTo ());
   range = new_type ("range", any, range_t);
   string = new_type ("string", any, string_t);
   boolean = new_type ("bool", any, bool_t);
@@ -1080,6 +1081,14 @@
   engine->addGlobalMapping (lerror_state,
                             reinterpret_cast<void *> (&error_state));
 
+  // generic call function
+  {
+    jit_type *int_t = intN (sizeof (octave_builtin::fcn) * 8);
+    any_call = create_function (jit_convention::external, "octave_jit_call",
+                                any, int_t, int_t, any_ptr, int_t);
+    any_call.add_mapping (engine, &octave_jit_call);
+  }
+
   // any with anything is an any op
   jit_function fn;
   jit_type *binary_op_type = intN (sizeof (octave_value::binary_op) * 8);
@@ -1974,10 +1983,48 @@
 }
 
 void
-jit_typeinfo::register_generic (const std::string&, jit_type *,
-                                const std::vector<jit_type *>&)
+jit_typeinfo::register_generic (const std::string& name, jit_type *result,
+                                const std::vector<jit_type *>& args)
 {
-  // FIXME: Implement
+  octave_builtin *builtin = find_builtin (name);
+  if (! builtin)
+    return;
+
+  std::vector<jit_type *> fn_args (args.size () + 1);
+  fn_args[0] = builtins[name];
+  std::copy (args.begin (), args.end (), fn_args.begin () + 1);
+  jit_function fn = create_function (jit_convention::internal, name, result,
+                                     fn_args);
+  llvm::BasicBlock *block = fn.new_block ();
+  builder.SetInsertPoint (block);
+  llvm::Type *any_t = any->to_llvm ();
+  llvm::ArrayType *array_t = llvm::ArrayType::get (any_t, args.size ());
+  llvm::Value *array = llvm::UndefValue::get (array_t);
+  for (size_t i = 0; i < args.size (); ++i)
+    {
+      llvm::Value *arg = fn.argument (builder, i + 1);
+      jit_function agrab = get_grab (args[i]);
+      llvm::Value *garg = agrab.call (builder, arg);
+      jit_function acast = cast (any, args[i]);
+      array = builder.CreateInsertValue (array, acast.call (builder, garg), i);
+    }
+
+  llvm::Value *array_mem = builder.CreateAlloca (array_t);
+  builder.CreateStore (array, array_mem);
+  array = builder.CreateBitCast (array_mem, any_t->getPointerTo ());
+
+  jit_type *jintTy = intN (sizeof (octave_builtin::fcn) * 8);
+  llvm::Type *intTy = jintTy->to_llvm ();
+  size_t fcn_int = reinterpret_cast<size_t> (builtin->function ());
+  llvm::Value *fcn = llvm::ConstantInt::get (intTy, fcn_int);
+  llvm::Value *nargin = llvm::ConstantInt::get (intTy, args.size ());
+  size_t result_int = reinterpret_cast<size_t> (result);
+  llvm::Value *res_llvm = llvm::ConstantInt::get (intTy, result_int);
+  llvm::Value *ret = any_call.call (builder, fcn, nargin, array, res_llvm);
+
+  jit_function cast_result = cast (result, any);
+  fn.do_return (builder, cast_result.call (builder, ret));
+  paren_subsref_fn.add_overload (fn);
 }
 
 jit_function
--- a/src/interp-core/jit-typeinfo.h
+++ b/src/interp-core/jit-typeinfo.h
@@ -724,6 +724,7 @@
   jit_type *matrix;
   jit_type *scalar;
   jit_type *scalar_ptr; // a fake type for interfacing with C++
+  jit_type *any_ptr; // a fake type for interfacing with C++
   jit_type *range;
   jit_type *string;
   jit_type *boolean;
@@ -749,6 +750,8 @@
   jit_operation end1_fn;
   jit_operation end_fn;
 
+  jit_function any_call;
+
   // type id -> cast function TO that type
   std::vector<jit_operation> casts;
 
--- a/src/interp-core/pt-jit.cc
+++ b/src/interp-core/pt-jit.cc
@@ -1940,4 +1940,12 @@
 %! m2(2, :) = 1:1001;
 %! assert (m, m2);
 
+%!test
+%! m = [1 2 3];
+%! for i=1:1001
+%!   m = sin (m);
+%!   break;
+%! endfor
+%! assert (m == sin ([1  2 3]));
+
 */