0%

DQN倒立摆代码解析

对莫烦老师的DQN倒立摆代码做了逐行解析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym

# Hyper Parameters 定义超参数
BATCH_SIZE = 32 # 样本数量
LR = 0.01 # learning rate 学习率
EPSILON = 0.9 # greedy policy
GAMMA = 0.9 # reward discount 折扣回报
TARGET_REPLACE_ITER = 100 # target update frequency 目标网络更新频率
MEMORY_CAPACITY = 2000 # 记忆库容量
env = gym.make('CartPole-v1') # 选择游戏 CartPole
env = env.unwrapped # 打开环境封装
N_ACTIONS = env.action_space.n # 杆子动作
N_STATES = env.observation_space.shape[0] # 杆子状态
ENV_A_SHAPE = 0 if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape # to confirm the shape

"""
Action Space
| Num | Action |
|-----|------------------------|
| 0 | Push cart to the left |
| 1 | Push cart to the right |
"""

"""
Observation Space
| Num | Observation | Min | Max |
|-----|-----------------------|----------------------|--------------------|
| 0 | Cart Position | -4.8 | 4.8 |
| 1 | Cart Velocity | -Inf | Inf |
| 2 | Pole Angle | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) |
| 3 | Pole Angular Velocity | -Inf | Inf |
"""

# 定义Net类
class Net(nn.Module):
def __init__(self, ): # 定义构造函数 nn.Module的子类函数必须在构造函数中执行父类的构造函数
super(Net, self).__init__() # 等价于 nn.Module.__init__()
self.fc1 = nn.Linear(N_STATES, 50) # 设置第一个全连接层(输入层到隐藏层): 状态数个神经元到50个神经元
self.fc1.weight.data.normal_(0, 0.1) # initialization 权值初始化(均值为0,方差为0.1的正态分布) 使神经网络更加收敛
self.out = nn.Linear(50, N_ACTIONS) # 设置第二个全连接层(隐藏层到输出层): 50个神经元到动作数个神经元
self.out.weight.data.normal_(0, 0.1) # initialization 权值初始化(均值为0,方差为0.1的正态分布)

def forward(self, x): # 前向传播 x为状态
x = self.fc1(x) # 连接输入层到隐藏层
x = F.relu(x) # 使用激励函数ReLU来处理经过隐藏层后的值
actions_value = self.out(x) # 连接隐藏层到输出层,获得最终的输出值 (即动作值)
return actions_value # 返回动作值

# 定义DQN网络(两个网络)
class DQN(object):
def __init__(self):
self.eval_net, self.target_net = Net(), Net() # 利用Net类创建两个神经网络: 评估网络和目标网络

self.learn_step_counter = 0 # for target updating
self.memory_counter = 0 # for storing memory
self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memory 初始化记忆库,一行一个transition,为s(位置、速度、角度、角速度)、s'(位置、速度、角度、角速度)、动作、奖励
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR) # 使用Adam优化器 (输入为评估网络的参数和学习率)
self.loss_func = nn.MSELoss() # 使用均方损失函数 (loss(xi, yi)=(xi-yi)^2)

def choose_action(self, x): # 定义动作选择函数 x为状态
x = torch.unsqueeze(torch.FloatTensor(x), 0) # 将x转换成32-bit floating point形式,并在dim=0增加维数为1的维度
# input only one sample
if np.random.uniform() < EPSILON: # greedy 生成一个在[0, 1)内的随机数,如果小于EPSILON,选择最优动作
actions_value = self.eval_net.forward(x) # 通过对评估网络输入状态x,前向传播获得动作值
action = torch.max(actions_value, 1)[1].data.numpy() # 输出每一行最大值的索引,并转化为numpy ndarray形式
action = action[0] if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE) # return the argmax index 输出action的第一个数
else: # random 随机选择动作
action = np.random.randint(0, N_ACTIONS) # 这里action随机等于0或1 (N_ACTIONS = 2)
action = action if ENV_A_SHAPE == 0 else action.reshape(ENV_A_SHAPE)
return action

def store_transition(self, s, a, r, s_): # 定义记忆存储函数 (这里输入为一个transition)
transition = np.hstack((s, [a, r], s_)) # 在水平方向上拼接数组
# replace the old memory with new memory
index = self.memory_counter % MEMORY_CAPACITY # 获取transition要置入的行数
self.memory[index, :] = transition # 置入transition
self.memory_counter += 1 # memory_counter自加1

def learn(self): # 定义学习函数(记忆库已满后便开始学习)
# target parameter update
if self.learn_step_counter % TARGET_REPLACE_ITER == 0: # 一开始触发,然后每100步触发
self.target_net.load_state_dict(self.eval_net.state_dict()) # 将评估网络的参数赋给目标网络
self.learn_step_counter += 1 # 学习步数自加1

# sample batch transitions 抽取记忆库中的数据
sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE) # 在[0, 2000)内随机抽取32个数,可能会重复
b_memory = self.memory[sample_index, :] # 抽取32个索引对应的32个transition,存入b_memory
b_s = torch.FloatTensor(b_memory[:, :N_STATES]) # 将32个s抽出,转为32-bit floating point形式,并存储到b_s中,b_s为32行4列
b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int)) # 将32个a抽出,转为64-bit integer (signed)形式,并存储到b_a中 (之所以为LongTensor类型,是为了方便后面torch.gather的使用),b_a为32行1列
b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2]) # 将32个r抽出,转为32-bit floating point形式,并存储到b_s中,b_r为32行1列
b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:]) # 将32个s_抽出,转为32-bit floating point形式,并存储到b_s中,b_s_为32行4列

# q_eval w.r.t the action in experience 获取32个transition的评估值和目标值,并利用损失函数和优化器进行评估网络参数更新
q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1) eval_net(b_s)通过评估网络输出32行每个b_s对应的一系列动作值,然后.gather(1, b_a)代表对每行对应索引b_a的Q值提取进行聚合
q_next = self.target_net(b_s_).detach() # detach from graph, don't backpropagate q_next不进行反向传递误差,所以detach;q_next表示通过目标网络输出32行每个b_s_对应的一系列动作值
q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1) # shape (batch, 1) q_next.max(1)[0]表示只返回每一行的最大值,不返回索引(长度为32的一维张量);.view()表示把前面所得到的一维张量变成(BATCH_SIZE, 1)的形状;最终通过公式得到目标值
loss = self.loss_func(q_eval, q_target) # 输入32个评估值和32个目标值,使用均方损失函数

self.optimizer.zero_grad() # 清空上一步的残余更新参数值
loss.backward() # 误差反向传播, 计算参数更新值
self.optimizer.step() # 更新评估网络的所有参数

dqn = DQN() # 令dqn=DQN类

print('\nCollecting experience...')
for i_episode in range(400): # 400个episode循环
s = env.reset() # 重置环境
ep_r = 0 # 初始化该循环对应的episode的总奖励
while True: # 开始一个episode (每一个循环代表一步)
env.render() # 显示实验动画
a = dqn.choose_action(s) # 输入该步对应的状态s,选择动作

# take action
s_, r, done, info = env.step(a) # 执行动作,获得反馈

# modify the reward 修改奖励 (不修改也可以,修改奖励只是为了更快地得到训练好的摆杆)
x, x_dot, theta, theta_dot = s_
r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8
r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
r = r1 + r2

dqn.store_transition(s, a, r, s_) # 存储样本

ep_r += r # 逐步加上一个episode内每个step的reward
if dqn.memory_counter > MEMORY_CAPACITY:
dqn.learn()
if done:
print('Ep: ', i_episode,
'| Ep_r: ', round(ep_r, 2))

if done:
break
s = s_ # 更新状态