Mercurial > hg > octave-nkf
comparison liboctave/numeric/floatCHOL.cc @ 20672:5ce959c55cc0
Propagate 'lower' in chol(a, 'lower') to underlying library function.
* chol.cc (chol): Send 'L' parameter correctly when chol is called with 'lower'.
* floatCHOL.cc (init): Propagate 'lower' to underlying library function.
* floatCHOL.h: Modify the prototype of methods.
* fMatrix.cc (inverse): Invoke chol with additional parameter.
* dbleCHOL.cc (init): Propagate 'lower' to underlying library function.
* dbleCHOL.h: Modify the prototype of methods.
* dMatrix.cc (inverse): Invoke chol with additional parameter.
* CmplxCHOL.cc (init): Propagate 'lower' to underlying library function.
* CmplxCHOL.h: Modify the prototype of methods.
* CMatrix.cc (inverse): Invoke chol with additional parameter.
author | PrasannaKumar Muralidharan <prasannatsmkumar@gmail.com> |
---|---|
date | Sun, 24 Aug 2014 19:35:06 +0530 |
parents | a9574e3c6e9e |
children | dcfbf4c1c3c8 |
comparison
equal
deleted
inserted
replaced
20671:0fe7133da8ce | 20672:5ce959c55cc0 |
---|---|
85 const octave_idx_type&, float*); | 85 const octave_idx_type&, float*); |
86 #endif | 86 #endif |
87 } | 87 } |
88 | 88 |
89 octave_idx_type | 89 octave_idx_type |
90 FloatCHOL::init (const FloatMatrix& a, bool calc_cond) | 90 FloatCHOL::init (const FloatMatrix& a, bool upper, bool calc_cond) |
91 { | 91 { |
92 octave_idx_type a_nr = a.rows (); | 92 octave_idx_type a_nr = a.rows (); |
93 octave_idx_type a_nc = a.cols (); | 93 octave_idx_type a_nc = a.cols (); |
94 | 94 |
95 if (a_nr != a_nc) | 95 if (a_nr != a_nc) |
99 } | 99 } |
100 | 100 |
101 octave_idx_type n = a_nc; | 101 octave_idx_type n = a_nc; |
102 octave_idx_type info; | 102 octave_idx_type info; |
103 | 103 |
104 is_upper = upper; | |
105 | |
104 chol_mat.clear (n, n); | 106 chol_mat.clear (n, n); |
105 for (octave_idx_type j = 0; j < n; j++) | 107 if (is_upper) |
106 { | 108 { |
107 for (octave_idx_type i = 0; i <= j; i++) | 109 for (octave_idx_type j = 0; j < n; j++) |
108 chol_mat.xelem (i, j) = a(i, j); | 110 { |
109 for (octave_idx_type i = j+1; i < n; i++) | 111 for (octave_idx_type i = 0; i <= j; i++) |
110 chol_mat.xelem (i, j) = 0.0f; | 112 chol_mat.xelem (i, j) = a(i, j); |
111 } | 113 for (octave_idx_type i = j+1; i < n; i++) |
114 chol_mat.xelem (i, j) = 0.0f; | |
115 } | |
116 } | |
117 else | |
118 { | |
119 for (octave_idx_type j = 0; j < n; j++) | |
120 { | |
121 for (octave_idx_type i = 0; i <= j; i++) | |
122 chol_mat.xelem (i, j) = 0.0f; | |
123 for (octave_idx_type i = j+1; i < n; i++) | |
124 chol_mat.xelem (i, j) = a(i, j); | |
125 } | |
126 } | |
127 | |
112 float *h = chol_mat.fortran_vec (); | 128 float *h = chol_mat.fortran_vec (); |
113 | 129 |
114 // Calculate the norm of the matrix, for later use. | 130 // Calculate the norm of the matrix, for later use. |
115 float anorm = 0; | 131 float anorm = 0; |
116 if (calc_cond) | 132 if (calc_cond) |
117 anorm = xnorm (a, 1); | 133 anorm = xnorm (a, 1); |
118 | 134 |
119 F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), | 135 if (is_upper) |
120 n, h, n, info | 136 { |
121 F77_CHAR_ARG_LEN (1))); | 137 F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), |
138 n, h, n, info | |
139 F77_CHAR_ARG_LEN (1))); | |
140 } | |
141 else | |
142 { | |
143 F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1), | |
144 n, h, n, info | |
145 F77_CHAR_ARG_LEN (1))); | |
146 } | |
122 | 147 |
123 xrcond = 0.0; | 148 xrcond = 0.0; |
124 if (info > 0) | 149 if (info > 0) |
125 chol_mat.resize (info - 1, info - 1); | 150 chol_mat.resize (info - 1, info - 1); |
126 else if (calc_cond) | 151 else if (calc_cond) |
130 // Now calculate the condition number for non-singular matrix. | 155 // Now calculate the condition number for non-singular matrix. |
131 Array<float> z (dim_vector (3*n, 1)); | 156 Array<float> z (dim_vector (3*n, 1)); |
132 float *pz = z.fortran_vec (); | 157 float *pz = z.fortran_vec (); |
133 Array<octave_idx_type> iz (dim_vector (n, 1)); | 158 Array<octave_idx_type> iz (dim_vector (n, 1)); |
134 octave_idx_type *piz = iz.fortran_vec (); | 159 octave_idx_type *piz = iz.fortran_vec (); |
135 F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, | 160 if (is_upper) |
136 n, anorm, xrcond, pz, piz, spocon_info | 161 { |
137 F77_CHAR_ARG_LEN (1))); | 162 F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, |
163 n, anorm, xrcond, pz, piz, spocon_info | |
164 F77_CHAR_ARG_LEN (1))); | |
165 } | |
166 else | |
167 { | |
168 F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("L", 1), n, h, | |
169 n, anorm, xrcond, pz, piz, spocon_info | |
170 F77_CHAR_ARG_LEN (1))); | |
171 } | |
172 | |
138 | 173 |
139 if (spocon_info != 0) | 174 if (spocon_info != 0) |
140 info = -1; | 175 info = -1; |
141 } | 176 } |
142 | 177 |
143 return info; | 178 return info; |
144 } | 179 } |
145 | 180 |
146 static FloatMatrix | 181 static FloatMatrix |
147 chol2inv_internal (const FloatMatrix& r) | 182 chol2inv_internal (const FloatMatrix& r, bool is_upper = true) |
148 { | 183 { |
149 FloatMatrix retval; | 184 FloatMatrix retval; |
150 | 185 |
151 octave_idx_type r_nr = r.rows (); | 186 octave_idx_type r_nr = r.rows (); |
152 octave_idx_type r_nc = r.cols (); | 187 octave_idx_type r_nc = r.cols (); |
159 FloatMatrix tmp = r; | 194 FloatMatrix tmp = r; |
160 float *v = tmp.fortran_vec (); | 195 float *v = tmp.fortran_vec (); |
161 | 196 |
162 if (info == 0) | 197 if (info == 0) |
163 { | 198 { |
164 F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n, | 199 if (is_upper) |
165 v, n, info | 200 { |
166 F77_CHAR_ARG_LEN (1))); | 201 F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n, |
202 v, n, info | |
203 F77_CHAR_ARG_LEN (1))); | |
204 } | |
205 else | |
206 { | |
207 F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n, | |
208 v, n, info | |
209 F77_CHAR_ARG_LEN (1))); | |
210 } | |
167 | 211 |
168 // If someone thinks of a more graceful way of doing this (or | 212 // If someone thinks of a more graceful way of doing this (or |
169 // faster for that matter :-)), please let me know! | 213 // faster for that matter :-)), please let me know! |
170 | 214 |
171 if (n > 1) | 215 if (n > 1) |
172 for (octave_idx_type j = 0; j < r_nc; j++) | 216 { |
173 for (octave_idx_type i = j+1; i < r_nr; i++) | 217 if (is_upper) |
174 tmp.xelem (i, j) = tmp.xelem (j, i); | 218 { |
219 for (octave_idx_type j = 0; j < r_nc; j++) | |
220 for (octave_idx_type i = j+1; i < r_nr; i++) | |
221 tmp.xelem (i, j) = tmp.xelem (j, i); | |
222 } | |
223 else | |
224 { | |
225 for (octave_idx_type j = 0; j < r_nc; j++) | |
226 for (octave_idx_type i = j+1; i < r_nr; i++) | |
227 tmp.xelem (j, i) = tmp.xelem (i, j); | |
228 } | |
229 } | |
175 | 230 |
176 retval = tmp; | 231 retval = tmp; |
177 } | 232 } |
178 } | 233 } |
179 else | 234 else |
184 | 239 |
185 // Compute the inverse of a matrix using the Cholesky factorization. | 240 // Compute the inverse of a matrix using the Cholesky factorization. |
186 FloatMatrix | 241 FloatMatrix |
187 FloatCHOL::inverse (void) const | 242 FloatCHOL::inverse (void) const |
188 { | 243 { |
189 return chol2inv_internal (chol_mat); | 244 return chol2inv_internal (chol_mat, is_upper); |
190 } | 245 } |
191 | 246 |
192 void | 247 void |
193 FloatCHOL::set (const FloatMatrix& R) | 248 FloatCHOL::set (const FloatMatrix& R) |
194 { | 249 { |