Mask R-CNN(十):代码理解inspect_weights.ipynb

版权声明:本文为博主原创文章,未经作者允许请勿转载。 https://blog.csdn.net/heiheiya https://blog.csdn.net/heiheiya/article/details/82114010

一、导包

import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import keras

#设置根目录
ROOT_DIR = os.path.abspath("../../")

#导入Mask RCNN
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log

%matplotlib inline 

#保存log和model的目录
MODEL_DIR = os.path.join(ROOT_DIR, "logs")

#预训练权重文件的路径
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
#下载COCO预训练权重文件
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)

#Shapes数据集预训练权重文件的路径
SHAPES_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_shapes.h5")

二、配置

#以下代码块二选其一即可

# Shapes toy数据集
# import shapes
# config = shapes.ShapesConfig()

# MS COCO数据集
import coco
config = coco.CocoConfig()

三、Notebook Preferences

#选择加载神经网络的设备.
#当你同时在该设备上训练模型的时候这个参数就比较有用了 
#你可以使用CPU,将GPU留作训练用
DEVICE = "/cpu:0"  # /cpu:0 or /gpu:0

def get_ax(rows=1, cols=1, size=16):
    """返回一个在该notebook中用于所有可视化的Matplotlib Axes array。
    提供一个中央点坐标来控制graph的尺寸。
    
    调整attribute的尺寸来控制渲染多大的图像
    """
    _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
    return ax

三、加载Model

#创建一个用于预测的model
with tf.device(DEVICE):
    model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR,
                              config=config)

#设置weights文件路径
if config.NAME == "shapes":
    weights_path = SHAPES_MODEL_PATH
elif config.NAME == "coco":
    weights_path = COCO_MODEL_PATH
#或者取消下面的注释行,加载最近训练的模型
# weights_path = model.find_last()

#加载weights
print("Loading weights ", weights_path)
model.load_weights(weights_path, by_name=True)

#显示所有训练的weights    
visualize.display_weight_stats(model)

下面是部分权重。

四、Weights的直方图

#选择一些layers显示
LAYER_TYPES = ['Conv2D', 'Dense', 'Conv2DTranspose']
#获取layers
layers = model.get_trainable_layers()
layers = list(filter(lambda l: l.__class__.__name__ in LAYER_TYPES, 
                layers))
#显示直方图
fig, ax = plt.subplots(len(layers), 2, figsize=(10, 3*len(layers)),
                       gridspec_kw={"hspace":1})
for l, layer in enumerate(layers):
    weights = layer.get_weights()
    for w, weight in enumerate(weights):
        tensor = layer.weights[w]
        ax[l, w].set_title(tensor.name)
        _ = ax[l, w].hist(weight[w].flatten(), 50)

猜你喜欢

转载自blog.csdn.net/heiheiya/article/details/82114010