Mercurial > hg > octave-shane
changeset 9028:e67dc11ed6e8
use Array<T>::find in find
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Thu, 26 Mar 2009 13:50:46 +0100 |
parents | 9a46ba093db4 |
children | 2df28ad88b0e |
files | src/ChangeLog src/DLD-FUNCTIONS/find.cc |
diffstat | 2 files changed, 106 insertions(+), 130 deletions(-) [+] |
line wrap: on
line diff
--- a/src/ChangeLog +++ b/src/ChangeLog @@ -1,3 +1,12 @@ +2009-03-26 Jaroslav Hajek <highegg@gmail.com> + + * DLD-FUNCTIONS/find.cc + (find_nonzero_elem_idx (const Array<T>&, ...)): Simplify. + Instantiate for bool and octave_int types. + (find_nonzero_elem_idx (const Sparse<T>&, ...)): + Instantiate for bool. + (Ffind): Handle bool and octave_int cases. + 2009-03-25 John W. Eaton <jwe@octave.org> * version.h (OCTAVE_VERSION): Now 3.1.55+.
--- a/src/DLD-FUNCTIONS/find.cc +++ b/src/DLD-FUNCTIONS/find.cc @@ -43,155 +43,75 @@ { octave_value_list retval ((nargout == 0 ? 1 : nargout), Matrix ()); - octave_idx_type count = 0; - - octave_idx_type nel = nda.nelem (); - - // Set the starting element to the correct value based on the - // direction to search. - octave_idx_type k = 0; - if (direction == -1) - k = nel - 1; - - // Search in the default range. - octave_idx_type start_el = -1; - octave_idx_type end_el = -1; - - // Search for the number of elements to return. - while (k < nel && k > -1 && n_to_find != count) - { - OCTAVE_QUIT; - - if (nda(k) != T ()) - { - end_el = k; - if (start_el == -1) - start_el = k; - count++; - } - k = k + direction; - } - - // Reverse the range if we're looking backward. - if (direction == -1) - { - octave_idx_type tmp_el = start_el; - start_el = end_el; - end_el = tmp_el; - } - // Fix an off by one error. - end_el++; - - // If the original argument was a row vector, force a row vector of - // the overall indices to be returned. But see below for scalar - // case... - - octave_idx_type result_nr = count; - octave_idx_type result_nc = 1; - - bool column_vector_arg = false; - bool scalar_arg = false; - - if (nda.ndims () == 2) - { - octave_idx_type nr = nda.rows (); - octave_idx_type nc = nda.columns (); + Array<octave_idx_type> idx; + if (n_to_find >= 0) + idx = nda.find (n_to_find, direction == -1); + else + idx = nda.find (); - if (nr == 1) - { - result_nr = 1; - result_nc = count; - - scalar_arg = (nc == 1); - } - else if (nc == 1) - column_vector_arg = true; - } - - Matrix idx (result_nr, result_nc); - - Matrix i_idx (result_nr, result_nc); - Matrix j_idx (result_nr, result_nc); - - ArrayN<T> val (dim_vector (result_nr, result_nc)); - - if (count > 0) - { - count = 0; - - octave_idx_type nr = nda.rows (); - - octave_idx_type i = 0; - - // Search for elements to return. Only search the region where - // there are elements to be found using the count that we want - // to find. + // Fixup idx dimensions, for Matlab compatibility. + // find(zeros(0,0)) -> zeros(0,0) + // find(zeros(1,0)) -> zeros(1,0) + // find(zeros(0,1)) -> zeros(0,1) + // find(zeros(0,X)) -> zeros(0,1) + // find(zeros(1,1)) -> zeros(0,0) !!!! WHY? + // find(zeros(0,1,0)) -> zeros(0,0) + // find(zeros(0,1,0,1)) -> zeros(0,0) etc + // FIXME: I don't believe this is right. Matlab seems to violate its own docs + // here, because a scalar *is* a row vector. - // For compatibility, all N-d arrays are handled as if they are - // 2-d, with the number of columns equal to "prod (dims (2:end))". - - for (k = start_el; k < end_el; k++) - { - OCTAVE_QUIT; - - if (nda(k) != T ()) - { - idx(count) = k + 1; - - octave_idx_type xr = k % nr; - i_idx(count) = xr + 1; - j_idx(count) = (k - xr) / nr + 1; - - val(count) = nda(k); - - count++; - } - - i++; - } - } - else if (scalar_arg || (nda.rows () == 0 && ! column_vector_arg)) - { - idx.resize (0, 0); - - i_idx.resize (0, 0); - j_idx.resize (0, 0); - - val.resize (dim_vector (0, 0)); - } + if ((nda.numel () == 1 && idx.is_empty ()) + || (nda.rows () == 0 && nda.dims ().numel (1) == 0)) + idx = idx.reshape (dim_vector (0, 0)); + else if (nda.rows () == 1 && nda.ndims () == 2) + idx = idx.reshape (dim_vector (1, idx.length ())); switch (nargout) { default: case 3: - retval(2) = val; + retval(2) = ArrayN<T> (nda.index (idx_vector (idx))); // Fall through! case 2: - retval(1) = j_idx; - retval(0) = i_idx; - break; + { + Array<octave_idx_type> jdx (idx.dims ()); + octave_idx_type n = idx.length (), nr = nda.rows (); + for (octave_idx_type i = 0; i < n; i++) + { + jdx.xelem (i) = idx.xelem (i) / nr; + idx.xelem (i) %= nr; + } + retval(1) = NDArray (jdx, true); + } + // Fall through! case 1: case 0: - retval(0) = idx; + retval(0) = NDArray (idx, true); break; } return retval; } -template octave_value_list find_nonzero_elem_idx (const Array<double>&, int, - octave_idx_type, int); - -template octave_value_list find_nonzero_elem_idx (const Array<Complex>&, int, - octave_idx_type, int); +#define INSTANTIATE_FIND_ARRAY(T) \ +template octave_value_list find_nonzero_elem_idx (const Array<T>&, int, \ + octave_idx_type, int) -template octave_value_list find_nonzero_elem_idx (const Array<float>&, int, - octave_idx_type, int); - -template octave_value_list find_nonzero_elem_idx (const Array<FloatComplex>&, - int, octave_idx_type, int); +INSTANTIATE_FIND_ARRAY(double); +INSTANTIATE_FIND_ARRAY(float); +INSTANTIATE_FIND_ARRAY(Complex); +INSTANTIATE_FIND_ARRAY(FloatComplex); +INSTANTIATE_FIND_ARRAY(bool); +INSTANTIATE_FIND_ARRAY(octave_int8); +INSTANTIATE_FIND_ARRAY(octave_int16); +INSTANTIATE_FIND_ARRAY(octave_int32); +INSTANTIATE_FIND_ARRAY(octave_int64); +INSTANTIATE_FIND_ARRAY(octave_uint8); +INSTANTIATE_FIND_ARRAY(octave_uint16); +INSTANTIATE_FIND_ARRAY(octave_uint32); +INSTANTIATE_FIND_ARRAY(octave_uint64); template <typename T> octave_value_list @@ -342,6 +262,9 @@ template octave_value_list find_nonzero_elem_idx (const Sparse<Complex>&, int, octave_idx_type, int); +template octave_value_list find_nonzero_elem_idx (const Sparse<bool>&, int, + octave_idx_type, int); + octave_value_list find_nonzero_elem_idx (const PermMatrix& v, int nargout, octave_idx_type n_to_find, int direction) @@ -561,7 +484,51 @@ octave_value arg = args(0); - if (arg.is_sparse_type ()) + if (arg.is_bool_type ()) + { + if (arg.is_sparse_type ()) + { + SparseBoolMatrix v = arg.sparse_bool_matrix_value (); + + if (! error_state) + retval = find_nonzero_elem_idx (v, nargout, + n_to_find, direction); + } + else + { + boolNDArray v = arg.bool_array_value (); + + if (! error_state) + retval = find_nonzero_elem_idx (v, nargout, + n_to_find, direction); + } + } + else if (arg.is_integer_type ()) + { +#define DO_INT_BRANCH(INTT) \ + else if (arg.is_ ## INTT ## _type ()) \ + { \ + INTT ## NDArray v = arg.INTT ## _array_value (); \ + \ + if (! error_state) \ + retval = find_nonzero_elem_idx (v, nargout, \ + n_to_find, direction);\ + } + + if (false) + ; + DO_INT_BRANCH (int8) + DO_INT_BRANCH (int16) + DO_INT_BRANCH (int32) + DO_INT_BRANCH (int64) + DO_INT_BRANCH (uint8) + DO_INT_BRANCH (uint16) + DO_INT_BRANCH (uint32) + DO_INT_BRANCH (uint64) + else + panic_impossible (); + } + else if (arg.is_sparse_type ()) { if (arg.is_real_type ()) {