Mercurial > hg > octave-lyh
comparison src/data.cc @ 7620:36594d5bbe13
Move diag function into the octave_value class
author | David Bateman <dbateman@free.fr> |
---|---|
date | Fri, 21 Mar 2008 00:08:24 +0100 |
parents | 3209a584e1ac |
children | 431f3788f5c4 |
comparison
equal
deleted
inserted
replaced
7619:56012914972a | 7620:36594d5bbe13 |
---|---|
881 @end deftypefn") | 881 @end deftypefn") |
882 { | 882 { |
883 DATA_REDUCTION (cumsum); | 883 DATA_REDUCTION (cumsum); |
884 } | 884 } |
885 | 885 |
886 template <class T> | |
887 static octave_value | |
888 make_diag (const T& v, octave_idx_type k) | |
889 { | |
890 octave_value retval; | |
891 dim_vector dv = v.dims (); | |
892 octave_idx_type nd = dv.length (); | |
893 | |
894 if (nd > 2) | |
895 error ("diag: expecting 2-dimensional matrix"); | |
896 else | |
897 { | |
898 octave_idx_type nr = dv (0); | |
899 octave_idx_type nc = dv (1); | |
900 | |
901 if (nr == 0 || nc == 0) | |
902 retval = T (); | |
903 else if (nr != 1 && nc != 1) | |
904 retval = v.diag (k); | |
905 else | |
906 { | |
907 octave_idx_type roff = 0; | |
908 octave_idx_type coff = 0; | |
909 if (k > 0) | |
910 { | |
911 roff = 0; | |
912 coff = k; | |
913 } | |
914 else if (k < 0) | |
915 { | |
916 roff = -k; | |
917 coff = 0; | |
918 } | |
919 | |
920 if (nr == 1) | |
921 { | |
922 octave_idx_type n = nc + std::abs (k); | |
923 T m (dim_vector (n, n), T::resize_fill_value ()); | |
924 | |
925 for (octave_idx_type i = 0; i < nc; i++) | |
926 m (i+roff, i+coff) = v (0, i); | |
927 retval = m; | |
928 } | |
929 else | |
930 { | |
931 octave_idx_type n = nr + std::abs (k); | |
932 T m (dim_vector (n, n), T::resize_fill_value ()); | |
933 for (octave_idx_type i = 0; i < nr; i++) | |
934 m (i+roff, i+coff) = v (i, 0); | |
935 retval = m; | |
936 } | |
937 } | |
938 } | |
939 | |
940 return retval; | |
941 } | |
942 | |
943 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) | |
944 static octave_value | |
945 make_diag (const Matrix& v, octave_idx_type k); | |
946 | |
947 static octave_value | |
948 make_diag (const ComplexMatrix& v, octave_idx_type k); | |
949 | |
950 static octave_value | |
951 make_diag (const charMatrix& v, octave_idx_type k); | |
952 | |
953 static octave_value | |
954 make_diag (const boolMatrix& v, octave_idx_type k); | |
955 | |
956 static octave_value | |
957 make_diag (const int8NDArray& v, octave_idx_type k); | |
958 | |
959 static octave_value | |
960 make_diag (const int16NDArray& v, octave_idx_type k); | |
961 | |
962 static octave_value | |
963 make_diag (const int32NDArray& v, octave_idx_type k); | |
964 | |
965 static octave_value | |
966 make_diag (const int64NDArray& v, octave_idx_type k); | |
967 | |
968 static octave_value | |
969 make_diag (const uint8NDArray& v, octave_idx_type k); | |
970 | |
971 static octave_value | |
972 make_diag (const uint16NDArray& v, octave_idx_type k); | |
973 | |
974 static octave_value | |
975 make_diag (const uint32NDArray& v, octave_idx_type k); | |
976 | |
977 static octave_value | |
978 make_diag (const uint64NDArray& v, octave_idx_type k); | |
979 | |
980 static octave_value | |
981 make_diag (const Cell& v, octave_idx_type k); | |
982 #endif | |
983 | |
984 template <class T> | |
985 static octave_value | |
986 make_spdiag (const T& v, octave_idx_type k) | |
987 { | |
988 octave_value retval; | |
989 dim_vector dv = v.dims (); | |
990 octave_idx_type nr = dv (0); | |
991 octave_idx_type nc = dv (1); | |
992 | |
993 if (nr == 0 || nc == 0) | |
994 retval = T (); | |
995 else if (nr != 1 && nc != 1) | |
996 retval = v.diag (k); | |
997 else | |
998 { | |
999 octave_idx_type roff = 0; | |
1000 octave_idx_type coff = 0; | |
1001 if (k > 0) | |
1002 { | |
1003 roff = 0; | |
1004 coff = k; | |
1005 } | |
1006 else if (k < 0) | |
1007 { | |
1008 roff = -k; | |
1009 coff = 0; | |
1010 } | |
1011 | |
1012 if (nr == 1) | |
1013 { | |
1014 octave_idx_type n = nc + std::abs (k); | |
1015 octave_idx_type nz = v.nzmax (); | |
1016 T r (n, n, nz); | |
1017 for (octave_idx_type i = 0; i < coff+1; i++) | |
1018 r.xcidx (i) = 0; | |
1019 for (octave_idx_type j = 0; j < nc; j++) | |
1020 { | |
1021 for (octave_idx_type i = v.cidx(j); i < v.cidx(j+1); i++) | |
1022 { | |
1023 r.xdata (i) = v.data (i); | |
1024 r.xridx (i) = j + roff; | |
1025 } | |
1026 r.xcidx (j+coff+1) = v.cidx(j+1); | |
1027 } | |
1028 for (octave_idx_type i = nc+coff+1; i < n+1; i++) | |
1029 r.xcidx (i) = nz; | |
1030 retval = r; | |
1031 } | |
1032 else | |
1033 { | |
1034 octave_idx_type n = nr + std::abs (k); | |
1035 octave_idx_type nz = v.nzmax (); | |
1036 octave_idx_type ii = 0; | |
1037 octave_idx_type ir = v.ridx(0); | |
1038 T r (n, n, nz); | |
1039 for (octave_idx_type i = 0; i < coff+1; i++) | |
1040 r.xcidx (i) = 0; | |
1041 for (octave_idx_type i = 0; i < nr; i++) | |
1042 { | |
1043 if (ir == i) | |
1044 { | |
1045 r.xdata (ii) = v.data (ii); | |
1046 r.xridx (ii++) = ir + roff; | |
1047 if (ii != nz) | |
1048 ir = v.ridx (ii); | |
1049 } | |
1050 r.xcidx (i+coff+1) = ii; | |
1051 } | |
1052 for (octave_idx_type i = nr+coff+1; i < n+1; i++) | |
1053 r.xcidx (i) = nz; | |
1054 retval = r; | |
1055 } | |
1056 } | |
1057 | |
1058 return retval; | |
1059 } | |
1060 | |
1061 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL) | |
1062 static octave_value | |
1063 make_spdiag (const SparseMatrix& v, octave_idx_type k); | |
1064 | |
1065 static octave_value | |
1066 make_spdiag (const SparseComplexMatrix& v, octave_idx_type k); | |
1067 | |
1068 static octave_value | |
1069 make_spdiag (const SparseBoolMatrix& v, octave_idx_type k); | |
1070 #endif | |
1071 | |
1072 static octave_value | |
1073 make_diag (const octave_value& a, octave_idx_type k) | |
1074 { | |
1075 octave_value retval; | |
1076 std::string result_type = a.class_name (); | |
1077 | |
1078 if (result_type == "double") | |
1079 { | |
1080 if (a.is_sparse_type ()) | |
1081 { | |
1082 if (a.is_real_type ()) | |
1083 { | |
1084 SparseMatrix m = a.sparse_matrix_value (); | |
1085 if (!error_state) | |
1086 retval = make_spdiag (m, k); | |
1087 } | |
1088 else | |
1089 { | |
1090 SparseComplexMatrix m = a.sparse_complex_matrix_value (); | |
1091 if (!error_state) | |
1092 retval = make_spdiag (m, k); | |
1093 } | |
1094 } | |
1095 else | |
1096 { | |
1097 if (a.is_real_type ()) | |
1098 { | |
1099 Matrix m = a.matrix_value (); | |
1100 if (!error_state) | |
1101 retval = make_diag (m, k); | |
1102 } | |
1103 else | |
1104 { | |
1105 ComplexMatrix m = a.complex_matrix_value (); | |
1106 if (!error_state) | |
1107 retval = make_diag (m, k); | |
1108 } | |
1109 } | |
1110 } | |
1111 #if 0 | |
1112 else if (result_type == "single") | |
1113 retval = make_diag (a.single_array_value (), k); | |
1114 #endif | |
1115 else if (result_type == "char") | |
1116 { | |
1117 charMatrix m = a.char_matrix_value (); | |
1118 if (!error_state) | |
1119 { | |
1120 retval = make_diag (m, k); | |
1121 if (a.is_sq_string ()) | |
1122 retval = octave_value (retval.char_array_value (), true, '\''); | |
1123 } | |
1124 } | |
1125 else if (result_type == "logical") | |
1126 { | |
1127 if (a.is_sparse_type ()) | |
1128 { | |
1129 SparseBoolMatrix m = a.sparse_bool_matrix_value (); | |
1130 if (!error_state) | |
1131 retval = make_spdiag (m, k); | |
1132 } | |
1133 else | |
1134 { | |
1135 boolMatrix m = a.bool_matrix_value (); | |
1136 if (!error_state) | |
1137 retval = make_diag (m, k); | |
1138 } | |
1139 } | |
1140 else if (result_type == "int8") | |
1141 retval = make_diag (a.int8_array_value (), k); | |
1142 else if (result_type == "int16") | |
1143 retval = make_diag (a.int16_array_value (), k); | |
1144 else if (result_type == "int32") | |
1145 retval = make_diag (a.int32_array_value (), k); | |
1146 else if (result_type == "int64") | |
1147 retval = make_diag (a.int64_array_value (), k); | |
1148 else if (result_type == "uint8") | |
1149 retval = make_diag (a.uint8_array_value (), k); | |
1150 else if (result_type == "uint16") | |
1151 retval = make_diag (a.uint16_array_value (), k); | |
1152 else if (result_type == "uint32") | |
1153 retval = make_diag (a.uint32_array_value (), k); | |
1154 else if (result_type == "uint64") | |
1155 retval = make_diag (a.uint64_array_value (), k); | |
1156 else if (result_type == "cell") | |
1157 retval = make_diag (a.cell_value (), k); | |
1158 else | |
1159 gripe_wrong_type_arg ("diag", a); | |
1160 | |
1161 return retval; | |
1162 } | |
1163 | |
1164 static octave_value | |
1165 make_diag (const octave_value& arg) | |
1166 { | |
1167 return make_diag (arg, 0); | |
1168 } | |
1169 | |
1170 static octave_value | |
1171 make_diag (const octave_value& a, const octave_value& b) | |
1172 { | |
1173 octave_value retval; | |
1174 | |
1175 octave_idx_type k = b.int_value (); | |
1176 | |
1177 if (error_state) | |
1178 error ("diag: invalid second argument"); | |
1179 else | |
1180 retval = make_diag (a, k); | |
1181 | |
1182 return retval; | |
1183 } | |
1184 | |
1185 DEFUN (diag, args, , | 886 DEFUN (diag, args, , |
1186 "-*- texinfo -*-\n\ | 887 "-*- texinfo -*-\n\ |
1187 @deftypefn {Built-in Function} {} diag (@var{v}, @var{k})\n\ | 888 @deftypefn {Built-in Function} {} diag (@var{v}, @var{k})\n\ |
1188 Return a diagonal matrix with vector @var{v} on diagonal @var{k}. The\n\ | 889 Return a diagonal matrix with vector @var{v} on diagonal @var{k}. The\n\ |
1189 second argument is optional. If it is positive, the vector is placed on\n\ | 890 second argument is optional. If it is positive, the vector is placed on\n\ |
1209 octave_value retval; | 910 octave_value retval; |
1210 | 911 |
1211 int nargin = args.length (); | 912 int nargin = args.length (); |
1212 | 913 |
1213 if (nargin == 1 && args(0).is_defined ()) | 914 if (nargin == 1 && args(0).is_defined ()) |
1214 retval = make_diag (args(0)); | 915 retval = args(0).diag(); |
1215 else if (nargin == 2 && args(0).is_defined () && args(1).is_defined ()) | 916 else if (nargin == 2 && args(0).is_defined () && args(1).is_defined ()) |
1216 retval = make_diag (args(0), args(1)); | 917 { |
918 octave_idx_type k = args(1).int_value (); | |
919 | |
920 if (error_state) | |
921 error ("diag: invalid second argument"); | |
922 else | |
923 retval = args(0).diag(k); | |
924 } | |
1217 else | 925 else |
1218 print_usage (); | 926 print_usage (); |
1219 | 927 |
1220 return retval; | 928 return retval; |
1221 } | 929 } |