[pose] deep-high-resolution-net.pytorch

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import pprint
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
# import _init_paths
from config import cfg
from config import update_config
from core.loss import JointsMSELoss
from core.function import validate, get_final_preds
from utils.utils import create_logger
from utils.transforms import *
import cv2
import dataset
import models
import numpy as np
def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        default='/home/cody/PycharmProjects/deep-high-resolution-net.pytorch/experiments/coco/hrnet/w48_256x192_adam_lr1e-3.yaml',
                        type=str)

    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    parser.add_argument('--img-file',
                        help='input your test img',
                        type=str,
                        default='/home/cody/PycharmProjects/deep-high-resolution-net.pytorch/videos/004.png')
    # philly
    parser.add_argument('--modelDir',
                        help='model directory',
                        type=str,
                        default='')
    parser.add_argument('--logDir',
                        help='log directory',
                        type=str,
                        default='')
    parser.add_argument('--dataDir',
                        help='data directory',
                        type=str,
                        default='')
    parser.add_argument('--prevModelDir',
                        help='prev Model directory',
                        type=str,
                        default='')
    args = parser.parse_args()
    return args

def _box2cs(box, image_width, image_height):
    x, y, w, h = box[:4]
    return _xywh2cs(x, y, w, h, image_width, image_height)


def _xywh2cs(x, y, w, h, image_width, image_height):
    center = np.zeros((2), dtype=np.float32)
    center[0] = x + w * 0.5
    center[1] = y + h * 0.5

    aspect_ratio = image_width * 1.0 / image_height
    pixel_std = 200 #200

    if w > aspect_ratio * h:
        h = w * 1.0 / aspect_ratio
    elif w < aspect_ratio * h:
        w = h * aspect_ratio
    scale = np.array(
        [w * 1.0 / pixel_std, h * 1.0 / pixel_std],
        dtype=np.float32)
    # if center[0] != -1:
    #     scale = scale * 1.25

    return center, scale

def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=False
    )

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
    else:
        model_state_file = os.path.join(
            final_output_dir, 'final_state.pth'
        )
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))

    model = torch.nn.DataParallel(model).cuda()
    # model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
    ).cuda()

    # Loading an image
    image_file = args.img_file
    data_numpy = cv2.imread(image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
    if data_numpy is None:
        logger.error('=> fail to read {}'.format(image_file))
        raise ValueError('=> fail to read {}'.format(image_file))

    # object detection box
    box = [0, 0, data_numpy.shape[1], data_numpy.shape[0]]

    # c, s = _box2cs(box, data_numpy.shape[1], data_numpy.shape[0])
    c, s = _box2cs(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
    r = 0

    trans = get_affine_transform(c, s, r, cfg.MODEL.IMAGE_SIZE)
    input = cv2.warpAffine(
        data_numpy,
        trans,
        (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])),
        flags=cv2.INTER_LINEAR)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    input = transform(input).unsqueeze(0)
    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        # compute output heatmap
        import time
        start = time.time()
        output = model(input)
        preds, maxvals = get_final_preds(cfg, output.clone().cpu().numpy(), np.asarray([c]), np.asarray([s]))
        print("time cost",time.time() - start)

        image = data_numpy.copy()
        '''
        for index,mat in enumerate(preds[0]):
            x, y = int(mat[0]), int(mat[1])
            cv2.circle(image, (x, y), 2, (255, 0, 0), 2)
            cv2.putText(image,str(index),(x,y),cv2.FONT_HERSHEY_COMPLEX,1.2,(255,0,0),1)
            cv2.imshow('res', image)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        cv2.rectangle(image,(box[0],box[1]),((box[0]+box[3]),(box[1]+box[3])),(0,0,255),3)
            # vis result
        cv2.imwrite("test_h36m.jpg", image)
        cv2.imshow('res', image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        '''

        '''
         HRNET格式:
              { 0 - nose, 1 - left-eye, 2 -right-eye,3 - left-ear,4 - r -ear,
             5 - l shoulder, 6 - r shoulder, 7 - l elbow,8 - r elbow, 9 - l wrist,
             10 - r wrist,11 - l hip, 12 - r hip, 13 - l knee,
             14 - r knee, 15 - l ankle,16 r ankle }
         '''

        new_order = [0, 1, 1, 3, 0, 2, 2, 4,
                     10,8,8,6,6,5,5,7,7,9,
                     16,14,14,12,12,11,11,13,13,15,
                     ]
        label = preds[0]
        line_key = label[new_order, :]

        color = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (0, 255, 255), (255, 255, 0), (255, 255, 255),
                 (255, 0, 0), (255, 0, 255), (0, 255, 0), (0, 0, 255), (255, 0, 0), (0, 255, 0), (0, 255, 255),
                 (255, 255, 0), (255, 255, 255), (255, 0, 0), (255, 0, 255)]
        for i in range(len(new_order)):
            if not i % 2:
                # if (line_key[i][0] == 0 & line_key[i][1] == 0) or (
                #         line_key[i + 1][0] == 0 & line_key[i + 1][1] == 0):
                #     continue
                #
                cv2.line(image, (line_key[i][0], line_key[i][1]), (line_key[i + 1][0], line_key[i + 1][1]),
                         color[int(i / 2)], 4)
                # cv2.imwrite(os.path.join(dst, img_name), img)
                # cv2.resizeWindow(image, int(1920 / 2), int(1080 / 2))
                # cv2.namedWindow("line", cv2.WINDOW_KEEPRATIO)
                cv2.imshow("line", image)
                cv2.waitKey(0)
                cv2.destroyAllWindows()


if __name__ == '__main__':
    main()

猜你喜欢

转载自blog.csdn.net/weixin_41449637/article/details/109575872