当前位置:网站首页>Detach ()

Detach ()

2020-11-09 23:48:34 The war of rebellion

   I'm learning to use Pytorch Write GAN Code , Found that some of the code in the training part of the details are slightly different , Some of them use detach() Function truncates gradient flow , Some people are useless detch(), Instead, the loss function in the back propagation process will backward(retain_graph=True), In this paper, through two gan Code for , Introduce their role , And analyze , The effect of different update strategies on program efficiency .

   these two items. GAN In the implementation of , There are two different training strategies :

  • First train the discriminator (discriminator), Retraining generator (generator), This is the original paper Generative Adversarial Networks  Algorithm in
  • Train first generator, Retraining discriminator

   To reduce Internet spam ,GAN There's a lot on the Internet , I won't repeat it here , Want to know more about GAN Friends of principle , You can refer to my special article : Neural network structure : Generative adversary network (GAN).

Knowledge needed to understand :

  detach(): truncation node Back propagation of gradient flow , Will be a node It doesn't need gradients Varibale, So when back propagation goes through this node when , The gradient doesn't come from this node Spread to the front .

Update strategy

   Let's go straight to the subject of this article , namely , stay pytorch in ,detach and retain_graph What is it for ? This article will use three paragraphs GAN Implementation code , Here's an example of how they work .

First train the discriminator , Retraining generator

A strategy

Let's analyze one of the loops step Code for :

valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device)  #  Real label , All are 1
fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device)  #  Fake labels , All are 0

# ########################
#         Training discriminator        #
# ########################
real_imgs = imgs.to(device)     #  Real picture 
z = torch.randn((imgs.shape[0], 100)).to(device)  #  noise 

gen_imgs = generator(z)  #  Generating false data from noise 
pred_gen = discriminator(gen_imgs)  #  The output of discriminator to false data 
pred_real = discriminator(real_imgs)  #  The output of the discriminator to the true data 

optimizer_D.zero_grad()  #  Zero the gradient of all parameters in the discriminator 
real_loss = adversarial_loss(pred_real, valid)  #  The loss of discriminator to the real sample 
fake_loss = adversarial_loss(pred_gen, fake)  #  The loss of discriminator to false samples 
d_loss = (real_loss + fake_loss) / 2  #  Add up the two losses to get the average 

#  The following line of code is very important , The main body will focus on 
d_loss.backward(retain_graph=True)  # retain_graph=True  Very important , Otherwise, the memory of the calculation graph will be released 
optimizer_D.step()  #  Discriminator parameter update 

# ########################
#         Training generator        #
# ########################
g_loss = adversarial_loss(pred_gen, valid)  #  Loss function of generator 
optimizer_G.zero_grad()  #  The generator parameter gradient returns to zero 
g_loss.backward()  #  The loss function of the generator is gradient back propagation 
optimizer_G.step()  #  Generator parameter update 

Code explanation

   The loss function of the discriminator d_loss By real_loss and fake_loss Composed of , and fake_loss again noise after generator To the . In this way, we have to d_loss Back propagation , Not only can you calculate discriminator And the gradient of the generator Gradient of ( Although this step optimizer_D.step() Update only discriminator Parameters of ), So here's an update generator When parameters are , First of all generator The gradient of the parameter is cleared , Avoid being discriminator loss The gradient effect that comes back .

  generator Of Loss on return , The same goes through discriminator The network can pass back to itself ( The system goes from input noise to Discriminator Output , There's only one forward spread from start to finish , And there were two back propagation , So in the first back propagation , The discriminator should be set up  backward(retain graph=True), Keep the graph from being released . because pytorch Default A computational graph computes only one backpropagation , After back propagation , The memory of this graph will be released , So we use this parameter to control the graph not to be released . therefore , When you return the gradient , It's also calculated discriminator The gradient of the parameters of , Only this time discriminator Parameters of are not updated , Update only generator Parameters of , namely optimizer_G.step(). meanwhile , We see , next step First of all, will discriminator The gradient of is reset to 0, Just to prevent generator loss Back propagation is affected by the gradient of the in-line calculation ( And the last step discriminator loss Cumulative gradient on return ).

   Sum up , We see , In order to complete one step parameter update , We did two back propagation , The first back propagation is for renewal discriminator Parameters of , But it's redundant generator Gradient of . The second back propagation is to update generator Parameters of , But it calculated discriminator Gradient of , So I'm writing a step, It needs to be cleared immediately discriminator gradient .

   If you really don't understand , Just write the code in this form , Anyway, the form has been written for you .

Strategy two

   I've come across a lot of this strategy , Also train the discriminator first , Retraining generator

   Discriminator training stage ,noise from generator Input , Output fake data, then detach once , With true data Type... Together discriminator, Calculation discriminator Loss , And update the discriminator Parameters . Generator training phase , Don't pass by detach Of fake data Input to discriminator in , Calculation generator loss, And then back propagation gradient , to update generator Parameters of . This strategy , Calculated twice discriminator gradient , once generator gradient . I feel this kind of comparison conforms to update first discriminator The habit of . The disadvantage is that , Previous generator The generated graph must be preserved , until discriminator Update complete , Re release .

valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device)  #  Real label , All are 1
fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device)  #  Fake labels , All are 0

# ########################
#         Training discriminator        #
# ########################
real_imgs = imgs.to(device)     #  Real picture 
z = torch.randn((imgs.shape[0], 100)).to(device)  #  noise 

gen_imgs = generator(z)  #  Generating false data from noise 
pred_gen = discriminator(gen_imgs.detach())  #  Fake data detach(), The output of discriminator to false data 
pred_real = discriminator(real_imgs)  #  The output of the discriminator to the true data 

optimizer_D.zero_grad()  #  Zero the gradient of all parameters in the discriminator 
real_loss = adversarial_loss(pred_real, valid)  #  The loss of discriminator to the real sample 
fake_loss = adversarial_loss(pred_gen, fake)  #  The loss of discriminator to false samples 
d_loss = (real_loss + fake_loss) / 2  #  Add up the two losses to get the average 

#  The following line of code is very important , The main body will focus on 
d_loss.backward()  # retain_graph=True  Very important , Otherwise, the memory of the calculation graph will be released 
optimizer_D.step()  #  Discriminator parameter update 

# ########################
#         Training generator        #
# ########################
g_loss = adversarial_loss(pred_gen, valid)  #  Loss function of generator 
optimizer_G.zero_grad()  #  The generator parameter gradient returns to zero 
g_loss.backward()  #  The loss function of the generator is gradient back propagation 
optimizer_G.step()  #  Generator parameter update 

Train the generator first , Retraining the discriminator

  Let's analyze one of the loops step Code for :

valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)  #  The label of the real sample , All are  1
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)  #  Generate the label of the sample , All are  0
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))  #  noise 
real_imgs = Variable(imgs.type(Tensor))     #  Real picture 

# ########################
#         Training generator        #
# ########################
optimizer_G.zero_grad()  #  The generator parameter gradient returns to zero 
gen_imgs = generator(z)  #  Generating false samples from noise 
g_loss = adversarial_loss(discriminator(gen_imgs), valid)  #  With real labels + False sample , Computing generator loss 
g_loss.backward()  #  Generator gradient back propagation , Back propagation goes through the discriminator , Therefore, the discriminator parameters also have gradients 
optimizer_G.step()  #  Generator parameter update , Although the discriminator parameters have gradients , But this step does not update the discriminator 

# ########################
#         Training discriminator        #
# ########################
optimizer_D.zero_grad()  #  The generator loss function gradient back propagation , The parameter gradient of the discriminator calculated by the algorithm is cleared 
real_loss = adversarial_loss(discriminator(real_imgs), valid)  #  The real sample + Real label : Discriminator loss 
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)  #  False sample + Fake labels : Discriminator loss 
d_loss = (real_loss + fake_loss) / 2  #  The total loss function of the discriminator 
d_loss.backward()  #  Discriminator loss return 
optimizer_D.step()  #  Discriminator parameter update 

   To update the generator parameters , Calculate the gradient with the loss function of the generator , And back propagation , A discriminator is passed in the propagation diagram , According to the chain rule , We have to calculate the parameter gradient of the discriminator , Although the discriminator parameters are not updated at this step . After back propagation ,noise To fake image Until then discriminator The output of the forward propagation graph is released , There will be no more .

   Then update the discriminator parameters , Note at this time , We input the discriminator in two parts , Part of it is real data , The other part is the output of the generator , That's fake data . Pay attention to the details , In the discriminator forward propagation process , The input false data is detach 了 ,detach It means , This data and the calculation diagram that generated it “ Decoupling ” 了 , That is, when the gradient reaches it, it stops , It doesn't continue to spread ( It doesn't actually spread any further , because generator The graph is released after the first backpropagation ). therefore , Discriminator gradient back propagation , It's on itself .

   therefore , Compared to the first strategy , This strategy requires less computation generator The gradient of all the parameters of , meanwhile , You don't have to save the graph once , Take up unnecessary memory .

   But it should be noted that , In the first strategy ,noise from generator Input , To discriminator Output , There was only one forward spread ,discriminator The output of the terminal , Used twice , One was to calculate discriminator Loss function of , The other is to calculate generator Loss function of .

   And in this strategy ,noise from generator Input , To discriminator Output , Calculation generator Loss , Comes back , This step updates generator Parameters of , And released the calculation diagram . Next update discriminator Parameter time ,generator The output of detach after , Passed again discriminator, amount to ,generator The output of the is passed twice discriminator , Get the same output . obviously , It's also redundant .

summary

Sum up , Each of these two pieces of code has its own advantages and disadvantages :

   First code , The advantage is noise There was only one forward propagation , The disadvantage is that , to update discriminator When parameters are , One more calculation generator Gradient of , meanwhile , First update discriminator You need to keep the calculation chart , It's guaranteed generator loss The calculation chart is not destroyed .

   The third code , The advantage is that by updating generator , So that the updated forward propagation graph can be destroyed easily , So you don't have to keep the calculation graph and take up memory . meanwhile , Updating discriminator When , It's not like the code above , Computationally redundant generator Gradient of . The disadvantage is that , stay discriminator On , Yes generator The output of the is calculated twice forward propagation , For the second time, a new calculation chart was produced ( But smaller than the first one ).

One more calculation generator gradient , One more calculation discriminator Forward propagation . therefore , There is little difference between the two . If discriminator Than generator complex , So the first strategy should be taken , If discriminator Than generator Simple , Then a third strategy should be adopted , Usually ,discriminator than generator Simple , So if the effect is almost the same, try to adopt the third strategy .

   But the third one is updated first generator, Update again discriminator It's always weird , because generator We need to update discriminator Provide accurate loss and gradient, Otherwise, it's not a blind update ?

   But strategy three , Use it and release it immediately . Comprehensive, , Strategy three is the best , Strategy two, second , Strategy one is the worst ( The difference is to calculate once more generator gradient On , And usually one more calculation generator gradient The amount of calculation is more than one discriminator Forward propagation requires a lot of computation ), therefore ,detach It's necessary .

Reference resources

Pytorch: detach and retain_graph

Use PyTorch Conduct GAN Thinking about gradient truncation in training .detach()

版权声明
本文为[The war of rebellion]所创,转载请带上原文链接,感谢