comparison src/corefcn/kron.cc @ 15039:e753177cde93

maint: Move non-dynamically linked functions from DLD-FUNCTIONS/ to corefcn/ directory * __contourc__.cc, __dispatch__.cc, __lin_interpn__.cc, __pchip_deriv__.cc, __qp__.cc, balance.cc, besselj.cc, betainc.cc, bsxfun.cc, cellfun.cc, colloc.cc, conv2.cc, daspk.cc, dasrt.cc, dassl.cc, det.cc, dlmread.cc, dot.cc, eig.cc, fft.cc, fft2.cc, fftn.cc, filter.cc, find.cc, gammainc.cc, gcd.cc, getgrent.cc, getpwent.cc, getrusage.cc, givens.cc, hess.cc, hex2num.cc, inv.cc, kron.cc, lookup.cc, lsode.cc, lu.cc, luinc.cc, matrix_type.cc, max.cc, md5sum.cc, mgorth.cc, nproc.cc, pinv.cc, quad.cc, quadcc.cc, qz.cc, rand.cc, rcond.cc, regexp.cc, schur.cc, spparms.cc, sqrtm.cc, str2double.cc, strfind.cc, sub2ind.cc, svd.cc, syl.cc, time.cc, tril.cc, typecast.cc: Move functions from DLD-FUNCTIONS/ to corefcn/ directory. Include "defun.h", not "defun-dld.h". Change docstring to refer to these as "Built-in Functions". * build-aux/mk-opts.pl: Generate options code with '#include "defun.h"'. Change option docstrings to refer to these as "Built-in Functions". * corefcn/module.mk: List of functions to build in corefcn/ dir. * DLD-FUNCTIONS/config-module.awk: Update to new build system. * DLD-FUNCTIONS/module-files: Remove functions which are now in corefcn/ directory. * src/Makefile.am: Update to build "convenience library" in corefcn/. Octave program now links against all other libraries + corefcn libary. * src/find-defun-files.sh: Strip $srcdir from filename. * src/link-deps.mk: Add REGEX and FFTW link dependencies for liboctinterp. * type.m, which.m: Change failing tests to use 'amd', still a dynamic function, rather than 'dot', which isn't.
author Rik <rik@octave.org>
date Fri, 27 Jul 2012 15:35:00 -0700
parents src/DLD-FUNCTIONS/kron.cc@5ae9f0f77635
children
comparison
equal deleted inserted replaced
15038:ab18578c2ade 15039:e753177cde93
1 /*
2
3 Copyright (C) 2002-2012 John W. Eaton
4
5 This file is part of Octave.
6
7 Octave is free software; you can redistribute it and/or modify it
8 under the terms of the GNU General Public License as published by the
9 Free Software Foundation; either version 3 of the License, or (at your
10 option) any later version.
11
12 Octave is distributed in the hope that it will be useful, but WITHOUT
13 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
14 FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15 for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with Octave; see the file COPYING. If not, see
19 <http://www.gnu.org/licenses/>.
20
21 */
22
23 // Author: Paul Kienzle <pkienzle@users.sf.net>
24
25 #ifdef HAVE_CONFIG_H
26 #include <config.h>
27 #endif
28
29 #include "dMatrix.h"
30 #include "fMatrix.h"
31 #include "CMatrix.h"
32 #include "fCMatrix.h"
33
34 #include "dSparse.h"
35 #include "CSparse.h"
36
37 #include "dDiagMatrix.h"
38 #include "fDiagMatrix.h"
39 #include "CDiagMatrix.h"
40 #include "fCDiagMatrix.h"
41
42 #include "PermMatrix.h"
43
44 #include "mx-inlines.cc"
45 #include "quit.h"
46
47 #include "defun.h"
48 #include "error.h"
49 #include "oct-obj.h"
50
51 template <class R, class T>
52 static MArray<T>
53 kron (const MArray<R>& a, const MArray<T>& b)
54 {
55 assert (a.ndims () == 2);
56 assert (b.ndims () == 2);
57
58 octave_idx_type nra = a.rows (), nrb = b.rows ();
59 octave_idx_type nca = a.cols (), ncb = b.cols ();
60
61 MArray<T> c (dim_vector (nra*nrb, nca*ncb));
62 T *cv = c.fortran_vec ();
63
64 for (octave_idx_type ja = 0; ja < nca; ja++)
65 for (octave_idx_type jb = 0; jb < ncb; jb++)
66 for (octave_idx_type ia = 0; ia < nra; ia++)
67 {
68 octave_quit ();
69 mx_inline_mul (nrb, cv, a(ia, ja), b.data () + nrb*jb);
70 cv += nrb;
71 }
72
73 return c;
74 }
75
76 template <class R, class T>
77 static MArray<T>
78 kron (const MDiagArray2<R>& a, const MArray<T>& b)
79 {
80 assert (b.ndims () == 2);
81
82 octave_idx_type nra = a.rows (), nrb = b.rows (), dla = a.diag_length ();
83 octave_idx_type nca = a.cols (), ncb = b.cols ();
84
85 MArray<T> c (dim_vector (nra*nrb, nca*ncb), T ());
86
87 for (octave_idx_type ja = 0; ja < dla; ja++)
88 for (octave_idx_type jb = 0; jb < ncb; jb++)
89 {
90 octave_quit ();
91 mx_inline_mul (nrb, &c.xelem (ja*nrb, ja*ncb + jb), a.dgelem (ja), b.data () + nrb*jb);
92 }
93
94 return c;
95 }
96
97 template <class T>
98 static MSparse<T>
99 kron (const MSparse<T>& A, const MSparse<T>& B)
100 {
101 octave_idx_type idx = 0;
102 MSparse<T> C (A.rows () * B.rows (), A.columns () * B.columns (),
103 A.nnz () * B.nnz ());
104
105 C.cidx (0) = 0;
106
107 for (octave_idx_type Aj = 0; Aj < A.columns (); Aj++)
108 for (octave_idx_type Bj = 0; Bj < B.columns (); Bj++)
109 {
110 octave_quit ();
111 for (octave_idx_type Ai = A.cidx (Aj); Ai < A.cidx (Aj+1); Ai++)
112 {
113 octave_idx_type Ci = A.ridx (Ai) * B.rows ();
114 const T v = A.data (Ai);
115
116 for (octave_idx_type Bi = B.cidx (Bj); Bi < B.cidx (Bj+1); Bi++)
117 {
118 C.data (idx) = v * B.data (Bi);
119 C.ridx (idx++) = Ci + B.ridx (Bi);
120 }
121 }
122 C.cidx (Aj * B.columns () + Bj + 1) = idx;
123 }
124
125 return C;
126 }
127
128 static PermMatrix
129 kron (const PermMatrix& a, const PermMatrix& b)
130 {
131 octave_idx_type na = a.rows (), nb = b.rows ();
132 const octave_idx_type *pa = a.data (), *pb = b.data ();
133 PermMatrix c(na*nb); // Row permutation.
134 octave_idx_type *pc = c.fortran_vec ();
135
136 bool cola = a.is_col_perm (), colb = b.is_col_perm ();
137 if (cola && colb)
138 {
139 for (octave_idx_type i = 0; i < na; i++)
140 for (octave_idx_type j = 0; j < nb; j++)
141 pc[pa[i]*nb+pb[j]] = i*nb+j;
142 }
143 else if (cola)
144 {
145 for (octave_idx_type i = 0; i < na; i++)
146 for (octave_idx_type j = 0; j < nb; j++)
147 pc[pa[i]*nb+j] = i*nb+pb[j];
148 }
149 else if (colb)
150 {
151 for (octave_idx_type i = 0; i < na; i++)
152 for (octave_idx_type j = 0; j < nb; j++)
153 pc[i*nb+pb[j]] = pa[i]*nb+j;
154 }
155 else
156 {
157 for (octave_idx_type i = 0; i < na; i++)
158 for (octave_idx_type j = 0; j < nb; j++)
159 pc[i*nb+j] = pa[i]*nb+pb[j];
160 }
161
162 return c;
163 }
164
165 template <class MTA, class MTB>
166 octave_value
167 do_kron (const octave_value& a, const octave_value& b)
168 {
169 MTA am = octave_value_extract<MTA> (a);
170 MTB bm = octave_value_extract<MTB> (b);
171 return octave_value (kron (am, bm));
172 }
173
174 octave_value
175 dispatch_kron (const octave_value& a, const octave_value& b)
176 {
177 octave_value retval;
178 if (a.is_perm_matrix () && b.is_perm_matrix ())
179 retval = do_kron<PermMatrix, PermMatrix> (a, b);
180 else if (a.is_diag_matrix ())
181 {
182 if (b.is_diag_matrix () && a.rows () == a.columns ()
183 && b.rows () == b.columns ())
184 {
185 // We have two diagonal matrices, the product of those will be
186 // another diagonal matrix. To do that efficiently, extract
187 // the diagonals as vectors and compute the product. That
188 // will be another vector, which we then use to construct a
189 // diagonal matrix object. Note that this will fail if our
190 // digaonal matrix object is modified to allow the non-zero
191 // values to be stored off of the principal diagonal (i.e., if
192 // diag ([1,2], 3) is modified to return a diagonal matrix
193 // object instead of a full matrix object).
194
195 octave_value tmp = dispatch_kron (a.diag (), b.diag ());
196 retval = tmp.diag ();
197 }
198 else if (a.is_single_type () || b.is_single_type ())
199 {
200 if (a.is_complex_type ())
201 retval = do_kron<FloatComplexDiagMatrix, FloatComplexMatrix> (a, b);
202 else if (b.is_complex_type ())
203 retval = do_kron<FloatDiagMatrix, FloatComplexMatrix> (a, b);
204 else
205 retval = do_kron<FloatDiagMatrix, FloatMatrix> (a, b);
206 }
207 else
208 {
209 if (a.is_complex_type ())
210 retval = do_kron<ComplexDiagMatrix, ComplexMatrix> (a, b);
211 else if (b.is_complex_type ())
212 retval = do_kron<DiagMatrix, ComplexMatrix> (a, b);
213 else
214 retval = do_kron<DiagMatrix, Matrix> (a, b);
215 }
216 }
217 else if (a.is_sparse_type () || b.is_sparse_type ())
218 {
219 if (a.is_complex_type () || b.is_complex_type ())
220 retval = do_kron<SparseComplexMatrix, SparseComplexMatrix> (a, b);
221 else
222 retval = do_kron<SparseMatrix, SparseMatrix> (a, b);
223 }
224 else if (a.is_single_type () || b.is_single_type ())
225 {
226 if (a.is_complex_type ())
227 retval = do_kron<FloatComplexMatrix, FloatComplexMatrix> (a, b);
228 else if (b.is_complex_type ())
229 retval = do_kron<FloatMatrix, FloatComplexMatrix> (a, b);
230 else
231 retval = do_kron<FloatMatrix, FloatMatrix> (a, b);
232 }
233 else
234 {
235 if (a.is_complex_type ())
236 retval = do_kron<ComplexMatrix, ComplexMatrix> (a, b);
237 else if (b.is_complex_type ())
238 retval = do_kron<Matrix, ComplexMatrix> (a, b);
239 else
240 retval = do_kron<Matrix, Matrix> (a, b);
241 }
242 return retval;
243 }
244
245
246 DEFUN (kron, args, , "-*- texinfo -*-\n\
247 @deftypefn {Built-in Function} {} kron (@var{A}, @var{B})\n\
248 @deftypefnx {Built-in Function} {} kron (@var{A1}, @var{A2}, @dots{})\n\
249 Form the Kronecker product of two or more matrices, defined block by \n\
250 block as\n\
251 \n\
252 @example\n\
253 x = [ a(i,j)*b ]\n\
254 @end example\n\
255 \n\
256 For example:\n\
257 \n\
258 @example\n\
259 @group\n\
260 kron (1:4, ones (3, 1))\n\
261 @result{} 1 2 3 4\n\
262 1 2 3 4\n\
263 1 2 3 4\n\
264 @end group\n\
265 @end example\n\
266 \n\
267 If there are more than two input arguments @var{A1}, @var{A2}, @dots{}, \n\
268 @var{An} the Kronecker product is computed as\n\
269 \n\
270 @example\n\
271 kron (kron (@var{A1}, @var{A2}), @dots{}, @var{An})\n\
272 @end example\n\
273 \n\
274 @noindent\n\
275 Since the Kronecker product is associative, this is well-defined.\n\
276 @end deftypefn")
277 {
278 octave_value retval;
279
280 int nargin = args.length ();
281
282 if (nargin >= 2)
283 {
284 octave_value a = args(0), b = args(1);
285 retval = dispatch_kron (a, b);
286 for (octave_idx_type i = 2; i < nargin; i++)
287 retval = dispatch_kron (retval, args(i));
288 }
289 else
290 print_usage ();
291
292 return retval;
293 }
294
295
296 /*
297 %!test
298 %! x = ones (2);
299 %! assert (kron (x, x), ones (4));
300
301 %!shared x, y, z
302 %! x = [1, 2];
303 %! y = [-1, -2];
304 %! z = [1, 2, 3, 4; 1, 2, 3, 4; 1, 2, 3, 4];
305 %!assert (kron (1:4, ones (3, 1)), z)
306 %!assert (kron (x, y, z), kron (kron (x, y), z))
307 %!assert (kron (x, y, z), kron (x, kron (y, z)))
308
309 %!assert (kron (diag ([1, 2]), diag ([3, 4])), diag ([3, 4, 6, 8]))
310
311 %% Test for two diag matrices. See the comments above in
312 %% dispatch_kron for this case.
313 %%
314 %!test
315 %! expected = zeros (16, 16);
316 %! expected (1, 11) = 3;
317 %! expected (2, 12) = 4;
318 %! expected (5, 15) = 6;
319 %! expected (6, 16) = 8;
320 %! assert (kron (diag ([1, 2], 2), diag ([3, 4], 2)), expected)
321 */