【Baseline】CCF遥感影像地块分割(非官方)
本文禁止转载
使用DeepLabV3+
!pip install paddlex -i https://mirror.baidu.com/pypi/simple
!pip install imgaug -i https://mirror.baidu.com/pypi/simple
# 设置使用0号GPU卡(如无GPU,执行此代码后仍然会使用CPU训练模型)
import matplotlib
import os
import paddlex as pdx
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from paddlex.seg import transforms
import imgaug.augmenters as iaa
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(target_size=300),
transforms.RandomPaddingCrop(crop_size=256),
transforms.RandomBlur(prob=0.1),
transforms.RandomRotate(rotate_range=15),
# transforms.RandomDistort(brightness_range=0.5),
transforms.Normalize()
])
eval_transforms = transforms.Compose([
transforms.Resize(256),
transforms.Normalize()
])
/opt/conda/envs/python35-paddle120-env/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject
return f(*args, **kwds)
# !unzip data/data55723/img_testA.zip
# !unzip data/data55723/train_data.zip
inflating: train_data/lab_train.zip
# !unzip train_data/lab_train.zip
# !unzip train_data/img_train.zip
import numpy as np
datas = []
image_base = 'img_train'
annos_base = 'lab_train'
ids_ = [v.split('.')[0] for v in os.listdir(image_base)]
for id_ in ids_:
img_pt0 = os.path.join(image_base, '{}.jpg'.format(id_))
img_pt1 = os.path.join(annos_base, '{}.png'.format(id_))
datas.append((img_pt0.replace('/home/aistudio/work/', ''), img_pt1.replace('/home/aistudio/work/', '')))
if os.path.exists(img_pt0) and os.path.exists(img_pt1):
pass
else:
raise "path invalid!"
print('total:', len(datas))
print(datas[0][0])
print(datas[0][1])
data_dir = '/home/aistudio/work/'
total: 145981
img_train/T100206.jpg
lab_train/T100206.png
import numpy as np
labels = [
'建筑', '耕地', '林地',
'水体', '道路', '草地',
'其他'
]
with open('labels.txt', 'w') as f:
for v in labels:
f.write(v+'\n')
np.random.seed(5)
np.random.shuffle(datas)
split_num = int(0.02*len(datas))
train_data = datas[:-split_num]
valid_data = datas[-split_num:]
with open('train_list.txt', 'w') as f:
for img, lbl in train_data:
f.write(img + ' ' + lbl + '\n')
with open('valid_list.txt', 'w') as f:
for img, lbl in valid_data:
f.write(img + ' ' + lbl + '\n')
print('train:', len(train_data))
print('valid:', len(valid_data))
train: 143062
valid: 2919
data_dir = './'
train_dataset = pdx.datasets.SegDataset(
data_dir=data_dir,
file_list='train_list.txt',
label_list='labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.SegDataset(
data_dir=data_dir,
file_list='valid_list.txt',
label_list='labels.txt',
transforms=eval_transforms)
2020-10-14 15:55:45 [INFO] 143062 samples in file train_list.txt
2020-10-14 15:55:45 [INFO] 2919 samples in file valid_list.txt
num_classes = len(train_dataset.labels)
model = pdx.seg.DeepLabv3p(
num_classes=num_classes, backbone='Xception65', use_bce_loss=False
)
model.train(
num_epochs=4,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
learning_rate=0.0002,
save_interval_epochs=1,
save_dir='output/deeplab',
log_interval_steps=200,
pretrain_weights='output/deeplab/best_model')
model.evaluate(eval_dataset, batch_size=1, epoch_id=None, return_details=False)
2020-10-14 18:00:57 [INFO] Start to evaluating(total_samples=2919, total_steps=2919)...
100%|██████████| 2919/2919 [01:13<00:00, 39.50it/s]
OrderedDict([('miou', 0.5736259173652354),
('category_iou',
array([0.63470599, 0.85614582, 0.80515013, 0.8607127 , 0.19762517,
0.13351321, 0.5275284 ])),
('macc', 0.8838217731898637),
('category_acc',
array([0.78868336, 0.90928854, 0.88048334, 0.91777826, 0.32795152,
0.32774787, 0.75617823])),
('kappa', 0.8223011553032171)])
# model = pdx.load_model('./output/deeplab/best_model')
from tqdm import tqdm
import cv2
test_base = 'img_testA/'
out_base = 'ccf_baidu_remote_sense/results/'
if not os.path.exists(out_base):
os.makedirs(out_base)
for im in tqdm(os.listdir(test_base)):
if not im.endswith('.jpg'):
continue
pt = test_base + im
result = model.predict(pt)
cv2.imwrite(out_base+im.replace('jpg', 'png'), result['label_map'])
100%|██████████| 10001/10001 [03:18<00:00, 50.30it/s]
有需求的大佬欢迎加入我的接单群,需求详情请群里戳群主
关注我的公众号:
感兴趣的同学关注我的公众号——可达鸭的深度学习教程: