comparison liboctave/numeric/floatCHOL.cc @ 20672:5ce959c55cc0

Propagate 'lower' in chol(a, 'lower') to underlying library function. * chol.cc (chol): Send 'L' parameter correctly when chol is called with 'lower'. * floatCHOL.cc (init): Propagate 'lower' to underlying library function. * floatCHOL.h: Modify the prototype of methods. * fMatrix.cc (inverse): Invoke chol with additional parameter. * dbleCHOL.cc (init): Propagate 'lower' to underlying library function. * dbleCHOL.h: Modify the prototype of methods. * dMatrix.cc (inverse): Invoke chol with additional parameter. * CmplxCHOL.cc (init): Propagate 'lower' to underlying library function. * CmplxCHOL.h: Modify the prototype of methods. * CMatrix.cc (inverse): Invoke chol with additional parameter.
author PrasannaKumar Muralidharan <prasannatsmkumar@gmail.com>
date Sun, 24 Aug 2014 19:35:06 +0530
parents a9574e3c6e9e
children dcfbf4c1c3c8
comparison
equal deleted inserted replaced
20671:0fe7133da8ce 20672:5ce959c55cc0
85 const octave_idx_type&, float*); 85 const octave_idx_type&, float*);
86 #endif 86 #endif
87 } 87 }
88 88
89 octave_idx_type 89 octave_idx_type
90 FloatCHOL::init (const FloatMatrix& a, bool calc_cond) 90 FloatCHOL::init (const FloatMatrix& a, bool upper, bool calc_cond)
91 { 91 {
92 octave_idx_type a_nr = a.rows (); 92 octave_idx_type a_nr = a.rows ();
93 octave_idx_type a_nc = a.cols (); 93 octave_idx_type a_nc = a.cols ();
94 94
95 if (a_nr != a_nc) 95 if (a_nr != a_nc)
99 } 99 }
100 100
101 octave_idx_type n = a_nc; 101 octave_idx_type n = a_nc;
102 octave_idx_type info; 102 octave_idx_type info;
103 103
104 is_upper = upper;
105
104 chol_mat.clear (n, n); 106 chol_mat.clear (n, n);
105 for (octave_idx_type j = 0; j < n; j++) 107 if (is_upper)
106 { 108 {
107 for (octave_idx_type i = 0; i <= j; i++) 109 for (octave_idx_type j = 0; j < n; j++)
108 chol_mat.xelem (i, j) = a(i, j); 110 {
109 for (octave_idx_type i = j+1; i < n; i++) 111 for (octave_idx_type i = 0; i <= j; i++)
110 chol_mat.xelem (i, j) = 0.0f; 112 chol_mat.xelem (i, j) = a(i, j);
111 } 113 for (octave_idx_type i = j+1; i < n; i++)
114 chol_mat.xelem (i, j) = 0.0f;
115 }
116 }
117 else
118 {
119 for (octave_idx_type j = 0; j < n; j++)
120 {
121 for (octave_idx_type i = 0; i <= j; i++)
122 chol_mat.xelem (i, j) = 0.0f;
123 for (octave_idx_type i = j+1; i < n; i++)
124 chol_mat.xelem (i, j) = a(i, j);
125 }
126 }
127
112 float *h = chol_mat.fortran_vec (); 128 float *h = chol_mat.fortran_vec ();
113 129
114 // Calculate the norm of the matrix, for later use. 130 // Calculate the norm of the matrix, for later use.
115 float anorm = 0; 131 float anorm = 0;
116 if (calc_cond) 132 if (calc_cond)
117 anorm = xnorm (a, 1); 133 anorm = xnorm (a, 1);
118 134
119 F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1), 135 if (is_upper)
120 n, h, n, info 136 {
121 F77_CHAR_ARG_LEN (1))); 137 F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("U", 1),
138 n, h, n, info
139 F77_CHAR_ARG_LEN (1)));
140 }
141 else
142 {
143 F77_XFCN (spotrf, SPOTRF, (F77_CONST_CHAR_ARG2 ("L", 1),
144 n, h, n, info
145 F77_CHAR_ARG_LEN (1)));
146 }
122 147
123 xrcond = 0.0; 148 xrcond = 0.0;
124 if (info > 0) 149 if (info > 0)
125 chol_mat.resize (info - 1, info - 1); 150 chol_mat.resize (info - 1, info - 1);
126 else if (calc_cond) 151 else if (calc_cond)
130 // Now calculate the condition number for non-singular matrix. 155 // Now calculate the condition number for non-singular matrix.
131 Array<float> z (dim_vector (3*n, 1)); 156 Array<float> z (dim_vector (3*n, 1));
132 float *pz = z.fortran_vec (); 157 float *pz = z.fortran_vec ();
133 Array<octave_idx_type> iz (dim_vector (n, 1)); 158 Array<octave_idx_type> iz (dim_vector (n, 1));
134 octave_idx_type *piz = iz.fortran_vec (); 159 octave_idx_type *piz = iz.fortran_vec ();
135 F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h, 160 if (is_upper)
136 n, anorm, xrcond, pz, piz, spocon_info 161 {
137 F77_CHAR_ARG_LEN (1))); 162 F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("U", 1), n, h,
163 n, anorm, xrcond, pz, piz, spocon_info
164 F77_CHAR_ARG_LEN (1)));
165 }
166 else
167 {
168 F77_XFCN (spocon, SPOCON, (F77_CONST_CHAR_ARG2 ("L", 1), n, h,
169 n, anorm, xrcond, pz, piz, spocon_info
170 F77_CHAR_ARG_LEN (1)));
171 }
172
138 173
139 if (spocon_info != 0) 174 if (spocon_info != 0)
140 info = -1; 175 info = -1;
141 } 176 }
142 177
143 return info; 178 return info;
144 } 179 }
145 180
146 static FloatMatrix 181 static FloatMatrix
147 chol2inv_internal (const FloatMatrix& r) 182 chol2inv_internal (const FloatMatrix& r, bool is_upper = true)
148 { 183 {
149 FloatMatrix retval; 184 FloatMatrix retval;
150 185
151 octave_idx_type r_nr = r.rows (); 186 octave_idx_type r_nr = r.rows ();
152 octave_idx_type r_nc = r.cols (); 187 octave_idx_type r_nc = r.cols ();
159 FloatMatrix tmp = r; 194 FloatMatrix tmp = r;
160 float *v = tmp.fortran_vec (); 195 float *v = tmp.fortran_vec ();
161 196
162 if (info == 0) 197 if (info == 0)
163 { 198 {
164 F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n, 199 if (is_upper)
165 v, n, info 200 {
166 F77_CHAR_ARG_LEN (1))); 201 F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("U", 1), n,
202 v, n, info
203 F77_CHAR_ARG_LEN (1)));
204 }
205 else
206 {
207 F77_XFCN (spotri, SPOTRI, (F77_CONST_CHAR_ARG2 ("L", 1), n,
208 v, n, info
209 F77_CHAR_ARG_LEN (1)));
210 }
167 211
168 // If someone thinks of a more graceful way of doing this (or 212 // If someone thinks of a more graceful way of doing this (or
169 // faster for that matter :-)), please let me know! 213 // faster for that matter :-)), please let me know!
170 214
171 if (n > 1) 215 if (n > 1)
172 for (octave_idx_type j = 0; j < r_nc; j++) 216 {
173 for (octave_idx_type i = j+1; i < r_nr; i++) 217 if (is_upper)
174 tmp.xelem (i, j) = tmp.xelem (j, i); 218 {
219 for (octave_idx_type j = 0; j < r_nc; j++)
220 for (octave_idx_type i = j+1; i < r_nr; i++)
221 tmp.xelem (i, j) = tmp.xelem (j, i);
222 }
223 else
224 {
225 for (octave_idx_type j = 0; j < r_nc; j++)
226 for (octave_idx_type i = j+1; i < r_nr; i++)
227 tmp.xelem (j, i) = tmp.xelem (i, j);
228 }
229 }
175 230
176 retval = tmp; 231 retval = tmp;
177 } 232 }
178 } 233 }
179 else 234 else
184 239
185 // Compute the inverse of a matrix using the Cholesky factorization. 240 // Compute the inverse of a matrix using the Cholesky factorization.
186 FloatMatrix 241 FloatMatrix
187 FloatCHOL::inverse (void) const 242 FloatCHOL::inverse (void) const
188 { 243 {
189 return chol2inv_internal (chol_mat); 244 return chol2inv_internal (chol_mat, is_upper);
190 } 245 }
191 246
192 void 247 void
193 FloatCHOL::set (const FloatMatrix& R) 248 FloatCHOL::set (const FloatMatrix& R)
194 { 249 {