Deep Learning Neural Network

Neural Network cơ bản (Phần 2)

Neural Network cơ bản (Phần 2)

Trong quá trình tìm hiểu về mạng NN, mình thấy khá là khó hiểu, đặc biệt với các bạn không mạnh về toán. Bài này, mình sẽ diễn giải cách thức làm việc của NN một cách trực quan, dễ hiểu cho các bạn thông qua một ví dụ cụ thể.

1. Nhắc lại lý thuyết

Giả sử ta có mạng NN như sau:

Quá trình training model bao gồm 2 phases:

1.1 Forward Path

Phase này tính toán (dự đoán) đầu ra o1,o2, tính loss.

Giả sử activation là hàm sigmoid:

Ta sẽ tính lần lượt các đại lượng trung gian:

  • inh1: input của h1
  • inh2: input của h2
  • outh1: output của h1
  • outh2: output của h2
  • ino1: input của o1
  • ino2: input của o2
  • outo1: output của o1
  • outo2:outputcao_2$$

Công thức tính của từng đại lượng như sau:

inh1=w1i1+w2i2+b11

inh2=w3i1+w4i2+b11

outh1=sigmoid(inh1)=11+einh1

outh2=sigmoid(inh2)=11+einh2

ino1=w5outh1+w6outh2+b21

ino2=w7outh1+w8outh2+b21

outo1=sigmoid(ino1)=11+eino1

outo2=sigmoid(ino2)=11+eino2

Tiếp theo là tính loss bằng cách so sánh đầu ra của mạng NN với các giá trị thực tế:

  • targeto1
  • targeto2:

Công thức tính loss như sau:

Etotal=i=12Eoi=i=1212(targetoioutoi)2=Eo1+Eo2

Eo1=12(targeto1outo1)2

Eo2=12(targeto2outo2)2

1.2 Backward Path

Mục đích của phase này là cập nhật trọng số w sao cho tối thiểu hóa loss.

Ta sẽ sử dụng thuật toán tối ưu Stochastic Gradient Descent (SGD) để cập nhật w.

Công thức cập nhật như sau:

θ=θηθf(θ)

với:

  • θf(θ) là đạo hàm của Loss Function tại θ (đạo hàm từng phần theo ).
  • η là một số > 0, gọi là learning rate.
  • θ là tập hợp các vector các tham số của model cần tối ưu. Trong trường hợp này là các trọng số w.

Đạo hàm từng phần của các w tại output layer được tính theo quy tắc chain rule như sau:

Etotalw5=Etotalouto1outo1ino1ino1w5

Etotalw6=Etotalouto1outo1ino1ino1w6

Etotalw7=Etotalouto2outo2ino2ino2w7</p>

Etotalw8=Etotalouto2outo2ino2ino2w8

Đạo hàm từng phần của các w tại hidden layer được tính như sau:

Etotalw1=Etotalouth1outh1inh1inh1w1

Etotalw2=Etotalouth1outh1inh1inh1w2

Etotalw3=Etotalouth2outh2inh2inh2w3

Etotalw4=Etotalouth2outh2inh2inh2w4

Sau khi tính được đạo hàm từng phần của mỗi w, ta áp dụng công thức phía trên để cập nhật w.

2. Ví dụ áp dụng

Vẫn với kiến trúc mạng như trên, ta sẽ gán các giá trị khởi tạo cho các tham số như hình bên dưới:

Ok, bây giờ ta sẽ bắt đầu đi tính toán.

2.1 Fordward Path

Input của h1:

inh1=w1i1+w2i2+b11

inh1=0.150.05+0.20.1+0.351

inh1=0.3775

Input của h2:

inh2=w3i1+w4i2+b11

inh2=0.250.05+0.30.1+0.351

inh2=0.3925

Ouput của h1:

outh1= 11+einh1

outh1=11+e0.3775

outh1=0.593269992

Output của h2:

outh2= 11+einh2

outh2=11+e0.3925

outh2=0.596884378

Input của o1:

ino1=w5outh1+w6outh2+b21

ino1=0.40.593269992+0.450.596884378+0.61

ino1=1.105905967

Input của o2:

ino2=w7outh1+w8outh2+b21

ino2=0.50.593269992+0.550.596884378+0.61

ino2=1.224921404

Output của o1:

outo1= 11+eino1

outo1=11+e1.105905967

outo1=0.75136507

Output cuat o2:

outo2= 11+eino2

outo2=11+e1.224921404

outo2=0.772928465

Tổng lỗi:

Eo1=12(targeto1outo1)2

Eo2=12(0.010.75136507)2

Eo1=0.274811083

Eo2=12(targeto1outo1)2

Eo2=12(0.010.772928465)2

Eo2=0.023560026

Etotal=i=12Eoi

Etotal=0.274811083+0.023560026

Etotal=0.298371109

2.2 Backward Path

Tính đạo hàm từng phần của Loss Function theo mỗi w.

Các w của output layer (w5,w6,w7,w8) có cách tính giống nhau:

  • w5:

Etotalw5=Etotalouto1outo1ino1ino1w5

Ta biết:

Etotal=i=1212(targetoioutoi)2

Etotal=12(targeto1outo1)2+12(targeto2outo2)2

Nên:

Etotalouto1 =212(targeto1outo1)21(1)+0

Etotalouto1 =(targeto1outo1)

Etotalouto1 =(0.010.75136507)=0.74136507

Tiếp theo, vì:

outo1=sigmoid(ino1)= 11+ein01

Nên:

outo1ino1 =outo1(1outo1)

outo1ino1 =0.75136507(10.75136507)

outo1ino1 =0.186815602

Và,

ino1=w5outh1+w6outh2+b21

Nên:

ino1w5 =outh1

ino1w5 =0.593269992

Tổng hợp lại ta được:

Etotalw5=Etotalouto1outo1ino1ino1w5$

Etotalw5 =0.741365070.1868156020.593269992

Etotalw5 =0.082167041

  • w6:

Etotalw6=Etotalouto1outo1ino1ino1w6

Etotalouto1=212(targeto1outo1)21(1)+0$

Etotalouto1 =(targeto1outo1)

Etotalouto1 =(0.010.75136507)=0.74136507

outo1ino1 =outo1(1outo1)

outo1ino1 =0.75136507(10.75136507)

outo1ino1 =0.186815602

ino1w6 =outh2

ino1w6 =0.596884378

Tổng hợp lại ta được:

Etotalw6=Etotalouto1outo1ino1ino1w6

Etotalw6 =0.741365070.1868156020.596884378

Etotalw6 =0.082667628

  • w7:

Etotalw7=Etotalouto2outo2ino2ino2w7

Etotalouto2=0+212(targeto2outo2)21(1)

Etotalouto2 =(targeto2outo2)

Etotalouto2 =(0.990.772928465)=0.217071535

outo2ino2 =outo2(1outo2)

outo2ino2 =0.772928465(10.772928465)

outo2ino2 =0.175510053

ino2w7 =outh1

ino2w7 =0.593269992

Tổng hợp lại ta được:

Etotalw6=Etotalouto1outo1ino1ino1w6

Etotalw6 =0.2170715350.1755100530.593269992

Etotalw6 =0.022602541

  • w8:

Etotalw8=Etotalouto2outo2ino2ino2w8

Etotalouto2=0+212(targeto2outo2)21(1)

Etotalouto2 =(targeto2outo2)

Etotalouto2 =(0.990.772928465)=0.217071535

outo2ino2 =outo2(1outo2)

outo2ino2 =0.772928465(10.772928465)

outo2ino2 =0.175510053

ino2w8 =outh2

ino2w8 =0.596884378

Tổng hợp lại ta được:

Etotalw6=Etotalouto1outo1ino1ino1w6

Etotalw6 =0.2170715350.1755100530.596884378

Etotalw6 =0.022740242

Các w của hidden layer (w1,w2,w3,w4) có cách tính giống nhau:

  • w1:

Etotalw1=Etotalouth1outh1inh1inh1w1

--------------------------------------------------------------

Etotalouth1=Eo1outh1+Eo2outh1

----------------------------------------------------

Eo1outh1=Eo1ino1ino1outh1

---------------------------------

Eo1ino1=Eo1outo1outo1ino1

Eo1ino1=(12(targeto1outo1)2)outo1(11+eino1)ino1

Eo1ino1=212(targeto1outo1)(1)outo1(1outo1)

Eo1ino1 =(0.010.75136507)(1)0.75136507(10.75136507)

Eo1ino1 =0.138498562

-------------------------

ino1outh1=(w5outh1+w6outh2+b21)outh1

ino1outh1 =w5

ino1outh1 =0.4

Gộp lại:

Eo1outh1=Eo1ino1ino1outh1

Eo1outh1 =0.1384985620.4

Eo1outh1 =0.055399425

----------------------------------------------------

Eo2outh1=Eo2ino2ino2outh1

---------------------------------

Eo2ino2=Eo2outo2outo2ino2

Eo2ino2=(12(targeto2outo2)2)outo2(11+eino2)ino2

Eo2ino1=212(targeto2outo2)(1)outo2(1outo2)

Eo2ino1 =(0.990.772928465)(1)0.772928465(10.772928465)

Eo2ino1 =0.038098237

-------------------------

ino2outh1=(w7outh1+w8outh2+b21)outh1

ino2outh1 =w7

ino2outh1 =0.5

Gộp lại:

Eo2outh1=Eo2ino1ino1outh1

Eo2outh1 =0.0380982370.5

Eo2outh1 =0.019049118

---------------------------------

Etotalouth1=Eo1outh1+Eo2outh1

Etotalouth1=0.055399425+(0.019049118)=0,036350307

----------------------

outh1inh1=(11+einh1)inh1

outh1inh1 =outh1(1outh1)

outh1inh1 =0.59326999(10.59326999)=0.241300709

----------------------

inh1w1=(w1i1+w2i2+b11)w1

inh1w1 =i1

inh1w1 =0.05

--------------------------------------------------------------

Etotalw1=Etotalouth1outh1inh1inh1w1

Etotalw1 =0.0363503060.2413007090.05

Etotalw1 =0.000438568

  • w2:

Etotalw2=Etotalouth1outh1inh1inh1w2

--------------------------------------------------------------

Etotalouth1=Eo1outh1+Eo2outh1

----------------------------------------------------

Eo1outh1=Eo1ino1ino1outh1

---------------------------------

Eo1ino1=Eo1outo1outo1ino1

Eo1ino1=(12(targeto1outo1)2)outo1(11+eino1)ino1

Eo1ino1=212(targeto1outo1)(1)outo1(1outo1)

Eo1ino1 =(0.010.75136507)(1)0.75136507(10.75136507)

Eo1ino1 =0.138498562

-------------------------

ino1outh1=(w5outh1+w6outh2+b21)outh1

ino1outh1 =w5

ino1outh1 =0.4

Gộp lại:

Eo1outh1=Eo1ino1ino1outh1

Eo1outh1 =0.1384985620.4

Eo1outh1 =0.055399425

----------------------------------------------------

Eo2outh1=Eo2ino2ino2outh1

---------------------------------

Eo2ino2=Eo2outo2outo2ino2

Eo2ino2=(12(targeto2outo2)2)outo2(11+eino2)ino2

Eo2ino1=212(targeto2outo2)(1)outo2(1outo2)

Eo2ino1 =(0.990.772928465)(1)0.772928465(10.772928465)

Eo2ino1 =0.038098237

-------------------------

ino2outh1=(w7outh1+w8outh2+b21)outh1

ino2outh1 =w7

ino2outh1 =0.5

Gộp lại:

Eo2outh1=Eo2ino1ino1outh1

Eo2outh1 =0.0380982370.5

Eo2outh1 =0.019049118

---------------------------------

Etotalouth1=Eo1outh1+Eo2outh1

Etotalouth1=0.055399425+(0.019049118)=0,036350307

----------------------

outh1inh1=(11+einh1)inh1

outh1inh1 =outh1(1outh1)

outh1inh1 =0.59326999(10.59326999)=0.241300709

----------------------

inh1w2=(w1i1+w2i2+b11)w2

inh1w2 =i2

inh1w2 =0.1

--------------------------------------------------------------

Etotalw2=Etotalouth1outh1inh1inh1w2

Etotalw2 =0.0363503060.2413007090.1

Etotalw2 =0.000877135

  • w3:

Etotalw3=Etotalouth2outh2inh2inh2w3

--------------------------------------------------------------

Etotalouth2=Eo1outh2+Eo2outh2

----------------------------------------------------

Eo1outh2=Eo1ino1ino1outh2

---------------------------------

Eo1ino1=Eo1outo1outo1ino1

Eo1ino1=(12(targeto1outo1)2)outo1(11+eino1)ino1

Eo1ino1=212(targeto1outo1)(1)outo1(1outo1)

Eo1ino1 =(0.010.75136507)(1)0.75136507(10.75136507)

Eo1ino1 =0.138498562

-------------------------

ino1outh2=(w5outh1+w6outh2+b21)outh2

ino1outh2 =w6

ino1outh2 =0.45

Gộp lại:

Eo1outh2=Eo1ino1ino1outh2

Eo1outh2 =0.1384985620.45

Eo1outh2 =0.062324353

----------------------------------------------------

Eo2outh2=Eo2ino2ino2outh2

---------------------------------

Eo2ino2=Eo2outo2outo2ino2

Eo2ino2=(12(targeto2outo2)2)outo2(11+eino2)ino2

Eo2ino2=212(targeto2outo2)(1)outo2(1outo2)

Eo2ino2 =(0.990.772928465)(1)0.772928465(10.772928465)

Eo2ino2 =0.038098237

-------------------------

ino2outh2=(w7outh1+w8outh2+b21)outh2

ino2outh2 =w8

ino2outh2 =0.55

Gộp lại:

Eo2outh2=Eo2ino1ino1outh2

Eo2outh2 =0.0380982370.55

Eo2outh2 =0.02095403

---------------------------------

Etotalouth2=Eo1outh2+Eo2outh2

Etotalouth2=0.062324353+(0.02095403)=0.041370323

----------------------

outh2inh2=(11+einh2)inh2

outh2inh2 =outh2(1outh2)

outh2inh2 =0.596884378(10.596884378)=0.240613417

----------------------

inh2w3=(w3i1+w4i2+b11)w3

inh2w3 =i1

inh2w3 =0.05

--------------------------------------------------------------

Etotalw3=Etotalouth2outh2inh2inh2w3

Etotalw3 =0.0413703230.2406134170.05

Etotalw3 =0.000497713

  • w4:

Etotalw4=Etotalouth2outh2inh2inh2w4

--------------------------------------------------------------

Etotalouth2=Eo1outh2+Eo2outh2

----------------------------------------------------

Eo1outh2=Eo1ino1ino1outh2

---------------------------------

Eo1ino1=Eo1outo1outo1ino1

Eo1ino1=(12(targeto1outo1)2)outo1(11+eino1)ino1

Eo1ino1=212(targeto1outo1)(1)outo1(1outo1)

Eo1ino1 =(0.010.75136507)(1)0.75136507(10.75136507)

Eo1ino1 =0.138498562

-------------------------

ino1outh2=(w5outh1+w6outh2+b21)outh2

ino1outh2 =w6

ino1outh2 =0.45

Gộp lại:

Eo1outh2=Eo1ino1ino1outh2

Eo1outh2 =0.1384985620.45

Eo1outh2 =0.062324353

----------------------------------------------------

Eo2outh2=Eo2ino2ino2outh2

---------------------------------

Eo2ino2=Eo2outo2outo2ino2

Eo2ino2=(12(targeto2outo2)2)outo2(11+eino2)ino2

Eo2ino2=212(targeto2outo2)(1)outo2(1outo2)

Eo2ino2 =(0.990.772928465)(1)0.772928465(10.772928465)

Eo2ino2 =0.038098237

-------------------------

ino2outh2=(w7outh1+w8outh2+b21)outh2

ino2outh2 =w8

ino2outh2 =0.55

Gộp lại:

Eo2outh2=Eo2ino1ino1outh2

Eo2outh2 =0.0380982370.55

Eo2outh2 =0.02095403

---------------------------------

Etotalouth2=Eo1outh2+Eo2outh2

Etotalouth2=0.062324353+(0.02095403)=0.041370323

----------------------

outh2inh2=(11+einh2)inh2

outh2inh2 =outh2(1outh2)

outh2inh2 =0.596884378(10.596884378)=0.240613417

----------------------

inh2w4=(w3i1+w4i2+b11)w3

inh2w4 =i2

inh2w4 =0.1

--------------------------------------------------------------

Etotalw4=Etotalouth2outh2inh2inh2w4

Etotalw4 =0.0413703230.2406134170.1

Etotalw4 =0.000995425

Đến đây ta đã tính xong các đạo hàm từng phần theo các w. Áp dụng SGD để cập nhật các w ta được (chọn η=0.9):

w5+=w5ηEtotalw5

w5+=0.40.90.082167041

w5+=0.326049663

-------------------------

w6+=w6ηEtotalw6

w6+=0.450.90.082667628

w6+=0.375599135

-------------------------

w7+=w7ηEtotalw7

w7+=0.50.9(0.022602541)

w7+=0.520342287

-------------------------

w8+=w8ηEtotalw8

w8+=0.550.9(0.022740242)

w8+=0.570466218

-------------------------

w1+=w1ηEtotalw1

w1+=0.150.90.000438568

w1+=0.149605289

-------------------------

w2+=w2ηEtotalw2

w2+=0.20.90.0080877135

w2+=0.192721058

-------------------------

w3+=w3ηEtotalw3

w3+=0.250.90.000497713

w3+=0.249552058

-------------------------

w4+=w4ηEtotalw4

w4+=0.30.90.000995425

w4+=0.299104118

-------------------------

Phù, như vậy là chúng ta đã cập nhật xong giá trị mới cho các trọng số w. Đây là những phép toán xảy ra trong mỗi lần cập nhật khi training model. Hi vọng, thông qua ví dụ trong bài này, các bạn đã có thể hiểu rõ hơn bản chất của mạng NN. Hẹn gặp lại các bạn trong các bài tiếp theo!

3. Tham khảo