1. 为什么需要函数近似?
在前面的章节中,我们都假设使用表格 (Table) 来存储 Value Function:
- $v(s)$ 是一个长度为 $|∘|$ 的向量。
- $q(s, a)$ 是一个大小为 $|∘| \times |∘|$ 的矩阵。
致命问题:
- 状态空间爆炸:围棋的状态数是 $10^{170}$,任何计算机都存不下这个表。
- 泛化能力差:如果是连续状态(比如机器人的位置坐标),你没见过的坐标就没有 Value,表格法无法处理“相似的状态”。
解决方案:用一个函数 $f(s, w)$ 来近似 $v(s)$。
$$ \hat{v}(s, w) \approx v_\pi(s) $$
其中 $w$ 是参数(比如神经网络的权重)。这样我们只需要存储 $w$,通常 $w$ 的维度远小于状态数。
2. 目标函数与优化
我们希望 $\hat{v}(s, w)$ 越准越好。定义目标函数(损失函数):
$$ J(w) = \mathbb{E}{\pi} \left[ (v\pi(S) - \hat{v}(S, w))^2 \right] $$
使用 SGD 进行更新:
$$ w \leftarrow w + \alpha [v_\pi(s) - \hat{v}(s, w)] \nabla_w \hat{v}(s, w) $$
但问题是:我们不知道真实的 $v_\pi(s)$。
办法:用 TD Target 代替 $v_\pi(s)$。
$$ w \leftarrow w + \alpha [ \underbrace{(R + \gamma \hat{v}(s’, w))}_{\text{TD Target}} - \hat{v}(s, w) ] \nabla_w \hat{v}(s, w) $$
这叫 Semi-gradient,因为我们在求导时,忽略了 TD Target 对 $w$ 的依赖(把 Target 当常数看)。
3. Deep Q-Network (DQN)
当 $\hat{v}$ 是一个深度神经网络时,这就是 Deep RL。
DQN 是将 Q-learning 与神经网络结合的里程碑算法(DeepMind, Nature 2015)。
3.1 核心创新
为了解决 Deep RL 训练不稳定的问题,DQN 引入了两大通过:
- Experience Replay (经验回放):
- 把 $(s, a, r, s’)$ 存进一个 Buffer。
- 训练时随机采样 Batch。这打破了数据的时间相关性(Correlation),让数据分布更像 i.i.d.,稳定神经网络训练。
- Target Network (目标网络):
- 在计算 TD Target 时,使用一个参数固定的旧网络 $Q(s’, a; w^-)$。
- $y = r + \gamma \max_{a’} Q(s’, a’; w^-)$。
- 这避免了“自己追自己”(Chasing a moving target)的震荡问题。
3.2 Python 代码实战:DQN 核心逻辑
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| import torch import torch.nn as nn import torch.optim as optim import random from collections import deque
class QNetwork(nn.Module): def __init__(self, state_dim, action_dim): super(QNetwork, self).__init__() self.fc = nn.Sequential( nn.Linear(state_dim, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, action_dim) ) def forward(self, x): return self.fc(x)
class DQNAgent: def __init__(self, state_dim, action_dim): self.q_net = QNetwork(state_dim, action_dim) self.target_net = QNetwork(state_dim, action_dim) self.target_net.load_state_dict(self.q_net.state_dict()) self.optimizer = optim.Adam(self.q_net.parameters(), lr=0.001) self.memory = deque(maxlen=10000) self.batch_size = 32 self.gamma = 0.99 def update(self): if len(self.memory) < self.batch_size: return batch = random.sample(self.memory, self.batch_size) states, actions, rewards, next_states, dones = zip(*batch) states = torch.tensor(states, dtype=torch.float32) actions = torch.tensor(actions, dtype=torch.long) rewards = torch.tensor(rewards, dtype=torch.float32) next_states = torch.tensor(next_states, dtype=torch.float32) dones = torch.tensor(dones, dtype=torch.float32) q_values = self.q_net(states) q_current = q_values.gather(1, actions.unsqueeze(1)).squeeze(1) with torch.no_grad(): next_q_values = self.target_net(next_states) max_next_q = next_q_values.max(1)[0] q_target = rewards + (1 - dones) * self.gamma * max_next_q loss = nn.MSELoss()(q_current, q_target) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def sync_target(self): self.target_net.load_state_dict(self.q_net.state_dict())
|
4. 总结
本章我们迈出了从“玩具问题”到“实际应用”的关键一步。
- Function Approximation 让 RL 能处理无限状态。
- DQN 通过 Replay Buffer 和 Target Network 解决了非线性近似的不稳定性。
至此,Value-based 方法(学习 $Q$ 值)已经讲完了。但还有一类问题很难处理:连续动作空间(比如机器人关节角度)。这就需要 Policy-based 方法,直接学习策略函数 $\pi(a|s)$。
下一章:策略梯度方法 (Policy Gradient Methods)。
上一章:第7章 - 时序差分方法 | 下一章:第9章 - 策略梯度方法 >>