分享web开发知识

注册/登录|最近发布|今日推荐

主页 IT知识网页技术软件开发前端开发代码编程运营维护技术分享教程案例
当前位置:首页 > 教程案例

AlexNet 分类 FashionMNIST

发布时间:2023-09-06 02:24责任编辑:傅花花关键词:暂无标签
from mxnet import gluon,init,nd,autogradfrom mxnet.gluon import data as gdata,nnfrom mxnet.gluon import loss as glossimport mxnet as mximport timeimport osimport sys# 建立网络net = nn.Sequential()# 使用较大的 11 x 11 窗口来捕获物体。同时使用步幅 4 来较大减小输出高和宽。# 这里使用的输入通道数比 LeNet 中的也要大很多。net.add(nn.Conv2D(96, kernel_size=11, strides=4, activation=‘relu‘), ???????nn.MaxPool2D(pool_size=3, strides=2), ???????# 减小卷积窗口,使用填充为 2 来使得输入输出高宽一致,且增大输出通道数。 ???????nn.Conv2D(256, kernel_size=5, padding=2, activation=‘relu‘), ???????nn.MaxPool2D(pool_size=3, strides=2), ???????# 连续三个卷积层,且使用更小的卷积窗口。除了最后的卷积层外,进一步增大了输出通道数。 ???????# 前两个卷积层后不使用池化层来减小输入的高和宽。 ???????nn.Conv2D(384, kernel_size=3, padding=1, activation=‘relu‘), ???????nn.Conv2D(384, kernel_size=3, padding=1, activation=‘relu‘), ???????nn.Conv2D(256, kernel_size=3, padding=1, activation=‘relu‘), ???????nn.MaxPool2D(pool_size=3, strides=2), ???????# 这里全连接层的输出个数比 LeNet 中的大数倍。使用丢弃层来缓解过拟合。 ???????nn.Dense(4096, activation="relu"), nn.Dropout(0.5), ???????nn.Dense(4096, activation="relu"), nn.Dropout(0.5), ???????# 输出层。由于这里使用 Fashion-MNIST,所以用类别数为 10,而非论文中的 1000。 ???????nn.Dense(10))X = nd.random.uniform(shape=(1,1,224,224))net.initialize()for layer in net: ???X = layer(X) ???print(layer.name,‘output shape:\t‘,X.shape)# 读取数据# fashionMNIST 28*28 转为224*224def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join( ???????‘~‘, ‘.mxnet‘, ‘datasets‘, ‘fashion-mnist‘)): ???root = os.path.expanduser(root) ?# 展开用户路径 ‘~‘。 ???transformer = [] ???if resize: ???????transformer += [gdata.vision.transforms.Resize(resize)] ???transformer += [gdata.vision.transforms.ToTensor()] ???transformer = gdata.vision.transforms.Compose(transformer) ???mnist_train = gdata.vision.FashionMNIST(root=root, train=True) ???mnist_test = gdata.vision.FashionMNIST(root=root, train=False) ???num_workers = 0 if sys.platform.startswith(‘win32‘) else 4 ???train_iter = gdata.DataLoader( ???????mnist_train.transform_first(transformer), batch_size, shuffle=True, ???????num_workers=num_workers) ???test_iter = gdata.DataLoader( ???????mnist_test.transform_first(transformer), batch_size, shuffle=False, ???????num_workers=num_workers) ???return train_iter, test_iterbatch_size = 128train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)def accuracy(y_hat,y): ???return (y_hat.argmax(axis=1)==y.astype(‘float32‘)).mean().asscalar()def evaluate_accuracy(data_iter,net,ctx): ???acc = nd.array([0],ctx=ctx) ???for X,y in data_iter: ???????X = X.as_in_context(ctx) ???????y = y.as_in_context(ctx) ???????acc+=accuracy(net(X),y) ???return acc.asscalar() / len(data_iter)# 训练模型def train(net,train_iter,test_iter,batch_size,trainer,ctx,num_epochs): ???print(‘training on‘,ctx) ???loss = gloss.SoftmaxCrossEntropyLoss() ???for epoch in range(num_epochs): ???????train_l_sum = 0 ???????train_acc_sum = 0 ???????start = time.time() ???????for X,y in train_iter: ???????????X = X.as_in_context(ctx) ???????????y = y.as_in_context(ctx) ???????????with autograd.record(): ???????????????y_hat = net(X) ???????????????l = loss(y_hat,y) ???????????l.backward() ???????????trainer.step(batch_size) ???????????train_l_sum += l.mean().asscalar() ???????????train_acc_sum += evaluate_accuracy(test_iter,net,ctx) ???????test_acc = evaluate_accuracy(test_iter,net,ctx) ???????print(‘epoch %d, loss %.4f, train acc %.3f, test acc %.3f, ‘ ?????????????‘time %.1f sec‘ % (epoch+1,train_l_sum/len(train_iter),test_acc,time.time()-start))def try_gpu(): ???try: ???????ctx = mx.gpu() ???????_ = nd.zeros((1,),ctx=ctx) ???except mx.base.MXNetError: ???????ctx = mx.cpu() ???return ctxlr = 0.01num_epochs = 5ctx = try_gpu()net.initialize(force_reinit=True,ctx=ctx,init=init.Xavier())trainer = gluon.Trainer(net.collect_params(),‘sgd‘,{‘learning_rate‘:lr})train(net,train_iter,test_iter,batch_size,trainer,ctx,num_epochs)

AlexNet 分类 FashionMNIST

原文地址:https://www.cnblogs.com/TreeDream/p/10045670.html

知识推荐

我的编程学习网——分享web前端后端开发技术知识。 垃圾信息处理邮箱 tousu563@163.com 网站地图
icp备案号 闽ICP备2023006418号-8 不良信息举报平台 互联网安全管理备案 Copyright 2023 www.wodecom.cn All Rights Reserved