用mmdetection训练目标检测模型时,出现了bbox_loss和cls_loss为nan的问题,记录排查原因的过程,以及最终的解决。现象是,loss 在正常降低的过程中,突然跳变nan,整体震荡下降。
mmdetection中出现 loss 在正常降低的过程中,突然跳变nan的可能原因如下:
1 、mmdet中的 core/ evalution/ classnames.py 中的类别没修改
2 、mmdet 中的 datasets/ voc.py 中的类别没修改
3 、config/ base/models 中基础模型的 num_classes 没有修改正确
4 、config/ cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py 中 datasets的路径 没有修改
5 、图片的标签 (xml) 中含有其他类别,比如训练类别只有 a ,但是标签里面有 a, b, c。
6、学习率过高
总结:
1)如果nan的情况是间断性出现的,比如前面几个 batch 的loss正常下降,突然有几个batch的loss变成nan ,然后loss又正常了,就是前面5中情况中的一种,大概率是情况5。
2)如果nan的情况不是间断性出现的,比如前面几个batch 的loss正常下降,突然持续变成nan,不恢复正常,则有可能是情况6,在 schedule1x.py 中把learning_rate 调至它的0.1倍或者更小。
数据清洗
mmdetection v2版本,在gt的box和image的重叠区域为0时,会出现loss nan的情况,如下:
于是,我检查了一遍数据。检查数据check_data.py代码如下:
# -*- coding:utf-8 -*-
# A demo for checking whether the annos has out-of-image boxs.
import xml.etree.ElementTree as ET
import os
import cv2
import mmcv
from PIL import Image
import numpy as np
xml_root = "Annotations" # 原始的所有xml文件路径
new_xml_root = "new_xml" # check并纠正后的所有xml文件路径
image_root = "JPEGImages" # 所有图片的路径
xml_name_list = sorted(os.listdir(xml_root))
def print_all_classes():
all_name_list = []
for xml_name in xml_name_list:
print(f"{
xml_name}")
xml_path = os.path.join(xml_root, xml_name)
tree = ET.parse(xml_path)
root = tree.getroot()
for obj in root.findall("object"):
name = obj.find("name").text
all_name_list.append(name)
print(all_name_list)
def check_hw():
tranposed_name_lists = []
for xml_name in xml_name_list:
xml_path = os.path.join(xml_root, xml_name)
tree = ET.parse(xml_path)
root = tree.getroot()
size = root.find("size")
width = int(size.find("width").text)
height = int(size.find("height").text)
image_path = os.path.join(image_root, xml_name[:-4] + ".jpg")
img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
h, w, _ = img.shape
if height != h or width != w:
print(width, w, height, h)
print(f"{
xml_name}'s h, w is tranposed.")
tranposed_name_lists.append(xml_name)
print(tranposed_name_lists)
def check_bbox():
if not os.path.exists(new_xml_root):
os.makedirs(new_xml_root)
for xml_name in xml_name_list:
xml_path = os.path.join(xml_root, xml_name)
tree = ET.parse(xml_path)
root = tree.getroot()
for obj in root.findall("object"):
bnd_box = obj.find("bndbox")
bbox = [
int(float(bnd_box.find("xmin").text)),
int(float(bnd_box.find("ymin").text)),
int(float(bnd_box.find("xmax").text)),
int(float(bnd_box.find("ymax").text)),
]
image_path = os.path.join(image_root, xml_name[:-4] + ".jpg")
img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
h, w, _ = img.shape
if bbox[0] >= bbox[2] or bbox[1] >= bbox[3]:
print("bbox[0] >= bbox[2] or bbox[1] >= bbox[3]", bbox, xml_name)
# bboxes = np.array([bbox])
# mmcv.imshow_det_bboxes(img, bboxes, labels=np.array(["h"]))
# bbox_min_ge_max_name_lists.append(xml_name)
root.remove(obj) # 删除掉有问题的box
elif bbox[3] > h or bbox[2] > w:
bnd_box.find("xmax").text = str(min(w, bbox[2]))
bnd_box.find("ymax").text = str(min(h, bbox[3]))
print("bbox[3] > h or bbox[2] > w", bbox, h, w, xml_name)
# bboxes = np.array([bbox])
# mmcv.imshow_det_bboxes(img, bboxes, labels=np.array(["h"]))
# bbox_max_border_name_lists.append(xml_name)
tree.write(os.path.join(new_xml_root, xml_name))
check_bbox()
运行结果如下,会把有问题的xml文件名及其中有问题的box坐标输出来。
清洗数据后再次运行训练,不会再出现loss为nan的情况。