Mercurial > hg > octave-nkf
diff src/data.cc @ 10758:f3892d8eea9f
optimize horzcat/vertcat for scalars, cells and structs
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Mon, 28 Jun 2010 12:06:48 +0200 |
parents | f7f26094021b |
children | b397b8edd8c5 |
line wrap: on
line diff
--- a/src/data.cc +++ b/src/data.cc @@ -1369,6 +1369,17 @@ */ +static bool +all_scalar_1x1 (const octave_value_list& args) +{ + int n_args = args.length (); + for (int i = 0; i < n_args; i++) + if (args(i).numel () != 1) + return false; + + return true; +} + template <class TYPE, class T> static void single_type_concat (Array<T>& result, @@ -1376,17 +1387,41 @@ int dim) { int n_args = args.length (); - OCTAVE_LOCAL_BUFFER (Array<T>, array_list, n_args); - - for (int j = 0; j < n_args && ! error_state; j++) + if (! (equal_types<T, char>::value + || equal_types<T, octave_value>::value) + && all_scalar_1x1 (args)) { - octave_quit (); - - array_list[j] = octave_value_extract<TYPE> (args(j)); + // Optimize all scalars case. + dim_vector dv (1, 1); + if (dim == -1 || dim == -2) + dim = -dim - 1; + else if (dim >= 2) + dv.resize (dim+1, 1); + dv(dim) = n_args; + + result.clear (dv); + + for (int j = 0; j < n_args && ! error_state; j++) + { + octave_quit (); + + result(j) = octave_value_extract<T> (args(j)); + } } - - if (! error_state) - result = Array<T>::cat (dim, n_args, array_list); + else + { + OCTAVE_LOCAL_BUFFER (Array<T>, array_list, n_args); + + for (int j = 0; j < n_args && ! error_state; j++) + { + octave_quit (); + + array_list[j] = octave_value_extract<TYPE> (args(j)); + } + + if (! error_state) + result = Array<T>::cat (dim, n_args, array_list); + } } template <class TYPE, class T> @@ -1421,6 +1456,44 @@ return result; } +template<class MAP> +static void +single_type_concat_map (octave_map& result, + const octave_value_list& args, + int dim) +{ + int n_args = args.length (); + OCTAVE_LOCAL_BUFFER (MAP, map_list, n_args); + + for (int j = 0; j < n_args && ! error_state; j++) + { + octave_quit (); + + map_list[j] = octave_value_extract<MAP> (args(j)); + } + + if (! error_state) + result = octave_map::cat (dim, n_args, map_list); +} + +static octave_map +do_single_type_concat_map (const octave_value_list& args, + int dim) +{ + octave_map result; + if (all_scalar_1x1 (args)) // optimize all scalars case. + { + if (dim < 0) + dim = -dim; + + single_type_concat_map<octave_scalar_map> (result, args, dim); + } + else + single_type_concat_map<octave_map> (result, args, dim); + + return result; +} + static octave_value do_cat (const octave_value_list& args, int dim, std::string fname) { @@ -1514,6 +1587,10 @@ retval = do_single_type_concat<uint32NDArray> (args, dim); else if (result_type == "uint64") retval = do_single_type_concat<uint64NDArray> (args, dim); + else if (result_type == "cell") + retval = do_single_type_concat<Cell> (args, dim); + else if (result_type == "struct") + retval = do_single_type_concat_map (args, dim); else { dim_vector dv = args(0).dims ();