分享web开发知识

注册/登录|最近发布|今日推荐

主页 IT知识网页技术软件开发前端开发代码编程运营维护技术分享教程案例
当前位置:首页 > IT知识

encode与decode

发布时间:2023-09-06 02:27责任编辑:赖小花关键词:暂无标签
import torchfrom torch import nnimport numpy as npimport matplotlib.pyplot as pltimport torch.utils.data as Dataimport torchvisionfrom mpl_toolkits.mplot3d import Axes3D ???#画3D图from matplotlib import cm# Hyper ParametersEPOCH=10BATCH_SIZE=64LR = 0.005 # learning rateDOWNLOAD_MNIST=FalseN_TEST_IMG=5train_data=torchvision.datasets.MNIST( ???root=‘./mnist/‘, ???train=True, ???transform=torchvision.transforms.ToTensor(), ???download=DOWNLOAD_MNIST)train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)class AutoEncoder(nn.Module): ???def __init__(self): ???????super(AutoEncoder, self).__init__() ???????self.encoder = nn.Sequential( ???????????nn.Linear(28 * 28, 128), ???????????nn.Tanh(), ???????????nn.Linear(128,64), ???????????nn.Tanh(), ???????????nn.Linear(64, 12), ???????????# nn.Tanh(), ???????????# nn.Linear(12, 3), ???????) ???????self.decoder=nn.Sequential( ???????????# nn.Linear(3,12), ???????????# nn.Tanh(), ???????????nn.Linear(12, 64), ???????????nn.Tanh(), ???????????nn.Linear(64, 128), ???????????nn.Tanh(), ???????????nn.Linear(128, 28*28), ???????????nn.Sigmoid() ???????) ???def forward(self, x ): ??????encoder=self.encoder(x) ??????decoder=self.decoder(encoder) ??????return ?encoder,decoderAutoEncoder = AutoEncoder()# print(AutoEncoder)optimizer = torch.optim.Adam(AutoEncoder.parameters(), lr=LR) ?# optimize all cnn parametersloss_func = nn.MSELoss()f,a=plt.subplots(2,N_TEST_IMG,figsize=(5,2))plt.ion() ?# continuously plotview_data=train_data.train_data[:N_TEST_IMG].view(-1,28*28).type(torch.FloatTensor)/255for i in range(N_TEST_IMG): ???a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap=‘gray‘) ???a[0][i].set_xticks(()) ???a[0][i].set_yticks(())for epoch in range(EPOCH): ???for step,(x,b_label) in enumerate(train_loader): ???????b_x=x.view(-1,28*28) ???????b_y=x.view(-1,28*28) ???????encoded, decoded = AutoEncoder(b_x) ???????loss=loss_func(decoded,b_y) ???????optimizer.zero_grad() ???????loss.backward() ???????optimizer.step() ???????if step%100==0: ???????????print(‘Epoch:|‘,epoch,‘train loss:%0.4f‘%loss.data.numpy()) ???????????_,decoded_data=AutoEncoder(view_data) ???????????for i in range(N_TEST_IMG): ???????????????a[1][i].clear() ???????????????a[1][i].imshow(np.reshape(decoded.data.numpy()[i],(28,28)),cmap=‘gray‘) ???????????????a[1][i].set_xticks(()) ???????????????a[1][i].set_yticks(()) ???????????plt.draw() ???????????plt.pause(0.05)plt.ioff()plt.show()view_data=train_data.train_data[:200].view(-1,28*28).type(torch.FloatTensor)/255encoded_data,_=AutoEncoder(view_data)fig=plt.figure(2)ax=Axes3D(fig)X,Y,Z=encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()values=train_data.train_labels[:200].numpy()for x,y,z ,s in zip(X,Y,Z,values): ???c=cm.rainbow(int(255*s/9)) ???ax.text(x,y,z,s,backgroundcolor=c)ax.set_xlim(X.min(),X.max())ax.set_ylim(Y.min(),Y.max())ax.set_zlim(Z.min(),Z.max())plt.show()

选出五张图片做测试。

图像分为5*2显示,上面一行是原始图像,下面一行为编码和解码后的图像。

encode与decode

原文地址:https://www.cnblogs.com/wmy-ncut/p/10190482.html

知识推荐

我的编程学习网——分享web前端后端开发技术知识。 垃圾信息处理邮箱 tousu563@163.com 网站地图
icp备案号 闽ICP备2023006418号-8 不良信息举报平台 互联网安全管理备案 Copyright 2023 www.wodecom.cn All Rights Reserved