不确定 SDE-Net

paper code video
https://arxiv.org/abs/2008.10546 https://github.com/Lingkai-Kong/SDE-Net https://www.youtube.com/watch?v=RylZA4Ioc3M

离散化:
x t + 1 = x t + f ( x t , t ) \Large x_{t+1} = x_t+f(x_t,t) xt+1=xt+f(xt,t)
连续化:
d x t = f ( x t , t ) d t \Large d{x_t} = f(x_t,t)dt dxt=f(xt,t)dt

d x t = f ( x t , t ) d t + g ( x t , t ) d w t \Large d{x_t} = f(x_t,t)dt+g(x_t,t)dw_t dxt=f(xt,t)dt+g(xt,t)dwt

x t + 1 = x t + f ( x k , t ) δ t + g ( x o ) δ t Z k \Large x_{t+1} = x_t + f(x_k,t)\delta t +g(x_o) \sqrt {\delta t} Z_k xt+1=xt+f(xk,t)δt+g(xo)δt Zk

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述在这里插入图片描述

  • 简单图示过程:

在这里插入图片描述

python sdenet_mnist.py 

Evaluation:

python test_detection.py --pre_trained_net save_sdenet_mnist/final_model --network sdenet --dataset

代码

  • 好消息是环境异常简单,直接打开工程,然后添加了一个venv的环境
    在这里插入图片描述
  • 然后安装个torch就能运行了
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
SDENet_mnist(
  (downsampling_layers): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): GroupNorm(32, 64, eps=1e-05, affine=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
  (drift): Drift(
    (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (relu): ReLU(inplace=True)
    (conv1): ConcatConv2d(
      (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
    (conv2): ConcatConv2d(
      (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (norm3): GroupNorm(32, 64, eps=1e-05, affine=True)
  )
  (diffusion): Diffusion(
    (norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
    (relu): ReLU(inplace=True)
    (conv1): ConcatConv2d(
      (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
    (conv2): ConcatConv2d(
      (_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (fc): Sequential( # Diffusion 和 Drift 相比就多了一个这个
      (0): GroupNorm(32, 64, eps=1e-05, affine=True)
      (1): ReLU(inplace=True)
      (2): AdaptiveAvgPool2d(output_size=(1, 1))
      (3): Flatten()
      (4): Linear(in_features=64, out_features=1, bias=True) # 这个和下边的的就只有输出维度不同
      (5): Sigmoid()
    )
  )
  (fc_layers): Sequential(
    (0): GroupNorm(32, 64, eps=1e-05, affine=True)
    (1): ReLU(inplace=True)
    (2): AdaptiveAvgPool2d(output_size=(1, 1))
    (3): Flatten()
    (4): Linear(in_features=64, out_features=10, bias=True)
  )
)
torch.Size([36, 1, 28, 28]) =>(out = self.downsampling_layers(x)) torch.Size([36, 64, 6, 6])
diffusion_term = self.sigma*self.diffusion(t, out)即为 20 * torch.Size([36, 1])
diffusion_term = torch.unsqueeze(diffusion_term, 2)=> torch.Size([36, 1, 1])
diffusion_term = torch.unsqueeze(diffusion_term, 3)=> torch.Size([36, 1, 1, 1])
t为 0.0 输出的大小 torch.Size([36, 64, 6, 6])
t为 1.0 输出的大小 torch.Size([36, 64, 6, 6])
t为 2.0 输出的大小 torch.Size([36, 64, 6, 6])
t为 3.0 输出的大小 torch.Size([36, 64, 6, 6])
t为 4.0 输出的大小 torch.Size([36, 64, 6, 6])
t为 5.0 输出的大小 torch.Size([36, 64, 6, 6])
final_out torch.Size([36, 10])
最关键的一句out = out + self.drift(t, out)*self.deltat + diffusion_term*math.sqrt(self.deltat)*torch.randn_like(out).to(x)
其中 self.deltat = 1.0
  • 注意代码中使用的GroupNorm和ConcatConv2d,其中ConcatConv2d为:
  • 代码中的t是一种权重
class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)

快速运行

无须下载数据的sdenet_mnist.py版本

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 11 16:34:10 2019

@author: lingkaikong
"""

from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import random
import os
import argparse
import sed as models
#import data_loader

parser = argparse.ArgumentParser(description='PyTorch SDE-Net Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate of drift net')
parser.add_argument('--lr2', default=0.01, type=float, help='learning rate of diffusion net')
parser.add_argument('--training_out', action='store_false', default=True, help='training_with_out')
parser.add_argument('--epochs', type=int, default=40, help='number of epochs to train')
parser.add_argument('--eva_iter', default=5, type=int, help='number of passes when evaluation')
parser.add_argument('--dataset_inDomain', default='mnist', help='training dataset')
parser.add_argument('--batch_size', type=int, default=36, help='input batch size for training')
parser.add_argument('--imageSize', type=int, default=28, help='the height / width of the input image to network')
parser.add_argument('--test_batch_size', type=int, default=1000)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--seed', type=float, default=0)
parser.add_argument('--droprate', type=float, default=0.1, help='learning rate decay')
parser.add_argument('--decreasing_lr', default=[10, 20,30], nargs='+', help='decreasing strategy')
parser.add_argument('--decreasing_lr2', default=[15, 30], nargs='+', help='decreasing strategy')
args = parser.parse_args()

device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

torch.manual_seed(args.seed)
random.seed(args.seed)

if device == 'cuda':
    cudnn.benchmark = True
    torch.cuda.manual_seed(args.seed)



# print('load in-domain data: ',args.dataset_inDomain)
# train_loader_inDomain, test_loader_inDomain = data_loader.getDataSet(args.dataset_inDomain, args.batch_size, args.test_batch_size, args.imageSize)

# Model
print('==> Building model..')
net = models.SDENet_mnist(layer_depth=6, num_classes=10, dim=64)
net = net.to(device)


real_label = 0
fake_label = 1

criterion = nn.CrossEntropyLoss()
criterion2 = nn.BCELoss()

optimizer_F = optim.SGD([ {'params': net.downsampling_layers.parameters()}, {'params': net.drift.parameters()},
{'params': net.fc_layers.parameters()}], lr=args.lr, momentum=0.9, weight_decay=5e-4)

optimizer_G = optim.SGD([ {'params': net.diffusion.parameters()}], lr=args.lr2, momentum=0.9, weight_decay=5e-4)

#use a smaller sigma during training for training stability
net.sigma = 20

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()

    train_loss = 0
    correct = 0
    total = 0
    train_loss_out = 0
    train_loss_in = 0

    ##training with in-domain data
    for batch_idx in range(800):#, (inputs, targets) in enumerate(train_loader_inDomain):
        #inputs, targets = inputs.to(device), targets.to(device)
        inputs = torch.randn([36, 1, 28, 28]).to(device)
        targets = torch.tensor([8, 1, 2, 7, 1, 2, 3, 0, 1, 2, 4, 5, 9, 6, 3, 9, 0, 3, 5, 7, 6, 9, 8, 1,
        2, 5, 0, 2, 6, 9, 7, 3, 3, 4, 0, 8]).to(device)
        optimizer_F.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer_F.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    #training with out-of-domain data
        label = torch.full((args.batch_size,1), real_label, device=device).float()# 后加的folat
        optimizer_G.zero_grad()
        predict_in = net(inputs, training_diffusion=True)
        loss_in = criterion2(predict_in, label)
        loss_in.backward()
        label.fill_(fake_label)
        inputs_out = 2*torch.randn(args.batch_size,1, args.imageSize, args.imageSize, device = device)+inputs
        predict_out = net(inputs_out, training_diffusion=True)
        loss_out = criterion2(predict_out, label)
        loss_out.backward()
        train_loss_out += loss_out.item()
        train_loss_in += loss_in.item()
        optimizer_G.step()

    # print('Train epoch:{} \tLoss: {:.6f} | Loss_in: {:.6f}, Loss_out: {:.6f} | Acc: {:.6f} ({}/{})'
    #     .format(epoch, train_loss/(len(train_loader_inDomain)), train_loss_in/len(train_loader_inDomain), train_loss_out/len(train_loader_inDomain), 100.*correct/total, correct, total))


# def test(epoch):
#     net.eval()
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for batch_idx, (inputs, targets) in enumerate(test_loader_inDomain):
#             inputs, targets = inputs.to(device), targets.to(device)
#             outputs = 0
#             for j in range(args.eva_iter):
#                 current_batch = net(inputs)
#                 outputs = outputs + F.softmax(current_batch, dim = 1)
#
#             outputs = outputs/args.eva_iter
#             _, predicted = outputs.max(1)
#             total += targets.size(0)
#             correct += predicted.eq(targets).sum().item()
#
#         print('Test epoch: {} | Acc: {:.6f} ({}/{})'
#         .format(epoch, 100.*correct/total, correct, total))


for epoch in range(0, args.epochs):
    train(epoch)
    # test(epoch)
    # if epoch in args.decreasing_lr:
    #     for param_group in optimizer_F.param_groups:
    #         param_group['lr'] *= args.droprate
    # if epoch in args.decreasing_lr2:
    #     for param_group in optimizer_G.param_groups:
    #         param_group['lr'] *= args.droprate

# if not os.path.isdir('./save_sdenet_mnist'):
#     os.makedirs('./save_sdenet_mnist')
# torch.save(net.state_dict(),'./save_sdenet_mnist/final_model')


最主要的部分(独立运行这个也行)

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 11 16:42:11 2019

@author: lingkaikong
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import torch.nn.init as init
import math

__all__ = ['SDENet_mnist']


def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal_(m.weight, mode='fan_out')
            if m.bias is not None:
                init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant_(m.weight, 1)
            init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal_(m.weight, std=1e-3)
            if m.bias is not None:
                init.constant_(m.bias, 0)


# torch.manual_seed(0)
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)


class ConcatConv2d(nn.Module):

    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
        super(ConcatConv2d, self).__init__()
        module = nn.ConvTranspose2d if transpose else nn.Conv2d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)


class Drift(nn.Module):

    def __init__(self, dim):
        super(Drift, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)

    def forward(self, t, x):
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


class Diffusion(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(Diffusion, self).__init__()
        self.norm1 = norm(dim_in)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim_in, dim_out, 3, 1, 1)
        self.norm2 = norm(dim_in)
        self.conv2 = ConcatConv2d(dim_in, dim_out, 3, 1, 1)
        self.fc = nn.Sequential(norm(dim_out), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(),
                                nn.Linear(dim_out, 1), nn.Sigmoid())

    def forward(self, t, x):
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.fc(out)
        return out


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


class SDENet_mnist(nn.Module):
    def __init__(self, layer_depth, num_classes=10, dim=64):
        super(SDENet_mnist, self).__init__()
        self.layer_depth = layer_depth
        self.downsampling_layers = nn.Sequential(
            nn.Conv2d(1, dim, 3, 1),
            norm(dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim, dim, 4, 2, 1),
            norm(dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim, dim, 4, 2, 1),
        )
        self.drift = Drift(dim)
        self.diffusion = Diffusion(dim, dim)
        self.fc_layers = nn.Sequential(norm(dim), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(),
                                       nn.Linear(dim, 10))
        self.deltat = 6. / self.layer_depth
        self.apply(init_params)
        self.sigma = 500

    def forward(self, x, training_diffusion=False):
        out = self.downsampling_layers(x)
        print(x.shape, "=>(out = self.downsampling_layers(x))", out.shape)
        if not training_diffusion:
            t = 0
            diffusion_term = self.sigma * self.diffusion(t, out)
            print("diffusion_term = self.sigma*self.diffusion(t, out)即为", self.sigma, "*", self.diffusion(t, out).shape)
            diffusion_term = torch.unsqueeze(diffusion_term, 2)
            print("diffusion_term = torch.unsqueeze(diffusion_term, 2)=>", diffusion_term.shape)
            diffusion_term = torch.unsqueeze(diffusion_term, 3)
            print("diffusion_term = torch.unsqueeze(diffusion_term, 3)=>", diffusion_term.shape)
            for i in range(self.layer_depth):
                t = 6 * (float(i)) / self.layer_depth
                print(
                    "最关键的一句out = out + self.drift(t, out)*self.deltat + diffusion_term*math.sqrt(self.deltat)*torch.randn_like(out).to(x)")
                print("self.deltat", self.deltat)
                out = out + self.drift(t, out) * self.deltat + diffusion_term * math.sqrt(
                    self.deltat) * torch.randn_like(out).to(x)  # .to(x) 表示变成x的类型和device
                print("t为", t, "输出的大小", out.shape)
            final_out = self.fc_layers(out)
            print("final_out", final_out.shape)
        else:
            t = 0
            final_out = self.diffusion(t, out.detach())
        return final_out


# def test():
#     model = SDENet_mnist(layer_depth=10, num_classes=10, dim=64)
#     return model


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


if __name__ == '__main__':
    model = SDENet_mnist(layer_depth=10, num_classes=10, dim=64)#test() # 有些可能会要求pytest模块才能运行所以注释一下
    num_params = count_parameters(model)
    print(num_params)

CG

可以估计不确定性的神经网络——SDE-Net模型浅析
概率自回归预测——DeepAR模型浅析

https://github.com/Junghwan-brian/SDE-Net/blob/master/model/SDENet.py

https://github.com/Lingkai-Kong/SDE-Net/blob/master/MNIST/resnet_mnist.py

猜你喜欢

转载自blog.csdn.net/ResumeProject/article/details/127894176