使用Unity Sentis 和MoveNet实现轻量级人体关键点检测

前言

随着深度学习技术的发展,人体姿态估计已成为计算机视觉领域的热门话题之一。它在游戏开发、增强现实(AR)、虚拟现实(VR)等领域展现出巨大的应用潜力。Unity作为一款广泛使用的游戏和实时3D应用开发引擎,结合轻量级人体关键点检测模型MoveNet,可以实现高效准确的人体姿态估计。本文将简要介绍如何在Unity中使用Sentis和MoveNet实现轻量级人体关键点检测。

  • MoveNet模型:MoveNet是一种用于人体姿态估计的轻量级模型,能够提供实时、高效和准确的关键点检测能力。

模型解析:

因为没有找到合适的movenet的onnx模型,我用pytorch自己训练了一个:

输入值:一张192*192*3的原图,格式为nchw;

输出值: output1: 当前特征图中所有人的关键点的热力图;

output2:预测人体的几何中心点;

output3:预测特征图中每个人的k个关键点与其中心坐标偏移值,分为x,y值。

output4:每个关键点的偏移场。

代码示例:

根据四个值就可以轻松计算出关键点的位置:

 model = ModelLoader.Load(modelAsset);   
        gpu = new GPUComputeBackend(); 
        
        var model2 = Functional.Compile(input =>
        {
            var outputs = model.Forward(input);
            var heatmaps = outputs[0];
            var centers = outputs[1];
            var regs =outputs[2];
            var offsets =outputs[3];

            
            heatmaps=  Functional.Where(heatmaps < 0.1f,  FunctionalTensor.FromTensor(new TensorFloat(0)), heatmaps);
            var centerIndex = MaxPoint(centers,true);
            var cx = centerIndex % 48; 
            var cy = centerIndex / 48;

            var center_x = cx.Unsqueeze(-1).BroadcastTo(new[] {48}); 
            center_x = center_x.Transpose(0, 1);
            var center_y = cy.Unsqueeze(-1).BroadcastTo(new[] {48});
          
          
            
            var _range_weight  =  Functional.ARange(0,48,1)  ;
            var _range_weight_y = _range_weight.Unsqueeze(-1).BroadcastTo(new int [] {48}); 
            var _range_weight_x = _range_weight_y.Transpose(0, 1);
             
             for (int i = 0; i < 17; i++)
             {

                  
                 //nchw
                 //获取17个检测点距离center的坐标
                  FunctionalTensor reg_x_origin =  Functional.Gather(regs[0,i*2],0,center_y) ;
                  reg_x_origin = (Functional.Gather(reg_x_origin,1,cx.Reshape(new []{1,1}))+0.5f).Int() .Reshape(new []{1});
            
                  FunctionalTensor reg_y_origin =  Functional.Gather(regs[0,i*2+1],0,center_y) ;
                  reg_y_origin =  (Functional.Gather(reg_y_origin,1,cx.Reshape(new []{1,1}))+0.5f).Int().Reshape(new []{1}) ;

                  var reg_x = reg_x_origin + cx;
                  var reg_y = reg_y_origin + cy;

                  reg_x= reg_x.Unsqueeze(-1).BroadcastTo(new[] {48}).Unsqueeze(-1).BroadcastTo(new[] {48}) ;
                  reg_y= reg_y.Unsqueeze(-1).BroadcastTo(new[] {48}).Unsqueeze(-1).BroadcastTo(new[] {48}) ;

                  var range_weight_x = _range_weight_x.Reshape(new[] {1, 48, 48});
                  var range_weight_y = _range_weight_y.Reshape(new[] {1, 48, 48});

                  var tmp_reg_x =  (range_weight_x - reg_x) * (range_weight_x - reg_x) ;
                  var tmp_reg_y =  (range_weight_y - reg_y) * (range_weight_y - reg_y) ;

                  var tmp_reg = Functional.Sqrt(tmp_reg_x + tmp_reg_y) + 1.8f;

                  tmp_reg = heatmaps[.., i, .., ..]/tmp_reg;
                  tmp_reg = tmp_reg.Unsqueeze(1);
                  
                  
                  var regCenter = MaxPoint(tmp_reg,false);
                  var regc_x = regCenter % 48;
                  var  regc_y = regCenter / 48;
                  
                  var regCenter_x = regc_x.Unsqueeze(-1).BroadcastTo(new[] {48}); 
                  regCenter_x = regCenter_x.Transpose(0, 1);
                  var regCenter_y = regc_y.Unsqueeze(-1).BroadcastTo(new[] {48}); 
                
 
                  FunctionalTensor score = Functional.Gather(heatmaps[0,i],0,regCenter_y) ;
                  score =  Functional.Gather(score,1,regc_x.Reshape(new []{1,1}))  .Reshape(new []{1}) ;
                  
                  
                  
                  FunctionalTensor offset_x =  Functional.Gather(offsets[0,i*2],0,regCenter_y) ;
                  offset_x =  Functional.Gather(offset_x,1,regc_x.Reshape(new []{1,1}))  .Reshape(new []{1}) ;
            
                  FunctionalTensor offset_y =  Functional.Gather(offsets[0,i*2+1],0,regCenter_y) ;
                  offset_y =  Functional.Gather(offset_y,1,regc_x.Reshape(new []{1,1})) .Reshape(new []{1})   ;


                  var  res_x = (regc_x + offset_x) / 48f;
                  var res_y =  (regc_y + offset_y) / 48f ;
                  res_x = Functional.Where(score < 0.1f, FunctionalTensor.FromTensor(new TensorFloat(-1)), res_x);
                  res_y = Functional.Where(score < 0.1f, FunctionalTensor.FromTensor(new TensorFloat(-1)), res_y);
                  var res_xy =  Functional.Concat(new[] {res_x , res_y }, -1);
                  res.Add(res_xy);
                   

           } 
               var res_out = Functional.Stack(res.ToArray(), 0); 
               
             return   res_out ;
        },
        InputDef.FromModel(model)[0]
        );
        
        worker = WorkerFactory.CreateWorker(BackendType.GPUCompute, model2 );
  using (var input = TextureConverter.ToTensor(source, 192, 192, 3) )
        {
           using var input1 = TensorFloat.AllocNoData(new TensorShape(1, 3, 192, 192) );
            gpu.Mul(input,new TensorFloat(255),input1);
            worker.Execute(input1);
        } 
        using  var  keypoint = worker.PeekOutput("output_0") as TensorFloat; 

猜你喜欢

转载自blog.csdn.net/m0_55632444/article/details/139771518