5683
|
1 /* |
|
2 |
|
3 Copyright (C) 2006 David Bateman |
|
4 |
|
5 Octave is free software; you can redistribute it and/or modify it |
|
6 under the terms of the GNU General Public License as published by the |
|
7 Free Software Foundation; either version 2, or (at your option) any |
|
8 later version. |
|
9 |
|
10 Octave is distributed in the hope that it will be useful, but WITHOUT |
|
11 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or |
|
12 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License |
|
13 for more details. |
|
14 |
|
15 You should have received a copy of the GNU General Public License |
|
16 along with this program; see the file COPYING. If not, write to the |
|
17 Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, |
|
18 Boston, MA 02110-1301, USA. |
|
19 |
|
20 */ |
|
21 |
|
22 #ifdef HAVE_CONFIG_H |
|
23 #include <config.h> |
|
24 #endif |
|
25 |
|
26 |
|
27 |
|
28 #include "ov-re-sparse.h" |
|
29 #include "ov-cx-sparse.h" |
|
30 #include "MArray2.h" |
|
31 #include "MSparse.h" |
|
32 #include "SparseQR.h" |
|
33 #include "SparseCmplxQR.h" |
5785
|
34 #include "MatrixType.h" |
5683
|
35 #include "oct-sort.h" |
|
36 |
|
37 template <class T> |
|
38 static MSparse<T> |
|
39 dmsolve_extract (const MSparse<T> &A, const octave_idx_type *Pinv, |
|
40 const octave_idx_type *Q, octave_idx_type rst, |
|
41 octave_idx_type rend, octave_idx_type cst, |
|
42 octave_idx_type cend, octave_idx_type maxnz = -1, |
|
43 bool lazy = false) |
|
44 { |
|
45 octave_idx_type nz = (rend - rst) * (cend - cst); |
|
46 maxnz = (maxnz < 0 ? A.nnz () : maxnz); |
|
47 MSparse<T> B (rend - rst, cend - cst, (nz < maxnz ? nz : maxnz)); |
|
48 // Some sparse functions can support lazy indexing (where elements |
|
49 // in the row are in no particular order), even though octave in |
|
50 // general can't. For those functions that can using it is a big |
|
51 // win here in terms of speed. |
|
52 if (lazy) |
|
53 { |
|
54 nz = 0; |
|
55 for (octave_idx_type j = cst ; j < cend ; j++) |
|
56 { |
|
57 octave_idx_type qq = (Q ? Q [j] : j); |
|
58 B.xcidx (j - cst) = nz; |
|
59 for (octave_idx_type p = A.cidx(qq) ; p < A.cidx (qq+1) ; p++) |
|
60 { |
|
61 OCTAVE_QUIT; |
|
62 octave_idx_type r = (Pinv ? Pinv [A.ridx (p)] : A.ridx (p)); |
|
63 if (r >= rst && r < rend) |
|
64 { |
|
65 B.xdata (nz) = A.data (p); |
|
66 B.xridx (nz++) = r - rst ; |
|
67 } |
|
68 } |
|
69 } |
|
70 B.xcidx (cend - cst) = nz ; |
|
71 } |
|
72 else |
|
73 { |
|
74 OCTAVE_LOCAL_BUFFER (T, X, rend - rst); |
|
75 octave_sort<octave_idx_type> sort; |
|
76 octave_idx_type *ri = B.xridx(); |
|
77 nz = 0; |
|
78 for (octave_idx_type j = cst ; j < cend ; j++) |
|
79 { |
|
80 octave_idx_type qq = (Q ? Q [j] : j); |
|
81 B.xcidx (j - cst) = nz; |
|
82 for (octave_idx_type p = A.cidx(qq) ; p < A.cidx (qq+1) ; p++) |
|
83 { |
|
84 OCTAVE_QUIT; |
|
85 octave_idx_type r = (Pinv ? Pinv [A.ridx (p)] : A.ridx (p)); |
|
86 if (r >= rst && r < rend) |
|
87 { |
|
88 X [r-rst] = A.data (p); |
|
89 B.xridx (nz++) = r - rst ; |
|
90 } |
|
91 } |
|
92 sort.sort (ri + B.xcidx (j - cst), nz - B.xcidx (j - cst)); |
|
93 for (octave_idx_type p = B.cidx (j - cst); p < nz; p++) |
|
94 B.xdata (p) = X [B.xridx (p)]; |
|
95 } |
|
96 B.xcidx (cend - cst) = nz ; |
|
97 } |
|
98 |
|
99 return B; |
|
100 } |
|
101 |
|
102 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) |
|
103 static MSparse<double> |
|
104 dmsolve_extract (const MSparse<double> &A, const octave_idx_type *Pinv, |
|
105 const octave_idx_type *Q, octave_idx_type rst, |
|
106 octave_idx_type rend, octave_idx_type cst, |
|
107 octave_idx_type cend, octave_idx_type maxnz, |
|
108 bool lazy); |
|
109 |
|
110 static MSparse<Complex> |
|
111 dmsolve_extract (const MSparse<Complex> &A, const octave_idx_type *Pinv, |
|
112 const octave_idx_type *Q, octave_idx_type rst, |
|
113 octave_idx_type rend, octave_idx_type cst, |
|
114 octave_idx_type cend, octave_idx_type maxnz, |
|
115 bool lazy); |
|
116 #endif |
|
117 |
|
118 template <class T> |
|
119 static MArray2<T> |
|
120 dmsolve_extract (const MArray2<T> &m, const octave_idx_type *, |
|
121 const octave_idx_type *, octave_idx_type r1, |
|
122 octave_idx_type r2, octave_idx_type c1, |
|
123 octave_idx_type c2) |
|
124 { |
|
125 r2 -= 1; |
|
126 c2 -= 1; |
|
127 if (r1 > r2) { octave_idx_type tmp = r1; r1 = r2; r2 = tmp; } |
|
128 if (c1 > c2) { octave_idx_type tmp = c1; c1 = c2; c2 = tmp; } |
|
129 |
|
130 octave_idx_type new_r = r2 - r1 + 1; |
|
131 octave_idx_type new_c = c2 - c1 + 1; |
|
132 |
|
133 MArray2<T> result (new_r, new_c); |
|
134 |
|
135 for (octave_idx_type j = 0; j < new_c; j++) |
|
136 for (octave_idx_type i = 0; i < new_r; i++) |
|
137 result.xelem (i, j) = m.elem (r1+i, c1+j); |
|
138 |
|
139 return result; |
|
140 } |
|
141 |
|
142 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) |
|
143 static MArray2<double> |
|
144 dmsolve_extract (const MArray2<double> &m, const octave_idx_type *, |
|
145 const octave_idx_type *, octave_idx_type r1, |
|
146 octave_idx_type r2, octave_idx_type c1, |
|
147 octave_idx_type c2) |
|
148 |
|
149 static MArray2<Complex> |
|
150 dmsolve_extract (const MArray2<Complex> &m, const octave_idx_type *, |
|
151 const octave_idx_type *, octave_idx_type r1, |
|
152 octave_idx_type r2, octave_idx_type c1, |
|
153 octave_idx_type c2) |
|
154 #endif |
|
155 |
|
156 template <class T> |
|
157 static void |
|
158 dmsolve_insert (MArray2<T> &a, const MArray2<T> &b, const octave_idx_type *Q, |
|
159 octave_idx_type r, octave_idx_type c) |
|
160 { |
|
161 T *ax = a.fortran_vec(); |
|
162 const T *bx = b.fortran_vec(); |
|
163 octave_idx_type anr = a.rows(); |
|
164 octave_idx_type nr = b.rows(); |
|
165 octave_idx_type nc = b.cols(); |
|
166 for (octave_idx_type j = 0; j < nc; j++) |
|
167 { |
|
168 octave_idx_type aoff = (c + j) * anr; |
|
169 octave_idx_type boff = j * nr; |
|
170 for (octave_idx_type i = 0; i < nr; i++) |
|
171 { |
|
172 OCTAVE_QUIT; |
|
173 ax [Q [r + i] + aoff] = bx [i + boff]; |
|
174 } |
|
175 } |
|
176 } |
|
177 |
|
178 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) |
|
179 static void |
|
180 dmsolve_insert (MArray2<double> &a, const MArray2<double> &b, |
|
181 const octave_idx_type *Q, octave_idx_type r, octave_idx_type c); |
|
182 |
|
183 static void |
|
184 dmsolve_insert (MArray2<Complex> &a, const MArray2<Complex> &b, |
|
185 const octave_idx_type *Q, octave_idx_type r, octave_idx_type c); |
|
186 #endif |
|
187 |
|
188 template <class T> |
|
189 static void |
|
190 dmsolve_insert (MSparse<T> &a, const MSparse<T> &b, const octave_idx_type *Q, |
|
191 octave_idx_type r, octave_idx_type c) |
|
192 { |
|
193 octave_idx_type b_rows = b.rows (); |
|
194 octave_idx_type b_cols = b.cols (); |
|
195 octave_idx_type nr = a.rows (); |
|
196 octave_idx_type nc = a.cols (); |
|
197 |
|
198 OCTAVE_LOCAL_BUFFER (octave_idx_type, Qinv, nr); |
|
199 for (octave_idx_type i = 0; i < nr; i++) |
|
200 Qinv [Q [i]] = i; |
|
201 |
|
202 // First count the number of elements in the final array |
|
203 octave_idx_type nel = a.xcidx(c) + b.nnz (); |
|
204 |
|
205 if (c + b_cols < nc) |
|
206 nel += a.xcidx(nc) - a.xcidx(c + b_cols); |
|
207 |
|
208 for (octave_idx_type i = c; i < c + b_cols; i++) |
|
209 for (octave_idx_type j = a.xcidx(i); j < a.xcidx(i+1); j++) |
|
210 if (Qinv [a.xridx(j)] < r || Qinv [a.xridx(j)] >= r + b_rows) |
|
211 nel++; |
|
212 |
|
213 OCTAVE_LOCAL_BUFFER (T, X, nr); |
|
214 octave_sort<octave_idx_type> sort; |
|
215 MSparse<T> tmp (a); |
|
216 a = MSparse<T> (nr, nc, nel); |
|
217 octave_idx_type *ri = a.xridx(); |
|
218 |
|
219 for (octave_idx_type i = 0; i < tmp.cidx(c); i++) |
|
220 { |
|
221 a.xdata(i) = tmp.xdata(i); |
|
222 a.xridx(i) = tmp.xridx(i); |
|
223 } |
|
224 for (octave_idx_type i = 0; i < c + 1; i++) |
|
225 a.xcidx(i) = tmp.xcidx(i); |
|
226 |
|
227 octave_idx_type ii = a.xcidx(c); |
|
228 |
|
229 for (octave_idx_type i = c; i < c + b_cols; i++) |
|
230 { |
|
231 OCTAVE_QUIT; |
|
232 |
|
233 for (octave_idx_type j = tmp.xcidx(i); j < tmp.xcidx(i+1); j++) |
|
234 if (Qinv [tmp.xridx(j)] < r || Qinv [tmp.xridx(j)] >= r + b_rows) |
|
235 { |
|
236 X [tmp.xridx(j)] = tmp.xdata(j); |
|
237 a.xridx(ii++) = tmp.xridx(j); |
|
238 } |
|
239 |
|
240 OCTAVE_QUIT; |
|
241 |
|
242 for (octave_idx_type j = b.cidx(i-c); j < b.cidx(i-c+1); j++) |
|
243 { |
|
244 X [Q [r + b.ridx(j)]] = b.data(j); |
|
245 a.xridx(ii++) = Q [r + b.ridx(j)]; |
|
246 } |
|
247 |
|
248 sort.sort (ri + a.xcidx (i), ii - a.xcidx (i)); |
|
249 for (octave_idx_type p = a.xcidx (i); p < ii; p++) |
|
250 a.xdata (p) = X [a.xridx (p)]; |
|
251 a.xcidx(i+1) = ii; |
|
252 } |
|
253 |
|
254 for (octave_idx_type i = c + b_cols; i < nc; i++) |
|
255 { |
|
256 for (octave_idx_type j = tmp.xcidx(i); j < tmp.cidx(i+1); j++) |
|
257 { |
|
258 a.xdata(ii) = tmp.xdata(j); |
|
259 a.xridx(ii++) = tmp.xridx(j); |
|
260 } |
|
261 a.xcidx(i+1) = ii; |
|
262 } |
|
263 } |
|
264 |
|
265 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) |
|
266 static void |
|
267 dmsolve_insert (MSparse<double> &a, const SparseMatrix &b, |
|
268 const octave_idx_type *Q, octave_idx_type r, octave_idx_type c); |
|
269 |
|
270 static void |
|
271 dmsolve_insert (MSparse<Complex> &a, const MSparse<Complex> &b, |
|
272 const octave_idx_type *Q, octave_idx_type r, octave_idx_type c); |
|
273 #endif |
|
274 |
|
275 template <class T, class RT> |
|
276 static void |
|
277 dmsolve_permute (MArray2<RT> &a, const MArray2<T>& b, const octave_idx_type *p) |
|
278 { |
|
279 octave_idx_type b_nr = b.rows (); |
|
280 octave_idx_type b_nc = b.cols (); |
|
281 const T *Bx = b.fortran_vec(); |
|
282 a.resize(b_nr, b_nc); |
|
283 RT *Btx = a.fortran_vec(); |
|
284 for (octave_idx_type j = 0; j < b_nc; j++) |
|
285 { |
|
286 octave_idx_type off = j * b_nr; |
|
287 for (octave_idx_type i = 0; i < b_nr; i++) |
|
288 { |
|
289 OCTAVE_QUIT; |
|
290 Btx [p [i] + off] = Bx [ i + off]; |
|
291 } |
|
292 } |
|
293 } |
|
294 |
|
295 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) |
|
296 static void |
|
297 dmsolve_permute (MArray2<double> &a, const MArray2<double>& b, |
|
298 const octave_idx_type *p); |
|
299 |
|
300 static void |
|
301 dmsolve_permute (MArray2<Complex> &a, const MArray2<double>& b, |
|
302 const octave_idx_type *p); |
|
303 |
|
304 static void |
|
305 dmsolve_permute (MArray2<Complex> &a, const MArray2<Complex>& b, |
|
306 const octave_idx_type *p); |
|
307 #endif |
|
308 |
|
309 template <class T, class RT> |
|
310 static void |
|
311 dmsolve_permute (MSparse<RT> &a, const MSparse<T>& b, const octave_idx_type *p) |
|
312 { |
|
313 octave_idx_type b_nr = b.rows (); |
|
314 octave_idx_type b_nc = b.cols (); |
|
315 octave_idx_type b_nz = b.nnz (); |
|
316 octave_idx_type nz = 0; |
|
317 a = MSparse<RT> (b_nr, b_nc, b_nz); |
|
318 octave_sort<octave_idx_type> sort; |
|
319 octave_idx_type *ri = a.xridx(); |
|
320 OCTAVE_LOCAL_BUFFER (RT, X, b_nr); |
|
321 a.xcidx(0) = 0; |
|
322 for (octave_idx_type j = 0; j < b_nc; j++) |
|
323 { |
|
324 for (octave_idx_type i = b.cidx(j); i < b.cidx(j+1); i++) |
|
325 { |
|
326 OCTAVE_QUIT; |
|
327 octave_idx_type r = p [b.ridx (i)]; |
|
328 X [r] = b.data (i); |
|
329 a.xridx(nz++) = p [b.ridx (i)]; |
|
330 } |
|
331 sort.sort (ri + a.xcidx (j), nz - a.xcidx (j)); |
|
332 for (octave_idx_type i = a.cidx (j); i < nz; i++) |
|
333 { |
|
334 OCTAVE_QUIT; |
|
335 a.xdata (i) = X [a.xridx (i)]; |
|
336 } |
|
337 a.xcidx(j+1) = nz; |
|
338 } |
|
339 } |
|
340 |
|
341 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) |
|
342 static void |
|
343 dmsolve_permute (MSparse<double> &a, const MSparse<double>& b, |
|
344 const octave_idx_type *p); |
|
345 |
|
346 static void |
|
347 dmsolve_permute (MSparse<Complex> &a, const MSparse<double>& b, |
|
348 const octave_idx_type *p); |
|
349 |
|
350 static void |
|
351 dmsolve_permute (MSparse<Complex> &a, const MSparse<Complex>& b, |
|
352 const octave_idx_type *p); |
|
353 #endif |
|
354 |
|
355 static void |
|
356 solve_singularity_warning (double) |
|
357 { |
|
358 // Dummy singularity handler so that LU solver doesn't flag |
|
359 // an error for numerically rank defficient matrices |
|
360 } |
|
361 |
|
362 template <class RT, class ST, class T> |
|
363 RT |
|
364 dmsolve (const ST &a, const T &b, octave_idx_type &info) |
|
365 { |
5684
|
366 #ifdef HAVE_CXSPARSE |
5683
|
367 octave_idx_type nr = a.rows (); |
|
368 octave_idx_type nc = a.cols (); |
|
369 octave_idx_type b_nr = b.rows (); |
|
370 octave_idx_type b_nc = b.cols (); |
|
371 RT retval; |
|
372 |
|
373 if (nr < 1 || nc < 1 || nr != b_nr) |
|
374 (*current_liboctave_error_handler) |
|
375 ("matrix dimension mismatch in solution of minimum norm problem"); |
|
376 else |
|
377 { |
|
378 octave_idx_type nnz_remaining = a.nnz (); |
|
379 CXSPARSE_DNAME () csm; |
|
380 csm.m = nr; |
|
381 csm.n = nc; |
|
382 csm.x = NULL; |
|
383 csm.nz = -1; |
|
384 csm.nzmax = a.nnz (); |
|
385 // Cast away const on A, with full knowledge that CSparse won't touch it. |
|
386 // Prevents the methods below making a copy of the data. |
|
387 csm.p = const_cast<octave_idx_type *>(a.cidx ()); |
|
388 csm.i = const_cast<octave_idx_type *>(a.ridx ()); |
|
389 |
|
390 CXSPARSE_DNAME (d) *dm = CXSPARSE_DNAME(_dmperm) (&csm); |
|
391 octave_idx_type *p = dm->P; |
|
392 octave_idx_type *q = dm->Q; |
|
393 OCTAVE_LOCAL_BUFFER (octave_idx_type, pinv, nr); |
|
394 for (octave_idx_type i = 0; i < nr; i++) |
|
395 pinv [p [i]] = i; |
|
396 RT btmp; |
|
397 dmsolve_permute (btmp, b, pinv); |
|
398 info = 0; |
|
399 retval.resize (nc, b_nc); |
|
400 |
|
401 // Leading over-determined block |
|
402 if (dm->rr [2] < nr && dm->cc [3] < nc) |
|
403 { |
|
404 ST m = dmsolve_extract (a, pinv, q, dm->rr [2], nr, dm->cc [3], nc, |
|
405 nnz_remaining, true); |
|
406 nnz_remaining -= m.nnz(); |
|
407 RT mtmp = |
|
408 qrsolve (m, dmsolve_extract (btmp, NULL, NULL, dm->rr[2], b_nr, 0, |
|
409 b_nc), info); |
|
410 dmsolve_insert (retval, mtmp, q, dm->cc [3], 0); |
|
411 if (dm->rr [2] > 0 && !info && !error_state) |
|
412 { |
|
413 m = dmsolve_extract (a, pinv, q, 0, dm->rr [2], |
|
414 dm->cc [3], nc, nnz_remaining, true); |
|
415 nnz_remaining -= m.nnz(); |
|
416 RT ctmp = dmsolve_extract (btmp, NULL, NULL, 0, |
|
417 dm->rr[2], 0, b_nc); |
|
418 btmp.insert (ctmp - m * mtmp, 0, 0); |
|
419 } |
|
420 } |
|
421 |
|
422 // Structurally non-singular blocks |
5775
|
423 // FIXME Should use fine Dulmange-Mendelsohn decomposition here. |
5683
|
424 if (dm->rr [1] < dm->rr [2] && dm->cc [2] < dm->cc [3] && |
|
425 !info && !error_state) |
|
426 { |
|
427 ST m = dmsolve_extract (a, pinv, q, dm->rr [1], dm->rr [2], |
|
428 dm->cc [2], dm->cc [3], nnz_remaining, false); |
|
429 nnz_remaining -= m.nnz(); |
|
430 RT btmp2 = dmsolve_extract (btmp, NULL, NULL, dm->rr [1], dm->rr [2], |
|
431 0, b_nc); |
|
432 double rcond = 0.0; |
5785
|
433 MatrixType mtyp (MatrixType::Full); |
5683
|
434 RT mtmp = m.solve (mtyp, btmp2, info, rcond, |
5697
|
435 solve_singularity_warning, false); |
5683
|
436 if (info != 0) |
|
437 { |
|
438 info = 0; |
|
439 mtmp = qrsolve (m, btmp2, info); |
|
440 } |
|
441 |
|
442 dmsolve_insert (retval, mtmp, q, dm->cc [2], 0); |
|
443 if (dm->rr [1] > 0 && !info && !error_state) |
|
444 { |
|
445 m = dmsolve_extract (a, pinv, q, 0, dm->rr [1], dm->cc [2], |
|
446 dm->cc [3], nnz_remaining, true); |
|
447 nnz_remaining -= m.nnz(); |
|
448 RT ctmp = dmsolve_extract (btmp, NULL, NULL, 0, |
|
449 dm->rr[1], 0, b_nc); |
|
450 btmp.insert (ctmp - m * mtmp, 0, 0); |
|
451 } |
|
452 } |
|
453 |
|
454 // Trailing under-determined block |
|
455 if (dm->rr [1] > 0 && dm->cc [2] > 0 && !info && !error_state) |
|
456 { |
|
457 ST m = dmsolve_extract (a, pinv, q, 0, dm->rr [1], 0, |
|
458 dm->cc [2], nnz_remaining, true); |
|
459 RT mtmp = |
|
460 qrsolve (m, dmsolve_extract(btmp, NULL, NULL, 0, dm->rr [1] , 0, |
|
461 b_nc), info); |
|
462 dmsolve_insert (retval, mtmp, q, 0, 0); |
|
463 } |
|
464 |
|
465 CXSPARSE_DNAME (_dfree) (dm); |
|
466 } |
|
467 return retval; |
5684
|
468 #else |
|
469 return RT (); |
|
470 #endif |
5683
|
471 } |
|
472 |
|
473 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) |
|
474 extern Matrix |
|
475 dmsolve (const SparseMatrix &a, const Matrix &b, |
|
476 octave_idx_type &info); |
|
477 |
|
478 extern ComplexMatrix |
|
479 dmsolve (const SparseMatrix &a, const ComplexMatrix &b, |
|
480 octave_idx_type &info); |
|
481 |
|
482 extern ComplexMatrix |
|
483 dmsolve (const SparseComplexMatrix &a, const Matrix &b, |
|
484 octave_idx_type &info); |
|
485 |
|
486 extern ComplexMatrix |
|
487 dmsolve (const SparseComplexMatrix &a, const ComplexMatrix &b, |
|
488 octave_idx_type &info); |
|
489 |
|
490 extern SparseMatrix |
|
491 dmsolve (const SparseMatrix &a, const SparseMatrix &b, |
|
492 octave_idx_type &info); |
|
493 |
|
494 extern SparseComplexMatrix |
|
495 dmsolve (const SparseMatrix &a, const SparseComplexMatrix &b, |
|
496 octave_idx_type &info); |
|
497 |
|
498 extern SparseComplexMatrix |
|
499 dmsolve (const SparseComplexMatrix &a, const SparseMatrix &b, |
|
500 octave_idx_type &info); |
|
501 |
|
502 extern SparseComplexMatrix |
|
503 dmsolve (const SparseComplexMatrix &a, const SparseComplexMatrix &b, |
|
504 octave_idx_type &info); |
|
505 #endif |
|
506 |
|
507 /* |
|
508 ;;; Local Variables: *** |
|
509 ;;; mode: C++ *** |
|
510 ;;; End: *** |
|
511 */ |