1. 为什么需要函数近似?

在前面的章节中,我们都假设使用表格 (Table) 来存储 Value Function:

  • $v(s)$ 是一个长度为 $|∘|$ 的向量。
  • $q(s, a)$ 是一个大小为 $|∘| \times |∘|$ 的矩阵。

致命问题

  1. 状态空间爆炸:围棋的状态数是 $10^{170}$,任何计算机都存不下这个表。
  2. 泛化能力差:如果是连续状态(比如机器人的位置坐标),你没见过的坐标就没有 Value,表格法无法处理“相似的状态”。

解决方案:用一个函数 $f(s, w)$ 来近似 $v(s)$。
$$ \hat{v}(s, w) \approx v_\pi(s) $$
其中 $w$ 是参数(比如神经网络的权重)。这样我们只需要存储 $w$,通常 $w$ 的维度远小于状态数。


2. 目标函数与优化

我们希望 $\hat{v}(s, w)$ 越准越好。定义目标函数(损失函数):
$$ J(w) = \mathbb{E}{\pi} \left[ (v\pi(S) - \hat{v}(S, w))^2 \right] $$

使用 SGD 进行更新:
$$ w \leftarrow w + \alpha [v_\pi(s) - \hat{v}(s, w)] \nabla_w \hat{v}(s, w) $$

但问题是:我们不知道真实的 $v_\pi(s)$。
办法:用 TD Target 代替 $v_\pi(s)$。
$$ w \leftarrow w + \alpha [ \underbrace{(R + \gamma \hat{v}(s’, w))}_{\text{TD Target}} - \hat{v}(s, w) ] \nabla_w \hat{v}(s, w) $$

这叫 Semi-gradient,因为我们在求导时,忽略了 TD Target 对 $w$ 的依赖(把 Target 当常数看)。


3. Deep Q-Network (DQN)

当 $\hat{v}$ 是一个深度神经网络时,这就是 Deep RL。
DQN 是将 Q-learning 与神经网络结合的里程碑算法(DeepMind, Nature 2015)。

3.1 核心创新

为了解决 Deep RL 训练不稳定的问题,DQN 引入了两大通过:

  1. Experience Replay (经验回放)
    • 把 $(s, a, r, s’)$ 存进一个 Buffer。
    • 训练时随机采样 Batch。这打破了数据的时间相关性(Correlation),让数据分布更像 i.i.d.,稳定神经网络训练。
  2. Target Network (目标网络)
    • 在计算 TD Target 时,使用一个参数固定的旧网络 $Q(s’, a; w^-)$。
    • $y = r + \gamma \max_{a’} Q(s’, a’; w^-)$。
    • 这避免了“自己追自己”(Chasing a moving target)的震荡问题。

3.2 Python 代码实战:DQN 核心逻辑

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(QNetwork, self).__init__()
self.fc = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, action_dim)
)

def forward(self, x):
return self.fc(x)

class DQNAgent:
def __init__(self, state_dim, action_dim):
self.q_net = QNetwork(state_dim, action_dim)
self.target_net = QNetwork(state_dim, action_dim)
self.target_net.load_state_dict(self.q_net.state_dict()) # 初始化相同

self.optimizer = optim.Adam(self.q_net.parameters(), lr=0.001)
self.memory = deque(maxlen=10000) # Replay Buffer
self.batch_size = 32
self.gamma = 0.99

def update(self):
if len(self.memory) < self.batch_size: return

# 1. Experience Replay
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)

states = torch.tensor(states, dtype=torch.float32)
actions = torch.tensor(actions, dtype=torch.long)
rewards = torch.tensor(rewards, dtype=torch.float32)
next_states = torch.tensor(next_states, dtype=torch.float32)
dones = torch.tensor(dones, dtype=torch.float32)

# 2. 计算 Q_current
q_values = self.q_net(states)
q_current = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

# 3. 计算 Q_target (使用 Target Network)
with torch.no_grad():
next_q_values = self.target_net(next_states)
max_next_q = next_q_values.max(1)[0]
q_target = rewards + (1 - dones) * self.gamma * max_next_q

# 4. 梯度下降
loss = nn.MSELoss()(q_current, q_target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

def sync_target(self):
self.target_net.load_state_dict(self.q_net.state_dict())

4. 总结

本章我们迈出了从“玩具问题”到“实际应用”的关键一步。

  • Function Approximation 让 RL 能处理无限状态。
  • DQN 通过 Replay Buffer 和 Target Network 解决了非线性近似的不稳定性。

至此,Value-based 方法(学习 $Q$ 值)已经讲完了。但还有一类问题很难处理:连续动作空间(比如机器人关节角度)。这就需要 Policy-based 方法,直接学习策略函数 $\pi(a|s)$。

下一章:策略梯度方法 (Policy Gradient Methods)


上一章:第7章 - 时序差分方法 | 下一章:第9章 - 策略梯度方法 >>