train
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V
import cv2
import os
import numpy as np
from time import time
import datetime
from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool
from framework import MyFrame
from loss import dice_bce_loss
from data import ImageFolder
import matplotlib.pyplot as plt
SHAPE = (512,512)
dsm_ROOT = 'dataset/dsm_train/'
rgb_ROOT = 'dataset/rgb_train/'
dsm_imagelist = filter(lambda x: x.find('sat')!=-1, os.listdir(dsm_ROOT))
dsm_trainlist = list(map(lambda x: x[:-8], dsm_imagelist))
NAME = 'log01_Dink101'
BATCHSIZE_PER_CARD = 2
solver = MyFrame(DinkNet34, dice_bce_loss, 2e-4)
batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
dataset = ImageFolder(dsm_trainlist, dsm_ROOT,rgb_ROOT)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batchsize,
shuffle=True)
mylog = open('logs/'+NAME+'.log','w')
tic = time()
time1 = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
no_optim = 0
total_epoch = 200
train_epoch_best_loss = 150
#draw=0
Loss_list = []
print('start!!')
for epoch in range(1, total_epoch + 1):
data_loader_iter = iter(data_loader)
train_epoch_loss = 0
for dsm_img, rgb_img, dsm_mask in data_loader_iter:
solver.set_input(dsm_img, rgb_img, dsm_mask)
train_loss = solver.optimize()
train_epoch_loss += train_loss
train_epoch_loss /= len(data_loader_iter)
Loss_list.append(train_epoch_loss)
mylog.write( '********')
#mylog.write('epoch:' + str(epoch) + ' time:'+ str(int(time()-tic)))
mylog.write('epoch:' + str(epoch) + ' time:'+ datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
mylog.write(' train_loss:' + str(train_epoch_loss))
mylog.write(' SHAPE:' + SHAPE.__str__())
mylog.write('\n')
print ('********')
#print ('epoch:',epoch,' time:',int(time()-tic))
print ('epoch:',epoch,' time:',datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print ('train_loss:',train_epoch_loss)
print ('SHAPE:',SHAPE)
#draw=draw+1
if train_epoch_loss >= train_epoch_best_loss:
no_optim += 1
else:
no_optim = 0
train_epoch_best_loss = train_epoch_loss
solver.save('weights/'+NAME+'.th')
if no_optim > 6:
mylog.write('early stop at %d epoch' % epoch)
print ('early stop at %d epoch' % epoch)
solver.save('weights/'+NAME+'_earlystop_%d.th'% epoch)
#break
if epoch%5==0:
solver.save('weights/'+NAME+'_five_%d.th'% epoch)
if no_optim > 3:
if solver.old_lr < 5e-7:
mylog.write('olver.old_lr < 5e-7 at %d epoch' % epoch)
#break
solver.load('weights/'+NAME+'.th')
solver.update_lr(5.0, factor = True, mylog = mylog)
mylog.flush()
mylog.write('Finish!')
print ('Finish!')
time2 = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("begin train!")
# 打印按指定格式排版的时间
print(time1)
print("finish train!")
# 打印按指定格式排版的时间
print(time2)
mylog.close()
x1 = range(0, len(Loss_list))
y1 = Loss_list
plt.plot(x1, y1, 'o-')
plt.title('Model loss vs. epoches')
plt.ylabel('Model loss')
plt.savefig("model_loss.jpg")
plt.show()
data
"""
Based on https://github.com/asanakoy/kaggle_carvana_segmentation
"""
import torch
import torch.utils.data as data
from torch.autograd import Variable as V
import cv2
import numpy as np
import os
def randomHueSaturationValue(image, hue_shift_limit=(-180, 180),
sat_shift_limit=(-255, 255),
val_shift_limit=(-255, 255), u=0.5):
if np.random.random() < u:
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
h, s, v = cv2.split(image)
hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1]+1)
hue_shift = np.uint8(hue_shift)
h += hue_shift
sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])
s = cv2.add(s, sat_shift)
val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])
v = cv2.add(v, val_shift)
image = cv2.merge((h, s, v))
#image = cv2.merge((s, v))
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
return image
def randomShiftScaleRotate(dsm_image, rgb_image, mask,
shift_limit=(-0.0, 0.0),
scale_limit=(-0.0, 0.0),
rotate_limit=(-0.0, 0.0),
aspect_limit=(-0.0, 0.0),
borderMode=cv2.BORDER_CONSTANT, u=0.5):
if np.random.random() < u:
height, width, channel = dsm_image.shape
angle = np.random.uniform(rotate_limit[0], rotate_limit[1])
scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])
aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])
sx = scale * aspect / (aspect ** 0.5)
sy = scale / (aspect ** 0.5)
dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)
dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)
cc = np.math.cos(angle / 180 * np.math.pi) * sx
ss = np.math.sin(angle / 180 * np.math.pi) * sy
rotate_matrix = np.array([[cc, -ss], [ss, cc]])
box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ])
box1 = box0 - np.array([width / 2, height / 2])
box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy])
box0 = box0.astype(np.float32)
box1 = box1.astype(np.float32)
mat = cv2.getPerspectiveTransform(box0, box1)
dsm_image = cv2.warpPerspective(dsm_image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
borderValue=(
0, 0,
0,))
rgb_image = cv2.warpPerspective(rgb_image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
borderValue=(
0, 0,
0,))
mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode,
borderValue=(
0, 0,
0,))
return dsm_image, rgb_image, mask
def randomHorizontalFlip(dsm_image, rgb_image, mask, u=0.5):
if np.random.random() < u:
dsm_image = cv2.flip(dsm_image, 1)
rgb_image = cv2.flip(rgb_image, 1)
mask = cv2.flip(mask, 1)
return dsm_image, rgb_image, mask
def randomVerticleFlip(dsm_image, rgb_image, mask, u=0.5):
if np.random.random() < u:
dsm_image = cv2.flip(dsm_image, 0)
rgb_image = cv2.flip(rgb_image, 0)
mask = cv2.flip(mask, 0)
return dsm_image, rgb_image, mask
def randomRotate90(dsm_image, rgb_image, mask, u=0.5):
if np.random.random() < u:
dsm_image=np.rot90(dsm_image)
rgb_image=np.rot90(rgb_image)
mask=np.rot90(mask)
return dsm_image, rgb_image, mask
def default_loader(id, dsm_root,rgb_root):
dsm_img = cv2.imread(os.path.join(dsm_root,'{}_sat.tif').format(id))
rgb_img = cv2.imread(os.path.join(rgb_root,'{}_sat.tif').format(id))
mask = cv2.imread(os.path.join(dsm_root+'{}_mask.png').format(id), cv2.IMREAD_GRAYSCALE)
if len(mask.shape)==1:
print(os.path.join(dsm_root+'{}_mask.png').format(id))
print(mask)
dsm_img = randomHueSaturationValue(dsm_img,
hue_shift_limit=(-30, 30),
sat_shift_limit=(-5, 5),
val_shift_limit=(-15, 15))
rgb_img = randomHueSaturationValue(rgb_img,
hue_shift_limit=(-30, 30),
sat_shift_limit=(-5, 5),
val_shift_limit=(-15, 15))
dsm_img, rgb_img, mask = randomShiftScaleRotate(dsm_img, rgb_img, mask,
shift_limit=(-0.1, 0.1),
scale_limit=(-0.1, 0.1),
aspect_limit=(-0.1, 0.1),
rotate_limit=(-0, 0))
dsm_img, rgb_img, mask = randomHorizontalFlip(dsm_img, rgb_img, mask)
dsm_img, rgb_img, mask = randomVerticleFlip(dsm_img, rgb_img, mask)
dsm_img, rgb_img, mask = randomRotate90(dsm_img, rgb_img, mask)
mask = np.expand_dims(mask, axis=2)
dsm_img = np.array(dsm_img, np.float32).transpose(2,0,1)/255.0 * 3.2 - 1.6
rgb_img = np.array(rgb_img, np.float32).transpose(2,0,1)/255.0 * 3.2 - 1.6
mask = np.array(mask, np.float32).transpose(2,0,1)/255.0
mask[mask>=0.5] = 1
mask[mask<=0.5] = 0
#mask = abs(mask-1)
return dsm_img, rgb_img, mask
class ImageFolder(data.Dataset):
def __init__(self, trainlist, dsm_root,rgb_root):
self.ids = trainlist
self.loader = default_loader
self.dsm_root = dsm_root
self.rgb_root = rgb_root
def __getitem__(self, index):
id = self.ids[index]
dsm_img, rgb_img, mask = self.loader(id, self.dsm_root,self.rgb_root)
dsm_img = torch.Tensor(dsm_img)
rgb_img = torch.Tensor(rgb_img)
mask = torch.Tensor(mask)
return dsm_img, rgb_img, mask
def __len__(self):
return len(list(self.ids))
dlinknet
"""
Codes of LinkNet based on https://github.com/snakers4/spacenet-three
"""
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import models
import torch.nn.functional as F
from functools import partial
nonlinearity = partial(F.relu,inplace=True)
class Dblock_more_dilate(nn.Module):
def __init__(self,channel):
super(Dblock_more_dilate, self).__init__()
self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
dilate1_out = nonlinearity(self.dilate1(x))
dilate2_out = nonlinearity(self.dilate2(dilate1_out))
dilate3_out = nonlinearity(self.dilate3(dilate2_out))
dilate4_out = nonlinearity(self.dilate4(dilate3_out))
dilate5_out = nonlinearity(self.dilate5(dilate4_out))
out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out + dilate5_out
return out
class Dblock(nn.Module):
def __init__(self,channel):
super(Dblock, self).__init__()
self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
#self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
dilate1_out = nonlinearity(self.dilate1(x))
dilate2_out = nonlinearity(self.dilate2(dilate1_out))
dilate3_out = nonlinearity(self.dilate3(dilate2_out))
dilate4_out = nonlinearity(self.dilate4(dilate3_out))
#dilate5_out = nonlinearity(self.dilate5(dilate4_out))
out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out# + dilate5_out
return out
class DecoderBlock(nn.Module):
def __init__(self, in_channels, n_filters):
super(DecoderBlock,self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
self.norm1 = nn.BatchNorm2d(in_channels // 4)
self.relu1 = nonlinearity
self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1)
self.norm2 = nn.BatchNorm2d(in_channels // 4)
self.relu2 = nonlinearity
self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
self.norm3 = nn.BatchNorm2d(n_filters)
self.relu3 = nonlinearity
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.deconv2(x)
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.relu3(x)
return x
class DinkNet34_less_pool(nn.Module):
def __init__(self, num_classes=1):
super(DinkNet34_more_dilate, self).__init__()
filters = [64, 128, 256, 512]
resnet = models.resnet34(pretrained=True)
self.firstconv = resnet.conv1
self.firstbn = resnet.bn1
self.firstrelu = resnet.relu
self.firstmaxpool = resnet.maxpool
self.encoder1 = resnet.layer1
self.encoder2 = resnet.layer2
self.encoder3 = resnet.layer3
self.dblock = Dblock_more_dilate(256)
self.decoder3 = DecoderBlock(filters[2], filters[1])
self.decoder2 = DecoderBlock(filters[1], filters[0])
self.decoder1 = DecoderBlock(filters[0], filters[0])
self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
self.finalrelu1 = nonlinearity
self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
self.finalrelu2 = nonlinearity
self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
def forward(self, x):
# Encoder
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstrelu(x)
x = self.firstmaxpool(x)
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
#Center
e3 = self.dblock(e3)
# Decoder
d3 = self.decoder3(e3) + e2
d2 = self.decoder2(d3) + e1
d1 = self.decoder1(d2)
# Final Classification
out = self.finaldeconv1(d1)
out = self.finalrelu1(out)
out = self.finalconv2(out)
out = self.finalrelu2(out)
out = self.finalconv3(out)
return torch.sigmoid(out)
class DinkNet34(nn.Module):
def __init__(self, num_classes=1, num_channels=3):
super(DinkNet34, self).__init__()
filters = [64, 128, 256, 512]
resnet = models.resnet34(weights=None)
resnet.load_state_dict(torch.load('D:/Deepl/networks/resnet34-333f7ec4.pth'))
self.firstconv = resnet.conv1
# self.firstconv = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.firstbn = resnet.bn1
self.firstrelu = resnet.relu
self.firstmaxpool = resnet.maxpool
self.encoder1 = resnet.layer1
self.encoder2 = resnet.layer2
self.encoder3 = resnet.layer3
self.encoder4 = resnet.layer4
self.dblock = Dblock(512)
self.decoder4 = DecoderBlock(filters[3], filters[2])
self.decoder3 = DecoderBlock(filters[2], filters[1])
self.decoder2 = DecoderBlock(filters[1], filters[0])
self.decoder1 = DecoderBlock(filters[0], filters[0])
self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
self.finalrelu1 = nonlinearity
self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
self.finalrelu2 = nonlinearity
self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
def forward(self, x, y):
# Encoder
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstrelu(x)
x = self.firstmaxpool(x)
y = self.firstconv(y)
y = self.firstbn(y)
y = self.firstrelu(y)
y = self.firstmaxpool(y)
e1_x = self.encoder1(x)
e2_x = self.encoder2(e1_x)
e3_x = self.encoder3(e2_x)
e4_x = self.encoder4(e3_x)
e1_y = self.encoder1(y)
e2_y = self.encoder2(e1_y)
e3_y = self.encoder3(e2_y)
e4_y = self.encoder4(e3_y)
# Center
e4 = self.dblock(e4_x) + self.dblock(e4_y)
# Decoder
d4 = self.decoder4(e4) + e3_x + e3_y
d3 = self.decoder3(d4) + e2_x + e2_y
d2 = self.decoder2(d3) + e1_x + e1_y
d1 = self.decoder1(d2)
out = self.finaldeconv1(d1)
out = self.finalrelu1(out)
out = self.finalconv2(out)
out = self.finalrelu2(out)
out = self.finalconv3(out)
return torch.sigmoid(out)
class DinkNet50(nn.Module):
def __init__(self, num_classes=1):
super(DinkNet50, self).__init__()
filters = [256, 512, 1024, 2048]
resnet = models.resnet50(pretrained=True)
self.firstconv = resnet.conv1
self.firstbn = resnet.bn1
self.firstrelu = resnet.relu
self.firstmaxpool = resnet.maxpool
self.encoder1 = resnet.layer1
self.encoder2 = resnet.layer2
self.encoder3 = resnet.layer3
self.encoder4 = resnet.layer4
self.dblock = Dblock_more_dilate(2048)
self.decoder4 = DecoderBlock(filters[3], filters[2])
self.decoder3 = DecoderBlock(filters[2], filters[1])
self.decoder2 = DecoderBlock(filters[1], filters[0])
self.decoder1 = DecoderBlock(filters[0], filters[0])
self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
self.finalrelu1 = nonlinearity
self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
self.finalrelu2 = nonlinearity
self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
def forward(self, x):
# Encoder
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstrelu(x)
x = self.firstmaxpool(x)
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
# Center
e4 = self.dblock(e4)
# Decoder
d4 = self.decoder4(e4) + e3
d3 = self.decoder3(d4) + e2
d2 = self.decoder2(d3) + e1
d1 = self.decoder1(d2)
out = self.finaldeconv1(d1)
out = self.finalrelu1(out)
out = self.finalconv2(out)
out = self.finalrelu2(out)
out = self.finalconv3(out)
return torch.sigmoid(out)
class DinkNet101(nn.Module):
def __init__(self, num_classes=1):
super(DinkNet101, self).__init__()
filters = [256, 512, 1024, 2048]
resnet = models.resnet101(pretrained=True)
self.firstconv = resnet.conv1
self.firstbn = resnet.bn1
self.firstrelu = resnet.relu
self.firstmaxpool = resnet.maxpool
self.encoder1 = resnet.layer1
self.encoder2 = resnet.layer2
self.encoder3 = resnet.layer3
self.encoder4 = resnet.layer4
self.dblock = Dblock_more_dilate(2048)
self.decoder4 = DecoderBlock(filters[3], filters[2])
self.decoder3 = DecoderBlock(filters[2], filters[1])
self.decoder2 = DecoderBlock(filters[1], filters[0])
self.decoder1 = DecoderBlock(filters[0], filters[0])
self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
self.finalrelu1 = nonlinearity
self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
self.finalrelu2 = nonlinearity
self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
def forward(self, x):
# Encoder
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstrelu(x)
x = self.firstmaxpool(x)
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
# Center
e4 = self.dblock(e4)
# Decoder
d4 = self.decoder4(e4) + e3
d3 = self.decoder3(d4) + e2
d2 = self.decoder2(d3) + e1
d1 = self.decoder1(d2)
out = self.finaldeconv1(d1)
out = self.finalrelu1(out)
out = self.finalconv2(out)
out = self.finalrelu2(out)
out = self.finalconv3(out)
return torch.sigmoid(out)
class LinkNet34(nn.Module):
def __init__(self, num_classes=1):
super(LinkNet34, self).__init__()
filters = [64, 128, 256, 512]
resnet = models.resnet34(pretrained=True)
self.firstconv = resnet.conv1
self.firstbn = resnet.bn1
self.firstrelu = resnet.relu
self.firstmaxpool = resnet.maxpool
self.encoder1 = resnet.layer1
self.encoder2 = resnet.layer2
self.encoder3 = resnet.layer3
self.encoder4 = resnet.layer4
self.decoder4 = DecoderBlock(filters[3], filters[2])
self.decoder3 = DecoderBlock(filters[2], filters[1])
self.decoder2 = DecoderBlock(filters[1], filters[0])
self.decoder1 = DecoderBlock(filters[0], filters[0])
self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2)
self.finalrelu1 = nonlinearity
self.finalconv2 = nn.Conv2d(32, 32, 3)
self.finalrelu2 = nonlinearity
self.finalconv3 = nn.Conv2d(32, num_classes, 2, padding=1)
def forward(self, x):
# Encoder
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstrelu(x)
x = self.firstmaxpool(x)
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
# Decoder
d4 = self.decoder4(e4) + e3
d3 = self.decoder3(d4) + e2
d2 = self.decoder2(d3) + e1
d1 = self.decoder1(d2)
out = self.finaldeconv1(d1)
out = self.finalrelu1(out)
out = self.finalconv2(out)
out = self.finalrelu2(out)
out = self.finalconv3(out)
return torch.sigmoid(out)
framework
import torch
import torch.nn as nn
from torch.autograd import Variable as V
import cv2
import numpy as np
class MyFrame():
def __init__(self, net, loss, lr=2e-4, evalmode = False):
self.net = net().cuda()
self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
self.optimizer = torch.optim.Adam(params=self.net.parameters(), lr=lr)
#self.optimizer = torch.optim.RMSprop(params=self.net.parameters(), lr=lr)
self.loss = loss()
self.old_lr = lr
if evalmode:
for i in self.net.modules():
if isinstance(i, nn.BatchNorm2d):
i.eval()
def set_input(self, dsm_img_batch, rgb_img_batch, mask_batch=None, img_id=None):
self.dsm_img = dsm_img_batch
self.rgb_img = rgb_img_batch
self.mask = mask_batch
self.img_id = img_id
def test_one_img(self, img):
pred = self.net.forward(img)
pred[pred>0.5] = 1
pred[pred<=0.5] = 0
mask = pred.squeeze().cpu().data.numpy()
return mask
def test_batch(self):
self.forward(volatile=True)
mask = self.net.forward(self.img).cpu().data.numpy().squeeze(1)
mask[mask>0.5] = 1
mask[mask<=0.5] = 0
return mask, self.img_id
def test_one_img_from_path(self, path):
img = cv2.imread(path)
img = np.array(img, np.float32)/255.0 * 3.2 - 1.6
img = V(torch.Tensor(img).cuda())
mask = self.net.forward(img).squeeze().cpu().data.numpy()#.squeeze(1)
mask[mask>0.5] = 1
mask[mask<=0.5] = 0
return mask
def forward(self, volatile=False):
self.dsm_img = V(self.dsm_img.cuda(), volatile=volatile)
self.rgb_img = V(self.rgb_img.cuda(), volatile=volatile)
if self.mask is not None:
self.mask = V(self.mask.cuda(), volatile=volatile)
def optimize(self):
self.forward()
self.optimizer.zero_grad()
pred = self.net.forward(self.dsm_img,self.rgb_img)
loss = self.loss(self.mask, pred)
loss.backward()
self.optimizer.step()
return loss.item()
def save(self, path):
torch.save(self.net.state_dict(), path)
def load(self, path):
self.net.load_state_dict(torch.load(path))
def update_lr(self, new_lr, mylog, factor=False):
if factor:
new_lr = self.old_lr / new_lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
mylog.write('update learning rate: %f -> %f' % (self.old_lr, new_lr))
print ('update learning rate: %f -> %f' % (self.old_lr, new_lr))
self.old_lr = new_lr