分享web开发知识

注册/登录|最近发布|今日推荐

主页 IT知识网页技术软件开发前端开发代码编程运营维护技术分享教程案例
当前位置:首页 > 软件开发

MXNET:权重衰减-gluon实现

发布时间:2023-09-06 02:11责任编辑:彭小芳关键词:暂无标签

构建数据集

# -*- coding: utf-8 -*-from mxnet import initfrom mxnet import ndarray as ndfrom mxnet.gluon import loss as glossimport gbn_train = 20n_test = 100num_inputs = 200true_w = nd.ones((num_inputs, 1)) * 0.01true_b = 0.05features = nd.random.normal(shape=(n_train+n_test, num_inputs))labels = nd.dot(features, true_w) + true_blabels += nd.random.normal(scale=0.01, shape=labels.shape)train_features, test_features = features[:n_train, :], features[n_train:, :]train_labels, test_labels = labels[:n_train], labels[n_train:]

数据迭代器

from mxnet import autogradfrom mxnet.gluon import data as gdatabatch_size = 1num_epochs = 10learning_rate = 0.003train_iter = gdata.DataLoader(gdata.ArrayDataset( ???train_features, train_labels), batch_size, shuffle=True)loss = gloss.L2Loss()

训练并展示结果

gb.semilogy函数:绘制训练和测试数据的loss

from mxnet import gluonfrom mxnet.gluon import nndef fit_and_plot(weight_decay): ???net = nn.Sequential() ???net.add(nn.Dense(1)) ???net.initialize(init.Normal(sigma=1)) ???# 对权重参数做 L2 范数正则化,即权重衰减。 ???trainer_w = gluon.Trainer(net.collect_params('.*weight'), 'sgd', { ???????'learning_rate': learning_rate, 'wd': weight_decay}) ???# 不对偏差参数做 L2 范数正则化。 ???trainer_b = gluon.Trainer(net.collect_params('.*bias'), 'sgd', { ???????'learning_rate': learning_rate}) ???train_ls = [] ???test_ls = [] ???for _ in range(num_epochs): ???????for X, y in train_iter: ???????????with autograd.record(): ???????????????l = loss(net(X), y) ???????????l.backward() ???????????# 对两个 Trainer 实例分别调用 step 函数。 ???????????trainer_w.step(batch_size) ???????????trainer_b.step(batch_size) ???????train_ls.append(loss(net(train_features), ????????????????????????????train_labels).mean().asscalar()) ???????test_ls.append(loss(net(test_features), ???????????????????????????test_labels).mean().asscalar()) ???gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss', ???????????????range(1, num_epochs + 1), test_ls, ['train', 'test']) ???return 'w[:10]:', net[0].weight.data()[:, :10], 'b:', net[0].bias.data()print fit_and_plot(5)
  • 使用 Gluon 的 wd 超参数可以使用权重衰减来应对过拟合问题。
  • 我们可以定义多个 Trainer 实例对不同的模型参数使用不同的迭代方法。

MXNET:权重衰减-gluon实现

原文地址:https://www.cnblogs.com/houkai/p/9521015.html

知识推荐

我的编程学习网——分享web前端后端开发技术知识。 垃圾信息处理邮箱 tousu563@163.com 网站地图
icp备案号 闽ICP备2023006418号-8 不良信息举报平台 互联网安全管理备案 Copyright 2023 www.wodecom.cn All Rights Reserved