# HG changeset patch # User Max Brister # Date 1339180540 18000 # Node ID 8efcaf5aa233d7adf574fec15520226e4c0248ea # Parent e8487d98561c1062af61fce040de7a73039e4bbc Prevent crash when using scalars as conditionals diff --git a/src/pt-jit.cc b/src/pt-jit.cc --- a/src/pt-jit.cc +++ b/src/pt-jit.cc @@ -167,6 +167,19 @@ return new octave_scalar (value); } +extern "C" void +octave_jit_gripe_nan_to_logical_conversion (void) +{ + try + { + gripe_nan_to_logical_conversion (); + } + catch (const octave_execution_exception&) + { + gripe_library_execution_error (); + } +} + // -------------------- jit_range -------------------- std::ostream& operator<< (std::ostream& os, const jit_range& rng) @@ -495,25 +508,40 @@ for_index_fn.add_overload (fn, false, false, scalar, range, index); // logically true - // FIXME: Check for NaN - fn = create_function ("octave_logically_true_scalar", boolean, scalar); + logically_true_fn.stash_name ("logically_true"); + + llvm::Function *gripe_nantl = create_function ("octave_jit_gripe_nan_to_logical_conversion", void_t); + engine->addGlobalMapping (gripe_nantl, reinterpret_cast (&octave_jit_gripe_nan_to_logical_conversion)); + + + fn = create_function ("octave_jit_logically_true_scalar", boolean, scalar); body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); { - llvm::Value *zero = llvm::ConstantFP::get (scalar->to_llvm (), 0); - llvm::Value *ret = builder.CreateFCmpUNE (fn->arg_begin (), zero); + llvm::BasicBlock *error_block = llvm::BasicBlock::Create (context, "error", fn); + llvm::BasicBlock *normal_block = llvm::BasicBlock::Create (context, "normal", fn); + + llvm::Value *check = builder.CreateFCmpUNE (fn->arg_begin (), fn->arg_begin ()); + builder.CreateCondBr (check, error_block, normal_block); + + builder.SetInsertPoint (error_block); + builder.CreateCall (gripe_nantl); + builder.CreateBr (normal_block); + builder.SetInsertPoint (normal_block); + + llvm::Value *zero = llvm::ConstantFP::get (dbl, 0); + llvm::Value *ret = builder.CreateFCmpONE (fn->arg_begin (), zero); builder.CreateRet (ret); } llvm::verifyFunction (*fn); - logically_true.add_overload (fn, true, false, boolean, scalar); + logically_true_fn.add_overload (fn, true, false, boolean, scalar); fn = create_function ("octave_logically_true_bool", boolean, boolean); body = llvm::BasicBlock::Create (context, "body", fn); builder.SetInsertPoint (body); builder.CreateRet (fn->arg_begin ()); llvm::verifyFunction (*fn); - logically_true.add_overload (fn, false, false, boolean, boolean); - logically_true.stash_name ("logically_true"); + logically_true_fn.add_overload (fn, false, false, boolean, boolean); // make_range // FIXME: May be benificial to implement all in LLVM @@ -1500,7 +1528,14 @@ if (! tic->is_else_clause ()) { tree_expression *expr = tic->condition (); - jit_value *cond = visit (expr); + jit_instruction *cond = visit (expr); + cond = create (&jit_typeinfo::logically_true, cond); + block->append (cond); + + jit_block *next = create (block->name () + "a"); + blocks.push_back (next); + block->append (create (next, final_block)); + block = next; jit_block *body = create (i == 0 ? "if_body" : "ifelse_body"); blocks.push_back (body); diff --git a/src/pt-jit.h b/src/pt-jit.h --- a/src/pt-jit.h +++ b/src/pt-jit.h @@ -348,6 +348,11 @@ return instance->make_range_fn; } + static const jit_function& logically_true (void) + { + return instance->logically_true_fn; + } + static const jit_function& cast (jit_type *result) { return instance->do_cast (result); @@ -525,7 +530,7 @@ jit_function for_init_fn; jit_function for_check_fn; jit_function for_index_fn; - jit_function logically_true; + jit_function logically_true_fn; jit_function make_range_fn; // type id -> cast function TO that type