1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| template<typename scalar_t> static void fractional_max_pool3d_out_frame( const scalar_t* input, scalar_t* output, int64_t* indices, const scalar_t* randomSamples, int64_t numBatch, int64_t numPlanes, int64_t inputT, int64_t inputH, int64_t inputW, int64_t outputT, int64_t outputH, int64_t outputW, int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) { if(numBatch == 1) { fractional_max_pool3d_out_single_batch_frame<scalar_t>( input, output, indices, randomSamples, numPlanes, inputT, inputH, inputW, outputT, outputH, outputW, poolSizeT, poolSizeH, poolSizeW ); return; }
at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) { for (const auto batch : c10::irange(start, end)) { fractional_max_pool3d_out_single_batch_frame<scalar_t>( input + batch * numPlanes * inputW * inputH * inputT, output + batch * numPlanes * outputW * outputH * outputT, indices + batch * numPlanes * outputW * outputH * outputT, randomSamples + batch * numPlanes * 3, numPlanes, inputT, inputH, inputW, outputT, outputH, outputW, poolSizeT, poolSizeH, poolSizeW ); } }); }
}
|