1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
| import * as tf from '@tensorflow/tfjs'; import * as tfvis from '@tensorflow/tfjs-vis'; import { getInputs } from "./data"; import { img2x, file2img } from './utils';
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json'; const NUM_CLASSES = 2; const BRAND_CLASSES = ['WeChat','QQ']
window.onload = async () => { const {inputs, labels} = await getInputs();
const surface = tfvis.visor().surface({name: '输入示例',styles:{height:250}}); inputs.forEach(imgE1 => { surface.drawArea.appendChild(imgE1); });
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH); mobilenet.summary(); const layer = mobilenet.getLayer('conv_pw_13_relu'); const truncatedMobilenet = tf.model({ inputs: mobilenet.inputs, outputs: layer.output });
const model = tf.sequential(); model.add(tf.layers.flatten({ inputShape: layer.outputShape.slice(1) })); model.add(tf.layers.dense({ units: 15, activation: 'relu' })); model.add(tf.layers.dense({ units: NUM_CLASSES, activation: 'softmax' }));
model.compile({loss: 'categoricalCrossentropy', optimizer: tf.train.adam()});
const { xs, ys } = tf.tidy(() => { const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl)))); const ys = tf.tensor(labels); return { xs, ys }; });
await model.fit(xs, ys, { epochs: 200, callbacks: tfvis.show.fitCallbacks( { name: '训练效果' }, ['loss'], { callbacks: ['onEpochEnd'] } ) });
window.predict = async (file) => { const img = await file2img(file); document.body.appendChild(img); const pred = tf.tidy(() => { const x = img2x(img); const input = truncatedMobilenet.predict(x); return model.predict(input); });
const index = pred.argMax(1).dataSync()[0]; setTimeout(() => { alert(`预测结果:${BRAND_CLASSES[index]}`); }, 0); };
};
|