Skip to content

Commit 5b0904e

Browse files
committed
feat[FoundationPose]: Support multi-iteration refinement process
Signed-off-by: zz990099 <771647586@qq.com>
1 parent b049286 commit 5b0904e

4 files changed

Lines changed: 83 additions & 116 deletions

File tree

detection_6d_foundationpose/include/detection_6d_foundationpose/foundationpose.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,16 @@ class Base6DofDetectionModel {
2929
* @param mask Object mask (CV_8UC1 format, positive pixels > 0)
3030
* @param target_name Object category name (must match construction mapping)
3131
* @param out_pose_in_mesh Output pose in mesh coordinate frame
32+
* @param refine_itr Refinement process iteration num
3233
* @return true Registration successful
3334
* @return false Registration failed
3435
*/
3536
virtual bool Register(const cv::Mat &rgb,
3637
const cv::Mat &depth,
3738
const cv::Mat &mask,
3839
const std::string &target_name,
39-
Eigen::Matrix4f &out_pose_in_mesh) = 0;
40+
Eigen::Matrix4f &out_pose_in_mesh,
41+
size_t refine_itr = 1) = 0;
4042

4143
/**
4244
* @brief Track object pose from subsequent frames (lightweight version of Register)
@@ -50,14 +52,16 @@ class Base6DofDetectionModel {
5052
* @param hyp_pose_in_mesh Hypothesis pose in mesh frame (from Register or other sources)
5153
* @param target_name Object category name (must match construction mapping)
5254
* @param out_pose_in_mesh Output pose in mesh coordinate frame
55+
* @param refine_itr Refinement process iteration num
5356
* @return true Tracking successful
5457
* @return false Tracking failed
5558
*/
5659
virtual bool Track(const cv::Mat &rgb,
5760
const cv::Mat &depth,
5861
const Eigen::Matrix4f &hyp_pose_in_mesh,
5962
const std::string &target_name,
60-
Eigen::Matrix4f &out_pose_in_mesh) = 0;
63+
Eigen::Matrix4f &out_pose_in_mesh,
64+
size_t refine_itr = 1) = 0;
6165

6266
/**
6367
* @brief Virtual destructor for proper resource cleanup

detection_6d_foundationpose/src/foundationpose.cpp

Lines changed: 75 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -39,32 +39,38 @@ class FoundationPose : public Base6DofDetectionModel {
3939
const cv::Mat &depth,
4040
const cv::Mat &mask,
4141
const std::string &target_name,
42-
Eigen::Matrix4f &out_pose_in_mesh) override;
42+
Eigen::Matrix4f &out_pose_in_mesh,
43+
size_t refine_itr = 1) override;
4344

4445
bool Track(const cv::Mat &rgb,
4546
const cv::Mat &depth,
4647
const Eigen::Matrix4f &hyp_pose_in_mesh,
4748
const std::string &target_name,
48-
Eigen::Matrix4f &out_pose_in_mesh) override;
49+
Eigen::Matrix4f &out_pose_in_mesh,
50+
size_t refine_itr = 1) override;
4951

5052
private:
5153
bool CheckInputArguments(const cv::Mat &rgb,
5254
const cv::Mat &depth,
5355
const cv::Mat &mask,
5456
const std::string &target_name);
5557

56-
bool UploadDataToDevice(const cv::Mat &rgb,
57-
const cv::Mat &depth,
58-
const cv::Mat &mask,
59-
const std::shared_ptr<FoundationPosePipelinePackage> &package);
58+
using ParsingType = std::unique_ptr<FoundationPosePipelinePackage>;
6059

61-
bool RefinePreProcess(std::shared_ptr<async_pipeline::IPipelinePackage> package);
60+
bool UploadDataToDevice(const cv::Mat &rgb,
61+
const cv::Mat &depth,
62+
const cv::Mat &mask,
63+
const ParsingType &package);
6264

63-
bool ScorePreprocess(std::shared_ptr<async_pipeline::IPipelinePackage> package);
65+
bool RefinePreProcess(const ParsingType &package);
6466

65-
bool ScorePostProcess(std::shared_ptr<async_pipeline::IPipelinePackage> package);
67+
bool RefinePostProcess(const ParsingType &package);
6668

67-
bool TrackPostProcess(std::shared_ptr<async_pipeline::IPipelinePackage> package);
69+
bool ScorePreprocess(const ParsingType &package);
70+
71+
bool ScorePostProcess(const ParsingType &package);
72+
73+
bool TrackPostProcess(const ParsingType &package);
6874

6975
private:
7076
// 以下参数不对外开放
@@ -76,8 +82,8 @@ class FoundationPose : public Base6DofDetectionModel {
7682
const float REFINE_ROT_NORMALIZER = 0.349065850398865;
7783
const std::string SCORE_OUTPUT_BLOB_NAME = "scores";
7884
// render参数
79-
const int score_mode_poses_num_ = 252;
80-
const int refine_mode_poses_num_ = 1;
85+
const int score_mode_poses_num_ = 252;
86+
const int refine_mode_poses_num_ = 1;
8187
const float refine_mode_crop_ratio_ = 1.2;
8288
const float score_mode_crop_ratio_ = 1.1;
8389

@@ -150,9 +156,9 @@ FoundationPose::FoundationPose(std::shared_ptr<inference_core::BaseInferCore>
150156
{
151157
const std::string &target_name = mesh_loader->GetName();
152158
LOG(INFO) << "[FoundationPose] Got target_name : " << target_name;
153-
map_name2loaders_[target_name] = mesh_loader;
154-
map_name2renderer_[target_name] = std::make_shared<FoundationPoseRenderer>(
155-
mesh_loader, intrinsic_, score_mode_poses_num_);
159+
map_name2loaders_[target_name] = mesh_loader;
160+
map_name2renderer_[target_name] =
161+
std::make_shared<FoundationPoseRenderer>(mesh_loader, intrinsic_, score_mode_poses_num_);
156162
}
157163

158164
hyp_poses_sampler_ = std::make_shared<FoundationPoseSampler>(
@@ -189,12 +195,13 @@ bool FoundationPose::Register(const cv::Mat &rgb,
189195
const cv::Mat &depth,
190196
const cv::Mat &mask,
191197
const std::string &target_name,
192-
Eigen::Matrix4f &out_pose_in_mesh)
198+
Eigen::Matrix4f &out_pose_in_mesh,
199+
size_t refine_itr)
193200
{
194201
CHECK_STATE(CheckInputArguments(rgb, depth, mask, target_name),
195202
"[FoundationPose] `Register` Got invalid arguments!!!");
196203

197-
auto package = std::make_shared<FoundationPosePipelinePackage>();
204+
auto package = std::make_unique<FoundationPosePipelinePackage>();
198205
package->rgb_on_host = rgb;
199206
package->depth_on_host = depth;
200207
package->mask_on_host = mask;
@@ -203,19 +210,23 @@ bool FoundationPose::Register(const cv::Mat &rgb,
203210
MESSURE_DURATION_AND_CHECK_STATE(UploadDataToDevice(rgb, depth, mask, package),
204211
"[FoundationPose] SyncDetect Failed to upload data!!!");
205212

206-
MESSURE_DURATION_AND_CHECK_STATE(
207-
RefinePreProcess(package),
208-
"[FoundationPose] SyncDetect Failed to execute RefinePreProcess!!!");
213+
for (size_t i = 0 ; i < refine_itr ; ++ i) {
214+
MESSURE_DURATION_AND_CHECK_STATE(
215+
RefinePreProcess(package),
216+
"[FoundationPose] SyncDetect Failed to execute RefinePreProcess!!!");
209217

210-
// package->infer_buffer = package->refiner_blobs_buffer;
211-
MESSURE_DURATION_AND_CHECK_STATE(
212-
refiner_core_->SyncInfer(package->GetInferBuffer()),
213-
"[FoundationPose] SyncDetect Failed to execute refiner_core_->SyncInfer!!!");
218+
MESSURE_DURATION_AND_CHECK_STATE(
219+
refiner_core_->SyncInfer(package->GetInferBuffer()),
220+
"[FoundationPose] SyncDetect Failed to execute refiner_core_->SyncInfer!!!");
221+
222+
MESSURE_DURATION_AND_CHECK_STATE(
223+
RefinePostProcess(package),
224+
"[FoundationPose] SyncDetect Failed to execute RefinePostProcess!!!");
225+
}
214226

215227
MESSURE_DURATION_AND_CHECK_STATE(
216228
ScorePreprocess(package), "[FoundationPose] SyncDetect Failed to execute ScorePreprocess!!!");
217229

218-
// unit_buffer->p_blob_buffers = package->scorer_blobs_buffer;
219230
MESSURE_DURATION_AND_CHECK_STATE(
220231
scorer_core_->SyncInfer(package->GetInferBuffer()),
221232
"[FoundationPose] SyncDetect Failed to execute scorer_core_->SyncInfer!!!");
@@ -232,12 +243,13 @@ bool FoundationPose::Track(const cv::Mat &rgb,
232243
const cv::Mat &depth,
233244
const Eigen::Matrix4f &hyp_pose_in_mesh,
234245
const std::string &target_name,
235-
Eigen::Matrix4f &out_pose_in_mesh)
246+
Eigen::Matrix4f &out_pose_in_mesh,
247+
size_t refine_itr)
236248
{
237249
CHECK_STATE(CheckInputArguments(rgb, depth, cv::Mat(), target_name),
238250
"[FoundationPose] `Track` Got invalid arguments!!!");
239251

240-
auto package = std::make_shared<FoundationPosePipelinePackage>();
252+
auto package = std::make_unique<FoundationPosePipelinePackage>();
241253
package->rgb_on_host = rgb;
242254
package->depth_on_host = depth;
243255
package->target_name = target_name;
@@ -246,26 +258,27 @@ bool FoundationPose::Track(const cv::Mat &rgb,
246258
MESSURE_DURATION_AND_CHECK_STATE(UploadDataToDevice(rgb, depth, cv::Mat(), package),
247259
"[FoundationPose] Track Failed to upload data!!!");
248260

249-
MESSURE_DURATION_AND_CHECK_STATE(RefinePreProcess(package),
250-
"[FoundationPose] Track Failed to execute RefinePreProcess!!!");
261+
for (size_t i = 0 ; i < refine_itr ; ++ i) {
262+
MESSURE_DURATION_AND_CHECK_STATE(RefinePreProcess(package),
263+
"[FoundationPose] Track Failed to execute RefinePreProcess!!!");
251264

252-
MESSURE_DURATION_AND_CHECK_STATE(
253-
refiner_core_->SyncInfer(package->GetInferBuffer()),
254-
"[FoundationPose] Track Failed to execute refiner_core_->SyncInfer!!!");
265+
MESSURE_DURATION_AND_CHECK_STATE(
266+
refiner_core_->SyncInfer(package->GetInferBuffer()),
267+
"[FoundationPose] Track Failed to execute refiner_core_->SyncInfer!!!");
255268

256-
MESSURE_DURATION_AND_CHECK_STATE(TrackPostProcess(package),
257-
"[Foundation] Track Failed to execute `TrackPostProcess`!!!");
269+
MESSURE_DURATION_AND_CHECK_STATE(RefinePostProcess(package),
270+
"[Foundation] Track Failed to execute `RefinePostProcess`!!!");
271+
}
258272

259-
out_pose_in_mesh = std::move(package->actual_pose);
273+
out_pose_in_mesh = std::move(package->hyp_poses[0]);
260274

261275
return true;
262276
}
263277

264-
bool FoundationPose::UploadDataToDevice(
265-
const cv::Mat &rgb,
266-
const cv::Mat &depth,
267-
const cv::Mat &mask,
268-
const std::shared_ptr<FoundationPosePipelinePackage> &package)
278+
bool FoundationPose::UploadDataToDevice(const cv::Mat &rgb,
279+
const cv::Mat &depth,
280+
const cv::Mat &mask,
281+
const ParsingType &package)
269282
{
270283
const int input_image_height = rgb.rows, input_image_width = rgb.cols;
271284
package->input_image_height = input_image_height;
@@ -312,11 +325,8 @@ bool FoundationPose::UploadDataToDevice(
312325
return true;
313326
}
314327

315-
bool FoundationPose::RefinePreProcess(std::shared_ptr<async_pipeline::IPipelinePackage> _package)
328+
bool FoundationPose::RefinePreProcess(const ParsingType &package)
316329
{
317-
auto package = std::dynamic_pointer_cast<FoundationPosePipelinePackage>(_package);
318-
CHECK_STATE(package != nullptr, "[FoundationPose] RefinePreProcess Got INVALID package ptr!!!");
319-
320330
// 1. sample
321331
if (package->hyp_poses.empty())
322332
{
@@ -327,11 +337,15 @@ bool FoundationPose::RefinePreProcess(std::shared_ptr<async_pipeline::IPipelineP
327337
}
328338

329339
// 2. render
330-
auto &refine_renderer = map_name2renderer_[package->target_name];
331-
auto refiner_blob_buffer = refiner_core_->GetBuffer(false);
340+
if (package->refiner_blobs_buffer == nullptr) {
341+
package->refiner_blobs_buffer = refiner_core_->GetBuffer(true);
342+
}
343+
const auto& refiner_blob_buffer = package->refiner_blobs_buffer;
332344
// 设置推理前blob的输入位置为device,输出的blob位置为host端
333345
refiner_blob_buffer->SetBlobBuffer(RENDER_INPUT_BLOB_NAME, DataLocation::DEVICE);
334346
refiner_blob_buffer->SetBlobBuffer(TRANSF_INPUT_BLOB_NAME, DataLocation::DEVICE);
347+
348+
auto &refine_renderer = map_name2renderer_[package->target_name];
335349
CHECK_STATE(
336350
refine_renderer->RenderAndTransform(
337351
package->hyp_poses, package->rgb_on_device.get(), package->depth_on_device.get(),
@@ -346,22 +360,21 @@ bool FoundationPose::RefinePreProcess(std::shared_ptr<async_pipeline::IPipelineP
346360
{input_poses_num, crop_window_H_, crop_window_W_, 6});
347361
refiner_blob_buffer->SetBlobShape(TRANSF_INPUT_BLOB_NAME,
348362
{input_poses_num, crop_window_H_, crop_window_W_, 6});
349-
package->refiner_blobs_buffer = refiner_blob_buffer;
350363
package->infer_buffer = refiner_blob_buffer;
351364

352365
return true;
353366
}
354367

355-
bool FoundationPose::ScorePreprocess(std::shared_ptr<async_pipeline::IPipelinePackage> _package)
368+
bool FoundationPose::RefinePostProcess(const ParsingType &package)
356369
{
357-
auto package = std::dynamic_pointer_cast<FoundationPosePipelinePackage>(_package);
358-
CHECK_STATE(package != nullptr, "[FoundationPose] ScorePreprocess Got INVALID package ptr!!!");
359370
// 获取refiner模型的缓存指针
360371
const auto &refiner_blob_buffer = package->refiner_blobs_buffer;
361372
const auto _trans_ptr = refiner_blob_buffer->GetOuterBlobBuffer(REFINE_TRANS_OUT_BLOB_NAME).first;
362373
const auto _rot_ptr = refiner_blob_buffer->GetOuterBlobBuffer(REFINE_ROT_OUT_BLOB_NAME).first;
363374
const float *trans_ptr = static_cast<float *>(_trans_ptr);
364375
const float *rot_ptr = static_cast<float *>(_rot_ptr);
376+
CHECK_STATE(trans_ptr != nullptr, "[FoundationPose] RefinePostProcess got invalid trans_ptr !");
377+
CHECK_STATE(rot_ptr != nullptr, "[FoundationPose] RefinePostProcess got invalid rot_ptr !");
365378

366379
// 获取生成的假设位姿
367380
const auto &hyp_poses = package->hyp_poses;
@@ -371,10 +384,12 @@ bool FoundationPose::ScorePreprocess(std::shared_ptr<async_pipeline::IPipelinePa
371384
const auto &mesh_loader = map_name2loaders_[package->target_name];
372385

373386
// transformation 将模型输出的相对位姿转换为绝对位姿
374-
const float mesh_diameter = mesh_loader->GetMeshDiameter();
387+
const float mesh_diameter = mesh_loader->GetMeshDiameter();
388+
375389
std::vector<Eigen::Vector3f> trans_delta(poses_num);
376390
std::vector<Eigen::Vector3f> rot_delta(poses_num);
377391
std::vector<Eigen::Matrix3f> rot_mat_delta(poses_num);
392+
378393
for (int i = 0; i < poses_num; ++i)
379394
{
380395
const size_t offset = i * 3;
@@ -398,6 +413,12 @@ bool FoundationPose::ScorePreprocess(std::shared_ptr<async_pipeline::IPipelinePa
398413
refine_poses[i].block<3, 3>(0, 0) = result_3x3;
399414
}
400415

416+
package->hyp_poses = std::move(refine_poses);
417+
return true;
418+
}
419+
420+
bool FoundationPose::ScorePreprocess(const ParsingType &package)
421+
{
401422
auto scorer_blob_buffer = scorer_core_->GetBuffer(false);
402423
// 获取对应的score_renderer
403424
// 设置推理前后blob输出的位置,这里输入输出都在device端
@@ -407,29 +428,26 @@ bool FoundationPose::ScorePreprocess(std::shared_ptr<async_pipeline::IPipelinePa
407428
auto &score_renderer = map_name2renderer_[package->target_name];
408429
CHECK_STATE(
409430
score_renderer->RenderAndTransform(
410-
refine_poses, package->rgb_on_device.get(), package->depth_on_device.get(),
431+
package->hyp_poses, package->rgb_on_device.get(), package->depth_on_device.get(),
411432
package->xyz_map_on_device.get(), package->input_image_height, package->input_image_width,
412433
scorer_blob_buffer->GetOuterBlobBuffer(RENDER_INPUT_BLOB_NAME).first,
413434
scorer_blob_buffer->GetOuterBlobBuffer(TRANSF_INPUT_BLOB_NAME).first,
414435
score_mode_crop_ratio_),
415436
"[FoundationPose] score_renderer RenderAndTransform Failed!!!");
416437

417-
package->refine_poses = std::move(refine_poses);
418438
package->scorer_blobs_buffer = scorer_blob_buffer;
419439
package->infer_buffer = scorer_blob_buffer;
420440

421441
return true;
422442
}
423443

424-
bool FoundationPose::ScorePostProcess(std::shared_ptr<async_pipeline::IPipelinePackage> _package)
444+
bool FoundationPose::ScorePostProcess(const ParsingType &package)
425445
{
426-
auto package = std::dynamic_pointer_cast<FoundationPosePipelinePackage>(_package);
427-
CHECK_STATE(package != nullptr, "[FoundationPose] ScorePostProcess Got INVALID package ptr!!!");
428446
const auto &scorer_blob_buffer = package->scorer_blobs_buffer;
429447
// 获取scorer模型的输出缓存指针
430448
void *score_ptr = scorer_blob_buffer->GetOuterBlobBuffer(SCORE_OUTPUT_BLOB_NAME).first;
431449

432-
const auto &refine_poses = package->refine_poses;
450+
const auto &refine_poses = package->hyp_poses;
433451
const int poses_num = refine_poses.size();
434452

435453
// 获取置信度最大的refined_pose
@@ -439,58 +457,6 @@ bool FoundationPose::ScorePostProcess(std::shared_ptr<async_pipeline::IPipelineP
439457
return true;
440458
}
441459

442-
bool FoundationPose::TrackPostProcess(std::shared_ptr<async_pipeline::IPipelinePackage> _package)
443-
{
444-
auto package = std::dynamic_pointer_cast<FoundationPosePipelinePackage>(_package);
445-
CHECK_STATE(package != nullptr, "[FoundationPose] TrackPostProcess Got INVALID package ptr!!!");
446-
447-
// 获取refiner模型的缓存指针
448-
const auto &refiner_blob_buffer = package->refiner_blobs_buffer;
449-
const auto _trans_ptr = refiner_blob_buffer->GetOuterBlobBuffer(REFINE_TRANS_OUT_BLOB_NAME).first;
450-
const auto _rot_ptr = refiner_blob_buffer->GetOuterBlobBuffer(REFINE_ROT_OUT_BLOB_NAME).first;
451-
const float *trans_ptr = static_cast<float *>(_trans_ptr);
452-
const float *rot_ptr = static_cast<float *>(_rot_ptr);
453-
454-
// 获取生成的假设位姿
455-
const auto &hyp_poses = package->hyp_poses;
456-
const int poses_num = hyp_poses.size();
457-
458-
// 获取对应的mesh_loader
459-
const auto &mesh_loader = map_name2loaders_[package->target_name];
460-
461-
// transformation 将模型输出的相对位姿转换为绝对位姿
462-
const float mesh_diameter = mesh_loader->GetMeshDiameter();
463-
std::vector<Eigen::Vector3f> trans_delta(poses_num);
464-
std::vector<Eigen::Vector3f> rot_delta(poses_num);
465-
std::vector<Eigen::Matrix3f> rot_mat_delta(poses_num);
466-
for (int i = 0; i < poses_num; ++i)
467-
{
468-
const size_t offset = i * 3;
469-
trans_delta[i] << trans_ptr[offset], trans_ptr[offset + 1], trans_ptr[offset + 2];
470-
trans_delta[i] *= mesh_diameter / 2;
471-
472-
rot_delta[i] << rot_ptr[offset], rot_ptr[offset + 1], rot_ptr[offset + 2];
473-
auto normalized_vect = (rot_delta[i].array().tanh() * REFINE_ROT_NORMALIZER).matrix();
474-
Eigen::AngleAxis rot_delta_angle_axis(normalized_vect.norm(), normalized_vect.normalized());
475-
rot_mat_delta[i] = rot_delta_angle_axis.toRotationMatrix().transpose();
476-
}
477-
478-
std::vector<Eigen::Matrix4f> refine_poses(poses_num);
479-
for (int i = 0; i < poses_num; ++i)
480-
{
481-
refine_poses[i] = hyp_poses[i];
482-
refine_poses[i].col(3).head(3) += trans_delta[i];
483-
484-
Eigen::Matrix3f top_left_3x3 = refine_poses[i].block<3, 3>(0, 0);
485-
Eigen::Matrix3f result_3x3 = rot_mat_delta[i] * top_left_3x3;
486-
refine_poses[i].block<3, 3>(0, 0) = result_3x3;
487-
}
488-
489-
package->actual_pose = refine_poses[0];
490-
491-
return true;
492-
}
493-
494460
std::shared_ptr<Base6DofDetectionModel> CreateFoundationPoseModel(
495461
std::shared_ptr<inference_core::BaseInferCore> refiner_core,
496462
std::shared_ptr<inference_core::BaseInferCore> scorer_core,

0 commit comments

Comments
 (0)