changeset 14936:561aad6a9e4b

Initial support for complex numbers in JIT * src/pt-jit.cc (octave_jit_cast_complex_any, octave_jit_cast_any_complex): New function. (jit_typeinfo::jit_typeinfo, jit_typeinfo::type_of, jit_convert::convert_llvm::visit): Support complex numbers. * src/pt-jit.h (jit_typeinfo::get_complex): New function. (jit_const_complex): New typedef.
author Max Brister <max@2bass.com>
date Mon, 09 Jul 2012 17:06:54 -0500
parents a5f75de0dab1
children f5925478bc15
files src/pt-jit.cc src/pt-jit.h
diffstat 2 files changed, 152 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/src/pt-jit.cc
+++ b/src/pt-jit.cc
@@ -56,6 +56,7 @@
 #include "ov-usr-fcn.h"
 #include "ov-builtin.h"
 #include "ov-scalar.h"
+#include "ov-complex.h"
 #include "pt-all.h"
 
 static llvm::IRBuilder<> builder (llvm::getGlobalContext ());
@@ -199,6 +200,24 @@
 }
 
 extern "C" void
+octave_jit_cast_complex_any (double *dest, octave_base_value *obv)
+{
+  Complex ret = obv->complex_value ();
+  obv->release ();
+  dest[0] = ret.real ();
+  dest[1] = ret.imag ();
+}
+
+extern "C" octave_base_value *
+octave_jit_cast_any_complex (double real, double imag)
+{
+  if (imag == 0)
+    return new octave_scalar (real);
+  else
+    return new octave_complex (Complex (real, imag));
+}
+
+extern "C" void
 octave_jit_gripe_nan_to_logical_conversion (void)
 {
   try
@@ -501,6 +520,8 @@
   matrix_contents[4] = string_t;
   matrix_t->setBody (llvm::makeArrayRef (matrix_contents, 5));
 
+  llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2);
+
   // create types
   any = new_type ("any", 0, any_t);
   matrix = new_type ("matrix", any, matrix_t);
@@ -509,6 +530,7 @@
   string = new_type ("string", any, string_t);
   boolean = new_type ("bool", any, bool_t);
   index = new_type ("index", any, index_t);
+  complex = new_type ("complex", any, complex_t);
 
   casts.resize (next_id + 1);
   identities.resize (next_id + 1, 0);
@@ -595,6 +617,10 @@
   fn = create_identity (scalar);
   release_fn.add_overload (fn, false, 0, scalar);
 
+  // release complex
+  fn = create_identity (complex);
+  release_fn.add_overload (fn, false, 0, complex);
+
   // release index
   fn = create_identity (index);
   release_fn.add_overload (fn, false, 0, index);
@@ -663,6 +689,67 @@
   }
   llvm::verifyFunction (*fn);
 
+  // now for binary complex operations
+  add_binary_op (complex, octave_value::op_add, llvm::Instruction::FAdd);
+  add_binary_op (complex, octave_value::op_sub, llvm::Instruction::FSub);
+
+  fn = create_function ("octave_jit_*_complex_complex", complex, complex,
+                        complex);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    // (x0*x1 - y0*y1, x0*y1 + y0*x1) = (x0,y0) * (x1,y1)
+    // We compute this in one vectorized multiplication, a subtraction, and an
+    // addition.
+    llvm::Value *lhs = fn->arg_begin ();
+    llvm::Value *rhs = ++fn->arg_begin ();
+
+    // FIXME: We need a better way of doing this, working with llvm's IR
+    // directly is sort of a pain.
+    llvm::Value *zero = builder.getInt32 (0);
+    llvm::Value *one = builder.getInt32 (1);
+    llvm::Value *two = builder.getInt32 (2);
+    llvm::Value *three = builder.getInt32 (3);
+
+    llvm::Type *vec4 = llvm::VectorType::get (scalar_t, 4);
+    llvm::Value *mlhs = llvm::UndefValue::get (vec4);
+    llvm::Value *mrhs = mlhs;
+
+    llvm::Value *temp = builder.CreateExtractElement (lhs, zero);
+    mlhs = builder.CreateInsertElement (mlhs, temp, zero);
+    mlhs = builder.CreateInsertElement (mlhs, temp, two);
+    temp = builder.CreateExtractElement (lhs, one);
+    mlhs = builder.CreateInsertElement (mlhs, temp, one);
+    mlhs = builder.CreateInsertElement (mlhs, temp, three);
+
+    temp = builder.CreateExtractElement (rhs, zero);
+    mrhs = builder.CreateInsertElement (mrhs, temp, zero);
+    mrhs = builder.CreateInsertElement (mrhs, temp, three);
+    temp = builder.CreateExtractElement (rhs, one);
+    mrhs = builder.CreateInsertElement (mrhs, temp, one);
+    mrhs = builder.CreateInsertElement (mrhs, temp, two);
+
+    llvm::Value *mres = builder.CreateFMul (mlhs, mrhs);
+    llvm::Value *ret = llvm::UndefValue::get (complex_t);
+    llvm::Value *tlhs = builder.CreateExtractElement (mres, zero);
+    llvm::Value *trhs = builder.CreateExtractElement (mres, one);
+    temp = builder.CreateFSub (tlhs, trhs);
+    //temp = llvm::ConstantFP::get (scalar_t, 123);
+    ret = builder.CreateInsertElement (ret, temp, zero);
+
+    tlhs = builder.CreateExtractElement (mres, two);
+    trhs = builder.CreateExtractElement (mres, three);
+    temp = builder.CreateFAdd (tlhs, trhs);
+    //temp = llvm::ConstantFP::get (scalar_t, 123);
+    ret = builder.CreateInsertElement (ret, temp, one);
+    builder.CreateRet (ret);
+
+    jit_operation::overload ol (fn, false, complex, complex, complex);
+    binary_ops[octave_value::op_mul].add_overload (ol);
+    binary_ops[octave_value::op_el_mul].add_overload (ol);
+  }
+  llvm::verifyFunction (*fn);
+
   // now for binary index operators
   add_binary_op (index, octave_value::op_add, llvm::Instruction::Add);
 
@@ -974,6 +1061,8 @@
 
   casts[any->type_id ()].stash_name ("(any)");
   casts[scalar->type_id ()].stash_name ("(scalar)");
+  casts[complex->type_id ()].stash_name ("(complex)");
+  casts[matrix->type_id ()].stash_name ("(matrix)");
 
   // cast any <- matrix
   fn = create_function ("octave_jit_cast_any_matrix", any_t,
@@ -999,6 +1088,42 @@
   engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_scalar_any));
   casts[scalar->type_id ()].add_overload (fn, false, scalar, any);
 
+  // cast any <- complex
+  llvm::Function *any_complex = create_function ("octave_jit_cast_any_complex",
+                                                 any, scalar, scalar);
+  engine->addGlobalMapping (any_complex, reinterpret_cast<void*> (&octave_jit_cast_any_complex));
+  fn = create_function ("cast_any_complex", any, complex);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *zero = builder.getInt32 (0);
+    llvm::Value *one = builder.getInt32 (1);
+
+    llvm::Value *cmplx = fn->arg_begin ();
+    llvm::Value *real = builder.CreateExtractElement (cmplx, zero);
+    llvm::Value *imag = builder.CreateExtractElement (cmplx, one);
+    llvm::Value *ret = builder.CreateCall2 (any_complex, real, imag);
+    builder.CreateRet (ret);
+  }
+  llvm::verifyFunction (*fn);
+  casts[any->type_id ()].add_overload (fn, false, any, complex);
+
+  // cast complex <- any
+  llvm::Function *complex_any = create_function ("octave_jit_cast_complex_any",
+                                                 void_t,
+                                                 complex_t->getPointerTo (),
+                                                 any_t);
+  fn = create_function ("cast_complex_any", complex, any);
+  body = llvm::BasicBlock::Create (context, "body", fn);
+  builder.SetInsertPoint (body);
+  {
+    llvm::Value *result = builder.CreateAlloca (complex_t);
+    builder.CreateCall2 (complex_any, result, fn->arg_begin ());
+    builder.CreateRet (builder.CreateLoad (result));
+  }
+  llvm::verifyFunction (*fn);
+  casts[complex->type_id ()].add_overload (fn, false, complex, any);
+
   // cast any <- any
   fn = create_identity (any);
   casts[any->type_id ()].add_overload (fn, false, any, any);
@@ -1007,6 +1132,9 @@
   fn = create_identity (scalar);
   casts[scalar->type_id ()].add_overload (fn, false, scalar, scalar);
 
+  // cast complex <- complex
+  fn = create_identity (complex);
+  casts[complex->type_id ()].add_overload (fn, false, complex, complex);
 
   // -------------------- builtin functions --------------------
   add_builtin ("sin");
@@ -1259,6 +1387,9 @@
         return get_matrix ();
     }
 
+  if (ov.is_complex_scalar ())
+    return get_complex ();
+
   return get_any ();
 }
 
@@ -2324,6 +2455,11 @@
       Range rv = v.range_value ();
       result = create<jit_const_range> (rv);
     }
+  else if (v.is_complex_scalar ())
+    {
+      Complex cv = v.complex_value ();
+      result = create<jit_const_complex> (cv);
+    }
   else
     fail ("Unknown constant");
 }
@@ -3036,6 +3172,17 @@
   cs.stash_llvm (llvm::ConstantFP::get (cs.type_llvm (), cs.value ()));
 }
 
+void
+jit_convert::convert_llvm::visit (jit_const_complex& cc)
+{
+  llvm::Type *scalar_t = jit_typeinfo::get_scalar_llvm ();
+  llvm::Constant *values[2];
+  Complex value = cc.value ();
+  values[0] = llvm::ConstantFP::get (scalar_t, value.real ());
+  values[1] = llvm::ConstantFP::get (scalar_t, value.imag ());
+  cc.stash_llvm (llvm::ConstantVector::get (values));
+}
+
 void jit_convert::convert_llvm::visit (jit_const_index& ci)
 {
   ci.stash_llvm (llvm::ConstantInt::get (ci.type_llvm (), ci.value ()));
--- a/src/pt-jit.h
+++ b/src/pt-jit.h
@@ -448,6 +448,8 @@
   static llvm::Type *get_index_llvm (void)
   { return instance->index->to_llvm (); }
 
+  static jit_type *get_complex (void) { return instance->complex; }
+
   static jit_type *type_of (const octave_value& ov)
   {
     return instance->do_type_of (ov);
@@ -687,6 +689,7 @@
   jit_type *string;
   jit_type *boolean;
   jit_type *index;
+  jit_type *complex;
   std::map<std::string, jit_type *> builtins;
 
   std::vector<jit_operation> binary_ops;
@@ -728,6 +731,7 @@
 #define JIT_VISIT_IR_CONST                      \
   JIT_METH(const_bool);                         \
   JIT_METH(const_scalar);                       \
+  JIT_METH(const_complex);                      \
   JIT_METH(const_index);                        \
   JIT_METH(const_string);                       \
   JIT_METH(const_range)
@@ -756,6 +760,7 @@
 
 typedef jit_const<bool, jit_typeinfo::get_bool> jit_const_bool;
 typedef jit_const<double, jit_typeinfo::get_scalar> jit_const_scalar;
+typedef jit_const<Complex, jit_typeinfo::get_complex> jit_const_complex;
 typedef jit_const<octave_idx_type, jit_typeinfo::get_index> jit_const_index;
 
 typedef jit_const<std::string, jit_typeinfo::get_string, const std::string&,