changeset 14762:8efcaf5aa233

Prevent crash when using scalars as conditionals
author Max Brister <max@2bass.com>
date Fri, 08 Jun 2012 13:35:40 -0500
parents e8487d98561c
children c0a5ab3b9278
files src/pt-jit.cc src/pt-jit.h
diffstat 2 files changed, 49 insertions(+), 9 deletions(-) [+]
line wrap: on
line diff
--- a/src/pt-jit.cc
+++ b/src/pt-jit.cc
@@ -167,6 +167,19 @@
   return new octave_scalar (value);
 }
 
+extern "C" void
+octave_jit_gripe_nan_to_logical_conversion (void)
+{
+  try
+    {
+      gripe_nan_to_logical_conversion ();
+    }
+  catch (const octave_execution_exception&)
+    {
+      gripe_library_execution_error ();
+    }
+}
+
 // -------------------- jit_range --------------------
 std::ostream&
 operator<< (std::ostream& os, const jit_range& rng)
@@ -495,25 +508,40 @@
   for_index_fn.add_overload (fn, false, false, scalar, range, index);
 
   // logically true
-  // FIXME: Check for NaN
-  fn = create_function ("octave_logically_true_scalar", boolean, scalar);
+  logically_true_fn.stash_name ("logically_true");
+
+  llvm::Function *gripe_nantl = create_function ("octave_jit_gripe_nan_to_logical_conversion", void_t);
+  engine->addGlobalMapping (gripe_nantl, reinterpret_cast<void *> (&octave_jit_gripe_nan_to_logical_conversion));
+                            
+
+  fn = create_function ("octave_jit_logically_true_scalar", boolean, scalar);
   body = llvm::BasicBlock::Create (context, "body", fn);
   builder.SetInsertPoint (body);
   {
-    llvm::Value *zero = llvm::ConstantFP::get (scalar->to_llvm (), 0);
-    llvm::Value *ret = builder.CreateFCmpUNE (fn->arg_begin (), zero);
+    llvm::BasicBlock *error_block = llvm::BasicBlock::Create (context, "error", fn);
+    llvm::BasicBlock *normal_block = llvm::BasicBlock::Create (context, "normal", fn);
+
+    llvm::Value *check = builder.CreateFCmpUNE (fn->arg_begin (), fn->arg_begin ());
+    builder.CreateCondBr (check, error_block, normal_block);
+
+    builder.SetInsertPoint (error_block);
+    builder.CreateCall (gripe_nantl);
+    builder.CreateBr (normal_block);
+    builder.SetInsertPoint (normal_block);
+
+    llvm::Value *zero = llvm::ConstantFP::get (dbl, 0);
+    llvm::Value *ret = builder.CreateFCmpONE (fn->arg_begin (), zero);
     builder.CreateRet (ret);
   }
   llvm::verifyFunction (*fn);
-  logically_true.add_overload (fn, true, false, boolean, scalar);
+  logically_true_fn.add_overload (fn, true, false, boolean, scalar);
 
   fn = create_function ("octave_logically_true_bool", boolean, boolean);
   body = llvm::BasicBlock::Create (context, "body", fn);
   builder.SetInsertPoint (body);
   builder.CreateRet (fn->arg_begin ());
   llvm::verifyFunction (*fn);
-  logically_true.add_overload (fn, false, false, boolean, boolean);
-  logically_true.stash_name ("logically_true");
+  logically_true_fn.add_overload (fn, false, false, boolean, boolean);
 
   // make_range
   // FIXME: May be benificial to implement all in LLVM
@@ -1500,7 +1528,14 @@
       if (! tic->is_else_clause ())
         {
           tree_expression *expr = tic->condition ();
-          jit_value *cond = visit (expr);
+          jit_instruction *cond = visit (expr);
+          cond = create<jit_call> (&jit_typeinfo::logically_true, cond);
+          block->append (cond);
+
+          jit_block *next = create<jit_block> (block->name () + "a");
+          blocks.push_back (next);
+          block->append (create<jit_check_error> (next, final_block));
+          block = next;
 
           jit_block *body = create<jit_block> (i == 0 ? "if_body" : "ifelse_body");
           blocks.push_back (body);
--- a/src/pt-jit.h
+++ b/src/pt-jit.h
@@ -348,6 +348,11 @@
     return instance->make_range_fn;
   }
 
+  static const jit_function& logically_true (void)
+  {
+    return instance->logically_true_fn;
+  }
+
   static const jit_function& cast (jit_type *result)
   {
     return instance->do_cast (result);
@@ -525,7 +530,7 @@
   jit_function for_init_fn;
   jit_function for_check_fn;
   jit_function for_index_fn;
-  jit_function logically_true;
+  jit_function logically_true_fn;
   jit_function make_range_fn;
 
   // type id -> cast function TO that type