提取BP神经网络模型数学表达式

Get Mathmatical Function of a BP Neural Network in Matlab

Posted by Chao Zhang on October 26, 2022

1. 问题描述

之前介绍了如何采用Matlab神经网络工具拟合同心管机器人(CTR)正运动学模型,并得到了模型的Matlab function,供程序直接调用。这些工作极大提升了正运动模型的计算速度,并保证了较高的模型拟合精度。

采用神经网络拟合所得到的显示CTR正运动学方程,为下一步CTR逆运动学的数值求解提供了方便。参考Girerd 2020年T-RO论文“Design and Control of a Hand-Held Concentric Tube Robot for Minimally Invasive Surgery”,其通过傅里叶级数拟合方法得到CTR显式正运动学方程,之后采用Newton-Raphson寻根法,求解逆运动学问题。

$q_{k+1} = q_k + \gamma J^{\dagger} F(q_k)$

其中 $F = p(q_k)-p_{des}$,$J = \partial{F}/\partial{q}$ 为雅克比矩阵,$J^\dagger$为$J$的伪逆,$\gamma$ 为迭代步长。

于是,需进一步对所得到的神经网络拟合模型的Matlab function $p=f(q)$ 求偏导,以获得雅克比矩阵$J$。 所以,需要更进一步对神经网络拟合模型的Matlab function做精简,并求解雅克比矩阵。

2. 提取神经网络模型数学表达式

Matlab神经网络曲线拟合工具采用的基于BP(Back Propagation)神经网络的结构,BP神经网络是一种按照误差逆向传播算法训练的多层前馈神经网络,是应用最广泛的神经网络模型之一。BP神经网络具有任意复杂的模式分类能力和优良的多维函数映射能力,解决了简单感知器不能解决的异或(Exclusive OR,XOR)和一些其他问题。从结构上讲,BP网络具有输入层、隐藏层和输出层;从本质上讲,BP算法就是以网络误差平方为目标函数、采用梯度下降法来计算目标函数的最小值。

老饼讲解-BP神经网络对BP神经网络进行了较多介绍,其中提取模型的数学表达式的内容给了本文较多启发。实际上,我们之前所得到的模型Matlab function已经将提取模型数学表达式的工作完成了,只是代码尚可进一步精简,以满足进一步求偏导等需求。

所采用的网络模型如图

图1 Network Architecture
首先给出所得到模型的Matlab function代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
function [y1] = ctr_nn_compliant_100e3(x1)
%MYNEURALNETWORKFUNCTION neural network simulation function.
%
% Auto-generated by MATLAB, 18-Oct-2022 10:28:45.
%
% [y1] = myNeuralNetworkFunction(x1) takes these arguments:
%   x = 6xQ matrix, input #1
% and returns:
%   y = 3xQ matrix, output #1
% where Q is the number of samples.

%#ok<*RPMT0>

% ===== NEURAL NETWORK CONSTANTS =====

% Input 1
x1_step1.xoffset = [-3.14151301294718;-3.14158930186924;-3.14155700428902;1.39278504079909e-07;3.796394083046e-05;0.00030838924090045];
x1_step1.gain = [0.318319717213362;0.31831692263459;0.318312047960775;40.0009294773944;26.7292411660151;20.2091493687639];
x1_step1.ymin = -1;

% Layer 1
b1 = [-1.7388939194493422402;-2.4749417490986771462;-1.1249356202752482936;-2.6714123085951806402;-1.858068131102155629;1.6478849199826257621;2.0846431364830761979;-0.89758955793323513817;0.40029527742044296312;-0.74791567806529213325;-0.60109872491771010328;-0.22093805634656574211;0.1065022160035486859;0.95913365748786505449;0.14339184536803128989;-0.54386742892165718111;-0.029544042983252836315;-0.63379761097495923572;-0.47405479609874645597;0.040394931367360040197;-0.33023457936925720846;-1.3853308931221361977;0.33426439440168753681;0.89849937204249341605;0.63809375203337226345;-1.8653663278578695994;0.83059864496975277515;1.2747355402384539236;1.3790147727652526388;1.0089799371839851538];
IW1_1 = [0.70481821454761317547 0.94066804495340350378 0.19250806896302796734 -0.175961397806350478 -0.26282501476827724307 0.49390354611646203331;1.7410593586642995056 0.090500092368769199935 -0.099058213574340178709 0.26895023627673586564 -0.11956263934775875069 0.53771048340498228679;0.19185886834324769601 0.15729276321291146834 -0.022744322563751247918 -0.21072866906180459656 0.22215943412019445447 -0.91056392604408276092;0.0021397716370694344797 0.1469248172827831711 -2.4427413900570775773 -0.12830278955081728109 -0.31646773667236560712 0.70714665856302749702;-1.7493888027266726315 1.9199067651441621418 -1.6882125009460646758 0.0027791275946151025415 -0.42869533399801107443 0.42692001936564372144;-0.40592776178898870576 -1.0540391176762362591 0.22607798737491094365 0.19895099638437335754 0.31090257842229579444 -0.65142163947122866396;0.035477532444007316148 -0.077402207909181480749 1.760706547608921424 0.11849754222414986271 0.41258618623999304686 -0.84595145483721523316;1.3849718476235566289 -0.039323501261560298015 -0.0047239758189265212626 0.29161423190608293243 0.078763916218152624227 0.17822837124910223539;-1.6000816141983904561 0.0042133417763625105254 0.030256997109806012258 0.17074487183614442398 0.0061362076773207267083 0.22987050130053141017;0.019073588217767584174 -0.0018744968898865030458 1.2276300703908842316 0.035991546321953303555 0.3630223495478334117 -0.5016916197639439412;1.0977343392837806135 -0.02538236614927592838 -0.029557893580960371838 0.27634998831452611823 0.014353760944995814369 0.44068236611367705979;0.99038352169396393077 0.84755627882871376055 -0.9825709030042119041 -0.054065007250253468984 -0.24300870861524129229 0.16509199144961023142;-0.02034320820540015673 -0.029122973820098509851 0.00098084129364271743778 -0.0029049751507347894124 -0.0068183736211576253042 -0.26613290444897252929;-1.5535630529917612286 1.4993155622530653925 -1.5847915517453563528 0.023862338643182847558 0.64741397241384046612 -0.40086303321128718968;-0.013051330236215213931 -1.3886059061710134266 -0.038679842880104439717 0.17749540683558542864 -0.009735829994397633344 -0.34965948264055973782;-0.72852353827228610861 0.28836294609575863213 0.053464308287469695935 0.15050259567743545785 0.0086806933462079769598 0.36008161845038977322;-1.3959000911314303206 0.0072718912499317306386 -0.0069091806452972158448 0.03252452948596788046 0.048183043775132917974 -0.090004533970955175048;0.029316958483144493763 0.0068122394289410262425 1.4660066146536769871 -0.029409721168997138979 0.090524309159346352804 -0.009961803889056555783;0.054078802264379929876 -0.023713993948357554864 1.5943560908796581277 -0.074222662838174510713 -0.22441033592658374274 0.50551890703026935991;0.0035134494768403463742 1.3502391337626034051 0.014838537631567675495 -0.40800052557255178654 0.079036890918649058135 0.73408084964250219784;0.010603067605131143852 1.5460676333179090847 0.018498026940454076472 0.050939514578350088281 -0.10984694487970295129 0.023854516361341549652;-2.0644152722145840961 0.062256199736235064146 0.010099582419517756648 0.24696139272926059816 -0.056722527263166865052 0.57354572960747041144;1.2736655989673915101 -0.0016608421242834366712 -0.013142338622925350128 0.060806058588442978008 -0.10913724161146608205 0.40712951938690905562;1.3563642192912257478 -0.066059170963851179814 -0.027998934225487100841 -0.19717015644104468608 0.036359149085922679023 -0.55027871774311698516;0.61511372183686718884 0.60574488897333444992 -0.60505416978781767501 -0.032938541182708856725 -0.1245610378710442917 0.15330960964077225417;-1.4321718532029801274 -1.4152075334331100098 1.4774194199729890897 -0.0095519428685474899038 -0.36551570628646334882 0.13947160658707380421;0.047088026964543248931 1.9632754031751309398 -0.12247807809440017235 1.3990678899276431935 -0.29786224136050915012 -2.5649841273971314415;0.92607958452549044015 0.45990980224124256015 -0.56780226574830439112 -0.13230264017675080757 -0.43881086844851979745 0.21874355880125948004;1.1499191988764048133 1.1191640252952768098 -1.1492517230657046667 -0.02728930514325038012 0.02499126912989173277 0.077244432577952451546;1.5027301590613100224 -1.2784194144168476992 -0.15317464112647982688 -0.20384741059534966512 0.016788712254186353295 -0.26029784868664052055];

% Layer 2
b2 = [-1.4968960437553147091;0.09266324647582233065;0.11831367794408265137];
LW2_1 = [-0.026798455178927101106 0.31062139465496668711 -0.59673020937732734748 1.2437816161673396209 0.033946857064605616416 -0.15816782374712254344 1.9309726286046808852 -1.5792383117644599455 1.4466802729318273268 1.879789463274934791 1.859415926791560425 0.36196210544740647919 5.4483676301544337051 0.11240296940425087424 -1.5723155458306832255 -1.2114611146042448819 -1.8165522759967325239 -2.7875006334475633807 1.2922336729261238997 0.12711006658453255502 -1.5650768409728668118 -1.524514707392883528 -0.076706793013409643733 -3.4713022533032731332 -0.029093688476668117665 -0.51356173072635524779 -0.10562740255306507509 0.59690561548468912267 -0.57025917207124099395 -0.058007918757893302264;0.81297092409812687919 -0.91248591959138591889 0.079370382333509251205 -0.27160918404612349741 0.1481323661261624014 1.4257601305469893571 -0.8515699508022619435 0.80726992549971554602 1.1644521901331774671 -1.2052811356574733015 -0.63527644230756519228 0.51995527412494946251 -1.1666353198985592865 0.1244720869770774424 -4.6783206261946350679 0.51643843629648267246 -3.9587505939816471923 0.17640980609422252101 0.43968562378159337101 -2.3360377338742424058 -1.9427863431711474007 -0.42722113336911637926 -2.4271409373472181414 -0.032435321472576453261 -3.477590289431355508 0.88427927382098014952 0.054563073395808479871 0.0072991319649624948271 2.4143813995740988076 0.031333085866140059084;0.022222259146938684593 -0.065141654320545430701 -0.18054582524586476611 -0.055962144208057292005 0.010487793548110921696 0.090401090720626658959 -0.078805389377559487141 0.062369380783094882648 0.11460067679474597591 -0.055270759819245576794 -0.12688528026812198601 -0.022813800595243044489 -3.7955463643099465187 0.0010926894532936018596 -0.083999885210372257638 -0.52670806804045999794 -0.187845048296486683 0.027151200772054986249 -0.0068673655053727179071 -0.052709825244097292807 -0.053924593052744898558 -0.10041287171172326442 -0.081675153465197547487 -0.33356646642284026871 0.018718526376028117586 -0.0076316471942440705645 0.00097107486741824220779 -0.055743316065909098711 -0.040974891096299356774 -0.10486871589016878714];

% Output 1
y1_step1.ymin = -1;
y1_step1.gain = [32.5493839278837;31.3341058590817;20.7971266101852];
y1_step1.xoffset = [-0.031417559279225;-0.0319097844020925;0.00030438901981315];

% ===== SIMULATION ========

% Dimensions
Q = size(x1,2); % samples

% Input 1
xp1 = mapminmax_apply(x1,x1_step1);

% Layer 1
a1 = tansig_apply(repmat(b1,1,Q) + IW1_1*xp1);

% Layer 2
a2 = repmat(b2,1,Q) + LW2_1*a1;

% Output 1
y1 = mapminmax_reverse(a2,y1_step1);
end

% ===== MODULE FUNCTIONS ========

% Map Minimum and Maximum Input Processing Function
function y = mapminmax_apply(x,settings)
y = bsxfun(@minus,x,settings.xoffset);
y = bsxfun(@times,y,settings.gain);
y = bsxfun(@plus,y,settings.ymin);
end

% Sigmoid Symmetric Transfer Function
function a = tansig_apply(n,~)
a = 2 ./ (1 + exp(-2*n)) - 1;
end

% Map Minimum and Maximum Output Reverse-Processing Function
function x = mapminmax_reverse(y,settings)
x = bsxfun(@minus,y,settings.ymin);
x = bsxfun(@rdivide,x,settings.gain);
x = bsxfun(@plus,x,settings.xoffset);
end

可以看到,Matlab function已经将模型参数给出,然后将模型参数传递关系表示出来,主要包括:

1.输入$x$归一化 $x_{norm} = 2\dfrac{x-x_{min}}{x_{max}-x_{min}}-1$

2.隐含层表示 $a_1 = tansig(w_{1} x_{norm}+b_1)$,采用$tansig(x) = \dfrac{2}{1+\exp^{-2x}}-1$ 做激活函数

3.输出层表示 $y_{norm} = w_{2}a_1+b_2$

4.输出$y_{norm}$反归一化 $y=\dfrac{(y_{norm}+1)(y_{max}-y_{min})}{2}+y_{min}$

根据本例情况,我们将上述Matlab function精简改写为了如下形式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
function [y1] = ctr_nn_mathmatic(x1)
% ===== NEURAL NETWORK CONSTANTS =====
% Input 1
x1_step1.xoffset = [-3.14151301294718;-3.14158930186924;-3.14155700428902;1.39278504079909e-07;3.796394083046e-05;0.00030838924090045];
x1_step1.gain = [0.318319717213362;0.31831692263459;0.318312047960775;40.0009294773944;26.7292411660151;20.2091493687639];
x1_step1.ymin = -1;

% Layer 1
b1 = [-1.7388939194493422402;-2.4749417490986771462;-1.1249356202752482936;-2.6714123085951806402;-1.858068131102155629;1.6478849199826257621;2.0846431364830761979;-0.89758955793323513817;0.40029527742044296312;-0.74791567806529213325;-0.60109872491771010328;-0.22093805634656574211;0.1065022160035486859;0.95913365748786505449;0.14339184536803128989;-0.54386742892165718111;-0.029544042983252836315;-0.63379761097495923572;-0.47405479609874645597;0.040394931367360040197;-0.33023457936925720846;-1.3853308931221361977;0.33426439440168753681;0.89849937204249341605;0.63809375203337226345;-1.8653663278578695994;0.83059864496975277515;1.2747355402384539236;1.3790147727652526388;1.0089799371839851538];
IW1_1 = [0.70481821454761317547 0.94066804495340350378 0.19250806896302796734 -0.175961397806350478 -0.26282501476827724307 0.49390354611646203331;1.7410593586642995056 0.090500092368769199935 -0.099058213574340178709 0.26895023627673586564 -0.11956263934775875069 0.53771048340498228679;0.19185886834324769601 0.15729276321291146834 -0.022744322563751247918 -0.21072866906180459656 0.22215943412019445447 -0.91056392604408276092;0.0021397716370694344797 0.1469248172827831711 -2.4427413900570775773 -0.12830278955081728109 -0.31646773667236560712 0.70714665856302749702;-1.7493888027266726315 1.9199067651441621418 -1.6882125009460646758 0.0027791275946151025415 -0.42869533399801107443 0.42692001936564372144;-0.40592776178898870576 -1.0540391176762362591 0.22607798737491094365 0.19895099638437335754 0.31090257842229579444 -0.65142163947122866396;0.035477532444007316148 -0.077402207909181480749 1.760706547608921424 0.11849754222414986271 0.41258618623999304686 -0.84595145483721523316;1.3849718476235566289 -0.039323501261560298015 -0.0047239758189265212626 0.29161423190608293243 0.078763916218152624227 0.17822837124910223539;-1.6000816141983904561 0.0042133417763625105254 0.030256997109806012258 0.17074487183614442398 0.0061362076773207267083 0.22987050130053141017;0.019073588217767584174 -0.0018744968898865030458 1.2276300703908842316 0.035991546321953303555 0.3630223495478334117 -0.5016916197639439412;1.0977343392837806135 -0.02538236614927592838 -0.029557893580960371838 0.27634998831452611823 0.014353760944995814369 0.44068236611367705979;0.99038352169396393077 0.84755627882871376055 -0.9825709030042119041 -0.054065007250253468984 -0.24300870861524129229 0.16509199144961023142;-0.02034320820540015673 -0.029122973820098509851 0.00098084129364271743778 -0.0029049751507347894124 -0.0068183736211576253042 -0.26613290444897252929;-1.5535630529917612286 1.4993155622530653925 -1.5847915517453563528 0.023862338643182847558 0.64741397241384046612 -0.40086303321128718968;-0.013051330236215213931 -1.3886059061710134266 -0.038679842880104439717 0.17749540683558542864 -0.009735829994397633344 -0.34965948264055973782;-0.72852353827228610861 0.28836294609575863213 0.053464308287469695935 0.15050259567743545785 0.0086806933462079769598 0.36008161845038977322;-1.3959000911314303206 0.0072718912499317306386 -0.0069091806452972158448 0.03252452948596788046 0.048183043775132917974 -0.090004533970955175048;0.029316958483144493763 0.0068122394289410262425 1.4660066146536769871 -0.029409721168997138979 0.090524309159346352804 -0.009961803889056555783;0.054078802264379929876 -0.023713993948357554864 1.5943560908796581277 -0.074222662838174510713 -0.22441033592658374274 0.50551890703026935991;0.0035134494768403463742 1.3502391337626034051 0.014838537631567675495 -0.40800052557255178654 0.079036890918649058135 0.73408084964250219784;0.010603067605131143852 1.5460676333179090847 0.018498026940454076472 0.050939514578350088281 -0.10984694487970295129 0.023854516361341549652;-2.0644152722145840961 0.062256199736235064146 0.010099582419517756648 0.24696139272926059816 -0.056722527263166865052 0.57354572960747041144;1.2736655989673915101 -0.0016608421242834366712 -0.013142338622925350128 0.060806058588442978008 -0.10913724161146608205 0.40712951938690905562;1.3563642192912257478 -0.066059170963851179814 -0.027998934225487100841 -0.19717015644104468608 0.036359149085922679023 -0.55027871774311698516;0.61511372183686718884 0.60574488897333444992 -0.60505416978781767501 -0.032938541182708856725 -0.1245610378710442917 0.15330960964077225417;-1.4321718532029801274 -1.4152075334331100098 1.4774194199729890897 -0.0095519428685474899038 -0.36551570628646334882 0.13947160658707380421;0.047088026964543248931 1.9632754031751309398 -0.12247807809440017235 1.3990678899276431935 -0.29786224136050915012 -2.5649841273971314415;0.92607958452549044015 0.45990980224124256015 -0.56780226574830439112 -0.13230264017675080757 -0.43881086844851979745 0.21874355880125948004;1.1499191988764048133 1.1191640252952768098 -1.1492517230657046667 -0.02728930514325038012 0.02499126912989173277 0.077244432577952451546;1.5027301590613100224 -1.2784194144168476992 -0.15317464112647982688 -0.20384741059534966512 0.016788712254186353295 -0.26029784868664052055];

% Layer 2
b2 = [-1.4968960437553147091;0.09266324647582233065;0.11831367794408265137];
LW2_1 = [-0.026798455178927101106 0.31062139465496668711 -0.59673020937732734748 1.2437816161673396209 0.033946857064605616416 -0.15816782374712254344 1.9309726286046808852 -1.5792383117644599455 1.4466802729318273268 1.879789463274934791 1.859415926791560425 0.36196210544740647919 5.4483676301544337051 0.11240296940425087424 -1.5723155458306832255 -1.2114611146042448819 -1.8165522759967325239 -2.7875006334475633807 1.2922336729261238997 0.12711006658453255502 -1.5650768409728668118 -1.524514707392883528 -0.076706793013409643733 -3.4713022533032731332 -0.029093688476668117665 -0.51356173072635524779 -0.10562740255306507509 0.59690561548468912267 -0.57025917207124099395 -0.058007918757893302264;0.81297092409812687919 -0.91248591959138591889 0.079370382333509251205 -0.27160918404612349741 0.1481323661261624014 1.4257601305469893571 -0.8515699508022619435 0.80726992549971554602 1.1644521901331774671 -1.2052811356574733015 -0.63527644230756519228 0.51995527412494946251 -1.1666353198985592865 0.1244720869770774424 -4.6783206261946350679 0.51643843629648267246 -3.9587505939816471923 0.17640980609422252101 0.43968562378159337101 -2.3360377338742424058 -1.9427863431711474007 -0.42722113336911637926 -2.4271409373472181414 -0.032435321472576453261 -3.477590289431355508 0.88427927382098014952 0.054563073395808479871 0.0072991319649624948271 2.4143813995740988076 0.031333085866140059084;0.022222259146938684593 -0.065141654320545430701 -0.18054582524586476611 -0.055962144208057292005 0.010487793548110921696 0.090401090720626658959 -0.078805389377559487141 0.062369380783094882648 0.11460067679474597591 -0.055270759819245576794 -0.12688528026812198601 -0.022813800595243044489 -3.7955463643099465187 0.0010926894532936018596 -0.083999885210372257638 -0.52670806804045999794 -0.187845048296486683 0.027151200772054986249 -0.0068673655053727179071 -0.052709825244097292807 -0.053924593052744898558 -0.10041287171172326442 -0.081675153465197547487 -0.33356646642284026871 0.018718526376028117586 -0.0076316471942440705645 0.00097107486741824220779 -0.055743316065909098711 -0.040974891096299356774 -0.10486871589016878714];

% Output 1
y1_step1.ymin = -1;
y1_step1.gain = [32.5493839278837;31.3341058590817;20.7971266101852];
y1_step1.xoffset = [-0.031417559279225;-0.0319097844020925;0.00030438901981315];


% ===== GET EXPRESSION =====
% Input 1
xp1 = (x1 - x1_step1.xoffset) .* x1_step1.gain -1;      % 输入归一化

% Layer 1
a1 = 2 ./ (1 + exp(-2*(b1 + IW1_1*xp1))) - 1;           % 隐含层

% Layer 2
a2 = b2 + LW2_1*a1;                                     % 输出层

% Output 1
y1 = (a2+1) ./ y1_step1.gain + y1_step1.xoffset;        % 输出反归一化
end