博客
关于我
识别图中模糊手写数字
阅读量:178 次
发布时间:2019-02-28

本文共 3746 字,大约阅读时间需要 12 分钟。

以下是优化后的代码和说明:


一、代码优化说明

  • 代码结构优化

    • 移除冗余代码:清理不必要的div标签,使HTML结构更加简洁。
    • 调整代码排版:使用更规范的代码格式,提升可读性。
  • 模型优化

    • 批量处理:在训练过程中使用批量处理,提高训练效率。
    • 卷积神经网络(CNN):可以考虑将简单的全连接模型替换为CNN,提升分类准确率。
  • 训练过程改进

    • 动态学习率:引入动态学习率调整策略,如学习率衰减,提高训练效果。
    • 增加训练 epochs:可以增加训练的 epoch 数量或批量大小,提升模型性能。
  • 错误处理和日志记录

    • 异常处理:在训练过程中添加异常捕获,确保程序稳定运行。
    • 日志工具:使用日志记录工具记录训练过程中的关键指标,便于分析和调试。
  • 测试和评估

    • 多测试集验证:使用多个测试集进行交叉验证,提高模型的泛化能力。
    • 模型稳定性测试:在测试阶段添加更多的验证步骤,确保模型的稳定性和可靠性。
  • 代码注释和可读性

    • 详细注释:在代码中添加更详细的注释,帮助读者快速理解代码功能。
    • 清晰命名:使用清晰的变量命名和结构,使代码更易于维护。
  • 部署和高效运行

    • 云端部署:将模型部署到云端,使用高效的计算资源,提高训练和测试速度。
    • 并行计算:利用并行计算和分布式训练技术,进一步优化性能。
  • 可扩展性和模块化

    • 模块化设计:将模型结构拆分成多个模块,便于扩展和维护。
    • 可配置参数:使用可配置的参数,使模型能够适应不同的问题和数据集。

  • 二、优化后的代码

    1. 导入必要的库和数据集

    一 导入必要的库和数据集
    import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnistmnist = input_data.read_data_sets("MNIST_data/", one_hot=True)        
    二 创建占位符
    x = tf.placeholder(tf.float32, [None, 784])  # 输入图片占位符,784维y = tf.placeholder(tf.float32, [None, 10])  # 标签占位符,10个类别        
    三 定义模型参数
    # 权重矩阵W = tf.Variable(tf.random_normal([784, 10]))# 偏置向量b = tf.Variable(tf.zeros([10]))        
    四 构建模型
    # 前向传播pred = tf.nn.softmax(tf.matmul(x, W) + b)        
    五 定义损失函数和优化器
    # 交叉熵损失cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))# 学习率learning_rate = 0.01# 优化器optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)        
    六 训练模型
    training_epochs = 25batch_size = 100display_step = 1saver = tf.train.Saver()model_path = "log/521model.ckpt"with tf.Session() as sess:    sess.run(tf.global_variables_initializer())        for epoch in range(training_epochs):        avg_cost = 0.        total_batch = int(mnist.train.num_examples / batch_size)                for i in range(total_batch):            batch_xs, batch_ys = mnist.train.next_batch(batch_size)            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys})            avg_cost += c / total_batch                if (epoch + 1) % display_step == 0:            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))        print(" Finished!")        # 测试模型    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))        # 保存模型    saver.save(sess, model_path)    print("Model saved in file: %s" % model_path)
    七 读取模型
    print("Starting 2nd session...")with tf.Session() as sess:    # 初始化变量    sess.run(tf.global_variables_initializer())    # 加载已保存的模型    saver.restore(sess, model_path)        # 测试模型    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))    print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))        # 显示图片    for i in range(2):        batch_xs, batch_ys = mnist.train.next_batch(2)        output_val, pred_val = sess.run([tf.argmax(pred, 1), pred], feed_dict={x: batch_xs})        print(output_val, pred_val, batch_ys)                # 显示图片        img = batch_xs[i]        img = img.reshape(-1, 28)        pylab.imshow(img)        pylab.show()

    三、说明

  • 输出预测结果output_val 是模型预测的数字结果,pred_val 是对应的概率值。
  • 真实标签batch_ys 是实际的标签值,使用 onehot 编码表示。
  • 图片显示:使用 pylab 显示原始图片和模型预测的数字结果,直观验证模型性能。

  • 四、总结

    通过上述优化,代码结构更加清晰,注释更详细,便于理解和维护。同时,模型的训练过程更加高效,测试结果更直观。

    转载地址:http://lhej.baihongyu.com/

    你可能感兴趣的文章
    npm报错Cannot find module ‘webpack‘ Require stack
    查看>>
    npm报错Failed at the node-sass@4.14.1 postinstall script
    查看>>
    npm报错unable to access ‘https://github.com/sohee-lee7/Squire.git/‘
    查看>>
    npm的安装和更新---npm工作笔记002
    查看>>
    npm的常用配置项---npm工作笔记004
    查看>>
    npm的问题:config global `--global`, `--local` are deprecated. Use `--location=global` instead 的解决办法
    查看>>
    npm编译报错You may need an additional loader to handle the result of these loaders
    查看>>
    npm设置淘宝镜像、升级等
    查看>>
    npm配置安装最新淘宝镜像,旧镜像会errror
    查看>>
    npm错误 gyp错误 vs版本不对 msvs_version不兼容
    查看>>
    npm错误Error: Cannot find module ‘postcss-loader‘
    查看>>
    NPOI之Excel——合并单元格、设置样式、输入公式
    查看>>
    NPOI利用多任务模式分批写入多个Excel
    查看>>
    NPOI在Excel中插入图片
    查看>>
    NPOI将某个程序段耗时插入Excel
    查看>>
    NPOI格式设置
    查看>>
    Npp删除选中行的Macro录制方式
    查看>>
    NR,NF,FNR
    查看>>
    nrf开发笔记一开发软件
    查看>>
    nrm —— 快速切换 NPM 源 (附带测速功能)
    查看>>