0%

手写数字识别

利用卷积神经网络+canvas实现了手写数字识别

Javascript代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { MnistData } from './data';

window.onload = async () => {
const data = new MnistData();
await data.load();
const examples = data.nextTestBatch(20);
//console.log(examples);

const surface = tfvis.visor().surface(
{name: '输入示例:'}
);

for (let i = 0; i < 20; i += 1) {
const imageTensor = tf.tidy(() => {
return examples.xs.slice([i, 0], [1, 784]).reshape([28, 28, 1]);
});

const canvas = document.createElement('canvas');
canvas.width = 28;
canvas.height = 28;
canvas.style = 'margin: 4px';
await tf.browser.toPixels(imageTensor, canvas);
surface.drawArea.appendChild(canvas);

};

//搭建卷积神经网络
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [28, 28,1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPool2d({
poolSize:[2, 2],
strides: [2, 2]
}));
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPool2d({
poolSize:[2, 2],
strides: [2, 2]
}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({
units: 10,
activation: 'softmax',
kernelInitializer: 'varianceScaling'
}));

//定义损失函数和优化器
model.compile({
loss: 'categoricalCrossentropy',
optimizer: tf.train.adam(),
metrics: 'accuracy'
});

//训练集和测试集
const [trainXs, trainYs] = tf.tidy(() => {
const d = data.nextTestBatch(5000);
return [
d.xs.reshape([5000, 28, 28, 1]),
d.labels,
];
});
const [testXs, testYs] = tf.tidy(() => {
const d = data.nextTestBatch(200);
return [
d.xs.reshape([200, 28, 28, 1]),
d.labels,
];
});

//可视化训练过程
await model.fit(trainXs, trainYs, {
validationData: [testXs, testYs],
epochs: 20,
callbacks: tfvis.show.fitCallbacks(
{name:'训练效果'},
['loss', 'val_loss', 'acc', 'val_acc'],
{callbacks: ['onEpochEnd']}
)
});

//定义画板功能
const canvas = document.querySelector('canvas');

canvas.addEventListener('mousemove',(e) => {
if(e.buttons === 1){
const ctx = canvas.getContext('2d');
ctx.fillStyle = 'rgb(255, 255, 255)';
ctx.fillRect(e.offsetX, e.offsetY, 20, 20);
}
});

window.clear = () => {
const ctx = canvas.getContext('2d');
ctx.fillStyle = 'rgb(0, 0, 0)';
ctx.fillRect(0, 0, 300, 300);
};

clear();

//预测结果
window.predict = () => {
const input = tf.tidy(() => {
return tf.image.resizeBilinear(
tf.browser.fromPixels(canvas),
[28, 28],
true,
)
.slice([0, 0, 0], [28, 28, 1])
.toFloat()
.div(255)
.reshape([1, 28, 28, 1])
});
const pred = model.predict(input).argMax(1);
alert(`预测结果为 ${pred.dataSync()[0]}`);
};
};

html代码

1
2
3
4
5
//绘制画板
<canvas width="300" height="300" style="border: 2px solid #666;"></canvas>
<br>
<button onclick="window.clear();" style="margin: 4px;">清除</button>
<button onclick="window.predict();" style="margin: 4px;">预测</button>