Mercurial > hg > octave-lyh
changeset 10670:654fbde5dceb
make cellfun's fast scalar collection mechanism public
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Fri, 28 May 2010 12:28:06 +0200 |
parents | cab3b148d4e4 |
children | f5f9bc8e83fc |
files | src/ChangeLog src/DLD-FUNCTIONS/cellfun.cc src/ov-base-mat.cc src/ov-base-mat.h src/ov-base-scalar.cc src/ov-base-scalar.h src/ov-base.cc src/ov-base.h src/ov-cell.cc src/ov-float.cc src/ov-float.h src/ov-scalar.cc src/ov-scalar.h src/ov.h |
diffstat | 14 files changed, 232 insertions(+), 180 deletions(-) [+] |
line wrap: on
line diff
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,28 @@ +2010-05-28 Jaroslav Hajek <highegg@gmail.com> + + * ov.h (octave_value::fast_elem_extract, + octave_value::fast_elem_insert): New methods. + * ov-base.cc (octave_base_value::fast_elem_extract, + octave_base_value::fast_elem_insert, + octave_base_value::fast_elem_insert_self): New methods. + * ov-base.h: Declare them. + * ov-base-mat.cc (octave_base_matrix::fast_elem_extract, + octave_base_matrix::fast_elem_insert): New overrides. + * ov-base-mat.h: Declare them. + * ov-base-scalar.cc (octave_base_scalar::fast_elem_extract, + octave_base_scalar::fast_elem_insert_self): New overrides. + * ov-base-scalar.h: Declare them. + (octave_base_scalar::scalar_ref): New method. + * ov-scalar.cc (octave_scalar::fast_elem_insert_self): New override. + * ov-scalar.h: Declare it. + * ov-float.cc (octave_float_scalar::fast_elem_insert_self): New override. + * ov-float.h: Declare it. + * ov-cell.cc (octave_base_matrix<Cell>::fast_elem_extract, + octave_base_matrix<Cell>::fast_elem_insert): New specializations. + * DLD-FUNCTIONS/cellfun.cc (scalar_col_helper, scalar_col_helper_def, + scalar_col_helper_nda, make_col_helper, can_extract): Remove. + (Fcellfun): Use the new fast_elem_insert method. + 2010-05-10 Rik <octave@nomad.inbox5.com> * DLD-FUNCTIONS/eigs.cc: Improve documentation string.
--- a/src/DLD-FUNCTIONS/cellfun.cc +++ b/src/DLD-FUNCTIONS/cellfun.cc @@ -58,172 +58,6 @@ #include "ov-uint32.h" #include "ov-uint64.h" -// Rationale: -// The octave_base_value::subsasgn method carries too much overhead for -// per-element assignment strategy. -// This class will optimize the most optimistic and most likely case -// when the output really is scalar by defining a hierarchy of virtual -// collectors specialized for some scalar types. - -class scalar_col_helper -{ -public: - virtual bool collect (octave_idx_type i, const octave_value& val) = 0; - virtual octave_value result (void) = 0; - virtual ~scalar_col_helper (void) { } -}; - -// The default collector represents what was previously done in the main loop. -// This reuses the existing assignment machinery via octave_value::subsasgn, -// which can perform all sorts of conversions, but is relatively slow. - -class scalar_col_helper_def : public scalar_col_helper -{ - std::list<octave_value_list> idx_list; - octave_value resval; -public: - scalar_col_helper_def (const octave_value& val, const dim_vector& dims) - : idx_list (1), resval (val) - { - idx_list.front ().resize (1); - if (resval.dims () != dims) - resval.resize (dims); - } - ~scalar_col_helper_def (void) { } - - bool collect (octave_idx_type i, const octave_value& val) - { - if (val.numel () == 1) - { - idx_list.front ()(0) = static_cast<double> (i + 1); - resval = resval.subsasgn ("(", idx_list, val); - } - else - error ("cellfun: expecting all values to be scalars for UniformOutput = true"); - - return true; - } - octave_value result (void) - { - return resval; - } -}; - -template <class T> -static bool can_extract (const octave_value& val) -{ return false; } - -#define DEF_CAN_EXTRACT(T, CLASS) \ -template <> \ -bool can_extract<T> (const octave_value& val) \ -{ return val.type_id () == octave_ ## CLASS::static_type_id (); } - -DEF_CAN_EXTRACT (double, scalar); -DEF_CAN_EXTRACT (float, float_scalar); -DEF_CAN_EXTRACT (bool, bool); -DEF_CAN_EXTRACT (octave_int8, int8_scalar); -DEF_CAN_EXTRACT (octave_int16, int16_scalar); -DEF_CAN_EXTRACT (octave_int32, int32_scalar); -DEF_CAN_EXTRACT (octave_int64, int64_scalar); -DEF_CAN_EXTRACT (octave_uint8, uint8_scalar); -DEF_CAN_EXTRACT (octave_uint16, uint16_scalar); -DEF_CAN_EXTRACT (octave_uint32, uint32_scalar); -DEF_CAN_EXTRACT (octave_uint64, uint64_scalar); - -template <> -bool can_extract<Complex> (const octave_value& val) -{ - int t = val.type_id (); - return (t == octave_complex::static_type_id () - || t == octave_scalar::static_type_id ()); -} - -template <> -bool can_extract<FloatComplex> (const octave_value& val) -{ - int t = val.type_id (); - return (t == octave_float_complex::static_type_id () - || t == octave_float_scalar::static_type_id ()); -} - -// This specializes for collecting elements of a single type, by accessing -// an array directly. If the scalar is not valid, it returns false. - -template <class NDA> -class scalar_col_helper_nda : public scalar_col_helper -{ - NDA arrayval; - typedef typename NDA::element_type T; -public: - scalar_col_helper_nda (const octave_value& val, const dim_vector& dims) - : arrayval (dims) - { - arrayval(0) = octave_value_extract<T> (val); - } - ~scalar_col_helper_nda (void) { } - - bool collect (octave_idx_type i, const octave_value& val) - { - bool retval = can_extract<T> (val); - if (retval) - arrayval(i) = octave_value_extract<T> (val); - return retval; - } - octave_value result (void) - { - return arrayval; - } -}; - -template class scalar_col_helper_nda<NDArray>; -template class scalar_col_helper_nda<FloatNDArray>; -template class scalar_col_helper_nda<ComplexNDArray>; -template class scalar_col_helper_nda<FloatComplexNDArray>; -template class scalar_col_helper_nda<boolNDArray>; -template class scalar_col_helper_nda<int8NDArray>; -template class scalar_col_helper_nda<int16NDArray>; -template class scalar_col_helper_nda<int32NDArray>; -template class scalar_col_helper_nda<int64NDArray>; -template class scalar_col_helper_nda<uint8NDArray>; -template class scalar_col_helper_nda<uint16NDArray>; -template class scalar_col_helper_nda<uint32NDArray>; -template class scalar_col_helper_nda<uint64NDArray>; - -// the virtual constructor. -scalar_col_helper * -make_col_helper (const octave_value& val, const dim_vector& dims) -{ - scalar_col_helper *retval; - - // No need to check numel() here. - switch (val.builtin_type ()) - { -#define ARRAYCASE(BTYP, ARRAY) \ - case BTYP: \ - retval = new scalar_col_helper_nda<ARRAY> (val, dims); \ - break - - ARRAYCASE (btyp_double, NDArray); - ARRAYCASE (btyp_float, FloatNDArray); - ARRAYCASE (btyp_complex, ComplexNDArray); - ARRAYCASE (btyp_float_complex, FloatComplexNDArray); - ARRAYCASE (btyp_bool, boolNDArray); - ARRAYCASE (btyp_int8, int8NDArray); - ARRAYCASE (btyp_int16, int16NDArray); - ARRAYCASE (btyp_int32, int32NDArray); - ARRAYCASE (btyp_int64, int64NDArray); - ARRAYCASE (btyp_uint8, uint8NDArray); - ARRAYCASE (btyp_uint16, uint16NDArray); - ARRAYCASE (btyp_uint32, uint32NDArray); - ARRAYCASE (btyp_uint64, uint64NDArray); - default: - retval = new scalar_col_helper_def (val, dims); - break; - } - - return retval; -} - static octave_value_list get_output_list (octave_idx_type count, octave_idx_type nargout, const octave_value_list& inputlist, @@ -636,7 +470,11 @@ if (uniform_output) { - OCTAVE_LOCAL_BUFFER (std::auto_ptr<scalar_col_helper>, retptr, nargout1); + std::list<octave_value_list> idx_list (1); + idx_list.front ().resize (1); + std::string idx_type = "("; + + OCTAVE_LOCAL_BUFFER (octave_value, retv, nargout1); for (octave_idx_type count = 0; count < k ; count++) { @@ -670,7 +508,7 @@ octave_value val = tmp(j); if (val.numel () == 1) - retptr[j].reset (make_col_helper (val, fdims)); + retv[j] = val.resize (fdims); else { error ("cellfun: expecting all values to be scalars for UniformOutput = true"); @@ -684,13 +522,22 @@ { octave_value val = tmp(j); - if (! retptr[j]->collect (count, val)) + if (! retv[j].fast_elem_insert (count, val)) { - // FIXME: A more elaborate structure would allow again a virtual - // constructor here. - retptr[j].reset (new scalar_col_helper_def (retptr[j]->result (), - fdims)); - retptr[j]->collect (count, val); + if (val.numel () == 1) + { + idx_list.front ()(0) = count + 1.0; + retv[j].assign (octave_value::op_asn_eq, + idx_type, idx_list, val); + + if (error_state) + break; + } + else + { + error ("cellfun: expecting all values to be scalars for UniformOutput = true"); + break; + } } } } @@ -701,12 +548,7 @@ retval.resize (nargout1); for (int j = 0; j < nargout1; j++) - { - if (retptr[j].get ()) - retval(j) = retptr[j]->result (); - else - retval(j) = Matrix (); - } + retval(j) = retv[j]; } else {
--- a/src/ov-base-mat.cc +++ b/src/ov-base-mat.cc @@ -33,6 +33,7 @@ #include "oct-map.h" #include "ov-base.h" #include "ov-base-mat.h" +#include "ov-base-scalar.h" #include "pr-output.h" template <class MT> @@ -448,3 +449,35 @@ { matrix.print_info (os, prefix); } + +template <class MT> +octave_value +octave_base_matrix<MT>::fast_elem_extract (octave_idx_type n) const +{ + if (n < matrix.numel ()) + return matrix(n); + else + return octave_value (); +} + +template <class MT> +bool +octave_base_matrix<MT>::fast_elem_insert (octave_idx_type n, + const octave_value& x) +{ + if (n < matrix.numel ()) + { + // Don't use builtin_type () here to avoid an extra VM call. + typedef typename MT::element_type ET; + const builtin_type_t btyp = class_to_btyp<ET>::btyp; + if (btyp == btyp_unknown) // Dead branch? + return false; + + // Set up the pointer to the proper place. + void *here = reinterpret_cast<void *> (&matrix(n)); + // Ask x to store there if it can. + return x.get_rep().fast_elem_insert_self (here, btyp); + } + else + return false; +}
--- a/src/ov-base-mat.h +++ b/src/ov-base-mat.h @@ -165,6 +165,12 @@ return matrix; } + octave_value + fast_elem_extract (octave_idx_type n) const; + + bool + fast_elem_insert (octave_idx_type n, const octave_value& x); + protected: MT matrix;
--- a/src/ov-base-scalar.cc +++ b/src/ov-base-scalar.cc @@ -154,3 +154,18 @@ os << name << " = "; return false; } + +template <class ST> +bool +octave_base_scalar<ST>::fast_elem_insert_self (void *where, builtin_type_t btyp) const +{ + + // Don't use builtin_type () here to avoid an extra VM call. + if (btyp == class_to_btyp<ST>::btyp) + { + *(reinterpret_cast<ST *>(where)) = scalar; + return true; + } + else + return false; +}
--- a/src/ov-base-scalar.h +++ b/src/ov-base-scalar.h @@ -136,6 +136,12 @@ // You should not use it anywhere else. void *mex_get_data (void) const { return const_cast<ST *> (&scalar); } + const ST& scalar_ref (void) const { return scalar; } + + ST& scalar_ref (void) { return scalar; } + + bool fast_elem_insert_self (void *where, builtin_type_t btyp) const; + protected: // The value of this scalar.
--- a/src/ov-base.cc +++ b/src/ov-base.cc @@ -1425,6 +1425,25 @@ curr_print_indent_level = 0; } + +octave_value +octave_base_value::fast_elem_extract (octave_idx_type n) const +{ + return octave_value (); +} + +bool +octave_base_value::fast_elem_insert (octave_idx_type n, const octave_value& x) +{ + return false; +} + +bool +octave_base_value::fast_elem_insert_self (void *where, builtin_type_t btyp) const +{ + return false; +} + CONVDECLX (matrix_conv) { return new octave_matrix ();
--- a/src/ov-base.h +++ b/src/ov-base.h @@ -714,6 +714,26 @@ virtual octave_value map (unary_mapper_t) const; + // These are fast indexing & assignment shortcuts for extracting + // or inserting a single scalar from/to an array. + + // Extract the n-th element, aka val(n). Result is undefined if val is not an + // array type or n is out of range. Never error. + virtual octave_value + fast_elem_extract (octave_idx_type n) const; + + // Assign the n-th element, aka val(n) = x. Returns false if val is not an + // array type, x is not a matching scalar type, or n is out of range. + // Never error. + virtual bool + fast_elem_insert (octave_idx_type n, const octave_value& x); + + // This is a helper for the above, to be overriden in scalar types. The + // whole point is to handle the insertion efficiently with just *two* VM + // calls, which is basically the theoretical minimum. + virtual bool + fast_elem_insert_self (void *where, builtin_type_t btyp) const; + protected: // This should only be called for derived types.
--- a/src/ov-cell.cc +++ b/src/ov-cell.cc @@ -93,6 +93,34 @@ matrix.delete_elements (idx); } +// FIXME: this list of specializations is becoming so long that we should really ask +// whether octave_cell should inherit from octave_base_matrix at all. + +template <> +octave_value +octave_base_matrix<Cell>::fast_elem_extract (octave_idx_type n) const +{ + if (n < matrix.numel ()) + return Cell (matrix(n)); + else + return octave_value (); +} + +template <> +bool +octave_base_matrix<Cell>::fast_elem_insert (octave_idx_type n, + const octave_value& x) +{ + const octave_cell *xrep = + dynamic_cast<const octave_cell *> (&x.get_rep ()); + + bool retval = xrep && xrep->matrix.numel () == 1 && n < matrix.numel (); + if (retval) + matrix(n) = xrep->matrix(0); + + return retval; +} + template class octave_base_matrix<Cell>; DEFINE_OCTAVE_ALLOCATOR (octave_cell);
--- a/src/ov-float.cc +++ b/src/ov-float.cc @@ -319,3 +319,22 @@ return octave_base_value::map (umap); } } + +bool +octave_float_scalar::fast_elem_insert_self (void *where, builtin_type_t btyp) const +{ + + // Support inline real->complex conversion. + if (btyp == btyp_float) + { + *(reinterpret_cast<float *>(where)) = scalar; + return true; + } + else if (btyp == btyp_float_complex) + { + *(reinterpret_cast<FloatComplex *>(where)) = scalar; + return true; + } + else + return false; +}
--- a/src/ov-float.h +++ b/src/ov-float.h @@ -246,6 +246,8 @@ octave_value map (unary_mapper_t umap) const; + bool fast_elem_insert_self (void *where, builtin_type_t btyp) const; + private: DECLARE_OCTAVE_ALLOCATOR
--- a/src/ov-scalar.cc +++ b/src/ov-scalar.cc @@ -341,3 +341,22 @@ return octave_base_value::map (umap); } } + +bool +octave_scalar::fast_elem_insert_self (void *where, builtin_type_t btyp) const +{ + + // Support inline real->complex conversion. + if (btyp == btyp_double) + { + *(reinterpret_cast<double *>(where)) = scalar; + return true; + } + else if (btyp == btyp_complex) + { + *(reinterpret_cast<Complex *>(where)) = scalar; + return true; + } + else + return false; +}
--- a/src/ov-scalar.h +++ b/src/ov-scalar.h @@ -247,6 +247,8 @@ octave_value map (unary_mapper_t umap) const; + bool fast_elem_insert_self (void *where, builtin_type_t btyp) const; + private: DECLARE_OCTAVE_ALLOCATOR
--- a/src/ov.h +++ b/src/ov.h @@ -1139,6 +1139,22 @@ octave_value map (octave_base_value::unary_mapper_t umap) const { return rep->map (umap); } + // Extract the n-th element, aka val(n). Result is undefined if val is not an + // array type or n is out of range. Never error. + octave_value + fast_elem_extract (octave_idx_type n) const + { return rep->fast_elem_extract (n); } + + // Assign the n-th element, aka val(n) = x. Returns false if val is not an + // array type, x is not a matching scalar type, or n is out of range. + // Never error. + virtual bool + fast_elem_insert (octave_idx_type n, const octave_value& x) + { + make_unique (); + return rep->fast_elem_insert (n, x); + } + protected: // The real representation.