-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgrid_sampler_python.cpp
More file actions
245 lines (214 loc) · 7.76 KB
/
grid_sampler_python.cpp
File metadata and controls
245 lines (214 loc) · 7.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <cuda_runtime.h>
#include <iostream>
#include <vector>
extern "C" {
void grid_sampler_cuda(
const float* input,
const float* grid,
float* output,
int input_batch, int input_channels, int input_height, int input_width,
int output_height, int output_width,
int mode, int padding_mode, int align_corners,
cudaStream_t stream = 0
);
}
namespace py = pybind11;
// 模式枚举
enum class Mode {
BILINEAR = 0,
NEAREST = 1
};
enum class PaddingMode {
ZEROS = 0,
BORDER = 1,
REFLECTION = 2
};
// 字符串到枚举的转换
Mode parse_mode(const std::string& mode_str) {
if (mode_str == "bilinear") return Mode::BILINEAR;
if (mode_str == "nearest") return Mode::NEAREST;
throw std::invalid_argument("Invalid mode: " + mode_str);
}
PaddingMode parse_padding_mode(const std::string& padding_str) {
if (padding_str == "zeros") return PaddingMode::ZEROS;
if (padding_str == "border") return PaddingMode::BORDER;
if (padding_str == "reflection") return PaddingMode::REFLECTION;
throw std::invalid_argument("Invalid padding mode: " + padding_str);
}
// 主要的grid_sampler函数
py::array_t<float> grid_sampler(
py::array_t<float> input, // [N, C, H, W]
py::array_t<float> grid, // [N, H_out, W_out, 2]
const std::string& mode = "bilinear",
const std::string& padding_mode = "zeros",
bool align_corners = false
) {
// 检查输入参数
if (input.ndim() != 4) {
throw std::invalid_argument("Input must be 4D tensor [N, C, H, W]");
}
if (grid.ndim() != 4 || grid.shape(3) != 2) {
throw std::invalid_argument("Grid must be 4D tensor [N, H_out, W_out, 2]");
}
// 获取输入维度
int input_batch = input.shape(0);
int input_channels = input.shape(1);
int input_height = input.shape(2);
int input_width = input.shape(3);
// 获取grid维度
int grid_batch = grid.shape(0);
int output_height = grid.shape(1);
int output_width = grid.shape(2);
// 检查batch维度是否匹配
if (input_batch != grid_batch) {
throw std::invalid_argument("Batch dimensions must match between input and grid");
}
// 解析模式
Mode mode_enum = parse_mode(mode);
PaddingMode padding_enum = parse_padding_mode(padding_mode);
// 创建输出数组
std::vector<size_t> output_shape = {
static_cast<size_t>(input_batch),
static_cast<size_t>(input_channels),
static_cast<size_t>(output_height),
static_cast<size_t>(output_width)
};
py::array_t<float> output(output_shape);
// 获取数据指针
float* input_ptr = const_cast<float*>(input.data());
float* grid_ptr = const_cast<float*>(grid.data());
float* output_ptr = output.mutable_data();
// 分配GPU内存
float *d_input, *d_grid, *d_output;
size_t input_size = input_batch * input_channels * input_height * input_width * sizeof(float);
size_t grid_size = input_batch * output_height * output_width * 2 * sizeof(float);
size_t output_size = input_batch * input_channels * output_height * output_width * sizeof(float);
cudaError_t err;
err = cudaMalloc(&d_input, input_size);
if (err != cudaSuccess) {
throw std::runtime_error("Failed to allocate GPU memory for input: " + std::string(cudaGetErrorString(err)));
}
err = cudaMalloc(&d_grid, grid_size);
if (err != cudaSuccess) {
cudaFree(d_input);
throw std::runtime_error("Failed to allocate GPU memory for grid: " + std::string(cudaGetErrorString(err)));
}
err = cudaMalloc(&d_output, output_size);
if (err != cudaSuccess) {
cudaFree(d_input);
cudaFree(d_grid);
throw std::runtime_error("Failed to allocate GPU memory for output: " + std::string(cudaGetErrorString(err)));
}
// 复制数据到GPU
err = cudaMemcpy(d_input, input_ptr, input_size, cudaMemcpyHostToDevice);
if (err != cudaSuccess) {
cudaFree(d_input);
cudaFree(d_grid);
cudaFree(d_output);
throw std::runtime_error("Failed to copy input to GPU: " + std::string(cudaGetErrorString(err)));
}
err = cudaMemcpy(d_grid, grid_ptr, grid_size, cudaMemcpyHostToDevice);
if (err != cudaSuccess) {
cudaFree(d_input);
cudaFree(d_grid);
cudaFree(d_output);
throw std::runtime_error("Failed to copy grid to GPU: " + std::string(cudaGetErrorString(err)));
}
// 调用CUDA kernel
grid_sampler_cuda(
d_input, d_grid, d_output,
input_batch, input_channels, input_height, input_width,
output_height, output_width,
static_cast<int>(mode_enum),
static_cast<int>(padding_enum),
static_cast<int>(align_corners)
);
// 复制结果回CPU
err = cudaMemcpy(output_ptr, d_output, output_size, cudaMemcpyDeviceToHost);
if (err != cudaSuccess) {
cudaFree(d_input);
cudaFree(d_grid);
cudaFree(d_output);
throw std::runtime_error("Failed to copy output from GPU: " + std::string(cudaGetErrorString(err)));
}
// 清理GPU内存
cudaFree(d_input);
cudaFree(d_grid);
cudaFree(d_output);
// 检查CUDA错误
err = cudaGetLastError();
if (err != cudaSuccess) {
throw std::runtime_error("CUDA error: " + std::string(cudaGetErrorString(err)));
}
return output;
}
// 创建测试数据
py::array_t<float> create_test_input(int batch, int channels, int height, int width) {
std::vector<size_t> shape = {
static_cast<size_t>(batch),
static_cast<size_t>(channels),
static_cast<size_t>(height),
static_cast<size_t>(width)
};
py::array_t<float> input(shape);
float* ptr = input.mutable_data();
// 填充测试数据
for (int n = 0; n < batch; n++) {
for (int c = 0; c < channels; c++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
int idx = n * channels * height * width +
c * height * width +
h * width + w;
ptr[idx] = static_cast<float>(n * 1000 + c * 100 + h * 10 + w);
}
}
}
}
return input;
}
py::array_t<float> create_test_grid(int batch, int height, int width) {
std::vector<size_t> shape = {
static_cast<size_t>(batch),
static_cast<size_t>(height),
static_cast<size_t>(width),
2
};
py::array_t<float> grid(shape);
float* ptr = grid.mutable_data();
// 创建简单的网格:从[-1, 1]的均匀分布
for (int n = 0; n < batch; n++) {
for (int h = 0; h < height; h++) {
for (int w = 0; w < width; w++) {
int idx = n * height * width * 2 + h * width * 2 + w * 2;
ptr[idx] = (w * 2.0f / (width - 1)) - 1.0f; // x坐标
ptr[idx + 1] = (h * 2.0f / (height - 1)) - 1.0f; // y坐标
}
}
}
return grid;
}
PYBIND11_MODULE(grid_sampler_cuda, m) {
m.doc() = "CUDA Grid Sampler implementation";
m.def("grid_sampler", &grid_sampler,
"Grid sampler function",
py::arg("input"),
py::arg("grid"),
py::arg("mode") = "bilinear",
py::arg("padding_mode") = "zeros",
py::arg("align_corners") = false);
m.def("create_test_input", &create_test_input,
"Create test input tensor",
py::arg("batch") = 1,
py::arg("channels") = 3,
py::arg("height") = 4,
py::arg("width") = 4);
m.def("create_test_grid", &create_test_grid,
"Create test grid tensor",
py::arg("batch") = 1,
py::arg("height") = 2,
py::arg("width") = 2);
}