分享web开发知识

注册/登录|最近发布|今日推荐

主页 IT知识网页技术软件开发前端开发代码编程运营维护技术分享教程案例
当前位置:首页 > 教程案例

Iris Classification Neural Network

发布时间:2023-09-06 01:55责任编辑:傅花花关键词:暂无标签

Iris Classification Neural Network

Neural Network

formula derivation

\[\begin{align}a & = ?x \cdot ?w_1 \y ?& = a \cdot w_2 \& = x \cdot w_1 \cdot w_2 \y & = softmax(y)\end{align}\]

code (training only)

\[a = x \cdot ?w_1 \y = a \cdot w_2\]

w1 = tf.Variable(tf.random_normal([4,5], stddev=1, seed=1))w2 = tf.Variable(tf.random_normal([5,3], stddev=1, seed=1))x = tf.placeholder(tf.float32, shape=(None, 4), name=‘x-input‘)a = tf.matmul(x, w1)y = tf.matmul(a, w2)

既然是有监督学习,那就在训练阶段必须要给出 label,以此来计算交叉熵

# 用来存储数据的标签y_ = tf.placeholder(tf.float32, shape=(None, 3), name=‘y-input‘)

隐藏层的激活函数是 sigmoid

y = tf.sigmoid(y)

softmax 与 交叉熵(corss entropy) 的组合函数,损失函数是交叉熵的均值

# softmax & corss_entropycross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y)# meancross_entropy_mean = tf.reduce_mean(cross_entropy)

为了防止神经网络过拟合,需加入正则化项,一般选取 “L2 正则化”

loss = cross_entropy_mean + \ ???tf.contrib.layers.l2_regularizer(regulation_lamda)(w1) + \ ???tf.contrib.layers.l2_regularizer(regulation_lamda)(w2)

为了加速神经网络的训练过程,需加入“指数衰减”技术

表示训练过程的计算图,优化方法选择了 Adam 算法,本质是反向传播算法。还可以选择“梯度下降法”(GradientDescentOptimizer)

train_step = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

训练阶段

with tf.Session() as sess: ?# Session 最好在“上下文机制”中开启,以防资源泄露 ???init_op = tf.global_variables_initializer() ?# 初始化网络中节点的参数,主要是 w1,w2 ???sess.run(init_op) ???steps = 10000 ???for i in range(steps): ???????beg = (i * batch_size) % dataset_size ???# 计算 batch ???????end = min(beg+batch_size, dataset_size) ?# 计算 batch ???????????sess.run(train_step, feed_dict={x:X[beg:end], y_:Y[beg:end]}) ?# 反向传播,训练网络 ???????if i % 1000 == 0: ???????????total_corss_entropy = sess.run( ?# 计算交叉熵 ???????????????cross_entropy_mean, ?????????# 计算交叉熵 ???????????????feed_dict={x:X, y_:Y} ???????# 计算交叉熵 ???????????) ???????????print("After %d training steps, cross entropy on all data is %g" % (i, total_corss_entropy))

在训练阶段中,需要引入“滑动平均模型”来提高模型在测试数据上的健壮性(这是书上的说法,而我认为是泛化能力)

全部代码

# -*- encoding=utf8 -*-from sklearn.datasets import load_irisimport tensorflow as tfdef label_convert(Y): ???l = list() ???for y in Y: ???????if y == 0: ???????????l.append([1,0,0]) ???????elif y == 1: ???????????l.append([0, 1, 0]) ???????elif y == 2: ???????????l.append([0, 0, 1]) ???return ldef load_data(): ???iris = load_iris() ???X = iris.data ???Y = label_convert(iris.target) ???return (X,Y)if __name__ == ‘__main__‘: ???X,Y = load_data() ???learning_rate = 0.001 ???batch_size = 10 ???dataset_size = 150 ???regulation_lamda = 0.001 ???w1 = tf.Variable(tf.random_normal([4,5], stddev=1, seed=1)) ???w2 = tf.Variable(tf.random_normal([5,3], stddev=1, seed=1)) ???x = tf.placeholder(tf.float32, shape=(None, 4), name=‘x-input‘) ???y_ = tf.placeholder(tf.float32, shape=(None, 3), name=‘y-input‘) ???a = tf.matmul(x, w1) ???y = tf.matmul(a, w2) ???y = tf.sigmoid(y) ???cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y) ???cross_entropy_mean = tf.reduce_mean(cross_entropy) ???loss = cross_entropy_mean + \ ??????????tf.contrib.layers.l2_regularizer(regulation_lamda)(w1) + \ ??????????tf.contrib.layers.l2_regularizer(regulation_lamda)(w2) ???train_step = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss) ???with tf.Session() as sess: ???????init_op = tf.global_variables_initializer() ???????sess.run(init_op) ???????steps = 10000 ???????for i in range(steps): ???????????beg = (i * batch_size) % dataset_size ???????????end = min(beg+batch_size, dataset_size) ???????????sess.run(train_step, feed_dict={x:X[beg:end], y_:Y[beg:end]}) ???????????if i % 1000 == 0: ???????????????total_corss_entropy = sess.run( ???????????????????cross_entropy_mean, ???????????????????feed_dict={x:X, y_:Y} ???????????????) ???????????????print("After %d training steps, cross entropy on all data is %g" % (i, total_corss_entropy)) ???????print(sess.run(w1)) ???????print(sess.run(w2))

Experiment Result

random split cross validation

Iris Classification Neural Network

原文地址:https://www.cnblogs.com/fengyubo/p/9060249.html

知识推荐

我的编程学习网——分享web前端后端开发技术知识。 垃圾信息处理邮箱 tousu563@163.com 网站地图
icp备案号 闽ICP备2023006418号-8 不良信息举报平台 互联网安全管理备案 Copyright 2023 www.wodecom.cn All Rights Reserved