comparison src/pt-jit.cc @ 14985:f5925478bc15

More support for complex numbers in JIT * src/pt-jit.cc (octave_jit_cast_complex_any): Return result directly. (octave_jit_complex_div, jit_typeinfo::wrap_complex, jit_typeinfo::pack_complex, jit_typeinfo::unpack_complex): New function. (jit_typeinfo::jit_typeinfo): Support more complex functionality. (tree_jit::optimize): Write llvm bytecode to a file when debugging. * src/pt-jit.h (jit_typeinfo::wrap_complex, jit_typeinfo::pack_complex, jit_typeinfo): New declarations.
author Max Brister <max@2bass.com>
date Tue, 10 Jul 2012 15:55:05 -0500
parents 561aad6a9e4b
children 70ff15b6d996
comparison
equal deleted inserted replaced
14984:561aad6a9e4b 14985:f5925478bc15
48 #include <llvm/Target/TargetData.h> 48 #include <llvm/Target/TargetData.h>
49 #include <llvm/Transforms/Scalar.h> 49 #include <llvm/Transforms/Scalar.h>
50 #include <llvm/Transforms/IPO.h> 50 #include <llvm/Transforms/IPO.h>
51 #include <llvm/Support/TargetSelect.h> 51 #include <llvm/Support/TargetSelect.h>
52 #include <llvm/Support/raw_os_ostream.h> 52 #include <llvm/Support/raw_os_ostream.h>
53 #include <llvm/Support/FormattedStream.h>
54 #include <llvm/Bitcode/ReaderWriter.h>
53 55
54 #include "octave.h" 56 #include "octave.h"
55 #include "ov-fcn-handle.h" 57 #include "ov-fcn-handle.h"
56 #include "ov-usr-fcn.h" 58 #include "ov-usr-fcn.h"
57 #include "ov-builtin.h" 59 #include "ov-builtin.h"
197 octave_jit_cast_any_scalar (double value) 199 octave_jit_cast_any_scalar (double value)
198 { 200 {
199 return new octave_scalar (value); 201 return new octave_scalar (value);
200 } 202 }
201 203
202 extern "C" void 204 extern "C" Complex
203 octave_jit_cast_complex_any (double *dest, octave_base_value *obv) 205 octave_jit_cast_complex_any (octave_base_value *obv)
204 { 206 {
205 Complex ret = obv->complex_value (); 207 Complex ret = obv->complex_value ();
206 obv->release (); 208 obv->release ();
207 dest[0] = ret.real (); 209 return ret;
208 dest[1] = ret.imag ();
209 } 210 }
210 211
211 extern "C" octave_base_value * 212 extern "C" octave_base_value *
212 octave_jit_cast_any_complex (double real, double imag) 213 octave_jit_cast_any_complex (Complex c)
213 { 214 {
214 if (imag == 0) 215 if (c.imag () == 0)
215 return new octave_scalar (real); 216 return new octave_scalar (c.real ());
216 else 217 else
217 return new octave_complex (Complex (real, imag)); 218 return new octave_complex (c);
218 } 219 }
219 220
220 extern "C" void 221 extern "C" void
221 octave_jit_gripe_nan_to_logical_conversion (void) 222 octave_jit_gripe_nan_to_logical_conversion (void)
222 { 223 {
318 } 319 }
319 320
320 result->update (array); 321 result->update (array);
321 } 322 }
322 323
324 extern "C" Complex
325 octave_jit_complex_div (Complex lhs, Complex rhs)
326 {
327 // see src/OPERATORS/op-cs-cs.cc
328 if (rhs == 0.0)
329 gripe_divide_by_zero ();
330
331 return lhs / rhs;
332 }
333
323 extern "C" void 334 extern "C" void
324 octave_jit_print_matrix (jit_matrix *m) 335 octave_jit_print_matrix (jit_matrix *m)
325 { 336 {
326 std::cout << *m << std::endl; 337 std::cout << *m << std::endl;
327 } 338 }
519 matrix_contents[3] = index_t->getPointerTo (); 530 matrix_contents[3] = index_t->getPointerTo ();
520 matrix_contents[4] = string_t; 531 matrix_contents[4] = string_t;
521 matrix_t->setBody (llvm::makeArrayRef (matrix_contents, 5)); 532 matrix_t->setBody (llvm::makeArrayRef (matrix_contents, 5));
522 533
523 llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2); 534 llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2);
535
536 // this is the structure that C functions return. Use this in order to get calling
537 // conventions right.
538 complex_ret = llvm::StructType::create (context, "complex_ret");
539 llvm::Type *complex_ret_contents[] = {scalar_t, scalar_t};
540 complex_ret->setBody (complex_ret_contents);
524 541
525 // create types 542 // create types
526 any = new_type ("any", 0, any_t); 543 any = new_type ("any", 0, any_t);
527 matrix = new_type ("matrix", any, matrix_t); 544 matrix = new_type ("matrix", any, matrix_t);
528 scalar = new_type ("scalar", any, scalar_t); 545 scalar = new_type ("scalar", any, scalar_t);
732 llvm::Value *mres = builder.CreateFMul (mlhs, mrhs); 749 llvm::Value *mres = builder.CreateFMul (mlhs, mrhs);
733 llvm::Value *ret = llvm::UndefValue::get (complex_t); 750 llvm::Value *ret = llvm::UndefValue::get (complex_t);
734 llvm::Value *tlhs = builder.CreateExtractElement (mres, zero); 751 llvm::Value *tlhs = builder.CreateExtractElement (mres, zero);
735 llvm::Value *trhs = builder.CreateExtractElement (mres, one); 752 llvm::Value *trhs = builder.CreateExtractElement (mres, one);
736 temp = builder.CreateFSub (tlhs, trhs); 753 temp = builder.CreateFSub (tlhs, trhs);
737 //temp = llvm::ConstantFP::get (scalar_t, 123);
738 ret = builder.CreateInsertElement (ret, temp, zero); 754 ret = builder.CreateInsertElement (ret, temp, zero);
739 755
740 tlhs = builder.CreateExtractElement (mres, two); 756 tlhs = builder.CreateExtractElement (mres, two);
741 trhs = builder.CreateExtractElement (mres, three); 757 trhs = builder.CreateExtractElement (mres, three);
742 temp = builder.CreateFAdd (tlhs, trhs); 758 temp = builder.CreateFAdd (tlhs, trhs);
743 //temp = llvm::ConstantFP::get (scalar_t, 123);
744 ret = builder.CreateInsertElement (ret, temp, one); 759 ret = builder.CreateInsertElement (ret, temp, one);
745 builder.CreateRet (ret); 760 builder.CreateRet (ret);
746 761
747 jit_operation::overload ol (fn, false, complex, complex, complex); 762 jit_operation::overload ol (fn, false, complex, complex, complex);
748 binary_ops[octave_value::op_mul].add_overload (ol); 763 binary_ops[octave_value::op_mul].add_overload (ol);
749 binary_ops[octave_value::op_el_mul].add_overload (ol); 764 binary_ops[octave_value::op_el_mul].add_overload (ol);
765 }
766 llvm::verifyFunction (*fn);
767
768 fn = create_function ("octave_jit_*_scalar_complex", complex, scalar,
769 complex);
770 llvm::Function *mul_scalar_complex = fn;
771 body = llvm::BasicBlock::Create (context, "body", fn);
772 builder.SetInsertPoint (body);
773 {
774 llvm::Value *lhs = fn->arg_begin ();
775 llvm::Value *tlhs = llvm::UndefValue::get (complex_t);
776 tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (0));
777 tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (1));
778
779 llvm::Value *rhs = ++fn->arg_begin ();
780 builder.CreateRet (builder.CreateFMul (tlhs, rhs));
781
782 jit_operation::overload ol (fn, false, complex, scalar, complex);
783 binary_ops[octave_value::op_mul].add_overload (ol);
784 binary_ops[octave_value::op_el_mul].add_overload (ol);
785 }
786 llvm::verifyFunction (*fn);
787
788 fn = create_function ("octave_jit_*_complex_scalar", complex, complex,
789 scalar);
790 body = llvm::BasicBlock::Create (context, "body", fn);
791 builder.SetInsertPoint (body);
792 {
793 llvm::Value *ret = builder.CreateCall2 (mul_scalar_complex,
794 ++fn->arg_begin (),
795 fn->arg_begin ());
796 builder.CreateRet (ret);
797
798 jit_operation::overload ol (fn, false, complex, complex, scalar);
799 binary_ops[octave_value::op_mul].add_overload (ol);
800 binary_ops[octave_value::op_el_mul].add_overload (ol);
801 }
802 llvm::verifyFunction (*fn);
803
804 llvm::Function *complex_div = create_function ("octave_jit_complex_div",
805 complex_ret, complex_ret,
806 complex_ret);
807 engine->addGlobalMapping (complex_div,
808 reinterpret_cast<void *> (&octave_jit_complex_div));
809 complex_div = wrap_complex (complex_div);
810 {
811 jit_operation::overload ol (complex_div, true, complex, complex, complex);
812 binary_ops[octave_value::op_div].add_overload (ol);
813 binary_ops[octave_value::op_ldiv].add_overload (ol);
814 }
815
816 fn = create_function ("octave_jit_\\_complex_complex", complex, complex,
817 complex);
818 body = llvm::BasicBlock::Create (context, "body", fn);
819 builder.SetInsertPoint (body);
820 {
821 builder.CreateRet (builder.CreateCall2 (complex_div, ++fn->arg_begin (),
822 fn->arg_begin ()));
823 jit_operation::overload ol (fn, true, complex, complex, complex);
824 binary_ops[octave_value::op_ldiv].add_overload (ol);
825 binary_ops[octave_value::op_el_ldiv].add_overload (ol);
750 } 826 }
751 llvm::verifyFunction (*fn); 827 llvm::verifyFunction (*fn);
752 828
753 // now for binary index operators 829 // now for binary index operators
754 add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); 830 add_binary_op (index, octave_value::op_add, llvm::Instruction::Add);
1087 fn = create_function ("octave_jit_cast_scalar_any", scalar, any); 1163 fn = create_function ("octave_jit_cast_scalar_any", scalar, any);
1088 engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_scalar_any)); 1164 engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_scalar_any));
1089 casts[scalar->type_id ()].add_overload (fn, false, scalar, any); 1165 casts[scalar->type_id ()].add_overload (fn, false, scalar, any);
1090 1166
1091 // cast any <- complex 1167 // cast any <- complex
1092 llvm::Function *any_complex = create_function ("octave_jit_cast_any_complex", 1168 fn = create_function ("octave_jit_cast_any_complex", any_t, complex_ret);
1093 any, scalar, scalar); 1169 engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_any_complex));
1094 engine->addGlobalMapping (any_complex, reinterpret_cast<void*> (&octave_jit_cast_any_complex)); 1170 casts[any->type_id ()].add_overload (wrap_complex (fn), false, any, complex);
1095 fn = create_function ("cast_any_complex", any, complex);
1096 body = llvm::BasicBlock::Create (context, "body", fn);
1097 builder.SetInsertPoint (body);
1098 {
1099 llvm::Value *zero = builder.getInt32 (0);
1100 llvm::Value *one = builder.getInt32 (1);
1101
1102 llvm::Value *cmplx = fn->arg_begin ();
1103 llvm::Value *real = builder.CreateExtractElement (cmplx, zero);
1104 llvm::Value *imag = builder.CreateExtractElement (cmplx, one);
1105 llvm::Value *ret = builder.CreateCall2 (any_complex, real, imag);
1106 builder.CreateRet (ret);
1107 }
1108 llvm::verifyFunction (*fn);
1109 casts[any->type_id ()].add_overload (fn, false, any, complex);
1110 1171
1111 // cast complex <- any 1172 // cast complex <- any
1112 llvm::Function *complex_any = create_function ("octave_jit_cast_complex_any", 1173 fn = create_function ("octave_jit_cast_complex_any", complex_ret, any_t);
1113 void_t, 1174 engine->addGlobalMapping (fn, reinterpret_cast<void *> (&octave_jit_cast_complex_any));
1114 complex_t->getPointerTo (), 1175 casts[complex->type_id ()].add_overload (wrap_complex (fn), false, complex,
1115 any_t); 1176 any);
1116 fn = create_function ("cast_complex_any", complex, any);
1117 body = llvm::BasicBlock::Create (context, "body", fn);
1118 builder.SetInsertPoint (body);
1119 {
1120 llvm::Value *result = builder.CreateAlloca (complex_t);
1121 builder.CreateCall2 (complex_any, result, fn->arg_begin ());
1122 builder.CreateRet (builder.CreateLoad (result));
1123 }
1124 llvm::verifyFunction (*fn);
1125 casts[complex->type_id ()].add_overload (fn, false, complex, any);
1126 1177
1127 // cast any <- any 1178 // cast any <- any
1128 fn = create_identity (any); 1179 fn = create_identity (any);
1129 casts[any->type_id ()].add_overload (fn, false, any, any); 1180 casts[any->type_id ()].add_overload (fn, false, any, any);
1130 1181
1359 void 1410 void
1360 jit_typeinfo::register_generic (const std::string&, jit_type *, 1411 jit_typeinfo::register_generic (const std::string&, jit_type *,
1361 const std::vector<jit_type *>&) 1412 const std::vector<jit_type *>&)
1362 { 1413 {
1363 // FIXME: Implement 1414 // FIXME: Implement
1415 }
1416
1417 llvm::Function *
1418 jit_typeinfo::wrap_complex (llvm::Function *wrap)
1419 {
1420 llvm::SmallVector<llvm::Type *, 5> new_args;
1421 new_args.reserve (wrap->arg_size ());
1422 llvm::Type *complex_t = complex->to_llvm ();
1423 for (llvm::Function::arg_iterator iter = wrap->arg_begin ();
1424 iter != wrap->arg_end (); ++iter)
1425 {
1426 llvm::Value *value = iter;
1427 llvm::Type *type = value->getType ();
1428 new_args.push_back (type == complex_ret ? complex_t : type);
1429 }
1430
1431 llvm::FunctionType *wrap_type = wrap->getFunctionType ();
1432 bool convert_ret = wrap_type->getReturnType () == complex_ret;
1433 llvm::Type *rtype = convert_ret ? complex_t : wrap->getReturnType ();
1434 llvm::FunctionType *ft = llvm::FunctionType::get (rtype, new_args, false);
1435 llvm::Function *fn = llvm::Function::Create (ft,
1436 llvm::Function::ExternalLinkage,
1437 wrap->getName () + "_wrap",
1438 module);
1439 llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn);
1440 builder.SetInsertPoint (body);
1441
1442 llvm::SmallVector<llvm::Value *, 5> converted (new_args.size ());
1443 llvm::Function::arg_iterator witer = wrap->arg_begin ();
1444 llvm::Function::arg_iterator fiter = fn->arg_begin ();
1445 for (size_t i = 0; i < new_args.size (); ++i, ++witer, ++fiter)
1446 {
1447 llvm::Value *warg = witer;
1448 llvm::Value *arg = fiter;
1449 converted[i] = warg->getType () == arg->getType () ? arg
1450 : pack_complex (arg);
1451 }
1452
1453 llvm::Value *ret = builder.CreateCall (wrap, converted);
1454 if (wrap_type->getReturnType () != builder.getVoidTy ())
1455 {
1456 if (convert_ret)
1457 ret = unpack_complex (ret);
1458 builder.CreateRet (ret);
1459 }
1460 else
1461 builder.CreateRetVoid ();
1462
1463 llvm::verifyFunction (*fn);
1464 return fn;
1465 }
1466
1467 llvm::Value *
1468 jit_typeinfo::pack_complex (llvm::Value *cplx)
1469 {
1470 llvm::Value *real = builder.CreateExtractElement (cplx, builder.getInt32 (0));
1471 llvm::Value *imag = builder.CreateExtractElement (cplx, builder.getInt32 (1));
1472 llvm::Value *ret = llvm::UndefValue::get (complex_ret);
1473 ret = builder.CreateInsertValue (ret, real, 0);
1474 return builder.CreateInsertValue (ret, imag, 1);
1475 }
1476
1477 llvm::Value *
1478 jit_typeinfo::unpack_complex (llvm::Value *result)
1479 {
1480 llvm::Type *complex_t = complex->to_llvm ();
1481 llvm::Value *real = builder.CreateExtractValue (result, 0);
1482 llvm::Value *imag = builder.CreateExtractValue (result, 1);
1483 llvm::Value *ret = llvm::UndefValue::get (complex_t);
1484 ret = builder.CreateInsertElement (ret, real, builder.getInt32 (0));
1485 return builder.CreateInsertElement (ret, imag, builder.getInt32 (1));
1364 } 1486 }
1365 1487
1366 jit_type * 1488 jit_type *
1367 jit_typeinfo::do_type_of (const octave_value &ov) const 1489 jit_typeinfo::do_type_of (const octave_value &ov) const
1368 { 1490 {
3444 void 3566 void
3445 tree_jit::optimize (llvm::Function *fn) 3567 tree_jit::optimize (llvm::Function *fn)
3446 { 3568 {
3447 module_pass_manager->run (*module); 3569 module_pass_manager->run (*module);
3448 pass_manager->run (*fn); 3570 pass_manager->run (*fn);
3571
3572 #ifdef OCTAVE_JIT_DEBUG
3573 std::string error;
3574 llvm::raw_fd_ostream fout ("test.bc", error,
3575 llvm::raw_fd_ostream::F_Binary);
3576 llvm::WriteBitcodeToFile (module, fout);
3577 #endif
3449 } 3578 }
3450 3579
3451 // -------------------- jit_info -------------------- 3580 // -------------------- jit_info --------------------
3452 jit_info::jit_info (tree_jit& tjit, tree& tee) 3581 jit_info::jit_info (tree_jit& tjit, tree& tee)
3453 : engine (tjit.get_engine ()), llvm_function (0) 3582 : engine (tjit.get_engine ()), llvm_function (0)