PyTorch FractionalMaxPool3d解析

函数入口

主要负责核心处理逻辑前的数据检查以及运行时类型。运行时类型靠AT_DISPATCH_FLOATING_TYPES_AND2宏实现,可展开查看。lambda表达式负责实际的逻辑执行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
// 核心输入张量。这是需要进行分数最大池化操作的特征图。
// 它的形状通常是 4D [C, T, H, W](通道, 时间, 高度, 宽度)或 5D [N, C, T, H, W](批次, 通道, 时间, 高度, 宽度)。
// 函数内部会首先将其转换为内存连续的 (contiguous) 版本以提高计算效率。 对应算子参数input.
const at::Tensor& input_,
// 池化窗口在**时间(T)**维度上的大小,即 kernel_size[0]。对应算子参数kernel_size.
int64_t poolSizeT,
// 池化窗口在**高度(H)**维度上的大小,即 kernel_size[1]。对应算子参数kernel_size.
int64_t poolSizeH,
// 池化窗口在**宽度(W)**维度上的大小,即 kernel_size[2]。对应算子参数kernel_size.
int64_t poolSizeW,
// 期望输出张量在**时间(T)**维度上的尺寸。可能由output_size直接给出,也可能由output_ratio_t计算。
int64_t outputT,
// 期望输出张量在**高度(H)**维度上的尺寸。可能由output_size直接给出,也可能由output_ratio_h计算。
int64_t outputH,
// 期望输出张量在**宽度(W)**维度上的尺寸。可能由output_size直接给出,也可能由output_ratio_w计算。
int64_t outputW,
// 分数池化的关键输入。这个张量包含了用于生成池化窗口起始位置的随机样本值。正是这个输入使得池化区域变得不规则和“分数化”。
// 它的形状应为 [N, C, 3] (5D输入) 或 [1, C, 3] (4D输入),最后的维度3分别对应T, H, W三个方向的随机数。 对应算子参数random_sample
const at::Tensor& randomSamples_,
// 输入张量的批次大小(N)。这个值是在调用此函数之前的 meta 函数中从 input_ 的形状计算得出的,直接传入可以避免重复计算。对于4D输入,该值为1。
int64_t numBatch,
// 输入张量的通道数(C),也称为 "planes"。同样是为了效率而预先计算并传入的。
int64_t numPlanes,
// 输入张量的**时间维度(T)**的大小。
int64_t inputT,
// 输入张量的**高度维度(H)**的大小。
int64_t inputH,
// 输入张量的**宽度维度(W)**的大小。
int64_t inputW,
// 预先分配好空间的输出张量。这个函数的所有计算结果(池化后的值)都将写入这个张量。它的形状为 [N, C, outputT, outputH, outputW]。对应算子参数的ouput
const at::Tensor& output,
// 预先分配好空间的索引张量。用于存储每个池化窗口中找到的最大值在原始 input_ 张量中的位置索引。
// 这些索引对于后续的反向传播(求导)至关重要。它的形状与 output 张量完全相同。对应算子参数的indices。
const at::Tensor& indices) {

// 检查参数形状
fractional_max_pool_check_shape</*ndim*/ 3>(input_, randomSamples_);

if (output.numel() == 0) {
return;
}

/* get contiguous input and samples */
// 准备数据,确保数据连续。
auto input = input_.contiguous();
auto randomSamples = randomSamples_.contiguous();

// 自动根据运行时输入张量 input 的实际数据类型,来实例化并执行对应类型的代码。
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
kHalf,
input.scalar_type(),
"fractional_max_pool3d_out_frame",
[&] {
fractional_max_pool3d_out_frame<scalar_t>(
input.const_data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
indices.data_ptr<int64_t>(),
randomSamples.const_data_ptr<scalar_t>(),
numBatch, numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW,
poolSizeT, poolSizeH, poolSizeW
);
}
);
}

并行处理函数

由于不同的通道互不干扰,可以完全并行,由此函数实现并行处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
template<typename scalar_t>
static void fractional_max_pool3d_out_frame(
const scalar_t* input,
scalar_t* output,
int64_t* indices,
const scalar_t* randomSamples,
int64_t numBatch, int64_t numPlanes,
int64_t inputT, int64_t inputH, int64_t inputW,
int64_t outputT, int64_t outputH, int64_t outputW,
int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
// 根据情况分批次并行处理
if(numBatch == 1) {
fractional_max_pool3d_out_single_batch_frame<scalar_t>(
input, output, indices, randomSamples,
numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW,
poolSizeT, poolSizeH, poolSizeW
);
return;
}

at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
for (const auto batch : c10::irange(start, end)) {
fractional_max_pool3d_out_single_batch_frame<scalar_t>(
input + batch * numPlanes * inputW * inputH * inputT,
output + batch * numPlanes * outputW * outputH * outputT,
indices + batch * numPlanes * outputW * outputH * outputT,
randomSamples + batch * numPlanes * 3,
numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW,
poolSizeT, poolSizeH, poolSizeW
);
}
});
}

}

核心逻辑函数

确定池化窗口并获取最大值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
template<typename scalar_t>
static void fractional_max_pool3d_out_single_batch_frame(
const scalar_t* input,
scalar_t* output,
int64_t* indices,
const scalar_t* randomSamples,
int64_t numPlanes,
int64_t inputT, int64_t inputH, int64_t inputW,
int64_t outputT, int64_t outputH, int64_t outputW,
int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {

// 并行处理多通道数据
at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
// 单通道处理
// parallel_for自动把0 - numPlanes分到多个线程,每个线程又负责处理多个plane
for (const auto plane : c10::irange(start, end)) {
/* each plane contains 3 random samples,
one for T, one for W, and one for H */
const scalar_t* randomSamplesForPlane = randomSamples + plane * 3;

// 核心逻辑 对于单个维度(比如时间 T,或高度 H,或宽度 W),根据一个输入的随机样本 sample,生成一个伪随机的、单调递增的池化窗口起始点序列。
/* Generate interval sequence */
auto sequenceT = generate_intervals<scalar_t>(
randomSamplesForPlane[0], inputT, outputT, poolSizeT);
auto sequenceH = generate_intervals<scalar_t>(
randomSamplesForPlane[1], inputH, outputH, poolSizeH);
auto sequenceW = generate_intervals<scalar_t>(
randomSamplesForPlane[2], inputW, outputW, poolSizeW);

/* loop over output */
// 压缩的一维数据
const scalar_t* inputForPlane = input + plane * inputT * inputH * inputW;
scalar_t* outputForPlane = output + plane * outputT * outputH * outputW;
int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW;

for (int64_t t = 0; t < outputT; ++t) {
int64_t inputTStart = sequenceT[t];

for (int64_t h = 0; h < outputH; ++h) {
int64_t inputHStart = sequenceH[h];

for (int64_t w = 0; w < outputW; ++w) {
int64_t inputWStart = sequenceW[w];

int64_t t2 = inputTStart, h2 = inputHStart, w2 = inputWStart;
scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
int64_t maxIndex = t2 * inputH * inputW + h2 * inputW + w2;

// 寻找最大值
for (t2 = inputTStart; t2 < inputTStart + poolSizeT; ++t2) {
for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
AT_ASSERT(t2 >= 0 && t2 < inputT);
AT_ASSERT(h2 >= 0 && h2 < inputH);
AT_ASSERT(w2 >= 0 && w2 < inputW);

int64_t planeIndex = t2 * inputH * inputW + h2 * inputW + w2;
scalar_t val = inputForPlane[planeIndex];
if (val > maxVal || std::isnan(val)) {
maxVal = val;
maxIndex = planeIndex;
}
}
}
}
// 保存结果
outputForPlane[t * outputH * outputW + h * outputW + w] = maxVal;
indicesForPlane[t * outputH * outputW + h * outputW + w] = maxIndex;
}
}
}
}
});
}

精髓所在

整个分数池化算法的绝对核心——generate_intervals 函数。生成一个伪随机的、单调递增的池化窗口起始点序列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// 核心逻辑(精髓)
template<typename scalar_t>
inline std::vector<int64_t> generate_intervals(
scalar_t sample, // 一个浮点数,通常在 [0, 1) 区间内。这是随机性的来源,对应 randomSamples 张量中的一个值。
int64_t inputSize, // 输入在这个维度上的尺寸(例如 inputT)。
int64_t outputSize, // 期望输出在这个维度上的尺寸(例如 outputT)。
int64_t poolSize // 池化核在这个维度上的尺寸(例如 poolSizeT)。
) {
std::vector<int64_t> sequence(outputSize);
if (outputSize > 1) {
scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize /* 池化窗口起始点可以移动的有效范围 */) /
static_cast<scalar_t>(outputSize - 1/* 间隔数 */);

// 生成中间的起点 (带随机抖动)
for (const auto i : c10::irange(outputSize - 1)) {
sequence[i] =
static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
}
}
// 强制最后一个池化窗口的起始点位于 inputSize - poolSize
if (outputSize > 0) {
sequence[outputSize - 1] = inputSize - poolSize;
}
return sequence;
}

该封面图片由Da DongPixabay上发布