DQN(Deep Q-Network)可以用来实现对倒立摆(CartPole)对象的控制。
DQN的原理就是建立一个神经网络来替代Q-Learning算法中Q-Table,根据对象的状态和采用的动作输出对应的Q值,Q值越高表示动作能得到的奖励越高。在DQN用于强化学习时,采取历史回放和Fixed Target策略,即系统状态和动作被记录的历史数据中,并被在学习过程中被回放进行学习,以模拟人的学习原理。另外,采用两个网络,一个网络(Target Net)相对稳定,用来评估目标的Q值,另一个网络不断学习迭代,经过一定次数迭代后再替换原来的Target Net。关于Q-Learning和DQN等详细介绍,可以参考文末参考链接。
实现示例如下
1 载入模块
载入需要的模块
# import required modules
import gym
import random
import numpy as np
import math
import torch
import torch.nn as nn
2 定义网络
class Net(nn.Module):
def __init__(self, n_states, n_actions):
super().__init__()
self.fc1 = nn.Linear(n_states, 10)
self.fc2 = nn.Linear(10, n_actions)
self.fc1.weight.data.normal_(0,0.1)
self.fc2.weight.data.normal_(0,0.1)
def forward(self, inputs):
x = self.fc1(inputs)
x = nn.functional.relu(x)
outputs = self.fc2(x)
return outputs
网络由一个前向的全连接层,一个relu层,和一个全连接输出层组成。输入为倒立摆对象的状态(位置、速度、角度、角速度),输出为对应两个动作(0和1)的Q值,Q值高的动作被选择作为采取的动作。
3 定义DQN学习策略
# define DQN
class DQN:
def __init__(self, n_states, n_actions):
# two nets
self.eval_net = Net(n_states, n_actions)
self.target_net = Net(n_states, n_actions)
self.loss
文章评论