代码可系统性地完成混淆矩阵的生成、可视化和性能评估。实际应用中建议结合ROC曲线、PR曲线等补充评价指标。
一、基础混淆矩阵生成
% 示例数据(真实标签与预测标签)
actual = [1 1 0 1 0 0 1 0 0 1]; % 真实类别
predicted = [1 0 0 1 0 0 1 1 0 1]; % 预测类别
% 生成混淆矩阵(自动识别类别标签)
C = confusionmat(actual, predicted);
disp('基础混淆矩阵:');
disp(C);
输出示例:
基础混淆矩阵:
2 0 0
0 3 1
0 1 2
说明:矩阵行表示真实类别,列表示预测类别。对角线元素为正确分类数。
二、指定类别顺序的混淆矩阵
% 自定义类别顺序(如按['Negative','Positive']排序)
labels = {'Negative','Positive'};
C_ordered = confusionmat(actual, predicted, 'Order', labels);
% 显示带标签的矩阵
disp('自定义顺序混淆矩阵:');
disp(C_ordered);
关键参数:
'Order'
强制指定行/列的类别顺序
自动过滤缺失值(如NaN)
三、多分类问题处理(One-Hot编码)
% 假设原始标签为分类变量
actual_categorical = categorical(actual, [0 1], {'Class0','Class1'});
predicted_categorical = categorical(predicted, [0 1], {'Class0','Class1'});
% 转换为One-Hot编码
actual_onehot = onehotencode(actual_categorical, 2);
predicted_onehot = onehotencode(predicted_categorical, 2);
% 生成混淆矩阵
C_multi = confusionmat(actual_categorical, predicted_categorical);
disp('多分类混淆矩阵:');
disp(C_multi);
扩展功能:
支持任意类别数量
自动处理非数值型标签
四、混淆矩阵可视化
% 基础图形绘制
figure;
cm = confusionchart(actual, predicted);
title('混淆矩阵热力图');
xlabel('预测类别');
ylabel('真实类别');
% 高级定制化
figure;
cm = confusionchart(C_multi, {'Class0','Class1','Class2'});
cm.Title = '多分类混淆矩阵';
cm.XLabel = '预测类别';
cm.YLabel = '真实类别';
cm.ColorbarVisible = 'on';
特性:
自动计算并显示精确率、召回率等指标
支持颜色映射调整(
colormap
参数)
五、性能指标计算
% 提取混淆矩阵元素
TP = diag(C); % 真正例
FN = sum(C,2) - TP; % 假反例
FP = sum(C,1) - TP; % 假正例
TN = sum(C(:)) - (TP + FN + FP); % 真反例
% 计算评估指标
accuracy = sum(diag(C))/sum(C(:)); % 准确率
precision = TP ./ (TP + FP); % 精确率
recall = TP ./ (TP + FN); % 召回率
f1_score = 2*(precision.*recall)./(precision + recall); % F1分数
% 显示结果
metrics = struct(...
'Accuracy', accuracy,...
'Precision', precision,...
'Recall', recall,...
'F1_Score', f1_score...
);
disp('分类性能指标:');
disp(metrics);
公式说明:
准确率 = (TP+TN)/(TP+TN+FP+FN)
精确率 = TP/(TP+FP)
召回率 = TP/(TP+FN)
六、完整应用案例
% 加载示例数据集
load fisheriris
X = meas(:,3:4); % 使用花瓣长度和宽度
Y = species;
% 划分训练集/测试集
cv = cvpartition(Y,'HoldOut',0.3);
idxTrain = training(cv);
idxTest = test(cv);
% 训练分类器
mdl = fitcknn(X(idxTrain,:), Y(idxTrain));
% 预测与评估
Y_pred = predict(mdl, X(idxTest,:));
C = confusionmat(Y(idxTest), Y_pred);
% 可视化与指标
figure;
confusionchart(C, {'setosa', 'versicolor', 'virginica'});
title('Iris分类混淆矩阵');
% 输出指标
accuracy = sum(diag(C)) / sum(C(:));
disp(['分类准确率: ', num2str(accuracy * 100), '%']);
参考代码 混淆矩阵的matlab代码
www.youwenfan.com/contentcsl/78004.html
七、进阶技巧
处理不平衡数据
% 加权混淆矩阵(调整类别权重)
classWeights = [0.5 1 2]; % 少数类权重更高
C_weighted = confusionmat(Y(idxTest), Y_pred, classWeights);
批量处理多模型对比
models = {@fitctree, @fitcsvm, @fitcknn};
figure;
for i = 1 : numel(models)
mdl = models{i}(X(idxTrain,:), Y(idxTrain));
Y_pred = predict(mdl, X(idxTest,:));
subplot(1, numel(models), i);
confusionchart(Y(idxTest), Y_pred);
title(sprintf('%s分类结果', class2str(models{i})));
end
混淆矩阵统计分析
% 计算每个类别的错误率
errorRates = 1 - diag(C) / sum(C, 2);
[~, maxIdx] = max(errorRates);
fprintf('错误率最高的类别: %s (%.2f%%)\n', labels{maxIdx}, errorRates(maxIdx) * 100);
'Order'