Mercurial > hg > octave-nkf
comparison src/pt-jit.cc @ 14985:f5925478bc15
More support for complex numbers in JIT
* src/pt-jit.cc (octave_jit_cast_complex_any): Return result directly.
(octave_jit_complex_div, jit_typeinfo::wrap_complex,
jit_typeinfo::pack_complex, jit_typeinfo::unpack_complex): New function.
(jit_typeinfo::jit_typeinfo): Support more complex functionality.
(tree_jit::optimize): Write llvm bytecode to a file when debugging.
* src/pt-jit.h (jit_typeinfo::wrap_complex, jit_typeinfo::pack_complex,
jit_typeinfo): New declarations.
author | Max Brister <max@2bass.com> |
---|---|
date | Tue, 10 Jul 2012 15:55:05 -0500 |
parents | 561aad6a9e4b |
children | 70ff15b6d996 |
comparison
equal
deleted
inserted
replaced
14984:561aad6a9e4b | 14985:f5925478bc15 |
---|---|
48 #include <llvm/Target/TargetData.h> | 48 #include <llvm/Target/TargetData.h> |
49 #include <llvm/Transforms/Scalar.h> | 49 #include <llvm/Transforms/Scalar.h> |
50 #include <llvm/Transforms/IPO.h> | 50 #include <llvm/Transforms/IPO.h> |
51 #include <llvm/Support/TargetSelect.h> | 51 #include <llvm/Support/TargetSelect.h> |
52 #include <llvm/Support/raw_os_ostream.h> | 52 #include <llvm/Support/raw_os_ostream.h> |
53 #include <llvm/Support/FormattedStream.h> | |
54 #include <llvm/Bitcode/ReaderWriter.h> | |
53 | 55 |
54 #include "octave.h" | 56 #include "octave.h" |
55 #include "ov-fcn-handle.h" | 57 #include "ov-fcn-handle.h" |
56 #include "ov-usr-fcn.h" | 58 #include "ov-usr-fcn.h" |
57 #include "ov-builtin.h" | 59 #include "ov-builtin.h" |
197 octave_jit_cast_any_scalar (double value) | 199 octave_jit_cast_any_scalar (double value) |
198 { | 200 { |
199 return new octave_scalar (value); | 201 return new octave_scalar (value); |
200 } | 202 } |
201 | 203 |
202 extern "C" void | 204 extern "C" Complex |
203 octave_jit_cast_complex_any (double *dest, octave_base_value *obv) | 205 octave_jit_cast_complex_any (octave_base_value *obv) |
204 { | 206 { |
205 Complex ret = obv->complex_value (); | 207 Complex ret = obv->complex_value (); |
206 obv->release (); | 208 obv->release (); |
207 dest[0] = ret.real (); | 209 return ret; |
208 dest[1] = ret.imag (); | |
209 } | 210 } |
210 | 211 |
211 extern "C" octave_base_value * | 212 extern "C" octave_base_value * |
212 octave_jit_cast_any_complex (double real, double imag) | 213 octave_jit_cast_any_complex (Complex c) |
213 { | 214 { |
214 if (imag == 0) | 215 if (c.imag () == 0) |
215 return new octave_scalar (real); | 216 return new octave_scalar (c.real ()); |
216 else | 217 else |
217 return new octave_complex (Complex (real, imag)); | 218 return new octave_complex (c); |
218 } | 219 } |
219 | 220 |
220 extern "C" void | 221 extern "C" void |
221 octave_jit_gripe_nan_to_logical_conversion (void) | 222 octave_jit_gripe_nan_to_logical_conversion (void) |
222 { | 223 { |
318 } | 319 } |
319 | 320 |
320 result->update (array); | 321 result->update (array); |
321 } | 322 } |
322 | 323 |
324 extern "C" Complex | |
325 octave_jit_complex_div (Complex lhs, Complex rhs) | |
326 { | |
327 // see src/OPERATORS/op-cs-cs.cc | |
328 if (rhs == 0.0) | |
329 gripe_divide_by_zero (); | |
330 | |
331 return lhs / rhs; | |
332 } | |
333 | |
323 extern "C" void | 334 extern "C" void |
324 octave_jit_print_matrix (jit_matrix *m) | 335 octave_jit_print_matrix (jit_matrix *m) |
325 { | 336 { |
326 std::cout << *m << std::endl; | 337 std::cout << *m << std::endl; |
327 } | 338 } |
519 matrix_contents[3] = index_t->getPointerTo (); | 530 matrix_contents[3] = index_t->getPointerTo (); |
520 matrix_contents[4] = string_t; | 531 matrix_contents[4] = string_t; |
521 matrix_t->setBody (llvm::makeArrayRef (matrix_contents, 5)); | 532 matrix_t->setBody (llvm::makeArrayRef (matrix_contents, 5)); |
522 | 533 |
523 llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2); | 534 llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2); |
535 | |
536 // this is the structure that C functions return. Use this in order to get calling | |
537 // conventions right. | |
538 complex_ret = llvm::StructType::create (context, "complex_ret"); | |
539 llvm::Type *complex_ret_contents[] = {scalar_t, scalar_t}; | |
540 complex_ret->setBody (complex_ret_contents); | |
524 | 541 |
525 // create types | 542 // create types |
526 any = new_type ("any", 0, any_t); | 543 any = new_type ("any", 0, any_t); |
527 matrix = new_type ("matrix", any, matrix_t); | 544 matrix = new_type ("matrix", any, matrix_t); |
528 scalar = new_type ("scalar", any, scalar_t); | 545 scalar = new_type ("scalar", any, scalar_t); |
732 llvm::Value *mres = builder.CreateFMul (mlhs, mrhs); | 749 llvm::Value *mres = builder.CreateFMul (mlhs, mrhs); |
733 llvm::Value *ret = llvm::UndefValue::get (complex_t); | 750 llvm::Value *ret = llvm::UndefValue::get (complex_t); |
734 llvm::Value *tlhs = builder.CreateExtractElement (mres, zero); | 751 llvm::Value *tlhs = builder.CreateExtractElement (mres, zero); |
735 llvm::Value *trhs = builder.CreateExtractElement (mres, one); | 752 llvm::Value *trhs = builder.CreateExtractElement (mres, one); |
736 temp = builder.CreateFSub (tlhs, trhs); | 753 temp = builder.CreateFSub (tlhs, trhs); |
737 //temp = llvm::ConstantFP::get (scalar_t, 123); | |
738 ret = builder.CreateInsertElement (ret, temp, zero); | 754 ret = builder.CreateInsertElement (ret, temp, zero); |
739 | 755 |
740 tlhs = builder.CreateExtractElement (mres, two); | 756 tlhs = builder.CreateExtractElement (mres, two); |
741 trhs = builder.CreateExtractElement (mres, three); | 757 trhs = builder.CreateExtractElement (mres, three); |
742 temp = builder.CreateFAdd (tlhs, trhs); | 758 temp = builder.CreateFAdd (tlhs, trhs); |
743 //temp = llvm::ConstantFP::get (scalar_t, 123); | |
744 ret = builder.CreateInsertElement (ret, temp, one); | 759 ret = builder.CreateInsertElement (ret, temp, one); |
745 builder.CreateRet (ret); | 760 builder.CreateRet (ret); |
746 | 761 |
747 jit_operation::overload ol (fn, false, complex, complex, complex); | 762 jit_operation::overload ol (fn, false, complex, complex, complex); |
748 binary_ops[octave_value::op_mul].add_overload (ol); | 763 binary_ops[octave_value::op_mul].add_overload (ol); |
749 binary_ops[octave_value::op_el_mul].add_overload (ol); | 764 binary_ops[octave_value::op_el_mul].add_overload (ol); |
765 } | |
766 llvm::verifyFunction (*fn); | |
767 | |
768 fn = create_function ("octave_jit_*_scalar_complex", complex, scalar, | |
769 complex); | |
770 llvm::Function *mul_scalar_complex = fn; | |
771 body = llvm::BasicBlock::Create (context, "body", fn); | |
772 builder.SetInsertPoint (body); | |
773 { | |
774 llvm::Value *lhs = fn->arg_begin (); | |
775 llvm::Value *tlhs = llvm::UndefValue::get (complex_t); | |
776 tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (0)); | |
777 tlhs = builder.CreateInsertElement (tlhs, lhs, builder.getInt32 (1)); | |
778 | |
779 llvm::Value *rhs = ++fn->arg_begin (); | |
780 builder.CreateRet (builder.CreateFMul (tlhs, rhs)); | |
781 | |
782 jit_operation::overload ol (fn, false, complex, scalar, complex); | |
783 binary_ops[octave_value::op_mul].add_overload (ol); | |
784 binary_ops[octave_value::op_el_mul].add_overload (ol); | |
785 } | |
786 llvm::verifyFunction (*fn); | |
787 | |
788 fn = create_function ("octave_jit_*_complex_scalar", complex, complex, | |
789 scalar); | |
790 body = llvm::BasicBlock::Create (context, "body", fn); | |
791 builder.SetInsertPoint (body); | |
792 { | |
793 llvm::Value *ret = builder.CreateCall2 (mul_scalar_complex, | |
794 ++fn->arg_begin (), | |
795 fn->arg_begin ()); | |
796 builder.CreateRet (ret); | |
797 | |
798 jit_operation::overload ol (fn, false, complex, complex, scalar); | |
799 binary_ops[octave_value::op_mul].add_overload (ol); | |
800 binary_ops[octave_value::op_el_mul].add_overload (ol); | |
801 } | |
802 llvm::verifyFunction (*fn); | |
803 | |
804 llvm::Function *complex_div = create_function ("octave_jit_complex_div", | |
805 complex_ret, complex_ret, | |
806 complex_ret); | |
807 engine->addGlobalMapping (complex_div, | |
808 reinterpret_cast<void *> (&octave_jit_complex_div)); | |
809 complex_div = wrap_complex (complex_div); | |
810 { | |
811 jit_operation::overload ol (complex_div, true, complex, complex, complex); | |
812 binary_ops[octave_value::op_div].add_overload (ol); | |
813 binary_ops[octave_value::op_ldiv].add_overload (ol); | |
814 } | |
815 | |
816 fn = create_function ("octave_jit_\\_complex_complex", complex, complex, | |
817 complex); | |
818 body = llvm::BasicBlock::Create (context, "body", fn); | |
819 builder.SetInsertPoint (body); | |
820 { | |
821 builder.CreateRet (builder.CreateCall2 (complex_div, ++fn->arg_begin (), | |
822 fn->arg_begin ())); | |
823 jit_operation::overload ol (fn, true, complex, complex, complex); | |
824 binary_ops[octave_value::op_ldiv].add_overload (ol); | |
825 binary_ops[octave_value::op_el_ldiv].add_overload (ol); | |
750 } | 826 } |
751 llvm::verifyFunction (*fn); | 827 llvm::verifyFunction (*fn); |
752 | 828 |
753 // now for binary index operators | 829 // now for binary index operators |
754 add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); | 830 add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); |
1087 fn = create_function ("octave_jit_cast_scalar_any", scalar, any); | 1163 fn = create_function ("octave_jit_cast_scalar_any", scalar, any); |
1088 engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_scalar_any)); | 1164 engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_scalar_any)); |
1089 casts[scalar->type_id ()].add_overload (fn, false, scalar, any); | 1165 casts[scalar->type_id ()].add_overload (fn, false, scalar, any); |
1090 | 1166 |
1091 // cast any <- complex | 1167 // cast any <- complex |
1092 llvm::Function *any_complex = create_function ("octave_jit_cast_any_complex", | 1168 fn = create_function ("octave_jit_cast_any_complex", any_t, complex_ret); |
1093 any, scalar, scalar); | 1169 engine->addGlobalMapping (fn, reinterpret_cast<void*> (&octave_jit_cast_any_complex)); |
1094 engine->addGlobalMapping (any_complex, reinterpret_cast<void*> (&octave_jit_cast_any_complex)); | 1170 casts[any->type_id ()].add_overload (wrap_complex (fn), false, any, complex); |
1095 fn = create_function ("cast_any_complex", any, complex); | |
1096 body = llvm::BasicBlock::Create (context, "body", fn); | |
1097 builder.SetInsertPoint (body); | |
1098 { | |
1099 llvm::Value *zero = builder.getInt32 (0); | |
1100 llvm::Value *one = builder.getInt32 (1); | |
1101 | |
1102 llvm::Value *cmplx = fn->arg_begin (); | |
1103 llvm::Value *real = builder.CreateExtractElement (cmplx, zero); | |
1104 llvm::Value *imag = builder.CreateExtractElement (cmplx, one); | |
1105 llvm::Value *ret = builder.CreateCall2 (any_complex, real, imag); | |
1106 builder.CreateRet (ret); | |
1107 } | |
1108 llvm::verifyFunction (*fn); | |
1109 casts[any->type_id ()].add_overload (fn, false, any, complex); | |
1110 | 1171 |
1111 // cast complex <- any | 1172 // cast complex <- any |
1112 llvm::Function *complex_any = create_function ("octave_jit_cast_complex_any", | 1173 fn = create_function ("octave_jit_cast_complex_any", complex_ret, any_t); |
1113 void_t, | 1174 engine->addGlobalMapping (fn, reinterpret_cast<void *> (&octave_jit_cast_complex_any)); |
1114 complex_t->getPointerTo (), | 1175 casts[complex->type_id ()].add_overload (wrap_complex (fn), false, complex, |
1115 any_t); | 1176 any); |
1116 fn = create_function ("cast_complex_any", complex, any); | |
1117 body = llvm::BasicBlock::Create (context, "body", fn); | |
1118 builder.SetInsertPoint (body); | |
1119 { | |
1120 llvm::Value *result = builder.CreateAlloca (complex_t); | |
1121 builder.CreateCall2 (complex_any, result, fn->arg_begin ()); | |
1122 builder.CreateRet (builder.CreateLoad (result)); | |
1123 } | |
1124 llvm::verifyFunction (*fn); | |
1125 casts[complex->type_id ()].add_overload (fn, false, complex, any); | |
1126 | 1177 |
1127 // cast any <- any | 1178 // cast any <- any |
1128 fn = create_identity (any); | 1179 fn = create_identity (any); |
1129 casts[any->type_id ()].add_overload (fn, false, any, any); | 1180 casts[any->type_id ()].add_overload (fn, false, any, any); |
1130 | 1181 |
1359 void | 1410 void |
1360 jit_typeinfo::register_generic (const std::string&, jit_type *, | 1411 jit_typeinfo::register_generic (const std::string&, jit_type *, |
1361 const std::vector<jit_type *>&) | 1412 const std::vector<jit_type *>&) |
1362 { | 1413 { |
1363 // FIXME: Implement | 1414 // FIXME: Implement |
1415 } | |
1416 | |
1417 llvm::Function * | |
1418 jit_typeinfo::wrap_complex (llvm::Function *wrap) | |
1419 { | |
1420 llvm::SmallVector<llvm::Type *, 5> new_args; | |
1421 new_args.reserve (wrap->arg_size ()); | |
1422 llvm::Type *complex_t = complex->to_llvm (); | |
1423 for (llvm::Function::arg_iterator iter = wrap->arg_begin (); | |
1424 iter != wrap->arg_end (); ++iter) | |
1425 { | |
1426 llvm::Value *value = iter; | |
1427 llvm::Type *type = value->getType (); | |
1428 new_args.push_back (type == complex_ret ? complex_t : type); | |
1429 } | |
1430 | |
1431 llvm::FunctionType *wrap_type = wrap->getFunctionType (); | |
1432 bool convert_ret = wrap_type->getReturnType () == complex_ret; | |
1433 llvm::Type *rtype = convert_ret ? complex_t : wrap->getReturnType (); | |
1434 llvm::FunctionType *ft = llvm::FunctionType::get (rtype, new_args, false); | |
1435 llvm::Function *fn = llvm::Function::Create (ft, | |
1436 llvm::Function::ExternalLinkage, | |
1437 wrap->getName () + "_wrap", | |
1438 module); | |
1439 llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); | |
1440 builder.SetInsertPoint (body); | |
1441 | |
1442 llvm::SmallVector<llvm::Value *, 5> converted (new_args.size ()); | |
1443 llvm::Function::arg_iterator witer = wrap->arg_begin (); | |
1444 llvm::Function::arg_iterator fiter = fn->arg_begin (); | |
1445 for (size_t i = 0; i < new_args.size (); ++i, ++witer, ++fiter) | |
1446 { | |
1447 llvm::Value *warg = witer; | |
1448 llvm::Value *arg = fiter; | |
1449 converted[i] = warg->getType () == arg->getType () ? arg | |
1450 : pack_complex (arg); | |
1451 } | |
1452 | |
1453 llvm::Value *ret = builder.CreateCall (wrap, converted); | |
1454 if (wrap_type->getReturnType () != builder.getVoidTy ()) | |
1455 { | |
1456 if (convert_ret) | |
1457 ret = unpack_complex (ret); | |
1458 builder.CreateRet (ret); | |
1459 } | |
1460 else | |
1461 builder.CreateRetVoid (); | |
1462 | |
1463 llvm::verifyFunction (*fn); | |
1464 return fn; | |
1465 } | |
1466 | |
1467 llvm::Value * | |
1468 jit_typeinfo::pack_complex (llvm::Value *cplx) | |
1469 { | |
1470 llvm::Value *real = builder.CreateExtractElement (cplx, builder.getInt32 (0)); | |
1471 llvm::Value *imag = builder.CreateExtractElement (cplx, builder.getInt32 (1)); | |
1472 llvm::Value *ret = llvm::UndefValue::get (complex_ret); | |
1473 ret = builder.CreateInsertValue (ret, real, 0); | |
1474 return builder.CreateInsertValue (ret, imag, 1); | |
1475 } | |
1476 | |
1477 llvm::Value * | |
1478 jit_typeinfo::unpack_complex (llvm::Value *result) | |
1479 { | |
1480 llvm::Type *complex_t = complex->to_llvm (); | |
1481 llvm::Value *real = builder.CreateExtractValue (result, 0); | |
1482 llvm::Value *imag = builder.CreateExtractValue (result, 1); | |
1483 llvm::Value *ret = llvm::UndefValue::get (complex_t); | |
1484 ret = builder.CreateInsertElement (ret, real, builder.getInt32 (0)); | |
1485 return builder.CreateInsertElement (ret, imag, builder.getInt32 (1)); | |
1364 } | 1486 } |
1365 | 1487 |
1366 jit_type * | 1488 jit_type * |
1367 jit_typeinfo::do_type_of (const octave_value &ov) const | 1489 jit_typeinfo::do_type_of (const octave_value &ov) const |
1368 { | 1490 { |
3444 void | 3566 void |
3445 tree_jit::optimize (llvm::Function *fn) | 3567 tree_jit::optimize (llvm::Function *fn) |
3446 { | 3568 { |
3447 module_pass_manager->run (*module); | 3569 module_pass_manager->run (*module); |
3448 pass_manager->run (*fn); | 3570 pass_manager->run (*fn); |
3571 | |
3572 #ifdef OCTAVE_JIT_DEBUG | |
3573 std::string error; | |
3574 llvm::raw_fd_ostream fout ("test.bc", error, | |
3575 llvm::raw_fd_ostream::F_Binary); | |
3576 llvm::WriteBitcodeToFile (module, fout); | |
3577 #endif | |
3449 } | 3578 } |
3450 | 3579 |
3451 // -------------------- jit_info -------------------- | 3580 // -------------------- jit_info -------------------- |
3452 jit_info::jit_info (tree_jit& tjit, tree& tee) | 3581 jit_info::jit_info (tree_jit& tjit, tree& tee) |
3453 : engine (tjit.get_engine ()), llvm_function (0) | 3582 : engine (tjit.get_engine ()), llvm_function (0) |