Mercurial > hg > octave-lyh
changeset 9678:c929f09457b7
rewrite num2cell for speed-up + a few associated fixes
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Thu, 01 Oct 2009 14:07:06 +0200 |
parents | 8cf522ce9c4d |
children | 0896714301e4 |
files | liboctave/Array.cc liboctave/ChangeLog src/ChangeLog src/DLD-FUNCTIONS/cellfun.cc |
diffstat | 4 files changed, 201 insertions(+), 64 deletions(-) [+] |
line wrap: on
line diff
--- a/liboctave/Array.cc +++ b/liboctave/Array.cc @@ -592,6 +592,8 @@ // Need this array to check for identical elements in permutation array. OCTAVE_LOCAL_BUFFER_INIT (bool, checked, perm_vec_len, false); + bool identity = true; + // Find dimension vector of permuted array. for (int i = 0; i < perm_vec_len; i++) { @@ -614,11 +616,17 @@ return retval; } else - checked[perm_elt] = true; + { + checked[perm_elt] = true; + identity = identity && perm_elt == i; + } dv_new(i) = dv(perm_elt); } + if (identity) + return *this; + if (inv) { for (int i = 0; i < perm_vec_len; i++)
--- a/liboctave/ChangeLog +++ b/liboctave/ChangeLog @@ -1,3 +1,7 @@ +2009-10-01 Jaroslav Hajek <highegg@gmail.com> + + * Array.cc (Array<T>::permute): Fast case identity permutation. + 2009-09-27 Jaroslav Hajek <highegg@gmail.com> * oct-cmplx.h: Fix complex-real orderings.
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,10 @@ +2009-10-01 Jaroslav Hajek <highegg@gmail.com> + + * DLD-FUNCTIONS/cellfun.cc + (do_num2cell_helper, do_num2cell): New funcs. + (Fnum2cell): Rewrite. + (do_cellslices_nda): Do not leave trailing dims. + 2009-09-30 John W. Eaton <jwe@octave.org> * error.cc (error_1, pr_where_2, handle_message):
--- a/src/DLD-FUNCTIONS/cellfun.cc +++ b/src/DLD-FUNCTIONS/cellfun.cc @@ -40,6 +40,7 @@ #include "variables.h" #include "ov-colon.h" #include "unwind-prot.h" +#include "gripes.h" // Rationale: // The octave_base_value::subsasgn method carries too much overhead for @@ -905,6 +906,91 @@ */ +static void +do_num2cell_helper (const dim_vector& dv, + const Array<int>& dimv, + dim_vector& celldv, dim_vector& arraydv, + Array<int>& perm) +{ + int dvl = dimv.length (); + int maxd = dv.length (); + celldv = dv; + for (int i = 0; i < dvl; i++) + maxd = std::max (maxd, dimv(i)); + if (maxd > dv.length ()) + celldv.resize (maxd, 1); + arraydv = celldv; + + OCTAVE_LOCAL_BUFFER_INIT (bool, sing, maxd, false); + + perm.clear (maxd); + for (int i = 0; i < dvl; i++) + { + int k = dimv(i) - 1; + if (k < 0) + { + error ("num2cell: dimension indices must be positive"); + return; + } + else if (i > 0 && k < dimv(i-1) - 1) + { + error ("num2cell: dimension indices must be strictly increasing"); + return; + } + + sing[k] = true; + perm(i) = k; + } + + for (int k = 0, i = dvl; k < maxd; k++) + if (! sing[k]) + perm(i++) = k; + + for (int i = 0; i < maxd; i++) + if (sing[i]) + celldv(i) = 1; + else + arraydv(i) = 1; +} + +template<class NDA> +static Cell +do_num2cell (const NDA& array, const Array<int>& dimv) +{ + if (dimv.is_empty ()) + { + Cell retval (array.dims ()); + octave_idx_type nel = array.numel (); + for (octave_idx_type i = 0; i < nel; i++) + retval.xelem (i) = array(i); + + return retval; + } + else + { + dim_vector celldv, arraydv; + Array<int> perm; + do_num2cell_helper (array.dims (), dimv, celldv, arraydv, perm); + if (error_state) + return Cell (); + + NDA parray = array.permute (perm); + + octave_idx_type nela = arraydv.numel (), nelc = celldv.numel (); + parray = parray.reshape (dim_vector (nela, nelc)); + + Cell retval (celldv); + for (octave_idx_type i = 0; i < nelc; i++) + { + NDA tmp (parray.index (idx_vector::colon, idx_vector (i))); + retval.xelem (i) = tmp.reshape (arraydv); + } + + return retval; + } +} + + DEFUN_DLD (num2cell, args, , "-*- texinfo -*-\n\ @deftypefn {Loadable Function} {@var{c} =} num2cell (@var{m})\n\ @@ -922,72 +1008,83 @@ print_usage (); else { - dim_vector dv = args(0).dims (); - Array<int> sings; - - if (nargin == 2) - { - ColumnVector dsings = ColumnVector (args(1).vector_value - (false, true)); - sings.resize (dsings.length()); - - if (!error_state) - for (octave_idx_type i = 0; i < dsings.length(); i++) - if (dsings(i) > dv.length() || dsings(i) < 1 || - D_NINT(dsings(i)) != dsings(i)) - { - error ("invalid dimension specified"); - break; - } - else - sings(i) = NINT(dsings(i)) - 1; - } - - if (! error_state) - { - Array<bool> idx_colon (dv.length()); - dim_vector new_dv (dv); - octave_value_list lst (new_dv.length(), octave_value()); + octave_value array = args(0); + Array<int> dimv; + if (nargin > 1) + dimv = args (1).int_vector_value (true); - for (int i = 0; i < dv.length(); i++) - { - idx_colon(i) = false; - for (int j = 0; j < sings.length(); j++) - { - if (sings(j) == i) - { - new_dv(i) = 1; - idx_colon(i) = true; - lst(i) = octave_value (octave_value::magic_colon_t); - break; - } - } - } - - Cell ret (new_dv); - octave_idx_type nel = new_dv.numel(); - octave_idx_type ntot = 1; + if (error_state) + ; + else if (array.is_bool_type ()) + retval = do_num2cell (array.bool_array_value (), dimv); + else if (array.is_char_matrix ()) + retval = do_num2cell (array.char_array_value (), dimv); + else if (array.is_numeric_type ()) + { + if (array.is_integer_type ()) + { + if (array.is_int8_type ()) + retval = do_num2cell (array.int8_array_value (), dimv); + else if (array.is_int16_type ()) + retval = do_num2cell (array.int16_array_value (), dimv); + else if (array.is_int32_type ()) + retval = do_num2cell (array.int32_array_value (), dimv); + else if (array.is_int64_type ()) + retval = do_num2cell (array.int64_array_value (), dimv); + else if (array.is_uint8_type ()) + retval = do_num2cell (array.uint8_array_value (), dimv); + else if (array.is_uint16_type ()) + retval = do_num2cell (array.uint16_array_value (), dimv); + else if (array.is_uint32_type ()) + retval = do_num2cell (array.uint32_array_value (), dimv); + else if (array.is_uint64_type ()) + retval = do_num2cell (array.uint64_array_value (), dimv); + } + else if (array.is_complex_type ()) + { + if (array.is_single_type ()) + retval = do_num2cell (array.float_complex_array_value (), dimv); + else + retval = do_num2cell (array.complex_array_value (), dimv); + } + else + { + if (array.is_single_type ()) + retval = do_num2cell (array.float_array_value (), dimv); + else + retval = do_num2cell (array.array_value (), dimv); + } + } + else if (array.is_cell () || array.is_map ()) + { + dim_vector celldv, arraydv; + Array<int> perm; + do_num2cell_helper (array.dims (), dimv, celldv, arraydv, perm); - for (int j = 0; j < new_dv.length()-1; j++) - ntot *= new_dv(j); + if (! error_state) + { + // FIXME: this operation may be rather inefficient. + octave_value parray = array.permute (perm); + + octave_idx_type nela = arraydv.numel (), nelc = celldv.numel (); + parray = parray.reshape (dim_vector (nela, nelc)); + + Cell retcell (celldv); + octave_value_list idx (2); + idx(0) = octave_value::magic_colon_t; - for (octave_idx_type i = 0; i < nel; i++) - { - octave_idx_type n = ntot; - octave_idx_type ii = i; - for (int j = new_dv.length() - 1; j >= 0 ; j--) - { - if (! idx_colon(j)) - lst (j) = ii/n + 1; - ii = ii % n; - if (j != 0) - n /= new_dv(j-1); - } - ret(i) = args(0).do_index_op(lst, 0); - } + for (octave_idx_type i = 0; i < nelc; i++) + { + idx(1) = i + 1; + octave_value tmp = parray.do_index_op (idx); + retcell(i) = tmp.reshape (arraydv); + } - retval = ret; - } + retval = retcell; + } + } + else + gripe_wrong_type_arg ("num2cell", array); } return retval; @@ -1224,7 +1321,7 @@ */ template <class NDA> -Cell +static Cell do_cellslices_nda (const NDA& array, const idx_vector& lb, const idx_vector& ub) { octave_idx_type n = lb.length (0); @@ -1242,7 +1339,9 @@ for (octave_idx_type i = 0; i < n && ! error_state; i++) { // Do it with a single index to speed things up. + dv = array.dims (); dv(dv.length () - 1) = ub(i) + 1 - lb(i); + dv.chop_trailing_singletons (); retval(i) = array.index (idx_vector (nl*lb(i), nl*(ub(i) + 1))).reshape (dv); } } @@ -1293,6 +1392,25 @@ retcell = do_cellslices_nda (x.bool_array_value (), lb, ub); else if (x.is_char_matrix ()) retcell = do_cellslices_nda (x.char_array_value (), lb, ub); + else if (x.is_integer_type ()) + { + if (x.is_int8_type ()) + retcell = do_cellslices_nda (x.int8_array_value (), lb, ub); + else if (x.is_int16_type ()) + retcell = do_cellslices_nda (x.int16_array_value (), lb, ub); + else if (x.is_int32_type ()) + retcell = do_cellslices_nda (x.int32_array_value (), lb, ub); + else if (x.is_int64_type ()) + retcell = do_cellslices_nda (x.int64_array_value (), lb, ub); + else if (x.is_uint8_type ()) + retcell = do_cellslices_nda (x.uint8_array_value (), lb, ub); + else if (x.is_uint16_type ()) + retcell = do_cellslices_nda (x.uint16_array_value (), lb, ub); + else if (x.is_uint32_type ()) + retcell = do_cellslices_nda (x.uint32_array_value (), lb, ub); + else if (x.is_uint64_type ()) + retcell = do_cellslices_nda (x.uint64_array_value (), lb, ub); + } else if (x.is_complex_type ()) { if (x.is_single_type ())