Mercurial > hg > octave-lyh
comparison liboctave/Sparse.cc @ 10512:aac9f4265048
rewrite sparse indexed assignment
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Tue, 13 Apr 2010 12:36:21 +0200 |
parents | ddbd812d09aa |
children | f0266ee4aabe |
comparison
equal
deleted
inserted
replaced
10511:153e6226a669 | 10512:aac9f4265048 |
---|---|
171 } | 171 } |
172 | 172 |
173 template <class T> | 173 template <class T> |
174 template <class U> | 174 template <class U> |
175 Sparse<T>::Sparse (const Sparse<U>& a) | 175 Sparse<T>::Sparse (const Sparse<U>& a) |
176 : dimensions (a.dimensions), idx (0), idx_count (0) | 176 : dimensions (a.dimensions) |
177 { | 177 { |
178 if (a.nnz () == 0) | 178 if (a.nnz () == 0) |
179 rep = new typename Sparse<T>::SparseRep (rows (), cols()); | 179 rep = new typename Sparse<T>::SparseRep (rows (), cols()); |
180 else | 180 else |
181 { | 181 { |
193 } | 193 } |
194 } | 194 } |
195 | 195 |
196 template <class T> | 196 template <class T> |
197 Sparse<T>::Sparse (octave_idx_type nr, octave_idx_type nc, T val) | 197 Sparse<T>::Sparse (octave_idx_type nr, octave_idx_type nc, T val) |
198 : dimensions (dim_vector (nr, nc)), idx (0), idx_count (0) | 198 : dimensions (dim_vector (nr, nc)) |
199 { | 199 { |
200 if (val != T ()) | 200 if (val != T ()) |
201 { | 201 { |
202 rep = new typename Sparse<T>::SparseRep (nr, nc, nr*nc); | 202 rep = new typename Sparse<T>::SparseRep (nr, nc, nr*nc); |
203 | 203 |
221 } | 221 } |
222 } | 222 } |
223 | 223 |
224 template <class T> | 224 template <class T> |
225 Sparse<T>::Sparse (const dim_vector& dv) | 225 Sparse<T>::Sparse (const dim_vector& dv) |
226 : dimensions (dv), idx (0), idx_count (0) | 226 : dimensions (dv) |
227 { | 227 { |
228 if (dv.length() != 2) | 228 if (dv.length() != 2) |
229 (*current_liboctave_error_handler) | 229 (*current_liboctave_error_handler) |
230 ("Sparse::Sparse (const dim_vector&): dimension mismatch"); | 230 ("Sparse::Sparse (const dim_vector&): dimension mismatch"); |
231 else | 231 else |
232 rep = new typename Sparse<T>::SparseRep (dv(0), dv(1)); | 232 rep = new typename Sparse<T>::SparseRep (dv(0), dv(1)); |
233 } | 233 } |
234 | 234 |
235 template <class T> | 235 template <class T> |
236 Sparse<T>::Sparse (const Sparse<T>& a, const dim_vector& dv) | 236 Sparse<T>::Sparse (const Sparse<T>& a, const dim_vector& dv) |
237 : dimensions (dv), idx (0), idx_count (0) | 237 : dimensions (dv) |
238 { | 238 { |
239 | 239 |
240 // Work in unsigned long long to avoid overflow issues with numel | 240 // Work in unsigned long long to avoid overflow issues with numel |
241 unsigned long long a_nel = static_cast<unsigned long long>(a.rows ()) * | 241 unsigned long long a_nel = static_cast<unsigned long long>(a.rows ()) * |
242 static_cast<unsigned long long>(a.cols ()); | 242 static_cast<unsigned long long>(a.cols ()); |
278 | 278 |
279 template <class T> | 279 template <class T> |
280 Sparse<T>::Sparse (const Array<T>& a, const idx_vector& r, | 280 Sparse<T>::Sparse (const Array<T>& a, const idx_vector& r, |
281 const idx_vector& c, octave_idx_type nr, | 281 const idx_vector& c, octave_idx_type nr, |
282 octave_idx_type nc, bool sum_terms) | 282 octave_idx_type nc, bool sum_terms) |
283 : rep (nil_rep ()), dimensions (), idx (0), idx_count (0) | 283 : rep (nil_rep ()), dimensions () |
284 { | 284 { |
285 if (nr < 0) | 285 if (nr < 0) |
286 nr = r.extent (0); | 286 nr = r.extent (0); |
287 else if (r.extent (nr) > nr) | 287 else if (r.extent (nr) > nr) |
288 (*current_liboctave_error_handler) ("sparse: row index %d out of bound %d", | 288 (*current_liboctave_error_handler) ("sparse: row index %d out of bound %d", |
617 } | 617 } |
618 } | 618 } |
619 | 619 |
620 template <class T> | 620 template <class T> |
621 Sparse<T>::Sparse (const Array<T>& a) | 621 Sparse<T>::Sparse (const Array<T>& a) |
622 : dimensions (a.dims ()), idx (0), idx_count (0) | 622 : dimensions (a.dims ()) |
623 { | 623 { |
624 if (dimensions.length () > 2) | 624 if (dimensions.length () > 2) |
625 (*current_liboctave_error_handler) | 625 (*current_liboctave_error_handler) |
626 ("Sparse::Sparse (const Array<T>&): dimension mismatch"); | 626 ("Sparse::Sparse (const Array<T>&): dimension mismatch"); |
627 else | 627 else |
656 template <class T> | 656 template <class T> |
657 Sparse<T>::~Sparse (void) | 657 Sparse<T>::~Sparse (void) |
658 { | 658 { |
659 if (--rep->count <= 0) | 659 if (--rep->count <= 0) |
660 delete rep; | 660 delete rep; |
661 | |
662 delete [] idx; | |
663 } | 661 } |
664 | 662 |
665 template <class T> | 663 template <class T> |
666 Sparse<T>& | 664 Sparse<T>& |
667 Sparse<T>::operator = (const Sparse<T>& a) | 665 Sparse<T>::operator = (const Sparse<T>& a) |
673 | 671 |
674 rep = a.rep; | 672 rep = a.rep; |
675 rep->count++; | 673 rep->count++; |
676 | 674 |
677 dimensions = a.dimensions; | 675 dimensions = a.dimensions; |
678 | |
679 delete [] idx; | |
680 idx_count = 0; | |
681 idx = 0; | |
682 } | 676 } |
683 | 677 |
684 return *this; | 678 return *this; |
685 } | 679 } |
686 | 680 |
890 void | 884 void |
891 Sparse<T>::resize1 (octave_idx_type n) | 885 Sparse<T>::resize1 (octave_idx_type n) |
892 { | 886 { |
893 octave_idx_type nr = rows (), nc = cols (); | 887 octave_idx_type nr = rows (), nc = cols (); |
894 | 888 |
895 if (nr == 1 || nr == 0) | 889 if (nr == 0) |
890 resize (1, std::max (nc, n)); | |
891 else if (nc == 0) | |
892 // FIXME: Due to Matlab 2007a, but some existing tests fail on this. | |
893 resize (nr, (n + nr - 1) / nr); | |
894 else if (nr == 1) | |
896 resize (1, n); | 895 resize (1, n); |
897 else if (nc == 1) | 896 else if (nc == 1) |
898 resize (n, 1); | 897 resize (n, 1); |
899 else | 898 else |
900 gripe_invalid_resize (); | 899 gripe_invalid_resize (); |
1098 assert (nnz () == retval.xcidx (nr)); | 1097 assert (nnz () == retval.xcidx (nr)); |
1099 // retval.xcidx[1:nr] holds row entry *end* offsets for rows 0:(nr-1) | 1098 // retval.xcidx[1:nr] holds row entry *end* offsets for rows 0:(nr-1) |
1100 // and retval.xcidx[0:(nr-1)] holds their row entry *start* offsets | 1099 // and retval.xcidx[0:(nr-1)] holds their row entry *start* offsets |
1101 | 1100 |
1102 return retval; | 1101 return retval; |
1103 } | |
1104 | |
1105 template <class T> | |
1106 void | |
1107 Sparse<T>::clear_index (void) | |
1108 { | |
1109 delete [] idx; | |
1110 idx = 0; | |
1111 idx_count = 0; | |
1112 } | |
1113 | |
1114 template <class T> | |
1115 void | |
1116 Sparse<T>::set_index (const idx_vector& idx_arg) | |
1117 { | |
1118 octave_idx_type nd = ndims (); | |
1119 | |
1120 if (! idx && nd > 0) | |
1121 idx = new idx_vector [nd]; | |
1122 | |
1123 if (idx_count < nd) | |
1124 { | |
1125 idx[idx_count++] = idx_arg; | |
1126 } | |
1127 else | |
1128 { | |
1129 idx_vector *new_idx = new idx_vector [idx_count+1]; | |
1130 | |
1131 for (octave_idx_type i = 0; i < idx_count; i++) | |
1132 new_idx[i] = idx[i]; | |
1133 | |
1134 new_idx[idx_count++] = idx_arg; | |
1135 | |
1136 delete [] idx; | |
1137 | |
1138 idx = new_idx; | |
1139 } | |
1140 } | 1102 } |
1141 | 1103 |
1142 // Lower bound lookup. Could also use octave_sort, but that has upper bound | 1104 // Lower bound lookup. Could also use octave_sort, but that has upper bound |
1143 // semantics, so requires some manipulation to set right. Uses a plain loop for | 1105 // semantics, so requires some manipulation to set right. Uses a plain loop for |
1144 // small columns. | 1106 // small columns. |
1336 { | 1298 { |
1337 (*current_liboctave_error_handler) | 1299 (*current_liboctave_error_handler) |
1338 ("invalid dimension in delete_elements"); | 1300 ("invalid dimension in delete_elements"); |
1339 return; | 1301 return; |
1340 } | 1302 } |
1341 } | |
1342 | |
1343 template <class T> | |
1344 Sparse<T> | |
1345 Sparse<T>::value (void) | |
1346 { | |
1347 Sparse<T> retval; | |
1348 | |
1349 int n_idx = index_count (); | |
1350 | |
1351 if (n_idx == 2) | |
1352 { | |
1353 idx_vector *tmp = get_idx (); | |
1354 | |
1355 idx_vector idx_i = tmp[0]; | |
1356 idx_vector idx_j = tmp[1]; | |
1357 | |
1358 retval = index (idx_i, idx_j); | |
1359 } | |
1360 else if (n_idx == 1) | |
1361 { | |
1362 retval = index (idx[0]); | |
1363 } | |
1364 else | |
1365 (*current_liboctave_error_handler) | |
1366 ("Sparse<T>::value: invalid number of indices specified"); | |
1367 | |
1368 clear_index (); | |
1369 | |
1370 return retval; | |
1371 } | 1303 } |
1372 | 1304 |
1373 template <class T> | 1305 template <class T> |
1374 Sparse<T> | 1306 Sparse<T> |
1375 Sparse<T>::index (const idx_vector& idx, bool resize_ok) const | 1307 Sparse<T>::index (const idx_vector& idx, bool resize_ok) const |
1761 } | 1693 } |
1762 | 1694 |
1763 return retval; | 1695 return retval; |
1764 } | 1696 } |
1765 | 1697 |
1698 template <class T> | |
1699 void | |
1700 Sparse<T>::assign (const idx_vector& idx, const Sparse<T>& rhs) | |
1701 { | |
1702 Sparse<T> retval; | |
1703 | |
1704 assert (ndims () == 2); | |
1705 | |
1706 // FIXME: please don't fix the shadowed member warning yet because | |
1707 // Sparse<T>::idx will eventually go away. | |
1708 | |
1709 octave_idx_type nr = dim1 (); | |
1710 octave_idx_type nc = dim2 (); | |
1711 octave_idx_type nz = nnz (); | |
1712 | |
1713 octave_idx_type n = numel (); // Can throw. | |
1714 | |
1715 octave_idx_type rhl = rhs.numel (); | |
1716 | |
1717 if (idx.length (n) == rhl) | |
1718 { | |
1719 if (rhl == 0) | |
1720 return; | |
1721 | |
1722 octave_idx_type nx = idx.extent (n); | |
1723 // Try to resize first if necessary. | |
1724 if (nx != n) | |
1725 { | |
1726 resize1 (nx); | |
1727 n = numel (); | |
1728 nr = rows (); | |
1729 nc = cols (); | |
1730 // nz is preserved. | |
1731 } | |
1732 | |
1733 if (idx.is_colon ()) | |
1734 { | |
1735 *this = rhs.reshape (dimensions); | |
1736 } | |
1737 else if (nc == 1 && rhs.cols () == 1) | |
1738 { | |
1739 // Sparse column vector to sparse column vector assignment. | |
1740 | |
1741 octave_idx_type lb, ub; | |
1742 if (idx.is_cont_range (nr, lb, ub)) | |
1743 { | |
1744 // Special-case a contiguous range. | |
1745 // Look-up indices first. | |
1746 octave_idx_type li = lblookup (ridx (), nz, lb); | |
1747 octave_idx_type ui = lblookup (ridx (), nz, ub); | |
1748 octave_idx_type rnz = rhs.nnz (), new_nz = nz - (ui - li) + rnz; | |
1749 | |
1750 if (new_nz >= nz && new_nz <= capacity ()) | |
1751 { | |
1752 // Adding/overwriting elements, enough capacity allocated. | |
1753 | |
1754 if (new_nz > nz) | |
1755 { | |
1756 // Make room first. | |
1757 std::copy_backward (data () + ui, data () + nz, data () + li + rnz); | |
1758 std::copy_backward (ridx () + ui, ridx () + nz, ridx () + li + rnz); | |
1759 } | |
1760 | |
1761 // Copy data and adjust indices from rhs. | |
1762 copy_or_memcpy (rnz, rhs.data (), data () + li); | |
1763 mx_inline_add (rnz, ridx () + li, rhs.ridx (), lb); | |
1764 } | |
1765 else | |
1766 { | |
1767 // Clearing elements or exceeding capacity, allocate afresh | |
1768 // and paste pieces. | |
1769 const Sparse<T> tmp = *this; | |
1770 *this = Sparse<T> (nr, 1, new_nz); | |
1771 | |
1772 // Head ... | |
1773 copy_or_memcpy (li, tmp.data (), data ()); | |
1774 copy_or_memcpy (li, tmp.ridx (), ridx ()); | |
1775 | |
1776 // new stuff ... | |
1777 copy_or_memcpy (rnz, rhs.data (), data () + li); | |
1778 mx_inline_add (rnz, ridx () + li, rhs.ridx (), lb); | |
1779 | |
1780 // ...tail | |
1781 copy_or_memcpy (nz - ui, data () + ui, data () + li + rnz); | |
1782 copy_or_memcpy (nz - ui, ridx () + ui, ridx () + li + rnz); | |
1783 } | |
1784 | |
1785 cidx(1) = new_nz; | |
1786 } | |
1787 else if (idx.is_range () && idx.increment () == -1) | |
1788 { | |
1789 // It's s(u:-1:l) = r. Reverse the assignment. | |
1790 assign (idx.sorted (), rhs.index (idx_vector (rhl - 1, 0, -1))); | |
1791 } | |
1792 else if (idx.is_permutation (n)) | |
1793 { | |
1794 *this = rhs.index (idx.inverse_permutation (n)); | |
1795 } | |
1796 else if (rhs.nnz () == 0) | |
1797 { | |
1798 // Elements are being zeroed. | |
1799 octave_idx_type *ri = ridx (); | |
1800 for (octave_idx_type i = 0; i < rhl; i++) | |
1801 { | |
1802 octave_idx_type iidx = idx(i); | |
1803 octave_idx_type li = lblookup (ri, nz, iidx); | |
1804 if (li != nz && ri[li] == iidx) | |
1805 xdata(li) = T(); | |
1806 } | |
1807 | |
1808 maybe_compress (true); | |
1809 } | |
1810 else | |
1811 { | |
1812 const Sparse<T> tmp = *this; | |
1813 octave_idx_type new_nz = nz + rhl; | |
1814 // Disassembly our matrix... | |
1815 Array<octave_idx_type> new_ri (new_nz, 1); | |
1816 Array<T> new_data (new_nz, 1); | |
1817 copy_or_memcpy (nz, tmp.ridx (), new_ri.fortran_vec ()); | |
1818 copy_or_memcpy (nz, tmp.data (), new_data.fortran_vec ()); | |
1819 // ... insert new data (densified) ... | |
1820 idx.copy_data (new_ri.fortran_vec () + nz); | |
1821 new_data.assign (idx_vector (nz, new_nz), rhs.array_value ()); | |
1822 // ... reassembly. | |
1823 *this = Sparse<T> (new_data, new_ri, 0, nr, nc, false); | |
1824 } | |
1825 } | |
1826 else | |
1827 { | |
1828 dim_vector save_dims = dimensions; | |
1829 *this = index (idx_vector::colon); | |
1830 assign (idx, rhs.index (idx_vector::colon)); | |
1831 *this = reshape (save_dims); | |
1832 } | |
1833 } | |
1834 else if (rhl == 1) | |
1835 { | |
1836 rhl = idx.length (n); | |
1837 if (rhs.nnz () != 0) | |
1838 assign (idx, Sparse<T> (rhl, 1, rhs.data (0))); | |
1839 else | |
1840 assign (idx, Sparse<T> (rhl, 1)); | |
1841 } | |
1842 else | |
1843 gripe_invalid_assignment_size (); | |
1844 } | |
1845 | |
1846 template <class T> | |
1847 void | |
1848 Sparse<T>::assign (const idx_vector& idx_i, | |
1849 const idx_vector& idx_j, const Sparse<T>& rhs) | |
1850 { | |
1851 Sparse<T> retval; | |
1852 | |
1853 assert (ndims () == 2); | |
1854 | |
1855 // FIXME: please don't fix the shadowed member warning yet because | |
1856 // Sparse<T>::idx will eventually go away. | |
1857 | |
1858 octave_idx_type nr = dim1 (); | |
1859 octave_idx_type nc = dim2 (); | |
1860 octave_idx_type nz = nnz (); | |
1861 | |
1862 octave_idx_type n = rhs.rows (); | |
1863 octave_idx_type m = rhs.columns (); | |
1864 | |
1865 if (idx_i.length (nr) == n && idx_j.length (nc) == m) | |
1866 { | |
1867 if (n == 0 || m == 0) | |
1868 return; | |
1869 | |
1870 octave_idx_type nrx = idx_i.extent (nr), ncx = idx_j.extent (nc); | |
1871 // Try to resize first if necessary. | |
1872 if (nrx != nr || ncx != nc) | |
1873 { | |
1874 resize (nrx, ncx); | |
1875 nr = rows (); | |
1876 nc = cols (); | |
1877 // nz is preserved. | |
1878 } | |
1879 | |
1880 if (idx_i.is_colon ()) | |
1881 { | |
1882 octave_idx_type lb, ub; | |
1883 // Great, we're just manipulating columns. This is going to be quite | |
1884 // efficient, because the columns can stay compressed as they are. | |
1885 if (idx_j.is_colon ()) | |
1886 *this = rhs; // Shallow copy. | |
1887 else if (idx_j.is_cont_range (nc, lb, ub)) | |
1888 { | |
1889 // Special-case a contiguous range. | |
1890 octave_idx_type li = cidx(lb), ui = cidx(ub); | |
1891 octave_idx_type rnz = rhs.nnz (), new_nz = nz - (ui - li) + rnz; | |
1892 | |
1893 if (new_nz >= nz && new_nz <= capacity ()) | |
1894 { | |
1895 // Adding/overwriting elements, enough capacity allocated. | |
1896 | |
1897 if (new_nz > nz) | |
1898 { | |
1899 // Make room first. | |
1900 std::copy_backward (data () + ui, data () + nz, data () + li + rnz); | |
1901 std::copy_backward (ridx () + ui, ridx () + nz, ridx () + li + rnz); | |
1902 mx_inline_add2 (nc - ub, cidx () + ub + 1, new_nz - nz); | |
1903 } | |
1904 | |
1905 // Copy data and indices from rhs. | |
1906 copy_or_memcpy (rnz, rhs.data (), data () + li); | |
1907 copy_or_memcpy (rnz, rhs.ridx (), ridx () + li); | |
1908 mx_inline_add (ub - lb, cidx () + lb + 1, rhs.cidx () + 1, li); | |
1909 | |
1910 assert (nnz () == new_nz); | |
1911 } | |
1912 else | |
1913 { | |
1914 // Clearing elements or exceeding capacity, allocate afresh | |
1915 // and paste pieces. | |
1916 const Sparse<T> tmp = *this; | |
1917 *this = Sparse<T> (nr, nc, new_nz); | |
1918 | |
1919 // Head... | |
1920 copy_or_memcpy (li, tmp.data (), data ()); | |
1921 copy_or_memcpy (li, tmp.ridx (), ridx ()); | |
1922 copy_or_memcpy (lb, tmp.cidx () + 1, cidx () + 1); | |
1923 | |
1924 // new stuff... | |
1925 copy_or_memcpy (rnz, rhs.data (), data () + li); | |
1926 copy_or_memcpy (rnz, rhs.ridx (), ridx () + li); | |
1927 mx_inline_add (ub - lb, cidx () + lb + 1, rhs.cidx () + 1, li); | |
1928 | |
1929 // ...tail. | |
1930 copy_or_memcpy (nz - ui, tmp.data () + ui, data () + li + rnz); | |
1931 copy_or_memcpy (nz - ui, tmp.ridx () + ui, ridx () + li + rnz); | |
1932 mx_inline_add (nc - ub, cidx () + ub + 1, tmp.cidx () + ub + 1, new_nz - nz); | |
1933 | |
1934 assert (nnz () == new_nz); | |
1935 } | |
1936 } | |
1937 else if (idx_j.is_range () && idx_j.increment () == -1) | |
1938 { | |
1939 // It's s(:,u:-1:l) = r. Reverse the assignment. | |
1940 assign (idx_i, idx_j.sorted (), rhs.index (idx_i, idx_vector (m - 1, 0, -1))); | |
1941 } | |
1942 else if (idx_j.is_permutation (nc)) | |
1943 { | |
1944 *this = rhs.index (idx_i, idx_j.inverse_permutation (nc)); | |
1945 } | |
1946 else | |
1947 { | |
1948 const Sparse<T> tmp = *this; | |
1949 *this = Sparse<T> (nr, nc); | |
1950 OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, jsav, nc, -1); | |
1951 | |
1952 // Assemble column lengths. | |
1953 for (octave_idx_type i = 0; i < nc; i++) | |
1954 xcidx(i+1) = tmp.cidx(i+1) - tmp.cidx(i); | |
1955 | |
1956 for (octave_idx_type i = 0; i < m; i++) | |
1957 { | |
1958 octave_idx_type j =idx_j(i); | |
1959 jsav[j] = i; | |
1960 xcidx(j+1) = rhs.cidx(i+1) - rhs.cidx(i); | |
1961 } | |
1962 | |
1963 // Make cumulative. | |
1964 for (octave_idx_type i = 0; i < nc; i++) | |
1965 xcidx(i+1) += xcidx(i); | |
1966 | |
1967 change_capacity (nnz ()); | |
1968 | |
1969 // Merge columns. | |
1970 for (octave_idx_type i = 0; i < nc; i++) | |
1971 { | |
1972 octave_idx_type l = xcidx(i), u = xcidx(i+1), j = jsav[i]; | |
1973 if (j >= 0) | |
1974 { | |
1975 // from rhs | |
1976 octave_idx_type k = rhs.cidx(j); | |
1977 copy_or_memcpy (u - l, rhs.data () + k, xdata () + l); | |
1978 copy_or_memcpy (u - l, rhs.ridx () + k, xridx () + l); | |
1979 } | |
1980 else | |
1981 { | |
1982 // original | |
1983 octave_idx_type k = tmp.cidx(i); | |
1984 copy_or_memcpy (u - l, tmp.data () + k, xdata () + l); | |
1985 copy_or_memcpy (u - l, tmp.ridx () + k, xridx () + l); | |
1986 } | |
1987 } | |
1988 | |
1989 } | |
1990 } | |
1991 else if (idx_j.is_colon ()) | |
1992 { | |
1993 if (idx_i.is_permutation (nr)) | |
1994 { | |
1995 *this = rhs.index (idx_i.inverse_permutation (nr), idx_j); | |
1996 } | |
1997 else | |
1998 { | |
1999 // FIXME: optimize more special cases? | |
2000 // In general this requires unpacking the columns, which is slow, | |
2001 // especially for many small columns. OTOH, transpose is an | |
2002 // efficient O(nr+nc+nnz) operation. | |
2003 *this = transpose (); | |
2004 assign (idx_vector::colon, idx_i, rhs.transpose ()); | |
2005 *this = transpose (); | |
2006 } | |
2007 } | |
2008 else | |
2009 { | |
2010 // Split it into 2 assignments and one indexing. | |
2011 Sparse<T> tmp = index (idx_vector::colon, idx_j); | |
2012 tmp.assign (idx_i, idx_vector::colon, rhs); | |
2013 assign (idx_vector::colon, idx_j, tmp); | |
2014 } | |
2015 } | |
2016 else if (m == 1 && n == 1) | |
2017 { | |
2018 n = idx_i.length (nr); | |
2019 m = idx_j.length (nc); | |
2020 if (rhs.nnz () != 0) | |
2021 assign (idx_i, idx_j, Sparse<T> (n, m, rhs.data (0))); | |
2022 else | |
2023 assign (idx_i, idx_j, Sparse<T> (n, m)); | |
2024 } | |
2025 else | |
2026 gripe_assignment_dimension_mismatch (); | |
2027 } | |
2028 | |
1766 // Can't use versions of these in Array.cc due to duplication of the | 2029 // Can't use versions of these in Array.cc due to duplication of the |
1767 // instantiations for Array<double and Sparse<double>, etc | 2030 // instantiations for Array<double and Sparse<double>, etc |
1768 template <class T> | 2031 template <class T> |
1769 bool | 2032 bool |
1770 sparse_ascending_compare (typename ref_param<T>::type a, | 2033 sparse_ascending_compare (typename ref_param<T>::type a, |
2136 { | 2399 { |
2137 for (octave_idx_type j = 0, nc = cols (); j < nc; j++) | 2400 for (octave_idx_type j = 0, nc = cols (); j < nc; j++) |
2138 for (octave_idx_type i = cidx(j), iu = cidx(j+1); i < iu; i++) | 2401 for (octave_idx_type i = cidx(j), iu = cidx(j+1); i < iu; i++) |
2139 retval (ridx(i), j) = data (i); | 2402 retval (ridx(i), j) = data (i); |
2140 } | 2403 } |
2141 | |
2142 return retval; | |
2143 } | |
2144 | |
2145 // FIXME | |
2146 // Unfortunately numel can overflow for very large but very sparse matrices. | |
2147 // For now just flag an error when this happens. | |
2148 template <class LT, class RT> | |
2149 int | |
2150 assign1 (Sparse<LT>& lhs, const Sparse<RT>& rhs) | |
2151 { | |
2152 int retval = 1; | |
2153 | |
2154 idx_vector *idx_tmp = lhs.get_idx (); | |
2155 | |
2156 idx_vector lhs_idx = idx_tmp[0]; | |
2157 | |
2158 octave_idx_type lhs_len = lhs.numel (); | |
2159 octave_idx_type rhs_len = rhs.numel (); | |
2160 | |
2161 uint64_t long_lhs_len = | |
2162 static_cast<uint64_t> (lhs.rows ()) * | |
2163 static_cast<uint64_t> (lhs.cols ()); | |
2164 | |
2165 uint64_t long_rhs_len = | |
2166 static_cast<uint64_t> (rhs.rows ()) * | |
2167 static_cast<uint64_t> (rhs.cols ()); | |
2168 | |
2169 if (long_rhs_len != static_cast<uint64_t>(rhs_len) || | |
2170 long_lhs_len != static_cast<uint64_t>(lhs_len)) | |
2171 { | |
2172 (*current_liboctave_error_handler) | |
2173 ("A(I) = X: Matrix dimensions too large to ensure correct\n", | |
2174 "operation. This is an limitation that should be removed\n", | |
2175 "in the future."); | |
2176 | |
2177 lhs.clear_index (); | |
2178 return 0; | |
2179 } | |
2180 | |
2181 octave_idx_type nr = lhs.rows (); | |
2182 octave_idx_type nc = lhs.cols (); | |
2183 octave_idx_type nz = lhs.nnz (); | |
2184 | |
2185 octave_idx_type n = lhs_idx.freeze (lhs_len, "vector", true); | |
2186 | |
2187 if (n != 0) | |
2188 { | |
2189 octave_idx_type max_idx = lhs_idx.max () + 1; | |
2190 max_idx = max_idx < lhs_len ? lhs_len : max_idx; | |
2191 | |
2192 // Take a constant copy of lhs. This means that elem won't | |
2193 // create missing elements. | |
2194 const Sparse<LT> c_lhs (lhs); | |
2195 | |
2196 if (rhs_len == n) | |
2197 { | |
2198 octave_idx_type new_nzmx = lhs.nnz (); | |
2199 | |
2200 OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx, n); | |
2201 if (! lhs_idx.is_colon ()) | |
2202 { | |
2203 // Ok here we have to be careful with the indexing, | |
2204 // to treat cases like "a([3,2,1]) = b", and still | |
2205 // handle the need for strict sorting of the sparse | |
2206 // elements. | |
2207 OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *, sidx, n); | |
2208 OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort, sidxX, n); | |
2209 | |
2210 for (octave_idx_type i = 0; i < n; i++) | |
2211 { | |
2212 sidx[i] = &sidxX[i]; | |
2213 sidx[i]->i = lhs_idx.elem(i); | |
2214 sidx[i]->idx = i; | |
2215 } | |
2216 | |
2217 octave_quit (); | |
2218 octave_sort<octave_idx_vector_sort *> | |
2219 sort (octave_idx_vector_comp); | |
2220 | |
2221 sort.sort (sidx, n); | |
2222 | |
2223 intNDArray<octave_idx_type> new_idx (dim_vector (n,1)); | |
2224 | |
2225 for (octave_idx_type i = 0; i < n; i++) | |
2226 { | |
2227 new_idx.xelem(i) = sidx[i]->i; | |
2228 rhs_idx[i] = sidx[i]->idx; | |
2229 } | |
2230 | |
2231 lhs_idx = idx_vector (new_idx); | |
2232 } | |
2233 else | |
2234 for (octave_idx_type i = 0; i < n; i++) | |
2235 rhs_idx[i] = i; | |
2236 | |
2237 // First count the number of non-zero elements | |
2238 for (octave_idx_type i = 0; i < n; i++) | |
2239 { | |
2240 octave_quit (); | |
2241 | |
2242 octave_idx_type ii = lhs_idx.elem (i); | |
2243 if (i < n - 1 && lhs_idx.elem (i + 1) == ii) | |
2244 continue; | |
2245 if (ii < lhs_len && c_lhs.elem(ii) != LT ()) | |
2246 new_nzmx--; | |
2247 if (rhs.elem(rhs_idx[i]) != RT ()) | |
2248 new_nzmx++; | |
2249 } | |
2250 | |
2251 if (nr > 1) | |
2252 { | |
2253 Sparse<LT> tmp ((max_idx > nr ? max_idx : nr), 1, new_nzmx); | |
2254 tmp.cidx(0) = 0; | |
2255 tmp.cidx(1) = new_nzmx; | |
2256 | |
2257 octave_idx_type i = 0; | |
2258 octave_idx_type ii = 0; | |
2259 if (i < nz) | |
2260 ii = c_lhs.ridx(i); | |
2261 | |
2262 octave_idx_type j = 0; | |
2263 octave_idx_type jj = lhs_idx.elem(j); | |
2264 | |
2265 octave_idx_type kk = 0; | |
2266 | |
2267 while (j < n || i < nz) | |
2268 { | |
2269 if (j < n - 1 && lhs_idx.elem (j + 1) == jj) | |
2270 { | |
2271 j++; | |
2272 jj = lhs_idx.elem (j); | |
2273 continue; | |
2274 } | |
2275 if (j == n || (i < nz && ii < jj)) | |
2276 { | |
2277 tmp.xdata (kk) = c_lhs.data (i); | |
2278 tmp.xridx (kk++) = ii; | |
2279 if (++i < nz) | |
2280 ii = c_lhs.ridx(i); | |
2281 } | |
2282 else | |
2283 { | |
2284 RT rtmp = rhs.elem (rhs_idx[j]); | |
2285 if (rtmp != RT ()) | |
2286 { | |
2287 tmp.xdata (kk) = rtmp; | |
2288 tmp.xridx (kk++) = jj; | |
2289 } | |
2290 | |
2291 if (ii == jj && i < nz) | |
2292 if (++i < nz) | |
2293 ii = c_lhs.ridx(i); | |
2294 if (++j < n) | |
2295 jj = lhs_idx.elem(j); | |
2296 } | |
2297 } | |
2298 | |
2299 lhs = tmp; | |
2300 } | |
2301 else | |
2302 { | |
2303 Sparse<LT> tmp (1, (max_idx > nc ? max_idx : nc), new_nzmx); | |
2304 | |
2305 octave_idx_type i = 0; | |
2306 octave_idx_type ii = 0; | |
2307 while (ii < nc && c_lhs.cidx(ii+1) <= i) | |
2308 ii++; | |
2309 | |
2310 octave_idx_type j = 0; | |
2311 octave_idx_type jj = lhs_idx.elem(j); | |
2312 | |
2313 octave_idx_type kk = 0; | |
2314 octave_idx_type ic = 0; | |
2315 | |
2316 while (j < n || i < nz) | |
2317 { | |
2318 if (j < n - 1 && lhs_idx.elem (j + 1) == jj) | |
2319 { | |
2320 j++; | |
2321 jj = lhs_idx.elem (j); | |
2322 continue; | |
2323 } | |
2324 if (j == n || (i < nz && ii < jj)) | |
2325 { | |
2326 while (ic <= ii) | |
2327 tmp.xcidx (ic++) = kk; | |
2328 tmp.xdata (kk) = c_lhs.data (i); | |
2329 tmp.xridx (kk++) = 0; | |
2330 i++; | |
2331 while (ii < nc && c_lhs.cidx(ii+1) <= i) | |
2332 ii++; | |
2333 } | |
2334 else | |
2335 { | |
2336 while (ic <= jj) | |
2337 tmp.xcidx (ic++) = kk; | |
2338 | |
2339 RT rtmp = rhs.elem (rhs_idx[j]); | |
2340 if (rtmp != RT ()) | |
2341 { | |
2342 tmp.xdata (kk) = rtmp; | |
2343 tmp.xridx (kk++) = 0; | |
2344 } | |
2345 if (ii == jj) | |
2346 { | |
2347 i++; | |
2348 while (ii < nc && c_lhs.cidx(ii+1) <= i) | |
2349 ii++; | |
2350 } | |
2351 j++; | |
2352 if (j < n) | |
2353 jj = lhs_idx.elem(j); | |
2354 } | |
2355 } | |
2356 | |
2357 for (octave_idx_type iidx = ic; iidx < max_idx+1; iidx++) | |
2358 tmp.xcidx(iidx) = kk; | |
2359 | |
2360 lhs = tmp; | |
2361 } | |
2362 } | |
2363 else if (rhs_len == 1) | |
2364 { | |
2365 octave_idx_type new_nzmx = lhs.nnz (); | |
2366 RT scalar = rhs.elem (0); | |
2367 bool scalar_non_zero = (scalar != RT ()); | |
2368 lhs_idx.sort (true); | |
2369 n = lhs_idx.length (n); | |
2370 | |
2371 // First count the number of non-zero elements | |
2372 if (scalar != RT ()) | |
2373 new_nzmx += n; | |
2374 for (octave_idx_type i = 0; i < n; i++) | |
2375 { | |
2376 octave_quit (); | |
2377 | |
2378 octave_idx_type ii = lhs_idx.elem (i); | |
2379 if (ii < lhs_len && c_lhs.elem(ii) != LT ()) | |
2380 new_nzmx--; | |
2381 } | |
2382 | |
2383 if (nr > 1) | |
2384 { | |
2385 Sparse<LT> tmp ((max_idx > nr ? max_idx : nr), 1, new_nzmx); | |
2386 tmp.cidx(0) = 0; | |
2387 tmp.cidx(1) = new_nzmx; | |
2388 | |
2389 octave_idx_type i = 0; | |
2390 octave_idx_type ii = 0; | |
2391 if (i < nz) | |
2392 ii = c_lhs.ridx(i); | |
2393 | |
2394 octave_idx_type j = 0; | |
2395 octave_idx_type jj = lhs_idx.elem(j); | |
2396 | |
2397 octave_idx_type kk = 0; | |
2398 | |
2399 while (j < n || i < nz) | |
2400 { | |
2401 if (j == n || (i < nz && ii < jj)) | |
2402 { | |
2403 tmp.xdata (kk) = c_lhs.data (i); | |
2404 tmp.xridx (kk++) = ii; | |
2405 if (++i < nz) | |
2406 ii = c_lhs.ridx(i); | |
2407 } | |
2408 else | |
2409 { | |
2410 if (scalar_non_zero) | |
2411 { | |
2412 tmp.xdata (kk) = scalar; | |
2413 tmp.xridx (kk++) = jj; | |
2414 } | |
2415 | |
2416 if (ii == jj && i < nz) | |
2417 if (++i < nz) | |
2418 ii = c_lhs.ridx(i); | |
2419 if (++j < n) | |
2420 jj = lhs_idx.elem(j); | |
2421 } | |
2422 } | |
2423 | |
2424 lhs = tmp; | |
2425 } | |
2426 else | |
2427 { | |
2428 Sparse<LT> tmp (1, (max_idx > nc ? max_idx : nc), new_nzmx); | |
2429 | |
2430 octave_idx_type i = 0; | |
2431 octave_idx_type ii = 0; | |
2432 while (ii < nc && c_lhs.cidx(ii+1) <= i) | |
2433 ii++; | |
2434 | |
2435 octave_idx_type j = 0; | |
2436 octave_idx_type jj = lhs_idx.elem(j); | |
2437 | |
2438 octave_idx_type kk = 0; | |
2439 octave_idx_type ic = 0; | |
2440 | |
2441 while (j < n || i < nz) | |
2442 { | |
2443 if (j == n || (i < nz && ii < jj)) | |
2444 { | |
2445 while (ic <= ii) | |
2446 tmp.xcidx (ic++) = kk; | |
2447 tmp.xdata (kk) = c_lhs.data (i); | |
2448 i++; | |
2449 while (ii < nc && c_lhs.cidx(ii+1) <= i) | |
2450 ii++; | |
2451 tmp.xridx (kk++) = 0; | |
2452 } | |
2453 else | |
2454 { | |
2455 while (ic <= jj) | |
2456 tmp.xcidx (ic++) = kk; | |
2457 if (scalar_non_zero) | |
2458 { | |
2459 tmp.xdata (kk) = scalar; | |
2460 tmp.xridx (kk++) = 0; | |
2461 } | |
2462 if (ii == jj) | |
2463 { | |
2464 i++; | |
2465 while (ii < nc && c_lhs.cidx(ii+1) <= i) | |
2466 ii++; | |
2467 } | |
2468 j++; | |
2469 if (j < n) | |
2470 jj = lhs_idx.elem(j); | |
2471 } | |
2472 } | |
2473 | |
2474 for (octave_idx_type iidx = ic; iidx < max_idx+1; iidx++) | |
2475 tmp.xcidx(iidx) = kk; | |
2476 | |
2477 lhs = tmp; | |
2478 } | |
2479 } | |
2480 else | |
2481 { | |
2482 (*current_liboctave_error_handler) | |
2483 ("A(I) = X: X must be a scalar or a vector with same length as I"); | |
2484 | |
2485 retval = 0; | |
2486 } | |
2487 } | |
2488 else if (lhs_idx.is_colon ()) | |
2489 { | |
2490 if (lhs_len == 0) | |
2491 { | |
2492 | |
2493 octave_idx_type new_nzmx = rhs.nnz (); | |
2494 Sparse<LT> tmp (1, rhs_len, new_nzmx); | |
2495 | |
2496 octave_idx_type ii = 0; | |
2497 octave_idx_type jj = 0; | |
2498 for (octave_idx_type i = 0; i < rhs.cols(); i++) | |
2499 for (octave_idx_type j = rhs.cidx(i); j < rhs.cidx(i+1); j++) | |
2500 { | |
2501 octave_quit (); | |
2502 for (octave_idx_type k = jj; k <= i * rhs.rows() + rhs.ridx(j); k++) | |
2503 tmp.cidx(jj++) = ii; | |
2504 | |
2505 tmp.data(ii) = rhs.data(j); | |
2506 tmp.ridx(ii++) = 0; | |
2507 } | |
2508 | |
2509 for (octave_idx_type i = jj; i < rhs_len + 1; i++) | |
2510 tmp.cidx(i) = ii; | |
2511 | |
2512 lhs = tmp; | |
2513 } | |
2514 else | |
2515 (*current_liboctave_error_handler) | |
2516 ("A(:) = X: A must be the same size as X"); | |
2517 } | |
2518 else if (! (rhs_len == 1 || rhs_len == 0)) | |
2519 { | |
2520 (*current_liboctave_error_handler) | |
2521 ("A([]) = X: X must also be an empty matrix or a scalar"); | |
2522 | |
2523 retval = 0; | |
2524 } | |
2525 | |
2526 lhs.clear_index (); | |
2527 | |
2528 return retval; | |
2529 } | |
2530 | |
2531 template <class LT, class RT> | |
2532 int | |
2533 assign (Sparse<LT>& lhs, const Sparse<RT>& rhs) | |
2534 { | |
2535 int retval = 1; | |
2536 | |
2537 int n_idx = lhs.index_count (); | |
2538 | |
2539 octave_idx_type lhs_nr = lhs.rows (); | |
2540 octave_idx_type lhs_nc = lhs.cols (); | |
2541 octave_idx_type lhs_nz = lhs.nnz (); | |
2542 | |
2543 octave_idx_type rhs_nr = rhs.rows (); | |
2544 octave_idx_type rhs_nc = rhs.cols (); | |
2545 | |
2546 idx_vector *tmp = lhs.get_idx (); | |
2547 | |
2548 idx_vector idx_i; | |
2549 idx_vector idx_j; | |
2550 | |
2551 if (n_idx > 2) | |
2552 { | |
2553 (*current_liboctave_error_handler) | |
2554 ("A(I, J) = X: can only have 1 or 2 indexes for sparse matrices"); | |
2555 | |
2556 lhs.clear_index (); | |
2557 return 0; | |
2558 } | |
2559 | |
2560 if (n_idx > 1) | |
2561 idx_j = tmp[1]; | |
2562 | |
2563 if (n_idx > 0) | |
2564 idx_i = tmp[0]; | |
2565 | |
2566 // Take a constant copy of lhs. This means that ridx and family won't | |
2567 // call make_unique. | |
2568 const Sparse<LT> c_lhs (lhs); | |
2569 | |
2570 if (n_idx == 2) | |
2571 { | |
2572 octave_idx_type n = idx_i.freeze (lhs_nr, "row", true); | |
2573 octave_idx_type m = idx_j.freeze (lhs_nc, "column", true); | |
2574 | |
2575 int idx_i_is_colon = idx_i.is_colon (); | |
2576 int idx_j_is_colon = idx_j.is_colon (); | |
2577 | |
2578 if (lhs_nr == 0 && lhs_nc == 0) | |
2579 { | |
2580 if (idx_i_is_colon) | |
2581 n = rhs_nr; | |
2582 | |
2583 if (idx_j_is_colon) | |
2584 m = rhs_nc; | |
2585 } | |
2586 | |
2587 if (idx_i && idx_j) | |
2588 { | |
2589 if (rhs_nr == 1 && rhs_nc == 1 && n >= 0 && m >= 0) | |
2590 { | |
2591 if (n > 0 && m > 0) | |
2592 { | |
2593 idx_i.sort (true); | |
2594 n = idx_i.length (n); | |
2595 idx_j.sort (true); | |
2596 m = idx_j.length (m); | |
2597 | |
2598 octave_idx_type max_row_idx = idx_i_is_colon ? rhs_nr : | |
2599 idx_i.max () + 1; | |
2600 octave_idx_type max_col_idx = idx_j_is_colon ? rhs_nc : | |
2601 idx_j.max () + 1; | |
2602 octave_idx_type new_nr = max_row_idx > lhs_nr ? | |
2603 max_row_idx : lhs_nr; | |
2604 octave_idx_type new_nc = max_col_idx > lhs_nc ? | |
2605 max_col_idx : lhs_nc; | |
2606 RT scalar = rhs.elem (0, 0); | |
2607 | |
2608 // Count the number of non-zero terms | |
2609 octave_idx_type new_nzmx = lhs.nnz (); | |
2610 for (octave_idx_type j = 0; j < m; j++) | |
2611 { | |
2612 octave_idx_type jj = idx_j.elem (j); | |
2613 if (jj < lhs_nc) | |
2614 { | |
2615 for (octave_idx_type i = 0; i < n; i++) | |
2616 { | |
2617 octave_quit (); | |
2618 | |
2619 octave_idx_type ii = idx_i.elem (i); | |
2620 | |
2621 if (ii < lhs_nr) | |
2622 { | |
2623 for (octave_idx_type k = c_lhs.cidx(jj); | |
2624 k < c_lhs.cidx(jj+1); k++) | |
2625 { | |
2626 if (c_lhs.ridx(k) == ii) | |
2627 new_nzmx--; | |
2628 if (c_lhs.ridx(k) >= ii) | |
2629 break; | |
2630 } | |
2631 } | |
2632 } | |
2633 } | |
2634 } | |
2635 | |
2636 if (scalar != RT()) | |
2637 new_nzmx += m * n; | |
2638 | |
2639 Sparse<LT> stmp (new_nr, new_nc, new_nzmx); | |
2640 | |
2641 octave_idx_type jji = 0; | |
2642 octave_idx_type jj = idx_j.elem (jji); | |
2643 octave_idx_type kk = 0; | |
2644 stmp.cidx(0) = 0; | |
2645 for (octave_idx_type j = 0; j < new_nc; j++) | |
2646 { | |
2647 if (jji < m && jj == j) | |
2648 { | |
2649 octave_idx_type iii = 0; | |
2650 octave_idx_type ii = idx_i.elem (iii); | |
2651 octave_idx_type ppp = 0; | |
2652 octave_idx_type ppi = (j >= lhs_nc ? 0 : | |
2653 c_lhs.cidx(j+1) - | |
2654 c_lhs.cidx(j)); | |
2655 octave_idx_type pp = (ppp < ppi ? | |
2656 c_lhs.ridx(c_lhs.cidx(j)+ppp) : | |
2657 new_nr); | |
2658 while (ppp < ppi || iii < n) | |
2659 { | |
2660 if (iii < n && ii <= pp) | |
2661 { | |
2662 if (scalar != RT ()) | |
2663 { | |
2664 stmp.data(kk) = scalar; | |
2665 stmp.ridx(kk++) = ii; | |
2666 } | |
2667 if (ii == pp) | |
2668 pp = (++ppp < ppi ? c_lhs.ridx(c_lhs.cidx(j)+ppp) : new_nr); | |
2669 if (++iii < n) | |
2670 ii = idx_i.elem(iii); | |
2671 } | |
2672 else | |
2673 { | |
2674 stmp.data(kk) = | |
2675 c_lhs.data(c_lhs.cidx(j)+ppp); | |
2676 stmp.ridx(kk++) = pp; | |
2677 pp = (++ppp < ppi ? c_lhs.ridx(c_lhs.cidx(j)+ppp) : new_nr); | |
2678 } | |
2679 } | |
2680 if (++jji < m) | |
2681 jj = idx_j.elem(jji); | |
2682 } | |
2683 else if (j < lhs_nc) | |
2684 { | |
2685 for (octave_idx_type i = c_lhs.cidx(j); | |
2686 i < c_lhs.cidx(j+1); i++) | |
2687 { | |
2688 stmp.data(kk) = c_lhs.data(i); | |
2689 stmp.ridx(kk++) = c_lhs.ridx(i); | |
2690 } | |
2691 } | |
2692 stmp.cidx(j+1) = kk; | |
2693 } | |
2694 | |
2695 lhs = stmp; | |
2696 } | |
2697 else | |
2698 { | |
2699 #if 0 | |
2700 // FIXME -- the following code will make this | |
2701 // function behave the same as the full matrix | |
2702 // case for things like | |
2703 // | |
2704 // x = sparse (ones (2)); | |
2705 // x([],3) = 2; | |
2706 // | |
2707 // x = | |
2708 // | |
2709 // Compressed Column Sparse (rows = 2, cols = 3, nnz = 4) | |
2710 // | |
2711 // (1, 1) -> 1 | |
2712 // (2, 1) -> 1 | |
2713 // (1, 2) -> 1 | |
2714 // (2, 2) -> 1 | |
2715 // | |
2716 // However, Matlab doesn't resize in this case | |
2717 // even though it does in the full matrix case. | |
2718 | |
2719 if (n > 0) | |
2720 { | |
2721 octave_idx_type max_row_idx = idx_i_is_colon ? | |
2722 rhs_nr : idx_i.max () + 1; | |
2723 octave_idx_type new_nr = max_row_idx > lhs_nr ? | |
2724 max_row_idx : lhs_nr; | |
2725 octave_idx_type new_nc = lhs_nc; | |
2726 | |
2727 lhs.resize (new_nr, new_nc); | |
2728 } | |
2729 else if (m > 0) | |
2730 { | |
2731 octave_idx_type max_col_idx = idx_j_is_colon ? | |
2732 rhs_nc : idx_j.max () + 1; | |
2733 octave_idx_type new_nr = lhs_nr; | |
2734 octave_idx_type new_nc = max_col_idx > lhs_nc ? | |
2735 max_col_idx : lhs_nc; | |
2736 | |
2737 lhs.resize (new_nr, new_nc); | |
2738 } | |
2739 #endif | |
2740 } | |
2741 } | |
2742 else if (n == rhs_nr && m == rhs_nc) | |
2743 { | |
2744 if (n > 0 && m > 0) | |
2745 { | |
2746 octave_idx_type max_row_idx = idx_i_is_colon ? rhs_nr : | |
2747 idx_i.max () + 1; | |
2748 octave_idx_type max_col_idx = idx_j_is_colon ? rhs_nc : | |
2749 idx_j.max () + 1; | |
2750 octave_idx_type new_nr = max_row_idx > lhs_nr ? | |
2751 max_row_idx : lhs_nr; | |
2752 octave_idx_type new_nc = max_col_idx > lhs_nc ? | |
2753 max_col_idx : lhs_nc; | |
2754 | |
2755 OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx_i, n); | |
2756 if (! idx_i.is_colon ()) | |
2757 { | |
2758 // Ok here we have to be careful with the indexing, | |
2759 // to treat cases like "a([3,2,1],:) = b", and still | |
2760 // handle the need for strict sorting of the sparse | |
2761 // elements. | |
2762 OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *, | |
2763 sidx, n); | |
2764 OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort, | |
2765 sidxX, n); | |
2766 | |
2767 for (octave_idx_type i = 0; i < n; i++) | |
2768 { | |
2769 sidx[i] = &sidxX[i]; | |
2770 sidx[i]->i = idx_i.elem(i); | |
2771 sidx[i]->idx = i; | |
2772 } | |
2773 | |
2774 octave_quit (); | |
2775 octave_sort<octave_idx_vector_sort *> | |
2776 sort (octave_idx_vector_comp); | |
2777 | |
2778 sort.sort (sidx, n); | |
2779 | |
2780 intNDArray<octave_idx_type> new_idx (dim_vector (n,1)); | |
2781 | |
2782 for (octave_idx_type i = 0; i < n; i++) | |
2783 { | |
2784 new_idx.xelem(i) = sidx[i]->i; | |
2785 rhs_idx_i[i] = sidx[i]->idx; | |
2786 } | |
2787 | |
2788 idx_i = idx_vector (new_idx); | |
2789 } | |
2790 else | |
2791 for (octave_idx_type i = 0; i < n; i++) | |
2792 rhs_idx_i[i] = i; | |
2793 | |
2794 OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx_j, m); | |
2795 if (! idx_j.is_colon ()) | |
2796 { | |
2797 // Ok here we have to be careful with the indexing, | |
2798 // to treat cases like "a([3,2,1],:) = b", and still | |
2799 // handle the need for strict sorting of the sparse | |
2800 // elements. | |
2801 OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *, | |
2802 sidx, m); | |
2803 OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort, | |
2804 sidxX, m); | |
2805 | |
2806 for (octave_idx_type i = 0; i < m; i++) | |
2807 { | |
2808 sidx[i] = &sidxX[i]; | |
2809 sidx[i]->i = idx_j.elem(i); | |
2810 sidx[i]->idx = i; | |
2811 } | |
2812 | |
2813 octave_quit (); | |
2814 octave_sort<octave_idx_vector_sort *> | |
2815 sort (octave_idx_vector_comp); | |
2816 | |
2817 sort.sort (sidx, m); | |
2818 | |
2819 intNDArray<octave_idx_type> new_idx (dim_vector (m,1)); | |
2820 | |
2821 for (octave_idx_type i = 0; i < m; i++) | |
2822 { | |
2823 new_idx.xelem(i) = sidx[i]->i; | |
2824 rhs_idx_j[i] = sidx[i]->idx; | |
2825 } | |
2826 | |
2827 idx_j = idx_vector (new_idx); | |
2828 } | |
2829 else | |
2830 for (octave_idx_type i = 0; i < m; i++) | |
2831 rhs_idx_j[i] = i; | |
2832 | |
2833 // Maximum number of non-zero elements | |
2834 octave_idx_type new_nzmx = lhs.nnz() + rhs.nnz(); | |
2835 | |
2836 Sparse<LT> stmp (new_nr, new_nc, new_nzmx); | |
2837 | |
2838 octave_idx_type jji = 0; | |
2839 octave_idx_type jj = idx_j.elem (jji); | |
2840 octave_idx_type kk = 0; | |
2841 stmp.cidx(0) = 0; | |
2842 for (octave_idx_type j = 0; j < new_nc; j++) | |
2843 { | |
2844 if (jji < m && jj == j) | |
2845 { | |
2846 octave_idx_type iii = 0; | |
2847 octave_idx_type ii = idx_i.elem (iii); | |
2848 octave_idx_type ppp = 0; | |
2849 octave_idx_type ppi = (j >= lhs_nc ? 0 : | |
2850 c_lhs.cidx(j+1) - | |
2851 c_lhs.cidx(j)); | |
2852 octave_idx_type pp = (ppp < ppi ? | |
2853 c_lhs.ridx(c_lhs.cidx(j)+ppp) : | |
2854 new_nr); | |
2855 while (ppp < ppi || iii < n) | |
2856 { | |
2857 if (iii < n && ii <= pp) | |
2858 { | |
2859 if (iii < n - 1 && | |
2860 idx_i.elem (iii + 1) == ii) | |
2861 { | |
2862 iii++; | |
2863 ii = idx_i.elem(iii); | |
2864 continue; | |
2865 } | |
2866 | |
2867 RT rtmp = rhs.elem (rhs_idx_i[iii], | |
2868 rhs_idx_j[jji]); | |
2869 if (rtmp != RT ()) | |
2870 { | |
2871 stmp.data(kk) = rtmp; | |
2872 stmp.ridx(kk++) = ii; | |
2873 } | |
2874 if (ii == pp) | |
2875 pp = (++ppp < ppi ? c_lhs.ridx(c_lhs.cidx(j)+ppp) : new_nr); | |
2876 if (++iii < n) | |
2877 ii = idx_i.elem(iii); | |
2878 } | |
2879 else | |
2880 { | |
2881 stmp.data(kk) = | |
2882 c_lhs.data(c_lhs.cidx(j)+ppp); | |
2883 stmp.ridx(kk++) = pp; | |
2884 pp = (++ppp < ppi ? c_lhs.ridx(c_lhs.cidx(j)+ppp) : new_nr); | |
2885 } | |
2886 } | |
2887 if (++jji < m) | |
2888 jj = idx_j.elem(jji); | |
2889 } | |
2890 else if (j < lhs_nc) | |
2891 { | |
2892 for (octave_idx_type i = c_lhs.cidx(j); | |
2893 i < c_lhs.cidx(j+1); i++) | |
2894 { | |
2895 stmp.data(kk) = c_lhs.data(i); | |
2896 stmp.ridx(kk++) = c_lhs.ridx(i); | |
2897 } | |
2898 } | |
2899 stmp.cidx(j+1) = kk; | |
2900 } | |
2901 | |
2902 stmp.maybe_compress(); | |
2903 lhs = stmp; | |
2904 } | |
2905 } | |
2906 else if (n == 0 && m == 0) | |
2907 { | |
2908 if (! ((rhs_nr == 1 && rhs_nc == 1) | |
2909 || (rhs_nr == 0 || rhs_nc == 0))) | |
2910 { | |
2911 (*current_liboctave_error_handler) | |
2912 ("A([], []) = X: X must be an empty matrix or a scalar"); | |
2913 | |
2914 retval = 0; | |
2915 } | |
2916 } | |
2917 else | |
2918 { | |
2919 (*current_liboctave_error_handler) | |
2920 ("A(I, J) = X: X must be a scalar or the number of elements in I must"); | |
2921 (*current_liboctave_error_handler) | |
2922 ("match the number of rows in X and the number of elements in J must"); | |
2923 (*current_liboctave_error_handler) | |
2924 ("match the number of columns in X"); | |
2925 | |
2926 retval = 0; | |
2927 } | |
2928 } | |
2929 // idx_vector::freeze() printed an error message for us. | |
2930 } | |
2931 else if (n_idx == 1) | |
2932 { | |
2933 int lhs_is_empty = lhs_nr == 0 || lhs_nc == 0; | |
2934 | |
2935 if (lhs_is_empty || (lhs_nr == 1 && lhs_nc == 1)) | |
2936 { | |
2937 octave_idx_type lhs_len = lhs.length (); | |
2938 | |
2939 // Called for side-effects on idx_i. | |
2940 idx_i.freeze (lhs_len, 0, true); | |
2941 | |
2942 if (idx_i) | |
2943 { | |
2944 if (lhs_is_empty | |
2945 && idx_i.is_colon () | |
2946 && ! (rhs_nr == 1 || rhs_nc == 1)) | |
2947 { | |
2948 (*current_liboctave_warning_with_id_handler) | |
2949 ("Octave:fortran-indexing", | |
2950 "A(:) = X: X is not a vector or scalar"); | |
2951 } | |
2952 else | |
2953 { | |
2954 octave_idx_type idx_nr = idx_i.orig_rows (); | |
2955 octave_idx_type idx_nc = idx_i.orig_columns (); | |
2956 | |
2957 if (! (rhs_nr == idx_nr && rhs_nc == idx_nc)) | |
2958 (*current_liboctave_warning_with_id_handler) | |
2959 ("Octave:fortran-indexing", | |
2960 "A(I) = X: X does not have same shape as I"); | |
2961 } | |
2962 | |
2963 if (! assign1 (lhs, rhs)) | |
2964 retval = 0; | |
2965 } | |
2966 // idx_vector::freeze() printed an error message for us. | |
2967 } | |
2968 else if (lhs_nr == 1) | |
2969 { | |
2970 idx_i.freeze (lhs_nc, "vector", true); | |
2971 | |
2972 if (idx_i) | |
2973 { | |
2974 if (! assign1 (lhs, rhs)) | |
2975 retval = 0; | |
2976 } | |
2977 // idx_vector::freeze() printed an error message for us. | |
2978 } | |
2979 else if (lhs_nc == 1) | |
2980 { | |
2981 idx_i.freeze (lhs_nr, "vector", true); | |
2982 | |
2983 if (idx_i) | |
2984 { | |
2985 if (! assign1 (lhs, rhs)) | |
2986 retval = 0; | |
2987 } | |
2988 // idx_vector::freeze() printed an error message for us. | |
2989 } | |
2990 else | |
2991 { | |
2992 if (! idx_i.is_colon ()) | |
2993 (*current_liboctave_warning_with_id_handler) | |
2994 ("Octave:fortran-indexing", "single index used for matrix"); | |
2995 | |
2996 octave_idx_type lhs_len = lhs.length (); | |
2997 | |
2998 octave_idx_type len = idx_i.freeze (lhs_nr * lhs_nc, "matrix"); | |
2999 | |
3000 if (idx_i) | |
3001 { | |
3002 if (len == 0) | |
3003 { | |
3004 if (! ((rhs_nr == 1 && rhs_nc == 1) | |
3005 || (rhs_nr == 0 || rhs_nc == 0))) | |
3006 (*current_liboctave_error_handler) | |
3007 ("A([]) = X: X must be an empty matrix or scalar"); | |
3008 } | |
3009 else if (len == rhs_nr * rhs_nc) | |
3010 { | |
3011 octave_idx_type new_nzmx = lhs_nz; | |
3012 OCTAVE_LOCAL_BUFFER (octave_idx_type, rhs_idx, len); | |
3013 | |
3014 if (! idx_i.is_colon ()) | |
3015 { | |
3016 // Ok here we have to be careful with the indexing, to | |
3017 // treat cases like "a([3,2,1]) = b", and still handle | |
3018 // the need for strict sorting of the sparse elements. | |
3019 | |
3020 OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort *, sidx, | |
3021 len); | |
3022 OCTAVE_LOCAL_BUFFER (octave_idx_vector_sort, sidxX, | |
3023 len); | |
3024 | |
3025 for (octave_idx_type i = 0; i < len; i++) | |
3026 { | |
3027 sidx[i] = &sidxX[i]; | |
3028 sidx[i]->i = idx_i.elem(i); | |
3029 sidx[i]->idx = i; | |
3030 } | |
3031 | |
3032 octave_quit (); | |
3033 octave_sort<octave_idx_vector_sort *> | |
3034 sort (octave_idx_vector_comp); | |
3035 | |
3036 sort.sort (sidx, len); | |
3037 | |
3038 intNDArray<octave_idx_type> new_idx (dim_vector (len,1)); | |
3039 | |
3040 for (octave_idx_type i = 0; i < len; i++) | |
3041 { | |
3042 new_idx.xelem(i) = sidx[i]->i; | |
3043 rhs_idx[i] = sidx[i]->idx; | |
3044 } | |
3045 | |
3046 idx_i = idx_vector (new_idx); | |
3047 } | |
3048 else | |
3049 for (octave_idx_type i = 0; i < len; i++) | |
3050 rhs_idx[i] = i; | |
3051 | |
3052 // First count the number of non-zero elements | |
3053 for (octave_idx_type i = 0; i < len; i++) | |
3054 { | |
3055 octave_quit (); | |
3056 | |
3057 octave_idx_type ii = idx_i.elem (i); | |
3058 if (i < len - 1 && idx_i.elem (i + 1) == ii) | |
3059 continue; | |
3060 if (ii < lhs_len && c_lhs.elem(ii) != LT ()) | |
3061 new_nzmx--; | |
3062 if (rhs.elem(rhs_idx[i]) != RT ()) | |
3063 new_nzmx++; | |
3064 } | |
3065 | |
3066 Sparse<LT> stmp (lhs_nr, lhs_nc, new_nzmx); | |
3067 | |
3068 octave_idx_type i = 0; | |
3069 octave_idx_type ii = 0; | |
3070 octave_idx_type ic = 0; | |
3071 if (i < lhs_nz) | |
3072 { | |
3073 while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) | |
3074 ic++; | |
3075 ii = ic * lhs_nr + c_lhs.ridx(i); | |
3076 } | |
3077 | |
3078 octave_idx_type j = 0; | |
3079 octave_idx_type jj = idx_i.elem (j); | |
3080 octave_idx_type jr = jj % lhs_nr; | |
3081 octave_idx_type jc = (jj - jr) / lhs_nr; | |
3082 | |
3083 octave_idx_type kk = 0; | |
3084 octave_idx_type kc = 0; | |
3085 | |
3086 while (j < len || i < lhs_nz) | |
3087 { | |
3088 if (j < len - 1 && idx_i.elem (j + 1) == jj) | |
3089 { | |
3090 j++; | |
3091 jj = idx_i.elem (j); | |
3092 jr = jj % lhs_nr; | |
3093 jc = (jj - jr) / lhs_nr; | |
3094 continue; | |
3095 } | |
3096 | |
3097 if (j == len || (i < lhs_nz && ii < jj)) | |
3098 { | |
3099 while (kc <= ic) | |
3100 stmp.xcidx (kc++) = kk; | |
3101 stmp.xdata (kk) = c_lhs.data (i); | |
3102 stmp.xridx (kk++) = c_lhs.ridx (i); | |
3103 i++; | |
3104 while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) | |
3105 ic++; | |
3106 if (i < lhs_nz) | |
3107 ii = ic * lhs_nr + c_lhs.ridx(i); | |
3108 } | |
3109 else | |
3110 { | |
3111 while (kc <= jc) | |
3112 stmp.xcidx (kc++) = kk; | |
3113 RT rtmp = rhs.elem (rhs_idx[j]); | |
3114 if (rtmp != RT ()) | |
3115 { | |
3116 stmp.xdata (kk) = rtmp; | |
3117 stmp.xridx (kk++) = jr; | |
3118 } | |
3119 if (ii == jj) | |
3120 { | |
3121 i++; | |
3122 while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) | |
3123 ic++; | |
3124 if (i < lhs_nz) | |
3125 ii = ic * lhs_nr + c_lhs.ridx(i); | |
3126 } | |
3127 j++; | |
3128 if (j < len) | |
3129 { | |
3130 jj = idx_i.elem (j); | |
3131 jr = jj % lhs_nr; | |
3132 jc = (jj - jr) / lhs_nr; | |
3133 } | |
3134 } | |
3135 } | |
3136 | |
3137 for (octave_idx_type iidx = kc; iidx < lhs_nc+1; iidx++) | |
3138 stmp.xcidx(iidx) = kk; | |
3139 | |
3140 lhs = stmp; | |
3141 } | |
3142 else if (rhs_nr == 1 && rhs_nc == 1) | |
3143 { | |
3144 RT scalar = rhs.elem (0, 0); | |
3145 octave_idx_type new_nzmx = lhs_nz; | |
3146 idx_i.sort (true); | |
3147 len = idx_i.length (len); | |
3148 | |
3149 // First count the number of non-zero elements | |
3150 if (scalar != RT ()) | |
3151 new_nzmx += len; | |
3152 for (octave_idx_type i = 0; i < len; i++) | |
3153 { | |
3154 octave_quit (); | |
3155 octave_idx_type ii = idx_i.elem (i); | |
3156 if (ii < lhs_len && c_lhs.elem(ii) != LT ()) | |
3157 new_nzmx--; | |
3158 } | |
3159 | |
3160 Sparse<LT> stmp (lhs_nr, lhs_nc, new_nzmx); | |
3161 | |
3162 octave_idx_type i = 0; | |
3163 octave_idx_type ii = 0; | |
3164 octave_idx_type ic = 0; | |
3165 if (i < lhs_nz) | |
3166 { | |
3167 while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) | |
3168 ic++; | |
3169 ii = ic * lhs_nr + c_lhs.ridx(i); | |
3170 } | |
3171 | |
3172 octave_idx_type j = 0; | |
3173 octave_idx_type jj = idx_i.elem (j); | |
3174 octave_idx_type jr = jj % lhs_nr; | |
3175 octave_idx_type jc = (jj - jr) / lhs_nr; | |
3176 | |
3177 octave_idx_type kk = 0; | |
3178 octave_idx_type kc = 0; | |
3179 | |
3180 while (j < len || i < lhs_nz) | |
3181 { | |
3182 if (j == len || (i < lhs_nz && ii < jj)) | |
3183 { | |
3184 while (kc <= ic) | |
3185 stmp.xcidx (kc++) = kk; | |
3186 stmp.xdata (kk) = c_lhs.data (i); | |
3187 stmp.xridx (kk++) = c_lhs.ridx (i); | |
3188 i++; | |
3189 while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) | |
3190 ic++; | |
3191 if (i < lhs_nz) | |
3192 ii = ic * lhs_nr + c_lhs.ridx(i); | |
3193 } | |
3194 else | |
3195 { | |
3196 while (kc <= jc) | |
3197 stmp.xcidx (kc++) = kk; | |
3198 if (scalar != RT ()) | |
3199 { | |
3200 stmp.xdata (kk) = scalar; | |
3201 stmp.xridx (kk++) = jr; | |
3202 } | |
3203 if (ii == jj) | |
3204 { | |
3205 i++; | |
3206 while (ic < lhs_nc && i >= c_lhs.cidx(ic+1)) | |
3207 ic++; | |
3208 if (i < lhs_nz) | |
3209 ii = ic * lhs_nr + c_lhs.ridx(i); | |
3210 } | |
3211 j++; | |
3212 if (j < len) | |
3213 { | |
3214 jj = idx_i.elem (j); | |
3215 jr = jj % lhs_nr; | |
3216 jc = (jj - jr) / lhs_nr; | |
3217 } | |
3218 } | |
3219 } | |
3220 | |
3221 for (octave_idx_type iidx = kc; iidx < lhs_nc+1; iidx++) | |
3222 stmp.xcidx(iidx) = kk; | |
3223 | |
3224 lhs = stmp; | |
3225 } | |
3226 else | |
3227 { | |
3228 (*current_liboctave_error_handler) | |
3229 ("A(I) = X: X must be a scalar or a matrix with the same size as I"); | |
3230 | |
3231 retval = 0; | |
3232 } | |
3233 } | |
3234 // idx_vector::freeze() printed an error message for us. | |
3235 } | |
3236 } | |
3237 else | |
3238 { | |
3239 (*current_liboctave_error_handler) | |
3240 ("invalid number of indices for matrix expression"); | |
3241 | |
3242 retval = 0; | |
3243 } | |
3244 | |
3245 lhs.clear_index (); | |
3246 | 2404 |
3247 return retval; | 2405 return retval; |
3248 } | 2406 } |
3249 | 2407 |
3250 /* | 2408 /* |
3328 %!test test_sparse_slice([2 2], 11, 2); | 2486 %!test test_sparse_slice([2 2], 11, 2); |
3329 %!test test_sparse_slice([2 2], 11, 3); | 2487 %!test test_sparse_slice([2 2], 11, 3); |
3330 %!test test_sparse_slice([2 2], 11, 4); | 2488 %!test test_sparse_slice([2 2], 11, 4); |
3331 %!test test_sparse_slice([2 2], 11, [4, 4]); | 2489 %!test test_sparse_slice([2 2], 11, [4, 4]); |
3332 # These 2 errors are the same as in the full case | 2490 # These 2 errors are the same as in the full case |
3333 %!error <invalid matrix index = 5> set_slice(sparse(ones([2 2])), 11, 5); | 2491 %!error id=Octave:invalid-resize set_slice(sparse(ones([2 2])), 11, 5); |
3334 %!error <invalid matrix index = 6> set_slice(sparse(ones([2 2])), 11, 6); | 2492 %!error id=Octave:invalid-resize set_slice(sparse(ones([2 2])), 11, 6); |
3335 | 2493 |
3336 | 2494 |
3337 #### 2d indexing | 2495 #### 2d indexing |
3338 | 2496 |
3339 ## size = [2 0] | 2497 ## size = [2 0] |
3419 << prefix << "rep->data: " << static_cast<void *> (rep->d) << "\n" | 2577 << prefix << "rep->data: " << static_cast<void *> (rep->d) << "\n" |
3420 << prefix << "rep->ridx: " << static_cast<void *> (rep->r) << "\n" | 2578 << prefix << "rep->ridx: " << static_cast<void *> (rep->r) << "\n" |
3421 << prefix << "rep->cidx: " << static_cast<void *> (rep->c) << "\n" | 2579 << prefix << "rep->cidx: " << static_cast<void *> (rep->c) << "\n" |
3422 << prefix << "rep->count: " << rep->count << "\n"; | 2580 << prefix << "rep->count: " << rep->count << "\n"; |
3423 } | 2581 } |
2582 | |
2583 #define INSTANTIATE_SPARSE(T, API) \ | |
2584 template class API Sparse<T>; | |
2585 |