# HG changeset patch # User Jordi GutiĆ©rrez Hermoso # Date 1342815215 14400 # Node ID 0ef151f9fdc90afafc9904a09da1e8674cb0e6a7 # Parent f8bb15f6a19b8c8ea87c4fe638eaf46bf793c27b# Parent bbc825cb2ea0ce6c1ad2383c12111e5ec19bc1a6 Merge in JIT branch \o/ diff --git a/build-aux/common.mk b/build-aux/common.mk --- a/build-aux/common.mk +++ b/build-aux/common.mk @@ -181,6 +181,10 @@ Z_LDFLAGS = @Z_LDFLAGS@ Z_LIBS = @Z_LIBS@ +LLVM_CPPFLAGS = @LLVM_CPPFLAGS@ +LLVM_LDFLAGS = @LLVM_LDFLAGS@ +LLVM_LIBS = @LLVM_LIBS@ + GRAPHICS_LIBS = @GRAPHICS_LIBS@ QHULL_CPPFLAGS = @QHULL_CPPFLAGS@ @@ -252,7 +256,7 @@ DL_LIBS = @DL_LIBS@ LIBS = @LIBS@ -ALL_CPPFLAGS = $(CPPFLAGS) $(HDF5_CPPFLAGS) $(Z_CPPFLAGS) +ALL_CPPFLAGS = $(CPPFLAGS) $(HDF5_CPPFLAGS) $(Z_CPPFLAGS) $(LLVM_CPPFLAGS) SPARSE_XCPPFLAGS = \ $(CHOLMOD_CPPFLAGS) $(UMFPACK_CPPFLAGS) \ @@ -544,6 +548,9 @@ -e "s|%OCTAVE_CONF_MAGICK_CPPFLAGS%|\"${MAGICK_CPPFLAGS}\"|" \ -e "s|%OCTAVE_CONF_MAGICK_LDFLAGS%|\"${MAGICK_LDFLAGS}\"|" \ -e "s|%OCTAVE_CONF_MAGICK_LIBS%|\"${MAGICK_LIBS}\"|" \ + -e "s|%OCTAVE_CONF_LLVM_CPPFLAGS%|\"${LLVM_CPPFLAGS}\"|" \ + -e "s|%OCTAVE_CONF_LLVM_LDFLAGS%|\"${LLVM_LDFLAGS}\"|" \ + -e "s|%OCTAVE_CONF_LLVM_LIBS%|\"${LLVM_LIBS}\"|" \ -e 's|%OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS%|\"@MKOCTFILE_DL_LDFLAGS@\"|' \ -e "s|%OCTAVE_CONF_OCTAVE_LINK_DEPS%|\"${OCTAVE_LINK_DEPS}\"|" \ -e "s|%OCTAVE_CONF_OCTAVE_LINK_OPTS%|\"${OCTAVE_LINK_OPTS}\"|" \ diff --git a/build-aux/mkinstalldirs b/build-aux/mkinstalldirs diff --git a/configure.ac b/configure.ac --- a/configure.ac +++ b/configure.ac @@ -715,6 +715,95 @@ [ZLIB library not found. Octave will not be able to save or load compressed data files or HDF5 files.], [zlib.h], [gzclearerr]) +### Check for the llvm library +dnl +dnl +dnl llvm is odd and has its own pkg-config like script. We should probably check +dnl for existance and +dnl +warn_llvm="LLVM library fails tests. JIT compilation will be disabled." + +AC_ARG_VAR(LLVM_CONFIG, [path to llvm-config utility]) + +AC_ARG_ENABLE([jit-debug], + AS_HELP_STRING([--enable-jit-debug], [Enable debug printing of jit IRs])) + +AS_IF([test "x$enable_jit_debug" = "xyes"], [ + AC_DEFINE(OCTAVE_JIT_DEBUG, 1, [Define for jit debug printing]) +]) + +LLVM_CXXFLAGS= +LLVM_CPPFLAGS= +LLVM_LDFLAGS= +LLVM_LIBS= + +if test "x$ac_cv_env_LLVM_CONFIG_set" = "xset"; then + # We use -isystem if avaiable because we do not want to see warnings in llvm + LLVM_INCLUDE_FLAG=-I + OCTAVE_CC_FLAG(-isystem ., [ + LLVM_INCLUDE_FLAG=-isystem + AC_MSG_NOTICE([using -isystem for llvm headers])]) + + LLVM_LDFLAGS="-L`$LLVM_CONFIG --libdir`" + LLVM_LIBS=`$LLVM_CONFIG --libs` + dnl Use -isystem so we don't get warnings from llvm headers + LLVM_CPPFLAGS="$LLVM_INCLUDE_FLAG `$LLVM_CONFIG --includedir`" + LLVM_CXXFLAGS= + + save_CPPFLAGS="$CPPFLAGS" + save_CXXFLAGS="$CXXFLAGS" + save_LIBS="$LIBS" + save_LDFLAGS="$LDFLAGS" + + dnl + dnl We define some extra flags that llvm requires in order to include headers. + dnl Idealy we should get these from llvm-config, but llvm-config isn't very + dnl helpful. + dnl + CPPFLAGS="-D__STDC_CONSTANT_MACROS -D__STDC_LIMIT_MACROS $LLVM_CPPFLAGS $CPPFLAGS" + CXXFLAGS="$LLVM_CXXFLAGS $CXXFLAGS" + LIBS="$LLVM_LIBS $LIBS" + LDFLAGS="$LLVM_LDFLAGS $LDFLAGS" + + AC_LANG_PUSH(C++) + AC_CHECK_HEADER([llvm/LLVMContext.h], [ + AC_MSG_CHECKING([for llvm::getGlobalContext in llvm/LLVMContext.h]) + AC_COMPILE_IFELSE( + [AC_LANG_PROGRAM([[#include ]], + [[llvm::LLVMContext& ctx = llvm::getGlobalContext ();]])], + [ + AC_MSG_RESULT([yes]) + warn_llvm= + XTRA_CXXFLAGS="$XTRA_CXXFLAGS $LLVM_CXXFLAGS $LLVM_CPPFLAGS" + ], + [AC_MSG_RESULT([no]) + ]) + ]) + AC_LANG_POP(C++) +else + warn_llvm="LLVM_CONFIG not set. JIT compilation will be disabled." +fi + +CPPFLAGS="$save_CPPFLAGS" +CXXFLAGS="$save_CXXFLAGS" +LIBS="$save_LIBS" +LDFLAGS="$save_LDFLAGS" + +if test -z "$warn_llvm"; then + AC_DEFINE(HAVE_LLVM, 1, [Define if LLVM is available]) +else + LLVM_CXXFLAGS= + LLVM_CPPFLAGS= + LLVM_LDFLAGS= + LLVM_LIBS= + OCTAVE_CONFIGURE_WARNING([warn_llvm]) +fi + +AC_SUBST(LLVM_CXXFLAGS) +AC_SUBST(LLVM_CPPFLAGS) +AC_SUBST(LLVM_LDFLAGS) +AC_SUBST(LLVM_LIBS) + ### Check for HDF5 library. save_CPPFLAGS="$CPPFLAGS" @@ -2245,6 +2334,9 @@ Magick++ CPPFLAGS: $MAGICK_CPPFLAGS Magick++ LDFLAGS: $MAGICK_LDFLAGS Magick++ libraries: $MAGICK_LIBS + LLVM CPPFLAGS: $LLVM_CPPFLAGS + LLVM LDFLAGS: $LLVM_LDFLAGS + LLVM Libraries: $LLVM_LIBS HDF5 CPPFLAGS: $HDF5_CPPFLAGS HDF5 LDFLAGS: $HDF5_LDFLAGS HDF5 libraries: $HDF5_LIBS diff --git a/liboctave/Array.h b/liboctave/Array.h --- a/liboctave/Array.h +++ b/liboctave/Array.h @@ -164,6 +164,14 @@ return &nr; } +protected: + + // For jit support + Array (T *sdata, octave_idx_type slen, octave_idx_type *adims, void *arep) + : dimensions (adims), + rep (reinterpret_cast::ArrayRep *> (arep)), + slice_data (sdata), slice_len (slen) {} + public: // Empty ctor (0x0). @@ -693,6 +701,16 @@ // supposedly equal dimensions (e.g. structs in the interpreter). bool optimize_dimensions (const dim_vector& dv); + // WARNING: Only call these functions from jit + + int *jit_ref_count (void) { return rep->count.get (); } + + T *jit_slice_data (void) const { return slice_data; } + + octave_idx_type *jit_dimensions (void) const { return dimensions.to_jit (); } + + void *jit_array_rep (void) const { return rep; } + private: void resize2 (octave_idx_type nr, octave_idx_type nc, const T& rfv); diff --git a/liboctave/MArray.h b/liboctave/MArray.h --- a/liboctave/MArray.h +++ b/liboctave/MArray.h @@ -39,6 +39,12 @@ class MArray : public Array { +protected: + + // For jit support + MArray (T *sdata, octave_idx_type slen, octave_idx_type *adims, void *arep) + : Array (sdata, slen, adims, arep) { } + public: MArray (void) : Array () {} diff --git a/liboctave/dNDArray.h b/liboctave/dNDArray.h --- a/liboctave/dNDArray.h +++ b/liboctave/dNDArray.h @@ -64,6 +64,10 @@ NDArray (const charNDArray&); + // For jit support only + NDArray (double *sdata, octave_idx_type slen, octave_idx_type *adims, void *arep) + : MArray (sdata, slen, adims, arep) { } + NDArray& operator = (const NDArray& a) { MArray::operator = (a); diff --git a/liboctave/dim-vector.h b/liboctave/dim-vector.h --- a/liboctave/dim-vector.h +++ b/liboctave/dim-vector.h @@ -212,6 +212,12 @@ void chop_all_singletons (void); + // WARNING: Only call by jit + octave_idx_type *to_jit (void) const + { + return rep; + } + private: static octave_idx_type *nil_rep (void) @@ -220,9 +226,6 @@ return zv.rep; } - explicit dim_vector (octave_idx_type *r) - : rep (r) { } - public: static octave_idx_type dim_max (void); @@ -233,6 +236,10 @@ dim_vector (const dim_vector& dv) : rep (dv.rep) { OCTREFCOUNT_ATOMIC_INCREMENT (&(count())); } + // FIXME: Should be private, but required by array constructor for jit + explicit dim_vector (octave_idx_type *r) + : rep (r) { } + static dim_vector alloc (int n) { return dim_vector (newrep (n < 2 ? 2 : n)); diff --git a/liboctave/lo-macros.h b/liboctave/lo-macros.h --- a/liboctave/lo-macros.h +++ b/liboctave/lo-macros.h @@ -92,4 +92,8 @@ #define OCT_MAKE_DECL_LIST(TYPE, PREFIX, NUM) \ OCT_ITERATE_PARAM_MACRO(OCT_MAKE_DECL_LIST_HELPER, TYPE PREFIX, NUM) +// expands to PREFIX0, PREFIX1, ..., PREFIX ## (NUM-1) +#define OCT_MAKE_ARG_LIST(PREFIX, NUM) \ + OCT_ITERATE_PARAM_MACRO(OCT_MAKE_DECL_LIST_HELPER, PREFIX, NUM) + #endif diff --git a/liboctave/oct-refcount.h b/liboctave/oct-refcount.h --- a/liboctave/oct-refcount.h +++ b/liboctave/oct-refcount.h @@ -82,6 +82,11 @@ return static_cast (count); } + count_type *get (void) + { + return &count; + } + private: count_type count; }; diff --git a/m4/acinclude.m4 b/m4/acinclude.m4 --- a/m4/acinclude.m4 +++ b/m4/acinclude.m4 @@ -286,7 +286,7 @@ dnl dnl OCTAVE_CC_FLAG AC_DEFUN([OCTAVE_CC_FLAG], [ - ac_safe=`echo "$1" | sed 'y%./+-:=%__p___%'` + ac_safe=`echo "$1" | sed 'y% ./+-:=%___p___%'` AC_MSG_CHECKING([whether ${CC-cc} accepts $1]) AC_CACHE_VAL(octave_cv_cc_flag_$ac_safe, [ AC_LANG_PUSH(C) diff --git a/src/Makefile.am b/src/Makefile.am --- a/src/Makefile.am +++ b/src/Makefile.am @@ -220,6 +220,7 @@ pt-fcn-handle.h \ pt-id.h \ pt-idx.h \ + pt-jit.h \ pt-jump.h \ pt-loop.h \ pt-mat.h \ @@ -392,6 +393,7 @@ pt-fcn-handle.cc \ pt-id.cc \ pt-idx.cc \ + pt-jit.cc \ pt-jump.cc \ pt-loop.cc \ pt-mat.cc \ diff --git a/src/TEMPLATE-INST/Array-jit.cc b/src/TEMPLATE-INST/Array-jit.cc new file mode 100644 --- /dev/null +++ b/src/TEMPLATE-INST/Array-jit.cc @@ -0,0 +1,40 @@ +/* + +Copyright (C) 2012 Max Brister + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#ifdef HAVE_CONFIG_H +#include +#endif + +#ifdef HAVE_LLVM + +#include "Array.h" +#include "Array.cc" + +extern template class OCTAVE_API Array; + +#include "pt-jit.h" + +NO_INSTANTIATE_ARRAY_SORT (jit_operation::overload); + +INSTANTIATE_ARRAY (jit_operation::overload, OCTINTERP_API); + +#endif diff --git a/src/TEMPLATE-INST/module.mk b/src/TEMPLATE-INST/module.mk --- a/src/TEMPLATE-INST/module.mk +++ b/src/TEMPLATE-INST/module.mk @@ -2,4 +2,5 @@ TEMPLATE_INST_SRC = \ TEMPLATE-INST/Array-os.cc \ - TEMPLATE-INST/Array-tc.cc + TEMPLATE-INST/Array-tc.cc \ + TEMPLATE-INST/Array-jit.cc diff --git a/src/link-deps.mk b/src/link-deps.mk --- a/src/link-deps.mk +++ b/src/link-deps.mk @@ -13,14 +13,16 @@ $(Z_LIBS) \ $(OPENGL_LIBS) \ $(X11_LIBS) \ - $(CARBON_LIBS) + $(CARBON_LIBS) \ + $(LLVM_LIBS) LIBOCTINTERP_LINK_OPTS = \ $(GRAPHICS_LDFLAGS) \ $(FT2_LDFLAGS) \ $(HDF5_LDFLAGS) \ $(Z_LDFLAGS) \ - $(REGEX_LDFLAGS) + $(REGEX_LDFLAGS) \ + $(LLVM_LDFLAGS) OCT_LINK_DEPS = diff --git a/src/oct-conf.in.h b/src/oct-conf.in.h --- a/src/oct-conf.in.h +++ b/src/oct-conf.in.h @@ -384,6 +384,18 @@ #define OCTAVE_CONF_MAGICK_LIBS %OCTAVE_CONF_MAGICK_LIBS% #endif +#ifndef OCTAVE_CONF_LLVM_CPPFLAGS +#define OCTAVE_CONF_LLVM_CPPFLAGS %OCTAVE_CONF_LLVM_CPPFLAGS% +#endif + +#ifndef OCTAVE_CONF_LLVM_LDFLAGS +#define OCTAVE_CONF_LLVM_LDFLAGS %OCTAVE_CONF_LLVM_LDFLAGS% +#endif + +#ifndef OCTAVE_CONF_LLVM_LIBS +#define OCTAVE_CONF_LLVM_LIBS %OCTAVE_CONF_LLVM_LIBS% +#endif + #ifndef OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS #define OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS %OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS% #endif diff --git a/src/ov-base.h b/src/ov-base.h --- a/src/ov-base.h +++ b/src/ov-base.h @@ -756,6 +756,21 @@ virtual bool fast_elem_insert_self (void *where, builtin_type_t btyp) const; + // Grab the reference count. For use by jit. + void + grab (void) + { + ++count; + } + + // Release the reference count. For use by jit. + void + release (void) + { + if (--count == 0) + delete this; + } + protected: // This should only be called for derived types. diff --git a/src/ov-builtin.cc b/src/ov-builtin.cc --- a/src/ov-builtin.cc +++ b/src/ov-builtin.cc @@ -152,5 +152,22 @@ return retval; } +jit_type * +octave_builtin::to_jit (void) const +{ + return jtype; +} + +void +octave_builtin::stash_jit (jit_type &type) +{ + jtype = &type; +} + +octave_builtin::fcn +octave_builtin::function (void) const +{ + return f; +} const std::list *octave_builtin::curr_lvalue_list = 0; diff --git a/src/ov-builtin.h b/src/ov-builtin.h --- a/src/ov-builtin.h +++ b/src/ov-builtin.h @@ -30,6 +30,7 @@ class octave_value; class octave_value_list; +class jit_type; // Builtin functions. @@ -39,13 +40,13 @@ { public: - octave_builtin (void) : octave_function (), f (0) { } + octave_builtin (void) : octave_function (), f (0), jtype (0) { } typedef octave_value_list (*fcn) (const octave_value_list&, int); octave_builtin (fcn ff, const std::string& nm = std::string (), const std::string& ds = std::string ()) - : octave_function (nm, ds), f (ff) { } + : octave_function (nm, ds), f (ff), jtype (0) { } ~octave_builtin (void) { } @@ -75,6 +76,12 @@ do_multi_index_op (int nargout, const octave_value_list& args, const std::list* lvalue_list); + jit_type *to_jit (void) const; + + void stash_jit (jit_type& type); + + fcn function (void) const; + static const std::list *curr_lvalue_list; protected: @@ -82,6 +89,9 @@ // A pointer to the actual function. fcn f; + // A pointer to the jit type that represents the function. + jit_type *jtype; + private: // No copying! diff --git a/src/pt-eval.cc b/src/pt-eval.cc --- a/src/pt-eval.cc +++ b/src/pt-eval.cc @@ -44,6 +44,12 @@ #include "symtab.h" #include "unwind-prot.h" +#if HAVE_LLVM +//FIXME: This should be part of tree_evaluator +#include "pt-jit.h" +static tree_jit jiter; +#endif + static tree_evaluator std_evaluator; tree_evaluator *current_evaluator = &std_evaluator; @@ -290,6 +296,11 @@ if (debug_mode) do_breakpoint (cmd.is_breakpoint ()); +#if HAVE_LLVM + if (jiter.execute (cmd)) + return; +#endif + // FIXME -- need to handle PARFOR loops here using cmd.in_parallel () // and cmd.maxproc_expr (); diff --git a/src/pt-id.cc b/src/pt-id.cc --- a/src/pt-id.cc +++ b/src/pt-id.cc @@ -65,7 +65,7 @@ if (error_state) return retval; - octave_value val = xsym ().find (); + octave_value val = sym->find (); if (val.is_defined ()) { @@ -116,7 +116,7 @@ octave_lvalue tree_identifier::lvalue (void) { - return octave_lvalue (&(xsym ().varref ())); + return octave_lvalue (&(sym->varref ())); } tree_identifier * diff --git a/src/pt-id.h b/src/pt-id.h --- a/src/pt-id.h +++ b/src/pt-id.h @@ -46,12 +46,12 @@ public: tree_identifier (int l = -1, int c = -1) - : tree_expression (l, c), sym (), scope (-1) { } + : tree_expression (l, c) { } tree_identifier (const symbol_table::symbol_record& s, int l = -1, int c = -1, symbol_table::scope_id sc = symbol_table::current_scope ()) - : tree_expression (l, c), sym (s), scope (sc) { } + : tree_expression (l, c), sym (s, sc) { } ~tree_identifier (void) { } @@ -63,9 +63,9 @@ // accessing it through sym so that this function may remain const. std::string name (void) const { return sym.name (); } - bool is_defined (void) { return xsym ().is_defined (); } + bool is_defined (void) { return sym->is_defined (); } - virtual bool is_variable (void) { return xsym ().is_variable (); } + virtual bool is_variable (void) { return sym->is_variable (); } virtual bool is_black_hole (void) { return false; } @@ -87,14 +87,14 @@ octave_value do_lookup (const octave_value_list& args = octave_value_list ()) { - return xsym ().find (args); + return sym->find (args); } - void mark_global (void) { xsym ().mark_global (); } + void mark_global (void) { sym->mark_global (); } - void mark_as_static (void) { xsym ().init_persistent (); } + void mark_as_static (void) { sym->init_persistent (); } - void mark_as_formal_parameter (void) { xsym ().mark_formal (); } + void mark_as_formal_parameter (void) { sym->mark_formal (); } // We really need to know whether this symbol referst to a variable // or a function, but we may not know that yet. @@ -114,28 +114,14 @@ void accept (tree_walker& tw); + symbol_table::symbol_reference symbol (void) const + { + return sym; + } private: // The symbol record that this identifier references. - symbol_table::symbol_record sym; - - symbol_table::scope_id scope; - - // A script may be executed in multiple scopes. If the last one was - // different from the one we are in now, update sym to be from the - // new scope. - symbol_table::symbol_record& xsym (void) - { - symbol_table::scope_id curr_scope = symbol_table::current_scope (); - - if (scope != curr_scope || ! sym.is_valid ()) - { - scope = curr_scope; - sym = symbol_table::insert (sym.name ()); - } - - return sym; - } + symbol_table::symbol_reference sym; // No copying! diff --git a/src/pt-jit.cc b/src/pt-jit.cc new file mode 100644 --- /dev/null +++ b/src/pt-jit.cc @@ -0,0 +1,3888 @@ +/* + +Copyright (C) 2012 Max Brister + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#define __STDC_LIMIT_MACROS +#define __STDC_CONSTANT_MACROS + +#ifdef HAVE_CONFIG_H +#include +#endif + +#ifdef HAVE_LLVM + +#include "pt-jit.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "octave.h" +#include "ov-fcn-handle.h" +#include "ov-usr-fcn.h" +#include "ov-builtin.h" +#include "ov-scalar.h" +#include "ov-complex.h" +#include "pt-all.h" +#include "symtab.h" + +static llvm::IRBuilder<> builder (llvm::getGlobalContext ()); + +static llvm::LLVMContext& context = llvm::getGlobalContext (); + +jit_typeinfo *jit_typeinfo::instance; + +// thrown when we should give up on JIT and interpret +class jit_fail_exception : public std::runtime_error +{ +public: + jit_fail_exception (void) : std::runtime_error ("unknown"), mknown (false) {} + jit_fail_exception (const std::string& reason) : std::runtime_error (reason), + mknown (true) + {} + + bool known (void) const { return mknown; } +private: + bool mknown; +}; + +static void fail (void) GCC_ATTR_NORETURN; +static void fail (const std::string&) GCC_ATTR_NORETURN; + +static void +fail (void) +{ + throw jit_fail_exception (); +} + +#ifdef OCTAVE_JIT_DEBUG +static void +fail (const std::string& reason) +{ + throw jit_fail_exception (reason); +} +#else +static void +fail (const std::string&) +{ + throw jit_fail_exception (); +} +#endif // OCTAVE_JIT_DEBUG + +std::ostream& jit_print (std::ostream& os, jit_type *atype) +{ + if (! atype) + return os << "null"; + return os << atype->name (); +} + +// function that jit code calls +extern "C" void +octave_jit_print_any (const char *name, octave_base_value *obv) +{ + obv->print_with_name (octave_stdout, name, true); +} + +extern "C" void +octave_jit_print_double (const char *name, double value) +{ + // FIXME: We should avoid allocating a new octave_scalar each time + octave_value ov (value); + ov.print_with_name (octave_stdout, name); +} + +extern "C" octave_base_value* +octave_jit_binary_any_any (octave_value::binary_op op, octave_base_value *lhs, + octave_base_value *rhs) +{ + octave_value olhs (lhs, true); + octave_value orhs (rhs, true); + octave_value result = do_binary_op (op, olhs, orhs); + octave_base_value *rep = result.internal_rep (); + rep->grab (); + return rep; +} + +extern "C" octave_idx_type +octave_jit_compute_nelem (double base, double limit, double inc) +{ + Range rng = Range (base, limit, inc); + return rng.nelem (); +} + +extern "C" void +octave_jit_release_any (octave_base_value *obv) +{ + obv->release (); +} + +extern "C" void +octave_jit_release_matrix (jit_matrix *m) +{ + delete m->array; +} + +extern "C" octave_base_value * +octave_jit_grab_any (octave_base_value *obv) +{ + obv->grab (); + return obv; +} + +extern "C" void +octave_jit_grab_matrix (jit_matrix *result, jit_matrix *m) +{ + *result = *m->array; +} + +extern "C" octave_base_value * +octave_jit_cast_any_matrix (jit_matrix *m) +{ + octave_value ret (*m->array); + octave_base_value *rep = ret.internal_rep (); + rep->grab (); + delete m->array; + + return rep; +} + +extern "C" void +octave_jit_cast_matrix_any (jit_matrix *ret, octave_base_value *obv) +{ + NDArray m = obv->array_value (); + *ret = m; + obv->release (); +} + +extern "C" double +octave_jit_cast_scalar_any (octave_base_value *obv) +{ + double ret = obv->double_value (); + obv->release (); + return ret; +} + +extern "C" octave_base_value * +octave_jit_cast_any_scalar (double value) +{ + return new octave_scalar (value); +} + +extern "C" Complex +octave_jit_cast_complex_any (octave_base_value *obv) +{ + Complex ret = obv->complex_value (); + obv->release (); + return ret; +} + +extern "C" octave_base_value * +octave_jit_cast_any_complex (Complex c) +{ + if (c.imag () == 0) + return new octave_scalar (c.real ()); + else + return new octave_complex (c); +} + +extern "C" void +octave_jit_gripe_nan_to_logical_conversion (void) +{ + try + { + gripe_nan_to_logical_conversion (); + } + catch (const octave_execution_exception&) + { + gripe_library_execution_error (); + } +} + +extern "C" void +octave_jit_ginvalid_index (void) +{ + try + { + gripe_invalid_index (); + } + catch (const octave_execution_exception&) + { + gripe_library_execution_error (); + } +} + +extern "C" void +octave_jit_gindex_range (int nd, int dim, octave_idx_type iext, + octave_idx_type ext) +{ + try + { + gripe_index_out_of_range (nd, dim, iext, ext); + } + catch (const octave_execution_exception&) + { + gripe_library_execution_error (); + } +} + +extern "C" void +octave_jit_paren_subsasgn_impl (jit_matrix *mat, octave_idx_type index, + double value) +{ + NDArray *array = mat->array; + if (array->nelem () < index) + array->resize1 (index); + + double *data = array->fortran_vec (); + data[index - 1] = value; + + mat->update (); +} + +extern "C" void +octave_jit_paren_subsasgn_matrix_range (jit_matrix *result, jit_matrix *mat, + jit_range *index, double value) +{ + NDArray *array = mat->array; + bool done = false; + + // optimize for the simple case (no resizing and no errors) + if (*array->jit_ref_count () == 1 + && index->all_elements_are_ints ()) + { + // this code is similar to idx_vector::fill, but we avoid allocating an + // idx_vector and its associated rep + octave_idx_type start = static_cast (index->base) - 1; + octave_idx_type step = static_cast (index->inc); + octave_idx_type nelem = index->nelem; + octave_idx_type final = start + nelem * step; + if (step < 0) + { + step = -step; + std::swap (final, start); + } + + if (start >= 0 && final < mat->slice_len) + { + done = true; + + double *data = array->jit_slice_data (); + if (step == 1) + std::fill (data + start, data + start + nelem, value); + else + { + for (octave_idx_type i = start; i < final; i += step) + data[i] = value; + } + } + } + + if (! done) + { + idx_vector idx (*index); + NDArray avalue (dim_vector (1, 1)); + avalue.xelem (0) = value; + array->assign (idx, avalue); + } + + result->update (array); +} + +extern "C" Complex +octave_jit_complex_div (Complex lhs, Complex rhs) +{ + // see src/OPERATORS/op-cs-cs.cc + if (rhs == 0.0) + gripe_divide_by_zero (); + + return lhs / rhs; +} + +// FIXME: CP form src/xpow.cc +static inline int +xisint (double x) +{ + return (D_NINT (x) == x + && ((x >= 0 && x < INT_MAX) + || (x <= 0 && x > INT_MIN))); +} + +extern "C" Complex +octave_jit_pow_scalar_scalar (double lhs, double rhs) +{ + // FIXME: almost CP from src/xpow.cc + if (lhs < 0.0 && ! xisint (rhs)) + return std::pow (Complex (lhs), rhs); + return std::pow (lhs, rhs); +} + +extern "C" Complex +octave_jit_pow_complex_complex (Complex lhs, Complex rhs) +{ + if (lhs.imag () == 0 && rhs.imag () == 0) + return octave_jit_pow_scalar_scalar (lhs.real (), rhs.real ()); + return std::pow (lhs, rhs); +} + +extern "C" Complex +octave_jit_pow_complex_scalar (Complex lhs, double rhs) +{ + if (lhs.imag () == 0) + return octave_jit_pow_scalar_scalar (lhs.real (), rhs); + return std::pow (lhs, rhs); +} + +extern "C" Complex +octave_jit_pow_scalar_complex (double lhs, Complex rhs) +{ + if (rhs.imag () == 0) + return octave_jit_pow_scalar_scalar (lhs, rhs.real ()); + return std::pow (lhs, rhs); +} + +extern "C" void +octave_jit_print_matrix (jit_matrix *m) +{ + std::cout << *m << std::endl; +} + +static void +gripe_bad_result (void) +{ + error ("incorrect type information given to the JIT compiler"); +} + +// FIXME: Add support for multiple outputs +extern "C" octave_base_value * +octave_jit_call (octave_builtin::fcn fn, size_t nargin, + octave_base_value **argin, jit_type *result_type) +{ + octave_value_list ovl (nargin); + for (size_t i = 0; i < nargin; ++i) + ovl.xelem (i) = octave_value (argin[i]); + + ovl = fn (ovl, 1); + + // These type checks are not strictly required, but I'm guessing that + // incorrect types will be entered on occasion. This will be very difficult to + // debug unless we do the sanity check here. + if (result_type) + { + if (ovl.length () != 1) + { + gripe_bad_result (); + return 0; + } + + octave_value& result = ovl.xelem (0); + jit_type *jtype = jit_typeinfo::join (jit_typeinfo::type_of (result), + result_type); + if (jtype != result_type) + { + gripe_bad_result (); + return 0; + } + + octave_base_value *ret = result.internal_rep (); + ret->grab (); + return ret; + } + + if (! (ovl.length () == 0 + || (ovl.length () == 1 && ovl.xelem (0).is_undefined ()))) + gripe_bad_result (); + + return 0; +} + +// -------------------- jit_range -------------------- +bool +jit_range::all_elements_are_ints () const +{ + Range r (*this); + return r.all_elements_are_ints (); +} + +std::ostream& +operator<< (std::ostream& os, const jit_range& rng) +{ + return os << "Range[" << rng.base << ", " << rng.limit << ", " << rng.inc + << ", " << rng.nelem << "]"; +} + +// -------------------- jit_matrix -------------------- + +std::ostream& +operator<< (std::ostream& os, const jit_matrix& mat) +{ + return os << "Matrix[" << mat.ref_count << ", " << mat.slice_data << ", " + << mat.slice_len << ", " << mat.dimensions << ", " + << mat.array << "]"; +} + +// -------------------- jit_type -------------------- +llvm::Type * +jit_type::to_llvm_arg (void) const +{ + return llvm_type ? llvm_type->getPointerTo () : 0; +} + +// -------------------- jit_operation -------------------- +void +jit_operation::add_overload (const overload& func, + const std::vector& args) +{ + if (args.size () >= overloads.size ()) + overloads.resize (args.size () + 1); + + Array& over = overloads[args.size ()]; + dim_vector dv (over.dims ()); + Array idx = to_idx (args); + bool must_resize = false; + + if (dv.length () != idx.numel ()) + { + dv.resize (idx.numel ()); + must_resize = true; + } + + for (octave_idx_type i = 0; i < dv.length (); ++i) + if (dv(i) <= idx(i)) + { + must_resize = true; + dv(i) = idx(i) + 1; + } + + if (must_resize) + over.resize (dv); + + over(idx) = func; +} + +const jit_operation::overload& +jit_operation::get_overload (const std::vector& types) const +{ + // FIXME: We should search for the next best overload on failure + static overload null_overload; + if (types.size () >= overloads.size ()) + return null_overload; + + for (size_t i =0; i < types.size (); ++i) + if (! types[i]) + return null_overload; + + const Array& over = overloads[types.size ()]; + dim_vector dv (over.dims ()); + Array idx = to_idx (types); + for (octave_idx_type i = 0; i < dv.length (); ++i) + if (idx(i) >= dv(i)) + return null_overload; + + return over(idx); +} + +Array +jit_operation::to_idx (const std::vector& types) const +{ + octave_idx_type numel = types.size (); + if (numel == 1) + numel = 2; + + Array idx (dim_vector (1, numel)); + for (octave_idx_type i = 0; i < static_cast (types.size ()); + ++i) + idx(i) = types[i]->type_id (); + + if (types.size () == 1) + { + idx(1) = idx(0); + idx(0) = 0; + } + + return idx; +} + +// -------------------- jit_typeinfo -------------------- +void +jit_typeinfo::initialize (llvm::Module *m, llvm::ExecutionEngine *e) +{ + instance = new jit_typeinfo (m, e); +} + +jit_typeinfo::jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e) + : module (m), engine (e), next_id (0) +{ + // FIXME: We should be registering types like in octave_value_typeinfo + llvm::Type *any_t = llvm::StructType::create (context, "octave_base_value"); + any_t = any_t->getPointerTo (); + + llvm::Type *scalar_t = llvm::Type::getDoubleTy (context); + llvm::Type *bool_t = llvm::Type::getInt1Ty (context); + llvm::Type *string_t = llvm::Type::getInt8Ty (context); + string_t = string_t->getPointerTo (); + llvm::Type *index_t = llvm::Type::getIntNTy (context, + sizeof(octave_idx_type) * 8); + + llvm::StructType *range_t = llvm::StructType::create (context, "range"); + std::vector range_contents (4, scalar_t); + range_contents[3] = index_t; + range_t->setBody (range_contents); + + llvm::Type *refcount_t = llvm::Type::getIntNTy (context, sizeof(int) * 8); + llvm::Type *int_t = refcount_t; + + llvm::StructType *matrix_t = llvm::StructType::create (context, "matrix"); + llvm::Type *matrix_contents[5]; + matrix_contents[0] = refcount_t->getPointerTo (); + matrix_contents[1] = scalar_t->getPointerTo (); + matrix_contents[2] = index_t; + matrix_contents[3] = index_t->getPointerTo (); + matrix_contents[4] = string_t; + matrix_t->setBody (llvm::makeArrayRef (matrix_contents, 5)); + + llvm::Type *complex_t = llvm::VectorType::get (scalar_t, 2); + + // this is the structure that C functions return. Use this in order to get calling + // conventions right. + complex_ret = llvm::StructType::create (context, "complex_ret"); + llvm::Type *complex_ret_contents[] = {scalar_t, scalar_t}; + complex_ret->setBody (complex_ret_contents); + + // create types + any = new_type ("any", 0, any_t); + matrix = new_type ("matrix", any, matrix_t); + complex = new_type ("complex", any, complex_t); + scalar = new_type ("scalar", complex, scalar_t); + range = new_type ("range", any, range_t); + string = new_type ("string", any, string_t); + boolean = new_type ("bool", any, bool_t); + index = new_type ("index", any, index_t); + + casts.resize (next_id + 1); + identities.resize (next_id + 1, 0); + + // bind global variables + lerror_state = new llvm::GlobalVariable (*module, bool_t, false, + llvm::GlobalValue::ExternalLinkage, + 0, "error_state"); + engine->addGlobalMapping (lerror_state, + reinterpret_cast (&error_state)); + + // any with anything is an any op + llvm::Function *fn; + llvm::Type *binary_op_type + = llvm::Type::getIntNTy (context, sizeof (octave_value::binary_op)); + llvm::Function *any_binary = create_function ("octave_jit_binary_any_any", + any_t, binary_op_type, + any_t, any_t); + engine->addGlobalMapping (any_binary, + reinterpret_cast(&octave_jit_binary_any_any)); + + binary_ops.resize (octave_value::num_binary_ops); + for (size_t i = 0; i < octave_value::num_binary_ops; ++i) + { + octave_value::binary_op op = static_cast (i); + std::string op_name = octave_value::binary_op_as_string (op); + binary_ops[i].stash_name ("binary" + op_name); + } + + for (int op = 0; op < octave_value::num_binary_ops; ++op) + { + llvm::Twine fn_name ("octave_jit_binary_any_any_"); + fn_name = fn_name + llvm::Twine (op); + fn = create_function (fn_name, any, any, any); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (block); + llvm::APInt op_int(sizeof (octave_value::binary_op), op, + std::numeric_limits::is_signed); + llvm::Value *op_as_llvm = llvm::ConstantInt::get (binary_op_type, op_int); + llvm::Value *ret = builder.CreateCall3 (any_binary, + op_as_llvm, + fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + binary_ops[op].add_overload (fn, true, any, any, any); + } + + llvm::Type *void_t = llvm::Type::getVoidTy (context); + + // grab any + fn = create_function ("octave_jit_grab_any", any, any); + engine->addGlobalMapping (fn, reinterpret_cast(&octave_jit_grab_any)); + grab_fn.add_overload (fn, false, any, any); + grab_fn.stash_name ("grab"); + + // grab matrix + llvm::Function *print_matrix = create_function ("octave_jit_print_matrix", + void_t, + matrix_t->getPointerTo ()); + engine->addGlobalMapping (print_matrix, + reinterpret_cast(&octave_jit_print_matrix)); + fn = create_function ("octave_jit_grab_matrix", void_t, + matrix_t->getPointerTo (), matrix_t->getPointerTo ()); + engine->addGlobalMapping (fn, + reinterpret_cast (&octave_jit_grab_matrix)); + grab_fn.add_overload (fn, false, matrix, matrix); + + // release any + fn = create_function ("octave_jit_release_any", void_t, any_t); + llvm::Function *release_any = fn; + engine->addGlobalMapping (fn, + reinterpret_cast(&octave_jit_release_any)); + release_fn.add_overload (fn, false, 0, any); + release_fn.stash_name ("release"); + + // release matrix + fn = create_function ("octave_jit_release_matrix", void_t, + matrix_t->getPointerTo ()); + engine->addGlobalMapping (fn, + reinterpret_cast (&octave_jit_release_matrix)); + release_fn.add_overload (fn, false, 0, matrix); + + // release scalar + fn = create_identity (scalar); + release_fn.add_overload (fn, false, 0, scalar); + + // release complex + fn = create_identity (complex); + release_fn.add_overload (fn, false, 0, complex); + + // release index + fn = create_identity (index); + release_fn.add_overload (fn, false, 0, index); + + // now for binary scalar operations + // FIXME: Finish all operations + add_binary_op (scalar, octave_value::op_add, llvm::Instruction::FAdd); + add_binary_op (scalar, octave_value::op_sub, llvm::Instruction::FSub); + add_binary_op (scalar, octave_value::op_mul, llvm::Instruction::FMul); + add_binary_op (scalar, octave_value::op_el_mul, llvm::Instruction::FMul); + + add_binary_fcmp (scalar, octave_value::op_lt, llvm::CmpInst::FCMP_ULT); + add_binary_fcmp (scalar, octave_value::op_le, llvm::CmpInst::FCMP_ULE); + add_binary_fcmp (scalar, octave_value::op_eq, llvm::CmpInst::FCMP_UEQ); + add_binary_fcmp (scalar, octave_value::op_ge, llvm::CmpInst::FCMP_UGE); + add_binary_fcmp (scalar, octave_value::op_gt, llvm::CmpInst::FCMP_UGT); + add_binary_fcmp (scalar, octave_value::op_ne, llvm::CmpInst::FCMP_UNE); + + llvm::Function *gripe_div0 = create_function ("gripe_divide_by_zero", void_t); + engine->addGlobalMapping (gripe_div0, + reinterpret_cast (&gripe_divide_by_zero)); + + // divide is annoying because it might error + fn = create_function ("octave_jit_div_scalar_scalar", scalar, scalar, scalar); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::BasicBlock *warn_block = llvm::BasicBlock::Create (context, "warn", + fn); + llvm::BasicBlock *normal_block = llvm::BasicBlock::Create (context, + "normal", fn); + + llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); + llvm::Value *check = builder.CreateFCmpUEQ (zero, ++fn->arg_begin ()); + builder.CreateCondBr (check, warn_block, normal_block); + + builder.SetInsertPoint (warn_block); + builder.CreateCall (gripe_div0); + builder.CreateBr (normal_block); + + builder.SetInsertPoint (normal_block); + llvm::Value *ret = builder.CreateFDiv (fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + + jit_operation::overload ol (fn, true, scalar, scalar, scalar); + binary_ops[octave_value::op_div].add_overload (ol); + binary_ops[octave_value::op_el_div].add_overload (ol); + } + llvm::verifyFunction (*fn); + + // ldiv is the same as div with the operators reversed + fn = mirror_binary (fn); + { + jit_operation::overload ol (fn, true, scalar, scalar, scalar); + binary_ops[octave_value::op_ldiv].add_overload (ol); + binary_ops[octave_value::op_el_ldiv].add_overload (ol); + } + + // In general, the result of scalar ^ scalar is a complex number. We might be + // able to improve on this if we keep track of the range of values varaibles + // can take on. + fn = create_function ("octave_jit_pow_scalar_scalar", complex_ret, scalar_t, + scalar_t); + engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_scalar_scalar)); + { + jit_operation::overload ol (wrap_complex (fn), false, complex, scalar, + scalar); + binary_ops[octave_value::op_pow].add_overload (ol); + binary_ops[octave_value::op_el_pow].add_overload (ol); + } + + // now for binary complex operations + add_binary_op (complex, octave_value::op_add, llvm::Instruction::FAdd); + add_binary_op (complex, octave_value::op_sub, llvm::Instruction::FSub); + + fn = create_function ("octave_jit_*_complex_complex", complex, complex, + complex); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + // (x0*x1 - y0*y1, x0*y1 + y0*x1) = (x0,y0) * (x1,y1) + // We compute this in one vectorized multiplication, a subtraction, and an + // addition. + llvm::Value *lhs = fn->arg_begin (); + llvm::Value *rhs = ++fn->arg_begin (); + + // FIXME: We need a better way of doing this, working with llvm's IR + // directly is sort of a pain. + llvm::Value *zero = builder.getInt32 (0); + llvm::Value *one = builder.getInt32 (1); + llvm::Value *two = builder.getInt32 (2); + llvm::Value *three = builder.getInt32 (3); + + llvm::Type *vec4 = llvm::VectorType::get (scalar_t, 4); + llvm::Value *mlhs = llvm::UndefValue::get (vec4); + llvm::Value *mrhs = mlhs; + + llvm::Value *temp = complex_real (lhs); + mlhs = builder.CreateInsertElement (mlhs, temp, zero); + mlhs = builder.CreateInsertElement (mlhs, temp, two); + temp = complex_imag (lhs); + mlhs = builder.CreateInsertElement (mlhs, temp, one); + mlhs = builder.CreateInsertElement (mlhs, temp, three); + + temp = complex_real (rhs); + mrhs = builder.CreateInsertElement (mrhs, temp, zero); + mrhs = builder.CreateInsertElement (mrhs, temp, three); + temp = complex_imag (rhs); + mrhs = builder.CreateInsertElement (mrhs, temp, one); + mrhs = builder.CreateInsertElement (mrhs, temp, two); + + llvm::Value *mres = builder.CreateFMul (mlhs, mrhs); + llvm::Value *tlhs = builder.CreateExtractElement (mres, zero); + llvm::Value *trhs = builder.CreateExtractElement (mres, one); + llvm::Value *ret_real = builder.CreateFSub (tlhs, trhs); + + tlhs = builder.CreateExtractElement (mres, two); + trhs = builder.CreateExtractElement (mres, three); + llvm::Value *ret_imag = builder.CreateFAdd (tlhs, trhs); + builder.CreateRet (complex_new (ret_real, ret_imag)); + + jit_operation::overload ol (fn, false, complex, complex, complex); + binary_ops[octave_value::op_mul].add_overload (ol); + binary_ops[octave_value::op_el_mul].add_overload (ol); + } + llvm::verifyFunction (*fn); + + llvm::Function *complex_div = create_function ("octave_jit_complex_div", + complex_ret, complex_ret, + complex_ret); + engine->addGlobalMapping (complex_div, + reinterpret_cast (&octave_jit_complex_div)); + complex_div = wrap_complex (complex_div); + { + jit_operation::overload ol (complex_div, true, complex, complex, complex); + binary_ops[octave_value::op_div].add_overload (ol); + binary_ops[octave_value::op_ldiv].add_overload (ol); + } + + fn = mirror_binary (complex_div); + { + jit_operation::overload ol (fn, true, complex, complex, complex); + binary_ops[octave_value::op_ldiv].add_overload (ol); + binary_ops[octave_value::op_el_ldiv].add_overload (ol); + } + + fn = create_function ("octave_jit_pow_complex_complex", complex_ret, + complex_ret, complex_ret); + engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_complex_complex)); + { + jit_operation::overload ol (wrap_complex (fn), false, complex, complex, + complex); + binary_ops[octave_value::op_pow].add_overload (ol); + binary_ops[octave_value::op_el_pow].add_overload (ol); + } + + fn = create_function ("octave_jit_*_scalar_complex", complex, scalar, + complex); + llvm::Function *mul_scalar_complex = fn; + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn->arg_begin (); + llvm::Value *tlhs = complex_new (lhs, lhs); + llvm::Value *rhs = ++fn->arg_begin (); + builder.CreateRet (builder.CreateFMul (tlhs, rhs)); + + jit_operation::overload ol (fn, false, complex, scalar, complex); + binary_ops[octave_value::op_mul].add_overload (ol); + binary_ops[octave_value::op_el_mul].add_overload (ol); + } + llvm::verifyFunction (*fn); + + fn = mirror_binary (mul_scalar_complex); + { + jit_operation::overload ol (fn, false, complex, complex, scalar); + binary_ops[octave_value::op_mul].add_overload (ol); + binary_ops[octave_value::op_el_mul].add_overload (ol); + } + + fn = create_function ("octave_jit_+_scalar_complex", complex, scalar, + complex); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn->arg_begin (); + llvm::Value *rhs = ++fn->arg_begin (); + llvm::Value *real = builder.CreateFAdd (lhs, complex_real (rhs)); + builder.CreateRet (complex_real (rhs, real)); + llvm::verifyFunction (*fn); + + binary_ops[octave_value::op_add].add_overload (fn, false, complex, scalar, + complex); + fn = mirror_binary (fn); + binary_ops[octave_value::op_add].add_overload (fn, false, complex, complex, + scalar); + } + + fn = create_function ("octave_jit_-_complex_scalar", complex, complex, + scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn->arg_begin (); + llvm::Value *rhs = ++fn->arg_begin (); + llvm::Value *real = builder.CreateFSub (complex_real (lhs), rhs); + builder.CreateRet (complex_real (lhs, real)); + llvm::verifyFunction (*fn); + + binary_ops[octave_value::op_sub].add_overload (fn, false, complex, complex, + scalar); + } + + fn = create_function ("octave_jit_-_scalar_complex", complex, scalar, + complex); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *lhs = fn->arg_begin (); + llvm::Value *rhs = ++fn->arg_begin (); + llvm::Value *real = builder.CreateFSub (lhs, complex_real (rhs)); + builder.CreateRet (complex_real (rhs, real)); + llvm::verifyFunction (*fn); + + binary_ops[octave_value::op_sub].add_overload (fn, false, complex, scalar, + complex); + } + + fn = create_function ("octave_jit_pow_scalar_complex", complex_ret, + scalar_t, complex_ret); + engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_scalar_complex)); + { + jit_operation::overload ol (wrap_complex (fn), false, complex, scalar, + complex); + binary_ops[octave_value::op_pow].add_overload (ol); + binary_ops[octave_value::op_el_pow].add_overload (ol); + } + + fn = create_function ("octave_jit_pow_complex_scalar", complex_ret, + complex_ret, scalar_t); + engine->addGlobalMapping (fn, reinterpret_cast (octave_jit_pow_complex_complex)); + { + jit_operation::overload ol (wrap_complex (fn), false, complex, complex, + scalar); + binary_ops[octave_value::op_pow].add_overload (ol); + binary_ops[octave_value::op_el_pow].add_overload (ol); + } + + // now for binary index operators + add_binary_op (index, octave_value::op_add, llvm::Instruction::Add); + + // and binary bool operators + add_binary_op (boolean, octave_value::op_el_or, llvm::Instruction::Or); + add_binary_op (boolean, octave_value::op_el_and, llvm::Instruction::And); + + // now for printing functions + print_fn.stash_name ("print"); + add_print (any, reinterpret_cast (&octave_jit_print_any)); + add_print (scalar, reinterpret_cast (&octave_jit_print_double)); + + // initialize for loop + for_init_fn.stash_name ("for_init"); + + fn = create_function ("octave_jit_for_range_init", index, range); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *zero = llvm::ConstantInt::get (index_t, 0); + builder.CreateRet (zero); + } + llvm::verifyFunction (*fn); + for_init_fn.add_overload (fn, false, index, range); + + // bounds check for for loop + for_check_fn.stash_name ("for_check"); + + fn = create_function ("octave_jit_for_range_check", boolean, range, index); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *nelem + = builder.CreateExtractValue (fn->arg_begin (), 3); + llvm::Value *idx = ++fn->arg_begin (); + llvm::Value *ret = builder.CreateICmpULT (idx, nelem); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + for_check_fn.add_overload (fn, false, boolean, range, index); + + // index variabe for for loop + for_index_fn.stash_name ("for_index"); + + fn = create_function ("octave_jit_for_range_idx", scalar, range, index); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *idx = ++fn->arg_begin (); + llvm::Value *didx = builder.CreateSIToFP (idx, scalar_t); + llvm::Value *rng = fn->arg_begin (); + llvm::Value *base = builder.CreateExtractValue (rng, 0); + llvm::Value *inc = builder.CreateExtractValue (rng, 2); + + llvm::Value *ret = builder.CreateFMul (didx, inc); + ret = builder.CreateFAdd (base, ret); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + for_index_fn.add_overload (fn, false, scalar, range, index); + + // logically true + logically_true_fn.stash_name ("logically_true"); + + llvm::Function *gripe_nantl + = create_function ("octave_jit_gripe_nan_to_logical_conversion", void_t); + engine->addGlobalMapping (gripe_nantl, reinterpret_cast (&octave_jit_gripe_nan_to_logical_conversion)); + + fn = create_function ("octave_jit_logically_true_scalar", boolean, scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::BasicBlock *error_block = llvm::BasicBlock::Create (context, "error", + fn); + llvm::BasicBlock *normal_block = llvm::BasicBlock::Create (context, + "normal", fn); + + llvm::Value *check = builder.CreateFCmpUNE (fn->arg_begin (), + fn->arg_begin ()); + builder.CreateCondBr (check, error_block, normal_block); + + builder.SetInsertPoint (error_block); + builder.CreateCall (gripe_nantl); + builder.CreateBr (normal_block); + builder.SetInsertPoint (normal_block); + + llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); + llvm::Value *ret = builder.CreateFCmpONE (fn->arg_begin (), zero); + builder.CreateRet (ret); + } + llvm::verifyFunction (*fn); + logically_true_fn.add_overload (fn, true, boolean, scalar); + + fn = create_function ("octave_logically_true_bool", boolean, boolean); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + builder.CreateRet (fn->arg_begin ()); + llvm::verifyFunction (*fn); + logically_true_fn.add_overload (fn, false, boolean, boolean); + + // make_range + // FIXME: May be benificial to implement all in LLVM + make_range_fn.stash_name ("make_range"); + llvm::Function *compute_nelem + = create_function ("octave_jit_compute_nelem", index, scalar, scalar, + scalar); + engine->addGlobalMapping (compute_nelem, + reinterpret_cast (&octave_jit_compute_nelem)); + + fn = create_function ("octave_jit_make_range", range, scalar, scalar, scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Function::arg_iterator args = fn->arg_begin (); + llvm::Value *base = args; + llvm::Value *limit = ++args; + llvm::Value *inc = ++args; + llvm::Value *nelem = builder.CreateCall3 (compute_nelem, base, limit, inc); + + llvm::Value *dzero = llvm::ConstantFP::get (scalar_t, 0); + llvm::Value *izero = llvm::ConstantInt::get (index_t, 0); + llvm::Value *rng = llvm::ConstantStruct::get (range_t, dzero, dzero, dzero, + izero, NULL); + rng = builder.CreateInsertValue (rng, base, 0); + rng = builder.CreateInsertValue (rng, limit, 1); + rng = builder.CreateInsertValue (rng, inc, 2); + rng = builder.CreateInsertValue (rng, nelem, 3); + builder.CreateRet (rng); + } + llvm::verifyFunction (*fn); + make_range_fn.add_overload (fn, false, range, scalar, scalar, scalar); + + // paren_subsref + llvm::Function *ginvalid_index = create_function ("gipe_invalid_index", + void_t); + engine->addGlobalMapping (ginvalid_index, + reinterpret_cast (&octave_jit_ginvalid_index)); + + llvm::Function *gindex_range = create_function ("gripe_index_out_of_range", + void_t, int_t, int_t, index_t, + index_t); + engine->addGlobalMapping (gindex_range, + reinterpret_cast (&octave_jit_gindex_range)); + + fn = create_function ("()subsref", scalar, matrix, scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (index_t, 1); + llvm::Value *ione; + if (index_t == int_t) + ione = one; + else + ione = llvm::ConstantInt::get (int_t, 1); + + llvm::Value *undef = llvm::UndefValue::get (scalar_t); + + llvm::Function::arg_iterator args = fn->arg_begin (); + llvm::Value *mat = args++; + llvm::Value *idx = args; + + // convert index to scalar to integer, and check index >= 1 + llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t); + llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t); + llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx); + llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one); + llvm::Value *cond = builder.CreateOr (cond0, cond1); + + llvm::BasicBlock *done = llvm::BasicBlock::Create (context, "done", fn); + + llvm::BasicBlock *conv_error = llvm::BasicBlock::Create (context, + "conv_error", fn, + done); + llvm::BasicBlock *normal = llvm::BasicBlock::Create (context, "normal", fn, + done); + builder.CreateCondBr (cond, conv_error, normal); + + builder.SetInsertPoint (conv_error); + builder.CreateCall (ginvalid_index); + builder.CreateBr (done); + + builder.SetInsertPoint (normal); + llvm::Value *len = builder.CreateExtractValue (mat, + llvm::ArrayRef (2)); + cond = builder.CreateICmpSGT (int_idx, len); + + + llvm::BasicBlock *bounds_error = llvm::BasicBlock::Create (context, + "bounds_error", + fn, done); + + llvm::BasicBlock *success = llvm::BasicBlock::Create (context, "success", + fn, done); + builder.CreateCondBr (cond, bounds_error, success); + + builder.SetInsertPoint (bounds_error); + builder.CreateCall4 (gindex_range, ione, ione, int_idx, len); + builder.CreateBr (done); + + builder.SetInsertPoint (success); + llvm::Value *data = builder.CreateExtractValue (mat, + llvm::ArrayRef (1)); + llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx); + llvm::Value *ret = builder.CreateLoad (gep); + builder.CreateBr (done); + + builder.SetInsertPoint (done); + + llvm::PHINode *merge = llvm::PHINode::Create (scalar_t, 3); + builder.Insert (merge); + merge->addIncoming (undef, conv_error); + merge->addIncoming (undef, bounds_error); + merge->addIncoming (ret, success); + builder.CreateRet (merge); + } + llvm::verifyFunction (*fn); + paren_subsref_fn.add_overload (fn, true, scalar, matrix, scalar); + + // paren subsasgn + paren_subsasgn_fn.stash_name ("()subsasgn"); + + llvm::Function *resize_paren_subsasgn + = create_function ("octave_jit_paren_subsasgn_impl", void_t, + matrix_t->getPointerTo (), index_t, scalar_t); + engine->addGlobalMapping (resize_paren_subsasgn, + reinterpret_cast (&octave_jit_paren_subsasgn_impl)); + + fn = create_function ("octave_jit_paren_subsasgn", matrix, matrix, scalar, + scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *one = llvm::ConstantInt::get (index_t, 1); + + llvm::Function::arg_iterator args = fn->arg_begin (); + llvm::Value *mat = args++; + llvm::Value *idx = args++; + llvm::Value *value = args; + + llvm::Value *int_idx = builder.CreateFPToSI (idx, index_t); + llvm::Value *check_idx = builder.CreateSIToFP (int_idx, scalar_t); + llvm::Value *cond0 = builder.CreateFCmpUNE (idx, check_idx); + llvm::Value *cond1 = builder.CreateICmpSLT (int_idx, one); + llvm::Value *cond = builder.CreateOr (cond0, cond1); + + llvm::BasicBlock *done = llvm::BasicBlock::Create (context, "done", fn); + + llvm::BasicBlock *conv_error = llvm::BasicBlock::Create (context, + "conv_error", fn, + done); + llvm::BasicBlock *normal = llvm::BasicBlock::Create (context, "normal", fn, + done); + builder.CreateCondBr (cond, conv_error, normal); + builder.SetInsertPoint (conv_error); + builder.CreateCall (ginvalid_index); + builder.CreateBr (done); + + builder.SetInsertPoint (normal); + llvm::Value *len = builder.CreateExtractValue (mat, + llvm::ArrayRef (2)); + cond0 = builder.CreateICmpSGT (int_idx, len); + + llvm::Value *rcount = builder.CreateExtractValue (mat, 0); + rcount = builder.CreateLoad (rcount); + cond1 = builder.CreateICmpSGT (rcount, one); + cond = builder.CreateOr (cond0, cond1); + + llvm::BasicBlock *bounds_error = llvm::BasicBlock::Create (context, + "bounds_error", + fn, done); + + llvm::BasicBlock *success = llvm::BasicBlock::Create (context, "success", + fn, done); + builder.CreateCondBr (cond, bounds_error, success); + + // resize on out of bounds access + builder.SetInsertPoint (bounds_error); + llvm::Value *resize_result = builder.CreateAlloca (matrix_t); + builder.CreateStore (mat, resize_result); + builder.CreateCall3 (resize_paren_subsasgn, resize_result, int_idx, value); + resize_result = builder.CreateLoad (resize_result); + builder.CreateBr (done); + + builder.SetInsertPoint (success); + llvm::Value *data = builder.CreateExtractValue (mat, + llvm::ArrayRef (1)); + llvm::Value *gep = builder.CreateInBoundsGEP (data, int_idx); + builder.CreateStore (value, gep); + builder.CreateBr (done); + + builder.SetInsertPoint (done); + + llvm::PHINode *merge = llvm::PHINode::Create (matrix_t, 3); + builder.Insert (merge); + merge->addIncoming (mat, conv_error); + merge->addIncoming (resize_result, bounds_error); + merge->addIncoming (mat, success); + builder.CreateRet (merge); + } + llvm::verifyFunction (*fn); + paren_subsasgn_fn.add_overload (fn, true, matrix, matrix, scalar, scalar); + + fn = create_function ("octave_jit_paren_subsasgn_matrix_range", void_t, + matrix_t->getPointerTo (), matrix_t->getPointerTo (), + range_t->getPointerTo (), scalar_t); + engine->addGlobalMapping (fn, + reinterpret_cast (&octave_jit_paren_subsasgn_matrix_range)); + paren_subsasgn_fn.add_overload (fn, true, matrix, matrix, range, scalar); + + casts[any->type_id ()].stash_name ("(any)"); + casts[scalar->type_id ()].stash_name ("(scalar)"); + casts[complex->type_id ()].stash_name ("(complex)"); + casts[matrix->type_id ()].stash_name ("(matrix)"); + + // cast any <- matrix + fn = create_function ("octave_jit_cast_any_matrix", any_t, + matrix_t->getPointerTo ()); + engine->addGlobalMapping (fn, + reinterpret_cast (&octave_jit_cast_any_matrix)); + casts[any->type_id ()].add_overload (fn, false, any, matrix); + + // cast matrix <- any + fn = create_function ("octave_jit_cast_matrix_any", void_t, + matrix_t->getPointerTo (), any_t); + engine->addGlobalMapping (fn, + reinterpret_cast (&octave_jit_cast_matrix_any)); + casts[matrix->type_id ()].add_overload (fn, false, matrix, any); + + // cast any <- scalar + fn = create_function ("octave_jit_cast_any_scalar", any, scalar); + engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_any_scalar)); + casts[any->type_id ()].add_overload (fn, false, any, scalar); + + // cast scalar <- any + fn = create_function ("octave_jit_cast_scalar_any", scalar, any); + engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_scalar_any)); + casts[scalar->type_id ()].add_overload (fn, false, scalar, any); + + // cast any <- complex + fn = create_function ("octave_jit_cast_any_complex", any_t, complex_ret); + engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_any_complex)); + casts[any->type_id ()].add_overload (wrap_complex (fn), false, any, complex); + + // cast complex <- any + fn = create_function ("octave_jit_cast_complex_any", complex_ret, any_t); + engine->addGlobalMapping (fn, reinterpret_cast (&octave_jit_cast_complex_any)); + casts[complex->type_id ()].add_overload (wrap_complex (fn), false, complex, + any); + + // cast complex <- scalar + fn = create_function ("octave_jit_cast_complex_scalar", complex, scalar); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + llvm::Value *zero = llvm::ConstantFP::get (scalar_t, 0); + builder.CreateRet (complex_new (fn->arg_begin (), zero)); + llvm::verifyFunction (*fn); + } + casts[complex->type_id ()].add_overload (fn, false, complex, scalar); + + // cast scalar <- complex + fn = create_function ("octave_jit_cast_scalar_complex", scalar, complex); + body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + { + builder.CreateRet (complex_real (fn->arg_begin ())); + llvm::verifyFunction (*fn); + } + casts[scalar->type_id ()].add_overload (fn, false, scalar, complex); + + // cast any <- any + fn = create_identity (any); + casts[any->type_id ()].add_overload (fn, false, any, any); + + // cast scalar <- scalar + fn = create_identity (scalar); + casts[scalar->type_id ()].add_overload (fn, false, scalar, scalar); + + // cast complex <- complex + fn = create_identity (complex); + casts[complex->type_id ()].add_overload (fn, false, complex, complex); + + // -------------------- builtin functions -------------------- + add_builtin ("#unknown_function"); + unknown_function = builtins["#unknown_function"]; + + add_builtin ("sin"); + register_intrinsic ("sin", llvm::Intrinsic::sin, scalar, scalar); + register_generic ("sin", matrix, matrix); + + add_builtin ("cos"); + register_intrinsic ("cos", llvm::Intrinsic::cos, scalar, scalar); + register_generic ("cos", matrix, matrix); + + add_builtin ("exp"); + register_intrinsic ("exp", llvm::Intrinsic::cos, scalar, scalar); + register_generic ("exp", matrix, matrix); + + casts.resize (next_id + 1); + fn = create_identity (any); + for (std::map::iterator iter = builtins.begin (); + iter != builtins.end (); ++iter) + { + jit_type *btype = iter->second; + release_fn.add_overload (release_any, false, 0, btype); + casts[any->type_id ()].add_overload (fn, false, any, btype); + casts[btype->type_id ()].add_overload (fn, false, btype, any); + } +} + +void +jit_typeinfo::add_print (jit_type *ty, void *call) +{ + std::stringstream name; + name << "octave_jit_print_" << ty->name (); + + llvm::Type *void_t = llvm::Type::getVoidTy (context); + llvm::Function *fn = create_function (name.str (), void_t, + llvm::Type::getInt8PtrTy (context), + ty->to_llvm ()); + engine->addGlobalMapping (fn, call); + + jit_operation::overload ol (fn, false, 0, string, ty); + print_fn.add_overload (ol); +} + +// FIXME: cp between add_binary_op, add_binary_icmp, and add_binary_fcmp +void +jit_typeinfo::add_binary_op (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast(op); + fname << "octave_jit_" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + llvm::Function *fn = create_function (fname.str (), ty, ty, ty); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (block); + llvm::Instruction::BinaryOps temp + = static_cast(llvm_op); + llvm::Value *ret = builder.CreateBinOp (temp, fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + jit_operation::overload ol(fn, false, ty, ty, ty); + binary_ops[op].add_overload (ol); +} + +void +jit_typeinfo::add_binary_icmp (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast(op); + fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (block); + llvm::CmpInst::Predicate temp + = static_cast(llvm_op); + llvm::Value *ret = builder.CreateICmp (temp, fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + jit_operation::overload ol (fn, false, boolean, ty, ty); + binary_ops[op].add_overload (ol); +} + +void +jit_typeinfo::add_binary_fcmp (jit_type *ty, int op, int llvm_op) +{ + std::stringstream fname; + octave_value::binary_op ov_op = static_cast(op); + fname << "octave_jit" << octave_value::binary_op_as_string (ov_op) + << "_" << ty->name (); + + llvm::Function *fn = create_function (fname.str (), boolean, ty, ty); + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (block); + llvm::CmpInst::Predicate temp + = static_cast(llvm_op); + llvm::Value *ret = builder.CreateFCmp (temp, fn->arg_begin (), + ++fn->arg_begin ()); + builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + jit_operation::overload ol (fn, false, boolean, ty, ty); + binary_ops[op].add_overload (ol); +} + +llvm::Function * +jit_typeinfo::create_function (const llvm::Twine& name, jit_type *ret, + const std::vector& args) +{ + llvm::Type *void_t = llvm::Type::getVoidTy (context); + std::vector llvm_args (args.size (), void_t); + for (size_t i = 0; i < args.size (); ++i) + if (args[i]) + llvm_args[i] = args[i]->to_llvm (); + + return create_function (name, ret ? ret->to_llvm () : void_t, llvm_args); +} + +llvm::Function * +jit_typeinfo::create_function (const llvm::Twine& name, llvm::Type *ret, + const std::vector& args) +{ + llvm::FunctionType *ft = llvm::FunctionType::get (ret, args, false); + llvm::Function *fn = llvm::Function::Create (ft, + llvm::Function::ExternalLinkage, + name, module); + fn->addFnAttr (llvm::Attribute::AlwaysInline); + return fn; +} + +llvm::Function * +jit_typeinfo::create_identity (jit_type *type) +{ + size_t id = type->type_id (); + if (id >= identities.size ()) + identities.resize (id + 1, 0); + + if (! identities[id]) + { + llvm::Function *fn = create_function ("id", type, type); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + builder.CreateRet (fn->arg_begin ()); + llvm::verifyFunction (*fn); + identities[id] = fn; + } + + return identities[id]; +} + +llvm::Value * +jit_typeinfo::do_insert_error_check (void) +{ + return builder.CreateLoad (lerror_state); +} + +void +jit_typeinfo::add_builtin (const std::string& name) +{ + jit_type *btype = new_type (name, any, any->to_llvm ()); + builtins[name] = btype; + + octave_builtin *ov_builtin = find_builtin (name); + if (ov_builtin) + ov_builtin->stash_jit (*btype); +} + +void +jit_typeinfo::register_intrinsic (const std::string& name, size_t iid, + jit_type *result, + const std::vector& args) +{ + jit_type *builtin_type = builtins[name]; + size_t nargs = args.size (); + llvm::SmallVector llvm_args (nargs); + for (size_t i = 0; i < nargs; ++i) + llvm_args[i] = args[i]->to_llvm (); + + llvm::Intrinsic::ID id = static_cast (iid); + llvm::Function *ifun = llvm::Intrinsic::getDeclaration (module, id, + llvm_args); + std::stringstream fn_name; + fn_name << "octave_jit_" << name; + + std::vector args1 (nargs + 1); + args1[0] = builtin_type; + std::copy (args.begin (), args.end (), args1.begin () + 1); + + // The first argument will be the Octave function, but we already know that + // the function call is the equivalent of the intrinsic, so we ignore it and + // call the intrinsic with the remaining arguments. + llvm::Function *fn = create_function (fn_name.str (), result, args1); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + + llvm::SmallVector fargs (nargs); + llvm::Function::arg_iterator iter = fn->arg_begin (); + ++iter; + for (size_t i = 0; i < nargs; ++i, ++iter) + fargs[i] = iter; + + llvm::Value *ret = builder.CreateCall (ifun, fargs); + builder.CreateRet (ret); + llvm::verifyFunction (*fn); + + paren_subsref_fn.add_overload (fn, false, result, args1); +} + +octave_builtin * +jit_typeinfo::find_builtin (const std::string& name) +{ + // FIXME: Finalize what we want to store in octave_builtin, then add functions + // to access these values in octave_value + octave_value ov_builtin = symbol_table::find (name); + return dynamic_cast (ov_builtin.internal_rep ()); +} + +void +jit_typeinfo::register_generic (const std::string&, jit_type *, + const std::vector&) +{ + // FIXME: Implement +} + +llvm::Function * +jit_typeinfo::mirror_binary (llvm::Function *fn) +{ + llvm::FunctionType *fn_type = fn->getFunctionType (); + llvm::Function *ret = create_function (fn->getName () + "_reverse", + fn_type->getReturnType (), + fn_type->getParamType (1), + fn_type->getParamType (0)); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", ret); + builder.SetInsertPoint (body); + llvm::Value *result = builder.CreateCall2 (fn, ++ret->arg_begin (), + ret->arg_begin ()); + if (ret->getReturnType () == builder.getVoidTy ()) + builder.CreateRetVoid (); + else + builder.CreateRet (result); + + llvm::verifyFunction (*ret); + return ret; +} + +llvm::Function * +jit_typeinfo::wrap_complex (llvm::Function *wrap) +{ + llvm::SmallVector new_args; + new_args.reserve (wrap->arg_size ()); + llvm::Type *complex_t = complex->to_llvm (); + for (llvm::Function::arg_iterator iter = wrap->arg_begin (); + iter != wrap->arg_end (); ++iter) + { + llvm::Value *value = iter; + llvm::Type *type = value->getType (); + new_args.push_back (type == complex_ret ? complex_t : type); + } + + llvm::FunctionType *wrap_type = wrap->getFunctionType (); + bool convert_ret = wrap_type->getReturnType () == complex_ret; + llvm::Type *rtype = convert_ret ? complex_t : wrap->getReturnType (); + llvm::FunctionType *ft = llvm::FunctionType::get (rtype, new_args, false); + llvm::Function *fn = llvm::Function::Create (ft, + llvm::Function::ExternalLinkage, + wrap->getName () + "_wrap", + module); + llvm::BasicBlock *body = llvm::BasicBlock::Create (context, "body", fn); + builder.SetInsertPoint (body); + + llvm::SmallVector converted (new_args.size ()); + llvm::Function::arg_iterator witer = wrap->arg_begin (); + llvm::Function::arg_iterator fiter = fn->arg_begin (); + for (size_t i = 0; i < new_args.size (); ++i, ++witer, ++fiter) + { + llvm::Value *warg = witer; + llvm::Value *arg = fiter; + converted[i] = warg->getType () == arg->getType () ? arg + : pack_complex (arg); + } + + llvm::Value *ret = builder.CreateCall (wrap, converted); + if (wrap_type->getReturnType () != builder.getVoidTy ()) + { + if (convert_ret) + ret = unpack_complex (ret); + builder.CreateRet (ret); + } + else + builder.CreateRetVoid (); + + llvm::verifyFunction (*fn); + return fn; +} + +llvm::Value * +jit_typeinfo::pack_complex (llvm::Value *cplx) +{ + llvm::Value *real = builder.CreateExtractElement (cplx, builder.getInt32 (0)); + llvm::Value *imag = builder.CreateExtractElement (cplx, builder.getInt32 (1)); + llvm::Value *ret = llvm::UndefValue::get (complex_ret); + ret = builder.CreateInsertValue (ret, real, 0); + return builder.CreateInsertValue (ret, imag, 1); +} + +llvm::Value * +jit_typeinfo::unpack_complex (llvm::Value *result) +{ + llvm::Type *complex_t = complex->to_llvm (); + llvm::Value *real = builder.CreateExtractValue (result, 0); + llvm::Value *imag = builder.CreateExtractValue (result, 1); + llvm::Value *ret = llvm::UndefValue::get (complex_t); + ret = builder.CreateInsertElement (ret, real, builder.getInt32 (0)); + return builder.CreateInsertElement (ret, imag, builder.getInt32 (1)); +} + +llvm::Value * +jit_typeinfo::complex_real (llvm::Value *cx) +{ + return builder.CreateExtractElement (cx, builder.getInt32 (0)); +} + +llvm::Value * +jit_typeinfo::complex_real (llvm::Value *cx, llvm::Value *real) +{ + return builder.CreateInsertElement (cx, real, builder.getInt32 (0)); +} + +llvm::Value * +jit_typeinfo::complex_imag (llvm::Value *cx) +{ + return builder.CreateExtractElement (cx, builder.getInt32 (1)); +} + +llvm::Value * +jit_typeinfo::complex_imag (llvm::Value *cx, llvm::Value *imag) +{ + return builder.CreateInsertElement (cx, imag, builder.getInt32 (1)); +} + +llvm::Value * +jit_typeinfo::complex_new (llvm::Value *real, llvm::Value *imag) +{ + llvm::Value *ret = llvm::UndefValue::get (complex->to_llvm ()); + ret = complex_real (ret, real); + return complex_imag (ret, imag); +} + +jit_type * +jit_typeinfo::do_type_of (const octave_value &ov) const +{ + if (ov.is_function ()) + { + // FIXME: This is ugly, we need to finalize how we want to to this, then + // have octave_value fully support the needed functionality + octave_builtin *builtin + = dynamic_cast (ov.internal_rep ()); + return builtin && builtin->to_jit () ? builtin->to_jit () + : unknown_function; + } + + if (ov.is_range ()) + return get_range (); + + if (ov.is_double_type ()) + { + if (ov.is_real_scalar ()) + return get_scalar (); + + if (ov.is_matrix_type ()) + return get_matrix (); + } + + if (ov.is_complex_scalar ()) + return get_complex (); + + return get_any (); +} + +jit_type* +jit_typeinfo::new_type (const std::string& name, jit_type *parent, + llvm::Type *llvm_type) +{ + jit_type *ret = new jit_type (name, parent, llvm_type, next_id++); + id_to_type.push_back (ret); + return ret; +} + +// -------------------- jit_use -------------------- +jit_block * +jit_use::user_parent (void) const +{ + return muser->parent (); +} + +// -------------------- jit_value -------------------- +jit_value::~jit_value (void) +{} + +jit_block * +jit_value::first_use_block (void) +{ + jit_use *use = first_use (); + while (use) + { + if (! isa (use->user ())) + return use->user_parent (); + + use = use->next (); + } + + return 0; +} + +void +jit_value::replace_with (jit_value *value) +{ + while (first_use ()) + { + jit_instruction *user = first_use ()->user (); + size_t idx = first_use ()->index (); + user->stash_argument (idx, value); + } +} + +#define JIT_METH(clname) \ + void \ + jit_ ## clname::accept (jit_ir_walker& walker) \ + { \ + walker.visit (*this); \ + } + +JIT_VISIT_IR_NOTEMPLATE +#undef JIT_METH + +std::ostream& +operator<< (std::ostream& os, const jit_value& value) +{ + return value.short_print (os); +} + +std::ostream& +jit_print (std::ostream& os, jit_value *avalue) +{ + if (avalue) + return avalue->print (os); + return os << "NULL"; +} + +// -------------------- jit_instruction -------------------- +void +jit_instruction::remove (void) +{ + if (mparent) + mparent->remove (mlocation); + resize_arguments (0); +} + +llvm::BasicBlock * +jit_instruction::parent_llvm (void) const +{ + return mparent->to_llvm (); +} + +std::ostream& +jit_instruction::short_print (std::ostream& os) const +{ + if (type ()) + jit_print (os, type ()) << ": "; + return os << "#" << mid; +} + +void +jit_instruction::do_construct_ssa (size_t start, size_t end) +{ + for (size_t i = start; i < end; ++i) + { + jit_value *arg = argument (i); + jit_variable *var = dynamic_cast (arg); + if (var && var->has_top ()) + stash_argument (i, var->top ()); + } +} + +// -------------------- jit_block -------------------- +void +jit_block::replace_with (jit_value *value) +{ + assert (isa (value)); + jit_block *block = static_cast (value); + + jit_value::replace_with (block); + + while (ILIST_T::first_use ()) + { + jit_phi_incomming *incomming = ILIST_T::first_use (); + incomming->stash_value (block); + } +} + +void +jit_block::replace_in_phi (jit_block *ablock, jit_block *with) +{ + jit_phi_incomming *node = ILIST_T::first_use (); + while (node) + { + jit_phi_incomming *prev = node; + node = node->next (); + + if (prev->user_parent () == ablock) + prev->stash_value (with); + } +} + +jit_block * +jit_block::maybe_merge () +{ + if (successor_count () == 1 && successor (0) != this + && (successor (0)->use_count () == 1 || instructions.size () == 1)) + { + jit_block *to_merge = successor (0); + merge (*to_merge); + return to_merge; + } + + return 0; +} + +void +jit_block::merge (jit_block& block) +{ + // the merge block will contain a new terminator + jit_terminator *old_term = terminator (); + if (old_term) + old_term->remove (); + + bool was_empty = end () == begin (); + iterator merge_begin = end (); + if (! was_empty) + --merge_begin; + + instructions.splice (end (), block.instructions); + if (was_empty) + merge_begin = begin (); + else + ++merge_begin; + + // now merge_begin points to the start of the new instructions, we must + // update their parent information + for (iterator iter = merge_begin; iter != end (); ++iter) + { + jit_instruction *instr = *iter; + instr->stash_parent (this, iter); + } + + block.replace_with (this); +} + +jit_instruction * +jit_block::prepend (jit_instruction *instr) +{ + instructions.push_front (instr); + instr->stash_parent (this, instructions.begin ()); + return instr; +} + +jit_instruction * +jit_block::prepend_after_phi (jit_instruction *instr) +{ + // FIXME: Make this O(1) + for (iterator iter = begin (); iter != end (); ++iter) + { + jit_instruction *temp = *iter; + if (! isa (temp)) + { + insert_before (iter, instr); + return instr; + } + } + + return append (instr); +} + +void +jit_block::internal_append (jit_instruction *instr) +{ + instructions.push_back (instr); + instr->stash_parent (this, --instructions.end ()); +} + +jit_instruction * +jit_block::insert_before (iterator loc, jit_instruction *instr) +{ + iterator iloc = instructions.insert (loc, instr); + instr->stash_parent (this, iloc); + return instr; +} + +jit_instruction * +jit_block::insert_after (iterator loc, jit_instruction *instr) +{ + ++loc; + iterator iloc = instructions.insert (loc, instr); + instr->stash_parent (this, iloc); + return instr; +} + +jit_terminator * +jit_block::terminator (void) const +{ + assert (this); + if (instructions.empty ()) + return 0; + + jit_instruction *last = instructions.back (); + return dynamic_cast (last); +} + +bool +jit_block::branch_alive (jit_block *asucc) const +{ + return terminator ()->alive (asucc); +} + +jit_block * +jit_block::successor (size_t i) const +{ + jit_terminator *term = terminator (); + return term->successor (i); +} + +size_t +jit_block::successor_count (void) const +{ + jit_terminator *term = terminator (); + return term ? term->successor_count () : 0; +} + +llvm::BasicBlock * +jit_block::to_llvm (void) const +{ + return llvm::cast (llvm_value); +} + +std::ostream& +jit_block::print_dom (std::ostream& os) const +{ + short_print (os); + os << ":\n"; + os << " mid: " << mid << std::endl; + os << " predecessors: "; + for (jit_use *use = first_use (); use; use = use->next ()) + os << *use->user_parent () << " "; + os << std::endl; + + os << " successors: "; + for (size_t i = 0; i < successor_count (); ++i) + os << *successor (i) << " "; + os << std::endl; + + os << " idom: "; + if (idom) + os << *idom; + else + os << "NULL"; + os << std::endl; + os << " df: "; + for (df_iterator iter = df_begin (); iter != df_end (); ++iter) + os << **iter << " "; + os << std::endl; + + os << " dom_succ: "; + for (size_t i = 0; i < dom_succ.size (); ++i) + os << *dom_succ[i] << " "; + + return os << std::endl; +} + +void +jit_block::compute_df (size_t avisit_count) +{ + if (visited (avisit_count)) + return; + + if (use_count () >= 2) + { + for (jit_use *use = first_use (); use; use = use->next ()) + { + jit_block *runner = use->user_parent (); + while (runner != idom) + { + runner->mdf.insert (this); + runner = runner->idom; + } + } + } + + for (size_t i = 0; i < successor_count (); ++i) + successor (i)->compute_df (avisit_count); +} + +bool +jit_block::update_idom (size_t avisit_count) +{ + if (visited (avisit_count) || ! use_count ()) + return false; + + bool changed = false; + for (jit_use *use = first_use (); use; use = use->next ()) + { + jit_block *pred = use->user_parent (); + changed = pred->update_idom (avisit_count) || changed; + } + + jit_use *use = first_use (); + jit_block *new_idom = use->user_parent (); + use = use->next (); + + for (; use; use = use->next ()) + { + jit_block *pred = use->user_parent (); + jit_block *pidom = pred->idom; + if (pidom) + new_idom = idom_intersect (pidom, new_idom); + } + + if (idom != new_idom) + { + idom = new_idom; + return true; + } + + return changed; +} + +void +jit_block::pop_all (void) +{ + for (iterator iter = begin (); iter != end (); ++iter) + { + jit_instruction *instr = *iter; + instr->pop_variable (); + } +} + +jit_block * +jit_block::maybe_split (jit_convert& convert, jit_block *asuccessor) +{ + if (successor_count () > 1) + { + jit_terminator *term = terminator (); + size_t idx = term->successor_index (asuccessor); + jit_block *split = convert.create ("phi_split", mvisit_count); + + // try to place splits where they make sense + if (id () < asuccessor->id ()) + convert.insert_before (asuccessor, split); + else + convert.insert_after (this, split); + + term->stash_argument (idx, split); + jit_branch *br = split->append (convert.create (asuccessor)); + replace_in_phi (asuccessor, split); + + if (alive ()) + { + split->mark_alive (); + br->infer (); + } + + return split; + } + + return this; +} + +void +jit_block::create_dom_tree (size_t avisit_count) +{ + if (visited (avisit_count)) + return; + + if (idom != this) + idom->dom_succ.push_back (this); + + for (size_t i = 0; i < successor_count (); ++i) + successor (i)->create_dom_tree (avisit_count); +} + +jit_block * +jit_block::idom_intersect (jit_block *i, jit_block *j) +{ + while (i && j && i != j) + { + while (i && i->id () > j->id ()) + i = i->idom; + + while (i && j && j->id () > i->id ()) + j = j->idom; + } + + return i ? i : j; +} + +// -------------------- jit_phi_incomming -------------------- + +jit_block * +jit_phi_incomming::user_parent (void) const +{ return muser->parent (); } + +// -------------------- jit_phi -------------------- +bool +jit_phi::prune (void) +{ + jit_block *p = parent (); + size_t new_idx = 0; + jit_value *unique = argument (1); + + for (size_t i = 0; i < argument_count (); ++i) + { + jit_block *inc = incomming (i); + if (inc->branch_alive (p)) + { + if (unique != argument (i)) + unique = 0; + + if (new_idx != i) + { + stash_argument (new_idx, argument (i)); + mincomming[new_idx].stash_value (inc); + } + + ++new_idx; + } + } + + if (new_idx != argument_count ()) + { + resize_arguments (new_idx); + mincomming.resize (new_idx); + } + + assert (argument_count () > 0); + if (unique) + { + replace_with (unique); + return true; + } + + return false; +} + +bool +jit_phi::infer (void) +{ + jit_block *p = parent (); + if (! p->alive ()) + return false; + + jit_type *infered = 0; + for (size_t i = 0; i < argument_count (); ++i) + { + jit_block *inc = incomming (i); + if (inc->branch_alive (p)) + infered = jit_typeinfo::join (infered, argument_type (i)); + } + + if (infered != type ()) + { + stash_type (infered); + return true; + } + + return false; +} + +llvm::PHINode * +jit_phi::to_llvm (void) const +{ + return llvm::cast (jit_value::to_llvm ()); +} + +// -------------------- jit_terminator -------------------- +size_t +jit_terminator::successor_index (const jit_block *asuccessor) const +{ + size_t scount = successor_count (); + for (size_t i = 0; i < scount; ++i) + if (successor (i) == asuccessor) + return i; + + panic_impossible (); +} + +bool +jit_terminator::infer (void) +{ + if (! parent ()->alive ()) + return false; + + bool changed = false; + for (size_t i = 0; i < malive.size (); ++i) + if (! malive[i] && check_alive (i)) + { + changed = true; + malive[i] = true; + successor (i)->mark_alive (); + } + + return changed; +} + +llvm::TerminatorInst * +jit_terminator::to_llvm (void) const +{ + return llvm::cast (jit_value::to_llvm ()); +} + +// -------------------- jit_call -------------------- +bool +jit_call::infer (void) +{ + // FIXME: explain algorithm + for (size_t i = 0; i < argument_count (); ++i) + { + already_infered[i] = argument_type (i); + if (! already_infered[i]) + return false; + } + + jit_type *infered = mfunction.get_result (already_infered); + if (! infered && use_count ()) + { + std::stringstream ss; + ss << "Missing overload in type inference for "; + print (ss, 0); + fail (ss.str ()); + } + + if (infered != type ()) + { + stash_type (infered); + return true; + } + + return false; +} + +// -------------------- jit_convert -------------------- +jit_convert::jit_convert (llvm::Module *module, tree &tee) + : iterator_count (0), short_count (0), breaking (false) +{ + jit_instruction::reset_ids (); + + entry_block = create ("body"); + final_block = create ("final"); + append (entry_block); + entry_block->mark_alive (); + block = entry_block; + visit (tee); + + // FIXME: Remove if we no longer only compile loops + assert (! breaking); + assert (breaks.empty ()); + assert (continues.empty ()); + + block->append (create (final_block)); + append (final_block); + + for (vmap_t::iterator iter = vmap.begin (); iter != vmap.end (); ++iter) + { + jit_variable *var = iter->second; + const std::string& name = var->name (); + if (name.size () && name[0] != '#') + final_block->append (create (var)); + } + + construct_ssa (); + + // initialize the worklist to instructions derived from constants + for (std::list::iterator iter = constants.begin (); + iter != constants.end (); ++iter) + append_users (*iter); + + // FIXME: Describe algorithm here + while (worklist.size ()) + { + jit_instruction *next = worklist.front (); + worklist.pop_front (); + next->stash_in_worklist (false); + + if (next->infer ()) + { + // terminators need to be handles specially + if (jit_terminator *term = dynamic_cast (next)) + append_users_term (term); + else + append_users (next); + } + } + + remove_dead (); + merge_blocks (); + final_block->label (); + place_releases (); + simplify_phi (); + +#ifdef OCTAVE_JIT_DEBUG + final_block->label (); + std::cout << "-------------------- Compiling tree --------------------\n"; + std::cout << tee.str_print_code () << std::endl; + print_blocks ("octave jit ir"); +#endif + + // for now just init arguments from entry, later we will have to do something + // more interesting + for (jit_block::iterator iter = entry_block->begin (); + iter != entry_block->end (); ++iter) + if (jit_extract_argument *extract + = dynamic_cast (*iter)) + arguments.push_back (std::make_pair (extract->name (), true)); + + convert_llvm to_llvm (*this); + function = to_llvm.convert (module, arguments, blocks, constants); + +#ifdef OCTAVE_JIT_DEBUG + std::cout << "-------------------- llvm ir --------------------"; + llvm::raw_os_ostream llvm_cout (std::cout); + function->print (llvm_cout); + std::cout << std::endl; + llvm::verifyFunction (*function); +#endif +} + +jit_convert::~jit_convert (void) +{ + for (std::list::iterator iter = all_values.begin (); + iter != all_values.end (); ++iter) + delete *iter; +} + +void +jit_convert::visit_anon_fcn_handle (tree_anon_fcn_handle&) +{ + fail (); +} + +void +jit_convert::visit_argument_list (tree_argument_list&) +{ + fail (); +} + +void +jit_convert::visit_binary_expression (tree_binary_expression& be) +{ + if (be.op_type () >= octave_value::num_binary_ops) + { + tree_boolean_expression *boole; + boole = dynamic_cast (&be); + assert (boole); + bool is_and = boole->op_type () == tree_boolean_expression::bool_and; + + std::stringstream ss; + ss << "#short_result" << short_count++; + + std::string short_name = ss.str (); + jit_variable *short_result = create (short_name); + vmap[short_name] = short_result; + + jit_block *done = create (block->name ()); + tree_expression *lhs = be.lhs (); + jit_value *lhsv = visit (lhs); + lhsv = create_checked (&jit_typeinfo::logically_true, lhsv); + + jit_block *short_early = create ("short_early"); + append (short_early); + + jit_block *short_cont = create ("short_cont"); + + if (is_and) + block->append (create (lhsv, short_cont, short_early)); + else + block->append (create (lhsv, short_early, short_cont)); + + block = short_early; + + jit_value *early_result = create (! is_and); + block->append (create (short_result, early_result)); + block->append (create (done)); + + append (short_cont); + block = short_cont; + + tree_expression *rhs = be.rhs (); + jit_value *rhsv = visit (rhs); + rhsv = create_checked (&jit_typeinfo::logically_true, rhsv); + block->append (create (short_result, rhsv)); + block->append (create (done)); + + append (done); + block = done; + result = short_result; + } + else + { + tree_expression *lhs = be.lhs (); + jit_value *lhsv = visit (lhs); + + tree_expression *rhs = be.rhs (); + jit_value *rhsv = visit (rhs); + + const jit_operation& fn = jit_typeinfo::binary_op (be.op_type ()); + result = create_checked (fn, lhsv, rhsv); + } +} + +void +jit_convert::visit_break_command (tree_break_command&) +{ + breaks.push_back (block); + breaking = true; +} + +void +jit_convert::visit_colon_expression (tree_colon_expression& expr) +{ + // in the futher we need to add support for classes and deal with rvalues + jit_value *base = visit (expr.base ()); + jit_value *limit = visit (expr.limit ()); + jit_value *increment; + tree_expression *tinc = expr.increment (); + + if (tinc) + increment = visit (tinc); + else + increment = create (1); + + result = block->append (create (jit_typeinfo::make_range, base, + limit, increment)); +} + +void +jit_convert::visit_continue_command (tree_continue_command&) +{ + continues.push_back (block); + breaking = true; +} + +void +jit_convert::visit_global_command (tree_global_command&) +{ + fail (); +} + +void +jit_convert::visit_persistent_command (tree_persistent_command&) +{ + fail (); +} + +void +jit_convert::visit_decl_elt (tree_decl_elt&) +{ + fail (); +} + +void +jit_convert::visit_decl_init_list (tree_decl_init_list&) +{ + fail (); +} + +void +jit_convert::visit_simple_for_command (tree_simple_for_command& cmd) +{ + // Note we do an initial check to see if the loop will run atleast once. + // This allows us to get better type inference bounds on variables defined + // and used only inside the for loop (e.g. the index variable) + + // If we are a nested for loop we need to store the previous breaks + assert (! breaking); + unwind_protect prot; + prot.protect_var (breaks); + prot.protect_var (continues); + prot.protect_var (breaking); + breaks.clear (); + continues.clear (); + + // we need a variable for our iterator, because it is used in multiple blocks + std::stringstream ss; + ss << "#iter" << iterator_count++; + std::string iter_name = ss.str (); + jit_variable *iterator = create (iter_name); + vmap[iter_name] = iterator; + + jit_block *body = create ("for_body"); + append (body); + + jit_block *tail = create ("for_tail"); + + // do control expression, iter init, and condition check in prev_block (block) + jit_value *control = visit (cmd.control_expr ()); + jit_call *init_iter = create (jit_typeinfo::for_init, control); + block->append (init_iter); + block->append (create (iterator, init_iter)); + + jit_value *check = block->append (create (jit_typeinfo::for_check, + control, iterator)); + block->append (create (check, body, tail)); + block = body; + + // compute the syntactical iterator + jit_call *idx_rhs = create (jit_typeinfo::for_index, control, + iterator); + block->append (idx_rhs); + do_assign (cmd.left_hand_side (), idx_rhs); + + // do loop + tree_statement_list *pt_body = cmd.body (); + pt_body->accept (*this); + + if (breaking && continues.empty ()) + { + // WTF are you doing user? Every branch was a continue, why did you have + // a loop??? Users are silly people... + finish_breaks (tail, breaks); + append (tail); + block = tail; + return; + } + + // check our condition, continues jump to this block + jit_block *check_block = create ("for_check"); + append (check_block); + + if (! breaking) + block->append (create (check_block)); + finish_breaks (check_block, continues); + + block = check_block; + const jit_operation& add_fn = jit_typeinfo::binary_op (octave_value::op_add); + jit_value *one = create (1); + jit_call *iter_inc = create (add_fn, iterator, one); + block->append (iter_inc); + block->append (create (iterator, iter_inc)); + check = block->append (create (jit_typeinfo::for_check, control, + iterator)); + block->append (create (check, body, tail)); + + // breaks will go to our tail + append (tail); + finish_breaks (tail, breaks); + block = tail; +} + +void +jit_convert::visit_complex_for_command (tree_complex_for_command&) +{ + fail (); +} + +void +jit_convert::visit_octave_user_script (octave_user_script&) +{ + fail (); +} + +void +jit_convert::visit_octave_user_function (octave_user_function&) +{ + fail (); +} + +void +jit_convert::visit_octave_user_function_header (octave_user_function&) +{ + fail (); +} + +void +jit_convert::visit_octave_user_function_trailer (octave_user_function&) +{ + fail (); +} + +void +jit_convert::visit_function_def (tree_function_def&) +{ + fail (); +} + +void +jit_convert::visit_identifier (tree_identifier& ti) +{ + result = get_variable (ti.name ()); +} + +void +jit_convert::visit_if_clause (tree_if_clause&) +{ + fail (); +} + +void +jit_convert::visit_if_command (tree_if_command& cmd) +{ + tree_if_command_list *lst = cmd.cmd_list (); + assert (lst); // jwe: Can this be null? + lst->accept (*this); +} + +void +jit_convert::visit_if_command_list (tree_if_command_list& lst) +{ + tree_if_clause *last = lst.back (); + size_t last_else = static_cast (last->is_else_clause ()); + + // entry_blocks represents the block you need to enter in order to execute + // the condition check for the ith clause. For the else, it is simple the + // else body. If there is no else body, then it is padded with the tail + std::vector entry_blocks (lst.size () + 1 - last_else); + std::vector branch_blocks (lst.size (), 0); // final blocks + entry_blocks[0] = block; + + // we need to construct blocks first, because they have jumps to eachother + tree_if_command_list::iterator iter = lst.begin (); + ++iter; + for (size_t i = 1; iter != lst.end (); ++iter, ++i) + { + tree_if_clause *tic = *iter; + if (tic->is_else_clause ()) + entry_blocks[i] = create ("else"); + else + entry_blocks[i] = create ("ifelse_cond"); + } + + jit_block *tail = create ("if_tail"); + if (! last_else) + entry_blocks[entry_blocks.size () - 1] = tail; + + size_t num_incomming = 0; // number of incomming blocks to our tail + iter = lst.begin (); + for (size_t i = 0; iter != lst.end (); ++iter, ++i) + { + tree_if_clause *tic = *iter; + block = entry_blocks[i]; + assert (block); + + if (i) // the first block is prev_block, so it has already been added + append (entry_blocks[i]); + + if (! tic->is_else_clause ()) + { + tree_expression *expr = tic->condition (); + jit_value *cond = visit (expr); + jit_call *check = create_checked (&jit_typeinfo::logically_true, + cond); + jit_block *body = create (i == 0 ? "if_body" + : "ifelse_body"); + append (body); + + jit_instruction *br = create (check, body, + entry_blocks[i + 1]); + block->append (br); + block = body; + } + + tree_statement_list *stmt_lst = tic->commands (); + assert (stmt_lst); // jwe: Can this be null? + stmt_lst->accept (*this); + + if (breaking) + breaking = false; + else + { + ++num_incomming; + block->append (create (tail)); + } + } + + if (num_incomming || ! last_else) + { + append (tail); + block = tail; + } + else + // every branch broke, so we don't have a tail + breaking = true; +} + +void +jit_convert::visit_index_expression (tree_index_expression& exp) +{ + std::pair res = resolve (exp); + jit_value *object = res.first; + jit_value *index = res.second; + + result = create_checked (jit_typeinfo::paren_subsref, object, index); +} + +void +jit_convert::visit_matrix (tree_matrix&) +{ + fail (); +} + +void +jit_convert::visit_cell (tree_cell&) +{ + fail (); +} + +void +jit_convert::visit_multi_assignment (tree_multi_assignment&) +{ + fail (); +} + +void +jit_convert::visit_no_op_command (tree_no_op_command&) +{ + fail (); +} + +void +jit_convert::visit_constant (tree_constant& tc) +{ + octave_value v = tc.rvalue1 (); + if (v.is_real_scalar () && v.is_double_type ()) + { + double dv = v.double_value (); + result = create (dv); + } + else if (v.is_range ()) + { + Range rv = v.range_value (); + result = create (rv); + } + else if (v.is_complex_scalar ()) + { + Complex cv = v.complex_value (); + result = create (cv); + } + else + fail ("Unknown constant"); +} + +void +jit_convert::visit_fcn_handle (tree_fcn_handle&) +{ + fail (); +} + +void +jit_convert::visit_parameter_list (tree_parameter_list&) +{ + fail (); +} + +void +jit_convert::visit_postfix_expression (tree_postfix_expression&) +{ + fail (); +} + +void +jit_convert::visit_prefix_expression (tree_prefix_expression&) +{ + fail (); +} + +void +jit_convert::visit_return_command (tree_return_command&) +{ + fail (); +} + +void +jit_convert::visit_return_list (tree_return_list&) +{ + fail (); +} + +void +jit_convert::visit_simple_assignment (tree_simple_assignment& tsa) +{ + if (tsa.op_type () != octave_value::op_asn_eq) + fail ("Unsupported assign"); + + // resolve rhs + tree_expression *rhs = tsa.right_hand_side (); + jit_value *rhsv = visit (rhs); + + result = do_assign (tsa.left_hand_side (), rhsv); +} + +void +jit_convert::visit_statement (tree_statement& stmt) +{ + tree_command *cmd = stmt.command (); + tree_expression *expr = stmt.expression (); + + if (cmd) + visit (cmd); + else + { + // stolen from tree_evaluator::visit_statement + bool do_bind_ans = false; + + if (expr->is_identifier ()) + { + tree_identifier *id = dynamic_cast (expr); + + do_bind_ans = (! id->is_variable ()); + } + else + do_bind_ans = (! expr->is_assignment_expression ()); + + jit_value *expr_result = visit (expr); + + if (do_bind_ans) + do_assign ("ans", expr_result, expr->print_result ()); + else if (expr->is_identifier () && expr->print_result ()) + { + // FIXME: ugly hack, we need to come up with a way to pass + // nargout to visit_identifier + const jit_operation& fn = jit_typeinfo::print_value (); + jit_const_string *name = create (expr->name ()); + block->append (create (fn, name, expr_result)); + } + } +} + +void +jit_convert::visit_statement_list (tree_statement_list& lst) +{ + for (tree_statement_list::iterator iter = lst.begin (); iter != lst.end(); + ++iter) + { + tree_statement *elt = *iter; + // jwe: Can this ever be null? + assert (elt); + elt->accept (*this); + + if (breaking) + break; + } +} + +void +jit_convert::visit_switch_case (tree_switch_case&) +{ + fail (); +} + +void +jit_convert::visit_switch_case_list (tree_switch_case_list&) +{ + fail (); +} + +void +jit_convert::visit_switch_command (tree_switch_command&) +{ + fail (); +} + +void +jit_convert::visit_try_catch_command (tree_try_catch_command&) +{ + fail (); +} + +void +jit_convert::visit_unwind_protect_command (tree_unwind_protect_command&) +{ + fail (); +} + +void +jit_convert::visit_while_command (tree_while_command& wc) +{ + assert (! breaking); + unwind_protect prot; + prot.protect_var (breaks); + prot.protect_var (continues); + prot.protect_var (breaking); + breaks.clear (); + continues.clear (); + + jit_block *cond_check = create ("while_cond_check"); + block->append (create (cond_check)); + append (cond_check); + block = cond_check; + + tree_expression *expr = wc.condition (); + assert (expr && "While expression can not be null"); + jit_value *check = visit (expr); + check = create_checked (&jit_typeinfo::logically_true, check); + + jit_block *body = create ("while_body"); + append (body); + + jit_block *tail = create ("while_tail"); + block->append (create (check, body, tail)); + block = body; + + tree_statement_list *loop_body = wc.body (); + if (loop_body) + loop_body->accept (*this); + + finish_breaks (tail, breaks); + finish_breaks (cond_check, continues); + + if (! breaking) + block->append (create (cond_check)); + + append (tail); + block = tail; +} + +void +jit_convert::visit_do_until_command (tree_do_until_command&) +{ + fail (); +} + +void +jit_convert::append (jit_block *ablock) +{ + blocks.push_back (ablock); + ablock->stash_location (--blocks.end ()); +} + +void +jit_convert::insert_before (block_iterator iter, jit_block *ablock) +{ + iter = blocks.insert (iter, ablock); + ablock->stash_location (iter); +} + +void +jit_convert::insert_after (block_iterator iter, jit_block *ablock) +{ + ++iter; + insert_before (iter, ablock); +} + +jit_variable * +jit_convert::get_variable (const std::string& vname) +{ + vmap_t::iterator iter; + iter = vmap.find (vname); + if (iter != vmap.end ()) + return iter->second; + + jit_variable *var = create (vname); + octave_value val = symbol_table::find (vname); + jit_type *type = jit_typeinfo::type_of (val); + jit_extract_argument *extract; + extract = create (type, var); + entry_block->prepend (extract); + + return vmap[vname] = var; +} + +std::pair +jit_convert::resolve (tree_index_expression& exp) +{ + std::string type = exp.type_tags (); + if (! (type.size () == 1 && type[0] == '(')) + fail ("Unsupported index operation"); + + std::list args = exp.arg_lists (); + if (args.size () != 1) + fail ("Bad number of arguments in tree_index_expression"); + + tree_argument_list *arg_list = args.front (); + if (! arg_list) + fail ("null argument list"); + + if (arg_list->size () != 1) + fail ("Bad number of arguments in arg_list"); + + tree_expression *tree_object = exp.expression (); + jit_value *object = visit (tree_object); + tree_expression *arg0 = arg_list->front (); + jit_value *index = visit (arg0); + + return std::make_pair (object, index); +} + +jit_value * +jit_convert::do_assign (tree_expression *exp, jit_value *rhs, bool artificial) +{ + if (! exp) + fail ("NULL lhs in assign"); + + if (isa (exp)) + return do_assign (exp->name (), rhs, exp->print_result (), artificial); + else if (tree_index_expression *idx + = dynamic_cast (exp)) + { + std::pair res = resolve (*idx); + jit_value *object = res.first; + jit_value *index = res.second; + jit_call *new_object = create (&jit_typeinfo::paren_subsasgn, + object, index, rhs); + block->append (new_object); + do_assign (idx->expression (), new_object, true); + create_check (new_object); + + // FIXME: Will not work for values that must be release/grabed + return rhs; + } + else + fail ("Unsupported assignment"); +} + +jit_value * +jit_convert::do_assign (const std::string& lhs, jit_value *rhs, + bool print, bool artificial) +{ + jit_variable *var = get_variable (lhs); + jit_assign *assign = block->append (create (var, rhs)); + + if (artificial) + assign->mark_artificial (); + + if (print) + { + const jit_operation& print_fn = jit_typeinfo::print_value (); + jit_const_string *name = create (lhs); + block->append (create (print_fn, name, var)); + } + + return var; +} + +jit_value * +jit_convert::visit (tree& tee) +{ + result = 0; + tee.accept (*this); + + jit_value *ret = result; + result = 0; + return ret; +} + +void +jit_convert::append_users_term (jit_terminator *term) +{ + for (size_t i = 0; i < term->successor_count (); ++i) + { + if (term->alive (i)) + { + jit_block *succ = term->successor (i); + for (jit_block::iterator iter = succ->begin (); iter != succ->end () + && isa (*iter); ++iter) + push_worklist (*iter); + + jit_terminator *sterm = succ->terminator (); + if (sterm) + push_worklist (sterm); + } + } +} + +void +jit_convert::merge_blocks (void) +{ + std::vector dead; + for (block_list::iterator iter = blocks.begin (); iter != blocks.end (); + ++iter) + { + jit_block *b = *iter; + jit_block *merged = b->maybe_merge (); + + if (merged) + { + if (merged == final_block) + final_block = b; + + if (merged == entry_block) + entry_block = b; + + dead.push_back (merged); + } + } + + for (size_t i = 0; i < dead.size (); ++i) + blocks.erase (dead[i]->location ()); +} + +void +jit_convert::construct_ssa (void) +{ + merge_blocks (); + final_block->label (); + final_block->compute_idom (entry_block); + entry_block->compute_df (); + entry_block->create_dom_tree (); + + // insert phi nodes where needed, this is done on a per variable basis + for (vmap_t::iterator iter = vmap.begin (); iter != vmap.end (); ++iter) + { + jit_block::df_set visited, added_phi; + std::list ssa_worklist; + iter->second->use_blocks (visited); + ssa_worklist.insert (ssa_worklist.begin (), visited.begin (), + visited.end ()); + + while (ssa_worklist.size ()) + { + jit_block *b = ssa_worklist.front (); + ssa_worklist.pop_front (); + + for (jit_block::df_iterator diter = b->df_begin (); + diter != b->df_end (); ++diter) + { + jit_block *dblock = *diter; + if (! added_phi.count (dblock)) + { + jit_phi *phi = create (iter->second, + dblock->use_count ()); + dblock->prepend (phi); + added_phi.insert (dblock); + } + + if (! visited.count (dblock)) + { + ssa_worklist.push_back (dblock); + visited.insert (dblock); + } + } + } + } + + do_construct_ssa (*entry_block, entry_block->visit_count ()); +} + +void +jit_convert::do_construct_ssa (jit_block& ablock, size_t avisit_count) +{ + if (ablock.visited (avisit_count)) + return; + + // replace variables with their current SSA value + for (jit_block::iterator iter = ablock.begin (); iter != ablock.end (); ++iter) + { + jit_instruction *instr = *iter; + instr->construct_ssa (); + instr->push_variable (); + } + + // finish phi nodes of successors + for (size_t i = 0; i < ablock.successor_count (); ++i) + { + jit_block *finish = ablock.successor (i); + + for (jit_block::iterator iter = finish->begin (); iter != finish->end () + && isa (*iter);) + { + jit_phi *phi = static_cast (*iter); + jit_variable *var = phi->dest (); + if (var->has_top ()) + { + phi->add_incomming (&ablock, var->top ()); + ++iter; + } + else + { + // temporaries may have extranious phi nodes which can be removed + assert (! phi->use_count ()); + assert (var->name ().size () && var->name ()[0] == '#'); + iter = finish->remove (iter); + } + } + } + + for (size_t i = 0; i < ablock.dom_successor_count (); ++i) + do_construct_ssa (*ablock.dom_successor (i), avisit_count); + + ablock.pop_all (); +} + +void +jit_convert::remove_dead () +{ + block_list::iterator biter; + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block *b = *biter; + if (b->alive ()) + { + for (jit_block::iterator iter = b->begin (); iter != b->end () + && isa (*iter);) + { + jit_phi *phi = static_cast (*iter); + if (phi->prune ()) + iter = b->remove (iter); + else + ++iter; + } + } + } + + for (biter = blocks.begin (); biter != blocks.end ();) + { + jit_block *b = *biter; + if (b->alive ()) + { + // FIXME: A special case for jit_error_check, if we generalize to + // we will need to change! + jit_terminator *term = b->terminator (); + if (term && term->successor_count () == 2 && ! term->alive (0)) + { + jit_block *succ = term->successor (1); + term->remove (); + jit_branch *abreak = b->append (create (succ)); + abreak->infer (); + } + + ++biter; + } + else + { + jit_terminator *term = b->terminator (); + if (term) + term->remove (); + biter = blocks.erase (biter); + } + } +} + +void +jit_convert::place_releases (void) +{ + std::set temporaries; + for (block_list::iterator iter = blocks.begin (); iter != blocks.end (); + ++iter) + { + jit_block& ablock = **iter; + if (ablock.id () != jit_block::NO_ID) + { + release_temp (ablock, temporaries); + release_dead_phi (ablock); + } + } +} + +void +jit_convert::release_temp (jit_block& ablock, std::set& temp) +{ + for (jit_block::iterator iter = ablock.begin (); iter != ablock.end (); + ++iter) + { + jit_instruction *instr = *iter; + + // check for temporaries that require release and live across + // multiple blocks + if (instr->needs_release ()) + { + jit_block *fu_block = instr->first_use_block (); + if (fu_block && fu_block != &ablock) + temp.insert (instr); + } + + if (isa (instr)) + { + // place releases for temporary arguments + for (size_t i = 0; i < instr->argument_count (); ++i) + { + jit_value *arg = instr->argument (i); + if (arg->needs_release ()) + { + jit_call *release = create (&jit_typeinfo::release, + arg); + release->infer (); + ablock.insert_after (iter, release); + ++iter; + temp.erase (arg); + } + } + } + } + + if (! temp.size () || ! isa (ablock.terminator ())) + return; + + // FIXME: If we support try/catch or unwind_protect final_block may not be the + // destination + jit_block *split = ablock.maybe_split (*this, final_block); + jit_terminator *term = split->terminator (); + for (std::set::const_iterator iter = temp.begin (); + iter != temp.end (); ++iter) + { + jit_value *value = *iter; + jit_call *release = create (&jit_typeinfo::release, value); + split->insert_before (term, release); + release->infer (); + } +} + +void +jit_convert::release_dead_phi (jit_block& ablock) +{ + jit_block::iterator iter = ablock.begin (); + while (iter != ablock.end () && isa (*iter)) + { + jit_phi *phi = static_cast (*iter); + ++iter; + + jit_use *use = phi->first_use (); + if (phi->use_count () == 1 && isa (use->user ())) + { + // instead of releasing on assign, release on all incomming branches, + // this can get rid of casts inside loops + for (size_t i = 0; i < phi->argument_count (); ++i) + { + jit_value *arg = phi->argument (i); + jit_block *inc = phi->incomming (i); + jit_block *split = inc->maybe_split (*this, ablock); + jit_terminator *term = split->terminator (); + jit_call *release = create (jit_typeinfo::release, arg); + release->infer (); + split->insert_before (term, release); + } + + phi->replace_with (0); + phi->remove (); + } + } +} + +void +jit_convert::simplify_phi (void) +{ + for (block_list::iterator biter = blocks.begin (); biter != blocks.end (); + ++biter) + { + jit_block &ablock = **biter; + for (jit_block::iterator iter = ablock.begin (); iter != ablock.end () + && isa (*iter); ++iter) + simplify_phi (*static_cast (*iter)); + } +} + +void +jit_convert::simplify_phi (jit_phi& phi) +{ + jit_block& pblock = *phi.parent (); + const jit_operation& cast_fn = jit_typeinfo::cast (phi.type ()); + jit_variable *dest = phi.dest (); + for (size_t i = 0; i < phi.argument_count (); ++i) + { + jit_value *arg = phi.argument (i); + if (arg->type () != phi.type ()) + { + jit_block *pred = phi.incomming (i); + jit_block *split = pred->maybe_split (*this, pblock); + jit_terminator *term = split->terminator (); + jit_instruction *cast = create (cast_fn, arg); + jit_assign *assign = create (dest, cast); + + split->insert_before (term, cast); + split->insert_before (term, assign); + cast->infer (); + assign->infer (); + phi.stash_argument (i, assign); + } + } +} + +void +jit_convert::finish_breaks (jit_block *dest, const block_list& lst) +{ + for (block_list::const_iterator iter = lst.begin (); iter != lst.end (); + ++iter) + { + jit_block *b = *iter; + b->append (create (dest)); + } +} + +// -------------------- jit_convert::convert_llvm -------------------- +llvm::Function * +jit_convert::convert_llvm::convert (llvm::Module *module, + const std::vector >& args, + const std::list& blocks, + const std::list& constants) +{ + jit_type *any = jit_typeinfo::get_any (); + + // argument is an array of octave_base_value*, or octave_base_value** + llvm::Type *arg_type = any->to_llvm (); // this is octave_base_value* + arg_type = arg_type->getPointerTo (); + llvm::FunctionType *ft = llvm::FunctionType::get (llvm::Type::getVoidTy (context), + arg_type, false); + function = llvm::Function::Create (ft, llvm::Function::ExternalLinkage, + "foobar", module); + + try + { + prelude = llvm::BasicBlock::Create (context, "prelude", function); + builder.SetInsertPoint (prelude); + + llvm::Value *arg = function->arg_begin (); + for (size_t i = 0; i < args.size (); ++i) + { + llvm::Value *loaded_arg = builder.CreateConstInBoundsGEP1_32 (arg, i); + arguments[args[i].first] = loaded_arg; + } + + std::list::const_iterator biter; + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block *jblock = *biter; + llvm::BasicBlock *block = llvm::BasicBlock::Create (context, + jblock->name (), + function); + jblock->stash_llvm (block); + } + + jit_block *first = *blocks.begin (); + builder.CreateBr (first->to_llvm ()); + + // constants aren't in the IR, we visit those first + for (std::list::const_iterator iter = constants.begin (); + iter != constants.end (); ++iter) + if (! isa (*iter)) + visit (*iter); + + // convert all instructions + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + visit (*biter); + + // now finish phi nodes + for (biter = blocks.begin (); biter != blocks.end (); ++biter) + { + jit_block& block = **biter; + for (jit_block::iterator piter = block.begin (); + piter != block.end () && isa (*piter); ++piter) + { + jit_instruction *phi = *piter; + finish_phi (static_cast (phi)); + } + } + + jit_block *last = blocks.back (); + builder.SetInsertPoint (last->to_llvm ()); + builder.CreateRetVoid (); + } catch (const jit_fail_exception& e) + { + function->eraseFromParent (); + throw; + } + + return function; +} + +void +jit_convert::convert_llvm::finish_phi (jit_phi *phi) +{ + llvm::PHINode *llvm_phi = phi->to_llvm (); + for (size_t i = 0; i < phi->argument_count (); ++i) + { + llvm::BasicBlock *pred = phi->incomming_llvm (i); + llvm_phi->addIncoming (phi->argument_llvm (i), pred); + } +} + +void +jit_convert::convert_llvm::visit (jit_const_string& cs) +{ + cs.stash_llvm (builder.CreateGlobalStringPtr (cs.value ())); +} + +void +jit_convert::convert_llvm::visit (jit_const_bool& cb) +{ + cb.stash_llvm (llvm::ConstantInt::get (cb.type_llvm (), cb.value ())); +} + +void +jit_convert::convert_llvm::visit (jit_const_scalar& cs) +{ + cs.stash_llvm (llvm::ConstantFP::get (cs.type_llvm (), cs.value ())); +} + +void +jit_convert::convert_llvm::visit (jit_const_complex& cc) +{ + llvm::Type *scalar_t = jit_typeinfo::get_scalar_llvm (); + llvm::Constant *values[2]; + Complex value = cc.value (); + values[0] = llvm::ConstantFP::get (scalar_t, value.real ()); + values[1] = llvm::ConstantFP::get (scalar_t, value.imag ()); + cc.stash_llvm (llvm::ConstantVector::get (values)); +} + +void jit_convert::convert_llvm::visit (jit_const_index& ci) +{ + ci.stash_llvm (llvm::ConstantInt::get (ci.type_llvm (), ci.value ())); +} + +void +jit_convert::convert_llvm::visit (jit_const_range& cr) +{ + llvm::StructType *stype = llvm::cast(cr.type_llvm ()); + llvm::Type *scalar_t = jit_typeinfo::get_scalar_llvm (); + llvm::Type *idx = jit_typeinfo::get_index_llvm (); + const jit_range& rng = cr.value (); + + llvm::Constant *constants[4]; + constants[0] = llvm::ConstantFP::get (scalar_t, rng.base); + constants[1] = llvm::ConstantFP::get (scalar_t, rng.limit); + constants[2] = llvm::ConstantFP::get (scalar_t, rng.inc); + constants[3] = llvm::ConstantInt::get (idx, rng.nelem); + + llvm::Value *as_llvm; + as_llvm = llvm::ConstantStruct::get (stype, + llvm::makeArrayRef (constants, 4)); + cr.stash_llvm (as_llvm); +} + +void +jit_convert::convert_llvm::visit (jit_block& b) +{ + llvm::BasicBlock *block = b.to_llvm (); + builder.SetInsertPoint (block); + for (jit_block::iterator iter = b.begin (); iter != b.end (); ++iter) + visit (*iter); +} + +void +jit_convert::convert_llvm::visit (jit_branch& b) +{ + b.stash_llvm (builder.CreateBr (b.successor_llvm ())); +} + +void +jit_convert::convert_llvm::visit (jit_cond_branch& cb) +{ + llvm::Value *cond = cb.cond_llvm (); + llvm::Value *br; + br = builder.CreateCondBr (cond, cb.successor_llvm (0), + cb.successor_llvm (1)); + cb.stash_llvm (br); +} + +void +jit_convert::convert_llvm::visit (jit_call& call) +{ + llvm::Value *ret = create_call (call.overload (), call.arguments ()); + call.stash_llvm (ret); +} + +void +jit_convert::convert_llvm::visit (jit_extract_argument& extract) +{ + llvm::Value *arg = arguments[extract.name ()]; + assert (arg); + arg = builder.CreateLoad (arg); + + jit_value *jarg = jthis.create (jit_typeinfo::get_any (), arg); + extract.stash_llvm (create_call (extract.overload (), jarg)); +} + +void +jit_convert::convert_llvm::visit (jit_store_argument& store) +{ + llvm::Value *arg_value = create_call (store.overload (), store.result ()); + + llvm::Value *arg = arguments[store.name ()]; + store.stash_llvm (builder.CreateStore (arg_value, arg)); +} + +void +jit_convert::convert_llvm::visit (jit_phi& phi) +{ + // we might not have converted all incoming branches, so we don't + // set incomming branches now + llvm::PHINode *node = llvm::PHINode::Create (phi.type_llvm (), + phi.argument_count ()); + builder.Insert (node); + phi.stash_llvm (node); +} + +void +jit_convert::convert_llvm::visit (jit_variable&) +{ + fail ("ERROR: SSA construction should remove all variables"); +} + +void +jit_convert::convert_llvm::visit (jit_error_check& check) +{ + llvm::Value *cond = jit_typeinfo::insert_error_check (); + llvm::Value *br = builder.CreateCondBr (cond, check.successor_llvm (0), + check.successor_llvm (1)); + check.stash_llvm (br); +} + +void +jit_convert::convert_llvm::visit (jit_assign& assign) +{ + assign.stash_llvm (assign.src ()->to_llvm ()); + + if (assign.artificial ()) + return; + + jit_value *new_value = assign.src (); + if (isa (new_value)) + { + const jit_operation::overload& ol + = jit_typeinfo::get_grab (new_value->type ()); + if (ol.function) + assign.stash_llvm (create_call (ol, new_value)); + } + + jit_value *overwrite = assign.overwrite (); + if (isa (overwrite)) + { + const jit_operation::overload& ol + = jit_typeinfo::get_release (overwrite->type ()); + if (ol.function) + create_call (ol, overwrite); + } +} + +void +jit_convert::convert_llvm::visit (jit_argument&) +{} + +llvm::Value * +jit_convert::convert_llvm::create_call (const jit_operation::overload& ol, + const std::vector& jargs) +{ + llvm::IRBuilder<> alloca_inserter (prelude, prelude->begin ()); + + llvm::Function *fun = ol.function; + if (! fun) + fail ("Missing overload"); + + const llvm::Function::ArgumentListType& alist = fun->getArgumentList (); + size_t nargs = alist.size (); + bool sret = false; + if (nargs != jargs.size ()) + { + // first argument is the structure return value + assert (nargs == jargs.size () + 1); + sret = true; + } + + std::vector args (nargs); + llvm::Function::arg_iterator llvm_arg = fun->arg_begin (); + if (sret) + { + args[0] = alloca_inserter.CreateAlloca (ol.result->to_llvm ()); + ++llvm_arg; + } + + for (size_t i = 0; i < jargs.size (); ++i, ++llvm_arg) + { + llvm::Value *arg = jargs[i]->to_llvm (); + llvm::Type *arg_type = arg->getType (); + llvm::Type *llvm_arg_type = llvm_arg->getType (); + + if (arg_type == llvm_arg_type) + args[i + sret] = arg; + else + { + // pass structure by pointer + assert (arg_type->getPointerTo () == llvm_arg_type); + llvm::Value *new_arg = alloca_inserter.CreateAlloca (arg_type); + builder.CreateStore (arg, new_arg); + args[i + sret] = new_arg; + } + } + + llvm::Value *llvm_call = builder.CreateCall (fun, args); + return sret ? builder.CreateLoad (args[0]) : llvm_call; +} + +llvm::Value * +jit_convert::convert_llvm::create_call (const jit_operation::overload& ol, + const std::vector& uses) +{ + std::vector values (uses.size ()); + for (size_t i = 0; i < uses.size (); ++i) + values[i] = uses[i].value (); + + return create_call (ol, values); +} + +// -------------------- tree_jit -------------------- + +tree_jit::tree_jit (void) : module (0), engine (0) +{ +} + +tree_jit::~tree_jit (void) +{} + +bool +tree_jit::execute (tree_simple_for_command& cmd) +{ + if (! initialize ()) + return false; + + jit_info *info = cmd.get_info (); + if (! info || ! info->match ()) + { + delete info; + info = new jit_info (*this, cmd); + cmd.stash_info (info); + } + + return info->execute (); +} + +bool +tree_jit::initialize (void) +{ + if (engine) + return true; + + if (! module) + { + llvm::InitializeNativeTarget (); + module = new llvm::Module ("octave", context); + } + + // sometimes this fails pre main + engine = llvm::ExecutionEngine::createJIT (module); + + if (! engine) + return false; + + module_pass_manager = new llvm::PassManager (); + module_pass_manager->add (llvm::createAlwaysInlinerPass ()); + + pass_manager = new llvm::FunctionPassManager (module); + pass_manager->add (new llvm::TargetData(*engine->getTargetData ())); + pass_manager->add (llvm::createBasicAliasAnalysisPass ()); + pass_manager->add (llvm::createPromoteMemoryToRegisterPass ()); + pass_manager->add (llvm::createInstructionCombiningPass ()); + pass_manager->add (llvm::createReassociatePass ()); + pass_manager->add (llvm::createGVNPass ()); + pass_manager->add (llvm::createCFGSimplificationPass ()); + pass_manager->doInitialization (); + + jit_typeinfo::initialize (module, engine); + + return true; +} + + +void +tree_jit::optimize (llvm::Function *fn) +{ + module_pass_manager->run (*module); + pass_manager->run (*fn); + +#ifdef OCTAVE_JIT_DEBUG + std::string error; + llvm::raw_fd_ostream fout ("test.bc", error, + llvm::raw_fd_ostream::F_Binary); + llvm::WriteBitcodeToFile (module, fout); +#endif +} + +// -------------------- jit_info -------------------- +jit_info::jit_info (tree_jit& tjit, tree& tee) + : engine (tjit.get_engine ()), llvm_function (0) +{ + try + { + jit_convert conv (tjit.get_module (), tee); + llvm_function = conv.get_function (); + arguments = conv.get_arguments (); + bounds = conv.get_bounds (); + } + catch (const jit_fail_exception& e) + { +#ifdef OCTAVE_JIT_DEBUG + if (e.known ()) + std::cout << "jit fail: " << e.what () << std::endl; +#endif + } + + if (! llvm_function) + { + function = 0; + return; + } + + tjit.optimize (llvm_function); + +#ifdef OCTAVE_JIT_DEBUG + std::cout << "-------------------- optimized llvm ir --------------------\n"; + llvm::raw_os_ostream llvm_cout (std::cout); + llvm_function->print (llvm_cout); + std::cout << std::endl; +#endif + + void *void_fn = engine->getPointerToFunction (llvm_function); + function = reinterpret_cast (void_fn); +} + +jit_info::~jit_info (void) +{ + if (llvm_function) + llvm_function->eraseFromParent (); +} + +bool +jit_info::execute (void) const +{ + if (! function) + return false; + + std::vector real_arguments (arguments.size ()); + for (size_t i = 0; i < arguments.size (); ++i) + { + if (arguments[i].second) + { + octave_value ¤t = symbol_table::varref (arguments[i].first); + octave_base_value *obv = current.internal_rep (); + obv->grab (); + real_arguments[i] = obv; + current = octave_value (); + } + } + + function (&real_arguments[0]); + + for (size_t i = 0; i < arguments.size (); ++i) + symbol_table::varref (arguments[i].first) = real_arguments[i]; + + return true; +} + +bool +jit_info::match (void) const +{ + if (! function) + return true; + + for (size_t i = 0; i < bounds.size (); ++i) + { + const std::string& arg_name = bounds[i].second; + octave_value value = symbol_table::find (arg_name); + jit_type *type = jit_typeinfo::type_of (value); + + // FIXME: Check for a parent relationship + if (type != bounds[i].first) + return false; + } + + return true; +} +#endif diff --git a/src/pt-jit.h b/src/pt-jit.h new file mode 100644 --- /dev/null +++ b/src/pt-jit.h @@ -0,0 +1,2358 @@ +/* + +Copyright (C) 2012 Max Brister + +This file is part of Octave. + +Octave is free software; you can redistribute it and/or modify it +under the terms of the GNU General Public License as published by the +Free Software Foundation; either version 3 of the License, or (at your +option) any later version. + +Octave is distributed in the hope that it will be useful, but WITHOUT +ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License +for more details. + +You should have received a copy of the GNU General Public License +along with Octave; see the file COPYING. If not, see +. + +*/ + +#if !defined (octave_tree_jit_h) +#define octave_tree_jit_h 1 + +#ifdef HAVE_LLVM + +#include +#include +#include +#include +#include +#include + +#include "Array.h" +#include "Range.h" +#include "pt-walk.h" + +// -------------------- Current status -------------------- +// Simple binary operations (+-*/) on octave_scalar's (doubles) are optimized. +// a = 5; +// b = a * 5 + a; +// +// Indexing matrices with scalars works. +// +// if, elseif, else, break, continue, and for compile. Compilation is triggered +// at the start of a simple for loop. +// +// The octave low level IR is a linear IR, it works by converting everything to +// calls to jit_operations. This turns expressions like c = a + b into +// c = call binary+ (a, b) +// The jit_operations contain information about overloads for different types. +// For, example, if we know a and b are scalars, then c must also be a scalar. +// +// Support for function calls is in progress. Currently, calls to sin with a +// scalar argument will compile. +// +// TODO: +// 1. Function calls (In progress) +// 2. Cleanup/documentation +// 3. ... +// --------------------------------------------------------- + + +// we don't want to include llvm headers here, as they require +// __STDC_LIMIT_MACROS and __STDC_CONSTANT_MACROS be defined in the entire +// compilation unit +namespace llvm +{ + class Value; + class Module; + class FunctionPassManager; + class PassManager; + class ExecutionEngine; + class Function; + class BasicBlock; + class LLVMContext; + class Type; + class StructType; + class Twine; + class GlobalVariable; + class TerminatorInst; + class PHINode; +} + +class octave_base_value; +class octave_builtin; +class octave_value; +class tree; +class tree_expression; + +template +class jit_internal_node; + +// jit_internal_list and jit_internal_node implement generic embedded doubly +// linked lists. List items extend from jit_internal_list, and can be placed +// in nodes of type jit_internal_node. We use CRTP twice. +template +class +jit_internal_list +{ + friend class jit_internal_node; +public: + jit_internal_list (void) : use_head (0), use_tail (0), muse_count (0) {} + + virtual ~jit_internal_list (void) + { + while (use_head) + use_head->stash_value (0); + } + + NODE_T *first_use (void) const { return use_head; } + + size_t use_count (void) const { return muse_count; } +private: + NODE_T *use_head; + NODE_T *use_tail; + size_t muse_count; +}; + +// a node for internal linked lists +template +class +jit_internal_node +{ +public: + typedef jit_internal_list jit_ilist; + + jit_internal_node (void) : mvalue (0), mnext (0), mprev (0) {} + + ~jit_internal_node (void) { remove (); } + + LIST_T *value (void) const { return mvalue; } + + void stash_value (LIST_T *avalue) + { + remove (); + + mvalue = avalue; + + if (mvalue) + { + jit_ilist *ilist = mvalue; + NODE_T *sthis = static_cast (this); + if (ilist->use_head) + { + ilist->use_tail->mnext = sthis; + mprev = ilist->use_tail; + } + else + ilist->use_head = sthis; + + ilist->use_tail = sthis; + ++ilist->muse_count; + } + } + + NODE_T *next (void) const { return mnext; } + + NODE_T *prev (void) const { return mprev; } +private: + void remove () + { + if (mvalue) + { + jit_ilist *ilist = mvalue; + if (mprev) + mprev->mnext = mnext; + else + // we are the use_head + ilist->use_head = mnext; + + if (mnext) + mnext->mprev = mprev; + else + // we are the use tail + ilist->use_tail = mprev; + + mnext = mprev = 0; + --ilist->muse_count; + mvalue = 0; + } + } + + LIST_T *mvalue; + NODE_T *mnext; + NODE_T *mprev; +}; + +// Use like: isa (value) +// basically just a short cut type typing dyanmic_cast. +template +bool isa (U *value) +{ + return dynamic_cast (value); +} + +// jit_range is compatable with the llvm range structure +struct +jit_range +{ + jit_range (const Range& from) : base (from.base ()), limit (from.limit ()), + inc (from.inc ()), nelem (from.nelem ()) + {} + + operator Range () const + { + return Range (base, limit, inc); + } + + bool all_elements_are_ints () const; + + double base; + double limit; + double inc; + octave_idx_type nelem; +}; + +std::ostream& operator<< (std::ostream& os, const jit_range& rng); + +// jit_array is compatable with the llvm array/matrix structures +template +struct +jit_array +{ + jit_array (T& from) : array (new T (from)) + { + update (); + } + + void update (void) + { + ref_count = array->jit_ref_count (); + slice_data = array->jit_slice_data () - 1; + slice_len = array->capacity (); + dimensions = array->jit_dimensions (); + } + + void update (T *aarray) + { + array = aarray; + update (); + } + + operator T () const + { + return *array; + } + + int *ref_count; + + U *slice_data; + octave_idx_type slice_len; + octave_idx_type *dimensions; + + T *array; +}; + +typedef jit_array jit_matrix; + +std::ostream& operator<< (std::ostream& os, const jit_matrix& mat); + +// Used to keep track of estimated (infered) types during JIT. This is a +// hierarchical type system which includes both concrete and abstract types. +// +// Current, we only support any and scalar types. If we can't figure out what +// type a variable is, we assign it the any type. This allows us to generate +// code even for the case of poor type inference. +class +jit_type +{ +public: + jit_type (const std::string& aname, jit_type *aparent, llvm::Type *allvm_type, + int aid) : + mname (aname), mparent (aparent), llvm_type (allvm_type), mid (aid), + mdepth (aparent ? aparent->mdepth + 1 : 0) + {} + + // a user readable type name + const std::string& name (void) const { return mname; } + + // a unique id for the type + int type_id (void) const { return mid; } + + // An abstract base type, may be null + jit_type *parent (void) const { return mparent; } + + // convert to an llvm type + llvm::Type *to_llvm (void) const { return llvm_type; } + + // how this type gets passed as a function argument + llvm::Type *to_llvm_arg (void) const; + + size_t depth (void) const { return mdepth; } +private: + std::string mname; + jit_type *mparent; + llvm::Type *llvm_type; + int mid; + size_t mdepth; +}; + +// seperate print function to allow easy printing if type is null +std::ostream& jit_print (std::ostream& os, jit_type *atype); + +// Keeps track of overloads for a builtin function. Used for both type inference +// and code generation. +class +jit_operation +{ +public: + struct + overload + { + overload (void) : function (0), can_error (false), result (0) {} + +#define ASSIGN_ARG(i) arguments[i] = arg ## i; +#define OVERLOAD_CTOR(N) \ + overload (llvm::Function *f, bool e, jit_type *ret, \ + OCT_MAKE_DECL_LIST (jit_type *, arg, N)) \ + : function (f), can_error (e), result (ret), arguments (N) \ + { \ + OCT_ITERATE_MACRO (ASSIGN_ARG, N); \ + } + + OVERLOAD_CTOR (1) + OVERLOAD_CTOR (2) + OVERLOAD_CTOR (3) + +#undef ASSIGN_ARG +#undef OVERLOAD_CTOR + + overload (llvm::Function *f, bool e, jit_type *r, + const std::vector& aarguments) + : function (f), can_error (e), result (r), arguments (aarguments) + {} + + llvm::Function *function; + bool can_error; + jit_type *result; + std::vector arguments; + }; + + void add_overload (const overload& func) + { + add_overload (func, func.arguments); + } + +#define ADD_OVERLOAD(N) \ + void add_overload (llvm::Function *f, bool e, jit_type *ret, \ + OCT_MAKE_DECL_LIST (jit_type *, arg, N)) \ + { \ + overload ol (f, e, ret, OCT_MAKE_ARG_LIST (arg, N)); \ + add_overload (ol); \ + } + + ADD_OVERLOAD (1); + ADD_OVERLOAD (2); + ADD_OVERLOAD (3); + +#undef ADD_OVERLOAD + + void add_overload (llvm::Function *f, bool e, jit_type *r, + const std::vector& args) + { + overload ol (f, e, r, args); + add_overload (ol); + } + + void add_overload (const overload& func, + const std::vector& args); + + const overload& get_overload (const std::vector& types) const; + + const overload& get_overload (jit_type *arg0) const + { + std::vector types (1); + types[0] = arg0; + return get_overload (types); + } + + const overload& get_overload (jit_type *arg0, jit_type *arg1) const + { + std::vector types (2); + types[0] = arg0; + types[1] = arg1; + return get_overload (types); + } + + jit_type *get_result (const std::vector& types) const + { + const overload& temp = get_overload (types); + return temp.result; + } + + jit_type *get_result (jit_type *arg0, jit_type *arg1) const + { + const overload& temp = get_overload (arg0, arg1); + return temp.result; + } + + const std::string& name (void) const { return mname; } + + void stash_name (const std::string& aname) { mname = aname; } +private: + Array to_idx (const std::vector& types) const; + + std::vector > overloads; + + std::string mname; +}; + +// Get information and manipulate jit types. +class +jit_typeinfo +{ +public: + static void initialize (llvm::Module *m, llvm::ExecutionEngine *e); + + static jit_type *join (jit_type *lhs, jit_type *rhs) + { + return instance->do_join (lhs, rhs); + } + + static jit_type *get_any (void) { return instance->any; } + + static jit_type *get_matrix (void) { return instance->matrix; } + + static jit_type *get_scalar (void) { return instance->scalar; } + + static llvm::Type *get_scalar_llvm (void) + { return instance->scalar->to_llvm (); } + + static jit_type *get_range (void) { return instance->range; } + + static jit_type *get_string (void) { return instance->string; } + + static jit_type *get_bool (void) { return instance->boolean; } + + static jit_type *get_index (void) { return instance->index; } + + static llvm::Type *get_index_llvm (void) + { return instance->index->to_llvm (); } + + static jit_type *get_complex (void) { return instance->complex; } + + static jit_type *type_of (const octave_value& ov) + { + return instance->do_type_of (ov); + } + + static const jit_operation& binary_op (int op) + { + return instance->do_binary_op (op); + } + + static const jit_operation& grab (void) { return instance->grab_fn; } + + static const jit_operation::overload& get_grab (jit_type *type) + { + return instance->grab_fn.get_overload (type); + } + + static const jit_operation& release (void) + { + return instance->release_fn; + } + + static const jit_operation::overload& get_release (jit_type *type) + { + return instance->release_fn.get_overload (type); + } + + static const jit_operation& print_value (void) + { + return instance->print_fn; + } + + static const jit_operation& for_init (void) + { + return instance->for_init_fn; + } + + static const jit_operation& for_check (void) + { + return instance->for_check_fn; + } + + static const jit_operation& for_index (void) + { + return instance->for_index_fn; + } + + static const jit_operation& make_range (void) + { + return instance->make_range_fn; + } + + static const jit_operation& paren_subsref (void) + { + return instance->paren_subsref_fn; + } + + static const jit_operation& paren_subsasgn (void) + { + return instance->paren_subsasgn_fn; + } + + static const jit_operation& logically_true (void) + { + return instance->logically_true_fn; + } + + static const jit_operation& cast (jit_type *result) + { + return instance->do_cast (result); + } + + static const jit_operation::overload& cast (jit_type *to, jit_type *from) + { + return instance->do_cast (to, from); + } + + static llvm::Value *insert_error_check (void) + { + return instance->do_insert_error_check (); + } +private: + jit_typeinfo (llvm::Module *m, llvm::ExecutionEngine *e); + + // FIXME: Do these methods really need to be in jit_typeinfo? + jit_type *do_join (jit_type *lhs, jit_type *rhs) + { + // empty case + if (! lhs) + return rhs; + + if (! rhs) + return lhs; + + // check for a shared parent + while (lhs != rhs) + { + if (lhs->depth () > rhs->depth ()) + lhs = lhs->parent (); + else if (lhs->depth () < rhs->depth ()) + rhs = rhs->parent (); + else + { + // we MUST have depth > 0 as any is the base type of everything + do + { + lhs = lhs->parent (); + rhs = rhs->parent (); + } + while (lhs != rhs); + } + } + + return lhs; + } + + jit_type *do_difference (jit_type *lhs, jit_type *) + { + // FIXME: Maybe we can do something smarter? + return lhs; + } + + jit_type *do_type_of (const octave_value &ov) const; + + const jit_operation& do_binary_op (int op) const + { + assert (static_cast(op) < binary_ops.size ()); + return binary_ops[op]; + } + + const jit_operation& do_cast (jit_type *to) + { + static jit_operation null_function; + if (! to) + return null_function; + + size_t id = to->type_id (); + if (id >= casts.size ()) + return null_function; + return casts[id]; + } + + const jit_operation::overload& do_cast (jit_type *to, jit_type *from) + { + return do_cast (to).get_overload (from); + } + + jit_type *new_type (const std::string& name, jit_type *parent, + llvm::Type *llvm_type); + + + void add_print (jit_type *ty, void *call); + + void add_binary_op (jit_type *ty, int op, int llvm_op); + + void add_binary_icmp (jit_type *ty, int op, int llvm_op); + + void add_binary_fcmp (jit_type *ty, int op, int llvm_op); + + + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret) + { + std::vector args; + return create_function (name, ret, args); + } + +#define ASSIGN_ARG(i) args[i] = arg ## i; +#define CREATE_FUNCTIONT(TYPE, N) \ + llvm::Function *create_function (const llvm::Twine& name, TYPE *ret, \ + OCT_MAKE_DECL_LIST (TYPE *, arg, N)) \ + { \ + std::vector args (N); \ + OCT_ITERATE_MACRO (ASSIGN_ARG, N); \ + return create_function (name, ret, args); \ + } + +#define CREATE_FUNCTION(N) \ + CREATE_FUNCTIONT(llvm::Type, N) \ + CREATE_FUNCTIONT(jit_type, N) + + CREATE_FUNCTION(1) + CREATE_FUNCTION(2) + CREATE_FUNCTION(3) + CREATE_FUNCTION(4) + +#undef ASSIGN_ARG +#undef CREATE_FUNCTIONT +#undef CREATE_FUNCTION + + llvm::Function *create_function (const llvm::Twine& name, jit_type *ret, + const std::vector& args); + + llvm::Function *create_function (const llvm::Twine& name, llvm::Type *ret, + const std::vector& args); + + llvm::Function *create_identity (jit_type *type); + + llvm::Value *do_insert_error_check (void); + + void add_builtin (const std::string& name); + + void register_intrinsic (const std::string& name, size_t id, + jit_type *result, jit_type *arg0) + { + std::vector args (1, arg0); + register_intrinsic (name, id, result, args); + } + + void register_intrinsic (const std::string& name, size_t id, jit_type *result, + const std::vector& args); + + void register_generic (const std::string& name, jit_type *result, + jit_type *arg0) + { + std::vector args (1, arg0); + register_generic (name, result, args); + } + + void register_generic (const std::string& name, jit_type *result, + const std::vector& args); + + octave_builtin *find_builtin (const std::string& name); + + llvm::Function *mirror_binary (llvm::Function *fn); + + llvm::Function *wrap_complex (llvm::Function *wrap); + + llvm::Value *pack_complex (llvm::Value *cplx); + + llvm::Value *unpack_complex (llvm::Value *result); + + llvm::Value *complex_real (llvm::Value *cx); + + llvm::Value *complex_real (llvm::Value *cx, llvm::Value *real); + + llvm::Value *complex_imag (llvm::Value *cx); + + llvm::Value *complex_imag (llvm::Value *cx, llvm::Value *imag); + + llvm::Value *complex_new (llvm::Value *real, llvm::Value *imag); + + static jit_typeinfo *instance; + + llvm::Module *module; + llvm::ExecutionEngine *engine; + int next_id; + + llvm::GlobalVariable *lerror_state; + + std::vector id_to_type; + jit_type *any; + jit_type *matrix; + jit_type *scalar; + jit_type *range; + jit_type *string; + jit_type *boolean; + jit_type *index; + jit_type *complex; + jit_type *unknown_function; + std::map builtins; + + llvm::StructType *complex_ret; + + std::vector binary_ops; + jit_operation grab_fn; + jit_operation release_fn; + jit_operation print_fn; + jit_operation for_init_fn; + jit_operation for_check_fn; + jit_operation for_index_fn; + jit_operation logically_true_fn; + jit_operation make_range_fn; + jit_operation paren_subsref_fn; + jit_operation paren_subsasgn_fn; + + // type id -> cast function TO that type + std::vector casts; + + // type id -> identity function + std::vector identities; +}; + +// The low level octave jit ir +// this ir is close to llvm, but contains information for doing type inference. +// We convert the octave parse tree to this IR directly. + +#define JIT_VISIT_IR_NOTEMPLATE \ + JIT_METH(block); \ + JIT_METH(branch); \ + JIT_METH(cond_branch); \ + JIT_METH(call); \ + JIT_METH(extract_argument); \ + JIT_METH(store_argument); \ + JIT_METH(phi); \ + JIT_METH(variable); \ + JIT_METH(error_check); \ + JIT_METH(assign) \ + JIT_METH(argument) + +#define JIT_VISIT_IR_CONST \ + JIT_METH(const_bool); \ + JIT_METH(const_scalar); \ + JIT_METH(const_complex); \ + JIT_METH(const_index); \ + JIT_METH(const_string); \ + JIT_METH(const_range) + +#define JIT_VISIT_IR_CLASSES \ + JIT_VISIT_IR_NOTEMPLATE \ + JIT_VISIT_IR_CONST + +// forward declare all ir classes +#define JIT_METH(cname) \ + class jit_ ## cname; + +JIT_VISIT_IR_NOTEMPLATE + +#undef JIT_METH + +class jit_convert; + +// ABCs which aren't included in JIT_VISIT_IR_ALL +class jit_instruction; +class jit_terminator; + +template +class jit_const; + +typedef jit_const jit_const_bool; +typedef jit_const jit_const_scalar; +typedef jit_const jit_const_complex; +typedef jit_const jit_const_index; + +typedef jit_const jit_const_string; +typedef jit_const +jit_const_range; + +class jit_ir_walker; +class jit_use; + +class +jit_value : public jit_internal_list +{ +public: + jit_value (void) : llvm_value (0), ty (0), mlast_use (0), + min_worklist (false) {} + + virtual ~jit_value (void); + + bool in_worklist (void) const + { + return min_worklist; + } + + void stash_in_worklist (bool ain_worklist) + { + min_worklist = ain_worklist; + } + + // The block of the first use which is not a jit_error_check + // So this is not necessarily first_use ()->parent (). + jit_block *first_use_block (void); + + // replace all uses with + virtual void replace_with (jit_value *value); + + jit_type *type (void) const { return ty; } + + llvm::Type *type_llvm (void) const + { + return ty ? ty->to_llvm () : 0; + } + + const std::string& type_name (void) const + { + return ty->name (); + } + + void stash_type (jit_type *new_ty) { ty = new_ty; } + + std::string print_string (void) + { + std::stringstream ss; + print (ss); + return ss.str (); + } + + jit_instruction *last_use (void) const { return mlast_use; } + + void stash_last_use (jit_instruction *alast_use) + { + mlast_use = alast_use; + } + + virtual bool needs_release (void) const { return false; } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const = 0; + + virtual std::ostream& short_print (std::ostream& os) const + { return print (os); } + + virtual void accept (jit_ir_walker& walker) = 0; + + bool has_llvm (void) const + { + return llvm_value; + } + + llvm::Value *to_llvm (void) const + { + assert (llvm_value); + return llvm_value; + } + + void stash_llvm (llvm::Value *compiled) + { + llvm_value = compiled; + } + +protected: + std::ostream& print_indent (std::ostream& os, size_t indent = 0) const + { + for (size_t i = 0; i < indent * 8; ++i) + os << " "; + return os; + } + + llvm::Value *llvm_value; +private: + jit_type *ty; + jit_instruction *mlast_use; + bool min_worklist; +}; + +std::ostream& operator<< (std::ostream& os, const jit_value& value); +std::ostream& jit_print (std::ostream& os, jit_value *avalue); + +class +jit_use : public jit_internal_node +{ +public: + jit_use (void) : muser (0), mindex (0) {} + + // we should really have a move operator, but not until c++11 :( + jit_use (const jit_use& use) : muser (0), mindex (0) + { + *this = use; + } + + jit_use& operator= (const jit_use& use) + { + stash_value (use.value (), use.user (), use.index ()); + return *this; + } + + size_t index (void) const { return mindex; } + + jit_instruction *user (void) const { return muser; } + + jit_block *user_parent (void) const; + + std::list user_parent_location (void) const; + + void stash_value (jit_value *avalue, jit_instruction *auser = 0, + size_t aindex = -1) + { + jit_internal_node::stash_value (avalue); + mindex = aindex; + muser = auser; + } +private: + jit_instruction *muser; + size_t mindex; +}; + +class +jit_instruction : public jit_value +{ +public: + // FIXME: this code could be so much pretier with varadic templates... + jit_instruction (void) : mid (next_id ()), mparent (0) + {} + + jit_instruction (size_t nargs) : mid (next_id ()), mparent (0) + { + already_infered.reserve (nargs); + marguments.reserve (nargs); + } + +#define STASH_ARG(i) stash_argument (i, arg ## i); +#define JIT_INSTRUCTION_CTOR(N) \ + jit_instruction (OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ + : already_infered (N), marguments (N), mid (next_id ()), mparent (0) \ + { \ + OCT_ITERATE_MACRO (STASH_ARG, N); \ + } + + JIT_INSTRUCTION_CTOR(1) + JIT_INSTRUCTION_CTOR(2) + JIT_INSTRUCTION_CTOR(3) + JIT_INSTRUCTION_CTOR(4) + +#undef STASH_ARG +#undef JIT_INSTRUCTION_CTOR + + static void reset_ids (void) + { + next_id (true); + } + + jit_value *argument (size_t i) const + { + return marguments[i].value (); + } + + llvm::Value *argument_llvm (size_t i) const + { + assert (argument (i)); + return argument (i)->to_llvm (); + } + + jit_type *argument_type (size_t i) const + { + return argument (i)->type (); + } + + llvm::Type *argument_type_llvm (size_t i) const + { + assert (argument (i)); + return argument_type (i)->to_llvm (); + } + + std::ostream& print_argument (std::ostream& os, size_t i) const + { + if (argument (i)) + return argument (i)->short_print (os); + else + return os << "NULL"; + } + + void stash_argument (size_t i, jit_value *arg) + { + marguments[i].stash_value (arg, this, i); + } + + void push_argument (jit_value *arg) + { + marguments.push_back (jit_use ()); + stash_argument (marguments.size () - 1, arg); + already_infered.push_back (0); + } + + size_t argument_count (void) const + { + return marguments.size (); + } + + void resize_arguments (size_t acount, jit_value *adefault = 0) + { + size_t old = marguments.size (); + marguments.resize (acount); + already_infered.resize (acount); + + if (adefault) + for (size_t i = old; i < acount; ++i) + stash_argument (i, adefault); + } + + const std::vector& arguments (void) const { return marguments; } + + // argument types which have been infered already + const std::vector& argument_types (void) const + { return already_infered; } + + virtual void push_variable (void) {} + + virtual void pop_variable (void) {} + + virtual void construct_ssa (void) + { + do_construct_ssa (0, argument_count ()); + } + + virtual bool infer (void) { return false; } + + void remove (void); + + virtual std::ostream& short_print (std::ostream& os) const; + + jit_block *parent (void) const { return mparent; } + + std::list::iterator location (void) const + { + return mlocation; + } + + llvm::BasicBlock *parent_llvm (void) const; + + void stash_parent (jit_block *aparent, + std::list::iterator alocation) + { + mparent = aparent; + mlocation = alocation; + } + + size_t id (void) const { return mid; } +protected: + + // Do SSA replacement on arguments in [start, end) + void do_construct_ssa (size_t start, size_t end); + + std::vector already_infered; +private: + static size_t next_id (bool reset = false) + { + static size_t ret = 0; + if (reset) + return ret = 0; + + return ret++; + } + + std::vector marguments; + + size_t mid; + jit_block *mparent; + std::list::iterator mlocation; +}; + +// defnie accept methods for subclasses +#define JIT_VALUE_ACCEPT \ + virtual void accept (jit_ir_walker& walker); + +// for use as a dummy argument during conversion to LLVM +class +jit_argument : public jit_value +{ +public: + jit_argument (jit_type *atype, llvm::Value *avalue) + { + stash_type (atype); + stash_llvm (avalue); + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent); + return jit_print (os, type ()) << ": DUMMY"; + } + + JIT_VALUE_ACCEPT; +}; + +template +class +jit_const : public jit_value +{ +public: + typedef PASS_T pass_t; + + jit_const (PASS_T avalue) : mvalue (avalue) + { + stash_type (EXTRACT_T ()); + } + + PASS_T value (void) const { return mvalue; } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent); + jit_print (os, type ()) << ": "; + if (QUOTE) + os << "\""; + os << mvalue; + if (QUOTE) + os << "\""; + return os; + } + + JIT_VALUE_ACCEPT; +private: + T mvalue; +}; + +class jit_phi_incomming; + +class +jit_block : public jit_value, public jit_internal_list +{ + typedef jit_internal_list ILIST_T; +public: + typedef std::list instruction_list; + typedef instruction_list::iterator iterator; + typedef instruction_list::const_iterator const_iterator; + + typedef std::set df_set; + typedef df_set::const_iterator df_iterator; + + static const size_t NO_ID = static_cast (-1); + + jit_block (const std::string& aname, size_t avisit_count = 0) + : mvisit_count (avisit_count), mid (NO_ID), idom (0), mname (aname), + malive (false) + {} + + virtual void replace_with (jit_value *value); + + void replace_in_phi (jit_block *ablock, jit_block *with); + + // we have a new internal list, but we want to stay compatable with jit_value + jit_use *first_use (void) const { return jit_value::first_use (); } + + size_t use_count (void) const { return jit_value::use_count (); } + + // if a block is alive, then it might be visited during execution + bool alive (void) const { return malive; } + + void mark_alive (void) { malive = true; } + + // If we can merge with a successor, do so and return the now empty block + jit_block *maybe_merge (); + + // merge another block into this block, leaving the merge block empty + void merge (jit_block& merge); + + const std::string& name (void) const { return mname; } + + jit_instruction *prepend (jit_instruction *instr); + + jit_instruction *prepend_after_phi (jit_instruction *instr); + + template + T *append (T *instr) + { + internal_append (instr); + return instr; + } + + jit_instruction *insert_before (iterator loc, jit_instruction *instr); + + jit_instruction *insert_before (jit_instruction *loc, jit_instruction *instr) + { + return insert_before (loc->location (), instr); + } + + jit_instruction *insert_after (iterator loc, jit_instruction *instr); + + jit_instruction *insert_after (jit_instruction *loc, jit_instruction *instr) + { + return insert_after (loc->location (), instr); + } + + iterator remove (iterator iter) + { + jit_instruction *instr = *iter; + iter = instructions.erase (iter); + instr->stash_parent (0, instructions.end ()); + return iter; + } + + jit_terminator *terminator (void) const; + + // is the jump from pred alive? + bool branch_alive (jit_block *asucc) const; + + jit_block *successor (size_t i) const; + + size_t successor_count (void) const; + + iterator begin (void) { return instructions.begin (); } + + const_iterator begin (void) const { return instructions.begin (); } + + iterator end (void) { return instructions.end (); } + + const_iterator end (void) const { return instructions.end (); } + + iterator phi_begin (void); + + iterator phi_end (void); + + iterator nonphi_begin (void); + + // must label before id is valid + size_t id (void) const { return mid; } + + // dominance frontier + const df_set& df (void) const { return mdf; } + + df_iterator df_begin (void) const { return mdf.begin (); } + + df_iterator df_end (void) const { return mdf.end (); } + + // label with a RPO walk + void label (void) + { + size_t number = 0; + label (mvisit_count, number); + } + + void label (size_t avisit_count, size_t& number) + { + if (visited (avisit_count)) + return; + + for (jit_use *use = first_use (); use; use = use->next ()) + { + jit_block *pred = use->user_parent (); + pred->label (avisit_count, number); + } + + mid = number++; + } + + // See for idom computation algorithm + // Cooper, Keith D.; Harvey, Timothy J; and Kennedy, Ken (2001). + // "A Simple, Fast Dominance Algorithm" + void compute_idom (jit_block *entry_block) + { + bool changed; + entry_block->idom = entry_block; + do + changed = update_idom (mvisit_count); + while (changed); + } + + // compute dominance frontier + void compute_df (void) + { + compute_df (mvisit_count); + } + + void create_dom_tree (void) + { + create_dom_tree (mvisit_count); + } + + jit_block *dom_successor (size_t idx) const + { + return dom_succ[idx]; + } + + size_t dom_successor_count (void) const + { + return dom_succ.size (); + } + + // call pop_varaible on all instructions + void pop_all (void); + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent); + short_print (os) << ": %pred = "; + for (jit_use *use = first_use (); use; use = use->next ()) + { + jit_block *pred = use->user_parent (); + os << *pred; + if (use->next ()) + os << ", "; + } + os << std::endl; + + for (const_iterator iter = begin (); iter != end (); ++iter) + { + jit_instruction *instr = *iter; + instr->print (os, indent + 1) << std::endl; + } + return os; + } + + // ... + jit_block *maybe_split (jit_convert& convert, jit_block *asuccessor); + + jit_block *maybe_split (jit_convert& convert, jit_block& asuccessor) + { + return maybe_split (convert, &asuccessor); + } + + // print dominator infomration + std::ostream& print_dom (std::ostream& os) const; + + virtual std::ostream& short_print (std::ostream& os) const + { + os << mname; + if (mid != NO_ID) + os << mid; + return os; + } + + llvm::BasicBlock *to_llvm (void) const; + + std::list::iterator location (void) const + { return mlocation; } + + void stash_location (std::list::iterator alocation) + { mlocation = alocation; } + + // used to prevent visiting the same node twice in the graph + size_t visit_count (void) const { return mvisit_count; } + + // check if this node has been visited yet at the given visit count. If we + // have not been visited yet, mark us as visited. + bool visited (size_t avisit_count) + { + if (mvisit_count <= avisit_count) + { + mvisit_count = avisit_count + 1; + return false; + } + + return true; + } + + JIT_VALUE_ACCEPT; +private: + void internal_append (jit_instruction *instr); + + void compute_df (size_t avisit_count); + + bool update_idom (size_t avisit_count); + + void create_dom_tree (size_t avisit_count); + + static jit_block *idom_intersect (jit_block *i, jit_block *j); + + size_t mvisit_count; + size_t mid; + jit_block *idom; + df_set mdf; + std::vector dom_succ; + std::string mname; + instruction_list instructions; + bool malive; + std::list::iterator mlocation; +}; + +// keeps track of phi functions that use a block on incomming edges +class +jit_phi_incomming : public jit_internal_node +{ +public: + jit_phi_incomming (void) : muser (0) {} + + jit_phi_incomming (jit_phi *auser) : muser (auser) {} + + jit_phi_incomming (const jit_phi_incomming& use) : jit_internal_node () + { + *this = use; + } + + jit_phi_incomming& operator= (const jit_phi_incomming& use) + { + stash_value (use.value ()); + muser = use.muser; + return *this; + } + + jit_phi *user (void) const { return muser; } + + jit_block *user_parent (void) const; +private: + jit_phi *muser; +}; + +// A non-ssa variable +class +jit_variable : public jit_value +{ +public: + jit_variable (const std::string& aname) : mname (aname), mlast_use (0) {} + + const std::string &name (void) const { return mname; } + + // manipulate the value_stack, for use during SSA construction. The top of the + // value stack represents the current value for this variable + bool has_top (void) const + { + return ! value_stack.empty (); + } + + jit_value *top (void) const + { + return value_stack.top (); + } + + void push (jit_instruction *v) + { + value_stack.push (v); + mlast_use = v; + } + + void pop (void) + { + value_stack.pop (); + } + + jit_instruction *last_use (void) const + { + return mlast_use; + } + + void stash_last_use (jit_instruction *instr) + { + mlast_use = instr; + } + + // blocks in which we are used + void use_blocks (jit_block::df_set& result) + { + jit_use *use = first_use (); + while (use) + { + result.insert (use->user_parent ()); + use = use->next (); + } + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + return print_indent (os, indent) << mname; + } + + JIT_VALUE_ACCEPT; +private: + std::string mname; + std::stack value_stack; + jit_instruction *mlast_use; +}; + +class +jit_assign_base : public jit_instruction +{ +public: + jit_assign_base (jit_variable *adest) : jit_instruction (), mdest (adest) {} + + jit_assign_base (jit_variable *adest, size_t npred) : jit_instruction (npred), + mdest (adest) {} + + jit_assign_base (jit_variable *adest, jit_value *arg0, jit_value *arg1) + : jit_instruction (arg0, arg1), mdest (adest) {} + + jit_variable *dest (void) const { return mdest; } + + virtual void push_variable (void) + { + mdest->push (this); + } + + virtual void pop_variable (void) + { + mdest->pop (); + } + + virtual std::ostream& short_print (std::ostream& os) const + { + if (type ()) + jit_print (os, type ()) << ": "; + + dest ()->short_print (os); + return os << "#" << id (); + } +private: + jit_variable *mdest; +}; + +class +jit_assign : public jit_assign_base +{ +public: + jit_assign (jit_variable *adest, jit_value *asrc) + : jit_assign_base (adest, adest, asrc), martificial (false) {} + + jit_value *overwrite (void) const + { + return argument (0); + } + + jit_value *src (void) const + { + return argument (1); + } + + // variables don't get modified in an SSA, but COW requires we modify + // variables. An artificial assign is for when a variable gets modified. We + // need an assign in the SSA, but the reference counts shouldn't be updated. + bool artificial (void) const { return martificial; } + + void mark_artificial (void) { martificial = true; } + + virtual bool infer (void) + { + jit_type *stype = src ()->type (); + if (stype != type()) + { + stash_type (stype); + return true; + } + + return false; + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent) << *this << " = " << *src (); + + if (artificial ()) + os << " [artificial]"; + + return os; + } + + JIT_VALUE_ACCEPT; +private: + bool martificial; +}; + +class +jit_phi : public jit_assign_base +{ +public: + jit_phi (jit_variable *adest, size_t npred) + : jit_assign_base (adest, npred) + { + mincomming.reserve (npred); + } + + // removes arguments form dead incomming jumps + bool prune (void); + + void add_incomming (jit_block *from, jit_value *value) + { + push_argument (value); + mincomming.push_back (jit_phi_incomming (this)); + mincomming[mincomming.size () - 1].stash_value (from); + } + + jit_block *incomming (size_t i) const + { + return mincomming[i].value (); + } + + llvm::BasicBlock *incomming_llvm (size_t i) const + { + return incomming (i)->to_llvm (); + } + + virtual void construct_ssa (void) {} + + virtual bool infer (void); + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + std::stringstream ss; + print_indent (ss, indent); + short_print (ss) << " phi "; + std::string ss_str = ss.str (); + std::string indent_str (ss_str.size (), ' '); + os << ss_str; + + for (size_t i = 0; i < argument_count (); ++i) + { + if (i > 0) + os << indent_str; + os << "| "; + + os << *incomming (i) << " -> "; + os << *argument (i); + + if (i + 1 < argument_count ()) + os << std::endl; + } + + return os; + } + + llvm::PHINode *to_llvm (void) const; + + JIT_VALUE_ACCEPT; +private: + std::vector mincomming; +}; + +class +jit_terminator : public jit_instruction +{ +public: +#define JIT_TERMINATOR_CONST(N) \ + jit_terminator (size_t asuccessor_count, \ + OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ + : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), \ + malive (asuccessor_count, false) {} + + JIT_TERMINATOR_CONST (1) + JIT_TERMINATOR_CONST (2) + JIT_TERMINATOR_CONST (3) + +#undef JIT_TERMINATOR_CONST + + jit_block *successor (size_t idx = 0) const + { + return static_cast (argument (idx)); + } + + llvm::BasicBlock *successor_llvm (size_t idx = 0) const + { + return successor (idx)->to_llvm (); + } + + size_t successor_index (const jit_block *asuccessor) const; + + std::ostream& print_successor (std::ostream& os, size_t idx = 0) const + { + if (alive (idx)) + os << "[live] "; + else + os << "[dead] "; + + return successor (idx)->short_print (os); + } + + // Check if the jump to successor is live + bool alive (const jit_block *asuccessor) const + { + return alive (successor_index (asuccessor)); + } + + bool alive (size_t idx) const { return malive[idx]; } + + bool alive (int idx) const { return malive[idx]; } + + size_t successor_count (void) const { return malive.size (); } + + virtual bool infer (void); + + llvm::TerminatorInst *to_llvm (void) const; +protected: + virtual bool check_alive (size_t) const { return true; } +private: + std::vector malive; +}; + +class +jit_branch : public jit_terminator +{ +public: + jit_branch (jit_block *succ) : jit_terminator (1, succ) {} + + virtual size_t successor_count (void) const { return 1; } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent) << "branch: "; + return print_successor (os); + } + + JIT_VALUE_ACCEPT; +}; + +class +jit_cond_branch : public jit_terminator +{ +public: + jit_cond_branch (jit_value *c, jit_block *ctrue, jit_block *cfalse) + : jit_terminator (2, ctrue, cfalse, c) {} + + jit_value *cond (void) const { return argument (2); } + + std::ostream& print_cond (std::ostream& os) const + { + return cond ()->short_print (os); + } + + llvm::Value *cond_llvm (void) const + { + return cond ()->to_llvm (); + } + + virtual size_t successor_count (void) const { return 2; } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent) << "cond_branch: "; + print_cond (os) << ", "; + print_successor (os, 0) << ", "; + return print_successor (os, 1); + } + + JIT_VALUE_ACCEPT; +}; + +class +jit_call : public jit_instruction +{ +public: +#define JIT_CALL_CONST(N) \ + jit_call (const jit_operation& afunction, \ + OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ + : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), mfunction (afunction) {} \ + \ + jit_call (const jit_operation& (*afunction) (void), \ + OCT_MAKE_DECL_LIST (jit_value *, arg, N)) \ + : jit_instruction (OCT_MAKE_ARG_LIST (arg, N)), mfunction (afunction ()) {} + + JIT_CALL_CONST (1) + JIT_CALL_CONST (2) + JIT_CALL_CONST (3) + JIT_CALL_CONST (4) + +#undef JIT_CALL_CONST + + + const jit_operation& function (void) const { return mfunction; } + + bool can_error (void) const + { + return overload ().can_error; + } + + const jit_operation::overload& overload (void) const + { + return mfunction.get_overload (argument_types ()); + } + + virtual bool needs_release (void) const + { + return type () && jit_typeinfo::get_release (type ()).function; + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent); + + if (use_count ()) + short_print (os) << " = "; + os << "call " << mfunction.name () << " ("; + + for (size_t i = 0; i < argument_count (); ++i) + { + print_argument (os, i); + if (i + 1 < argument_count ()) + os << ", "; + } + return os << ")"; + } + + virtual bool infer (void); + + JIT_VALUE_ACCEPT; +private: + const jit_operation& mfunction; +}; + +// FIXME: This is just ugly... +// checks error_state, if error_state is false then goto the normal branche, +// otherwise goto the error branch +class +jit_error_check : public jit_terminator +{ +public: + jit_error_check (jit_call *acheck_for, jit_block *normal, jit_block *error) + : jit_terminator (2, error, normal, acheck_for) {} + + jit_call *check_for (void) const + { + return static_cast (argument (2)); + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent) << "error_check " << *check_for () << ", "; + print_successor (os, 1) << ", "; + return print_successor (os, 0); + } + + JIT_VALUE_ACCEPT; +protected: + virtual bool check_alive (size_t idx) const + { + return idx == 1 ? true : check_for ()->can_error (); + } +}; + +class +jit_extract_argument : public jit_assign_base +{ +public: + jit_extract_argument (jit_type *atype, jit_variable *adest) + : jit_assign_base (adest) + { + stash_type (atype); + } + + const std::string& name (void) const + { + return dest ()->name (); + } + + const jit_operation::overload& overload (void) const + { + return jit_typeinfo::cast (type (), jit_typeinfo::get_any ()); + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + print_indent (os, indent); + + return short_print (os) << " = extract " << name (); + } + + JIT_VALUE_ACCEPT; +}; + +class +jit_store_argument : public jit_instruction +{ +public: + jit_store_argument (jit_variable *var) + : jit_instruction (var), dest (var) + {} + + const std::string& name (void) const + { + return dest->name (); + } + + const jit_operation::overload& overload (void) const + { + return jit_typeinfo::cast (jit_typeinfo::get_any (), result_type ()); + } + + jit_value *result (void) const + { + return argument (0); + } + + jit_type *result_type (void) const + { + return result ()->type (); + } + + llvm::Value *result_llvm (void) const + { + return result ()->to_llvm (); + } + + virtual std::ostream& print (std::ostream& os, size_t indent = 0) const + { + jit_value *res = result (); + print_indent (os, indent) << "store "; + dest->short_print (os); + + if (! isa (res)) + { + os << " = "; + res->short_print (os); + } + + return os; + } + + JIT_VALUE_ACCEPT; +private: + jit_variable *dest; +}; + +class +jit_ir_walker +{ +public: + virtual ~jit_ir_walker () {} + +#define JIT_METH(clname) \ + virtual void visit (jit_ ## clname&) = 0; + + JIT_VISIT_IR_CLASSES; + +#undef JIT_METH +}; + +template +void +jit_const::accept (jit_ir_walker& walker) +{ + walker.visit (*this); +} + +// convert between IRs +// FIXME: Class relationships are messy from here on down. They need to be +// cleaned up. +class +jit_convert : public tree_walker +{ +public: + typedef std::pair type_bound; + typedef std::vector type_bound_vector; + + jit_convert (llvm::Module *module, tree &tee); + + ~jit_convert (void); + + llvm::Function *get_function (void) const { return function; } + + const std::vector >& get_arguments(void) const + { return arguments; } + + const type_bound_vector& get_bounds (void) const { return bounds; } + + void visit_anon_fcn_handle (tree_anon_fcn_handle&); + + void visit_argument_list (tree_argument_list&); + + void visit_binary_expression (tree_binary_expression&); + + void visit_break_command (tree_break_command&); + + void visit_colon_expression (tree_colon_expression&); + + void visit_continue_command (tree_continue_command&); + + void visit_global_command (tree_global_command&); + + void visit_persistent_command (tree_persistent_command&); + + void visit_decl_elt (tree_decl_elt&); + + void visit_decl_init_list (tree_decl_init_list&); + + void visit_simple_for_command (tree_simple_for_command&); + + void visit_complex_for_command (tree_complex_for_command&); + + void visit_octave_user_script (octave_user_script&); + + void visit_octave_user_function (octave_user_function&); + + void visit_octave_user_function_header (octave_user_function&); + + void visit_octave_user_function_trailer (octave_user_function&); + + void visit_function_def (tree_function_def&); + + void visit_identifier (tree_identifier&); + + void visit_if_clause (tree_if_clause&); + + void visit_if_command (tree_if_command&); + + void visit_if_command_list (tree_if_command_list&); + + void visit_index_expression (tree_index_expression&); + + void visit_matrix (tree_matrix&); + + void visit_cell (tree_cell&); + + void visit_multi_assignment (tree_multi_assignment&); + + void visit_no_op_command (tree_no_op_command&); + + void visit_constant (tree_constant&); + + void visit_fcn_handle (tree_fcn_handle&); + + void visit_parameter_list (tree_parameter_list&); + + void visit_postfix_expression (tree_postfix_expression&); + + void visit_prefix_expression (tree_prefix_expression&); + + void visit_return_command (tree_return_command&); + + void visit_return_list (tree_return_list&); + + void visit_simple_assignment (tree_simple_assignment&); + + void visit_statement (tree_statement&); + + void visit_statement_list (tree_statement_list&); + + void visit_switch_case (tree_switch_case&); + + void visit_switch_case_list (tree_switch_case_list&); + + void visit_switch_command (tree_switch_command&); + + void visit_try_catch_command (tree_try_catch_command&); + + void visit_unwind_protect_command (tree_unwind_protect_command&); + + void visit_while_command (tree_while_command&); + + void visit_do_until_command (tree_do_until_command&); + + // this would be easier with variadic templates + template + T *create (void) + { + T *ret = new T(); + track_value (ret); + return ret; + } + +#define DECL_ARG(n) const ARG ## n& arg ## n +#define JIT_CREATE(N) \ + template \ + T *create (OCT_MAKE_LIST (DECL_ARG, N)) \ + { \ + T *ret = new T (OCT_MAKE_ARG_LIST (arg, N)); \ + track_value (ret); \ + return ret; \ + } + + JIT_CREATE (1) + JIT_CREATE (2) + JIT_CREATE (3) + JIT_CREATE (4) + +#undef JIT_CREATE + +#define JIT_CREATE_CHECKED(N) \ + template \ + jit_call *create_checked (OCT_MAKE_LIST (DECL_ARG, N)) \ + { \ + jit_call *ret = create (OCT_MAKE_ARG_LIST (arg, N)); \ + return create_checked_impl (ret); \ + } + + JIT_CREATE_CHECKED (1) + JIT_CREATE_CHECKED (2) + JIT_CREATE_CHECKED (3) + JIT_CREATE_CHECKED (4) + +#undef JIT_CREATE_CHECKED +#undef DECL_ARG + + typedef std::list block_list; + typedef block_list::iterator block_iterator; + + void append (jit_block *ablock); + + void insert_before (block_iterator iter, jit_block *ablock); + + void insert_before (jit_block *loc, jit_block *ablock) + { + insert_before (loc->location (), ablock); + } + + void insert_after (block_iterator iter, jit_block *ablock); + + void insert_after (jit_block *loc, jit_block *ablock) + { + insert_after (loc->location (), ablock); + } +private: + std::vector > arguments; + type_bound_vector bounds; + + // used instead of return values from visit_* functions + jit_value *result; + + jit_block *entry_block; + + jit_block *final_block; + + jit_block *block; + + llvm::Function *function; + + std::list blocks; + + std::list worklist; + + std::list constants; + + std::list all_values; + + size_t iterator_count; + size_t short_count; + + typedef std::map vmap_t; + vmap_t vmap; + + jit_call *create_checked_impl (jit_call *ret) + { + block->append (ret); + create_check (ret); + return ret; + } + + jit_error_check *create_check (jit_call *call) + { + jit_block *normal = create (block->name ()); + jit_error_check *ret + = block->append (create (call, normal, final_block)); + append (normal); + block = normal; + + return ret; + } + + jit_variable *get_variable (const std::string& vname); + + std::pair resolve (tree_index_expression& exp); + + jit_value *do_assign (tree_expression *exp, jit_value *rhs, + bool artificial = false); + + jit_value *do_assign (const std::string& lhs, jit_value *rhs, bool print, + bool artificial = false); + + jit_value *visit (tree *tee) { return visit (*tee); } + + jit_value *visit (tree& tee); + + void push_worklist (jit_instruction *instr) + { + if (! instr->in_worklist ()) + { + instr->stash_in_worklist (true); + worklist.push_back (instr); + } + } + + void append_users (jit_value *v) + { + for (jit_use *use = v->first_use (); use; use = use->next ()) + push_worklist (use->user ()); + } + + void append_users_term (jit_terminator *term); + + void track_value (jit_value *value) + { + if (value->type ()) + constants.push_back (value); + all_values.push_back (value); + } + + void merge_blocks (void); + + void construct_ssa (void); + + void do_construct_ssa (jit_block& block, size_t avisit_count); + + void remove_dead (); + + void place_releases (void); + + void release_temp (jit_block& ablock, std::set& temp); + + void release_dead_phi (jit_block& ablock); + + void simplify_phi (void); + + void simplify_phi (jit_phi& phi); + + void print_blocks (const std::string& header) + { + std::cout << "-------------------- " << header << " --------------------\n"; + for (std::list::iterator iter = blocks.begin (); + iter != blocks.end (); ++iter) + { + assert (*iter); + (*iter)->print (std::cout, 0); + } + std::cout << std::endl; + } + + void print_dom (void) + { + std::cout << "-------------------- dom info --------------------\n"; + for (std::list::iterator iter = blocks.begin (); + iter != blocks.end (); ++iter) + { + assert (*iter); + (*iter)->print_dom (std::cout); + } + std::cout << std::endl; + } + + bool breaking; // true if we are breaking OR continuing + block_list breaks; + block_list continues; + + void finish_breaks (jit_block *dest, const block_list& lst); + + // this case is much simpler, just convert from the jit ir to llvm + class + convert_llvm : public jit_ir_walker + { + public: + convert_llvm (jit_convert& jc) : jthis (jc) {} + + llvm::Function *convert (llvm::Module *module, + const std::vector >& args, + const std::list& blocks, + const std::list& constants); + +#define JIT_METH(clname) \ + virtual void visit (jit_ ## clname&); + + JIT_VISIT_IR_CLASSES; + +#undef JIT_METH + private: + // name -> llvm argument + std::map arguments; + + void finish_phi (jit_phi *phi); + + void visit (jit_value *jvalue) + { + return visit (*jvalue); + } + + void visit (jit_value &jvalue) + { + jvalue.accept (*this); + } + + llvm::Value *create_call (const jit_operation::overload& ol, jit_value *arg0) + { + std::vector args (1, arg0); + return create_call (ol, args); + } + + llvm::Value *create_call (const jit_operation::overload& ol, jit_value *arg0, + jit_value *arg1) + { + std::vector args (2); + args[0] = arg0; + args[1] = arg1; + + return create_call (ol, args); + } + + llvm::Value *create_call (const jit_operation::overload& ol, + const std::vector& jargs); + + llvm::Value *create_call (const jit_operation::overload& ol, + const std::vector& uses); + private: + jit_convert &jthis; + llvm::Function *function; + llvm::BasicBlock *prelude; + }; +}; + +class jit_info; + +class +tree_jit +{ +public: + tree_jit (void); + + ~tree_jit (void); + + bool execute (tree_simple_for_command& cmd); + + llvm::ExecutionEngine *get_engine (void) const { return engine; } + + llvm::Module *get_module (void) const { return module; } + + void optimize (llvm::Function *fn); + private: + bool initialize (void); + + // FIXME: Temorary hack to test + typedef std::map compiled_map; + llvm::Module *module; + llvm::PassManager *module_pass_manager; + llvm::FunctionPassManager *pass_manager; + llvm::ExecutionEngine *engine; +}; + +class +jit_info +{ +public: + jit_info (tree_jit& tjit, tree& tee); + + ~jit_info (void); + + bool execute (void) const; + + bool match (void) const; +private: + typedef jit_convert::type_bound type_bound; + typedef jit_convert::type_bound_vector type_bound_vector; + typedef void (*jited_function)(octave_base_value**); + + llvm::ExecutionEngine *engine; + jited_function function; + llvm::Function *llvm_function; + + std::vector > arguments; + type_bound_vector bounds; +}; + +// some #defines we use in the header, but not the cc file +#undef JIT_VISIT_IR_CLASSES +#undef JIT_VISIT_IR_CONST +#undef JIT_VALUE_ACCEPT + +#endif +#endif diff --git a/src/pt-loop.cc b/src/pt-loop.cc --- a/src/pt-loop.cc +++ b/src/pt-loop.cc @@ -35,6 +35,7 @@ #include "pt-bp.h" #include "pt-cmd.h" #include "pt-exp.h" +#include "pt-jit.h" #include "pt-jump.h" #include "pt-loop.h" #include "pt-stmt.h" @@ -97,6 +98,7 @@ delete list; delete lead_comm; delete trail_comm; + delete compiled; } tree_command * diff --git a/src/pt-loop.h b/src/pt-loop.h --- a/src/pt-loop.h +++ b/src/pt-loop.h @@ -36,6 +36,8 @@ #include "pt-cmd.h" #include "symtab.h" +class jit_info; + // While. class @@ -146,7 +148,7 @@ tree_simple_for_command (int l = -1, int c = -1) : tree_command (l, c), parallel (false), lhs (0), expr (0), - maxproc (0), list (0), lead_comm (0), trail_comm (0) { } + maxproc (0), list (0), lead_comm (0), trail_comm (0), compiled (0) { } tree_simple_for_command (bool parallel_arg, tree_expression *le, tree_expression *re, @@ -157,7 +159,7 @@ int l = -1, int c = -1) : tree_command (l, c), parallel (parallel_arg), lhs (le), expr (re), maxproc (maxproc_arg), list (lst), - lead_comm (lc), trail_comm (tc) { } + lead_comm (lc), trail_comm (tc), compiled (0) { } ~tree_simple_for_command (void); @@ -180,8 +182,18 @@ void accept (tree_walker& tw); + // some functions use by tree_jit + jit_info *get_info (void) const + { + return compiled; + } + + void stash_info (jit_info *jinfo) + { + compiled = jinfo; + } + private: - // TRUE means operate in parallel (subject to the value of the // maxproc expression). bool parallel; @@ -205,6 +217,9 @@ // Comment preceding ENDFOR token. octave_comment_list *trail_comm; + // compiled version of the loop + jit_info *compiled; + // No copying! tree_simple_for_command (const tree_simple_for_command&); diff --git a/src/pt-stmt.h b/src/pt-stmt.h --- a/src/pt-stmt.h +++ b/src/pt-stmt.h @@ -35,12 +35,13 @@ #include "base-list.h" #include "comment-list.h" #include "symtab.h" +#include "pt.h" // A statement is either a command to execute or an expression to // evaluate. class -tree_statement +tree_statement : public tree { public: diff --git a/src/symtab.h b/src/symtab.h --- a/src/symtab.h +++ b/src/symtab.h @@ -484,7 +484,7 @@ return symbol_record (rep->dup (new_scope)); } - std::string name (void) const { return rep->name; } + const std::string& name (void) const { return rep->name; } octave_value find (const octave_value_list& args = octave_value_list ()) const; @@ -581,6 +581,66 @@ symbol_record (symbol_record_rep *new_rep) : rep (new_rep) { } }; + // Always access a symbol from the current scope. + // Useful for scripts, as they may be executed in more than one scope. + class + symbol_reference + { + public: + symbol_reference (void) : scope (-1) {} + + symbol_reference (symbol_record record, + scope_id curr_scope = symbol_table::current_scope ()) + : scope (curr_scope), sym (record) + {} + + symbol_reference& operator = (const symbol_reference& ref) + { + scope = ref.scope; + sym = ref.sym; + return *this; + } + + // The name is the same regardless of scope. + const std::string& name (void) const { return sym.name (); } + + symbol_record *operator-> (void) + { + update (); + return &sym; + } + + symbol_record *operator-> (void) const + { + update (); + return &sym; + } + + // can be used to place symbol_reference in maps, we don't overload < as + // it doesn't make any sense for symbol_reference + struct comparator + { + bool operator ()(const symbol_reference& lhs, + const symbol_reference& rhs) const + { + return lhs.name () < rhs.name (); + } + }; + private: + void update (void) const + { + scope_id curr_scope = symbol_table::current_scope (); + if (scope != curr_scope || ! sym.is_valid ()) + { + scope = curr_scope; + sym = symbol_table::insert (sym.name ()); + } + } + + mutable scope_id scope; + mutable symbol_record sym; + }; + class fcn_info { diff --git a/src/toplev.cc b/src/toplev.cc --- a/src/toplev.cc +++ b/src/toplev.cc @@ -1325,6 +1325,9 @@ { false, "MAGICK_CPPFLAGS", OCTAVE_CONF_MAGICK_CPPFLAGS }, { false, "MAGICK_LDFLAGS", OCTAVE_CONF_MAGICK_LDFLAGS }, { false, "MAGICK_LIBS", OCTAVE_CONF_MAGICK_LIBS }, + { false, "LLVM_CPPFLAGS", OCTAVE_CONF_LLVM_CPPFLAGS }, + { false, "LLVM_LDFLAGS", OCTAVE_CONF_LLVM_LDFLAGS }, + { false, "LLVM_LIBS", OCTAVE_CONF_LLVM_LIBS }, { false, "MKOCTFILE_DL_LDFLAGS", OCTAVE_CONF_MKOCTFILE_DL_LDFLAGS }, { false, "OCTAVE_LINK_DEPS", OCTAVE_CONF_OCTAVE_LINK_DEPS }, { false, "OCTAVE_LINK_OPTS", OCTAVE_CONF_OCTAVE_LINK_OPTS },