Mercurial > hg > octave-nkf
comparison src/DLD-FUNCTIONS/tril.cc @ 9756:b134960cea23
implement built-in tril/triu
author | Jaroslav Hajek <highegg@gmail.com> |
---|---|
date | Fri, 23 Oct 2009 10:10:37 +0200 |
parents | |
children | 40dfc0c99116 |
comparison
equal
deleted
inserted
replaced
9755:4f4873f6f873 | 9756:b134960cea23 |
---|---|
1 /* | |
2 | |
3 Copyright (C) 2004, 2007 David Bateman | |
4 Copyright (C) 2009 VZLU Prague | |
5 | |
6 This program is free software; you can redistribute it and/or modify it | |
7 under the terms of the GNU General Public License as published by | |
8 the Free Software Foundation; either version 3, or (at your option) | |
9 any later version. | |
10 | |
11 This program is distributed in the hope that it will be useful, but | |
12 WITHOUT ANY WARRANTY; without even the implied warranty of | |
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | |
14 General Public License for more details. | |
15 | |
16 You should have received a copy of the GNU General Public License | |
17 along with Octave; see the file COPYING. If not, see | |
18 <http://www.gnu.org/licenses/>. | |
19 | |
20 */ | |
21 | |
22 #ifdef HAVE_CONFIG_H | |
23 #include <config.h> | |
24 #endif | |
25 | |
26 #include <algorithm> | |
27 #include "Array.h" | |
28 #include "Sparse.h" | |
29 #include "mx-base.h" | |
30 | |
31 #include "ov.h" | |
32 #include "Cell.h" | |
33 | |
34 #include "defun-dld.h" | |
35 #include "error.h" | |
36 #include "oct-obj.h" | |
37 | |
38 // The bulk of the work. | |
39 template <class T> | |
40 static Array<T> | |
41 do_tril (const Array<T>& a, octave_idx_type k, bool pack) | |
42 { | |
43 octave_idx_type nr = a.rows (), nc = a.columns (); | |
44 const T *avec = a.fortran_vec (); | |
45 | |
46 if (pack) | |
47 { | |
48 octave_idx_type j1 = std::min (std::max (0, k), nc); | |
49 octave_idx_type j2 = std::min (std::max (0, nr + k), nc); | |
50 octave_idx_type n = j1 * nr + ((j2 - j1) * (nr-(j1-k) + nr-(j2-1-k))) / 2; | |
51 Array<T> r (n); | |
52 T *rvec = r.fortran_vec (); | |
53 for (octave_idx_type j = 0; j < nc; j++) | |
54 { | |
55 octave_idx_type ii = std::min (std::max (0, j - k), nr); | |
56 rvec = std::copy (avec + ii, avec + nr, rvec); | |
57 avec += nr; | |
58 } | |
59 | |
60 return r; | |
61 } | |
62 else | |
63 { | |
64 Array<T> r (a.dims ()); | |
65 T *rvec = r.fortran_vec (); | |
66 for (octave_idx_type j = 0; j < nc; j++) | |
67 { | |
68 octave_idx_type ii = std::min (std::max (0, j - k), nr); | |
69 std::fill (rvec, rvec + ii, T()); | |
70 std::copy (avec + ii, avec + nr, rvec + ii); | |
71 avec += nr; | |
72 rvec += nr; | |
73 } | |
74 | |
75 return r; | |
76 } | |
77 } | |
78 | |
79 template <class T> | |
80 static Array<T> | |
81 do_triu (const Array<T>& a, octave_idx_type k, bool pack) | |
82 { | |
83 octave_idx_type nr = a.rows (), nc = a.columns (); | |
84 const T *avec = a.fortran_vec (); | |
85 | |
86 if (pack) | |
87 { | |
88 octave_idx_type j1 = std::min (std::max (0, k), nc); | |
89 octave_idx_type j2 = std::min (std::max (0, nr + k), nc); | |
90 octave_idx_type n = ((j2 - j1) * ((j1+1-k) + (j2-k))) / 2 + (nc - j2) * nr; | |
91 Array<T> r (n); | |
92 T *rvec = r.fortran_vec (); | |
93 for (octave_idx_type j = 0; j < nc; j++) | |
94 { | |
95 octave_idx_type ii = std::min (std::max (0, j + 1 - k), nr); | |
96 rvec = std::copy (avec, avec + ii, rvec); | |
97 avec += nr; | |
98 } | |
99 | |
100 return r; | |
101 } | |
102 else | |
103 { | |
104 NoAlias<Array<T> > r (a.dims ()); | |
105 T *rvec = r.fortran_vec (); | |
106 for (octave_idx_type j = 0; j < nc; j++) | |
107 { | |
108 octave_idx_type ii = std::min (std::max (0, j + 1 - k), nr); | |
109 std::copy (avec, avec + ii, rvec); | |
110 std::fill (rvec + ii, rvec + nr, T()); | |
111 avec += nr; | |
112 rvec += nr; | |
113 } | |
114 | |
115 return r; | |
116 } | |
117 } | |
118 | |
119 // These two are by David Bateman. | |
120 // FIXME: optimizations possible. "pack" support missing. | |
121 | |
122 template <class T> | |
123 static Sparse<T> | |
124 do_tril (const Sparse<T>& a, octave_idx_type k, bool pack) | |
125 { | |
126 if (pack) // FIXME | |
127 { | |
128 error ("tril: \"pack\" not implemented for sparse matrices"); | |
129 return Sparse<T> (); | |
130 } | |
131 | |
132 Sparse<T> m = a; | |
133 octave_idx_type nc = m.cols(); | |
134 | |
135 for (octave_idx_type j = 0; j < nc; j++) | |
136 for (octave_idx_type i = m.cidx(j); i < m.cidx(j+1); i++) | |
137 if (m.ridx(i) < j-k) | |
138 m.data(i) = 0.; | |
139 | |
140 m.maybe_compress (true); | |
141 return m; | |
142 } | |
143 | |
144 template <class T> | |
145 static Sparse<T> | |
146 do_triu (const Sparse<T>& a, octave_idx_type k, bool pack) | |
147 { | |
148 if (pack) // FIXME | |
149 { | |
150 error ("triu: \"pack\" not implemented for sparse matrices"); | |
151 return Sparse<T> (); | |
152 } | |
153 | |
154 Sparse<T> m = a; | |
155 octave_idx_type nc = m.cols(); | |
156 | |
157 for (octave_idx_type j = 0; j < nc; j++) | |
158 for (octave_idx_type i = m.cidx(j); i < m.cidx(j+1); i++) | |
159 if (m.ridx(i) > j-k) | |
160 m.data(i) = 0.; | |
161 | |
162 m.maybe_compress (true); | |
163 return m; | |
164 } | |
165 | |
166 // Convenience dispatchers. | |
167 template <class T> | |
168 static Array<T> | |
169 do_trilu (const Array<T>& a, octave_idx_type k, bool lower, bool pack) | |
170 { | |
171 return lower ? do_tril (a, k, pack) : do_triu (a, k, pack); | |
172 } | |
173 | |
174 template <class T> | |
175 static Sparse<T> | |
176 do_trilu (const Sparse<T>& a, octave_idx_type k, bool lower, bool pack) | |
177 { | |
178 return lower ? do_tril (a, k, pack) : do_triu (a, k, pack); | |
179 } | |
180 | |
181 static octave_value | |
182 do_trilu (const std::string& name, | |
183 const octave_value_list& args) | |
184 { | |
185 bool lower = name == "tril"; | |
186 | |
187 octave_value retval; | |
188 int nargin = args.length (); | |
189 octave_idx_type k = 0; | |
190 bool pack = false; | |
191 if (nargin >= 2 && args(nargin-1).is_string ()) | |
192 { | |
193 pack = args(nargin-1).string_value () == "pack"; | |
194 nargin--; | |
195 } | |
196 | |
197 if (nargin == 2) | |
198 { | |
199 k = args(1).int_value (true); | |
200 | |
201 if (error_state) | |
202 return retval; | |
203 } | |
204 | |
205 if (nargin < 1 || nargin > 2) | |
206 print_usage (); | |
207 else | |
208 { | |
209 octave_value arg = args (0); | |
210 | |
211 dim_vector dims = arg.dims (); | |
212 if (dims.length () != 2) | |
213 error ("%s: needs a 2D matrix", name.c_str ()); | |
214 else if (k < -dims (0) || k > dims(1)) | |
215 error ("%s: requested diagonal out of range", name.c_str ()); | |
216 else | |
217 { | |
218 switch (arg.builtin_type ()) | |
219 { | |
220 case btyp_double: | |
221 if (arg.is_sparse_type ()) | |
222 retval = do_trilu (arg.sparse_matrix_value (), k, lower, pack); | |
223 else | |
224 retval = do_trilu (arg.array_value (), k, lower, pack); | |
225 break; | |
226 case btyp_complex: | |
227 if (arg.is_sparse_type ()) | |
228 retval = do_trilu (arg.sparse_complex_matrix_value (), k, lower, pack); | |
229 else | |
230 retval = do_trilu (arg.complex_array_value (), k, lower, pack); | |
231 break; | |
232 case btyp_bool: | |
233 if (arg.is_sparse_type ()) | |
234 retval = do_trilu (arg.sparse_bool_matrix_value (), k, lower, pack); | |
235 else | |
236 retval = do_trilu (arg.bool_array_value (), k, lower, pack); | |
237 break; | |
238 #define ARRAYCASE(TYP) \ | |
239 case btyp_ ## TYP: \ | |
240 retval = do_trilu (arg.TYP ## _array_value (), k, lower, pack); \ | |
241 break | |
242 ARRAYCASE (float); | |
243 ARRAYCASE (float_complex); | |
244 ARRAYCASE (int8); | |
245 ARRAYCASE (int16); | |
246 ARRAYCASE (int32); | |
247 ARRAYCASE (int64); | |
248 ARRAYCASE (uint8); | |
249 ARRAYCASE (uint16); | |
250 ARRAYCASE (uint32); | |
251 ARRAYCASE (uint64); | |
252 ARRAYCASE (char); | |
253 #undef ARRAYCASE | |
254 default: | |
255 { | |
256 // Generic code that works on octave-values, that is slow | |
257 // but will also work on arbitrary user types | |
258 | |
259 if (pack) // FIXME | |
260 { | |
261 error ("%s: \"pack\" not implemented for class %s", | |
262 name.c_str (), arg.class_name ().c_str ()); | |
263 return octave_value (); | |
264 } | |
265 | |
266 octave_value tmp = arg; | |
267 if (arg.numel () == 0) | |
268 return arg; | |
269 | |
270 octave_idx_type nr = dims(0), nc = dims (1); | |
271 | |
272 // The sole purpose of the below is to force the correct | |
273 // matrix size. This would not be necessary if the | |
274 // octave_value resize function allowed a fill_value. | |
275 // It also allows odd attributes in some user types | |
276 // to be handled. With a fill_value ot should be replaced | |
277 // with | |
278 // | |
279 // octave_value_list ov_idx; | |
280 // tmp = tmp.resize(dim_vector (0,0)).resize (dims, fill_value); | |
281 | |
282 octave_value_list ov_idx; | |
283 std::list<octave_value_list> idx_tmp; | |
284 ov_idx(1) = static_cast<double> (nc+1); | |
285 ov_idx(0) = Range (1, nr); | |
286 idx_tmp.push_back (ov_idx); | |
287 ov_idx(1) = static_cast<double> (nc); | |
288 tmp = tmp.resize (dim_vector (0,0)); | |
289 tmp = tmp.subsasgn("(",idx_tmp, arg.do_index_op (ov_idx)); | |
290 tmp = tmp.resize(dims); | |
291 | |
292 if (lower) | |
293 { | |
294 octave_idx_type st = nc < nr + k ? nc : nr + k; | |
295 | |
296 for (octave_idx_type j = 1; j <= st; j++) | |
297 { | |
298 octave_idx_type nr_limit = 1 > j - k ? 1 : j - k; | |
299 ov_idx(1) = static_cast<double> (j); | |
300 ov_idx(0) = Range (nr_limit, nr); | |
301 std::list<octave_value_list> idx; | |
302 idx.push_back (ov_idx); | |
303 | |
304 tmp = tmp.subsasgn ("(", idx, arg.do_index_op(ov_idx)); | |
305 | |
306 if (error_state) | |
307 return retval; | |
308 } | |
309 } | |
310 else | |
311 { | |
312 octave_idx_type st = k + 1 > 1 ? k + 1 : 1; | |
313 | |
314 for (octave_idx_type j = st; j <= nc; j++) | |
315 { | |
316 octave_idx_type nr_limit = nr < j - k ? nr : j - k; | |
317 ov_idx(1) = static_cast<double> (j); | |
318 ov_idx(0) = Range (1, nr_limit); | |
319 std::list<octave_value_list> idx; | |
320 idx.push_back (ov_idx); | |
321 | |
322 tmp = tmp.subsasgn ("(", idx, arg.do_index_op(ov_idx)); | |
323 | |
324 if (error_state) | |
325 return retval; | |
326 } | |
327 } | |
328 | |
329 retval = tmp; | |
330 } | |
331 } | |
332 } | |
333 } | |
334 | |
335 return retval; | |
336 } | |
337 | |
338 DEFUN_DLD (tril, args, , | |
339 "-*- texinfo -*-\n\ | |
340 @deftypefn {Function File} {} tril (@var{a}, @var{k}, @var{pack})\n\ | |
341 @deftypefnx {Function File} {} triu (@var{a}, @var{k}, @var{pack})\n\ | |
342 Return a new matrix formed by extracting extract the lower (@code{tril})\n\ | |
343 or upper (@code{triu}) triangular part of the matrix @var{a}, and\n\ | |
344 setting all other elements to zero. The second argument is optional,\n\ | |
345 and specifies how many diagonals above or below the main diagonal should\n\ | |
346 also be set to zero.\n\ | |
347 \n\ | |
348 The default value of @var{k} is zero, so that @code{triu} and\n\ | |
349 @code{tril} normally include the main diagonal as part of the result\n\ | |
350 matrix.\n\ | |
351 \n\ | |
352 If the value of @var{k} is negative, additional elements above (for\n\ | |
353 @code{tril}) or below (for @code{triu}) the main diagonal are also\n\ | |
354 selected.\n\ | |
355 \n\ | |
356 The absolute value of @var{k} must not be greater than the number of\n\ | |
357 sub- or super-diagonals.\n\ | |
358 \n\ | |
359 For example,\n\ | |
360 \n\ | |
361 @example\n\ | |
362 @group\n\ | |
363 tril (ones (3), -1)\n\ | |
364 @result{} 0 0 0\n\ | |
365 1 0 0\n\ | |
366 1 1 0\n\ | |
367 @end group\n\ | |
368 @end example\n\ | |
369 \n\ | |
370 @noindent\n\ | |
371 and\n\ | |
372 \n\ | |
373 @example\n\ | |
374 @group\n\ | |
375 tril (ones (3), 1)\n\ | |
376 @result{} 1 1 0\n\ | |
377 1 1 1\n\ | |
378 1 1 1\n\ | |
379 @end group\n\ | |
380 @end example\n\ | |
381 \n\ | |
382 If the option \"pack\" is given as third argument, the extracted elements\n\ | |
383 are not inserted into a matrix, but rather stacked column-wise one above other.\n\ | |
384 @seealso{triu, diag}\n\ | |
385 @end deftypefn") | |
386 { | |
387 return do_trilu ("tril", args); | |
388 } | |
389 | |
390 DEFUN_DLD (triu, args, , | |
391 "-*- texinfo -*-\n\ | |
392 @deftypefn {Function File} {} triu (@var{a}, @var{k})\n\ | |
393 See tril.\n\ | |
394 @end deftypefn") | |
395 { | |
396 return do_trilu ("triu", args); | |
397 } | |
398 | |
399 /* | |
400 | |
401 %!test | |
402 %! a = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12]; | |
403 %! | |
404 %! l0 = [1, 0, 0; 4, 5, 0; 7, 8, 9; 10, 11, 12]; | |
405 %! l1 = [1, 2, 0; 4, 5, 6; 7, 8, 9; 10, 11, 12]; | |
406 %! l2 = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12]; | |
407 %! lm1 = [0, 0, 0; 4, 0, 0; 7, 8, 0; 10, 11, 12]; | |
408 %! lm2 = [0, 0, 0; 0, 0, 0; 7, 0, 0; 10, 11, 0]; | |
409 %! lm3 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 10, 0, 0]; | |
410 %! lm4 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 0, 0, 0]; | |
411 %! | |
412 %! assert((tril (a, -4) == lm4 && tril (a, -3) == lm3 | |
413 %! && tril (a, -2) == lm2 && tril (a, -1) == lm1 | |
414 %! && tril (a) == l0 && tril (a, 1) == l1 && tril (a, 2) == l2)); | |
415 | |
416 %!error tril (); | |
417 | |
418 */ | |
419 | |
420 /* | |
421 ;;; Local Variables: *** | |
422 ;;; mode: C++ *** | |
423 ;;; End: *** | |
424 */ |