本文共 3746 字,大约阅读时间需要 12 分钟。
以下是优化后的代码和说明:
代码结构优化:
div标签,使HTML结构更加简洁。模型优化:
训练过程改进:
错误处理和日志记录:
测试和评估:
代码注释和可读性:
部署和高效运行:
可扩展性和模块化:
一 导入必要的库和数据集import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnistmnist = input_data.read_data_sets("MNIST_data/", one_hot=True)二 创建占位符x = tf.placeholder(tf.float32, [None, 784]) # 输入图片占位符,784维y = tf.placeholder(tf.float32, [None, 10]) # 标签占位符,10个类别三 定义模型参数# 权重矩阵W = tf.Variable(tf.random_normal([784, 10]))# 偏置向量b = tf.Variable(tf.zeros([10]))四 构建模型# 前向传播pred = tf.nn.softmax(tf.matmul(x, W) + b)五 定义损失函数和优化器# 交叉熵损失cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))# 学习率learning_rate = 0.01# 优化器optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)六 训练模型training_epochs = 25batch_size = 100display_step = 1saver = tf.train.Saver()model_path = "log/521model.ckpt"with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch in range(training_epochs): avg_cost = 0. total_batch = int(mnist.train.num_examples / batch_size) for i in range(total_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys}) avg_cost += c / total_batch if (epoch + 1) % display_step == 0: print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost)) print(" Finished!") # 测试模型 correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) # 保存模型 saver.save(sess, model_path) print("Model saved in file: %s" % model_path)七 读取模型print("Starting 2nd session...")with tf.Session() as sess: # 初始化变量 sess.run(tf.global_variables_initializer()) # 加载已保存的模型 saver.restore(sess, model_path) # 测试模型 correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) # 显示图片 for i in range(2): batch_xs, batch_ys = mnist.train.next_batch(2) output_val, pred_val = sess.run([tf.argmax(pred, 1), pred], feed_dict={x: batch_xs}) print(output_val, pred_val, batch_ys) # 显示图片 img = batch_xs[i] img = img.reshape(-1, 28) pylab.imshow(img) pylab.show()
output_val 是模型预测的数字结果,pred_val 是对应的概率值。batch_ys 是实际的标签值,使用 onehot 编码表示。pylab 显示原始图片和模型预测的数字结果,直观验证模型性能。通过上述优化,代码结构更加清晰,注释更详细,便于理解和维护。同时,模型的训练过程更加高效,测试结果更直观。
转载地址:http://lhej.baihongyu.com/