comparison liboctave/Array-util.cc @ 9479:d9716e3ee0dd

supply optimized compiled sub2ind & ind2sub
author Jaroslav Hajek <highegg@gmail.com>
date Mon, 03 Aug 2009 15:52:40 +0200
parents 864805896876
children b096d11237be
comparison
equal deleted inserted replaced
9478:7e1e90837fef 9479:d9716e3ee0dd
25 #endif 25 #endif
26 26
27 #include "Array-util.h" 27 #include "Array-util.h"
28 #include "dim-vector.h" 28 #include "dim-vector.h"
29 #include "lo-error.h" 29 #include "lo-error.h"
30 #include "oct-locbuf.h"
30 31
31 bool 32 bool
32 index_in_bounds (const Array<octave_idx_type>& ra_idx, 33 index_in_bounds (const Array<octave_idx_type>& ra_idx,
33 const dim_vector& dimensions) 34 const dim_vector& dimensions)
34 { 35 {
473 } 474 }
474 475
475 return rdv; 476 return rdv;
476 } 477 }
477 478
479 // A helper class.
480 struct sub2ind_helper
481 {
482 octave_idx_type *ind, n;
483 sub2ind_helper (octave_idx_type *_ind, octave_idx_type _n)
484 : ind(_ind), n(_n) { }
485 void operator ()(octave_idx_type k) { (*ind++ *= n) += k; }
486 };
487
488 idx_vector sub2ind (const dim_vector& dv, const Array<idx_vector>& idxa)
489 {
490 idx_vector retval;
491 octave_idx_type len = idxa.length ();
492
493 if (len >= 2)
494 {
495 const dim_vector dvx = dv.redim (len);
496 bool all_ranges = true;
497 octave_idx_type clen = -1;
498
499 for (octave_idx_type i = 0; i < len; i++)
500 {
501 idx_vector idx = idxa(i);
502 octave_idx_type n = dvx(i);
503
504 all_ranges = all_ranges && idx.is_range ();
505 if (clen < 0)
506 clen = idx.length (n);
507 else if (clen != idx.length (n))
508 current_liboctave_error_handler ("sub2ind: lengths of indices must match");
509
510 if (idx.extent (n) > n)
511 current_liboctave_error_handler ("sub2ind: index out of range");
512 }
513
514 if (clen == 1)
515 {
516 // All scalars case - the result is a scalar.
517 octave_idx_type idx = idxa(len-1)(0);
518 for (octave_idx_type i = len - 2; i >= 0; i--)
519 idx = idx * dvx(i) + idxa(i)(0);
520 retval = idx_vector (idx);
521 }
522 else if (all_ranges && clen != 0)
523 {
524 // All ranges case - the result is a range.
525 octave_idx_type start = 0, step = 0;
526 for (octave_idx_type i = len - 1; i >= 0; i--)
527 {
528 octave_idx_type xstart = idxa(i)(0), xstep = idxa(i)(1) - xstart;
529 start = start * dvx(i) + xstart;
530 step = step * dvx(i) + xstep;
531 }
532 retval = idx_vector::make_range (start, step, clen);
533 }
534 else
535 {
536 Array<octave_idx_type> idx (idxa(0).orig_dimensions ());
537 octave_idx_type *idx_vec = idx.fortran_vec ();
538
539 for (octave_idx_type i = len - 1; i >= 0; i--)
540 {
541 if (i < len - 1)
542 idxa(i).loop (clen, sub2ind_helper (idx_vec, dvx(i)));
543 else
544 idxa(i).copy_data (idx_vec);
545 }
546
547 retval = idx_vector (idx);
548 }
549 }
550 else
551 current_liboctave_error_handler ("sub2ind: needs at least 2 indices");
552
553 return retval;
554 }
555
556 Array<idx_vector> ind2sub (const dim_vector& dv, const idx_vector& idx)
557 {
558 octave_idx_type len = idx.length (0), n = dv.length ();
559 Array<idx_vector> retval(n);
560 octave_idx_type numel = dv.numel ();
561
562 if (idx.extent (numel) > numel)
563 current_liboctave_error_handler ("ind2sub: index out of range");
564 else
565 {
566 if (idx.is_scalar ())
567 {
568 octave_idx_type k = idx(0);
569 for (octave_idx_type j = 0; j < n; j++)
570 {
571 retval(j) = k % dv(j);
572 k /= dv(j);
573 }
574 }
575 else
576 {
577 OCTAVE_LOCAL_BUFFER (Array<octave_idx_type>, rdata, n);
578
579 dim_vector odv = idx.orig_dimensions ();
580 for (octave_idx_type j = 0; j < n; j++)
581 rdata[j] = Array<octave_idx_type> (odv);
582
583 for (octave_idx_type i = 0; i < len; i++)
584 {
585 octave_idx_type k = idx(i);
586 for (octave_idx_type j = 0; j < n; j++)
587 {
588 rdata[j](i) = k % dv(j);
589 k /= dv(j);
590 }
591 }
592
593 for (octave_idx_type j = 0; j < n; j++)
594 retval(j) = rdata[j];
595 }
596
597
598 }
599
600 return retval;
601 }
602
478 int 603 int
479 permute_vector_compare (const void *a, const void *b) 604 permute_vector_compare (const void *a, const void *b)
480 { 605 {
481 const permute_vector *pva = static_cast<const permute_vector *> (a); 606 const permute_vector *pva = static_cast<const permute_vector *> (a);
482 const permute_vector *pvb = static_cast<const permute_vector *> (b); 607 const permute_vector *pvb = static_cast<const permute_vector *> (b);