分享web开发知识

注册/登录|最近发布|今日推荐

主页 IT知识网页技术软件开发前端开发代码编程运营维护技术分享教程案例
当前位置:首页 > 技术分享

mxnet-训练器与分批读取样本

发布时间:2023-09-06 02:22责任编辑:胡小海关键词:暂无标签
#!/usr/bin/env python2# -*- coding: utf-8 -*-"""Created on Fri Aug 10 16:13:29 2018@author: myhaspl"""from mxnet import nd, gluon, init, autogradfrom mxnet.gluon import nnfrom mxnet.gluon.data.vision import datasets,transforms import matplotlib.pyplot as pltfrom time import timemnist_train = datasets.FashionMNIST(train=True)X, y = mnist_train[0]print (‘X shape: ‘, X.shape, ‘X dtype‘, X.dtype, ‘y:‘, y,‘Y dtype‘, y.dtype)#x:(height, width, channel)#y:numpy.scalar,标签text_labels = [ ???????????‘t-shirt‘, ‘trouser‘, ‘pullover‘, ‘dress‘, ‘coat‘, ???????????‘sandal‘, ‘shirt‘, ‘sneaker‘, ‘bag‘, ‘ankle boot‘]X, y = mnist_train[0:6]#取6个样本_, figs = plt.subplots(1, X.shape[0], figsize=(15, 15))for f,x,yi in zip(figs, X,y): ???# 3D->2D by removing the last channel dim ???f.imshow(x.reshape((28,28)).asnumpy()) ???ax = f.axes ???ax.set_title(text_labels[int(yi)]) ???ax.title.set_fontsize(20) ???ax.get_xaxis().set_visible(False) ???ax.get_yaxis().set_visible(False)plt.show()#转换图像为(channel, height, weight)格式,并且为floating数据类型,通过transforms.ToTensor。#另外,normalize所有像素值 使用 transforms.Normalize平均值0.13和标准差0.31. transformer = transforms.Compose([ ???????????transforms.ToTensor(), ???????????transforms.Normalize(0.13, 0.31)])#只转换第一个元素,图像部分。第二个元素为标签。mnist_train = mnist_train.transform_first(transformer)#加载批次数据batch_size = 200train_data = gluon.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)#读取本批数据i=1for data, label in train_data: ???print i ???print data,label ???break#没有这一行,会以每批次200个数据来读取。mnist_valid = gluon.data.vision.FashionMNIST(train=False)valid_data = gluon.data.DataLoader(mnist_valid.transform_first(transformer),batch_size=batch_size, num_workers=4)#定义网络net = nn.Sequential()net.add(nn.Conv2D(channels=6,kernel_size=5,activation="relu"), ???????nn.MaxPool2D(pool_size=2, strides=2), ???????nn.Conv2D(channels=16, kernel_size=3, activation="relu"), ???????nn.MaxPool2D(pool_size=2, strides=2), ???????nn.Flatten(), ???????nn.Dense(120, activation="relu"), ???????nn.Dense(84, activation="relu"), ???????nn.Dense(10))net.initialize(init=init.Xavier())print net#输出softmax与误差softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()#定义训练器trainer = gluon.Trainer(net.collect_params(), ‘sgd‘, {‘learning_rate‘: 0.1})

-0.41935483
? ? -0.41935483]
? ?[-0.41935483 -0.41935483 -0.41935483 ... -0.41935483 -0.41935483
? ? -0.41935483]]]]
<NDArray 200x1x28x28 @cpu_shared(0)>?
[9 0 9 ... 3 8 5]
<NDArray 200 @cpu_shared(0)>
Sequential(
? (0): Conv2D(None -> 6, kernel_size=(5, 5), stride=(1, 1))
? (1): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
? (2): Conv2D(None -> 16, kernel_size=(3, 3), stride=(1, 1))
? (3): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
? (4): Flatten
? (5): Dense(None -> 120, Activation(relu))
? (6): Dense(None -> 84, Activation(relu))
? (7): Dense(None -> 10, linear)
)

mxnet-训练器与分批读取样本

原文地址:http://blog.51cto.com/13959448/2317239

知识推荐

我的编程学习网——分享web前端后端开发技术知识。 垃圾信息处理邮箱 tousu563@163.com 网站地图
icp备案号 闽ICP备2023006418号-8 不良信息举报平台 互联网安全管理备案 Copyright 2023 www.wodecom.cn All Rights Reserved