# HG changeset patch # User Max Brister # Date 1341871614 18000 # Node ID 561aad6a9e4b553aba9df6be83dcce6ee46913b7 # Parent a5f75de0dab1c0a021a06668c35fe823e5d7de03 Initial support for complex numbers in JIT * src/pt-jit.cc (octave_jit_cast_complex_any, octave_jit_cast_any_complex): New function. (jit_typeinfo::jit_typeinfo, jit_typeinfo::type_of, jit_convert::convert_llvm::visit): Support complex numbers. * src/pt-jit.h (jit_typeinfo::get_complex): New function. (jit_const_complex): New typedef. diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -56,6 +56,7 @@ #include "ov-usr-fcn.h" #include "ov-builtin.h" #include "ov-scalar.h" +#include "ov-complex.h" #include "pt-all.h" static llvm::IRBuilder<> builder (llvm::getGlobalContext ()); @@ -199,6 +200,24 @@ } extern "C" void +octave_jit_cast_complex_any (double *dest, octave_base_value *obv) +{ + Complex ret = obv->complex_value (); + obv->release (); + dest[0] = ret.real (); + dest[1] = ret.imag (); +} + +extern "C" octave_base_value * +octave_jit_cast_any_complex (double real, double imag) +{ + if (imag == 0) + return new octave_scalar (real); + else + return new octave_complex (Complex (real, imag)); +} + +extern "C" void octave_jit_gripe_nan_to_logical_conversion (void) { try @@ -501,6 +520,8 @@ matrix_contents[4] = string_t; matrix_t->setBody (llvm::makeArrayRef (matrix_contents, 5)); + llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2); + // create types any = new_type ("any", 0, any_t); matrix = new_type ("matrix", any, matrix_t); @@ -509,6 +530,7 @@ string = new_type ("string", any, string_t); boolean = new_type ("bool", any, bool_t); index = new_type ("index", any, index_t); + complex = new_type ("complex", any, complex_t); casts.resize (next_id + 1); identities.resize (next_id + 1, 0); @@ -595,6 +617,10 @@ fn = create_identity (scalar); release_fn.add_overload (fn, false, 0, scalar); + // release complex + fn = create_identity (complex); + release_fn.add_overload (fn, false, 0, complex); + // release index fn = create_identity (index); release_fn.add_overload (fn, false, 0, index); @@ -663,6 +689,67 @@ } llvm::verifyFunction (*fn); + // now for binary complex operations + add_binary_op (complex, octave_value::op_add, llvm::Instruction::FAdd); + add_binary_op (complex, octave_value::op_sub, llvm::Instruction::FSub); + + fn = create_function ("octave_jit_*_complex_complex", complex, complex, + complex); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + // (x0*x1 - y0*y1, x0*y1 + y0*x1) = (x0,y0) * (x1,y1) + // We compute this in one vectorized multiplication, a subtraction, and an + // addition. + llvm::Value *lhs = fn->arg_begin (); + llvm::Value *rhs = ++fn->arg_begin (); + + // FIXME: We need a better way of doing this, working with llvm's IR + // directly is sort of a pain. + llvm::Value *zero = builder.getInt32 (0); + llvm::Value *one = builder.getInt32 (1); + llvm::Value *two = builder.getInt32 (2); + llvm::Value *three = builder.getInt32 (3); + + llvm::Type *vec4 = llvm::VectorType::get (scalar_t, 4); + llvm::Value *mlhs = llvm::UndefValue::get (vec4); + llvm::Value *mrhs = mlhs; + + llvm::Value *temp = builder.CreateExtractElement (lhs, zero); + mlhs = builder.CreateInsertElement (mlhs, temp, zero); + mlhs = builder.CreateInsertElement (mlhs, temp, two); + temp = builder.CreateExtractElement (lhs, one); + mlhs = builder.CreateInsertElement (mlhs, temp, one); + mlhs = builder.CreateInsertElement (mlhs, temp, three); + + temp = builder.CreateExtractElement (rhs, zero); + mrhs = builder.CreateInsertElement (mrhs, temp, zero); + mrhs = builder.CreateInsertElement (mrhs, temp, three); + temp = builder.CreateExtractElement (rhs, one); + mrhs = builder.CreateInsertElement (mrhs, temp, one); + mrhs = builder.CreateInsertElement (mrhs, temp, two); + + llvm::Value *mres = builder.CreateFMul (mlhs, mrhs); + llvm::Value *ret = llvm::UndefValue::get (complex_t); + llvm::Value *tlhs = builder.CreateExtractElement (mres, zero); + llvm::Value *trhs = builder.CreateExtractElement (mres, one); + temp = builder.CreateFSub (tlhs, trhs); + //temp = llvm::ConstantFP::get (scalar_t, 123); + ret = builder.CreateInsertElement (ret, temp, zero); + + tlhs = builder.CreateExtractElement (mres, two); + trhs = builder.CreateExtractElement (mres, three); + temp = builder.CreateFAdd (tlhs, trhs); + //temp = llvm::ConstantFP::get (scalar_t, 123); + ret = builder.CreateInsertElement (ret, temp, one); + builder.CreateRet (ret); + + jit_operation::overload ol (fn, false, complex, complex, complex); + binary_ops[octave_value::op_mul].add_overload (ol); + binary_ops[octave_value::op_el_mul].add_overload (ol); + } + llvm::verifyFunction (*fn); + // now for binary index operators add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); @@ -974,6 +1061,8 @@ casts[any->type_id ()].stash_name ("(any)"); casts[scalar->type_id ()].stash_name ("(scalar)"); + casts[complex->type_id ()].stash_name ("(complex)"); + casts[matrix->type_id ()].stash_name ("(matrix)"); // cast any <- matrix fn = create_function ("octave_jit_cast_any_matrix", any_t, @@ -999,6 +1088,42 @@ engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_scalar_any)); casts[scalar->type_id ()].add_overload (fn, false, scalar, any); + // cast any <- complex + llvm::Function *any_complex = create_function ("octave_jit_cast_any_complex", + any, scalar, scalar); + engine->addGlobalMapping (any_complex, reinterpret_cast (&octave_jit_cast_any_complex)); + fn = create_function ("cast_any_complex", any, complex); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *zero = builder.getInt32 (0); + llvm::Value *one = builder.getInt32 (1); + + llvm::Value *cmplx = fn->arg_begin (); + llvm::Value *real = builder.CreateExtractElement (cmplx, zero); + llvm::Value *imag = builder.CreateExtractElement (cmplx, one); + llvm::Value *ret = builder.CreateCall2 (any_complex, real, imag); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + casts[any->type_id ()].add_overload (fn, false, any, complex); + + // cast complex <- any + llvm::Function *complex_any = create_function ("octave_jit_cast_complex_any", + void_t, + complex_t->getPointerTo (), + any_t); + fn = create_function ("cast_complex_any", complex, any); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *result = builder.CreateAlloca (complex_t); + builder.CreateCall2 (complex_any, result, fn->arg_begin ()); + builder.CreateRet (builder.CreateLoad (result)); + } + llvm::verifyFunction (*fn); + casts[complex->type_id ()].add_overload (fn, false, complex, any); + // cast any <- any fn = create_identity (any); casts[any->type_id ()].add_overload (fn, false, any, any); @@ -1007,6 +1132,9 @@ fn = create_identity (scalar); casts[scalar->type_id ()].add_overload (fn, false, scalar, scalar); + // cast complex <- complex + fn = create_identity (complex); + casts[complex->type_id ()].add_overload (fn, false, complex, complex); // -------------------- builtin functions -------------------- add_builtin ("sin"); @@ -1259,6 +1387,9 @@ return get_matrix (); } + if (ov.is_complex_scalar ()) + return get_complex (); + return get_any (); } @@ -2324,6 +2455,11 @@ Range rv = v.range_value (); result = create (rv); } + else if (v.is_complex_scalar ()) + { + Complex cv = v.complex_value (); + result = create (cv); + } else fail ("Unknown constant"); } @@ -3036,6 +3172,17 @@ cs.stash_llvm (llvm::ConstantFP::get (cs.type_llvm (), cs.value ())); } +void +jit_convert::convert_llvm::visit (jit_const_complex& cc) +{ + llvm::Type *scalar_t = jit_typeinfo::get_scalar_llvm (); + llvm::Constant *values[2]; + Complex value = cc.value (); + values[0] = llvm::ConstantFP::get (scalar_t, value.real ()); + values[1] = llvm::ConstantFP::get (scalar_t, value.imag ()); + cc.stash_llvm (llvm::ConstantVector::get (values)); +} + void jit_convert::convert_llvm::visit (jit_const_index& ci) { ci.stash_llvm (llvm::ConstantInt::get (ci.type_llvm (), ci.value ())); diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -448,6 +448,8 @@ static llvm::Type *get_index_llvm (void) { return instance->index->to_llvm (); } + static jit_type *get_complex (void) { return instance->complex; } + static jit_type *type_of (const octave_value& ov) { return instance->do_type_of (ov); @@ -687,6 +689,7 @@ jit_type *string; jit_type *boolean; jit_type *index; + jit_type *complex; std::map builtins; std::vector binary_ops; @@ -728,6 +731,7 @@ #define JIT_VISIT_IR_CONST \ JIT_METH(const_bool); \ JIT_METH(const_scalar); \ + JIT_METH(const_complex); \ JIT_METH(const_index); \ JIT_METH(const_string); \ JIT_METH(const_range) @@ -756,6 +760,7 @@ typedef jit_const jit_const_bool; typedef jit_const jit_const_scalar; +typedef jit_const jit_const_complex; typedef jit_const jit_const_index; typedef jit_const