0%

利用迁移学习进行图标识别

迁移学习是将一个预训练的模型通过截断和重构后,用在别的任务中。

Javascript代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from "./data";
import { img2x, file2img } from './utils';

//加载模型
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 2;
const BRAND_CLASSES = ['WeChat','QQ']

window.onload = async () => {
const {inputs, labels} = await getInputs();
//console.log(inputs,labels);

const surface = tfvis.visor().surface({name: '输入示例',styles:{height:250}});
inputs.forEach(imgE1 => {
surface.drawArea.appendChild(imgE1);
});

//截断模型
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
mobilenet.summary();
const layer = mobilenet.getLayer('conv_pw_13_relu');
const truncatedMobilenet = tf.model({
inputs: mobilenet.inputs,
outputs: layer.output
});

//添加双层神经网络
const model = tf.sequential();
model.add(tf.layers.flatten({
inputShape: layer.outputShape.slice(1)
}));
model.add(tf.layers.dense({
units: 15,
activation: 'relu'
}));
model.add(tf.layers.dense({
units: NUM_CLASSES,
activation: 'softmax'
}));

//设置损失函数和优化器
model.compile({loss: 'categoricalCrossentropy', optimizer: tf.train.adam()});

//使用截断模型预处理训练数据集
const { xs, ys } = tf.tidy(() => {
const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));
const ys = tf.tensor(labels);
return { xs, ys };
});

//使用双层神经网络进行模型训练
await model.fit(xs, ys, {
epochs: 200,
callbacks: tfvis.show.fitCallbacks(
{ name: '训练效果' },
['loss'],
{ callbacks: ['onEpochEnd'] }
)
});

//进行预测
window.predict = async (file) => {
const img = await file2img(file);
document.body.appendChild(img);
const pred = tf.tidy(() => {
const x = img2x(img);
const input = truncatedMobilenet.predict(x);
return model.predict(input);
});

const index = pred.argMax(1).dataSync()[0];
setTimeout(() => {
alert(`预测结果:${BRAND_CLASSES[index]}`);
}, 0);
};

};

效果图: