Mercurial > hg > octave-lyh
diff src/pt-jit.cc @ 14969:bbeef7b8ea2e
Add support for matrix indexed assignment to JIT
* src/pt-jit.cc (octave_jit_subsasgn_impl, jit_convert::resolve): New function.
(jit_typeinfo::jit_typeinfo): Add subsasgn implementation in llvm.
(jit_convert::visit_simple_for_command): Use new do_assign overload.
(jit_convert::visit_index_expression): Use new do_assign overload and resolve.
(jit_convert::visit_simple_assignment): Use new do_assign overload.
(jit_convert::do_assign): New overload.
(jit_convert::convert_llvm::visit): Check if assignment is artificial.
* src/pt-jit.h (jit_typeinfo::paren_subsasgn, jit_convert::create_check):
New function.
(jit_assign::jit_assign): Initialize martificial.
(jit_assign::artificial, jit_assign::mark_artificial): New function.
(jit_assign::print): Print the artificial flag.
(jit_convert::create_checked_impl): Call create_check.
(jit_convert::resolve): New declaration.
(jit_convert::do_assign): New overload declaration.
author | Max Brister <max@2bass.com> |
---|---|
date | Mon, 25 Jun 2012 14:21:45 -0500 |
parents | 7f60cdfcc0e5 |
children | b23a98ca0e43 |
line wrap: on
line diff
--- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -237,6 +237,24 @@ } extern "C" void +octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index, + double value) +{ + std::cout << "impl\n"; + NDArray *array = mat->array; + if (array->nelem () < index) + array->resize1 (index); + + double *data = array->fortran_vec (); + data[index - 1] = value; + + mat->ref_count = array->jit_ref_count (); + mat->slice_data = array->jit_slice_data () - 1; + mat->dimensions = array->jit_dimensions (); + mat->slice_len = array->nelem (); +} + +extern "C" void octave_jit_print_matrix (jit_matrix *m) { std::cout << *m << std::endl; @@ -755,6 +773,92 @@ llvm::verifyFunction (*fn); paren_subsref_fn.add_overload (fn, true, scalar, matrix, scalar); + // paren subsasgn + paren_subsasgn_fn.stash_name ("()subsasgn"); + + llvm::Function *resize_paren_subsasgn + = create_function ("octave_jit_paren_subsasgn_impl", void_t, + matrix_t->getPointerTo (), index_t, scalar_t); + engine->addGlobalMapping (resize_paren_subsasgn, + reinterpret_cast<void *> (&octave_jit_paren_subsasgn_impl)); + + fn = create_function ("octave_jit_paren_subsasgn", matrix, matrix, scalar, + scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (index_t, 1); + + llvm::Function::arg_iterator args = fn->arg_begin (); + llvm::Value *mat = args++; + llvm::Value *idx = args++; + llvm::Value *value = args; + + llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t); + llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t); + llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx); + llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one); + llvm::Value *cond = builder.CreateOr (cond0, cond1); + + llvm::BasicBlock *done = llvm::BasicBlock::Create (context, "done", fn); + + llvm::BasicBlock *conv_error = llvm::BasicBlock::Create (context, + "conv_error", fn, + done); + llvm::BasicBlock *normal = llvm::BasicBlock::Create (context, "normal", fn, + done); + builder.CreateCondBr (cond, conv_error, normal); + builder.SetInsertPoint (conv_error); + builder.CreateCall (ginvalid_index); + builder.CreateBr (done); + + builder.SetInsertPoint (normal); + llvm::Value *len = builder.CreateExtractValue (mat, + llvm::ArrayRef<unsigned> (2)); + cond0 = builder.CreateICmpSGT (int_idx, len); + + llvm::Value *rcount = builder.CreateExtractValue (mat, 0); + rcount = builder.CreateLoad (rcount); + cond1 = builder.CreateICmpSGT (rcount, one); + cond = builder.CreateOr (cond0, cond1); + + llvm::BasicBlock *bounds_error = llvm::BasicBlock::Create (context, + "bounds_error", + fn, done); + + llvm::BasicBlock *success = llvm::BasicBlock::Create (context, "success", + fn, done); + builder.CreateCondBr (cond, bounds_error, success); + + // resize on out of bounds access + builder.SetInsertPoint (bounds_error); + llvm::Value *resize_result = builder.CreateAlloca (matrix_t); + builder.CreateStore (mat, resize_result); + builder.CreateCall3 (resize_paren_subsasgn, resize_result, int_idx, value); + resize_result = builder.CreateLoad (resize_result); + builder.CreateBr (done); + + builder.SetInsertPoint (success); + llvm::Value *data = builder.CreateExtractValue (mat, + llvm::ArrayRef<unsigned> (1)); + llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx); + builder.CreateStore (value, gep); + builder.CreateBr (done); + + builder.SetInsertPoint (done); + + llvm::PHINode *merge = llvm::PHINode::Create (matrix_t, 3); + builder.Insert (merge); + merge->addIncoming (mat, conv_error); + merge->addIncoming (resize_result, bounds_error); + merge->addIncoming (mat, success); + builder.CreateRet (merge); + } + llvm::verifyFunction (*fn); + paren_subsasgn_fn.add_overload (fn, true, matrix, matrix, scalar, scalar); + + // paren_subsasgn + casts[any->type_id ()].stash_name ("(any)"); casts[scalar->type_id ()].stash_name ("(scalar)"); @@ -1689,12 +1793,6 @@ prot.protect_var (breaking); breaks.clear (); - // FIXME: one of these days we will introduce proper lvalues... - tree_identifier *lhs = dynamic_cast<tree_identifier *>(cmd.left_hand_side ()); - if (! lhs) - fail (); - std::string lhs_name = lhs->name (); - // we need a variable for our iterator, because it is used in multiple blocks std::stringstream ss; ss << "#iter" << iterator_count++; @@ -1719,9 +1817,10 @@ block = body; // compute the syntactical iterator - jit_call *idx_rhs = create<jit_call> (jit_typeinfo::for_index, control, iterator); + jit_call *idx_rhs = create<jit_call> (jit_typeinfo::for_index, control, + iterator); block->append (idx_rhs); - do_assign (lhs_name, idx_rhs, false); + do_assign (cmd.left_hand_side (), idx_rhs); // do loop tree_statement_list *pt_body = cmd.body (); @@ -1901,26 +2000,9 @@ void jit_convert::visit_index_expression (tree_index_expression& exp) { - std::string type = exp.type_tags (); - if (! (type.size () == 1 && type[0] == '(')) - fail ("Unsupported index operation"); - - std::list<tree_argument_list *> args = exp.arg_lists (); - if (args.size () != 1) - fail ("Bad number of arguments in tree_index_expression"); - - tree_argument_list *arg_list = args.front (); - if (! arg_list) - fail ("null argument list"); - - if (arg_list->size () != 1) - fail ("Bad number of arguments in arg_list"); - - tree_expression *tree_object = exp.expression (); - jit_value *object = visit (tree_object); - - tree_expression *arg0 = arg_list->front (); - jit_value *index = visit (arg0); + std::pair<jit_value *, jit_value *> res = resolve (exp); + jit_value *object = res.first; + jit_value *index = res.second; result = create_checked (jit_typeinfo::paren_subsref, object, index); } @@ -2013,13 +2095,7 @@ tree_expression *rhs = tsa.right_hand_side (); jit_value *rhsv = visit (rhs); - // resolve lhs - tree_expression *lhs = tsa.left_hand_side (); - if (! lhs->is_identifier ()) - fail (); - - std::string lhs_name = lhs->name (); - result = do_assign (lhs_name, rhsv, tsa.print_result ()); + do_assign (tsa.left_hand_side (), rhsv); } void @@ -2156,12 +2232,68 @@ return vmap[vname] = var; } +std::pair<jit_value *, jit_value *> +jit_convert::resolve (tree_index_expression& exp) +{ + std::string type = exp.type_tags (); + if (! (type.size () == 1 && type[0] == '(')) + fail ("Unsupported index operation"); + + std::list<tree_argument_list *> args = exp.arg_lists (); + if (args.size () != 1) + fail ("Bad number of arguments in tree_index_expression"); + + tree_argument_list *arg_list = args.front (); + if (! arg_list) + fail ("null argument list"); + + if (arg_list->size () != 1) + fail ("Bad number of arguments in arg_list"); + + tree_expression *tree_object = exp.expression (); + jit_value *object = visit (tree_object); + tree_expression *arg0 = arg_list->front (); + jit_value *index = visit (arg0); + + return std::make_pair (object, index); +} + +jit_value * +jit_convert::do_assign (tree_expression *exp, jit_value *rhs, bool artificial) +{ + if (! exp) + fail ("NULL lhs in assign"); + + if (isa<tree_identifier> (exp)) + return do_assign (exp->name (), rhs, exp->print_result (), artificial); + else if (tree_index_expression *idx + = dynamic_cast<tree_index_expression *> (exp)) + { + std::pair<jit_value *, jit_value *> res = resolve (*idx); + jit_value *object = res.first; + jit_value *index = res.second; + jit_call *new_object = create<jit_call> (&jit_typeinfo::paren_subsasgn, + object, index, rhs); + block->append (new_object); + do_assign (idx->expression (), new_object, true); + create_check (new_object); + + // FIXME: Will not work for values that must be release/grabed + return rhs; + } + else + fail ("Unsupported assignment"); +} + jit_value * jit_convert::do_assign (const std::string& lhs, jit_value *rhs, - bool print) + bool print, bool artificial) { jit_variable *var = get_variable (lhs); - block->append (create<jit_assign> (var, rhs)); + jit_assign *assign = block->append (create<jit_assign> (var, rhs)); + + if (artificial) + assign->mark_artificial (); if (print) { @@ -2776,6 +2908,9 @@ { assign.stash_llvm (assign.src ()->to_llvm ()); + if (assign.artificial ()) + return; + jit_value *new_value = assign.src (); if (isa<jit_assign_base> (new_value)) {