mxnet-读取模型参数


#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Fri Aug 10 16:13:29 2018
@author: myhaspl
"""
import mxnet as mx
from mxnet import nd 
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.gluon.data.vision import datasets, transforms 
import matplotlib.pyplot as plt

def build_lenet(net):    
    with net.name_scope():
        net.add(gluon.nn.Conv2D(channels=6,kernel_size=5,activation="relu"),
            gluon.nn.MaxPool2D(pool_size=2, strides=2),
            gluon.nn.Conv2D(channels=16, kernel_size=3, activation="relu"),
            gluon.nn.MaxPool2D(pool_size=2, strides=2),
            gluon.nn.Flatten(),
            gluon.nn.Dense(120, activation="relu"),
            gluon.nn.Dense(84, activation="relu"),
            gluon.nn.Dense(10))
        return net

text_labels = [
            't-shirt', 'trouser', 'pullover', 'dress', 'coat',
            'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'
]
#定义网络

#定义网络
net = build_lenet(gluon.nn.Sequential())
net.initialize(init=mx.init.Xavier())
print net

#加载模型参数
file_name = "net.params"
net.load_params(file_name)
#转换图像为(channel, height, weight)格式,并且为floating数据类型,通过transforms.ToTensor。
#另外,normalize所有像素值 使用 transforms.Normalize平均值0.13和标准差0.31. 
transformer = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(0.13, 0.31)])
mnist_valid = gluon.data.vision.FashionMNIST(train=False)
X, y = mnist_valid[:6]
preds = [] 
for x in X:
    x = transformer(x).expand_dims(axis=0)
    pred = net(x).argmax(axis=1)
    preds.append(pred.astype('int32').asscalar())
_, figs = plt.subplots(1, 6, figsize=(15, 15))
for f,x,yi,pyi in zip(figs, X, y, preds): 
    f.imshow(x.reshape((28,28)).asnumpy())
    ax = f.axes 
    ax.set_title(text_labels[yi]+'\n'+text_labels[pyi]) 
    ax.title.set_fontsize(20) 
    ax.get_xaxis().set_visible(False) 
    ax.get_yaxis().set_visible(False)
plt.show()

猜你喜欢

转载自blog.51cto.com/13959448/2317237