Mercurial > hg > octave-lyh
changeset 17524:534247e14b03
Merge the official development
author | LYH <lyh.kernel@gmail.com> |
---|---|
date | Fri, 27 Sep 2013 03:01:11 +0800 |
parents | 417fae0562da (diff) 55680de6a897 (current diff) |
children | 1d0fa3c34ad7 |
files | |
diffstat | 5 files changed, 906 insertions(+), 31 deletions(-) [+] |
line wrap: on
line diff
--- a/libinterp/corefcn/jit-typeinfo.cc +++ b/libinterp/corefcn/jit-typeinfo.cc @@ -1,3 +1,5 @@ +#pragma GCC diagnostic push +#pragma GCC diagnostic error "-Werror" /* Copyright (C) 2012 Max Brister @@ -66,8 +68,15 @@ #include "ov-builtin.h" #include "ov-complex.h" #include "ov-scalar.h" +#include "ov-float.h" +#include "ov-int8.h" +#include "ov-int64.h" +#include "ov-uint16.h" #include "pager.h" +typedef __int128_t int128_t; +typedef __uint128_t uint128_t; + static llvm::LLVMContext& context = llvm::getGlobalContext (); jit_typeinfo *jit_typeinfo::instance = 0; @@ -190,6 +199,62 @@ return new octave_scalar (value); } +extern "C" float +octave_jit_cast_single_any (octave_base_value *obv) +{ + float ret = obv->float_value (); + obv->release (); + return ret; +} + +extern "C" octave_base_value * +octave_jit_cast_any_single (float value) +{ + return new octave_float_scalar (value); +} + +extern "C" int8_t +octave_jit_cast_int8_any (octave_base_value *obv) +{ + int8_t ret = obv->int8_scalar_value (); + obv->release (); + return ret; +} + +extern "C" octave_base_value * +octave_jit_cast_any_int8 (int8_t value) +{ + return new octave_int8_scalar (value); +} + +extern "C" int64_t +octave_jit_cast_int64_any (octave_base_value *obv) +{ + int64_t ret = obv->int64_scalar_value (); + obv->release (); + return ret; +} + +extern "C" octave_base_value * +octave_jit_cast_any_int64 (int64_t value) +{ + return new octave_int64_scalar (value); +} + +extern "C" uint16_t +octave_jit_cast_uint16_any (octave_base_value *obv) +{ + uint16_t ret = obv->uint16_scalar_value (); + obv->release (); + return ret; +} + +extern "C" octave_base_value * +octave_jit_cast_any_uint16 (uint16_t value) +{ + return new octave_uint16_scalar (value); +} + extern "C" Complex octave_jit_cast_complex_any (octave_base_value *obv) { @@ -444,6 +509,25 @@ return std::pow (lhs, rhs); } +// FIXME: should handle FloatComplex +extern "C" Complex +octave_jit_pow_single_single (float lhs, float rhs) +{ + // FIXME: almost CP from libinterp/corefcn/xpow.cc + float retval; + + if (lhs < 0.0 && ! xisint (rhs)) + { + FloatComplex lhstmp (lhs); + + return std::pow (lhstmp, rhs); + } + else + retval = std::pow (lhs, rhs); + + return retval; +} + extern "C" Complex octave_jit_pow_complex_scalar (Complex lhs, double rhs) { @@ -460,6 +544,191 @@ return std::pow (lhs, rhs); } +/************************************************************ + * + * int8 related external helper function + * + ************************************************************/ + +extern "C" int8_t +octave_jit_add_int8_int8 (int8_t lhs, int8_t rhs) +{ + uint8_t ulhs = lhs; + uint8_t urhs = rhs; + uint8_t res = ulhs + urhs; + + /* Calculate overflowed result. (Don't change the sign bit of ux) */ + ulhs = (ulhs >> 7) + SCHAR_MAX; + + /* Force compiler to use cmovns instruction */ + if ((int8_t) ((ulhs ^ urhs) | ~(urhs ^ res)) >= 0) + { + res = ulhs; + } + + return res; +} + +extern "C" int8_t +octave_jit_sub_int8_int8 (int8_t lhs, int8_t rhs) +{ + uint8_t ulhs = lhs; + uint8_t urhs = rhs; + uint8_t res = ulhs - urhs; + + ulhs = (ulhs >> 7) + SCHAR_MAX; + + /* Force compiler to use cmovns instruction */ + if ((int8_t)((ulhs ^ urhs) & (ulhs ^ res)) < 0) + { + res = ulhs; + } + + return res; +} + +extern "C" int8_t +octave_jit_mul_int8_int8 (int8_t lhs, int8_t rhs) +{ + int16_t res = (int16_t) lhs * (int16_t) rhs; + uint8_t res2 = ((uint8_t) (lhs ^ rhs) >> 7) + SCHAR_MAX; + + int8_t hi = (res >> 8); + int8_t lo = res; + + if (hi != (lo >> 7)) res = res2; + + return res; +} + +extern "C" int8_t +octave_jit_incr_int8 (int8_t val) +{ + return octave_jit_add_int8_int8 (val, 1); +} + +extern "C" int8_t +octave_jit_decr_int8 (int8_t val) +{ + return octave_jit_sub_int8_int8 (val, 1); +} + +/************************************************************ + * + * int64 related external helper function + * + ************************************************************/ + +extern "C" int64_t +octave_jit_add_int64_int64 (int64_t lhs, int64_t rhs) +{ + uint64_t ulhs = lhs; + uint64_t urhs = rhs; + uint64_t res = ulhs + urhs; + + /* Calculate overflowed result. (Don't change the sign bit of ux) */ + ulhs = (ulhs >> 63) + LONG_MAX; + + /* Force compiler to use cmovns instruction */ + if ((int64_t) ((ulhs ^ urhs) | ~(urhs ^ res)) >= 0) + { + res = ulhs; + } + + return res; +} + +extern "C" int64_t +octave_jit_sub_int64_int64 (int64_t lhs, int64_t rhs) +{ + uint64_t ulhs = lhs; + uint64_t urhs = rhs; + uint64_t res = ulhs - urhs; + + ulhs = (ulhs >> 63) + LONG_MAX; + + /* Force compiler to use cmovns instruction */ + if ((int64_t)((ulhs ^ urhs) & (ulhs ^ res)) < 0) + { + res = ulhs; + } + + return res; +} + +extern "C" int64_t +octave_jit_mul_int64_int64 (int64_t lhs, int64_t rhs) +{ + int128_t res = (int128_t) lhs * (int128_t) rhs; + uint64_t res2 = ((uint64_t) (lhs ^ rhs) >> 63) + LONG_MAX; + + int64_t hi = (res >> 64); + int64_t lo = res; + + if (hi != (lo >> 63)) res = res2; + + return res; +} + +extern "C" int64_t +octave_jit_incr_int64 (int64_t val) +{ + return octave_jit_add_int64_int64 (val, 1); +} + +extern "C" int64_t +octave_jit_decr_int64 (int64_t val) +{ + return octave_jit_sub_int64_int64 (val, 1); +} + +/************************************************************ + * + * uint16 related external helper function + * + ************************************************************/ + +extern "C" uint16_t +octave_jit_add_uint16_uint16 (uint16_t lhs, uint16_t rhs) +{ + uint16_t res = lhs + rhs; + res |= -(res < lhs); + + return res; +} + +extern "C" uint16_t +octave_jit_sub_uint16_uint16 (uint16_t lhs, uint16_t rhs) +{ + uint16_t res = lhs - rhs; + res &= -(res <= lhs); + + return res; +} + +extern "C" uint16_t +octave_jit_mul_uint16_uint16 (uint16_t lhs, uint16_t rhs) +{ + uint32_t res = (uint32_t) lhs * (uint32_t) rhs; + + uint16_t hi = res >> 16; + uint16_t lo = res; + + return lo | -!!hi; +} + +extern "C" uint16_t +octave_jit_incr_uint16 (uint16_t val) +{ + return octave_jit_add_uint16_uint16 (val, 1); +} + +extern "C" uint16_t +octave_jit_decr_uint16 (uint16_t val) +{ + return octave_jit_sub_uint16_uint16 (val, 1); +} + extern "C" void octave_jit_print_matrix (jit_matrix *m) { @@ -655,6 +924,10 @@ jit_function::call (llvm::IRBuilderD& builder, const std::vector<jit_value *>& in_args) const { + // FIXME: Unhandled case: + // function ret = lt(x, y) + // ret = x < y; + // endfunction if (! valid ()) throw jit_fail_exception ("Call not implemented"); @@ -1080,6 +1353,10 @@ any_t = any_t->getPointerTo (); llvm::Type *scalar_t = llvm::Type::getDoubleTy (context); + llvm::Type *single_t = llvm::Type::getFloatTy (context); + llvm::Type *int8__t = llvm::Type::getIntNTy (context, 8); + llvm::Type *int64__t = llvm::Type::getIntNTy (context, 64); + llvm::Type *uint16__t = llvm::Type::getIntNTy (context, 16); llvm::Type *bool_t = llvm::Type::getInt1Ty (context); llvm::Type *string_t = llvm::Type::getInt8Ty (context); string_t = string_t->getPointerTo (); @@ -1120,6 +1397,7 @@ matrix = new_type ("matrix", any, matrix_t); complex = new_type ("complex", any, complex_t); scalar = new_type ("scalar", complex, scalar_t); + single = new_type ("single", any, single_t); scalar_ptr = new_type ("scalar_ptr", 0, scalar_t->getPointerTo ()); any_ptr = new_type ("any_ptr", 0, any_t->getPointerTo ()); range = new_type ("range", any, range_t); @@ -1132,6 +1410,8 @@ create_int (32); create_int (64); + create_uint (16); + casts.resize (next_id + 1); identities.resize (next_id + 1); @@ -1224,6 +1504,10 @@ grab_fn.add_overload (fn); grab_fn.add_overload (create_identity (scalar)); + grab_fn.add_overload (create_identity (single)); + grab_fn.add_overload (create_identity (intN (8))); + grab_fn.add_overload (create_identity (intN (64))); + grab_fn.add_overload (create_identity (uintN (16))); grab_fn.add_overload (create_identity (scalar_ptr)); grab_fn.add_overload (create_identity (any_ptr)); grab_fn.add_overload (create_identity (boolean)); @@ -1243,10 +1527,20 @@ destroy_fn = release_fn; destroy_fn.stash_name ("destroy"); destroy_fn.add_overload (create_identity(scalar)); + destroy_fn.add_overload (create_identity(single)); + destroy_fn.add_overload (create_identity(intN (8))); + destroy_fn.add_overload (create_identity(intN (64))); + destroy_fn.add_overload (create_identity(uintN (16))); destroy_fn.add_overload (create_identity(boolean)); destroy_fn.add_overload (create_identity(index)); destroy_fn.add_overload (create_identity(complex)); + /************************************************************ + * + * scalar related operations + * + ************************************************************/ + // now for binary scalar operations add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd); add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub); @@ -1335,12 +1629,19 @@ val = builder.CreateFMul (val, mone); fn.do_return (builder, val); } + unary_ops[octave_value::op_uminus].add_overload (fn); fn = create_identity (scalar); unary_ops[octave_value::op_uplus].add_overload (fn); unary_ops[octave_value::op_transpose].add_overload (fn); unary_ops[octave_value::op_hermitian].add_overload (fn); + /************************************************************ + * + * complex related operations + * + ************************************************************/ + // now for binary complex operations fn = create_internal ("octave_jit_+_complex_complex", complex, complex, complex); @@ -1388,6 +1689,299 @@ binary_ops[octave_value::op_pow].add_overload (fn); binary_ops[octave_value::op_el_pow].add_overload (fn); + /************************************************************ + * + * single related operations + * + ************************************************************/ + + // now for binary single operations + add_binary_op (single, octave_value::op_add, llvm::Instruction::FAdd); + add_binary_op (single, octave_value::op_sub, llvm::Instruction::FSub); + add_binary_op (single, octave_value::op_mul, llvm::Instruction::FMul); + add_binary_op (single, octave_value::op_el_mul, llvm::Instruction::FMul); + + add_binary_fcmp (single, octave_value::op_lt, llvm::CmpInst::FCMP_ULT); + add_binary_fcmp (single, octave_value::op_le, llvm::CmpInst::FCMP_ULE); + add_binary_fcmp (single, octave_value::op_eq, llvm::CmpInst::FCMP_UEQ); + add_binary_fcmp (single, octave_value::op_ge, llvm::CmpInst::FCMP_UGE); + add_binary_fcmp (single, octave_value::op_gt, llvm::CmpInst::FCMP_UGT); + add_binary_fcmp (single, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); + + // divide is annoying because it might error + fn = create_internal ("octave_jit_div_single_single", single, single, single); + fn.mark_can_error (); + + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::BasicBlock *warn_block = fn.new_block ("warn"); + llvm::BasicBlock *normal_block = fn.new_block ("normal"); + + llvm::Value *zero = llvm::ConstantFP::get (single_t, 0); + llvm::Value *check = builder.CreateFCmpUEQ (zero, fn.argument (builder, 1)); + builder.CreateCondBr (check, warn_block, normal_block); + + builder.SetInsertPoint (warn_block); + gripe_div0.call (builder); + builder.CreateBr (normal_block); + + builder.SetInsertPoint (normal_block); + llvm::Value *ret = builder.CreateFDiv (fn.argument (builder, 0), + fn.argument (builder, 1)); + fn.do_return (builder, ret); + } + binary_ops[octave_value::op_div].add_overload (fn); + binary_ops[octave_value::op_el_div].add_overload (fn); + + // ldiv is the same as div with the operators reversed + fn = mirror_binary (fn); + binary_ops[octave_value::op_ldiv].add_overload (fn); + binary_ops[octave_value::op_el_ldiv].add_overload (fn); + + // In general, the result of scalar ^ scalar is a complex number. We might be + // able to improve on this if we keep track of the range of values varaibles + // can take on. + fn = create_external (JIT_FN (octave_jit_pow_single_single), complex, single, + single); + binary_ops[octave_value::op_pow].add_overload (fn); + binary_ops[octave_value::op_el_pow].add_overload (fn); + + // now for unary single operations + // FIXME: Impelment not + fn = create_internal ("octave_jit_++", single, single); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantFP::get (single_t, 1); + llvm::Value *val = fn.argument (builder, 0); + val = builder.CreateFAdd (val, one); + fn.do_return (builder, val); + } + unary_ops[octave_value::op_incr].add_overload (fn); + + fn = create_internal ("octave_jit_--", single, single); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantFP::get (single_t, 1); + llvm::Value *val = fn.argument (builder, 0); + val = builder.CreateFSub (val, one); + fn.do_return (builder, val); + } + unary_ops[octave_value::op_decr].add_overload (fn); + + fn = create_internal ("octave_jit_uminus", single, single); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *mone = llvm::ConstantFP::get (single_t, -1); + llvm::Value *val = fn.argument (builder, 0); + val = builder.CreateFMul (val, mone); + fn.do_return (builder, val); + } + unary_ops[octave_value::op_uminus].add_overload (fn); + + fn = create_identity (single); + unary_ops[octave_value::op_uplus].add_overload (fn); + unary_ops[octave_value::op_transpose].add_overload (fn); + unary_ops[octave_value::op_hermitian].add_overload (fn); + + /************************************************************ + * + * int8 related operations + * + ************************************************************/ + + // now for binary int8 operations + fn = create_external (JIT_FN (octave_jit_add_int8_int8), intN (8), intN (8), + intN (8)); + binary_ops[octave_value::op_add].add_overload (fn); + fn = create_external (JIT_FN (octave_jit_sub_int8_int8), intN (8), intN (8), + intN (8)); + binary_ops[octave_value::op_sub].add_overload (fn); + fn = create_external (JIT_FN (octave_jit_mul_int8_int8), intN (8), intN (8), + intN (8)); + binary_ops[octave_value::op_mul].add_overload (fn); + binary_ops[octave_value::op_el_mul].add_overload (fn); + + add_binary_icmp (intN (8), octave_value::op_lt, llvm::CmpInst::ICMP_SLT); + add_binary_icmp (intN (8), octave_value::op_le, llvm::CmpInst::ICMP_SLE); + add_binary_icmp (intN (8), octave_value::op_eq, llvm::CmpInst::ICMP_EQ); + add_binary_icmp (intN (8), octave_value::op_ge, llvm::CmpInst::ICMP_SGE); + add_binary_icmp (intN (8), octave_value::op_gt, llvm::CmpInst::ICMP_SGT); + add_binary_icmp (intN (8), octave_value::op_ne, llvm::CmpInst::ICMP_NE); + + // FIXME: saturation divide definition? interpreter convert int to double, calculate and round. + // divide is annoying because it might error + // FIXME: Implement div + + // FIXME: Implement pow + + // now for unary int8 operations + // FIXME: Impelment not + fn = create_external (JIT_FN (octave_jit_incr_int8), intN (8), intN (8)); + unary_ops[octave_value::op_incr].add_overload (fn); + + fn = create_external (JIT_FN (octave_jit_decr_int8), intN (8), intN (8)); + unary_ops[octave_value::op_decr].add_overload (fn); + + fn = create_internal ("octave_jit_uminus", intN (8), intN (8)); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *mone = llvm::ConstantInt::get (int8__t, -1); + llvm::Value *val = fn.argument (builder, 0); + val = builder.CreateMul (val, mone); + fn.do_return (builder, val); + } + unary_ops[octave_value::op_uminus].add_overload (fn); + + fn = create_identity (intN (8)); + unary_ops[octave_value::op_uplus].add_overload (fn); + unary_ops[octave_value::op_transpose].add_overload (fn); + unary_ops[octave_value::op_hermitian].add_overload (fn); + + /************************************************************ + * + * int64 related operations + * + ************************************************************/ + + // FIXME: overflow occurs at minus + minus, minus - plus + // now for binary int64 operations + fn = create_external (JIT_FN (octave_jit_add_int64_int64), intN (64), intN (64), + intN (64)); + binary_ops[octave_value::op_add].add_overload (fn); + fn = create_external (JIT_FN (octave_jit_sub_int64_int64), intN (64), intN (64), + intN (64)); + binary_ops[octave_value::op_sub].add_overload (fn); + fn = create_external (JIT_FN (octave_jit_mul_int64_int64), intN (64), intN (64), + intN (64)); + binary_ops[octave_value::op_mul].add_overload (fn); + binary_ops[octave_value::op_el_mul].add_overload (fn); + + add_binary_icmp (intN (64), octave_value::op_lt, llvm::CmpInst::ICMP_SLT); + add_binary_icmp (intN (64), octave_value::op_le, llvm::CmpInst::ICMP_SLE); + add_binary_icmp (intN (64), octave_value::op_eq, llvm::CmpInst::ICMP_EQ); + add_binary_icmp (intN (64), octave_value::op_ge, llvm::CmpInst::ICMP_SGE); + add_binary_icmp (intN (64), octave_value::op_gt, llvm::CmpInst::ICMP_SGT); + add_binary_icmp (intN (64), octave_value::op_ne, llvm::CmpInst::ICMP_NE); + + // FIXME: saturation divide definition? interpreter convert int to double, calculate and round. + // divide is annoying because it might error + // FIXME: Implement div + + // FIXME: Implement pow + + // now for unary int8 operations + // FIXME: Impelment not + fn = create_external (JIT_FN (octave_jit_incr_int64), intN (64), intN (64)); + unary_ops[octave_value::op_incr].add_overload (fn); + + fn = create_external (JIT_FN (octave_jit_decr_int64), intN (64), intN (64)); + unary_ops[octave_value::op_decr].add_overload (fn); + + fn = create_internal ("octave_jit_uminus", intN (64), intN (64)); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *mone = llvm::ConstantInt::get (int64__t, -1); + llvm::Value *val = fn.argument (builder, 0); + val = builder.CreateMul (val, mone); + fn.do_return (builder, val); + } + unary_ops[octave_value::op_uminus].add_overload (fn); + + fn = create_identity (intN (64)); + unary_ops[octave_value::op_uplus].add_overload (fn); + unary_ops[octave_value::op_transpose].add_overload (fn); + unary_ops[octave_value::op_hermitian].add_overload (fn); + + /************************************************************ + * + * uint16 related operations + * + ************************************************************/ + + // now for binary uint16 operations + fn = create_external (JIT_FN (octave_jit_add_uint16_uint16), uintN (16), uintN (16), + uintN (16)); + binary_ops[octave_value::op_add].add_overload (fn); + fn = create_external (JIT_FN (octave_jit_sub_uint16_uint16), uintN (16), uintN (16), + uintN (16)); + binary_ops[octave_value::op_sub].add_overload (fn); + fn = create_external (JIT_FN (octave_jit_mul_uint16_uint16), uintN (16), uintN (16), + uintN (16)); + binary_ops[octave_value::op_mul].add_overload (fn); + binary_ops[octave_value::op_el_mul].add_overload (fn); + + add_binary_icmp (uintN (16), octave_value::op_lt, llvm::CmpInst::ICMP_ULT); + add_binary_icmp (uintN (16), octave_value::op_le, llvm::CmpInst::ICMP_ULE); + add_binary_icmp (uintN (16), octave_value::op_eq, llvm::CmpInst::ICMP_EQ); + add_binary_icmp (uintN (16), octave_value::op_ge, llvm::CmpInst::ICMP_UGE); + add_binary_icmp (uintN (16), octave_value::op_gt, llvm::CmpInst::ICMP_UGT); + add_binary_icmp (uintN (16), octave_value::op_ne, llvm::CmpInst::ICMP_NE); + + // FIXME: saturation divide definition? interpreter convert uint to double, calculate and round. + // divide is annoying because it might error +#if 0 + fn = create_internal ("octave_jit_div_uint16_uint16", uintN (16), uintN (16), uintN (16)); + fn.mark_can_error (); + + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::BasicBlock *warn_block = fn.new_block ("warn"); + llvm::BasicBlock *normal_block = fn.new_block ("normal"); + + llvm::Value *zero = llvm::ConstantInt::get (uint16__t, 0); + llvm::Value *check = builder.CreateICmpEQ (zero, fn.argument (builder, 1)); + builder.CreateCondBr (check, warn_block, normal_block); + + builder.SetInsertPoint (warn_block); + gripe_div0.call (builder); + builder.CreateBr (normal_block); + + builder.SetInsertPoint (normal_block); + llvm::Value *ret = builder.CreateUDiv (fn.argument (builder, 0), + fn.argument (builder, 1)); + fn.do_return (builder, ret); + } + binary_ops[octave_value::op_div].add_overload (fn); + binary_ops[octave_value::op_el_div].add_overload (fn); + + // ldiv is the same as div with the operators reversed + fn = mirror_binary (fn); + binary_ops[octave_value::op_ldiv].add_overload (fn); + binary_ops[octave_value::op_el_ldiv].add_overload (fn); +#endif + + // FIXME: Implement pow + + // now for unary uint16 operations + // FIXME: Impelment not + fn = create_external (JIT_FN (octave_jit_incr_uint16), uintN (16), uintN (16)); + unary_ops[octave_value::op_incr].add_overload (fn); + + fn = create_external (JIT_FN (octave_jit_decr_uint16), uintN (16), uintN (16)); + unary_ops[octave_value::op_decr].add_overload (fn); + + fn = create_internal ("octave_jit_uminus", uintN (16), uintN (16)); + body = fn.new_block (); + builder.SetInsertPoint (body); + { + llvm::Value *zero = llvm::ConstantInt::get (uint16__t, 0); + + fn.do_return (builder, zero); + } + unary_ops[octave_value::op_uminus].add_overload (fn); + + fn = create_identity (uintN (16)); + unary_ops[octave_value::op_uplus].add_overload (fn); + unary_ops[octave_value::op_transpose].add_overload (fn); + unary_ops[octave_value::op_hermitian].add_overload (fn); + fn = create_internal ("octave_jit_*_scalar_complex", complex, scalar, complex); jit_function mul_scalar_complex = fn; @@ -1767,6 +2361,10 @@ casts[any->type_id ()].stash_name ("(any)"); casts[scalar->type_id ()].stash_name ("(scalar)"); + casts[single->type_id ()].stash_name ("(single)"); + casts[intN (8)->type_id ()].stash_name ("(int8)"); + casts[intN (64)->type_id ()].stash_name ("(int64)"); + casts[uintN (16)->type_id ()].stash_name ("(uint16)"); casts[complex->type_id ()].stash_name ("(complex)"); casts[matrix->type_id ()].stash_name ("(matrix)"); casts[range->type_id ()].stash_name ("(range)"); @@ -1795,6 +2393,38 @@ fn = create_external (JIT_FN (octave_jit_cast_scalar_any), scalar, any); casts[scalar->type_id ()].add_overload (fn); + // cast any <- single + fn = create_external (JIT_FN (octave_jit_cast_any_single), any, single); + casts[any->type_id ()].add_overload (fn); + + // cast single <- any + fn = create_external (JIT_FN (octave_jit_cast_single_any), single, any); + casts[single->type_id ()].add_overload (fn); + + // cast any <- int8 + fn = create_external (JIT_FN (octave_jit_cast_any_int8), any, intN (8)); + casts[any->type_id ()].add_overload (fn); + + // cast int8 <- any + fn = create_external (JIT_FN (octave_jit_cast_int8_any), intN (8), any); + casts[intN (8)->type_id ()].add_overload (fn); + + // cast any <- int64 + fn = create_external (JIT_FN (octave_jit_cast_any_int64), any, intN (64)); + casts[any->type_id ()].add_overload (fn); + + // cast int64 <- any + fn = create_external (JIT_FN (octave_jit_cast_int64_any), intN (64), any); + casts[intN (64)->type_id ()].add_overload (fn); + + // cast any <- uint16 + fn = create_external (JIT_FN (octave_jit_cast_any_uint16), any, uintN (16)); + casts[any->type_id ()].add_overload (fn); + + // cast uint16 <- any + fn = create_external (JIT_FN (octave_jit_cast_uint16_any), uintN (16), any); + casts[uintN (16)->type_id ()].add_overload (fn); + // cast any <- complex fn = create_external (JIT_FN (octave_jit_cast_any_complex), any, complex); casts[any->type_id ()].add_overload (fn); @@ -1828,6 +2458,22 @@ fn = create_identity (scalar); casts[scalar->type_id ()].add_overload (fn); + // cast single <- single + fn = create_identity (single); + casts[single->type_id ()].add_overload (fn); + + // cast int8 <- int8 + fn = create_identity (intN (8)); + casts[intN (8)->type_id ()].add_overload (fn); + + // cast int64 <- int64 + fn = create_identity (intN (64)); + casts[intN (64)->type_id ()].add_overload (fn); + + // cast uint16 <- uint16 + fn = create_identity (uintN (16)); + casts[uintN (16)->type_id ()].add_overload (fn); + // cast complex <- complex fn = create_identity (complex); casts[complex->type_id ()].add_overload (fn); @@ -2225,6 +2871,15 @@ nbits)); } +void +jit_typeinfo::create_uint (size_t nbits) +{ + std::stringstream tname; + tname << "uint" << nbits; + uints[nbits] = new_type (tname.str (), any, llvm::Type::getIntNTy (context, + nbits)); +} + jit_type * jit_typeinfo::intN (size_t nbits) const { @@ -2236,6 +2891,16 @@ } jit_type * +jit_typeinfo::uintN (size_t nbits) const +{ + std::map<size_t, jit_type *>::const_iterator iter = uints.find (nbits); + if (iter != uints.end ()) + return iter->second; + + throw jit_fail_exception ("No such unsigned integer type"); +} + +jit_type * jit_typeinfo::do_type_of (const octave_value &ov) const { if (ov.is_function ()) @@ -2260,6 +2925,27 @@ return get_matrix (); } + if (ov.is_single_type () && ! ov.is_complex_type ()) + { + if (ov.is_real_scalar ()) + return get_single (); + } + + if (ov.is_int8_type()) + { + return intN (8); + } + + if (ov.is_int64_type()) + { + return intN (64); + } + + if (ov.is_uint16_type()) + { + return uintN (16); + } + if (ov.is_complex_scalar ()) { Complex cv = ov.complex_value (); @@ -2274,3 +2960,5 @@ } #endif + +#pragma GCC diagnostic pop
--- a/libinterp/corefcn/jit-typeinfo.h +++ b/libinterp/corefcn/jit-typeinfo.h @@ -458,6 +458,8 @@ static llvm::Type *get_scalar_llvm (void) { return instance->scalar->to_llvm (); } + static jit_type *get_single (void) { return instance->single; } + static jit_type *get_scalar_ptr (void) { return instance->scalar_ptr; } static jit_type *get_any_ptr (void) { return instance->any_ptr; } @@ -790,8 +792,12 @@ void create_int (size_t nbits); + void create_uint (size_t nbits); + jit_type *intN (size_t nbits) const; + jit_type *uintN (size_t nbits) const; + static jit_typeinfo *instance; llvm::Module *module; @@ -807,6 +813,7 @@ jit_type *any; jit_type *matrix; jit_type *scalar; + jit_type *single; jit_type *scalar_ptr; // a fake type for interfacing with C++ jit_type *any_ptr; // a fake type for interfacing with C++ jit_type *range; @@ -816,6 +823,7 @@ jit_type *complex; jit_type *unknown_function; std::map<size_t, jit_type *> ints; + std::map<size_t, jit_type *> uints; std::map<std::string, jit_type *> builtins; llvm::StructType *complex_ret;
--- a/libinterp/corefcn/pt-jit.cc +++ b/libinterp/corefcn/pt-jit.cc @@ -230,13 +230,13 @@ void jit_convert::visit_anon_fcn_handle (tree_anon_fcn_handle&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_anon_fcn_handle implementation"); } void jit_convert::visit_argument_list (tree_argument_list&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_argument_list implementation"); } void @@ -335,25 +335,25 @@ void jit_convert::visit_global_command (tree_global_command&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_global_command implemenation"); } void jit_convert::visit_persistent_command (tree_persistent_command&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_persistent_command implementation"); } void jit_convert::visit_decl_elt (tree_decl_elt&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_decl_elt implementation"); } void jit_convert::visit_decl_init_list (tree_decl_init_list&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_decl_init_list implementation"); } void @@ -462,37 +462,37 @@ void jit_convert::visit_complex_for_command (tree_complex_for_command&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_complex_for_command implementation"); } void jit_convert::visit_octave_user_script (octave_user_script&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_octave_user_script implementation"); } void jit_convert::visit_octave_user_function (octave_user_function&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_octave_user_function implementation"); } void jit_convert::visit_octave_user_function_header (octave_user_function&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_octave_user_function_header implementation"); } void jit_convert::visit_octave_user_function_trailer (octave_user_function&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_octave_user_function_trailer implementation"); } void jit_convert::visit_function_def (tree_function_def&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_function_def implementation"); } void @@ -516,7 +516,7 @@ void jit_convert::visit_if_clause (tree_if_clause&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_if_clause implementation"); } void @@ -628,25 +628,25 @@ void jit_convert::visit_matrix (tree_matrix&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_matrix implementation"); } void jit_convert::visit_cell (tree_cell&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_cell implementation"); } void jit_convert::visit_multi_assignment (tree_multi_assignment&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_multi_assignment implementation"); } void jit_convert::visit_no_op_command (tree_no_op_command&) { - throw jit_fail_exception (); + return; } void @@ -677,13 +677,13 @@ void jit_convert::visit_fcn_handle (tree_fcn_handle&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_fcn_handle implementation"); } void jit_convert::visit_parameter_list (tree_parameter_list&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_parameter_list implementation"); } void @@ -719,13 +719,13 @@ void jit_convert::visit_return_command (tree_return_command&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_return_command implementation"); } void jit_convert::visit_return_list (tree_return_list&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_return_list implementation"); } void @@ -802,31 +802,147 @@ void jit_convert::visit_switch_case (tree_switch_case&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_switch_case implementation"); } void jit_convert::visit_switch_case_list (tree_switch_case_list&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_switch_case_list implementation"); } void -jit_convert::visit_switch_command (tree_switch_command&) +jit_convert::visit_switch_command (tree_switch_command& cmd) { - throw jit_fail_exception (); + // create blocks + jit_block *head = factory.create<jit_block> ("switch_head"); + + tree_switch_case_list *lst = cmd.case_list (); + tree_switch_case *last = lst->back (); + + size_t cond_blocks_num; + if (lst->size() && last->is_default_case ()) + cond_blocks_num = lst->size () - 1; + else + cond_blocks_num = lst->size (); + + std::vector<jit_block *> if_blocks (cond_blocks_num); + std::vector<jit_block *> body_blocks (cond_blocks_num); + std::vector<jit_block *> else_blocks (cond_blocks_num); + + //tree_switch_case_list::iterator iter = lst->begin (); + for (size_t i = 0; i < cond_blocks_num; ++i) + { + if_blocks[i] = factory.create<jit_block> ("if_cond"); + body_blocks[i] = factory.create<jit_block> ("if_body"); + else_blocks[i] = factory.create<jit_block> ("else"); + } + + jit_block *tail = factory.create<jit_block> ("switch_tail"); + + + // link & fullfil these blocks + block->append (factory.create<jit_branch> (head)); + + blocks.push_back (head); + block = head; // switch_head + tree_expression *expr = cmd.switch_value (); + assert (expr && "Switch value can not be null"); + jit_value *value = visit (expr); + assert (value); + + // each branch in the if statement will have different breaks/continues + block_list current_breaks = breaks; + block_list current_continues = continues; + breaks.clear (); + continues.clear (); + +#if 0 + size_t num_incomming = 0; // number of incomming blocks to our tail +#endif + tree_switch_case_list::iterator iter = lst->begin (); + for (size_t i = 0; i < cond_blocks_num; ++iter, ++i) + { + block->append (factory.create<jit_branch> (if_blocks[i])); + + blocks.push_back (if_blocks[i]); + block = if_blocks[i]; // if_cond + tree_switch_case *twc = *iter; + tree_expression *expr = twc->case_label (); + jit_value *label = visit (expr); + assert(label); + + const jit_operation& fn = jit_typeinfo::binary_op (octave_value::op_eq); + jit_value *cond = create_checked (fn, value, label); + assert(cond); + jit_call *check = create_checked (&jit_typeinfo::logically_true, + cond); + block->append (factory.create<jit_cond_branch> (check, body_blocks[i], + else_blocks[i])); + + blocks.push_back (body_blocks[i]); + block = body_blocks[i]; // if_body + tree_statement_list *stmt_lst = twc->commands (); + assert(stmt_lst); +#if 0 + try + { +#endif + stmt_lst->accept (*this); +#if 0 + num_incomming++; +#endif + block->append (factory.create<jit_branch> (tail)); +#if 0 + } + catch(const jit_break_exception&) + {} +#endif + +#if 0 + // each branch in the if statement will have different breaks/continues + current_breaks.splice (current_breaks.end (), breaks); + current_continues.splice (current_continues.end (), continues); +#endif + + blocks.push_back (else_blocks[i]); + block = else_blocks[i]; // else + } + + if (lst->size() && last->is_default_case ()) + { + tree_statement_list *stmt_lst = last->commands (); + assert(stmt_lst); + stmt_lst->accept (*this); + +#if 0 + // each branch in the if statement will have different breaks/continues + current_breaks.splice (current_breaks.end (), breaks); + current_continues.splice (current_continues.end (), continues); +#endif + } + +#if 0 + // each branch in the if statement will have different breaks/continues + breaks.splice (breaks.end (), current_breaks); + continues.splice (continues.end (), current_continues); +#endif + + block->append (factory.create<jit_branch> (tail)); + blocks.push_back (tail); + block = tail; // switch_tail } void jit_convert::visit_try_catch_command (tree_try_catch_command&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_try_catch_command implementation"); } void jit_convert::visit_unwind_protect_command (tree_unwind_protect_command&) { - throw jit_fail_exception (); + throw jit_fail_exception ("No visit_unwind_protect_command implementation"); } void @@ -892,9 +1008,68 @@ } void -jit_convert::visit_do_until_command (tree_do_until_command&) +jit_convert::visit_do_until_command (tree_do_until_command& duc) { - throw jit_fail_exception (); + unwind_protect prot; + prot.protect_var (breaks); + prot.protect_var (continues); + breaks.clear (); + continues.clear (); + + jit_block *body = factory.create<jit_block> ("do_until_body"); + block->append (factory.create<jit_branch> (body)); + blocks.push_back (body); + block = body; + + tree_statement_list *loop_body = duc.body (); + bool all_breaking = false; + if (loop_body) + { + try + { + loop_body->accept (*this); + } + catch (const jit_break_exception&) + { + all_breaking = true; + } + } + + jit_block *cond_check = factory.create<jit_block> ("do_until_cond_check"); + block->append (factory.create<jit_branch> (cond_check)); + blocks.push_back (cond_check); + block = cond_check; + + tree_expression *expr = duc.condition (); + assert (expr && "Do-Until expression can not be null"); + jit_value *check = visit (expr); + check = create_checked (&jit_typeinfo::logically_true, check); + + jit_block *tail = factory.create<jit_block> ("do_until_tail"); + block->append (factory.create<jit_cond_branch> (check, tail, body)); + + finish_breaks (tail, breaks); + +#if 0 + if (! all_breaking || continues.size ()) + { + jit_block *interrupt_check + = factory.create<jit_block> ("interrupt_check"); + blocks.push_back (interrupt_check); + finish_breaks (interrupt_check, continues); + if (! all_breaking) + block->append (factory.create<jit_branch> (interrupt_check)); + + block = interrupt_check; + jit_error_check *ec + = factory.create<jit_error_check> (jit_error_check::var_interrupt, + cond_check, final_block); + block->append (ec); + } +#endif + + blocks.push_back (tail); + block = tail; } void