人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist)

大憨熊 提交于 2020-02-07 04:20:08

人工智能入门第一课:手写字体识别及可视化项目(手写画板)(mnist),使用技术(Django+js+tensorflow+html+bootstrap+inspinia框架)

直接上图,项目效果
在这里插入图片描述

1.训练模型

项目结构
在这里插入图片描述
运用tensorflow
70行代码即可搞定训练,首先大家需要下载mnist的数据集
大家可自行百度,形式如下
在这里插入图片描述
创建mnist_train.py文件
导入代码

# -*-coding:utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

learning_rate = 0.001
TRAINING_STEPS = 100000
BATCH_SIZE = 32

def conv_layer(input, in_channel, out_channel):
    # 定义卷积层
    w = tf.Variable(tf.truncated_normal([5, 5, in_channel, out_channel], stddev=0.1)) #生成一个5x5的矩阵
    b = tf.Variable(tf.constant(0.1, shape=[out_channel]))
    conv = tf.nn.conv2d(input, w, strides=[1, 1, 1, 1], padding="SAME")
    act = tf.nn.relu(conv + b)
    return act

def fc_layer(input, size_in, size_out):
    # 定义全连接层
    w = tf.Variable(tf.truncated_normal([size_in, size_out], stddev=0.1))
    b = tf.Variable(tf.constant(0.1, shape=[size_out]))
    fc =tf.matmul(input, w)+b
    return fc

#converlution_conv
def inference(image, keep_prob):
    # conv1
    conv1 = conv_layer(image,1, 32)
    conv1_pool = tf.nn.max_pool(conv1, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME")
    # conv2
    conv2= conv_layer(conv1_pool, 32, 64)
    conv2_pool = tf.nn.max_pool(conv2, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME")
    conv_flaten = tf.reshape(conv2_pool, shape=[-1, 7*7*64])
    # fc1
    fc1 = tf.nn.relu(fc_layer(conv_flaten, 7*7*64, 1024))
    fc1 = tf.nn.dropout(fc1, keep_prob=keep_prob)
    model_output=tf.add(fc_layer(fc1, 1024, 10),0,name='model_output')
    return model_output

def inputs():
    # 定义输入
    keep_prob = tf.placeholder(tf.float32, name="keep_prob")
    x = tf.placeholder(tf.float32, shape=[None, 784], name="x")
    y = tf.placeholder(tf.float32, shape=[None, 10], name="labels")
    return x, y, keep_prob


def loss(y_pred, y_real):
    # 定义交叉熵误差
    xent = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
            logits=y_pred, labels=y_real))
    return xent
def train_optimizer(loss, global_step=None):
    # 定义训练优化器
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)
    return train_step

def accuracy(y_pred, y_real):
    # 定义预测准确率
    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_real, 1))
    acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    return acc

def train(mnist):
    x, y, keep_prob = inputs()
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    global_step = tf.Variable(0, trainable=False)
    logits = inference(x_image, keep_prob=keep_prob)   # 定义模型
    losses = loss(y_pred=logits, y_real=y)
    train_step = train_optimizer(losses, global_step=global_step)
    acc = accuracy(y_pred=logits, y_real=y)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(TRAINING_STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            print('hello')
            print(xs.shape)
            _, loss_value, step = sess.run([train_step, losses, global_step],
                                           feed_dict={x: xs, y: ys, keep_prob:0.5})
            if i % 200 ==0:
                print("After %d training step(s), loss on training batch is %g."
                      % (step, loss_value))
            if i % 1000 == 0:
                valid_acc = sess.run(acc, feed_dict={x: mnist.validation.images, y: mnist.validation.labels, keep_prob:1.})
                print("After %d training step(s), accuracy on validation is %g." % (step, valid_acc))
        test_acc = sess.run(acc, feed_dict={x:mnist.test.images, y:mnist.test.labels, keep_prob:1.})
        print("After %d training step(s), accuracy on test is %g." % (TRAINING_STEPS, test_acc))
        saver=tf.train.Saver()
        saver.save(sess,'my_model/mnist_model.ckpt')
if __name__ == "__main__":
    mnist = input_data.read_data_sets("data/mnist_data/", one_hot=True)
    train(mnist)

运行此文件,开始训练
在这里插入图片描述
大家训练个50000步左右应该就差不多啦
我此处为啦方便就只训练啦和10000步左右,差不多90%以上的准确率

训练完成后会有,四个文件,此四个文件即为模型,一定要保存好,接下来就是应用模型啦
在这里插入图片描述

2.模型的可视化应用

既然是手写字体模型,那一定需要画板来进行可视化啦,此处我使用Django来构建一个web应用,创建Django前面的博客已经演示过,此处不再重复,不懂的同学前往https://blog.csdn.net/qq_40947673/article/details/104106573
项目架构如下
在这里插入图片描述
那最重要的就是前端交互啦
奉上前端

{% extends "index.html" %}
{% block details %}

![在这里插入图片描述](https://img-blog.csdnimg.cn/20200206163744844.gif)
    <div class="row  border-bottom white-bg dashboard-header">

        <div class="col-sm-3">
            <h2>手写字体模型测试</h2>
            <small>mnist model test.</small>
            <ul class="list-group clear-list m-t">
                <li class="list-group-item fist-item">
                    <span class="label label-success">1</span> 右侧画板画上数字
                </li>
                <li class="list-group-item">
                    <span class="label label-info">2</span> 点击识别
                </li>
                <li class="list-group-item">
                    <span class="label label-info">3</span> 显示结果
                </li>
                <li class="list-group-item">
                    <span class="label label-primary">4</span> 点击清空
                </li>
            </ul>
        </div>

        <div class="col-sm-6">
            <canvas id="canvas" width="500" height="500">
            </canvas>
            <div class="row text-left">
                <div class="col-xs-4">
                    <button onclick="shibie()" class="btn btn-primary  m-t">识别</button>
                    <button onclick="qingkong()" class="btn btn-primary  m-t" >清空</button>
                </div>

            </div>
        </div>

        <div class="col-sm-3">
            <div class="statistic-box">
                <div class="row text-center">

                    <div class="col-lg-6">
                        <h2>识别结果</h2>

                        <table id="tbl" border="1">
                            <tbody id="body"></tbody>
                        </table>
                    </div>
                </div>
            </div>
        </div>

        <style>
            #canvas {
                background: #fff;
                cursor: crosshair;
                margin-left: 10px;
                margin-top: 10px;
                -webkit-box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.5);
                -moz-box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.5);
                box-shadow: 4px 4px 8px rgba(0, 0, 0, 0.5);
            }


        </style>

        <script>
            var canvas = document.querySelector("canvas");
            var cobj = canvas.getContext("2d");

            var data = [];
            var s = "pen";
            var c = "#000";
            var w = "8";

            function drawGrid(stepX, stepY, color, lineWidth) {
                cobj.beginPath();
                // 创建垂直格网线路径
                for (var i = 0.5 + stepX; i < canvas.width; i += stepX) {
                    cobj.moveTo(i, 0);
                    cobj.lineTo(i, canvas.height);
                }
                // 创建水平格网线路径
                for (var j = 0.5 + stepY; j < canvas.height; j += stepY) {
                    cobj.moveTo(0, j);
                    cobj.lineTo(canvas.width, j);
                }
                // 设置绘制颜色
                cobj.strokeStyle = color;
                // 设置绘制线段的宽度
                cobj.lineWidth = lineWidth;
                // 绘制格网
                cobj.stroke();
                // 清除路径
                cobj.beginPath();
            }

            drawGrid(10, 10, 'lightgray', 0.5);

            canvas.onmousedown = function (e) {
                var ox = e.offsetX;
                var oy = e.offsetY;
                var draw = new Draw(cobj, {
                    color: c,
                    width: w
                });

                cobj.beginPath();
                cobj.moveTo(ox, oy);

                canvas.onmousemove = function (e) {
                    var mx = e.offsetX;
                    var my = e.offsetY;
                    if (s != "eraser") {
                        if (data.length != 0) {

                            cobj.putImageData(data[data.length - 1], 0, 0, 0, 0, 500, 500); //将某个图像数据放置到画布指定的位置上  后面四个参数可省略

                        }
                    }
                    //            cobj.strokeRect(ox,oy,mx-ox,my-oy);
                    // cobj.beginPath()

                    draw[s](ox, oy, mx, my);
                };
                document.onmouseup = function () {
                    data.push(cobj.getImageData(0, 0, 500, 500)); //获取画布当中指定区域当中所有的图形数据
                    canvas.onmousemove = null;
                    document.onmouseup = null;
                }
            };

            function qingkong() {
                cobj.clearRect(0, 0, 500, 500);
                data = [];
                drawGrid(10, 10, 'lightgray', 0.5);

            }

            class Draw {
                constructor(cobj, option) {
                    this.cobj = cobj;
                    this.color = option.color;
                    this.width = option.width;
                    this.style = option.style;
                }

                init() { //初始化
                    this.cobj.strokeStyle = this.color;
                    this.cobj.fillStyle = this.color;
                    this.cobj.lineWidth = this.width;
                }

                pen(ox, oy, mx, my) {
                    this.init();
                    this.cobj.lineTo(mx, my);
                    this.cobj.stroke();
                }
            }
        </script>

        <script>
            function shibie() {
                var img = document.getElementById("canvas").toDataURL("image/png");
                img = img.replace(/^data:image\/(png|jpg);base64,/, "");
                sendData = {
                    "img": img,
                };
                $.ajax({
                    traditional: true,
                    url: "/mnist/",
                    type: 'get',
                    data: sendData,
                    dataType: "json",
                    success: function (data) {
                        {#alert(data["res"])#}
                        var html = data["res"];
                        var html3 = document.createElement("h1");
                        html3.innerHTML = html;
                        $("#body").append(html3);
                    }
                })
            }
        </script>
    </div>

{% endblock %}

最后结果
在这里插入图片描述

由于细节太多就不一一叙述,git:https://github.com/kulinbin/mnist_web

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!