Mercurial > hg > octave-nkf
diff src/DLD-FUNCTIONS/cellfun.cc @ 10670:654fbde5dceb
make cellfun's fast scalar collection mechanism public
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Fri, 28 May 2010 12:28:06 +0200 |
parents | 4d1fc073fbb7 |
children | a8ce6bdecce5 |
line wrap: on
line diff
--- a/src/DLD-FUNCTIONS/cellfun.cc +++ b/src/DLD-FUNCTIONS/cellfun.cc @@ -58,172 +58,6 @@ #include "ov-uint32.h" #include "ov-uint64.h" -// Rationale: -// The octave_base_value::subsasgn method carries too much overhead for -// per-element assignment strategy. -// This class will optimize the most optimistic and most likely case -// when the output really is scalar by defining a hierarchy of virtual -// collectors specialized for some scalar types. - -class scalar_col_helper -{ -public: - virtual bool collect (octave_idx_type i, const octave_value& val) = 0; - virtual octave_value result (void) = 0; - virtual ~scalar_col_helper (void) { } -}; - -// The default collector represents what was previously done in the main loop. -// This reuses the existing assignment machinery via octave_value::subsasgn, -// which can perform all sorts of conversions, but is relatively slow. - -class scalar_col_helper_def : public scalar_col_helper -{ - std::list<octave_value_list> idx_list; - octave_value resval; -public: - scalar_col_helper_def (const octave_value& val, const dim_vector& dims) - : idx_list (1), resval (val) - { - idx_list.front ().resize (1); - if (resval.dims () != dims) - resval.resize (dims); - } - ~scalar_col_helper_def (void) { } - - bool collect (octave_idx_type i, const octave_value& val) - { - if (val.numel () == 1) - { - idx_list.front ()(0) = static_cast<double> (i + 1); - resval = resval.subsasgn ("(", idx_list, val); - } - else - error ("cellfun: expecting all values to be scalars for UniformOutput = true"); - - return true; - } - octave_value result (void) - { - return resval; - } -}; - -template <class T> -static bool can_extract (const octave_value& val) -{ return false; } - -#define DEF_CAN_EXTRACT(T, CLASS) \ -template <> \ -bool can_extract<T> (const octave_value& val) \ -{ return val.type_id () == octave_ ## CLASS::static_type_id (); } - -DEF_CAN_EXTRACT (double, scalar); -DEF_CAN_EXTRACT (float, float_scalar); -DEF_CAN_EXTRACT (bool, bool); -DEF_CAN_EXTRACT (octave_int8, int8_scalar); -DEF_CAN_EXTRACT (octave_int16, int16_scalar); -DEF_CAN_EXTRACT (octave_int32, int32_scalar); -DEF_CAN_EXTRACT (octave_int64, int64_scalar); -DEF_CAN_EXTRACT (octave_uint8, uint8_scalar); -DEF_CAN_EXTRACT (octave_uint16, uint16_scalar); -DEF_CAN_EXTRACT (octave_uint32, uint32_scalar); -DEF_CAN_EXTRACT (octave_uint64, uint64_scalar); - -template <> -bool can_extract<Complex> (const octave_value& val) -{ - int t = val.type_id (); - return (t == octave_complex::static_type_id () - || t == octave_scalar::static_type_id ()); -} - -template <> -bool can_extract<FloatComplex> (const octave_value& val) -{ - int t = val.type_id (); - return (t == octave_float_complex::static_type_id () - || t == octave_float_scalar::static_type_id ()); -} - -// This specializes for collecting elements of a single type, by accessing -// an array directly. If the scalar is not valid, it returns false. - -template <class NDA> -class scalar_col_helper_nda : public scalar_col_helper -{ - NDA arrayval; - typedef typename NDA::element_type T; -public: - scalar_col_helper_nda (const octave_value& val, const dim_vector& dims) - : arrayval (dims) - { - arrayval(0) = octave_value_extract<T> (val); - } - ~scalar_col_helper_nda (void) { } - - bool collect (octave_idx_type i, const octave_value& val) - { - bool retval = can_extract<T> (val); - if (retval) - arrayval(i) = octave_value_extract<T> (val); - return retval; - } - octave_value result (void) - { - return arrayval; - } -}; - -template class scalar_col_helper_nda<NDArray>; -template class scalar_col_helper_nda<FloatNDArray>; -template class scalar_col_helper_nda<ComplexNDArray>; -template class scalar_col_helper_nda<FloatComplexNDArray>; -template class scalar_col_helper_nda<boolNDArray>; -template class scalar_col_helper_nda<int8NDArray>; -template class scalar_col_helper_nda<int16NDArray>; -template class scalar_col_helper_nda<int32NDArray>; -template class scalar_col_helper_nda<int64NDArray>; -template class scalar_col_helper_nda<uint8NDArray>; -template class scalar_col_helper_nda<uint16NDArray>; -template class scalar_col_helper_nda<uint32NDArray>; -template class scalar_col_helper_nda<uint64NDArray>; - -// the virtual constructor. -scalar_col_helper * -make_col_helper (const octave_value& val, const dim_vector& dims) -{ - scalar_col_helper *retval; - - // No need to check numel() here. - switch (val.builtin_type ()) - { -#define ARRAYCASE(BTYP, ARRAY) \ - case BTYP: \ - retval = new scalar_col_helper_nda<ARRAY> (val, dims); \ - break - - ARRAYCASE (btyp_double, NDArray); - ARRAYCASE (btyp_float, FloatNDArray); - ARRAYCASE (btyp_complex, ComplexNDArray); - ARRAYCASE (btyp_float_complex, FloatComplexNDArray); - ARRAYCASE (btyp_bool, boolNDArray); - ARRAYCASE (btyp_int8, int8NDArray); - ARRAYCASE (btyp_int16, int16NDArray); - ARRAYCASE (btyp_int32, int32NDArray); - ARRAYCASE (btyp_int64, int64NDArray); - ARRAYCASE (btyp_uint8, uint8NDArray); - ARRAYCASE (btyp_uint16, uint16NDArray); - ARRAYCASE (btyp_uint32, uint32NDArray); - ARRAYCASE (btyp_uint64, uint64NDArray); - default: - retval = new scalar_col_helper_def (val, dims); - break; - } - - return retval; -} - static octave_value_list get_output_list (octave_idx_type count, octave_idx_type nargout, const octave_value_list& inputlist, @@ -636,7 +470,11 @@ if (uniform_output) { - OCTAVE_LOCAL_BUFFER (std::auto_ptr<scalar_col_helper>, retptr, nargout1); + std::list<octave_value_list> idx_list (1); + idx_list.front ().resize (1); + std::string idx_type = "("; + + OCTAVE_LOCAL_BUFFER (octave_value, retv, nargout1); for (octave_idx_type count = 0; count < k ; count++) { @@ -670,7 +508,7 @@ octave_value val = tmp(j); if (val.numel () == 1) - retptr[j].reset (make_col_helper (val, fdims)); + retv[j] = val.resize (fdims); else { error ("cellfun: expecting all values to be scalars for UniformOutput = true"); @@ -684,13 +522,22 @@ { octave_value val = tmp(j); - if (! retptr[j]->collect (count, val)) + if (! retv[j].fast_elem_insert (count, val)) { - // FIXME: A more elaborate structure would allow again a virtual - // constructor here. - retptr[j].reset (new scalar_col_helper_def (retptr[j]->result (), - fdims)); - retptr[j]->collect (count, val); + if (val.numel () == 1) + { + idx_list.front ()(0) = count + 1.0; + retv[j].assign (octave_value::op_asn_eq, + idx_type, idx_list, val); + + if (error_state) + break; + } + else + { + error ("cellfun: expecting all values to be scalars for UniformOutput = true"); + break; + } } } } @@ -701,12 +548,7 @@ retval.resize (nargout1); for (int j = 0; j < nargout1; j++) - { - if (retptr[j].get ()) - retval(j) = retptr[j]->result (); - else - retval(j) = Matrix (); - } + retval(j) = retv[j]; } else {