#!/usr/bin/env python2# -*- coding: utf-8 -*-"""Created on Fri Aug 10 16:13:29 2018@author: myhaspl"""import mxnet as mxfrom mxnet import nd from mxnet import gluonfrom mxnet.gluon import nnfrom mxnet.gluon.data.vision import datasets, transforms import matplotlib.pyplot as pltdef build_lenet(net): ???????with net.name_scope(): ???????net.add(gluon.nn.Conv2D(channels=6,kernel_size=5,activation="relu"), ???????????gluon.nn.MaxPool2D(pool_size=2, strides=2), ???????????gluon.nn.Conv2D(channels=16, kernel_size=3, activation="relu"), ???????????gluon.nn.MaxPool2D(pool_size=2, strides=2), ???????????gluon.nn.Flatten(), ???????????gluon.nn.Dense(120, activation="relu"), ???????????gluon.nn.Dense(84, activation="relu"), ???????????gluon.nn.Dense(10)) ???????return nettext_labels = [ ???????????‘t-shirt‘, ‘trouser‘, ‘pullover‘, ‘dress‘, ‘coat‘, ???????????‘sandal‘, ‘shirt‘, ‘sneaker‘, ‘bag‘, ‘ankle boot‘]#定义网络#定义网络net = build_lenet(gluon.nn.Sequential())net.initialize(init=mx.init.Xavier())print net#加载模型参数file_name = "net.params"net.load_params(file_name)#转换图像为(channel, height, weight)格式,并且为floating数据类型,通过transforms.ToTensor。#另外,normalize所有像素值 使用 transforms.Normalize平均值0.13和标准差0.31. transformer = transforms.Compose([ ???????????transforms.ToTensor(), ???????????transforms.Normalize(0.13, 0.31)])mnist_valid = gluon.data.vision.FashionMNIST(train=False)X, y = mnist_valid[:6]preds = [] for x in X: ???x = transformer(x).expand_dims(axis=0) ???pred = net(x).argmax(axis=1) ???preds.append(pred.astype(‘int32‘).asscalar())_, figs = plt.subplots(1, 6, figsize=(15, 15))for f,x,yi,pyi in zip(figs, X, y, preds): ????f.imshow(x.reshape((28,28)).asnumpy()) ???ax = f.axes ????ax.set_title(text_labels[yi]+‘\n‘+text_labels[pyi]) ????ax.title.set_fontsize(20) ????ax.get_xaxis().set_visible(False) ????ax.get_yaxis().set_visible(False)plt.show()
mxnet-读取模型参数
原文地址:http://blog.51cto.com/13959448/2317237