分享web开发知识

注册/登录|最近发布|今日推荐

主页 IT知识网页技术软件开发前端开发代码编程运营维护技术分享教程案例
当前位置:首页 > 技术分享

[MXNet逐梦之旅]练习一·使用MXNet拟合直线手动实现

发布时间:2023-09-06 02:33责任编辑:沈小雨关键词:暂无标签

[MXNet逐梦之旅]练习一·使用MXNet拟合直线手动实现

  • code
#%%from matplotlib import pyplot as pltfrom mxnet import autograd, ndimport random#%%num_inputs = 1num_examples = 100true_w = 1.56true_b = 1.24features = nd.arange(0,10,0.1).reshape((-1, 1))labels = true_w * features + true_blabels += nd.random.normal(scale=0.2, shape=labels.shape)features[0], labels[0]#%%# 本函数已保存在d2lzh包中方便以后使用def data_iter(batch_size, features, labels): ???num_examples = len(features) ???indices = list(range(num_examples)) ???random.shuffle(indices) ?# 样本的读取顺序是随机的 ???for i in range(0, num_examples, batch_size): ???????j = nd.array(indices[i: min(i + batch_size, num_examples)]) ???????yield features.take(j), labels.take(j) ?# take函数根据索引返回对应元素#%%batch_size = 10for X, y in data_iter(batch_size, features, labels): ???print(X, y) ???break#%%w = nd.random.normal(scale=0.01, shape=(num_inputs, 1))b = nd.zeros(shape=(1,))#%%w.attach_grad()b.attach_grad()#%%def linreg(X, w, b): ?# 本函数已保存在d2lzh包中方便以后使用 ???return nd.dot(X, w) + b#%%def squared_loss(y_hat, y): ?# 本函数已保存在d2lzh包中方便以后使用 ???return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2#%%def sgd(params, lr, batch_size): ?# 本函数已保存在d2lzh包中方便以后使用 ???for param in params: ???????param[:] = param - lr * param.grad / batch_size#%%lr = 0.05num_epochs = 20net = linregloss = squared_lossfor epoch in range(num_epochs): ?# 训练模型一共需要num_epochs个迭代周期 ???# 在每一个迭代周期中,会使用训练数据集中所有样本一次(假设样本数能够被批量大小整除)。X ???# 和y分别是小批量样本的特征和标签 ???for X, y in data_iter(batch_size, features, labels): ???????with autograd.record(): ???????????l = loss(net(X, w, b), y) ?# l是有关小批量X和y的损失 ???????l.backward() ?# 小批量的损失对模型参数求梯度 ???????sgd([w, b], lr, batch_size) ?# 使用小批量随机梯度下降迭代模型参数 ???train_l = loss(net(features, w, b), labels) ???print(‘epoch %d, loss %f‘ % (epoch + 1, train_l.mean().asnumpy()))#%%true_w, w#%%true_b, b#%%plt.scatter(features.asnumpy(), labels.asnumpy(), 1)labels1 = linreg(features,w,b)plt.scatter(features.asnumpy(), labels1.asnumpy(), 1)plt.show()
  • out

黄色是原始数据

绿色为拟合数据

?

[MXNet逐梦之旅]练习一·使用MXNet拟合直线手动实现

原文地址:https://www.cnblogs.com/xiaosongshine/p/10421687.html

知识推荐

我的编程学习网——分享web前端后端开发技术知识。 垃圾信息处理邮箱 tousu563@163.com 网站地图
icp备案号 闽ICP备2023006418号-8 不良信息举报平台 互联网安全管理备案 Copyright 2023 www.wodecom.cn All Rights Reserved