changeset 15371:8355fddce815

Use sret and do not use save/restore stack (bug #37308) * jit-typeinfo.cc (octave_jit_grab_matrix, octave_jit_cast_matrix_any, octave_jit_paren_subsasgn_impl, octave_jit_paren_scalar_subsasgn, octave_jit_paren_subsasgn_matrix_range): Return matrix directly. (octave_jit_cast_range_any): Return range directly. (jit_function::jit_function): Maybe mark llvm function return as sret. (jit_function::call): Maybe mark llvm call sret and place allocas at function entry. (jit_function::do_return): Handle new parameter, verify. (jit_typeinfo::jit_typeinfo): Match C++ std::complex type better, pass jit_convetion::external explicitly, and disable right complex division. (jit_typeinfo::create_identity): Improve name. (jit_typeinfo::pack_complex, jit_typeinfo::unpack_complex): Handle changed complex format. * jit-typeinfo.h (jit_array::jit_array): New overload. (jit_type::mark_sret, jit_type::mark_pointer_arg): Remove default convention. (jit_function::do_return): Add verify parameter. * pt-jit.cc (jit_convert_llvm::convert_function): Store the jit_function. (jit_convert::visit): Call do_return if converting a function. * pt-jit.h (jit_convert_llvm::creating): New member variable.
author Max Brister <max@2bass.com>
date Wed, 12 Sep 2012 19:18:51 -0600
parents 715220d2b511
children eec0d1fcba4f
files libinterp/interp-core/jit-typeinfo.cc libinterp/interp-core/jit-typeinfo.h libinterp/interp-core/pt-jit.cc libinterp/interp-core/pt-jit.h
diffstat 4 files changed, 111 insertions(+), 69 deletions(-) [+]
line wrap: on
line diff
--- a/libinterp/interp-core/jit-typeinfo.cc
+++ b/libinterp/interp-core/jit-typeinfo.cc
@@ -113,10 +113,10 @@
   return obv;
 }
 
-extern "C" void
-octave_jit_grab_matrix (jit_matrix *result, jit_matrix *m)
+extern "C" jit_matrix
+octave_jit_grab_matrix (jit_matrix *m)
 {
-  *result = *m->array;
+  return *m->array;
 }
 
 extern "C" octave_base_value *
@@ -130,12 +130,12 @@
   return rep;
 }
 
-extern "C" void
-octave_jit_cast_matrix_any (jit_matrix *ret, octave_base_value *obv)
+extern "C" jit_matrix
+octave_jit_cast_matrix_any (octave_base_value *obv)
 {
   NDArray m = obv->array_value ();
-  *ret = m;
   obv->release ();
+  return m;
 }
 
 extern "C" octave_base_value *
@@ -148,13 +148,13 @@
 
   return rep;
 }
-extern "C" void
-octave_jit_cast_range_any (jit_range *ret, octave_base_value *obv)
+extern "C" jit_range
+octave_jit_cast_range_any (octave_base_value *obv)
 {
 
   jit_range r (obv->range_value ());
-  *ret = r;
   obv->release ();
+  return r;
 }
 
 extern "C" double
@@ -228,9 +228,9 @@
     }
 }
 
-extern "C" void
-octave_jit_paren_subsasgn_impl (jit_matrix *ret, jit_matrix *mat,
-                                octave_idx_type index, double value)
+extern "C" jit_matrix
+octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index,
+                                double value)
 {
   NDArray *array = mat->array;
   if (array->nelem () < index)
@@ -240,7 +240,7 @@
   data[index - 1] = value;
 
   mat->update ();
-  *ret = *mat;
+  return *mat;
 }
 
 static void
@@ -272,12 +272,12 @@
     }
 }
 
-extern "C" void
-octave_jit_paren_scalar_subsasgn (jit_matrix *ret, jit_matrix *mat,
-                                  double *indices, octave_idx_type idx_count,
-                                  double value)
+extern "C" jit_matrix
+octave_jit_paren_scalar_subsasgn (jit_matrix *mat, double *indices,
+                                  octave_idx_type idx_count, double value)
 {
   // FIXME: Replace this with a more optimal version
+  jit_matrix ret;
   try
     {
       Array<idx_vector> idx;
@@ -286,17 +286,19 @@
       Matrix temp (1, 1);
       temp.xelem(0) = value;
       mat->array->assign (idx, temp);
-      ret->update (mat->array);
+      ret.update (mat->array);
     }
   catch (const octave_execution_exception&)
     {
       gripe_library_execution_error ();
     }
+
+  return ret;
 }
 
-extern "C" void
-octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat,
-                                        jit_range *index, double value)
+extern "C" jit_matrix
+octave_jit_paren_subsasgn_matrix_range (jit_matrix *mat, jit_range *index,
+                                        double value)
 {
   NDArray *array = mat->array;
   bool done = false;
@@ -340,7 +342,9 @@
       array->assign (idx, avalue);
     }
 
-  result->update (array);
+  jit_matrix ret;
+  ret.update (array);
+  return ret;
 }
 
 extern "C" double
@@ -562,6 +566,10 @@
   llvm::FunctionType *ft = llvm::FunctionType::get (rtype, llvm_args, false);
   llvm_function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
                                           aname, module);
+
+  if (sret ())
+    llvm_function->addAttribute (1, llvm::Attribute::StructRet);
+
   if (call_conv == jit_convention::internal)
     llvm_function->addFnAttr (llvm::Attribute::AlwaysInline);
 }
@@ -620,12 +628,18 @@
   llvm::SmallVector<llvm::Value *, 10> llvm_args;
   llvm_args.reserve (in_args.size () + sret ());
 
-  llvm::Value *sret_mem = 0;
-  llvm::Value *saved_stack = 0;
+  llvm::BasicBlock *insert_block = builder.GetInsertBlock ();
+  llvm::Function *parent = insert_block->getParent ();
+  assert (parent);
+
+  // we insert allocas inside the prelude block to prevent stack overflows
+  llvm::BasicBlock& prelude = parent->getEntryBlock ();
+  llvm::IRBuilder<> pre_builder (&prelude, prelude.begin ());
+
+  llvm::AllocaInst *sret_mem = 0;
   if (sret ())
     {
-      saved_stack = builder.CreateCall (stacksave);
-      sret_mem = builder.CreateAlloca (mresult->packed_type (call_conv));
+      sret_mem = pre_builder.CreateAlloca (mresult->packed_type (call_conv));
       llvm_args.push_back (sret_mem);
     }
 
@@ -638,19 +652,23 @@
 
       if (args[i]->pointer_arg (call_conv))
         {
-          if (! saved_stack)
-            saved_stack = builder.CreateCall (stacksave);
-
-          arg = builder.CreateAlloca (args[i]->to_llvm ());
-          builder.CreateStore (in_args[i], arg);
+          llvm::Type *ty = args[i]->packed_type (call_conv);
+          llvm::Value *alloca = pre_builder.CreateAlloca (ty);
+          builder.CreateStore (arg, alloca);
+          arg = alloca;
         }
 
       llvm_args.push_back (arg);
     }
 
-  llvm::Value *ret = builder.CreateCall (llvm_function, llvm_args);
-  if (sret_mem)
-    ret = builder.CreateLoad (sret_mem);
+  llvm::CallInst *callinst = builder.CreateCall (llvm_function, llvm_args);
+  llvm::Value *ret = callinst;
+
+  if (sret ())
+    {
+      callinst->addAttribute (1, llvm::Attribute::StructRet);
+      ret = builder.CreateLoad (sret_mem);
+    }
 
   if (mresult)
     {
@@ -659,14 +677,6 @@
         ret = unpack (builder, ret);
     }
 
-  if (saved_stack)
-    {
-      llvm::Function *stackrestore
-        = llvm::Intrinsic::getDeclaration (module,
-                                           llvm::Intrinsic::stackrestore);
-      builder.CreateCall (stackrestore, saved_stack);
-    }
-
   return ret;
 }
 
@@ -691,7 +701,8 @@
 }
 
 void
-jit_function::do_return (llvm::IRBuilderD& builder, llvm::Value *rval)
+jit_function::do_return (llvm::IRBuilderD& builder, llvm::Value *rval,
+                         bool verify)
 {
   assert (! rval == ! mresult);
 
@@ -702,14 +713,18 @@
         rval = convert (builder, rval);
 
       if (sret ())
-        builder.CreateStore (rval, llvm_function->arg_begin ());
+        {
+          builder.CreateStore (rval, llvm_function->arg_begin ());
+          builder.CreateRetVoid ();
+        }
       else
         builder.CreateRet (rval);
     }
   else
     builder.CreateRetVoid ();
 
-  llvm::verifyFunction (*llvm_function);
+  if (verify)
+    llvm::verifyFunction (*llvm_function);
 }
 
 void
@@ -1032,9 +1047,14 @@
 
   // complex_ret is what is passed to C functions in order to get calling
   // convention right
+  llvm::Type *cmplx_inner_cont[] = {scalar_t, scalar_t};
+  llvm::StructType *cmplx_inner = llvm::StructType::create (cmplx_inner_cont);
+
   complex_ret = llvm::StructType::create (context, "complex_ret");
-  llvm::Type *complex_ret_contents[] = {scalar_t, scalar_t};
-  complex_ret->setBody (complex_ret_contents);
+  {
+    llvm::Type *contents[] = {cmplx_inner};
+    complex_ret->setBody (contents);
+  }
 
   // create types
   any = new_type ("any", 0, any_t);
@@ -1059,18 +1079,18 @@
   // specify calling conventions
   // FIXME: We should detect architecture and do something sane based on that
   // here we assume x86 or x86_64
-  matrix->mark_sret ();
-  matrix->mark_pointer_arg ();
+  matrix->mark_sret (jit_convention::external);
+  matrix->mark_pointer_arg (jit_convention::external);
 
-  range->mark_sret ();
-  range->mark_pointer_arg ();
+  range->mark_sret (jit_convention::external);
+  range->mark_pointer_arg (jit_convention::external);
 
   complex->set_pack (jit_convention::external, &jit_typeinfo::pack_complex);
   complex->set_unpack (jit_convention::external, &jit_typeinfo::unpack_complex);
   complex->set_packed_type (jit_convention::external, complex_ret);
 
   if (sizeof (void *) == 4)
-    complex->mark_sret ();
+    complex->mark_sret (jit_convention::external);
 
   paren_subsref_fn.initialize (module, engine);
   paren_subsasgn_fn.initialize (module, engine);
@@ -1333,9 +1353,9 @@
   binary_ops[octave_value::op_div].add_overload (fn);
   binary_ops[octave_value::op_ldiv].add_overload (fn);
 
-  fn = mirror_binary (complex_div);
-  binary_ops[octave_value::op_ldiv].add_overload (fn);
-  binary_ops[octave_value::op_el_ldiv].add_overload (fn);
+  // fn = mirror_binary (complex_div);
+  // binary_ops[octave_value::op_ldiv].add_overload (fn);
+  // binary_ops[octave_value::op_el_ldiv].add_overload (fn);
 
   fn = create_function (jit_convention::external,
                         "octave_jit_pow_complex_complex", complex, complex,
@@ -1990,8 +2010,11 @@
 
   if (! identities[id].valid ())
     {
-      jit_function fn = create_function (jit_convention::internal, "id", type,
-                                         type);
+      std::stringstream name;
+      name << "id_" << type->name ();
+      jit_function fn = create_function (jit_convention::internal, name.str (),
+                                         type, type);
+
       llvm::BasicBlock *body = fn.new_block ();
       builder.SetInsertPoint (body);
       fn.do_return (builder, fn.argument (builder, 0));
@@ -2141,17 +2164,24 @@
   llvm::Value *real = bld.CreateExtractElement (cplx, bld.getInt32 (0));
   llvm::Value *imag = bld.CreateExtractElement (cplx, bld.getInt32 (1));
   llvm::Value *ret = llvm::UndefValue::get (complex_ret);
-  ret = bld.CreateInsertValue (ret, real, 0);
-  return bld.CreateInsertValue (ret, imag, 1);
+
+  unsigned int re_idx[] = {0, 0};
+  unsigned int im_idx[] = {0, 1};
+  ret = bld.CreateInsertValue (ret, real, re_idx);
+  return bld.CreateInsertValue (ret, imag, im_idx);
 }
 
 llvm::Value *
 jit_typeinfo::unpack_complex (llvm::IRBuilderD& bld, llvm::Value *result)
 {
+  unsigned int re_idx[] = {0, 0};
+  unsigned int im_idx[] = {0, 1};
+
   llvm::Type *complex_t = get_complex ()->to_llvm ();
-  llvm::Value *real = bld.CreateExtractValue (result, 0);
-  llvm::Value *imag = bld.CreateExtractValue (result, 1);
+  llvm::Value *real = bld.CreateExtractValue (result, re_idx);
+  llvm::Value *imag = bld.CreateExtractValue (result, im_idx);
   llvm::Value *ret = llvm::UndefValue::get (complex_t);
+
   ret = bld.CreateInsertElement (ret, real, bld.getInt32 (0));
   return bld.CreateInsertElement (ret, imag, bld.getInt32 (1));
 }
--- a/libinterp/interp-core/jit-typeinfo.h
+++ b/libinterp/interp-core/jit-typeinfo.h
@@ -66,6 +66,8 @@
 struct
 jit_array
 {
+  jit_array () : array (0) {}
+
   jit_array (T& from) : array (new T (from))
   {
     update ();
@@ -161,7 +163,7 @@
   // retval. (on the stack)
   bool sret (jit_convention::type cc) const { return msret[cc]; }
 
-  void mark_sret (jit_convention::type cc = jit_convention::external)
+  void mark_sret (jit_convention::type cc)
   { msret[cc] = true; }
 
   // A function like: void foo (mytype arg0)
@@ -169,7 +171,7 @@
   // Basically just pass by reference.
   bool pointer_arg (jit_convention::type cc) const { return mpointer_arg[cc]; }
 
-  void mark_pointer_arg (jit_convention::type cc = jit_convention::external)
+  void mark_pointer_arg (jit_convention::type cc)
   { mpointer_arg[cc] = true; }
 
   // Convert into an equivalent form before calling. For example, complex is
@@ -278,7 +280,8 @@
 
   llvm::Value *argument (llvm::IRBuilderD& builder, size_t idx) const;
 
-  void do_return (llvm::IRBuilderD& builder, llvm::Value *rval = 0);
+  void do_return (llvm::IRBuilderD& builder, llvm::Value *rval = 0,
+                  bool verify = true);
 
   llvm::Function *to_llvm (void) const { return llvm_function; }
 
--- a/libinterp/interp-core/pt-jit.cc
+++ b/libinterp/interp-core/pt-jit.cc
@@ -1075,8 +1075,8 @@
   jit_return *ret = dynamic_cast<jit_return *> (final_block->back ());
   assert (ret);
 
-  jit_function creating = jit_function (module, jit_convention::internal,
-                                        "foobar", ret->result_type (), args);
+  creating = jit_function (module, jit_convention::internal,
+                           "foobar", ret->result_type (), args);
   function = creating.to_llvm ();
 
   try
@@ -1280,10 +1280,16 @@
 jit_convert_llvm::visit (jit_return& ret)
 {
   jit_value *res = ret.result ();
-  if (res)
-    builder.CreateRet (res->to_llvm ());
+
+  if (converting_function)
+    creating.do_return (builder, res->to_llvm (), false);
   else
-    builder.CreateRetVoid ();
+    {
+      if (res)
+        builder.CreateRet (res->to_llvm ());
+      else
+        builder.CreateRetVoid ();
+    }
 }
 
 void
--- a/libinterp/interp-core/pt-jit.h
+++ b/libinterp/interp-core/pt-jit.h
@@ -276,6 +276,9 @@
 
   bool converting_function;
 
+  // only used if we are converting a function
+  jit_function creating;
+
   llvm::Function *function;
   llvm::BasicBlock *prelude;