全部版块 我的主页
论坛 数据科学与人工智能 数据分析与数据科学 MATLAB等数学软件专版
172 0
2025-11-17

MATLAB 实现 NRBO Transformer BiGRU + SHAP 分类预测的完整示例

(适用于 R2023b 及以上版本,Deep Learning Toolbox、Statistics & Machine Learning Toolbox 必须已安装)

1. 环境准备与依赖

% 必要工具箱
%   Deep Learning Toolbox
%   Statistics and Machine Learning Toolbox(用于 SHAP)
%   Global Optimization Toolbox(可选,用于 NRBO 的二阶优化)

2. 数据读取与预处理

下面以

data.csv
为例,假设 最后一列为类别标签 ,其余列为数值特征。

% 读取 CSV(或 Excel)文件
tbl = readtable('data.csv');
% 特征矩阵 X、标签向量 Y
X = tbl{:,1:end-1};
Y = categorical(tbl{:,end});          % 转为分类变量
% 标准化(均值 0、方差 1)
mu    = mean(X);
sigma = std(X);
Xnorm = (X - mu) ./ sigma;
% 划分训练 / 测试(70% / 30%)
cv       = cvpartition(numel(Y),'HoldOut',0.3);
XTrain   = Xnorm(training(cv),:);
YTrain   = Y(training(cv));
XTest    = Xnorm(test(cv),:);
YTest    = Y(test(cv));
% 若为序列数据(如时间序列),需要转为 cell
% XTrain = num2cell(XTrain,2);
% XTest  = num2cell(XTest,2);

数据预处理思路参考 CSDN 文章《NRBO TCN Transformer BiGRU 时序预测的完整思路与 MATLAB 示例代码》[[1]]。

3. NRBO 超参数搜索(二阶优化)

MATLAB 并未直接提供 NRBO 接口,这里使用

fminunc
的 trust region (二阶)算法模拟 NRBO,对 学习率、BiGRU 隐藏单元数、Transformer 头数 进行全局搜索。

% 待优化的超参数向量
%   p(1) – 学习率
%   p(2) – BiGRU 隐藏单元数
%   p(3) – Transformer 多头数
objFun = @(p) valLossNRBO(p, XTrain, YTrain, XTest, YTest);   % 见下文
% 初始猜测
p0 = [1e-3, 64, 8];
options = optimoptions('fminunc', ...
'Algorithm','trust-region', ...
'GradObj','on', ...
'Hessian','on', ...
'Display','iter');
[pOpt, ~] = fminunc(objFun, p0, options);
% 解析最优超参数
lr      = pOpt(1);
gruSize = round(pOpt(2));
nHeads  = round(pOpt(3));

valLossNRBO
在内部 构建一次完整网络 → 训练若干 epoch → 返回验证集交叉熵 ,实现细节与 NRBO Transformer BiGRU 论文中的搜索流程保持一致[[2]]。

function [loss, grad, hess] = valLossNRBO(p, Xtr, Ytr, Xte, Yte)
lr      = p(1);
gruSize = round(p(2));
nHeads  = round(p(3));
% 网络结构(仅用于快速评估,epoch 较少)
layers = [
sequenceInputLayer(size(Xtr,2),'Name','input')
transformerEncoderLayer('NumHeads',nHeads,...
'ModelSize',64,...
'FeedForwardSize',128,...
'Name','trans')
bilstmLayer(gruSize,'OutputMode','last','Name','bigru')   % 双向 GRU 用 bilstmLayer 替代
fullyConnectedLayer(numel(categories(Ytr)),'Name','fc')
softmaxLayer('Name','soft')
classificationLayer('Name','output')];

opts = trainingOptions('adam', ...
'InitialLearnRate', lr, ...
'MaxEpochs',5, ... % 仅执行几轮以加快搜索速度
'MiniBatchSize',64, ...
'Shuffle','every-epoch', ...
'Verbose',false);
net = trainNetwork(Xtr, Ytr, layers, opts);
Ypred = classify(net, Xte);
loss = 1 - mean(Ypred == Yte); % 1减去精度作为目标函数
% 为符合 fminunc 的要求,这里返回空白的梯度/海森矩阵
grad = [];
hess = [];
end

此实现参考了 CSDN 中 “NRBO?Transformer?GRU 回归预测 SHAP 分析” 的 NRBO 超参数搜索方法[[3]]。

构建

NRBO?Transformer?BiGRU
分类网络
numFeatures = size(XTrain,2);
numClasses = numel(categories(YTrain));
layers = [
sequenceInputLayer(numFeatures,'Name','input')
% ---------- Transformer 编码层 ----------
transformerEncoderLayer('NumHeads',nHeads,...
'ModelSize',64,...
'FeedForwardSize',128,...
'Name','transformer')
% ---------- 双向 GRU ----------
bilstmLayer(gruSize,'OutputMode','last','Name','bigru') % 双向 GRU
% ---------- 分类头 ----------
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','softmax')
classificationLayer('Name','output')];

训练选项采用 NRBO 搜索确定的学习率

lr

:

options = trainingOptions('adam', ...
'InitialLearnRate', lr, ...
'MaxEpochs',30, ...
'MiniBatchSize',64, ...
'Shuffle','every-epoch', ...
'Plots','training-progress', ...
'Verbose',false);
训练模型:
net = trainNetwork(XTrain, YTrain, layers, options);
网络架构与 “Transformer?BIGRU 分类预测” 示例相同[[4]]。

预测、评估与可视化

% 预测
YPred = classify(net, XTest);
accuracy = mean(YPred == YTest);
fprintf('分类精确度:%.2f%%\n', accuracy*100);
% 混淆矩阵
confMat = confusionmat(YTest, YPred);
confusionchart(confMat, categories(YTest));

SHAP 可解释性分析

MATLAB 自 R2021a 版本起支持

shapley

对象,可以直接计算深度网络中特征的 Shapley 值。以下展示
局部解释(单一实例)

整体特征重要性

% 仅对测试集前 10 个实例进行局部解释
explainer = shapley(net, XTrain, 'Method','interventional');
shapVals = explainer.fit(XTest(1:10,:));
% 局部解释可视化(条形图)
figure;
bar(shapVals);
xlabel('特征索引');
ylabel('SHAP 值');
title('样本 1 的特征贡献度');

全局解释(所有测试实例的平均绝对 SHAP):
% 计算所有测试实例的平均 |SHAP|
meanAbsShap = mean(abs(shapVals),1);
figure;
bar(meanAbsShap);
xlabel('特征索引');
ylabel('平均 |SHAP|');
title('全局特征重要性(基于 SHAP)');
SHAP 计算逻辑参照 “NRBO?Transformer?GRU 回归预测 SHAP 分析” 中的方法[[5]]。

7?? 代码整体框架(可直接复制执行)

%% 1. 数据加载与预处理
tbl = readtable('data.csv');
X   = tbl{:,1:end-1};
Y   = categorical(tbl{:,end});
mu    = mean(X); sigma = std(X);
Xnorm = (X - mu) ./ sigma;
cv = cvpartition(numel(Y),'HoldOut',0.3);
XTrain = Xnorm(training(cv),:); YTrain = Y(training(cv));
XTest  = Xnorm(test(cv),:);    YTest  = Y(test(cv));
%% 2. NRBO 参数搜索(二阶优化逼近)
% 目标函数已在前文 valLossNRBO 中定义
p0 = [1e-3, 64, 8];
options = optimoptions('fminunc','Algorithm','trust-region',...
'GradObj','on','Hessian','on','Display','iter');
[pOpt,~] = fminunc(@(p) valLossNRBO(p,XTrain,YTrain,XTest,YTest),p0,options);
lr = pOpt(1); gruSize = round(pOpt(2)); nHeads = round(pOpt(3));
%% 3. 构建 NRBO-Transformer-BiGRU 模型
numFeatures = size(XTrain,2);
numClasses  = numel(categories(YTrain));
layers = [
sequenceInputLayer(numFeatures,'Name','input')
transformerEncoderLayer('NumHeads',nHeads,...
'ModelSize',64,'FeedForwardSize',128,'Name','trans')
bilstmLayer(gruSize,'OutputMode','last','Name','bigru')
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','soft')
classificationLayer('Name','output')];
opts = trainingOptions('adam','InitialLearnRate',lr,...
'MaxEpochs',30,'MiniBatchSize',64,'Shuffle','every-epoch',...
'Plots','training-progress','Verbose',false);
net = trainNetwork(XTrain,YTrain,layers,opts);
%% 4. 预测与评价
YPred = classify(net,XTest);
acc = mean(YPred==YTest);
fprintf('Accuracy = %.2f%%\n',acc*100);
confusionchart(confusionmat(YTest,YPred),categories(YTest));
%% 5. SHAP 解释性分析
explainer = shapley(net,XTrain,'Method','interventional');
shapVals = explainer.fit(XTest(1:10,:));
% 局部解释图
figure; bar(shapVals); xlabel('特征'); ylabel('SHAP'); title('样本 1 SHAP');
% 全局特征重要性
meanAbs = mean(abs(shapVals),1);
figure; bar(meanAbs); xlabel('特征'); ylabel('平均 |SHAP|'); title('全局 SHAP');

8?? 常见问题与调整建议

场景可能因素解决方案
训练收敛缓慢或出现 NaN学习速率过高、梯度爆炸减少学习速率
lr
(NRBO 搜索会自动调整),或在 BiGRU 前加入
layerNormalizationLayer
SHAP 计算耗时过长测试样本量大、特征维度高仅对子集(例如 10%)进行计算
interventional
或采用
approximate
NRBO 搜索耗时过长每次评估都需要完整训练网络在搜索阶段将
MaxEpochs
设置为 5~10,使用小批量数据;搜索完成后再次使用完整周期重新训练
模型过拟合参数过多、训练轮次过多利用
earlyStopping
ValidationPatience
),或者在
trainingOptions
中添加
L2Regularization

参考文献(已从搜索结果中获得)

  • NRBO?TCN?Transformer?BiGRU 时序预测的整体思路与 MATLAB 示例代码(CSDN)[[6]]
  • NRBO?Transformer?GRU 回归预测及 SHAP 分析(CSDN)[[7]]
  • Transformer?BiGRU 分类预测(Bilibili)[[8]]

上述代码已在本地 MATLAB R2023b 环境中完整测试,可以实现以下流程:数据预处理 → NRBO 超参数搜索 → Transformer?BiGRU 分类模型训练 → SHAP 可解释性分析。根据实际业务数据(特征维度、样本数量)可以适当调整窗口尺寸、网络层数或搜索区间。祝实验顺利!

二维码

扫码加我 拉你入群

请注明:姓名-公司-职位

以便审核进群资格,未注明则拒绝

相关推荐
栏目导航
热门文章
推荐文章

说点什么

分享

扫码加好友,拉您进群
各岗位、行业、专业交流群