# HG changeset patch # User Max Brister # Date 1343858412 18000 # Node ID fe4752f772e29f95737a83a362d73a7efb9e2d3b # Parent f0b04a20d7cfa90b20cd616c979f9c0587c7739a Generate ND indexing functions on demand in JIT. * src/jit-typeinfo.cc (jit_operation::~jit_operation, jit_operation::do_generate, jit_operation::generate, jit_operation::signature_cmp::operator()): New function. (jit_operation::overload): Call do_generate when lookup fails. (jit_index_operation, jit_paren_subsref, jit_paren_subsasgn): New class. (jit_typeinfo::jit_typeinfo): Update to use jit_paren_subsref and jit_paren_subsasgn. (jit_typeinfo::gen_subsref, jit_typeinfo::gen_subsasgn): Removed functions. * src/jit-typeinfo.h (jit_operation::~jit_operation, jit_operation::generate, jit_operation::do_generate): New declaration. (jit_operation::add_overload, jit_operation::overload, jit_operation::result, jit_operation::to_idx): Use signature_vec typedef. (jit_operation::singature_cmp): New class. (jit_index_operation, jit_paren_subsref, jit_paren_subsasgn): New class. (jit_typeinfo::get_scalar_ptr): Nwe function. (jit_typeinfo::gen_subsref, jit_typeinfo::gen_subsasgn): Removed declaration. * src/pt-jit.cc: New test. diff --git a/src/jit-typeinfo.cc b/src/jit-typeinfo.cc --- a/src/jit-typeinfo.cc +++ b/src/jit-typeinfo.cc @@ -708,6 +708,16 @@ } // -------------------- jit_operation -------------------- +jit_operation::~jit_operation (void) +{ + for (generated_map::iterator iter = generated.begin (); + iter != generated.end (); ++iter) + { + delete iter->first; + delete iter->second; + } +} + void jit_operation::add_overload (const jit_function& func, const std::vector& args) @@ -742,23 +752,26 @@ const jit_function& jit_operation::overload (const std::vector& types) const { - // FIXME: We should search for the next best overload on failure static jit_function null_overload; - if (types.size () >= overloads.size ()) - return null_overload; - for (size_t i =0; i < types.size (); ++i) if (! types[i]) return null_overload; + if (types.size () >= overloads.size ()) + return do_generate (types); + const Array& over = overloads[types.size ()]; dim_vector dv (over.dims ()); Array idx = to_idx (types); for (octave_idx_type i = 0; i < dv.length (); ++i) if (idx(i) >= dv(i)) - return null_overload; + return do_generate (types); - return over(idx); + const jit_function& ret = over(idx); + if (! ret.valid ()) + return do_generate (types); + + return ret; } Array @@ -782,6 +795,175 @@ return idx; } +const jit_function& +jit_operation::do_generate (const signature_vec& types) const +{ + static jit_function null_overload; + generated_map::const_iterator find = generated.find (&types); + if (find != generated.end ()) + { + if (find->second) + return *find->second; + else + return null_overload; + } + + jit_function *ret = generate (types); + generated[new signature_vec (types)] = ret; + return ret ? *ret : null_overload; +} + +jit_function * +jit_operation::generate (const signature_vec& types) const +{ + return 0; +} + +bool +jit_operation::signature_cmp +::operator() (const signature_vec *lhs, const signature_vec *rhs) +{ + const signature_vec& l = *lhs; + const signature_vec& r = *rhs; + + if (l.size () < r.size ()) + return true; + else if (l.size () > r.size ()) + return false; + + for (size_t i = 0; i < l.size (); ++i) + { + if (l[i]->type_id () < r[i]->type_id ()) + return true; + else if (l[i]->type_id () > r[i]->type_id ()) + return false; + } + + return false; +} + +// -------------------- jit_index_operation -------------------- +jit_function * +jit_index_operation::generate (const signature_vec& types) const +{ + if (types.size () > 2 && types[0] == jit_typeinfo::get_matrix ()) + { + // indexing a matrix with scalars + jit_type *scalar = jit_typeinfo::get_scalar (); + for (size_t i = 1; i < types.size (); ++i) + if (types[i] != scalar) + return 0; + + return generate_matrix (types); + } + + return 0; +} + +llvm::Value * +jit_index_operation::create_arg_array (llvm::IRBuilderD& builder, + const jit_function &fn, size_t start_idx, + size_t end_idx) const +{ + size_t n = end_idx - start_idx; + llvm::Type *scalar_t = jit_typeinfo::get_scalar_llvm (); + llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n); + llvm::Value *array = llvm::UndefValue::get (array_t); + for (size_t i = start_idx; i < end_idx; ++i) + { + llvm::Value *idx = fn.argument (builder, i); + array = builder.CreateInsertValue (array, idx, i - start_idx); + } + + llvm::Value *array_mem = builder.CreateAlloca (array_t); + builder.CreateStore (array, array_mem); + return builder.CreateBitCast (array_mem, scalar_t->getPointerTo ()); +} + +// -------------------- jit_paren_subsref -------------------- +jit_function * +jit_paren_subsref::generate_matrix (const signature_vec& types) const +{ + std::stringstream ss; + ss << "jit_paren_subsref_matrix_scalar" << (types.size () - 1); + + jit_type *scalar = jit_typeinfo::get_scalar (); + jit_function *fn = new jit_function (module, jit_convention::internal, + ss.str (), scalar, types); + fn->mark_can_error (); + llvm::BasicBlock *body = fn->new_block (); + llvm::IRBuilder<> builder (body); + + llvm::Value *array = create_arg_array (builder, *fn, 1, types.size ()); + jit_type *index = jit_typeinfo::get_index (); + llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), + types.size () - 1); + llvm::Value *mat = fn->argument (builder, 0); + llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem); + fn->do_return (builder, ret); + return fn; +} + +void +jit_paren_subsref::do_initialize (void) +{ + std::vector types (3); + types[0] = jit_typeinfo::get_matrix (); + types[1] = jit_typeinfo::get_scalar_ptr (); + types[2] = jit_typeinfo::get_index (); + + jit_type *scalar = jit_typeinfo::get_scalar (); + paren_scalar = jit_function (module, jit_convention::external, + "octave_jit_paren_scalar", scalar, types); + paren_scalar.add_mapping (engine, &octave_jit_paren_scalar); + paren_scalar.mark_can_error (); +} + +// -------------------- jit_paren_subsasgn -------------------- +jit_function * +jit_paren_subsasgn::generate_matrix (const signature_vec& types) const +{ + std::stringstream ss; + ss << "jit_paren_subsasgn_matrix_scalar" << (types.size () - 2); + + jit_type *matrix = jit_typeinfo::get_matrix (); + jit_function *fn = new jit_function (module, jit_convention::internal, + ss.str (), matrix, types); + fn->mark_can_error (); + llvm::BasicBlock *body = fn->new_block (); + llvm::IRBuilder<> builder (body); + + llvm::Value *array = create_arg_array (builder, *fn, 1, types.size () - 1); + jit_type *index = jit_typeinfo::get_index (); + llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), + types.size () - 2); + + llvm::Value *mat = fn->argument (builder, 0); + llvm::Value *value = fn->argument (builder, types.size () - 1); + llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem, value); + fn->do_return (builder, ret); + return fn; +} + +void +jit_paren_subsasgn::do_initialize (void) +{ + if (paren_scalar.valid ()) + return; + + jit_type *matrix = jit_typeinfo::get_matrix (); + std::vector types (4); + types[0] = matrix; + types[1] = jit_typeinfo::get_scalar_ptr (); + types[2] = jit_typeinfo::get_index (); + types[3] = jit_typeinfo::get_scalar (); + + paren_scalar = jit_function (module, jit_convention::external, + "octave_jit_paren_scalar", matrix, types); + paren_scalar.add_mapping (engine, &octave_jit_paren_scalar_subsasgn); + paren_scalar.mark_can_error (); +} + // -------------------- jit_typeinfo -------------------- void jit_typeinfo::initialize (llvm::Module *m, llvm::ExecutionEngine *e) @@ -835,14 +1017,12 @@ matrix = new_type ("matrix", any, matrix_t); complex = new_type ("complex", any, complex_t); scalar = new_type ("scalar", complex, scalar_t); + scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ()); range = new_type ("range", any, range_t); string = new_type ("string", any, string_t); boolean = new_type ("bool", any, bool_t); index = new_type ("index", any, index_t); - // a fake type for interfacing with C++ - jit_type *scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ()); - create_int (8); create_int (16); create_int (32); @@ -867,6 +1047,9 @@ if (sizeof (void *) == 4) complex->mark_sret (); + paren_subsref_fn.initialize (module, engine); + paren_subsasgn_fn.initialize (module, engine); + // bind global variables lerror_state = new llvm::GlobalVariable (*module, bool_t, false, llvm::GlobalValue::ExternalLinkage, @@ -1364,28 +1547,6 @@ } paren_subsref_fn.add_overload (fn); - // generate () subsref for ND indexing of matricies with scalars - jit_function paren_scalar = create_function (jit_convention::external, - "octave_jit_paren_scalar", - scalar, matrix, scalar_ptr, - index); - paren_scalar.add_mapping (engine, &octave_jit_paren_scalar); - paren_scalar.mark_can_error (); - - jit_function paren_scalar_subsasgn - = create_function (jit_convention::external, - "octave_jit_paren_scalar_subsasgn", matrix, matrix, - scalar_ptr, index, scalar); - paren_scalar_subsasgn.add_mapping (engine, &octave_jit_paren_scalar_subsasgn); - paren_scalar_subsasgn.mark_can_error (); - - // FIXME: Generate this on the fly - for (size_t i = 2; i < 10; ++i) - { - gen_subsref (paren_scalar, i); - gen_subsasgn (paren_scalar_subsasgn, i); - } - // paren subsasgn paren_subsasgn_fn.stash_name ("()subsasgn"); @@ -1907,71 +2068,4 @@ return ret; } -void -jit_typeinfo::gen_subsref (const jit_function& paren_scalar, size_t n) -{ - std::stringstream name; - name << "jit_paren_subsref_matrix_scalar" << n; - std::vector args (n + 1, scalar); - args[0] = matrix; - jit_function fn = create_function (jit_convention::internal, name.str (), - scalar, args); - fn.mark_can_error (); - llvm::BasicBlock *body = fn.new_block (); - builder.SetInsertPoint (body); - - llvm::Type *scalar_t = scalar->to_llvm (); - llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n); - llvm::Value *array = llvm::UndefValue::get (array_t); - for (size_t i = 0; i < n; ++i) - { - llvm::Value *idx = fn.argument (builder, i + 1); - array = builder.CreateInsertValue (array, idx, i); - } - - llvm::Value *array_mem = builder.CreateAlloca (array_t); - builder.CreateStore (array, array_mem); - array = builder.CreateBitCast (array_mem, scalar_t->getPointerTo ()); - - llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), n); - llvm::Value *mat = fn.argument (builder, 0); - llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem); - fn.do_return (builder, ret); - paren_subsref_fn.add_overload (fn); -} - -void -jit_typeinfo::gen_subsasgn (const jit_function& paren_scalar, size_t n) -{ - std::stringstream name; - name << "jit_paren_subsasgn_matrix_scalar" << n; - std::vector args (n + 2, scalar); - args[0] = matrix; - jit_function fn = create_function (jit_convention::internal, name.str (), - matrix, args); - fn.mark_can_error (); - llvm::BasicBlock *body = fn.new_block (); - builder.SetInsertPoint (body); - - llvm::Type *scalar_t = scalar->to_llvm (); - llvm::ArrayType *array_t = llvm::ArrayType::get (scalar_t, n); - llvm::Value *array = llvm::UndefValue::get (array_t); - for (size_t i = 0; i < n; ++i) - { - llvm::Value *idx = fn.argument (builder, i + 1); - array = builder.CreateInsertValue (array, idx, i); - } - - llvm::Value *array_mem = builder.CreateAlloca (array_t); - builder.CreateStore (array, array_mem); - array = builder.CreateBitCast (array_mem, scalar_t->getPointerTo ()); - - llvm::Value *nelem = llvm::ConstantInt::get (index->to_llvm (), n); - llvm::Value *mat = fn.argument (builder, 0); - llvm::Value *value = fn.argument (builder, n + 1); - llvm::Value *ret = paren_scalar.call (builder, mat, array, nelem, value); - fn.do_return (builder, ret); - paren_subsasgn_fn.add_overload (fn); -} - #endif diff --git a/src/jit-typeinfo.h b/src/jit-typeinfo.h --- a/src/jit-typeinfo.h +++ b/src/jit-typeinfo.h @@ -314,17 +314,22 @@ jit_operation { public: + // type signature vector + typedef std::vector signature_vec; + + virtual ~jit_operation (void); + void add_overload (const jit_function& func) { add_overload (func, func.arguments ()); } void add_overload (const jit_function& func, - const std::vector& args); + const signature_vec& args); - const jit_function& overload (const std::vector& types) const; + const jit_function& overload (const signature_vec& types) const; - jit_type *result (const std::vector& types) const + jit_type *result (const signature_vec& types) const { const jit_function& temp = overload (types); return temp.result (); @@ -346,14 +351,79 @@ const std::string& name (void) const { return mname; } void stash_name (const std::string& aname) { mname = aname; } +protected: + virtual jit_function *generate (const signature_vec& types) const; private: - Array to_idx (const std::vector& types) const; + Array to_idx (const signature_vec& types) const; + + const jit_function& do_generate (const signature_vec& types) const; + + struct signature_cmp + { + bool operator() (const signature_vec *lhs, const signature_vec *rhs); + }; + + typedef std::map + generated_map; + + mutable generated_map generated; std::vector > overloads; std::string mname; }; +class +jit_index_operation : public jit_operation +{ +public: + jit_index_operation (void) : module (0), engine (0) {} + + void initialize (llvm::Module *amodule, llvm::ExecutionEngine *aengine) + { + module = amodule; + engine = aengine; + do_initialize (); + } +protected: + virtual jit_function *generate (const signature_vec& types) const; + + virtual jit_function *generate_matrix (const signature_vec& types) const = 0; + + virtual void do_initialize (void) = 0; + + // helper functions + // [start_idx, end_idx). + llvm::Value *create_arg_array (llvm::IRBuilderD& builder, + const jit_function &fn, size_t start_idx, + size_t end_idx) const; + + llvm::Module *module; + llvm::ExecutionEngine *engine; +}; + +class +jit_paren_subsref : public jit_index_operation +{ +protected: + virtual jit_function *generate_matrix (const signature_vec& types) const; + + virtual void do_initialize (void); +private: + jit_function paren_scalar; +}; + +class +jit_paren_subsasgn : public jit_index_operation +{ +protected: + jit_function *generate_matrix (const signature_vec& types) const; + + virtual void do_initialize (void); +private: + jit_function paren_scalar; +}; + // A singleton class which handles the construction of jit_types and // jit_operations. class @@ -376,6 +446,8 @@ static llvm::Type *get_scalar_llvm (void) { return instance->scalar->to_llvm (); } + static jit_type *get_scalar_ptr (void) { return instance->scalar_ptr; } + static jit_type *get_range (void) { return instance->range; } static jit_type *get_string (void) { return instance->string; } @@ -631,10 +703,6 @@ jit_type *intN (size_t nbits) const; - void gen_subsref (const jit_function& paren_scalar, size_t n); - - void gen_subsasgn (const jit_function& paren_scalar, size_t n); - static jit_typeinfo *instance; llvm::Module *module; @@ -647,6 +715,7 @@ jit_type *any; jit_type *matrix; jit_type *scalar; + jit_type *scalar_ptr; // a fake type for interfacing with C++ jit_type *range; jit_type *string; jit_type *boolean; @@ -667,8 +736,8 @@ jit_operation for_index_fn; jit_operation logically_true_fn; jit_operation make_range_fn; - jit_operation paren_subsref_fn; - jit_operation paren_subsasgn_fn; + jit_paren_subsref paren_subsref_fn; + jit_paren_subsasgn paren_subsasgn_fn; jit_operation end_fn; // type id -> cast function TO that type diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -1883,4 +1883,25 @@ %! m2(:) = 1:(ndim^2); %! assert (all (m == m2)); +%!test +%! ndim = 2; +%! m = zeros (ndim, ndim, ndim, ndim); +%! result = 0; +%! i0 = 1; +%! while (i0 <= ndim) +%! for i1 = 1:ndim +%! for i2 = 1:ndim +%! for i3 = 1:ndim +%! m(i0, i1, i2, i3) = 1; +%! m(i0, i1, i2, i3, 1, 1, 1, 1, 1, 1) = 1; +%! result = result + m(i0, i1, i2, i3); +%! endfor +%! endfor +%! endfor +%! i0 = i0 + 1; +%! endwhile +%! expected = ones (ndim, ndim, ndim, ndim); +%! assert (all (m == expected)); +%! assert (result == sum (expected (:))); + */