changeset 14937:f5925478bc15

More support for complex numbers in JIT * src/pt-jit.cc (octave_jit_cast_complex_any): Return result directly. (octave_jit_complex_div, jit_typeinfo::wrap_complex, jit_typeinfo::pack_complex, jit_typeinfo::unpack_complex): New function. (jit_typeinfo::jit_typeinfo): Support more complex functionality. (tree_jit::optimize): Write llvm bytecode to a file when debugging. * src/pt-jit.h (jit_typeinfo::wrap_complex, jit_typeinfo::pack_complex, jit_typeinfo): New declarations.
author Max Brister <max@2bass.com>
date Tue, 10 Jul 2012 15:55:05 -0500
parents 561aad6a9e4b
children 70ff15b6d996
files src/pt-jit.cc src/pt-jit.h
diffstat 2 files changed, 180 insertions(+), 42 deletions(-) [+]
line wrap: on
line diff
--- a/src/pt-jit.cc
+++ b/src/pt-jit.cc
@@ -50,6 +50,8 @@
 #include <llvm/Transforms/IPO.h>
 #include <llvm/Support/TargetSelect.h>
 #include <llvm/Support/raw_os_ostream.h>
+#include <llvm/Support/FormattedStream.h>
+#include <llvm/Bitcode/ReaderWriter.h>
 
 #include "octave.h"
 #include "ov-fcn-handle.h"
@@ -199,22 +201,21 @@
   return new octave_scalar (value);
 }
 
-extern "C" void
-octave_jit_cast_complex_any (double *dest, octave_base_value *obv)
+extern "C" Complex
+octave_jit_cast_complex_any (octave_base_value *obv)
 {
   Complex ret = obv->complex_value ();
   obv->release ();
-  dest[0] = ret.real ();
-  dest[1] = ret.imag ();
+  return ret;
 }
 
 extern "C" octave_base_value *
-octave_jit_cast_any_complex (double real, double imag)
+octave_jit_cast_any_complex (Complex c)
 {
-  if (imag == 0)
-    return new octave_scalar (real);
+  if (c.imag () == 0)
+    return new octave_scalar (c.real ());
   else
-    return new octave_complex (Complex (real, imag));
+    return new octave_complex (c);
 }
 
 extern "C" void
@@ -320,6 +321,16 @@
   result->update (array);
 }
 
+extern "C" Complex
+octave_jit_complex_div (Complex lhs, Complex rhs)
+{
+  // see src/OPERATORS/op-cs-cs.cc
+  if (rhs == 0.0)
+    gripe_divide_by_zero ();
+
+  return lhs / rhs;
+}
+
 extern "C" void
 octave_jit_print_matrix (jit_matrix *m)
 {
@@ -522,6 +533,12 @@
 
   llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2);
 
+  // this is the structure that C functions return. Use this in order to get calling
+  // conventions right.
+  complex_ret = llvm::StructType::create (context, "complex_ret");
+  llvm::Type *complex_ret_contents[] = {scalar_t, scalar_t};
+  complex_ret->setBody (complex_ret_contents);
+
   // create types
   any = new_type ("any", 0, any_t);
   matrix = new_type ("matrix", any, matrix_t);
@@ -734,13 +751,11 @@
     llvm::Value *tlhs = builder.CreateExtractElement (mres, zero);
     llvm::Value *trhs = builder.CreateExtractElement (mres, one);
     temp = builder.CreateFSub (tlhs, trhs);
-    //temp = llvm::ConstantFP::get (scalar_t, 123);
     ret = builder.CreateInsertElement (ret, temp, zero);
 
     tlhs = builder.CreateExtractElement (mres, two);
     trhs = builder.CreateExtractElement (mres, three);
     temp = builder.CreateFAdd (tlhs, trhs);
-    //temp = llvm::ConstantFP::get (scalar_t, 123);
     ret = builder.CreateInsertElement (ret, temp, one);
     builder.CreateRet (ret);
 
@@ -750,6 +765,67 @@
   }
   llvm::verifyFunction (*fn);
 
+  fn = create_function ("octave_jit_*_scalar_complex", complex, scalar,
+                        complex);
+  llvm::Function *mul_scalar_complex = fn;
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *lhs = fn->arg_begin ();
+    llvm::Value *tlhs = llvm::UndefValue::get (complex_t);
+    tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (0));
+    tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (1));
+
+    llvm::Value *rhs = ++fn->arg_begin ();
+    builder.CreateRet (builder.CreateFMul (tlhs, rhs));
+
+    jit_operation::overload ol (fn, false, complex, scalar, complex);
+    binary_ops[octave_value::op_mul].add_overload (ol);
+    binary_ops[octave_value::op_el_mul].add_overload (ol);
+  }
+  llvm::verifyFunction (*fn);
+
+  fn = create_function ("octave_jit_*_complex_scalar", complex, complex,
+                        scalar);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *ret = builder.CreateCall2 (mul_scalar_complex,
+                                            ++fn->arg_begin (),
+                                            fn->arg_begin ());
+    builder.CreateRet (ret);
+
+    jit_operation::overload ol (fn, false, complex, complex,  scalar);
+    binary_ops[octave_value::op_mul].add_overload (ol);
+    binary_ops[octave_value::op_el_mul].add_overload (ol);
+  }
+  llvm::verifyFunction (*fn);
+
+  llvm::Function *complex_div = create_function ("octave_jit_complex_div",
+                                                 complex_ret, complex_ret,
+                                                 complex_ret);
+  engine->addGlobalMapping (complex_div,
+                            reinterpret_cast<void *> (&octave_jit_complex_div));
+  complex_div = wrap_complex (complex_div);
+  {
+    jit_operation::overload ol (complex_div, true, complex, complex, complex);
+    binary_ops[octave_value::op_div].add_overload (ol);
+    binary_ops[octave_value::op_ldiv].add_overload (ol);
+  }
+
+  fn = create_function ("octave_jit_\\_complex_complex", complex, complex,
+                        complex);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    builder.CreateRet (builder.CreateCall2 (complex_div, ++fn->arg_begin (),
+                                            fn->arg_begin ()));
+    jit_operation::overload ol (fn, true, complex, complex, complex);
+    binary_ops[octave_value::op_ldiv].add_overload (ol);
+    binary_ops[octave_value::op_el_ldiv].add_overload (ol);
+  }
+  llvm::verifyFunction (*fn);
+
   // now for binary index operators
   add_binary_op (index, octave_value::op_add, llvm::Instruction::Add);
 
@@ -1089,40 +1165,15 @@
   casts[scalar->type_id ()].add_overload (fn, false, scalar, any);
 
   // cast any <- complex
-  llvm::Function *any_complex = create_function ("octave_jit_cast_any_complex",
-                                                 any, scalar, scalar);
-  engine->addGlobalMapping (any_complex, reinterpret_cast<void*> (&octave_jit_cast_any_complex));
-  fn = create_function ("cast_any_complex", any, complex);
-  body = llvm::BasicBlock::Create (context, "body", fn);
-  builder.SetInsertPoint (body);
-  {
-    llvm::Value *zero = builder.getInt32 (0);
-    llvm::Value *one = builder.getInt32 (1);
-
-    llvm::Value *cmplx = fn->arg_begin ();
-    llvm::Value *real = builder.CreateExtractElement (cmplx, zero);
-    llvm::Value *imag = builder.CreateExtractElement (cmplx, one);
-    llvm::Value *ret = builder.CreateCall2 (any_complex, real, imag);
-    builder.CreateRet (ret);
-  }
-  llvm::verifyFunction (*fn);
-  casts[any->type_id ()].add_overload (fn, false, any, complex);
+  fn = create_function ("octave_jit_cast_any_complex", any_t, complex_ret);
+  engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_any_complex));
+  casts[any->type_id ()].add_overload (wrap_complex (fn), false, any, complex);
 
   // cast complex <- any
-  llvm::Function *complex_any = create_function ("octave_jit_cast_complex_any",
-                                                 void_t,
-                                                 complex_t->getPointerTo (),
-                                                 any_t);
-  fn = create_function ("cast_complex_any", complex, any);
-  body = llvm::BasicBlock::Create (context, "body", fn);
-  builder.SetInsertPoint (body);
-  {
-    llvm::Value *result = builder.CreateAlloca (complex_t);
-    builder.CreateCall2 (complex_any, result, fn->arg_begin ());
-    builder.CreateRet (builder.CreateLoad (result));
-  }
-  llvm::verifyFunction (*fn);
-  casts[complex->type_id ()].add_overload (fn, false, complex, any);
+  fn = create_function ("octave_jit_cast_complex_any", complex_ret, any_t);
+  engine->addGlobalMapping (fn, reinterpret_cast<void *> (&octave_jit_cast_complex_any));
+  casts[complex->type_id ()].add_overload (wrap_complex (fn), false, complex,
+                                           any);
 
   // cast any <- any
   fn = create_identity (any);
@@ -1363,6 +1414,77 @@
   // FIXME: Implement
 }
 
+llvm::Function *
+jit_typeinfo::wrap_complex (llvm::Function *wrap)
+{
+  llvm::SmallVector<llvm::Type *, 5> new_args;
+  new_args.reserve (wrap->arg_size ());
+  llvm::Type *complex_t = complex->to_llvm ();
+  for (llvm::Function::arg_iterator iter = wrap->arg_begin ();
+       iter != wrap->arg_end (); ++iter)
+    {
+      llvm::Value *value = iter;
+      llvm::Type *type = value->getType ();
+      new_args.push_back (type == complex_ret ? complex_t : type);
+    }
+
+  llvm::FunctionType *wrap_type = wrap->getFunctionType ();
+  bool convert_ret = wrap_type->getReturnType () == complex_ret;
+  llvm::Type *rtype = convert_ret ? complex_t : wrap->getReturnType ();
+  llvm::FunctionType *ft = llvm::FunctionType::get (rtype, new_args, false);
+  llvm::Function *fn = llvm::Function::Create (ft,
+                                               llvm::Function::ExternalLinkage,
+                                               wrap->getName () + "_wrap",
+                                               module);
+  llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+
+  llvm::SmallVector<llvm::Value *, 5> converted (new_args.size ());
+  llvm::Function::arg_iterator witer = wrap->arg_begin ();
+  llvm::Function::arg_iterator fiter = fn->arg_begin ();
+  for (size_t i = 0; i < new_args.size (); ++i, ++witer, ++fiter)
+    {
+      llvm::Value *warg = witer;
+      llvm::Value *arg = fiter;
+      converted[i] = warg->getType () == arg->getType () ? arg
+        : pack_complex (arg);
+    }
+
+  llvm::Value *ret = builder.CreateCall (wrap, converted);
+  if (wrap_type->getReturnType () != builder.getVoidTy ())
+    {
+      if (convert_ret)
+        ret = unpack_complex (ret);
+      builder.CreateRet (ret);
+    }
+  else
+    builder.CreateRetVoid ();
+
+  llvm::verifyFunction (*fn);
+  return fn;
+}
+
+llvm::Value *
+jit_typeinfo::pack_complex (llvm::Value *cplx)
+{
+  llvm::Value *real = builder.CreateExtractElement (cplx, builder.getInt32 (0));
+  llvm::Value *imag = builder.CreateExtractElement (cplx, builder.getInt32 (1));
+  llvm::Value *ret = llvm::UndefValue::get (complex_ret);
+  ret = builder.CreateInsertValue (ret, real, 0);
+  return builder.CreateInsertValue (ret, imag, 1);
+}
+
+llvm::Value *
+jit_typeinfo::unpack_complex (llvm::Value *result)
+{
+  llvm::Type *complex_t = complex->to_llvm ();
+  llvm::Value *real = builder.CreateExtractValue (result, 0);
+  llvm::Value *imag = builder.CreateExtractValue (result, 1);
+  llvm::Value *ret = llvm::UndefValue::get (complex_t);
+  ret = builder.CreateInsertElement (ret, real, builder.getInt32 (0));
+  return builder.CreateInsertElement (ret, imag, builder.getInt32 (1));
+}
+
 jit_type *
 jit_typeinfo::do_type_of (const octave_value &ov) const
 {
@@ -3446,6 +3568,13 @@
 {
   module_pass_manager->run (*module);
   pass_manager->run (*fn);
+
+#ifdef OCTAVE_JIT_DEBUG
+  std::string error;
+  llvm::raw_fd_ostream fout ("test.bc", error,
+                             llvm::raw_fd_ostream::F_Binary);
+  llvm::WriteBitcodeToFile (module, fout);
+#endif
 }
 
 // -------------------- jit_info --------------------
--- a/src/pt-jit.h
+++ b/src/pt-jit.h
@@ -77,6 +77,7 @@
   class BasicBlock;
   class LLVMContext;
   class Type;
+  class StructType;
   class Twine;
   class GlobalVariable;
   class TerminatorInst;
@@ -673,6 +674,12 @@
 
   octave_builtin *find_builtin (const std::string& name);
 
+  llvm::Function *wrap_complex (llvm::Function *wrap);
+
+  llvm::Value *pack_complex (llvm::Value *cplx);
+
+  llvm::Value *unpack_complex (llvm::Value *result);
+
   static jit_typeinfo *instance;
 
   llvm::Module *module;
@@ -692,6 +699,8 @@
   jit_type *complex;
   std::map<std::string, jit_type *> builtins;
 
+  llvm::StructType *complex_ret;
+
   std::vector<jit_operation> binary_ops;
   jit_operation grab_fn;
   jit_operation release_fn;