comparison src/pt-jit.cc @ 14920:51d4b1018efb

For loops compile with new IR * src/pt-eval.cc (tree_evaluator::visit_simple_for_command): Compile loops. (tree_evaluator::visit_statement): No longer compile individual statements. * src/pt-loop.h (tree_simple_for_command::get_info): Remove type map. (tree_simple_for_command::stash_info): Remove type map. * src/pt-loop.cc (tree_simple_for_command::~tree_simple_for_command): Delete compiled code instead of map.
author Max Brister <max@2bass.com>
date Sat, 26 May 2012 20:30:28 -0500
parents 13465aab507f
children 2e6f83b2f2b9
comparison
equal deleted inserted replaced
14919:f0499b0af646 14920:51d4b1018efb
63 static llvm::LLVMContext& context = llvm::getGlobalContext (); 63 static llvm::LLVMContext& context = llvm::getGlobalContext ();
64 64
65 jit_typeinfo *jit_typeinfo::instance; 65 jit_typeinfo *jit_typeinfo::instance;
66 66
67 // thrown when we should give up on JIT and interpret 67 // thrown when we should give up on JIT and interpret
68 class jit_fail_exception : public std::exception {}; 68 class jit_fail_exception : public std::runtime_error
69 {
70 public:
71 jit_fail_exception (void) : std::runtime_error ("unknown"), mknown (false) {}
72 jit_fail_exception (const std::string& reason) : std::runtime_error (reason),
73 mknown (true)
74 {}
75
76 bool known (void) const { return mknown; }
77 private:
78 bool mknown;
79 };
69 80
70 static void 81 static void
71 fail (void) 82 fail (void)
72 { 83 {
73 throw jit_fail_exception (); 84 throw jit_fail_exception ();
85 }
86
87 static void
88 fail (const std::string& reason)
89 {
90 throw jit_fail_exception (reason);
74 } 91 }
75 92
76 // function that jit code calls 93 // function that jit code calls
77 extern "C" void 94 extern "C" void
78 octave_jit_print_any (const char *name, octave_base_value *obv) 95 octave_jit_print_any (const char *name, octave_base_value *obv)
123 140
124 extern "C" octave_base_value * 141 extern "C" octave_base_value *
125 octave_jit_cast_any_scalar (double value) 142 octave_jit_cast_any_scalar (double value)
126 { 143 {
127 return new octave_scalar (value); 144 return new octave_scalar (value);
145 }
146
147 // -------------------- jit_range --------------------
148 std::ostream&
149 operator<< (std::ostream& os, const jit_range& rng)
150 {
151 return os << "Range[" << rng.base << ", " << rng.limit << ", " << rng.inc
152 << ", " << rng.nelem << "]";
128 } 153 }
129 154
130 // -------------------- jit_type -------------------- 155 // -------------------- jit_type --------------------
131 llvm::Type * 156 llvm::Type *
132 jit_type::to_llvm_arg (void) const 157 jit_type::to_llvm_arg (void) const
306 331
307 // grab scalar 332 // grab scalar
308 fn = create_identity (scalar); 333 fn = create_identity (scalar);
309 grab_fn.add_overload (fn, false, scalar, scalar); 334 grab_fn.add_overload (fn, false, scalar, scalar);
310 335
336 // grab index
337 fn = create_identity (index);
338 grab_fn.add_overload (fn, false, index, index);
339
311 // release any 340 // release any
312 fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ()); 341 fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ());
313 engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_release_any)); 342 engine->addGlobalMapping (fn, reinterpret_cast<void*>(&octave_jit_release_any));
314 release_fn.add_overload (fn, false, 0, any); 343 release_fn.add_overload (fn, false, 0, any);
315 release_fn.stash_name ("release"); 344 release_fn.stash_name ("release");
316 345
317 // release scalar 346 // release scalar
318 fn = create_identity (scalar); 347 fn = create_identity (scalar);
319 release_fn.add_overload (fn, false, 0, scalar); 348 release_fn.add_overload (fn, false, 0, scalar);
349
350 // release index
351 fn = create_identity (index);
352 release_fn.add_overload (fn, false, 0, index);
320 353
321 // now for binary scalar operations 354 // now for binary scalar operations
322 // FIXME: Finish all operations 355 // FIXME: Finish all operations
323 add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd); 356 add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd);
324 add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub); 357 add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub);
334 add_binary_fcmp (scalar, octave_value::op_eq, llvm::CmpInst::FCMP_UEQ); 367 add_binary_fcmp (scalar, octave_value::op_eq, llvm::CmpInst::FCMP_UEQ);
335 add_binary_fcmp (scalar, octave_value::op_ge, llvm::CmpInst::FCMP_UGE); 368 add_binary_fcmp (scalar, octave_value::op_ge, llvm::CmpInst::FCMP_UGE);
336 add_binary_fcmp (scalar, octave_value::op_gt, llvm::CmpInst::FCMP_UGT); 369 add_binary_fcmp (scalar, octave_value::op_gt, llvm::CmpInst::FCMP_UGT);
337 add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); 370 add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE);
338 371
372 // now for binary index operators
373 add_binary_op (index, octave_value::op_add, llvm::Instruction::Add);
374
339 // now for printing functions 375 // now for printing functions
340 print_fn.stash_name ("print"); 376 print_fn.stash_name ("print");
341 add_print (any, reinterpret_cast<void*> (&octave_jit_print_any)); 377 add_print (any, reinterpret_cast<void*> (&octave_jit_print_any));
342 add_print (scalar, reinterpret_cast<void*> (&octave_jit_print_double)); 378 add_print (scalar, reinterpret_cast<void*> (&octave_jit_print_double));
343 379
380 // initialize for loop
381 for_init_fn.stash_name ("for_init");
382
383 fn = create_function ("octave_jit_for_range_init", index, range);
384 llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn);
385 builder.SetInsertPoint (body);
386 {
387 llvm::Value *zero = llvm::ConstantInt::get (index_t, 0);
388 builder.CreateRet (zero);
389 }
390 llvm::verifyFunction (*fn);
391 for_init_fn.add_overload (fn, false, index, range);
392
344 // bounds check for for loop 393 // bounds check for for loop
345 fn = create_function ("octave_jit_simple_for_range", boolean, range, index); 394 for_check_fn.stash_name ("for_check");
346 llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); 395
396 fn = create_function ("octave_jit_for_range_check", boolean, range, index);
397 body = llvm::BasicBlock::Create (context, "body", fn);
347 builder.SetInsertPoint (body); 398 builder.SetInsertPoint (body);
348 { 399 {
349 llvm::Value *nelem 400 llvm::Value *nelem
350 = builder.CreateExtractValue (fn->arg_begin (), 3); 401 = builder.CreateExtractValue (fn->arg_begin (), 3);
351 // llvm::Value *idx = builder.CreateLoad (++fn->arg_begin ());
352 llvm::Value *idx = ++fn->arg_begin (); 402 llvm::Value *idx = ++fn->arg_begin ();
353 llvm::Value *ret = builder.CreateICmpULT (idx, nelem); 403 llvm::Value *ret = builder.CreateICmpULT (idx, nelem);
354 builder.CreateRet (ret); 404 builder.CreateRet (ret);
355 } 405 }
356 llvm::verifyFunction (*fn); 406 llvm::verifyFunction (*fn);
357 simple_for_check.add_overload (fn, false, boolean, range, index); 407 for_check_fn.add_overload (fn, false, boolean, range, index);
358
359 // increment for for loop
360 fn = create_function ("octave_jit_imple_for_range_incr", index, index);
361 body = llvm::BasicBlock::Create (context, "body", fn);
362 builder.SetInsertPoint (body);
363 {
364 llvm::Value *one = llvm::ConstantInt::get (index_t, 1);
365 llvm::Value *idx = fn->arg_begin ();
366 llvm::Value *ret = builder.CreateAdd (idx, one);
367 builder.CreateRet (ret);
368 }
369 llvm::verifyFunction (*fn);
370 simple_for_incr.add_overload (fn, false, index, index);
371 408
372 // index variabe for for loop 409 // index variabe for for loop
373 fn = create_function ("octave_jit_simple_for_idx", scalar, range, index); 410 for_index_fn.stash_name ("for_index");
411
412 fn = create_function ("octave_jit_for_range_idx", scalar, range, index);
374 body = llvm::BasicBlock::Create (context, "body", fn); 413 body = llvm::BasicBlock::Create (context, "body", fn);
375 builder.SetInsertPoint (body); 414 builder.SetInsertPoint (body);
376 { 415 {
377 llvm::Value *idx = ++fn->arg_begin (); 416 llvm::Value *idx = ++fn->arg_begin ();
378 llvm::Value *didx = builder.CreateUIToFP (idx, dbl); 417 llvm::Value *didx = builder.CreateUIToFP (idx, dbl);
383 llvm::Value *ret = builder.CreateFMul (didx, inc); 422 llvm::Value *ret = builder.CreateFMul (didx, inc);
384 ret = builder.CreateFAdd (base, ret); 423 ret = builder.CreateFAdd (base, ret);
385 builder.CreateRet (ret); 424 builder.CreateRet (ret);
386 } 425 }
387 llvm::verifyFunction (*fn); 426 llvm::verifyFunction (*fn);
388 simple_for_index.add_overload (fn, false, scalar, range, index); 427 for_index_fn.add_overload (fn, false, scalar, range, index);
389 428
390 // logically true 429 // logically true
391 // FIXME: Check for NaN 430 // FIXME: Check for NaN
392 fn = create_function ("octave_logically_true_scalar", boolean, scalar); 431 fn = create_function ("octave_logically_true_scalar", boolean, scalar);
393 body = llvm::BasicBlock::Create (context, "body", fn); 432 body = llvm::BasicBlock::Create (context, "body", fn);
567 jit_type *ret = new jit_type (name, parent, llvm_type, next_id++); 606 jit_type *ret = new jit_type (name, parent, llvm_type, next_id++);
568 id_to_type.push_back (ret); 607 id_to_type.push_back (ret);
569 return ret; 608 return ret;
570 } 609 }
571 610
611 // -------------------- jit_use --------------------
612 jit_block *
613 jit_use::user_parent (void) const
614 {
615 return usr->parent ();
616 }
617
618 // -------------------- jit_value --------------------
619 #define JIT_METH(clname) \
620 void \
621 jit_ ## clname::accept (jit_ir_walker& walker) \
622 { \
623 walker.visit (*this); \
624 }
625
626 JIT_VISIT_IR_NOTEMPLATE
627 #undef JIT_METH
628
629 // -------------------- jit_instruction --------------------
630 llvm::BasicBlock *
631 jit_instruction::parent_llvm (void) const
632 {
633 return mparent->to_llvm ();
634 }
635
572 // -------------------- jit_block -------------------- 636 // -------------------- jit_block --------------------
637 jit_instruction *
638 jit_block::prepend (jit_instruction *instr)
639 {
640 instructions.push_front (instr);
641 instr->stash_parent (this);
642 return instr;
643 }
644
645 jit_instruction *
646 jit_block::append (jit_instruction *instr)
647 {
648 instructions.push_back (instr);
649 instr->stash_parent (this);
650 return instr;
651 }
652
653 jit_terminator *
654 jit_block::terminator (void) const
655 {
656 if (instructions.empty ())
657 return 0;
658
659 jit_instruction *last = instructions.back ();
660 return dynamic_cast<jit_terminator *> (last);
661 }
662
663 llvm::Value *
664 jit_block::pred_terminator_llvm (size_t idx) const
665 {
666 jit_terminator *term = pred_terminator (idx);
667 return term ? term->to_llvm () : 0;
668 }
669
670 void
671 jit_block::create_merge (llvm::Function *inside, size_t pred_idx)
672 {
673 mpred_llvm.resize (pred_count ());
674
675 jit_block *ipred = pred (pred_idx);
676 if (! mpred_llvm[pred_idx] && ipred->pred_count () > 1)
677 {
678 llvm::BasicBlock *merge;
679 merge = llvm::BasicBlock::Create (context, "phi_merge", inside,
680 to_llvm ());
681
682 // fix the predecessor jump if it has been created
683 llvm::Value *term = pred_terminator_llvm (pred_idx);
684 if (term)
685 {
686 llvm::TerminatorInst *branch = llvm::cast<llvm::TerminatorInst> (term);
687 for (size_t i = 0; i < branch->getNumSuccessors (); ++i)
688 {
689 if (branch->getSuccessor (i) == to_llvm ())
690 branch->setSuccessor (i, merge);
691 }
692 }
693
694 llvm::IRBuilder<> temp (merge);
695 temp.CreateBr (to_llvm ());
696 mpred_llvm[pred_idx] = merge;
697 }
698 }
699
700 size_t
701 jit_block::succ_count (void) const
702 {
703 jit_terminator *term = terminator ();
704 return term ? term->sucessor_count () : 0;
705 }
706
573 llvm::BasicBlock * 707 llvm::BasicBlock *
574 jit_block::to_llvm (void) const 708 jit_block::to_llvm (void) const
575 { 709 {
576 return llvm::cast<llvm::BasicBlock> (llvm_value); 710 return llvm::cast<llvm::BasicBlock> (llvm_value);
577 } 711 }
607 // -------------------- jit_convert -------------------- 741 // -------------------- jit_convert --------------------
608 jit_convert::jit_convert (llvm::Module *module, tree &tee) 742 jit_convert::jit_convert (llvm::Module *module, tree &tee)
609 { 743 {
610 jit_instruction::reset_ids (); 744 jit_instruction::reset_ids ();
611 745
612 entry_block = new jit_block ("entry"); 746 jit_block *entry_block = new jit_block ("body");
613 blocks.push_back (entry_block); 747 block = entry_block;
614 block = new jit_block ("body");
615 blocks.push_back (block); 748 blocks.push_back (block);
616 749
750 toplevel_map tlevel (block);
751 variables = &tlevel;
617 final_block = new jit_block ("final"); 752 final_block = new jit_block ("final");
618 visit (tee); 753 visit (tee);
754
619 blocks.push_back (final_block); 755 blocks.push_back (final_block);
620
621 entry_block->append (new jit_break (block));
622 block->append (new jit_break (final_block)); 756 block->append (new jit_break (final_block));
623 757
624 for (variable_map::iterator iter = variables.begin (); 758 for (variable_map::iterator iter = variables->begin ();
625 iter != variables.end (); ++iter) 759 iter != variables->end (); ++iter)
626 final_block->append (new jit_store_argument (iter->first, iter->second)); 760 final_block->append (new jit_store_argument (iter->first, iter->second));
627 761
628 // FIXME: Maybe we should remove dead code here? 762 // FIXME: Maybe we should remove dead code here?
629 763
630 // initialize the worklist to instructions derived from constants 764 // initialize the worklist to instructions derived from constants
631 for (std::list<jit_value *>::iterator iter = constants.begin (); 765 for (std::list<jit_value *>::iterator iter = constants.begin ();
632 iter != constants.end (); ++iter) 766 iter != constants.end (); ++iter)
633 append_users (*iter); 767 append_users (*iter);
768
769 // also get anything from jit_extract_argument, as these have constant types
770 for (jit_block::iterator iter = entry_block->begin ();
771 iter != entry_block->end (); ++iter)
772 {
773 jit_instruction *instr = *iter;
774 if (jit_extract_argument *extract = dynamic_cast<jit_extract_argument *>(instr))
775 {
776 if (! extract->type ())
777 fail (); // we depend on an unknown type
778 append_users (extract);
779 }
780 }
634 781
635 // FIXME: Describe algorithm here 782 // FIXME: Describe algorithm here
636 while (worklist.size ()) 783 while (worklist.size ())
637 { 784 {
638 jit_instruction *next = worklist.front (); 785 jit_instruction *next = worklist.front ();
651 iter != blocks.end (); ++iter) 798 iter != blocks.end (); ++iter)
652 (*iter)->print (std::cout, 0); 799 (*iter)->print (std::cout, 0);
653 std::cout << std::endl; 800 std::cout << std::endl;
654 } 801 }
655 802
803 // for now just init arguments from entry, later we will have to do something
804 // more interesting
805 for (jit_block::iterator iter = entry_block->begin ();
806 iter != entry_block->end (); ++iter)
807 {
808 if (jit_extract_argument *extract = dynamic_cast<jit_extract_argument *> (*iter))
809 arguments.push_back (std::make_pair (extract->tag (), true));
810 }
811
656 convert_llvm to_llvm; 812 convert_llvm to_llvm;
657 function = to_llvm.convert (module, arguments, blocks, constants); 813 function = to_llvm.convert (module, arguments, blocks, constants);
658 814
659 if (debug_print) 815 if (debug_print)
660 { 816 {
661 std::cout << "-------------------- llvm ir --------------------"; 817 std::cout << "-------------------- llvm ir --------------------";
662 llvm::raw_os_ostream llvm_cout (std::cout); 818 llvm::raw_os_ostream llvm_cout (std::cout);
663 function->print (llvm_cout); 819 function->print (llvm_cout);
664 std::cout << std::endl; 820 std::cout << std::endl;
821 llvm::verifyFunction (*function);
665 } 822 }
666 } 823 }
667 824
668 void 825 void
669 jit_convert::visit_anon_fcn_handle (tree_anon_fcn_handle&) 826 jit_convert::visit_anon_fcn_handle (tree_anon_fcn_handle&)
735 { 892 {
736 fail (); 893 fail ();
737 } 894 }
738 895
739 void 896 void
740 jit_convert::visit_simple_for_command (tree_simple_for_command&) 897 jit_convert::visit_simple_for_command (tree_simple_for_command& cmd)
741 { 898 {
742 fail (); 899 // how a for statement is compiled. Note we do an initial check
900 // to see if the loop will run atleast once. This allows us to get
901 // better type inference bounds on variables defined and used only
902 // inside the for loop (e.g. the index variable)
903
904 // prev_block: % pred = ?
905 // #control.0 = % compute_control (note this will just be a temp)
906 // #iter.0 = call for_init (#control.0) % Let type of control decide iter
907 // % initial value and type
908 // #temp.0 = call for_check (control.0, #iter.0)
909 // cond_break #temp.0, for_body, for_tail
910 // for_body: % pred = for_init, for_cond
911 // idxvar.2 = phi | for_init -> idxvar.1
912 // | for_body -> idxvar.3
913 // #iter.1 = phi | for_init -> #iter.0
914 // | for_body -> #iter.2
915 // idxvar.3 = call for_index (#control.0, #iter.1)
916 // % do loop body
917 // #iter.2 = #iter.1 + 1 % release is implicit in iter reuse
918 // #check = call for_check (#control.0, iter.2)
919 // cond_break #check for_body, for_tail
920 // for_tail: % pred = prev_block, for_body
921 // #iter.3 = phi | prev_block -> #iter.0
922 // | for_body -> #iter.2
923 // idxvar.4 = phi | prev_block -> idxvar.0
924 // | for_body -> idxvar.3
925 // call release (#iter.3)
926 // % rest of code
927
928 // FIXME: one of these days we will introduce proper lvalues...
929 tree_identifier *lhs = dynamic_cast<tree_identifier *>(cmd.left_hand_side ());
930 if (! lhs)
931 fail ();
932 std::string lhs_name = lhs->name ();
933
934 jit_block *body = new jit_block ("for_body");
935 blocks.push_back (body);
936
937 jit_block *tail = new jit_block ("for_tail");
938 unwind_protect prot_tail;
939 prot_tail.add_delete (tail); // incase we fail before adding tail to blocks
940
941 // do control expression, iter init, and condition check in prev_block (block)
942 jit_value *control = visit (cmd.control_expr ());
943 jit_call *init_iter = new jit_call (jit_typeinfo::for_init, control);
944 init_iter->stash_tag ("#iter");
945 block->append (init_iter);
946 jit_value *check = block->append (new jit_call (jit_typeinfo::for_check,
947 control, init_iter));
948 block->append (new jit_cond_break (check, body, tail));
949
950 // we need to do iter phi manually, for_map handles the rest
951 jit_phi *iter_phi = new jit_phi (2);
952 iter_phi->stash_tag ("#iter");
953 iter_phi->stash_argument (1, init_iter);
954 body->append (iter_phi);
955
956 variable_map *merge_vars = variables;
957 for_map body_vars (variables, body);
958 variables = &body_vars;
959 block = body;
960
961 // first thing we do in the for loop is bind our index from our itertor
962 jit_call *idx_rhs = new jit_call (jit_typeinfo::for_index, control, iter_phi);
963 block->append (idx_rhs);
964 idx_rhs->stash_tag (lhs_name);
965 do_assign (lhs_name, idx_rhs, false);
966
967 tree_statement_list *pt_body = cmd.body ();
968 pt_body->accept (*this);
969
970 // increment iterator, check conditional, and repeat
971 const jit_function& add_fn = jit_typeinfo::binary_op (octave_value::op_add);
972 jit_call *iter_inc = new jit_call (add_fn, iter_phi,
973 get_const<jit_const_index> (1));
974 iter_inc->stash_tag ("#iter");
975 block->append (iter_inc);
976 check = block->append (new jit_call (jit_typeinfo::for_check, control,
977 iter_inc));
978 block->append (new jit_cond_break (check, body, tail));
979 iter_phi->stash_argument (0, iter_inc);
980 body_vars.finish_phi (*variables);
981
982 blocks.push_back (tail);
983 prot_tail.discard ();
984 block = tail;
985
986 variables = merge_vars;
987 merge (body_vars);
988 iter_phi = new jit_phi (2);
989 iter_phi->stash_tag ("#iter");
990 iter_phi->stash_argument (0, iter_inc);
991 iter_phi->stash_argument (1, init_iter);
992 block->append (iter_phi);
993 block->append (new jit_call (jit_typeinfo::release, iter_phi));
743 } 994 }
744 995
745 void 996 void
746 jit_convert::visit_complex_for_command (tree_complex_for_command&) 997 jit_convert::visit_complex_for_command (tree_complex_for_command&)
747 { 998 {
779 } 1030 }
780 1031
781 void 1032 void
782 jit_convert::visit_identifier (tree_identifier& ti) 1033 jit_convert::visit_identifier (tree_identifier& ti)
783 { 1034 {
784 std::string name = ti.name ();
785 variable_map::iterator iter = variables.find (name);
786 jit_value *var;
787 if (iter == variables.end ())
788 {
789 octave_value var_value = ti.do_lookup ();
790 jit_type *var_type = jit_typeinfo::type_of (var_value);
791 var = entry_block->append (new jit_extract_argument (var_type, name));
792 constants.push_back (var);
793 bounds.push_back (std::make_pair (var_type, name));
794 variables[name] = var;
795 arguments.push_back (std::make_pair (name, true));
796 }
797 else
798 var = iter->second;
799
800 const jit_function& fn = jit_typeinfo::grab (); 1035 const jit_function& fn = jit_typeinfo::grab ();
1036 jit_value *var = variables->get (ti.name ());
801 result = block->append (new jit_call (fn, var)); 1037 result = block->append (new jit_call (fn, var));
802 } 1038 }
803 1039
804 void 1040 void
805 jit_convert::visit_if_clause (tree_if_clause&) 1041 jit_convert::visit_if_clause (tree_if_clause&)
854 { 1090 {
855 octave_value v = tc.rvalue1 (); 1091 octave_value v = tc.rvalue1 ();
856 if (v.is_real_scalar () && v.is_double_type ()) 1092 if (v.is_real_scalar () && v.is_double_type ())
857 { 1093 {
858 double dv = v.double_value (); 1094 double dv = v.double_value ();
859 result = get_scalar (dv); 1095 result = get_const<jit_const_scalar> (dv);
860 } 1096 }
861 else if (v.is_range ()) 1097 else if (v.is_range ())
862 fail (); 1098 {
1099 Range rv = v.range_value ();
1100 result = get_const<jit_const_range> (rv);
1101 }
863 else 1102 else
864 fail (); 1103 fail ();
865 } 1104 }
866 1105
867 void 1106 void
949 else if (expr->is_identifier () && expr->print_result ()) 1188 else if (expr->is_identifier () && expr->print_result ())
950 { 1189 {
951 // FIXME: ugly hack, we need to come up with a way to pass 1190 // FIXME: ugly hack, we need to come up with a way to pass
952 // nargout to visit_identifier 1191 // nargout to visit_identifier
953 const jit_function& fn = jit_typeinfo::print_value (); 1192 const jit_function& fn = jit_typeinfo::print_value ();
954 jit_const_string *name = get_string (expr->name ()); 1193 jit_const_string *name = get_const<jit_const_string> (expr->name ());
955 block->append (new jit_call (fn, name, expr_result)); 1194 block->append (new jit_call (fn, name, expr_result));
956 } 1195 }
957 } 1196 }
958 } 1197 }
959 1198
960 void 1199 void
961 jit_convert::visit_statement_list (tree_statement_list&) 1200 jit_convert::visit_statement_list (tree_statement_list& lst)
962 { 1201 {
963 fail (); 1202 for (tree_statement_list::iterator iter = lst.begin (); iter != lst.end();
1203 ++iter)
1204 {
1205 tree_statement *elt = *iter;
1206 // jwe: Can this ever be null?
1207 assert (elt);
1208 elt->accept (*this);
1209 }
964 } 1210 }
965 1211
966 void 1212 void
967 jit_convert::visit_switch_case (tree_switch_case&) 1213 jit_convert::visit_switch_case (tree_switch_case&)
968 { 1214 {
1006 } 1252 }
1007 1253
1008 void 1254 void
1009 jit_convert::do_assign (const std::string& lhs, jit_value *rhs, bool print) 1255 jit_convert::do_assign (const std::string& lhs, jit_value *rhs, bool print)
1010 { 1256 {
1011 variable_map::iterator iter = variables.find (lhs); 1257 const jit_function& release = jit_typeinfo::release ();
1012 if (iter == variables.end ()) 1258 jit_value *current = variables->get (lhs);
1013 arguments.push_back (std::make_pair (lhs, false)); 1259 block->append (new jit_call (release, current));
1014 else 1260 variables->set (lhs, rhs);
1015 {
1016 const jit_function& fn = jit_typeinfo::release ();
1017 block->append (new jit_call (fn, iter->second));
1018 }
1019
1020 variables[lhs] = rhs;
1021 1261
1022 if (print) 1262 if (print)
1023 { 1263 {
1024 const jit_function& fn = jit_typeinfo::print_value (); 1264 const jit_function& print_fn = jit_typeinfo::print_value ();
1025 jit_const_string *name = get_string (lhs); 1265 jit_const_string *name = get_const<jit_const_string> (lhs);
1026 block->append (new jit_call (fn, name, rhs)); 1266 block->append (new jit_call (print_fn, name, rhs));
1027 } 1267 }
1028 } 1268 }
1029 1269
1030 jit_value * 1270 jit_value *
1031 jit_convert::visit (tree& tee) 1271 jit_convert::visit (tree& tee)
1034 tee.accept (*this); 1274 tee.accept (*this);
1035 1275
1036 jit_value *ret = result; 1276 jit_value *ret = result;
1037 result = 0; 1277 result = 0;
1038 return ret; 1278 return ret;
1279 }
1280
1281 void
1282 jit_convert::merge (const variable_map& ref)
1283 {
1284 assert (variables->size () == ref.size ());
1285 variable_map::iterator viter = variables->begin ();
1286 variable_map::const_iterator riter = ref.begin ();
1287 for (; viter != variables->end (); ++viter, ++riter)
1288 {
1289 assert (viter->first == riter->first);
1290 if (viter->second != riter->second)
1291 {
1292 jit_phi *phi = new jit_phi (2);
1293 phi->stash_tag (viter->first);
1294 block->prepend (phi);
1295 phi->stash_argument (0, riter->second);
1296 phi->stash_argument (1, viter->second);
1297 viter->second = phi;
1298 }
1299 }
1300 }
1301
1302 // -------------------- jit_convert::toplevel_map --------------------
1303 jit_value *
1304 jit_convert::toplevel_map::insert (const std::string& name, jit_value *pval)
1305 {
1306 assert (pval == 0); // we have no parent
1307
1308 jit_block *entry = block ();
1309 octave_value val = symbol_table::find (name);
1310 jit_type *type = jit_typeinfo::type_of (val);
1311 jit_instruction *ret = new jit_extract_argument (type, name);
1312 return vars[name] = entry->prepend (ret);
1039 } 1313 }
1040 1314
1041 // -------------------- jit_convert::convert_llvm -------------------- 1315 // -------------------- jit_convert::convert_llvm --------------------
1042 llvm::Function * 1316 llvm::Function *
1043 jit_convert::convert_llvm::convert (llvm::Module *module, 1317 jit_convert::convert_llvm::convert (llvm::Module *module,
1050 // argument is an array of octave_base_value*, or octave_base_value** 1324 // argument is an array of octave_base_value*, or octave_base_value**
1051 llvm::Type *arg_type = any->to_llvm (); // this is octave_base_value* 1325 llvm::Type *arg_type = any->to_llvm (); // this is octave_base_value*
1052 arg_type = arg_type->getPointerTo (); 1326 arg_type = arg_type->getPointerTo ();
1053 llvm::FunctionType *ft = llvm::FunctionType::get (llvm::Type::getVoidTy (context), 1327 llvm::FunctionType *ft = llvm::FunctionType::get (llvm::Type::getVoidTy (context),
1054 arg_type, false); 1328 arg_type, false);
1055 llvm::Function *function = llvm::Function::Create (ft, 1329 function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage,
1056 llvm::Function::ExternalLinkage, 1330 "foobar", module);
1057 "foobar", module);
1058 1331
1059 try 1332 try
1060 { 1333 {
1061 llvm::BasicBlock *prelude = llvm::BasicBlock::Create (context, "prelude", 1334 llvm::BasicBlock *prelude = llvm::BasicBlock::Create (context, "prelude",
1062 function); 1335 function);
1071 1344
1072 // we need to generate llvm values for constants, as these don't appear in 1345 // we need to generate llvm values for constants, as these don't appear in
1073 // a block 1346 // a block
1074 for (std::list<jit_value *>::const_iterator iter = constants.begin (); 1347 for (std::list<jit_value *>::const_iterator iter = constants.begin ();
1075 iter != constants.end (); ++iter) 1348 iter != constants.end (); ++iter)
1076 { 1349 visit (*iter);
1077 jit_value *constant = *iter;
1078 if (! dynamic_cast<jit_instruction *> (constant))
1079 visit (constant);
1080 }
1081 1350
1082 std::list<jit_block *>::const_iterator biter; 1351 std::list<jit_block *>::const_iterator biter;
1083 for (biter = blocks.begin (); biter != blocks.end (); ++biter) 1352 for (biter = blocks.begin (); biter != blocks.end (); ++biter)
1084 { 1353 {
1085 jit_block *jblock = *biter; 1354 jit_block *jblock = *biter;
1089 } 1358 }
1090 1359
1091 jit_block *first = *blocks.begin (); 1360 jit_block *first = *blocks.begin ();
1092 builder.CreateBr (first->to_llvm ()); 1361 builder.CreateBr (first->to_llvm ());
1093 1362
1363 // convert all instructions
1094 for (biter = blocks.begin (); biter != blocks.end (); ++biter) 1364 for (biter = blocks.begin (); biter != blocks.end (); ++biter)
1095 visit (*biter); 1365 visit (*biter);
1096 1366
1367 // now finish phi nodes
1368 for (biter = blocks.begin (); biter != blocks.end (); ++biter)
1369 {
1370 jit_block& block = **biter;
1371 for (jit_block::iterator piter = block.begin ();
1372 piter != block.end () && dynamic_cast<jit_phi *> (*piter); ++piter)
1373 {
1374 // our phi nodes don't have to have the same incomming type,
1375 // so we do casts here
1376 jit_instruction *phi = *piter;
1377 jit_block *pblock = phi->parent ();
1378 llvm::PHINode *llvm_phi = llvm::cast<llvm::PHINode> (phi->to_llvm ());
1379 for (size_t i = 0; i < phi->argument_count (); ++i)
1380 {
1381 llvm::BasicBlock *pred = pblock->pred_llvm (i);
1382 if (phi->argument_type_llvm (i) == phi->type_llvm ())
1383 {
1384 llvm_phi->addIncoming (phi->argument_llvm (i), pred);
1385 }
1386 else
1387 {
1388 // add cast right before pred terminator
1389 builder.SetInsertPoint (--pred->end ());
1390
1391 const jit_function::overload& ol
1392 = jit_typeinfo::cast (phi->type (),
1393 phi->argument_type (i));
1394 if (! ol.function)
1395 {
1396 std::stringstream ss;
1397 ss << "No cast for phi(" << i << "): ";
1398 phi->print (ss);
1399 fail (ss.str ());
1400 }
1401
1402 llvm::Value *casted;
1403 casted = builder.CreateCall (ol.function,
1404 phi->argument_llvm (i));
1405 llvm_phi->addIncoming (casted, pred);
1406 }
1407 }
1408 }
1409 }
1410
1411 jit_block *last = blocks.back ();
1412 builder.SetInsertPoint (last->to_llvm ());
1097 builder.CreateRetVoid (); 1413 builder.CreateRetVoid ();
1098 } catch (const jit_fail_exception&) 1414 } catch (const jit_fail_exception& e)
1099 { 1415 {
1100 function->eraseFromParent (); 1416 function->eraseFromParent ();
1101 throw; 1417 throw;
1102 } 1418 }
1103 1419
1104 llvm::verifyFunction (*function);
1105
1106 return function; 1420 return function;
1107 } 1421 }
1108 1422
1109 void 1423 void
1110 jit_convert::convert_llvm::visit_const_string (jit_const_string& cs) 1424 jit_convert::convert_llvm::visit (jit_const_string& cs)
1111 { 1425 {
1112 cs.stash_llvm (builder.CreateGlobalStringPtr (cs.value ())); 1426 cs.stash_llvm (builder.CreateGlobalStringPtr (cs.value ()));
1113 } 1427 }
1114 1428
1115 void 1429 void
1116 jit_convert::convert_llvm::visit_const_scalar (jit_const_scalar& cs) 1430 jit_convert::convert_llvm::visit (jit_const_scalar& cs)
1117 { 1431 {
1118 llvm::Type *dbl = llvm::Type::getDoubleTy (context); 1432 cs.stash_llvm (llvm::ConstantFP::get (cs.type_llvm (), cs.value ()));
1119 cs.stash_llvm (llvm::ConstantFP::get (dbl, cs.value ())); 1433 }
1120 } 1434
1121 1435 void jit_convert::convert_llvm::visit (jit_const_index& ci)
1122 void 1436 {
1123 jit_convert::convert_llvm::visit_block (jit_block& b) 1437 ci.stash_llvm (llvm::ConstantInt::get (ci.type_llvm (), ci.value ()));
1438 }
1439
1440 void
1441 jit_convert::convert_llvm::visit (jit_const_range& cr)
1442 {
1443 llvm::StructType *stype = llvm::cast<llvm::StructType>(cr.type_llvm ());
1444 llvm::Type *dbl = jit_typeinfo::get_scalar_llvm ();
1445 llvm::Type *idx = jit_typeinfo::get_index_llvm ();
1446 const jit_range& rng = cr.value ();
1447
1448 llvm::Constant *constants[4];
1449 constants[0] = llvm::ConstantFP::get (dbl, rng.base);
1450 constants[1] = llvm::ConstantFP::get (dbl, rng.limit);
1451 constants[2] = llvm::ConstantFP::get (dbl, rng.inc);
1452 constants[3] = llvm::ConstantInt::get (idx, rng.nelem);
1453
1454 llvm::Value *as_llvm;
1455 as_llvm = llvm::ConstantStruct::get (stype,
1456 llvm::makeArrayRef (constants, 4));
1457 cr.stash_llvm (as_llvm);
1458 }
1459
1460 void
1461 jit_convert::convert_llvm::visit (jit_block& b)
1124 { 1462 {
1125 llvm::BasicBlock *block = b.to_llvm (); 1463 llvm::BasicBlock *block = b.to_llvm ();
1126 builder.SetInsertPoint (block); 1464 builder.SetInsertPoint (block);
1127 for (jit_block::iterator iter = b.begin (); iter != b.end (); ++iter) 1465 for (jit_block::iterator iter = b.begin (); iter != b.end (); ++iter)
1128 visit (*iter); 1466 visit (*iter);
1129 } 1467 }
1130 1468
1131 void 1469 void
1132 jit_convert::convert_llvm::visit_break (jit_break& b) 1470 jit_convert::convert_llvm::visit (jit_break& b)
1133 { 1471 {
1134 builder.CreateBr (b.sucessor_llvm ()); 1472 b.stash_llvm (builder.CreateBr (b.sucessor_llvm ()));
1135 } 1473 }
1136 1474
1137 void 1475 void
1138 jit_convert::convert_llvm::visit_cond_break (jit_cond_break& cb) 1476 jit_convert::convert_llvm::visit (jit_cond_break& cb)
1139 { 1477 {
1140 llvm::Value *cond = cb.cond_llvm (); 1478 llvm::Value *cond = cb.cond_llvm ();
1141 builder.CreateCondBr (cond, cb.sucessor_llvm (0), cb.sucessor_llvm (1)); 1479 llvm::Value *br;
1142 } 1480 br = builder.CreateCondBr (cond, cb.sucessor_llvm (0), cb.sucessor_llvm (1));
1143 1481 cb.stash_llvm (br);
1144 void 1482 }
1145 jit_convert::convert_llvm::visit_call (jit_call& call) 1483
1484 void
1485 jit_convert::convert_llvm::visit (jit_call& call)
1146 { 1486 {
1147 const jit_function::overload& ol = call.overload (); 1487 const jit_function::overload& ol = call.overload ();
1148 if (! ol.function) 1488 if (! ol.function)
1149 fail (); 1489 fail ("No overload for: " + call.print_string ());
1150 1490
1151 std::vector<llvm::Value *> args (call.argument_count ()); 1491 std::vector<llvm::Value *> args (call.argument_count ());
1152 for (size_t i = 0; i < call.argument_count (); ++i) 1492 for (size_t i = 0; i < call.argument_count (); ++i)
1153 args[i] = call.argument_llvm (i); 1493 args[i] = call.argument_llvm (i);
1154 1494
1155 call.stash_llvm (builder.CreateCall (ol.function, args)); 1495 call.stash_llvm (builder.CreateCall (ol.function, args, call.tag ()));
1156 } 1496 }
1157 1497
1158 void 1498 void
1159 jit_convert::convert_llvm::visit_extract_argument (jit_extract_argument& extract) 1499 jit_convert::convert_llvm::visit (jit_extract_argument& extract)
1160 { 1500 {
1161 const jit_function::overload& ol = extract.overload (); 1501 const jit_function::overload& ol = extract.overload ();
1162 if (! ol.function) 1502 if (! ol.function)
1163 fail (); 1503 fail ();
1164 1504
1165 llvm::Value *arg = arguments[extract.tag ()]; 1505 llvm::Value *arg = arguments[extract.tag ()];
1506 assert (arg);
1166 arg = builder.CreateLoad (arg); 1507 arg = builder.CreateLoad (arg);
1167 extract.stash_llvm (builder.CreateCall (ol.function, arg)); 1508 extract.stash_llvm (builder.CreateCall (ol.function, arg, extract.tag ()));
1168 } 1509 }
1169 1510
1170 void 1511 void
1171 jit_convert::convert_llvm::visit_store_argument (jit_store_argument& store) 1512 jit_convert::convert_llvm::visit (jit_store_argument& store)
1172 { 1513 {
1173 llvm::Value *arg_value = store.result_llvm (); 1514 llvm::Value *arg_value = store.result_llvm ();
1174 const jit_function::overload& ol = store.overload (); 1515 const jit_function::overload& ol = store.overload ();
1175 if (! ol.function) 1516 if (! ol.function)
1176 fail (); 1517 fail ();
1179 1520
1180 llvm::Value *arg = arguments[store.tag ()]; 1521 llvm::Value *arg = arguments[store.tag ()];
1181 store.stash_llvm (builder.CreateStore (arg_value, arg)); 1522 store.stash_llvm (builder.CreateStore (arg_value, arg));
1182 } 1523 }
1183 1524
1525 void
1526 jit_convert::convert_llvm::visit (jit_phi& phi)
1527 {
1528 // we might not have converted all incoming branches, so we don't
1529 // set incomming branches now
1530 llvm::PHINode *node = llvm::PHINode::Create (phi.type_llvm (),
1531 phi.argument_count (),
1532 phi.tag ());
1533 builder.Insert (node);
1534 phi.stash_llvm (node);
1535
1536 jit_block *parent = phi.parent ();
1537 for (size_t i = 0; i < phi.argument_count (); ++i)
1538 if (phi.argument_type (i) != phi.type ())
1539 parent->create_merge (function, i);
1540 }
1541
1184 // -------------------- tree_jit -------------------- 1542 // -------------------- tree_jit --------------------
1185 1543
1186 tree_jit::tree_jit (void) : context (llvm::getGlobalContext ()), engine (0) 1544 tree_jit::tree_jit (void) : module (0), engine (0)
1187 { 1545 {
1188 llvm::InitializeNativeTarget ();
1189 module = new llvm::Module ("octave", context);
1190 } 1546 }
1191 1547
1192 tree_jit::~tree_jit (void) 1548 tree_jit::~tree_jit (void)
1193 {} 1549 {}
1194 1550
1195 bool 1551 bool
1196 tree_jit::execute (tree& cmd) 1552 tree_jit::execute (tree_simple_for_command& cmd)
1197 { 1553 {
1198 if (! initialize ()) 1554 if (! initialize ())
1199 return false; 1555 return false;
1200 1556
1201 compiled_map::iterator iter = compiled.find (&cmd); 1557 jit_info *info = cmd.get_info ();
1202 jit_info *jinfo = 0; 1558 if (! info || ! info->match ())
1203 if (iter != compiled.end ()) 1559 {
1204 { 1560 delete info;
1205 jinfo = iter->second; 1561 info = new jit_info (*this, cmd);
1206 if (! jinfo->match ()) 1562 cmd.stash_info (info);
1207 { 1563 }
1208 delete jinfo; 1564
1209 jinfo = 0; 1565 return info->execute ();
1210 }
1211 }
1212
1213 if (! jinfo)
1214 {
1215 jinfo = new jit_info (*this, cmd);
1216 compiled[&cmd] = jinfo;
1217 }
1218
1219 return jinfo->execute ();
1220 } 1566 }
1221 1567
1222 bool 1568 bool
1223 tree_jit::initialize (void) 1569 tree_jit::initialize (void)
1224 { 1570 {
1225 if (engine) 1571 if (engine)
1226 return true; 1572 return true;
1573
1574 if (! module)
1575 {
1576 llvm::InitializeNativeTarget ();
1577 module = new llvm::Module ("octave", context);
1578 }
1227 1579
1228 // sometimes this fails pre main 1580 // sometimes this fails pre main
1229 engine = llvm::ExecutionEngine::createJIT (module); 1581 engine = llvm::ExecutionEngine::createJIT (module);
1230 1582
1231 if (! engine) 1583 if (! engine)
1267 jit_convert conv (tjit.get_module (), tee); 1619 jit_convert conv (tjit.get_module (), tee);
1268 fun = conv.get_function (); 1620 fun = conv.get_function ();
1269 arguments = conv.get_arguments (); 1621 arguments = conv.get_arguments ();
1270 bounds = conv.get_bounds (); 1622 bounds = conv.get_bounds ();
1271 } 1623 }
1272 catch (const jit_fail_exception&) 1624 catch (const jit_fail_exception& e)
1273 {} 1625 {
1626 if (debug_print && e.known ())
1627 std::cout << "jit fail: " << e.what () << std::endl;
1628 }
1274 1629
1275 if (! fun) 1630 if (! fun)
1276 { 1631 {
1277 function = 0; 1632 function = 0;
1278 return; 1633 return;
1324 return true; 1679 return true;
1325 1680
1326 for (size_t i = 0; i < bounds.size (); ++i) 1681 for (size_t i = 0; i < bounds.size (); ++i)
1327 { 1682 {
1328 const std::string& arg_name = bounds[i].second; 1683 const std::string& arg_name = bounds[i].second;
1329 octave_value value = symbol_table::varval (arg_name); 1684 octave_value value = symbol_table::find (arg_name);
1330 jit_type *type = jit_typeinfo::type_of (value); 1685 jit_type *type = jit_typeinfo::type_of (value);
1331 1686
1332 // FIXME: Check for a parent relationship 1687 // FIXME: Check for a parent relationship
1333 if (type != bounds[i].first) 1688 if (type != bounds[i].first)
1334 return false; 1689 return false;