# HG changeset patch # User Max Brister # Date 1339709886 18000 # Node ID 4c9fd3e314366db7838e8524dea6506bac28b565 # Parent 7ab3ac5c676c094e00c8b649dfef962b65d30829 Start of jit support for double matricies diff --git a/liboctave/Array.h b/liboctave/Array.h --- a/liboctave/Array.h +++ b/liboctave/Array.h @@ -164,6 +164,14 @@ return &nr; } +protected: + + // For jit support + Array (T *sdata, octave_idx_type slen, octave_idx_type *adims, void *arep) + : dimensions (adims), + rep (reinterpret_cast::ArrayRep *> (arep)), + slice_data (sdata), slice_len (slen) {} + public: // Empty ctor (0x0). @@ -693,6 +701,16 @@ // supposedly equal dimensions (e.g. structs in the interpreter). bool optimize_dimensions (const dim_vector& dv); + // WARNING: Only call these functions from jit + + int *jit_ref_count (void) { return rep->count.get (); } + + T *jit_slice_data (void) const { return slice_data; } + + octave_idx_type *jit_dimensions (void) const { return dimensions.to_jit (); } + + void *jit_array_rep (void) const { return rep; } + private: void resize2 (octave_idx_type nr, octave_idx_type nc, const T& rfv); diff --git a/liboctave/MArray.h b/liboctave/MArray.h --- a/liboctave/MArray.h +++ b/liboctave/MArray.h @@ -39,6 +39,12 @@ class MArray : public Array { +protected: + + // For jit support + MArray (T *sdata, octave_idx_type slen, octave_idx_type *adims, void *arep) + : Array (sdata, slen, adims, arep) { } + public: MArray (void) : Array () {} diff --git a/liboctave/dNDArray.h b/liboctave/dNDArray.h --- a/liboctave/dNDArray.h +++ b/liboctave/dNDArray.h @@ -64,6 +64,10 @@ NDArray (const charNDArray&); + // For jit support only + NDArray (double *sdata, octave_idx_type slen, octave_idx_type *adims, void *arep) + : MArray (sdata, slen, adims, arep) { } + NDArray& operator = (const NDArray& a) { MArray::operator = (a); diff --git a/liboctave/dim-vector.h b/liboctave/dim-vector.h --- a/liboctave/dim-vector.h +++ b/liboctave/dim-vector.h @@ -212,6 +212,12 @@ void chop_all_singletons (void); + // WARNING: Only call by jit + octave_idx_type *to_jit (void) const + { + return rep; + } + private: static octave_idx_type *nil_rep (void) @@ -220,9 +226,6 @@ return zv.rep; } - explicit dim_vector (octave_idx_type *r) - : rep (r) { } - public: static octave_idx_type dim_max (void); @@ -233,6 +236,10 @@ dim_vector (const dim_vector& dv) : rep (dv.rep) { OCTREFCOUNT_ATOMIC_INCREMENT (&(count())); } + // FIXME: Should be private, but required by array constructor for jit + explicit dim_vector (octave_idx_type *r) + : rep (r) { } + static dim_vector alloc (int n) { return dim_vector (newrep (n < 2 ? 2 : n)); diff --git a/liboctave/oct-refcount.h b/liboctave/oct-refcount.h --- a/liboctave/oct-refcount.h +++ b/liboctave/oct-refcount.h @@ -82,6 +82,11 @@ return static_cast (count); } + count_type *get (void) + { + return &count; + } + private: count_type count; }; diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -23,6 +23,8 @@ #define __STDC_LIMIT_MACROS #define __STDC_CONSTANT_MACROS +#define OCTAVE_JIT_DEBUG + #ifdef HAVE_CONFIG_H #include #endif @@ -147,6 +149,12 @@ obv->release (); } +extern "C" void +octave_jit_delete_matrix (jit_matrix *m) +{ + NDArray array (*m); +} + extern "C" octave_base_value * octave_jit_grab_any (octave_base_value *obv) { @@ -154,6 +162,25 @@ return obv; } +extern "C" octave_base_value * +octave_jit_cast_any_matrix (jit_matrix *jmatrix) +{ + ++(*jmatrix->ref_count); + NDArray matrix = *jmatrix; + octave_value ret (matrix); + + octave_base_value *rep = ret.internal_rep (); + rep->grab (); + return rep; +} + +extern "C" void +octave_jit_cast_matrix_any (jit_matrix *ret, octave_base_value *obv) +{ + NDArray m = obv->array_value (); + *ret = m; +} + extern "C" double octave_jit_cast_scalar_any (octave_base_value *obv) { @@ -190,6 +217,40 @@ return obv; } +extern "C" void +octave_jit_ginvalid_index (void) +{ + try + { + gripe_invalid_index (); + } + catch (const octave_execution_exception&) + { + gripe_library_execution_error (); + } +} + +extern "C" void +octave_jit_gindex_range (int nd, int dim, octave_idx_type iext, + octave_idx_type ext) +{ + std::cout << "gindex_range\n"; + try + { + gripe_index_out_of_range (nd, dim, iext, ext); + } + catch (const octave_execution_exception&) + { + gripe_library_execution_error (); + } +} + +extern "C" void +octave_jit_print_matrix (jit_matrix *m) +{ + std::cout << *m << std::endl; +} + // -------------------- jit_range -------------------- std::ostream& operator<< (std::ostream& os, const jit_range& rng) @@ -198,6 +259,16 @@ << ", " << rng.nelem << "]"; } +// -------------------- jit_matrix -------------------- + +std::ostream& +operator<< (std::ostream& os, const jit_matrix& mat) +{ + return os << "Matrix[" << mat.ref_count << ", " << mat.slice_data << ", " + << mat.slice_len << ", " << mat.dimensions << ", " + << mat.array_rep << "]"; +} + // -------------------- jit_type -------------------- llvm::Type * jit_type::to_llvm_arg (void) const @@ -291,34 +362,36 @@ : module (m), engine (e), next_id (0) { // FIXME: We should be registering types like in octave_value_typeinfo - ov_t = llvm::StructType::create (context, "octave_base_value"); - ov_t = ov_t->getPointerTo (); - - llvm::Type *dbl = llvm::Type::getDoubleTy (context); + llvm::Type *any_t = llvm::StructType::create (context, "octave_base_value"); + any_t = any_t->getPointerTo (); + + llvm::Type *scalar_t = llvm::Type::getDoubleTy (context); llvm::Type *bool_t = llvm::Type::getInt1Ty (context); llvm::Type *string_t = llvm::Type::getInt8Ty (context); string_t = string_t->getPointerTo (); - llvm::Type *index_t = 0; - switch (sizeof(octave_idx_type)) - { - case 4: - index_t = llvm::Type::getInt32Ty (context); - break; - case 8: - index_t = llvm::Type::getInt64Ty (context); - break; - default: - assert (false && "Unrecognized index type size"); - } + llvm::Type *index_t = llvm::Type::getIntNTy (context, sizeof(octave_idx_type) * 8); llvm::StructType *range_t = llvm::StructType::create (context, "range"); - std::vector range_contents (4, dbl); + std::vector range_contents (4, scalar_t); range_contents[3] = index_t; range_t->setBody (range_contents); + llvm::Type *refcount_t = llvm::Type::getIntNTy (context, sizeof(int) * 8); + llvm::Type *int_t = refcount_t; + + llvm::StructType *matrix_t = llvm::StructType::create (context, "matrix"); + llvm::Type *matrix_contents[5]; + matrix_contents[0] = refcount_t->getPointerTo (); + matrix_contents[1] = scalar_t->getPointerTo (); + matrix_contents[2] = index_t; + matrix_contents[3] = index_t->getPointerTo (); + matrix_contents[4] = string_t; + matrix_t->setBody (llvm::makeArrayRef (matrix_contents, 5)); + // create types - any = new_type ("any", 0, ov_t); - scalar = new_type ("scalar", any, dbl); + any = new_type ("any", 0, any_t); + matrix = new_type ("matrix", any, matrix_t); + scalar = new_type ("scalar", any, scalar_t); range = new_type ("range", any, range_t); string = new_type ("string", any, string_t); boolean = new_type ("bool", any, bool_t); @@ -378,6 +451,27 @@ grab_fn.add_overload (fn, false, any, any); grab_fn.stash_name ("grab"); + // grab matrix + llvm::Function *print_matrix = create_function ("octave_jit_print_matrix", + void_t, + matrix_t->getPointerTo ()); + engine->addGlobalMapping (print_matrix, reinterpret_cast(&octave_jit_print_matrix)); + + fn = create_function ("octave_jit_grab_matrix", matrix, matrix); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (refcount_t, 1); + + llvm::Value *mat = fn->arg_begin (); + llvm::Value *rcount= builder.CreateExtractValue (mat, 0); + llvm::Value *count = builder.CreateLoad (rcount); + count = builder.CreateAdd (count, one); + builder.CreateStore (count, rcount); + builder.CreateRet (mat); + } + grab_fn.add_overload (fn, false, matrix, matrix); + // grab scalar fn = create_identity (scalar); grab_fn.add_overload (fn, false, scalar, scalar); @@ -387,11 +481,45 @@ grab_fn.add_overload (fn, false, index, index); // release any - fn = create_function ("octave_jit_release_any", void_t, any->to_llvm ()); + fn = create_function ("octave_jit_release_any", void_t, any_t); engine->addGlobalMapping (fn, reinterpret_cast(&octave_jit_release_any)); release_fn.add_overload (fn, false, 0, any); release_fn.stash_name ("release"); + // release matrix + llvm::Function *delete_mat = create_function ("octave_jit_delete_matrix", void_t, + matrix_t); + engine->addGlobalMapping (delete_mat, + reinterpret_cast (&octave_jit_delete_matrix)); + + fn = create_function ("octave_jit_release_matrix", void_t, matrix_t); + llvm::Function *release_mat = fn; + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (refcount_t, 1); + llvm::Value *zero = llvm::ConstantInt::get (refcount_t, 0); + + llvm::Value *mat = fn->arg_begin (); + llvm::Value *rcount= builder.CreateExtractValue (mat, 0); + llvm::Value *count = builder.CreateLoad (rcount); + count = builder.CreateSub (count, one); + + llvm::BasicBlock *dead = llvm::BasicBlock::Create (context, "dead", fn); + llvm::BasicBlock *live = llvm::BasicBlock::Create (context, "live", fn); + llvm::Value *isdead = builder.CreateICmpEQ (count, zero); + builder.CreateCondBr (isdead, dead, live); + + builder.SetInsertPoint (dead); + builder.CreateCall (delete_mat, mat); + builder.CreateRetVoid (); + + builder.SetInsertPoint (live); + builder.CreateStore (count, rcount); + builder.CreateRetVoid (); + } + release_fn.add_overload (fn, false, 0, matrix); + // release scalar fn = create_identity (scalar); release_fn.add_overload (fn, false, 0, scalar); @@ -429,13 +557,13 @@ // divide is annoying because it might error fn = create_function ("octave_jit_div_scalar_scalar", scalar, scalar, scalar); - llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { llvm::BasicBlock *warn_block = llvm::BasicBlock::Create (context, "warn", fn); llvm::BasicBlock *normal_block = llvm::BasicBlock::Create (context, "normal", fn); - llvm::Value *zero = llvm::ConstantFP::get (dbl, 0); + llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); llvm::Value *check = builder.CreateFCmpUEQ (zero, ++fn->arg_begin ()); builder.CreateCondBr (check, warn_block, normal_block); @@ -514,7 +642,7 @@ builder.SetInsertPoint (body); { llvm::Value *idx = ++fn->arg_begin (); - llvm::Value *didx = builder.CreateUIToFP (idx, dbl); + llvm::Value *didx = builder.CreateSIToFP (idx, scalar_t); llvm::Value *rng = fn->arg_begin (); llvm::Value *base = builder.CreateExtractValue (rng, 0); llvm::Value *inc = builder.CreateExtractValue (rng, 2); @@ -548,7 +676,7 @@ builder.CreateBr (normal_block); builder.SetInsertPoint (normal_block); - llvm::Value *zero = llvm::ConstantFP::get (dbl, 0); + llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); llvm::Value *ret = builder.CreateFCmpONE (fn->arg_begin (), zero); builder.CreateRet (ret); } @@ -580,7 +708,7 @@ llvm::Value *inc = ++args; llvm::Value *nelem = builder.CreateCall3 (compute_nelem, base, limit, inc); - llvm::Value *dzero = llvm::ConstantFP::get (dbl, 0); + llvm::Value *dzero = llvm::ConstantFP::get (scalar_t, 0); llvm::Value *izero = llvm::ConstantInt::get (index_t, 0); llvm::Value *rng = llvm::ConstantStruct::get (range_t, dzero, dzero, dzero, izero, NULL); @@ -593,9 +721,110 @@ llvm::verifyFunction (*fn); make_range_fn.add_overload (fn, false, range, scalar, scalar, scalar); + // paren_subsref + llvm::Function *ginvalid_index = create_function ("gipe_invalid_index", void_t); + engine->addGlobalMapping (ginvalid_index, + reinterpret_cast (&octave_jit_ginvalid_index)); + + llvm::Function *gindex_range = create_function ("gripe_index_out_of_range", + void_t, int_t, int_t, index_t, + index_t); + engine->addGlobalMapping (gindex_range, + reinterpret_cast (&octave_jit_gindex_range)); + + fn = create_function ("()subsref", scalar, matrix, scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (index_t, 1); + llvm::Value *ione; + if (index_t == int_t) + ione = one; + else + ione = llvm::ConstantInt::get (int_t, 1); + + + llvm::Value *szero = llvm::ConstantFP::get (scalar_t, 0); + + llvm::Function::arg_iterator args = fn->arg_begin (); + llvm::Value *mat = args++; + llvm::Value *idx = args; + + // convert index to scalar to integer, and check index >= 1 + llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t); + llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t); + llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx); + llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one); + llvm::Value *cond = builder.CreateOr (cond0, cond1); + + llvm::BasicBlock *done = llvm::BasicBlock::Create (context, "done", fn); + + llvm::BasicBlock *conv_error = llvm::BasicBlock::Create (context, + "conv_error", fn, + done); + llvm::BasicBlock *normal = llvm::BasicBlock::Create (context, "normal", fn, + done); + builder.CreateCondBr (cond, conv_error, normal); + + builder.SetInsertPoint (conv_error); + builder.CreateCall (ginvalid_index); + builder.CreateBr (done); + + builder.SetInsertPoint (normal); + llvm::Value *len = builder.CreateExtractValue (mat, + llvm::ArrayRef (2)); + cond = builder.CreateICmpSGT (int_idx, len); + + + llvm::BasicBlock *bounds_error = llvm::BasicBlock::Create (context, + "bounds_error", + fn, done); + + llvm::BasicBlock *success = llvm::BasicBlock::Create (context, "success", + fn, done); + builder.CreateCondBr (cond, bounds_error, success); + + builder.SetInsertPoint (bounds_error); + builder.CreateCall4 (gindex_range, ione, ione, int_idx, len); + builder.CreateBr (done); + + builder.SetInsertPoint (success); + llvm::Value *data = builder.CreateExtractValue (mat, + llvm::ArrayRef (1)); + llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx); + llvm::Value *ret = builder.CreateLoad (gep); + builder.CreateBr (done); + + builder.SetInsertPoint (done); + + llvm::PHINode *merge = llvm::PHINode::Create (scalar_t, 3); + builder.Insert (merge); + merge->addIncoming (szero, conv_error); + merge->addIncoming (szero, bounds_error); + merge->addIncoming (ret, success); + builder.CreateCall (release_mat, mat); + builder.CreateRet (merge); + } + llvm::verifyFunction (*fn); + paren_subsref_fn.add_overload (fn, true, scalar, matrix, scalar); + casts[any->type_id ()].stash_name ("(any)"); casts[scalar->type_id ()].stash_name ("(scalar)"); + // cast any <- matrix + fn = create_function ("octave_jit_cast_any_matrix", any_t, + matrix_t->getPointerTo ()); + engine->addGlobalMapping (fn, + reinterpret_cast (&octave_jit_cast_any_matrix)); + casts[any->type_id ()].add_overload (fn, false, any, matrix); + + // cast matrix <- any + fn = create_function ("octave_jit_cast_matrix_any", void_t, + matrix_t->getPointerTo (), any_t); + engine->addGlobalMapping (fn, + reinterpret_cast (&octave_jit_cast_matrix_any)); + casts[matrix->type_id ()].add_overload (fn, false, matrix, any); + // cast any <- scalar fn = create_function ("octave_jit_cast_any_scalar", any, scalar); engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_any_scalar)); @@ -740,14 +969,20 @@ jit_typeinfo::do_type_of (const octave_value &ov) const { if (ov.is_function ()) - return 0; - - if (ov.is_double_type () && ov.is_real_scalar ()) - return get_scalar (); + return 0; // functions are not supported if (ov.is_range ()) return get_range (); + if (ov.is_double_type ()) + { + if (ov.is_real_scalar ()) + return get_scalar (); + + if (ov.is_matrix_type ()) + return get_matrix (); + } + return get_any (); } @@ -1345,7 +1580,7 @@ if (jit_extract_argument *extract = dynamic_cast (*iter)) arguments.push_back (std::make_pair (extract->name (), true)); - convert_llvm to_llvm; + convert_llvm to_llvm (*this); function = to_llvm.convert (module, arguments, blocks, constants); #ifdef OCTAVE_JIT_DEBUG @@ -1686,9 +1921,34 @@ } void -jit_convert::visit_index_expression (tree_index_expression&) +jit_convert::visit_index_expression (tree_index_expression& exp) { - fail (); + std::string type = exp.type_tags (); + if (! (type.size () == 1 && type[0] == '(')) + fail ("Unsupported index operation"); + + std::list args = exp.arg_lists (); + if (args.size () != 1) + fail ("Bad number of arguments in tree_index_expression"); + + tree_argument_list *arg_list = args.front (); + if (arg_list->size () != 1) + fail ("Bad number of arguments in arg_list"); + + tree_expression *tree_object = exp.expression (); + jit_value *object = visit (tree_object); + + tree_expression *arg0 = arg_list->front (); + jit_value *index = visit (arg0); + + jit_call *call = create (jit_typeinfo::paren_subsref, object, index); + block->append (call); + + jit_block *normal = create (block->name ()); + block->append (create (call, normal, final_block)); + add_block (normal); + block = normal; + result = call; } void @@ -2286,7 +2546,7 @@ fail (ss.str ()); } - builder.CreateCall (ol.function, phi->argument_llvm (i)); + create_call (ol, phi->argument (i)); } } } @@ -2305,17 +2565,8 @@ const jit_function::overload& ol = jit_typeinfo::cast (phi->type (), phi->argument_type (i)); - if (! ol.function) - { - std::stringstream ss; - ss << "No cast for phi(" << i << "): "; - phi->print (ss); - fail (ss.str ()); - } - - llvm::Value *casted; - casted = builder.CreateCall (ol.function, - phi->argument_llvm (i)); + + llvm::Value *casted = create_call (ol, phi->argument (i)); llvm_phi->addIncoming (casted, pred); } } @@ -2343,14 +2594,14 @@ jit_convert::convert_llvm::visit (jit_const_range& cr) { llvm::StructType *stype = llvm::cast(cr.type_llvm ()); - llvm::Type *dbl = jit_typeinfo::get_scalar_llvm (); + llvm::Type *scalar_t = jit_typeinfo::get_scalar_llvm (); llvm::Type *idx = jit_typeinfo::get_index_llvm (); const jit_range& rng = cr.value (); llvm::Constant *constants[4]; - constants[0] = llvm::ConstantFP::get (dbl, rng.base); - constants[1] = llvm::ConstantFP::get (dbl, rng.limit); - constants[2] = llvm::ConstantFP::get (dbl, rng.inc); + constants[0] = llvm::ConstantFP::get (scalar_t, rng.base); + constants[1] = llvm::ConstantFP::get (scalar_t, rng.limit); + constants[2] = llvm::ConstantFP::get (scalar_t, rng.inc); constants[3] = llvm::ConstantInt::get (idx, rng.nelem); llvm::Value *as_llvm; @@ -2386,39 +2637,25 @@ void jit_convert::convert_llvm::visit (jit_call& call) { - const jit_function::overload& ol = call.overload (); - if (! ol.function) - fail ("No overload for: " + call.print_string ()); - - std::vector args (call.argument_count ()); - for (size_t i = 0; i < call.argument_count (); ++i) - args[i] = call.argument_llvm (i); - - call.stash_llvm (builder.CreateCall (ol.function, args)); + llvm::Value *ret = create_call (call.overload (), call.arguments ()); + call.stash_llvm (ret); } void jit_convert::convert_llvm::visit (jit_extract_argument& extract) { - const jit_function::overload& ol = extract.overload (); - if (! ol.function) - fail (); - llvm::Value *arg = arguments[extract.name ()]; assert (arg); arg = builder.CreateLoad (arg); - extract.stash_llvm (builder.CreateCall (ol.function, arg, extract.name ())); + + jit_value *jarg = jthis.create (jit_typeinfo::get_any (), arg); + extract.stash_llvm (create_call (extract.overload (), jarg)); } void jit_convert::convert_llvm::visit (jit_store_argument& store) { - llvm::Value *arg_value = store.result_llvm (); - const jit_function::overload& ol = store.overload (); - if (! ol.function) - fail (); - - arg_value = builder.CreateCall (ol.function, arg_value); + llvm::Value *arg_value = create_call (store.overload (), store.result ()); llvm::Value *arg = arguments[store.name ()]; store.stash_llvm (builder.CreateStore (arg_value, arg)); @@ -2463,6 +2700,69 @@ jit_convert::convert_llvm::visit (jit_assign&) {} +void +jit_convert::convert_llvm::visit (jit_argument&) +{} + +llvm::Value * +jit_convert::convert_llvm::create_call (const jit_function::overload& ol, + const std::vector& jargs) +{ + llvm::Function *fun = ol.function; + if (! fun) + fail ("Missing overload"); + + const llvm::Function::ArgumentListType& alist = fun->getArgumentList (); + size_t nargs = alist.size (); + bool sret = false; + if (nargs != jargs.size ()) + { + // first argument is the structure return value + assert (nargs == jargs.size () + 1); + sret = true; + } + + std::vector args (nargs); + llvm::Function::arg_iterator llvm_arg = fun->arg_begin (); + if (sret) + { + args[0] = builder.CreateAlloca (ol.result->to_llvm ()); + ++llvm_arg; + } + + for (size_t i = 0; i < jargs.size (); ++i, ++llvm_arg) + { + llvm::Value *arg = jargs[i]->to_llvm (); + llvm::Type *arg_type = arg->getType (); + llvm::Type *llvm_arg_type = llvm_arg->getType (); + + if (arg_type == llvm_arg_type) + args[i + sret] = arg; + else + { + // pass structure by pointer + assert (arg_type->getPointerTo () == llvm_arg_type); + llvm::Value *new_arg = builder.CreateAlloca (arg_type); + builder.CreateStore (arg, new_arg); + args[i + sret] = new_arg; + } + } + + llvm::Value *llvm_call = builder.CreateCall (fun, args); + return sret ? builder.CreateLoad (args[0]) : llvm_call; +} + +llvm::Value * +jit_convert::convert_llvm::create_call (const jit_function::overload& ol, + const std::vector& uses) +{ + std::vector values (uses.size ()); + for (size_t i = 0; i < uses.size (); ++i) + values[i] = uses[i].value (); + + return create_call (ol, values); +} + // -------------------- tree_jit -------------------- tree_jit::tree_jit (void) : module (0), engine (0) diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -219,6 +219,43 @@ std::ostream& operator<< (std::ostream& os, const jit_range& rng); +// jit_array is compatable with the llvm array/matrix structures +template +struct +jit_array +{ + jit_array (T& from) : ref_count (from.jit_ref_count ()), + slice_data (from.jit_slice_data () - 1), + slice_len (from.capacity ()), + dimensions (from.jit_dimensions ()), + array_rep (from.jit_array_rep ()) + { + grab_dimensions (); + } + + void grab_dimensions (void) + { + ++(dimensions[-2]); + } + + operator T () const + { + return T (slice_data + 1, slice_len, dimensions, array_rep); + } + + int *ref_count; + + U *slice_data; + octave_idx_type slice_len; + octave_idx_type *dimensions; + + void *array_rep; +}; + +typedef jit_array jit_matrix; + +std::ostream& operator<< (std::ostream& os, const jit_matrix& mat); + // Used to keep track of estimated (infered) types during JIT. This is a // hierarchical type system which includes both concrete and abstract types. // @@ -384,6 +421,8 @@ static jit_type *get_any (void) { return instance->any; } + static jit_type *get_matrix (void) { return instance->matrix; } + static jit_type *get_scalar (void) { return instance->scalar; } static llvm::Type *get_scalar_llvm (void) { return instance->scalar->to_llvm (); } @@ -445,6 +484,11 @@ return instance->make_range_fn; } + static const jit_function& paren_subsref (void) + { + return instance->paren_subsref_fn; + } + static const jit_function& logically_true (void) { return instance->logically_true_fn; @@ -597,6 +641,18 @@ } llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, + llvm::Type *arg0, llvm::Type *arg1, + llvm::Type *arg2, llvm::Type *arg3) + { + std::vector args (4); + args[0] = arg0; + args[1] = arg1; + args[2] = arg2; + args[3] = arg3; + return create_function (name, ret, args); + } + + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, const std::vector& args); llvm::Function *create_identity (jit_type *type); @@ -609,11 +665,11 @@ llvm::ExecutionEngine *engine; int next_id; - llvm::Type *ov_t; llvm::GlobalVariable *lerror_state; std::vector id_to_type; jit_type *any; + jit_type *matrix; jit_type *scalar; jit_type *range; jit_type *string; @@ -629,6 +685,7 @@ jit_function for_index_fn; jit_function logically_true_fn; jit_function make_range_fn; + jit_function paren_subsref_fn; // type id -> cast function TO that type std::vector casts; @@ -651,7 +708,8 @@ JIT_METH(phi); \ JIT_METH(variable); \ JIT_METH(check_error); \ - JIT_METH(assign) + JIT_METH(assign) \ + JIT_METH(argument) #define JIT_VISIT_IR_CONST \ JIT_METH(const_scalar); \ @@ -830,18 +888,18 @@ : already_infered (nargs, reinterpret_cast(0)), mid (next_id ()), mparent (0) { - arguments.reserve (nargs); + marguments.reserve (nargs); } jit_instruction (jit_value *arg0) - : already_infered (1, reinterpret_cast(0)), arguments (1), + : already_infered (1, reinterpret_cast(0)), marguments (1), mid (next_id ()), mparent (0) { stash_argument (0, arg0); } jit_instruction (jit_value *arg0, jit_value *arg1) - : already_infered (2, reinterpret_cast(0)), arguments (2), + : already_infered (2, reinterpret_cast(0)), marguments (2), mid (next_id ()), mparent (0) { stash_argument (0, arg0); @@ -849,7 +907,7 @@ } jit_instruction (jit_value *arg0, jit_value *arg1, jit_value *arg2) - : already_infered (3, reinterpret_cast(0)), arguments (3), + : already_infered (3, reinterpret_cast(0)), marguments (3), mid (next_id ()), mparent (0) { stash_argument (0, arg0); @@ -859,7 +917,7 @@ jit_instruction (jit_value *arg0, jit_value *arg1, jit_value *arg2, jit_value *arg3) - : already_infered (3, reinterpret_cast(0)), arguments (4), + : already_infered (3, reinterpret_cast(0)), marguments (4), mid (next_id ()), mparent (0) { stash_argument (0, arg0); @@ -875,7 +933,7 @@ jit_value *argument (size_t i) const { - return arguments[i].value (); + return marguments[i].value (); } llvm::Value *argument_llvm (size_t i) const @@ -905,25 +963,25 @@ void stash_argument (size_t i, jit_value *arg) { - arguments[i].stash_value (arg, this, i); + marguments[i].stash_value (arg, this, i); } void push_argument (jit_value *arg) { - arguments.push_back (jit_use ()); - stash_argument (arguments.size () - 1, arg); + marguments.push_back (jit_use ()); + stash_argument (marguments.size () - 1, arg); already_infered.push_back (0); } size_t argument_count (void) const { - return arguments.size (); + return marguments.size (); } void resize_arguments (size_t acount, jit_value *adefault = 0) { - size_t old = arguments.size (); - arguments.resize (acount); + size_t old = marguments.size (); + marguments.resize (acount); already_infered.resize (acount); if (adefault) @@ -931,6 +989,8 @@ stash_argument (i, adefault); } + const std::vector& arguments (void) const { return marguments; } + // argument types which have been infered already const std::vector& argument_types (void) const { return already_infered; } @@ -974,7 +1034,7 @@ return ret++; } - std::vector arguments; + std::vector marguments; size_t mid; jit_block *mparent; @@ -982,9 +1042,29 @@ }; // defnie accept methods for subclasses -#define JIT_VALUE_ACCEPT(clname) \ +#define JIT_VALUE_ACCEPT \ virtual void accept (jit_ir_walker& walker); +// for use as a dummy argument during conversion to LLVM +class +jit_argument : public jit_value +{ +public: + jit_argument (jit_type *atype, llvm::Value *avalue) + { + stash_type (atype); + stash_llvm (avalue); + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent); + return jit_print (os, type ()) << ": DUMMY"; + } + + JIT_VALUE_ACCEPT; +}; + template class @@ -1012,7 +1092,7 @@ return os; } - JIT_VALUE_ACCEPT (jit_const); + JIT_VALUE_ACCEPT; private: T mvalue; }; @@ -1212,7 +1292,7 @@ void stash_location (std::list::iterator alocation) { mlocation = alocation; } - JIT_VALUE_ACCEPT (block); + JIT_VALUE_ACCEPT; private: void internal_append (jit_instruction *instr); @@ -1370,7 +1450,7 @@ return print_indent (os, indent) << mname; } - JIT_VALUE_ACCEPT (variable) + JIT_VALUE_ACCEPT; private: std::string mname; std::stack value_stack; @@ -1426,7 +1506,7 @@ return print_indent (os, indent) << *dest () << " = " << *src (); } - JIT_VALUE_ACCEPT (assign); + JIT_VALUE_ACCEPT; private: jit_variable *mdest; }; @@ -1498,7 +1578,7 @@ return os << "#" << id (); } - JIT_VALUE_ACCEPT (phi); + JIT_VALUE_ACCEPT; private: std::vector mincomming; }; @@ -1597,7 +1677,7 @@ return print_successor (os); } - JIT_VALUE_ACCEPT (break) + JIT_VALUE_ACCEPT; }; class @@ -1629,7 +1709,7 @@ return print_successor (os, 1); } - JIT_VALUE_ACCEPT (cond_break) + JIT_VALUE_ACCEPT; }; class @@ -1691,7 +1771,7 @@ virtual bool infer (void); - JIT_VALUE_ACCEPT (call) + JIT_VALUE_ACCEPT; private: const jit_function& mfunction; }; @@ -1718,7 +1798,7 @@ return print_successor (os, 0); } - JIT_VALUE_ACCEPT (jit_check_error) + JIT_VALUE_ACCEPT; protected: virtual bool check_alive (size_t idx) const { @@ -1753,7 +1833,7 @@ return short_print (os) << " = extract " << name (); } - JIT_VALUE_ACCEPT (extract_argument) + JIT_VALUE_ACCEPT; }; class @@ -1804,7 +1884,7 @@ return os; } - JIT_VALUE_ACCEPT (store_argument) + JIT_VALUE_ACCEPT; private: jit_variable *dest; }; @@ -2103,6 +2183,8 @@ convert_llvm : public jit_ir_walker { public: + convert_llvm (jit_convert& jc) : jthis (jc) {} + llvm::Function *convert (llvm::Module *module, const std::vector >& args, const std::list& blocks, @@ -2129,7 +2211,30 @@ { jvalue.accept (*this); } + + llvm::Value *create_call (const jit_function::overload& ol, jit_value *arg0) + { + std::vector args (1, arg0); + return create_call (ol, args); + } + + llvm::Value *create_call (const jit_function::overload& ol, jit_value *arg0, + jit_value *arg1) + { + std::vector args (2); + args[0] = arg0; + args[1] = arg1; + + return create_call (ol, args); + } + + llvm::Value *create_call (const jit_function::overload& ol, + const std::vector& jargs); + + llvm::Value *create_call (const jit_function::overload& ol, + const std::vector& uses); private: + jit_convert &jthis; llvm::Function *function; }; };