Understanding Machine Learning: From Theory to Algorithms

(Jeff_L) #1

280 Neural Networks


Next, we discuss how to calculate the partial derivatives with respect to the
edges fromVt− 1 toVt, namely, with respect to the elements inWt− 1. Since we
fix all other weights of the network, it follows that the outputs of all the neurons
inVt− 1 are fixed numbers which do not depend on the weights inWt− 1. Denote
the corresponding vector byot− 1. In addition, let us denote by`t:Rkt→Rthe
loss function of the subnetwork defined by layersVt,...,VTas a function of the
outputs of the neurons inVt. The input to the neurons ofVtcan be written as
at=Wt− 1 ot− 1 and the output of the neurons ofVtisot=σ(at). That is, for
everyjwe haveot,j=σ(at,j). We obtain that the loss, as a function ofWt− 1 ,
can be written as

gt(Wt− 1 ) =`t(ot) =`t(σ(at)) =`t(σ(Wt− 1 ot− 1 )).

It would be convenient to rewrite this as follows. Letwt− 1 ∈Rkt−^1 ktbe the
column vector obtained by concatenating the rows ofWt− 1 and then taking the
transpose of the resulting long vector. Define byOt− 1 thekt×(kt− 1 kt) matrix

Ot− 1 =








o>t− 1 0 ··· 0
0 o>t− 1 ··· 0
..
.

..

.

... ..

.

0 0 ··· o>t− 1






. (20.2)

Then,Wt− 1 ot− 1 =Ot− 1 wt− 1 , so we can also write

gt(wt− 1 ) = `t(σ(Ot− 1 wt− 1 )).

Therefore, applying the chain rule, we obtain that

Jwt− 1 (gt) = Jσ(Ot− 1 wt− 1 )(`t) diag(σ′(Ot− 1 wt− 1 ))Ot− 1.

Using our notation we haveot=σ(Ot− 1 wt− 1 ) andat=Ot− 1 wt− 1 , which yields

Jwt− 1 (gt) =Jot(`t) diag(σ′(at))Ot− 1.

Let us also denoteδt=Jot(`t). Then, we can further rewrite the preceding as

Jwt− 1 (gt) =

(

δt, 1 σ′(at, 1 )o>t− 1 , ... , δt,ktσ′(at,kt)o>t− 1

)

. (20.3)

It is left to calculate the vectorδt=Jot(`t) for everyt. This is the gradient
of`tatot. We calculate this in a recursive manner. First observe that for the
last layer we have that`T(u) = ∆(u,y), where ∆ is the loss function. Since we
assume that ∆(u,y) =^12 ‖u−y‖^2 we obtain thatJu(`T) = (u−y). In particular,
δT=JoT(`T) = (oT−y). Next, note that

`t(u) =`t+1(σ(Wtu)).

Therefore, by the chain rule,

Ju(`t) =Jσ(Wtu)(`t+1)diag(σ′(Wtu))Wt.
Free download pdf