comparison src/data.cc @ 7620:36594d5bbe13

Move diag function into the octave_value class
author David Bateman <dbateman@free.fr>
date Fri, 21 Mar 2008 00:08:24 +0100
parents 3209a584e1ac
children 431f3788f5c4
comparison
equal deleted inserted replaced
7619:56012914972a 7620:36594d5bbe13
881 @end deftypefn") 881 @end deftypefn")
882 { 882 {
883 DATA_REDUCTION (cumsum); 883 DATA_REDUCTION (cumsum);
884 } 884 }
885 885
886 template <class T>
887 static octave_value
888 make_diag (const T& v, octave_idx_type k)
889 {
890 octave_value retval;
891 dim_vector dv = v.dims ();
892 octave_idx_type nd = dv.length ();
893
894 if (nd > 2)
895 error ("diag: expecting 2-dimensional matrix");
896 else
897 {
898 octave_idx_type nr = dv (0);
899 octave_idx_type nc = dv (1);
900
901 if (nr == 0 || nc == 0)
902 retval = T ();
903 else if (nr != 1 && nc != 1)
904 retval = v.diag (k);
905 else
906 {
907 octave_idx_type roff = 0;
908 octave_idx_type coff = 0;
909 if (k > 0)
910 {
911 roff = 0;
912 coff = k;
913 }
914 else if (k < 0)
915 {
916 roff = -k;
917 coff = 0;
918 }
919
920 if (nr == 1)
921 {
922 octave_idx_type n = nc + std::abs (k);
923 T m (dim_vector (n, n), T::resize_fill_value ());
924
925 for (octave_idx_type i = 0; i < nc; i++)
926 m (i+roff, i+coff) = v (0, i);
927 retval = m;
928 }
929 else
930 {
931 octave_idx_type n = nr + std::abs (k);
932 T m (dim_vector (n, n), T::resize_fill_value ());
933 for (octave_idx_type i = 0; i < nr; i++)
934 m (i+roff, i+coff) = v (i, 0);
935 retval = m;
936 }
937 }
938 }
939
940 return retval;
941 }
942
943 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
944 static octave_value
945 make_diag (const Matrix& v, octave_idx_type k);
946
947 static octave_value
948 make_diag (const ComplexMatrix& v, octave_idx_type k);
949
950 static octave_value
951 make_diag (const charMatrix& v, octave_idx_type k);
952
953 static octave_value
954 make_diag (const boolMatrix& v, octave_idx_type k);
955
956 static octave_value
957 make_diag (const int8NDArray& v, octave_idx_type k);
958
959 static octave_value
960 make_diag (const int16NDArray& v, octave_idx_type k);
961
962 static octave_value
963 make_diag (const int32NDArray& v, octave_idx_type k);
964
965 static octave_value
966 make_diag (const int64NDArray& v, octave_idx_type k);
967
968 static octave_value
969 make_diag (const uint8NDArray& v, octave_idx_type k);
970
971 static octave_value
972 make_diag (const uint16NDArray& v, octave_idx_type k);
973
974 static octave_value
975 make_diag (const uint32NDArray& v, octave_idx_type k);
976
977 static octave_value
978 make_diag (const uint64NDArray& v, octave_idx_type k);
979
980 static octave_value
981 make_diag (const Cell& v, octave_idx_type k);
982 #endif
983
984 template <class T>
985 static octave_value
986 make_spdiag (const T& v, octave_idx_type k)
987 {
988 octave_value retval;
989 dim_vector dv = v.dims ();
990 octave_idx_type nr = dv (0);
991 octave_idx_type nc = dv (1);
992
993 if (nr == 0 || nc == 0)
994 retval = T ();
995 else if (nr != 1 && nc != 1)
996 retval = v.diag (k);
997 else
998 {
999 octave_idx_type roff = 0;
1000 octave_idx_type coff = 0;
1001 if (k > 0)
1002 {
1003 roff = 0;
1004 coff = k;
1005 }
1006 else if (k < 0)
1007 {
1008 roff = -k;
1009 coff = 0;
1010 }
1011
1012 if (nr == 1)
1013 {
1014 octave_idx_type n = nc + std::abs (k);
1015 octave_idx_type nz = v.nzmax ();
1016 T r (n, n, nz);
1017 for (octave_idx_type i = 0; i < coff+1; i++)
1018 r.xcidx (i) = 0;
1019 for (octave_idx_type j = 0; j < nc; j++)
1020 {
1021 for (octave_idx_type i = v.cidx(j); i < v.cidx(j+1); i++)
1022 {
1023 r.xdata (i) = v.data (i);
1024 r.xridx (i) = j + roff;
1025 }
1026 r.xcidx (j+coff+1) = v.cidx(j+1);
1027 }
1028 for (octave_idx_type i = nc+coff+1; i < n+1; i++)
1029 r.xcidx (i) = nz;
1030 retval = r;
1031 }
1032 else
1033 {
1034 octave_idx_type n = nr + std::abs (k);
1035 octave_idx_type nz = v.nzmax ();
1036 octave_idx_type ii = 0;
1037 octave_idx_type ir = v.ridx(0);
1038 T r (n, n, nz);
1039 for (octave_idx_type i = 0; i < coff+1; i++)
1040 r.xcidx (i) = 0;
1041 for (octave_idx_type i = 0; i < nr; i++)
1042 {
1043 if (ir == i)
1044 {
1045 r.xdata (ii) = v.data (ii);
1046 r.xridx (ii++) = ir + roff;
1047 if (ii != nz)
1048 ir = v.ridx (ii);
1049 }
1050 r.xcidx (i+coff+1) = ii;
1051 }
1052 for (octave_idx_type i = nr+coff+1; i < n+1; i++)
1053 r.xcidx (i) = nz;
1054 retval = r;
1055 }
1056 }
1057
1058 return retval;
1059 }
1060
1061 #if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
1062 static octave_value
1063 make_spdiag (const SparseMatrix& v, octave_idx_type k);
1064
1065 static octave_value
1066 make_spdiag (const SparseComplexMatrix& v, octave_idx_type k);
1067
1068 static octave_value
1069 make_spdiag (const SparseBoolMatrix& v, octave_idx_type k);
1070 #endif
1071
1072 static octave_value
1073 make_diag (const octave_value& a, octave_idx_type k)
1074 {
1075 octave_value retval;
1076 std::string result_type = a.class_name ();
1077
1078 if (result_type == "double")
1079 {
1080 if (a.is_sparse_type ())
1081 {
1082 if (a.is_real_type ())
1083 {
1084 SparseMatrix m = a.sparse_matrix_value ();
1085 if (!error_state)
1086 retval = make_spdiag (m, k);
1087 }
1088 else
1089 {
1090 SparseComplexMatrix m = a.sparse_complex_matrix_value ();
1091 if (!error_state)
1092 retval = make_spdiag (m, k);
1093 }
1094 }
1095 else
1096 {
1097 if (a.is_real_type ())
1098 {
1099 Matrix m = a.matrix_value ();
1100 if (!error_state)
1101 retval = make_diag (m, k);
1102 }
1103 else
1104 {
1105 ComplexMatrix m = a.complex_matrix_value ();
1106 if (!error_state)
1107 retval = make_diag (m, k);
1108 }
1109 }
1110 }
1111 #if 0
1112 else if (result_type == "single")
1113 retval = make_diag (a.single_array_value (), k);
1114 #endif
1115 else if (result_type == "char")
1116 {
1117 charMatrix m = a.char_matrix_value ();
1118 if (!error_state)
1119 {
1120 retval = make_diag (m, k);
1121 if (a.is_sq_string ())
1122 retval = octave_value (retval.char_array_value (), true, '\'');
1123 }
1124 }
1125 else if (result_type == "logical")
1126 {
1127 if (a.is_sparse_type ())
1128 {
1129 SparseBoolMatrix m = a.sparse_bool_matrix_value ();
1130 if (!error_state)
1131 retval = make_spdiag (m, k);
1132 }
1133 else
1134 {
1135 boolMatrix m = a.bool_matrix_value ();
1136 if (!error_state)
1137 retval = make_diag (m, k);
1138 }
1139 }
1140 else if (result_type == "int8")
1141 retval = make_diag (a.int8_array_value (), k);
1142 else if (result_type == "int16")
1143 retval = make_diag (a.int16_array_value (), k);
1144 else if (result_type == "int32")
1145 retval = make_diag (a.int32_array_value (), k);
1146 else if (result_type == "int64")
1147 retval = make_diag (a.int64_array_value (), k);
1148 else if (result_type == "uint8")
1149 retval = make_diag (a.uint8_array_value (), k);
1150 else if (result_type == "uint16")
1151 retval = make_diag (a.uint16_array_value (), k);
1152 else if (result_type == "uint32")
1153 retval = make_diag (a.uint32_array_value (), k);
1154 else if (result_type == "uint64")
1155 retval = make_diag (a.uint64_array_value (), k);
1156 else if (result_type == "cell")
1157 retval = make_diag (a.cell_value (), k);
1158 else
1159 gripe_wrong_type_arg ("diag", a);
1160
1161 return retval;
1162 }
1163
1164 static octave_value
1165 make_diag (const octave_value& arg)
1166 {
1167 return make_diag (arg, 0);
1168 }
1169
1170 static octave_value
1171 make_diag (const octave_value& a, const octave_value& b)
1172 {
1173 octave_value retval;
1174
1175 octave_idx_type k = b.int_value ();
1176
1177 if (error_state)
1178 error ("diag: invalid second argument");
1179 else
1180 retval = make_diag (a, k);
1181
1182 return retval;
1183 }
1184
1185 DEFUN (diag, args, , 886 DEFUN (diag, args, ,
1186 "-*- texinfo -*-\n\ 887 "-*- texinfo -*-\n\
1187 @deftypefn {Built-in Function} {} diag (@var{v}, @var{k})\n\ 888 @deftypefn {Built-in Function} {} diag (@var{v}, @var{k})\n\
1188 Return a diagonal matrix with vector @var{v} on diagonal @var{k}. The\n\ 889 Return a diagonal matrix with vector @var{v} on diagonal @var{k}. The\n\
1189 second argument is optional. If it is positive, the vector is placed on\n\ 890 second argument is optional. If it is positive, the vector is placed on\n\
1209 octave_value retval; 910 octave_value retval;
1210 911
1211 int nargin = args.length (); 912 int nargin = args.length ();
1212 913
1213 if (nargin == 1 && args(0).is_defined ()) 914 if (nargin == 1 && args(0).is_defined ())
1214 retval = make_diag (args(0)); 915 retval = args(0).diag();
1215 else if (nargin == 2 && args(0).is_defined () && args(1).is_defined ()) 916 else if (nargin == 2 && args(0).is_defined () && args(1).is_defined ())
1216 retval = make_diag (args(0), args(1)); 917 {
918 octave_idx_type k = args(1).int_value ();
919
920 if (error_state)
921 error ("diag: invalid second argument");
922 else
923 retval = args(0).diag(k);
924 }
1217 else 925 else
1218 print_usage (); 926 print_usage ();
1219 927
1220 return retval; 928 return retval;
1221 } 929 }