@@ -39,32 +39,38 @@ class FoundationPose : public Base6DofDetectionModel {
3939 const cv::Mat &depth,
4040 const cv::Mat &mask,
4141 const std::string &target_name,
42- Eigen::Matrix4f &out_pose_in_mesh) override ;
42+ Eigen::Matrix4f &out_pose_in_mesh,
43+ size_t refine_itr = 1 ) override ;
4344
4445 bool Track (const cv::Mat &rgb,
4546 const cv::Mat &depth,
4647 const Eigen::Matrix4f &hyp_pose_in_mesh,
4748 const std::string &target_name,
48- Eigen::Matrix4f &out_pose_in_mesh) override ;
49+ Eigen::Matrix4f &out_pose_in_mesh,
50+ size_t refine_itr = 1 ) override ;
4951
5052private:
5153 bool CheckInputArguments (const cv::Mat &rgb,
5254 const cv::Mat &depth,
5355 const cv::Mat &mask,
5456 const std::string &target_name);
5557
56- bool UploadDataToDevice (const cv::Mat &rgb,
57- const cv::Mat &depth,
58- const cv::Mat &mask,
59- const std::shared_ptr<FoundationPosePipelinePackage> &package);
58+ using ParsingType = std::unique_ptr<FoundationPosePipelinePackage>;
6059
61- bool RefinePreProcess (std::shared_ptr<async_pipeline::IPipelinePackage> package);
60+ bool UploadDataToDevice (const cv::Mat &rgb,
61+ const cv::Mat &depth,
62+ const cv::Mat &mask,
63+ const ParsingType &package);
6264
63- bool ScorePreprocess (std::shared_ptr<async_pipeline::IPipelinePackage> package);
65+ bool RefinePreProcess ( const ParsingType & package);
6466
65- bool ScorePostProcess (std::shared_ptr<async_pipeline::IPipelinePackage> package);
67+ bool RefinePostProcess ( const ParsingType & package);
6668
67- bool TrackPostProcess (std::shared_ptr<async_pipeline::IPipelinePackage> package);
69+ bool ScorePreprocess (const ParsingType &package);
70+
71+ bool ScorePostProcess (const ParsingType &package);
72+
73+ bool TrackPostProcess (const ParsingType &package);
6874
6975private:
7076 // 以下参数不对外开放
@@ -76,8 +82,8 @@ class FoundationPose : public Base6DofDetectionModel {
7682 const float REFINE_ROT_NORMALIZER = 0.349065850398865 ;
7783 const std::string SCORE_OUTPUT_BLOB_NAME = " scores" ;
7884 // render参数
79- const int score_mode_poses_num_ = 252 ;
80- const int refine_mode_poses_num_ = 1 ;
85+ const int score_mode_poses_num_ = 252 ;
86+ const int refine_mode_poses_num_ = 1 ;
8187 const float refine_mode_crop_ratio_ = 1.2 ;
8288 const float score_mode_crop_ratio_ = 1.1 ;
8389
@@ -150,9 +156,9 @@ FoundationPose::FoundationPose(std::shared_ptr<inference_core::BaseInferCore>
150156 {
151157 const std::string &target_name = mesh_loader->GetName ();
152158 LOG (INFO) << " [FoundationPose] Got target_name : " << target_name;
153- map_name2loaders_[target_name] = mesh_loader;
154- map_name2renderer_[target_name] = std::make_shared<FoundationPoseRenderer>(
155- mesh_loader, intrinsic_, score_mode_poses_num_);
159+ map_name2loaders_[target_name] = mesh_loader;
160+ map_name2renderer_[target_name] =
161+ std::make_shared<FoundationPoseRenderer>( mesh_loader, intrinsic_, score_mode_poses_num_);
156162 }
157163
158164 hyp_poses_sampler_ = std::make_shared<FoundationPoseSampler>(
@@ -189,12 +195,13 @@ bool FoundationPose::Register(const cv::Mat &rgb,
189195 const cv::Mat &depth,
190196 const cv::Mat &mask,
191197 const std::string &target_name,
192- Eigen::Matrix4f &out_pose_in_mesh)
198+ Eigen::Matrix4f &out_pose_in_mesh,
199+ size_t refine_itr)
193200{
194201 CHECK_STATE (CheckInputArguments (rgb, depth, mask, target_name),
195202 " [FoundationPose] `Register` Got invalid arguments!!!" );
196203
197- auto package = std::make_shared <FoundationPosePipelinePackage>();
204+ auto package = std::make_unique <FoundationPosePipelinePackage>();
198205 package->rgb_on_host = rgb;
199206 package->depth_on_host = depth;
200207 package->mask_on_host = mask;
@@ -203,19 +210,23 @@ bool FoundationPose::Register(const cv::Mat &rgb,
203210 MESSURE_DURATION_AND_CHECK_STATE (UploadDataToDevice (rgb, depth, mask, package),
204211 " [FoundationPose] SyncDetect Failed to upload data!!!" );
205212
206- MESSURE_DURATION_AND_CHECK_STATE (
207- RefinePreProcess (package),
208- " [FoundationPose] SyncDetect Failed to execute RefinePreProcess!!!" );
213+ for (size_t i = 0 ; i < refine_itr ; ++ i) {
214+ MESSURE_DURATION_AND_CHECK_STATE (
215+ RefinePreProcess (package),
216+ " [FoundationPose] SyncDetect Failed to execute RefinePreProcess!!!" );
209217
210- // package->infer_buffer = package->refiner_blobs_buffer;
211- MESSURE_DURATION_AND_CHECK_STATE (
212- refiner_core_->SyncInfer (package->GetInferBuffer ()),
213- " [FoundationPose] SyncDetect Failed to execute refiner_core_->SyncInfer!!!" );
218+ MESSURE_DURATION_AND_CHECK_STATE (
219+ refiner_core_->SyncInfer (package->GetInferBuffer ()),
220+ " [FoundationPose] SyncDetect Failed to execute refiner_core_->SyncInfer!!!" );
221+
222+ MESSURE_DURATION_AND_CHECK_STATE (
223+ RefinePostProcess (package),
224+ " [FoundationPose] SyncDetect Failed to execute RefinePostProcess!!!" );
225+ }
214226
215227 MESSURE_DURATION_AND_CHECK_STATE (
216228 ScorePreprocess (package), " [FoundationPose] SyncDetect Failed to execute ScorePreprocess!!!" );
217229
218- // unit_buffer->p_blob_buffers = package->scorer_blobs_buffer;
219230 MESSURE_DURATION_AND_CHECK_STATE (
220231 scorer_core_->SyncInfer (package->GetInferBuffer ()),
221232 " [FoundationPose] SyncDetect Failed to execute scorer_core_->SyncInfer!!!" );
@@ -232,12 +243,13 @@ bool FoundationPose::Track(const cv::Mat &rgb,
232243 const cv::Mat &depth,
233244 const Eigen::Matrix4f &hyp_pose_in_mesh,
234245 const std::string &target_name,
235- Eigen::Matrix4f &out_pose_in_mesh)
246+ Eigen::Matrix4f &out_pose_in_mesh,
247+ size_t refine_itr)
236248{
237249 CHECK_STATE (CheckInputArguments (rgb, depth, cv::Mat (), target_name),
238250 " [FoundationPose] `Track` Got invalid arguments!!!" );
239251
240- auto package = std::make_shared <FoundationPosePipelinePackage>();
252+ auto package = std::make_unique <FoundationPosePipelinePackage>();
241253 package->rgb_on_host = rgb;
242254 package->depth_on_host = depth;
243255 package->target_name = target_name;
@@ -246,26 +258,27 @@ bool FoundationPose::Track(const cv::Mat &rgb,
246258 MESSURE_DURATION_AND_CHECK_STATE (UploadDataToDevice (rgb, depth, cv::Mat (), package),
247259 " [FoundationPose] Track Failed to upload data!!!" );
248260
249- MESSURE_DURATION_AND_CHECK_STATE (RefinePreProcess (package),
250- " [FoundationPose] Track Failed to execute RefinePreProcess!!!" );
261+ for (size_t i = 0 ; i < refine_itr ; ++ i) {
262+ MESSURE_DURATION_AND_CHECK_STATE (RefinePreProcess (package),
263+ " [FoundationPose] Track Failed to execute RefinePreProcess!!!" );
251264
252- MESSURE_DURATION_AND_CHECK_STATE (
253- refiner_core_->SyncInfer (package->GetInferBuffer ()),
254- " [FoundationPose] Track Failed to execute refiner_core_->SyncInfer!!!" );
265+ MESSURE_DURATION_AND_CHECK_STATE (
266+ refiner_core_->SyncInfer (package->GetInferBuffer ()),
267+ " [FoundationPose] Track Failed to execute refiner_core_->SyncInfer!!!" );
255268
256- MESSURE_DURATION_AND_CHECK_STATE (TrackPostProcess (package),
257- " [Foundation] Track Failed to execute `TrackPostProcess`!!!" );
269+ MESSURE_DURATION_AND_CHECK_STATE (RefinePostProcess (package),
270+ " [Foundation] Track Failed to execute `RefinePostProcess`!!!" );
271+ }
258272
259- out_pose_in_mesh = std::move (package->actual_pose );
273+ out_pose_in_mesh = std::move (package->hyp_poses [ 0 ] );
260274
261275 return true ;
262276}
263277
264- bool FoundationPose::UploadDataToDevice (
265- const cv::Mat &rgb,
266- const cv::Mat &depth,
267- const cv::Mat &mask,
268- const std::shared_ptr<FoundationPosePipelinePackage> &package)
278+ bool FoundationPose::UploadDataToDevice (const cv::Mat &rgb,
279+ const cv::Mat &depth,
280+ const cv::Mat &mask,
281+ const ParsingType &package)
269282{
270283 const int input_image_height = rgb.rows , input_image_width = rgb.cols ;
271284 package->input_image_height = input_image_height;
@@ -312,11 +325,8 @@ bool FoundationPose::UploadDataToDevice(
312325 return true ;
313326}
314327
315- bool FoundationPose::RefinePreProcess (std::shared_ptr<async_pipeline::IPipelinePackage> _package )
328+ bool FoundationPose::RefinePreProcess (const ParsingType &package )
316329{
317- auto package = std::dynamic_pointer_cast<FoundationPosePipelinePackage>(_package);
318- CHECK_STATE (package != nullptr , " [FoundationPose] RefinePreProcess Got INVALID package ptr!!!" );
319-
320330 // 1. sample
321331 if (package->hyp_poses .empty ())
322332 {
@@ -327,11 +337,15 @@ bool FoundationPose::RefinePreProcess(std::shared_ptr<async_pipeline::IPipelineP
327337 }
328338
329339 // 2. render
330- auto &refine_renderer = map_name2renderer_[package->target_name ];
331- auto refiner_blob_buffer = refiner_core_->GetBuffer (false );
340+ if (package->refiner_blobs_buffer == nullptr ) {
341+ package->refiner_blobs_buffer = refiner_core_->GetBuffer (true );
342+ }
343+ const auto & refiner_blob_buffer = package->refiner_blobs_buffer ;
332344 // 设置推理前blob的输入位置为device,输出的blob位置为host端
333345 refiner_blob_buffer->SetBlobBuffer (RENDER_INPUT_BLOB_NAME, DataLocation::DEVICE);
334346 refiner_blob_buffer->SetBlobBuffer (TRANSF_INPUT_BLOB_NAME, DataLocation::DEVICE);
347+
348+ auto &refine_renderer = map_name2renderer_[package->target_name ];
335349 CHECK_STATE (
336350 refine_renderer->RenderAndTransform (
337351 package->hyp_poses , package->rgb_on_device .get (), package->depth_on_device .get (),
@@ -346,22 +360,21 @@ bool FoundationPose::RefinePreProcess(std::shared_ptr<async_pipeline::IPipelineP
346360 {input_poses_num, crop_window_H_, crop_window_W_, 6 });
347361 refiner_blob_buffer->SetBlobShape (TRANSF_INPUT_BLOB_NAME,
348362 {input_poses_num, crop_window_H_, crop_window_W_, 6 });
349- package->refiner_blobs_buffer = refiner_blob_buffer;
350363 package->infer_buffer = refiner_blob_buffer;
351364
352365 return true ;
353366}
354367
355- bool FoundationPose::ScorePreprocess (std::shared_ptr<async_pipeline::IPipelinePackage> _package )
368+ bool FoundationPose::RefinePostProcess ( const ParsingType &package )
356369{
357- auto package = std::dynamic_pointer_cast<FoundationPosePipelinePackage>(_package);
358- CHECK_STATE (package != nullptr , " [FoundationPose] ScorePreprocess Got INVALID package ptr!!!" );
359370 // 获取refiner模型的缓存指针
360371 const auto &refiner_blob_buffer = package->refiner_blobs_buffer ;
361372 const auto _trans_ptr = refiner_blob_buffer->GetOuterBlobBuffer (REFINE_TRANS_OUT_BLOB_NAME).first ;
362373 const auto _rot_ptr = refiner_blob_buffer->GetOuterBlobBuffer (REFINE_ROT_OUT_BLOB_NAME).first ;
363374 const float *trans_ptr = static_cast <float *>(_trans_ptr);
364375 const float *rot_ptr = static_cast <float *>(_rot_ptr);
376+ CHECK_STATE (trans_ptr != nullptr , " [FoundationPose] RefinePostProcess got invalid trans_ptr !" );
377+ CHECK_STATE (rot_ptr != nullptr , " [FoundationPose] RefinePostProcess got invalid rot_ptr !" );
365378
366379 // 获取生成的假设位姿
367380 const auto &hyp_poses = package->hyp_poses ;
@@ -371,10 +384,12 @@ bool FoundationPose::ScorePreprocess(std::shared_ptr<async_pipeline::IPipelinePa
371384 const auto &mesh_loader = map_name2loaders_[package->target_name ];
372385
373386 // transformation 将模型输出的相对位姿转换为绝对位姿
374- const float mesh_diameter = mesh_loader->GetMeshDiameter ();
387+ const float mesh_diameter = mesh_loader->GetMeshDiameter ();
388+
375389 std::vector<Eigen::Vector3f> trans_delta (poses_num);
376390 std::vector<Eigen::Vector3f> rot_delta (poses_num);
377391 std::vector<Eigen::Matrix3f> rot_mat_delta (poses_num);
392+
378393 for (int i = 0 ; i < poses_num; ++i)
379394 {
380395 const size_t offset = i * 3 ;
@@ -398,6 +413,12 @@ bool FoundationPose::ScorePreprocess(std::shared_ptr<async_pipeline::IPipelinePa
398413 refine_poses[i].block <3 , 3 >(0 , 0 ) = result_3x3;
399414 }
400415
416+ package->hyp_poses = std::move (refine_poses);
417+ return true ;
418+ }
419+
420+ bool FoundationPose::ScorePreprocess (const ParsingType &package)
421+ {
401422 auto scorer_blob_buffer = scorer_core_->GetBuffer (false );
402423 // 获取对应的score_renderer
403424 // 设置推理前后blob输出的位置,这里输入输出都在device端
@@ -407,29 +428,26 @@ bool FoundationPose::ScorePreprocess(std::shared_ptr<async_pipeline::IPipelinePa
407428 auto &score_renderer = map_name2renderer_[package->target_name ];
408429 CHECK_STATE (
409430 score_renderer->RenderAndTransform (
410- refine_poses , package->rgb_on_device .get (), package->depth_on_device .get (),
431+ package-> hyp_poses , package->rgb_on_device .get (), package->depth_on_device .get (),
411432 package->xyz_map_on_device .get (), package->input_image_height , package->input_image_width ,
412433 scorer_blob_buffer->GetOuterBlobBuffer (RENDER_INPUT_BLOB_NAME).first ,
413434 scorer_blob_buffer->GetOuterBlobBuffer (TRANSF_INPUT_BLOB_NAME).first ,
414435 score_mode_crop_ratio_),
415436 " [FoundationPose] score_renderer RenderAndTransform Failed!!!" );
416437
417- package->refine_poses = std::move (refine_poses);
418438 package->scorer_blobs_buffer = scorer_blob_buffer;
419439 package->infer_buffer = scorer_blob_buffer;
420440
421441 return true ;
422442}
423443
424- bool FoundationPose::ScorePostProcess (std::shared_ptr<async_pipeline::IPipelinePackage> _package )
444+ bool FoundationPose::ScorePostProcess (const ParsingType &package )
425445{
426- auto package = std::dynamic_pointer_cast<FoundationPosePipelinePackage>(_package);
427- CHECK_STATE (package != nullptr , " [FoundationPose] ScorePostProcess Got INVALID package ptr!!!" );
428446 const auto &scorer_blob_buffer = package->scorer_blobs_buffer ;
429447 // 获取scorer模型的输出缓存指针
430448 void *score_ptr = scorer_blob_buffer->GetOuterBlobBuffer (SCORE_OUTPUT_BLOB_NAME).first ;
431449
432- const auto &refine_poses = package->refine_poses ;
450+ const auto &refine_poses = package->hyp_poses ;
433451 const int poses_num = refine_poses.size ();
434452
435453 // 获取置信度最大的refined_pose
@@ -439,58 +457,6 @@ bool FoundationPose::ScorePostProcess(std::shared_ptr<async_pipeline::IPipelineP
439457 return true ;
440458}
441459
442- bool FoundationPose::TrackPostProcess (std::shared_ptr<async_pipeline::IPipelinePackage> _package)
443- {
444- auto package = std::dynamic_pointer_cast<FoundationPosePipelinePackage>(_package);
445- CHECK_STATE (package != nullptr , " [FoundationPose] TrackPostProcess Got INVALID package ptr!!!" );
446-
447- // 获取refiner模型的缓存指针
448- const auto &refiner_blob_buffer = package->refiner_blobs_buffer ;
449- const auto _trans_ptr = refiner_blob_buffer->GetOuterBlobBuffer (REFINE_TRANS_OUT_BLOB_NAME).first ;
450- const auto _rot_ptr = refiner_blob_buffer->GetOuterBlobBuffer (REFINE_ROT_OUT_BLOB_NAME).first ;
451- const float *trans_ptr = static_cast <float *>(_trans_ptr);
452- const float *rot_ptr = static_cast <float *>(_rot_ptr);
453-
454- // 获取生成的假设位姿
455- const auto &hyp_poses = package->hyp_poses ;
456- const int poses_num = hyp_poses.size ();
457-
458- // 获取对应的mesh_loader
459- const auto &mesh_loader = map_name2loaders_[package->target_name ];
460-
461- // transformation 将模型输出的相对位姿转换为绝对位姿
462- const float mesh_diameter = mesh_loader->GetMeshDiameter ();
463- std::vector<Eigen::Vector3f> trans_delta (poses_num);
464- std::vector<Eigen::Vector3f> rot_delta (poses_num);
465- std::vector<Eigen::Matrix3f> rot_mat_delta (poses_num);
466- for (int i = 0 ; i < poses_num; ++i)
467- {
468- const size_t offset = i * 3 ;
469- trans_delta[i] << trans_ptr[offset], trans_ptr[offset + 1 ], trans_ptr[offset + 2 ];
470- trans_delta[i] *= mesh_diameter / 2 ;
471-
472- rot_delta[i] << rot_ptr[offset], rot_ptr[offset + 1 ], rot_ptr[offset + 2 ];
473- auto normalized_vect = (rot_delta[i].array ().tanh () * REFINE_ROT_NORMALIZER).matrix ();
474- Eigen::AngleAxis rot_delta_angle_axis (normalized_vect.norm (), normalized_vect.normalized ());
475- rot_mat_delta[i] = rot_delta_angle_axis.toRotationMatrix ().transpose ();
476- }
477-
478- std::vector<Eigen::Matrix4f> refine_poses (poses_num);
479- for (int i = 0 ; i < poses_num; ++i)
480- {
481- refine_poses[i] = hyp_poses[i];
482- refine_poses[i].col (3 ).head (3 ) += trans_delta[i];
483-
484- Eigen::Matrix3f top_left_3x3 = refine_poses[i].block <3 , 3 >(0 , 0 );
485- Eigen::Matrix3f result_3x3 = rot_mat_delta[i] * top_left_3x3;
486- refine_poses[i].block <3 , 3 >(0 , 0 ) = result_3x3;
487- }
488-
489- package->actual_pose = refine_poses[0 ];
490-
491- return true ;
492- }
493-
494460std::shared_ptr<Base6DofDetectionModel> CreateFoundationPoseModel (
495461 std::shared_ptr<inference_core::BaseInferCore> refiner_core,
496462 std::shared_ptr<inference_core::BaseInferCore> scorer_core,
0 commit comments