# HG changeset patch # User Max Brister # Date 1341953705 18000 # Node ID f5925478bc15cc3502bee249637b954278829d4a # Parent 561aad6a9e4b553aba9df6be83dcce6ee46913b7 More support for complex numbers in JIT * src/pt-jit.cc (octave_jit_cast_complex_any): Return result directly. (octave_jit_complex_div, jit_typeinfo::wrap_complex, jit_typeinfo::pack_complex, jit_typeinfo::unpack_complex): New function. (jit_typeinfo::jit_typeinfo): Support more complex functionality. (tree_jit::optimize): Write llvm bytecode to a file when debugging. * src/pt-jit.h (jit_typeinfo::wrap_complex, jit_typeinfo::pack_complex, jit_typeinfo): New declarations. diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -50,6 +50,8 @@ #include #include #include +#include +#include #include "octave.h" #include "ov-fcn-handle.h" @@ -199,22 +201,21 @@ return new octave_scalar (value); } -extern "C" void -octave_jit_cast_complex_any (double *dest, octave_base_value *obv) +extern "C" Complex +octave_jit_cast_complex_any (octave_base_value *obv) { Complex ret = obv->complex_value (); obv->release (); - dest[0] = ret.real (); - dest[1] = ret.imag (); + return ret; } extern "C" octave_base_value * -octave_jit_cast_any_complex (double real, double imag) +octave_jit_cast_any_complex (Complex c) { - if (imag == 0) - return new octave_scalar (real); + if (c.imag () == 0) + return new octave_scalar (c.real ()); else - return new octave_complex (Complex (real, imag)); + return new octave_complex (c); } extern "C" void @@ -320,6 +321,16 @@ result->update (array); } +extern "C" Complex +octave_jit_complex_div (Complex lhs, Complex rhs) +{ + // see src/OPERATORS/op-cs-cs.cc + if (rhs == 0.0) + gripe_divide_by_zero (); + + return lhs / rhs; +} + extern "C" void octave_jit_print_matrix (jit_matrix *m) { @@ -522,6 +533,12 @@ llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2); + // this is the structure that C functions return. Use this in order to get calling + // conventions right. + complex_ret = llvm::StructType::create (context, "complex_ret"); + llvm::Type *complex_ret_contents[] = {scalar_t, scalar_t}; + complex_ret->setBody (complex_ret_contents); + // create types any = new_type ("any", 0, any_t); matrix = new_type ("matrix", any, matrix_t); @@ -734,13 +751,11 @@ llvm::Value *tlhs = builder.CreateExtractElement (mres, zero); llvm::Value *trhs = builder.CreateExtractElement (mres, one); temp = builder.CreateFSub (tlhs, trhs); - //temp = llvm::ConstantFP::get (scalar_t, 123); ret = builder.CreateInsertElement (ret, temp, zero); tlhs = builder.CreateExtractElement (mres, two); trhs = builder.CreateExtractElement (mres, three); temp = builder.CreateFAdd (tlhs, trhs); - //temp = llvm::ConstantFP::get (scalar_t, 123); ret = builder.CreateInsertElement (ret, temp, one); builder.CreateRet (ret); @@ -750,6 +765,67 @@ } llvm::verifyFunction (*fn); + fn = create_function ("octave_jit_*_scalar_complex", complex, scalar, + complex); + llvm::Function *mul_scalar_complex = fn; + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn->arg_begin (); + llvm::Value *tlhs = llvm::UndefValue::get (complex_t); + tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (0)); + tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (1)); + + llvm::Value *rhs = ++fn->arg_begin (); + builder.CreateRet (builder.CreateFMul (tlhs, rhs)); + + jit_operation::overload ol (fn, false, complex, scalar, complex); + binary_ops[octave_value::op_mul].add_overload (ol); + binary_ops[octave_value::op_el_mul].add_overload (ol); + } + llvm::verifyFunction (*fn); + + fn = create_function ("octave_jit_*_complex_scalar", complex, complex, + scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *ret = builder.CreateCall2 (mul_scalar_complex, + ++fn->arg_begin (), + fn->arg_begin ()); + builder.CreateRet (ret); + + jit_operation::overload ol (fn, false, complex, complex, scalar); + binary_ops[octave_value::op_mul].add_overload (ol); + binary_ops[octave_value::op_el_mul].add_overload (ol); + } + llvm::verifyFunction (*fn); + + llvm::Function *complex_div = create_function ("octave_jit_complex_div", + complex_ret, complex_ret, + complex_ret); + engine->addGlobalMapping (complex_div, + reinterpret_cast (&octave_jit_complex_div)); + complex_div = wrap_complex (complex_div); + { + jit_operation::overload ol (complex_div, true, complex, complex, complex); + binary_ops[octave_value::op_div].add_overload (ol); + binary_ops[octave_value::op_ldiv].add_overload (ol); + } + + fn = create_function ("octave_jit_\\_complex_complex", complex, complex, + complex); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + builder.CreateRet (builder.CreateCall2 (complex_div, ++fn->arg_begin (), + fn->arg_begin ())); + jit_operation::overload ol (fn, true, complex, complex, complex); + binary_ops[octave_value::op_ldiv].add_overload (ol); + binary_ops[octave_value::op_el_ldiv].add_overload (ol); + } + llvm::verifyFunction (*fn); + // now for binary index operators add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); @@ -1089,40 +1165,15 @@ casts[scalar->type_id ()].add_overload (fn, false, scalar, any); // cast any <- complex - llvm::Function *any_complex = create_function ("octave_jit_cast_any_complex", - any, scalar, scalar); - engine->addGlobalMapping (any_complex, reinterpret_cast (&octave_jit_cast_any_complex)); - fn = create_function ("cast_any_complex", any, complex); - body = llvm::BasicBlock::Create (context, "body", fn); - builder.SetInsertPoint (body); - { - llvm::Value *zero = builder.getInt32 (0); - llvm::Value *one = builder.getInt32 (1); - - llvm::Value *cmplx = fn->arg_begin (); - llvm::Value *real = builder.CreateExtractElement (cmplx, zero); - llvm::Value *imag = builder.CreateExtractElement (cmplx, one); - llvm::Value *ret = builder.CreateCall2 (any_complex, real, imag); - builder.CreateRet (ret); - } - llvm::verifyFunction (*fn); - casts[any->type_id ()].add_overload (fn, false, any, complex); + fn = create_function ("octave_jit_cast_any_complex", any_t, complex_ret); + engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_any_complex)); + casts[any->type_id ()].add_overload (wrap_complex (fn), false, any, complex); // cast complex <- any - llvm::Function *complex_any = create_function ("octave_jit_cast_complex_any", - void_t, - complex_t->getPointerTo (), - any_t); - fn = create_function ("cast_complex_any", complex, any); - body = llvm::BasicBlock::Create (context, "body", fn); - builder.SetInsertPoint (body); - { - llvm::Value *result = builder.CreateAlloca (complex_t); - builder.CreateCall2 (complex_any, result, fn->arg_begin ()); - builder.CreateRet (builder.CreateLoad (result)); - } - llvm::verifyFunction (*fn); - casts[complex->type_id ()].add_overload (fn, false, complex, any); + fn = create_function ("octave_jit_cast_complex_any", complex_ret, any_t); + engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_complex_any)); + casts[complex->type_id ()].add_overload (wrap_complex (fn), false, complex, + any); // cast any <- any fn = create_identity (any); @@ -1363,6 +1414,77 @@ // FIXME: Implement } +llvm::Function * +jit_typeinfo::wrap_complex (llvm::Function *wrap) +{ + llvm::SmallVector new_args; + new_args.reserve (wrap->arg_size ()); + llvm::Type *complex_t = complex->to_llvm (); + for (llvm::Function::arg_iterator iter = wrap->arg_begin (); + iter != wrap->arg_end (); ++iter) + { + llvm::Value *value = iter; + llvm::Type *type = value->getType (); + new_args.push_back (type == complex_ret ? complex_t : type); + } + + llvm::FunctionType *wrap_type = wrap->getFunctionType (); + bool convert_ret = wrap_type->getReturnType () == complex_ret; + llvm::Type *rtype = convert_ret ? complex_t : wrap->getReturnType (); + llvm::FunctionType *ft = llvm::FunctionType::get (rtype, new_args, false); + llvm::Function *fn = llvm::Function::Create (ft, + llvm::Function::ExternalLinkage, + wrap->getName () + "_wrap", + module); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + + llvm::SmallVector converted (new_args.size ()); + llvm::Function::arg_iterator witer = wrap->arg_begin (); + llvm::Function::arg_iterator fiter = fn->arg_begin (); + for (size_t i = 0; i < new_args.size (); ++i, ++witer, ++fiter) + { + llvm::Value *warg = witer; + llvm::Value *arg = fiter; + converted[i] = warg->getType () == arg->getType () ? arg + : pack_complex (arg); + } + + llvm::Value *ret = builder.CreateCall (wrap, converted); + if (wrap_type->getReturnType () != builder.getVoidTy ()) + { + if (convert_ret) + ret = unpack_complex (ret); + builder.CreateRet (ret); + } + else + builder.CreateRetVoid (); + + llvm::verifyFunction (*fn); + return fn; +} + +llvm::Value * +jit_typeinfo::pack_complex (llvm::Value *cplx) +{ + llvm::Value *real = builder.CreateExtractElement (cplx, builder.getInt32 (0)); + llvm::Value *imag = builder.CreateExtractElement (cplx, builder.getInt32 (1)); + llvm::Value *ret = llvm::UndefValue::get (complex_ret); + ret = builder.CreateInsertValue (ret, real, 0); + return builder.CreateInsertValue (ret, imag, 1); +} + +llvm::Value * +jit_typeinfo::unpack_complex (llvm::Value *result) +{ + llvm::Type *complex_t = complex->to_llvm (); + llvm::Value *real = builder.CreateExtractValue (result, 0); + llvm::Value *imag = builder.CreateExtractValue (result, 1); + llvm::Value *ret = llvm::UndefValue::get (complex_t); + ret = builder.CreateInsertElement (ret, real, builder.getInt32 (0)); + return builder.CreateInsertElement (ret, imag, builder.getInt32 (1)); +} + jit_type * jit_typeinfo::do_type_of (const octave_value &ov) const { @@ -3446,6 +3568,13 @@ { module_pass_manager->run (*module); pass_manager->run (*fn); + +#ifdef OCTAVE_JIT_DEBUG + std::string error; + llvm::raw_fd_ostream fout ("test.bc", error, + llvm::raw_fd_ostream::F_Binary); + llvm::WriteBitcodeToFile (module, fout); +#endif } // -------------------- jit_info -------------------- diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -77,6 +77,7 @@ class BasicBlock; class LLVMContext; class Type; + class StructType; class Twine; class GlobalVariable; class TerminatorInst; @@ -673,6 +674,12 @@ octave_builtin *find_builtin (const std::string& name); + llvm::Function *wrap_complex (llvm::Function *wrap); + + llvm::Value *pack_complex (llvm::Value *cplx); + + llvm::Value *unpack_complex (llvm::Value *result); + static jit_typeinfo *instance; llvm::Module *module; @@ -692,6 +699,8 @@ jit_type *complex; std::map builtins; + llvm::StructType *complex_ret; + std::vector binary_ops; jit_operation grab_fn; jit_operation release_fn;