pytorch 自定义高斯核进行卷积操作

1.介绍

    高斯滤波的用处很多,也有很多现成的包可以被调用,比如opencv里面的cv2.GaussianBlur,一般情况,我们是没必要去造轮子,除非遇到特殊情况,比如我们在使用pytorch的过程中,需要自定义高斯核进行卷积操作,假设,我们要用的高斯核的参数是以下数目:

0.00655965 0.01330373 0.00655965 0.00078633 0.00002292
0.00655965 0.05472157 0.11098164 0.05472157 0.00655965
0.01330373 0.11098164 0.22508352 0.11098164 0.01330373
0.00655965 0.05472157 0.11098164 0.05472157 0.00655965
0.00078633 0.00655965 0.01330373 0.00655965 0.00078633

    在使用pytorch过程中,常用的卷积函数是:

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

    感觉是无法自定义卷积权重,那么我们就此放弃吗?肯定不是,当你再仔细看看pytorch的说明书之后,会发现一个好东西:

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

    里面的weight参数刚好可以用高斯核参数来填充。

2.代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import cv2


class GaussianBlurConv(nn.Module):
    def __init__(self, channels=3):
        super(GaussianBlurConv, self).__init__()
        self.channels = channels
        kernel = [[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633],
                  [0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
                  [0.01330373, 0.11098164, 0.22508352, 0.11098164, 0.01330373],
                  [0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
                  [0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633]]
        kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
        kernel = np.repeat(kernel, self.channels, axis=0)
        self.weight = nn.Parameter(data=kernel, requires_grad=False)

    def __call__(self, x):
        x = F.conv2d(x.unsqueeze(0), self.weight, padding=2, groups=self.channels)
        return x

input_x = cv2.imread("kodim04.png")
cv2.imshow("input_x", input_x)
input_x = Variable(torch.from_numpy(input_x.astype(np.float32))).permute(2, 0, 1)
gaussian_conv = GaussianBlurConv()
out_x = gaussian_conv(input_x)
out_x = out_x.squeeze(0).permute(1, 2, 0).data.numpy().astype(np.uint8)
cv2.imshow("out_x", out_x)
cv2.waitKey(0)

    原图:

    输出图:

3.扩展应用

    我们知道了怎么自定义高斯核,其它的核都可以照搬,这里就不一一讲述了。

发布了138 篇原创文章 · 获赞 141 · 访问量 11万+

猜你喜欢

转载自blog.csdn.net/u013289254/article/details/103896635
今日推荐