comparison src/DLD-FUNCTIONS/tril.cc @ 9756:b134960cea23

implement built-in tril/triu
author Jaroslav Hajek <highegg@gmail.com>
date Fri, 23 Oct 2009 10:10:37 +0200
parents
children 40dfc0c99116
comparison
equal deleted inserted replaced
9755:4f4873f6f873 9756:b134960cea23
1 /*
2
3 Copyright (C) 2004, 2007 David Bateman
4 Copyright (C) 2009 VZLU Prague
5
6 This program is free software; you can redistribute it and/or modify it
7 under the terms of the GNU General Public License as published by
8 the Free Software Foundation; either version 3, or (at your option)
9 any later version.
10
11 This program is distributed in the hope that it will be useful, but
12 WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with Octave; see the file COPYING. If not, see
18 <http://www.gnu.org/licenses/>.
19
20 */
21
22 #ifdef HAVE_CONFIG_H
23 #include <config.h>
24 #endif
25
26 #include <algorithm>
27 #include "Array.h"
28 #include "Sparse.h"
29 #include "mx-base.h"
30
31 #include "ov.h"
32 #include "Cell.h"
33
34 #include "defun-dld.h"
35 #include "error.h"
36 #include "oct-obj.h"
37
38 // The bulk of the work.
39 template <class T>
40 static Array<T>
41 do_tril (const Array<T>& a, octave_idx_type k, bool pack)
42 {
43 octave_idx_type nr = a.rows (), nc = a.columns ();
44 const T *avec = a.fortran_vec ();
45
46 if (pack)
47 {
48 octave_idx_type j1 = std::min (std::max (0, k), nc);
49 octave_idx_type j2 = std::min (std::max (0, nr + k), nc);
50 octave_idx_type n = j1 * nr + ((j2 - j1) * (nr-(j1-k) + nr-(j2-1-k))) / 2;
51 Array<T> r (n);
52 T *rvec = r.fortran_vec ();
53 for (octave_idx_type j = 0; j < nc; j++)
54 {
55 octave_idx_type ii = std::min (std::max (0, j - k), nr);
56 rvec = std::copy (avec + ii, avec + nr, rvec);
57 avec += nr;
58 }
59
60 return r;
61 }
62 else
63 {
64 Array<T> r (a.dims ());
65 T *rvec = r.fortran_vec ();
66 for (octave_idx_type j = 0; j < nc; j++)
67 {
68 octave_idx_type ii = std::min (std::max (0, j - k), nr);
69 std::fill (rvec, rvec + ii, T());
70 std::copy (avec + ii, avec + nr, rvec + ii);
71 avec += nr;
72 rvec += nr;
73 }
74
75 return r;
76 }
77 }
78
79 template <class T>
80 static Array<T>
81 do_triu (const Array<T>& a, octave_idx_type k, bool pack)
82 {
83 octave_idx_type nr = a.rows (), nc = a.columns ();
84 const T *avec = a.fortran_vec ();
85
86 if (pack)
87 {
88 octave_idx_type j1 = std::min (std::max (0, k), nc);
89 octave_idx_type j2 = std::min (std::max (0, nr + k), nc);
90 octave_idx_type n = ((j2 - j1) * ((j1+1-k) + (j2-k))) / 2 + (nc - j2) * nr;
91 Array<T> r (n);
92 T *rvec = r.fortran_vec ();
93 for (octave_idx_type j = 0; j < nc; j++)
94 {
95 octave_idx_type ii = std::min (std::max (0, j + 1 - k), nr);
96 rvec = std::copy (avec, avec + ii, rvec);
97 avec += nr;
98 }
99
100 return r;
101 }
102 else
103 {
104 NoAlias<Array<T> > r (a.dims ());
105 T *rvec = r.fortran_vec ();
106 for (octave_idx_type j = 0; j < nc; j++)
107 {
108 octave_idx_type ii = std::min (std::max (0, j + 1 - k), nr);
109 std::copy (avec, avec + ii, rvec);
110 std::fill (rvec + ii, rvec + nr, T());
111 avec += nr;
112 rvec += nr;
113 }
114
115 return r;
116 }
117 }
118
119 // These two are by David Bateman.
120 // FIXME: optimizations possible. "pack" support missing.
121
122 template <class T>
123 static Sparse<T>
124 do_tril (const Sparse<T>& a, octave_idx_type k, bool pack)
125 {
126 if (pack) // FIXME
127 {
128 error ("tril: \"pack\" not implemented for sparse matrices");
129 return Sparse<T> ();
130 }
131
132 Sparse<T> m = a;
133 octave_idx_type nc = m.cols();
134
135 for (octave_idx_type j = 0; j < nc; j++)
136 for (octave_idx_type i = m.cidx(j); i < m.cidx(j+1); i++)
137 if (m.ridx(i) < j-k)
138 m.data(i) = 0.;
139
140 m.maybe_compress (true);
141 return m;
142 }
143
144 template <class T>
145 static Sparse<T>
146 do_triu (const Sparse<T>& a, octave_idx_type k, bool pack)
147 {
148 if (pack) // FIXME
149 {
150 error ("triu: \"pack\" not implemented for sparse matrices");
151 return Sparse<T> ();
152 }
153
154 Sparse<T> m = a;
155 octave_idx_type nc = m.cols();
156
157 for (octave_idx_type j = 0; j < nc; j++)
158 for (octave_idx_type i = m.cidx(j); i < m.cidx(j+1); i++)
159 if (m.ridx(i) > j-k)
160 m.data(i) = 0.;
161
162 m.maybe_compress (true);
163 return m;
164 }
165
166 // Convenience dispatchers.
167 template <class T>
168 static Array<T>
169 do_trilu (const Array<T>& a, octave_idx_type k, bool lower, bool pack)
170 {
171 return lower ? do_tril (a, k, pack) : do_triu (a, k, pack);
172 }
173
174 template <class T>
175 static Sparse<T>
176 do_trilu (const Sparse<T>& a, octave_idx_type k, bool lower, bool pack)
177 {
178 return lower ? do_tril (a, k, pack) : do_triu (a, k, pack);
179 }
180
181 static octave_value
182 do_trilu (const std::string& name,
183 const octave_value_list& args)
184 {
185 bool lower = name == "tril";
186
187 octave_value retval;
188 int nargin = args.length ();
189 octave_idx_type k = 0;
190 bool pack = false;
191 if (nargin >= 2 && args(nargin-1).is_string ())
192 {
193 pack = args(nargin-1).string_value () == "pack";
194 nargin--;
195 }
196
197 if (nargin == 2)
198 {
199 k = args(1).int_value (true);
200
201 if (error_state)
202 return retval;
203 }
204
205 if (nargin < 1 || nargin > 2)
206 print_usage ();
207 else
208 {
209 octave_value arg = args (0);
210
211 dim_vector dims = arg.dims ();
212 if (dims.length () != 2)
213 error ("%s: needs a 2D matrix", name.c_str ());
214 else if (k < -dims (0) || k > dims(1))
215 error ("%s: requested diagonal out of range", name.c_str ());
216 else
217 {
218 switch (arg.builtin_type ())
219 {
220 case btyp_double:
221 if (arg.is_sparse_type ())
222 retval = do_trilu (arg.sparse_matrix_value (), k, lower, pack);
223 else
224 retval = do_trilu (arg.array_value (), k, lower, pack);
225 break;
226 case btyp_complex:
227 if (arg.is_sparse_type ())
228 retval = do_trilu (arg.sparse_complex_matrix_value (), k, lower, pack);
229 else
230 retval = do_trilu (arg.complex_array_value (), k, lower, pack);
231 break;
232 case btyp_bool:
233 if (arg.is_sparse_type ())
234 retval = do_trilu (arg.sparse_bool_matrix_value (), k, lower, pack);
235 else
236 retval = do_trilu (arg.bool_array_value (), k, lower, pack);
237 break;
238 #define ARRAYCASE(TYP) \
239 case btyp_ ## TYP: \
240 retval = do_trilu (arg.TYP ## _array_value (), k, lower, pack); \
241 break
242 ARRAYCASE (float);
243 ARRAYCASE (float_complex);
244 ARRAYCASE (int8);
245 ARRAYCASE (int16);
246 ARRAYCASE (int32);
247 ARRAYCASE (int64);
248 ARRAYCASE (uint8);
249 ARRAYCASE (uint16);
250 ARRAYCASE (uint32);
251 ARRAYCASE (uint64);
252 ARRAYCASE (char);
253 #undef ARRAYCASE
254 default:
255 {
256 // Generic code that works on octave-values, that is slow
257 // but will also work on arbitrary user types
258
259 if (pack) // FIXME
260 {
261 error ("%s: \"pack\" not implemented for class %s",
262 name.c_str (), arg.class_name ().c_str ());
263 return octave_value ();
264 }
265
266 octave_value tmp = arg;
267 if (arg.numel () == 0)
268 return arg;
269
270 octave_idx_type nr = dims(0), nc = dims (1);
271
272 // The sole purpose of the below is to force the correct
273 // matrix size. This would not be necessary if the
274 // octave_value resize function allowed a fill_value.
275 // It also allows odd attributes in some user types
276 // to be handled. With a fill_value ot should be replaced
277 // with
278 //
279 // octave_value_list ov_idx;
280 // tmp = tmp.resize(dim_vector (0,0)).resize (dims, fill_value);
281
282 octave_value_list ov_idx;
283 std::list<octave_value_list> idx_tmp;
284 ov_idx(1) = static_cast<double> (nc+1);
285 ov_idx(0) = Range (1, nr);
286 idx_tmp.push_back (ov_idx);
287 ov_idx(1) = static_cast<double> (nc);
288 tmp = tmp.resize (dim_vector (0,0));
289 tmp = tmp.subsasgn("(",idx_tmp, arg.do_index_op (ov_idx));
290 tmp = tmp.resize(dims);
291
292 if (lower)
293 {
294 octave_idx_type st = nc < nr + k ? nc : nr + k;
295
296 for (octave_idx_type j = 1; j <= st; j++)
297 {
298 octave_idx_type nr_limit = 1 > j - k ? 1 : j - k;
299 ov_idx(1) = static_cast<double> (j);
300 ov_idx(0) = Range (nr_limit, nr);
301 std::list<octave_value_list> idx;
302 idx.push_back (ov_idx);
303
304 tmp = tmp.subsasgn ("(", idx, arg.do_index_op(ov_idx));
305
306 if (error_state)
307 return retval;
308 }
309 }
310 else
311 {
312 octave_idx_type st = k + 1 > 1 ? k + 1 : 1;
313
314 for (octave_idx_type j = st; j <= nc; j++)
315 {
316 octave_idx_type nr_limit = nr < j - k ? nr : j - k;
317 ov_idx(1) = static_cast<double> (j);
318 ov_idx(0) = Range (1, nr_limit);
319 std::list<octave_value_list> idx;
320 idx.push_back (ov_idx);
321
322 tmp = tmp.subsasgn ("(", idx, arg.do_index_op(ov_idx));
323
324 if (error_state)
325 return retval;
326 }
327 }
328
329 retval = tmp;
330 }
331 }
332 }
333 }
334
335 return retval;
336 }
337
338 DEFUN_DLD (tril, args, ,
339 "-*- texinfo -*-\n\
340 @deftypefn {Function File} {} tril (@var{a}, @var{k}, @var{pack})\n\
341 @deftypefnx {Function File} {} triu (@var{a}, @var{k}, @var{pack})\n\
342 Return a new matrix formed by extracting extract the lower (@code{tril})\n\
343 or upper (@code{triu}) triangular part of the matrix @var{a}, and\n\
344 setting all other elements to zero. The second argument is optional,\n\
345 and specifies how many diagonals above or below the main diagonal should\n\
346 also be set to zero.\n\
347 \n\
348 The default value of @var{k} is zero, so that @code{triu} and\n\
349 @code{tril} normally include the main diagonal as part of the result\n\
350 matrix.\n\
351 \n\
352 If the value of @var{k} is negative, additional elements above (for\n\
353 @code{tril}) or below (for @code{triu}) the main diagonal are also\n\
354 selected.\n\
355 \n\
356 The absolute value of @var{k} must not be greater than the number of\n\
357 sub- or super-diagonals.\n\
358 \n\
359 For example,\n\
360 \n\
361 @example\n\
362 @group\n\
363 tril (ones (3), -1)\n\
364 @result{} 0 0 0\n\
365 1 0 0\n\
366 1 1 0\n\
367 @end group\n\
368 @end example\n\
369 \n\
370 @noindent\n\
371 and\n\
372 \n\
373 @example\n\
374 @group\n\
375 tril (ones (3), 1)\n\
376 @result{} 1 1 0\n\
377 1 1 1\n\
378 1 1 1\n\
379 @end group\n\
380 @end example\n\
381 \n\
382 If the option \"pack\" is given as third argument, the extracted elements\n\
383 are not inserted into a matrix, but rather stacked column-wise one above other.\n\
384 @seealso{triu, diag}\n\
385 @end deftypefn")
386 {
387 return do_trilu ("tril", args);
388 }
389
390 DEFUN_DLD (triu, args, ,
391 "-*- texinfo -*-\n\
392 @deftypefn {Function File} {} triu (@var{a}, @var{k})\n\
393 See tril.\n\
394 @end deftypefn")
395 {
396 return do_trilu ("triu", args);
397 }
398
399 /*
400
401 %!test
402 %! a = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12];
403 %!
404 %! l0 = [1, 0, 0; 4, 5, 0; 7, 8, 9; 10, 11, 12];
405 %! l1 = [1, 2, 0; 4, 5, 6; 7, 8, 9; 10, 11, 12];
406 %! l2 = [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12];
407 %! lm1 = [0, 0, 0; 4, 0, 0; 7, 8, 0; 10, 11, 12];
408 %! lm2 = [0, 0, 0; 0, 0, 0; 7, 0, 0; 10, 11, 0];
409 %! lm3 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 10, 0, 0];
410 %! lm4 = [0, 0, 0; 0, 0, 0; 0, 0, 0; 0, 0, 0];
411 %!
412 %! assert((tril (a, -4) == lm4 && tril (a, -3) == lm3
413 %! && tril (a, -2) == lm2 && tril (a, -1) == lm1
414 %! && tril (a) == l0 && tril (a, 1) == l1 && tril (a, 2) == l2));
415
416 %!error tril ();
417
418 */
419
420 /*
421 ;;; Local Variables: ***
422 ;;; mode: C++ ***
423 ;;; End: ***
424 */