pytorch transform后的tensor还原为PIL.Image图片

 注意:以下这段代码是将一张图片的tensor形式转为PIL.Image图片格式,拿到了img后就可以保存了,导入PIL中的Image,img.save('xx.png')就行。

    def transform_invert(self,img, show=False):
        # Tensor -> PIL.Image
        # 注意:img.shape = [3,32,32] cifar10中的一张图片,经过transform后的tensor格式
        
        if img.dim() == 3:  # single image # 3,32,32
            img = img.unsqueeze(0)         #在第0维增加一个维度 1,3,32,32
        low = float(img.min())
        high = float(img.max())
        # img.clamp_(min=low, max=high)
        img.sub_(low).div_(max(high - low, 1e-5))   # (img - low)/(high-low)
        grid = img.squeeze(0)  #去除维度为1的维度
        ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
        img = Image.fromarray(ndarr)
        if show:
            img.show()
        return img

以下代码是将一个batch的tensor存储为一张网格图片。

    def transform_invert(self,imgs,path):
        def norm_ip(img, low, high):
            img.clamp_(min=low, max=high)
            img.sub_(low).div_(max(high - low, 1e-5))

        def norm_range(t, value_range):
            if value_range is not None:
                norm_ip(t, value_range[0], value_range[1])
            else:
                norm_ip(t, float(t.min()), float(t.max()))
        nrow = 4
        tensor = imgs
        import math
        padding: int = 2
        pad_value = 8

        norm_range(imgs, None)
        # make the mini-batch of images into a grid
        nmaps = tensor.size(0)
        xmaps = min(nrow, nmaps)
        ymaps = int(math.ceil(float(nmaps) / xmaps))
        height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
        num_channels = tensor.size(1)
        grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
        k = 0
        for y in range(ymaps):
            for x in range(xmaps):
                if k >= nmaps:
                    break
                # Tensor.copy_() is a valid method but seems to be missing from the stubs
                # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
                grid.narrow(1, y * height + padding, height - padding).narrow(  # type: ignore[attr-defined]
                    2, x * width + padding, width - padding
                ).copy_(tensor[k])
                k = k + 1
                
        # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
        ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
        im = Image.fromarray(ndarr)
        im.save(path, format=format)

猜你喜欢

转载自blog.csdn.net/cj151525/article/details/128263525