转自 zhihu

1. 背景

去年我理解了torch.gather()用法,今年看到又给忘了,索性把自己的理解梳理出来,方便今后遗忘后快速上手。

  • 官方文档:

TORCH.GATHER

官方文档对torch.gather()的定义非常简洁

定义:从原tensor中获取指定dim和指定index的数据

看到这个核心定义,我们很容易想到gather()的基本想法其实就类似从完整数据中按索引取值般简单,比如下面从列表中按索引取值

lst = [1, 2, 3, 4, 5]
value = lst[2]  # value = 3
value = lst[2:4]  # value = [3, 4]

上面的取值例子是取单个值或具有逻辑顺序序列的例子,而对于深度学习常用的批量tensor数据来说,我们的需求可能是选取其中多个且乱序的值,此时gather()就是一个很好的tool,它可以帮助我们从批量tensor中取出指定乱序索引下的数据,因此其用途如下

用途:方便从批量tensor中获取指定索引下的数据,该索引是高度自定义化的,可乱序的

2. 实战

我们找个3×3的二维矩阵做个实验

import torch

tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

输出结果

tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

2.1 输入行向量index,并替换行索引(dim=0)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

输出结果

tensor([[9, 7, 5]])

过程如图所示

2.2 输入行向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5, 4, 3]])

过程如图所示

2.3 输入列向量index,并替换列索引(dim=1)

index = torch.tensor([[2, 1, 0]]).t()
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[5],
        [7],
        [9]])

过程如图所示

2.4 输入二维矩阵index,并替换列索引(dim=1)

index = torch.tensor([[0, 2], [1, 2]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

输出结果

tensor([[3, 5],
        [7, 8]])

过程同上

3. 在强化学习DQN中的使用

在PyTorch官网DQN页面的代码中,$Q(S_t, a)$ 是这样获取的

# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch).gather(1, action_batch)

其中 $Q(S_t)$,即policy_net(state_batch)为shape=(128, 2)的二维表,动作数为2

而我们通过神经网络输出的对应批量动作为

此时,使用gather()函数即可轻松获取批量状态对应批量动作的$Q(S_t, a)$

打赏作者

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

CAPTCHA