forked from CodingTrain/Auto-Encoder-Demo
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathindex.js
More file actions
119 lines (106 loc) · 2.6 KB
/
index.js
File metadata and controls
119 lines (106 loc) · 2.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
console.log("Hello Autoencoder 🚂");
import * as tf from "@tensorflow/tfjs-node";
// import canvas from "canvas";
// const { loadImage } = canvas;
import Jimp from "jimp";
import numeral from "numeral";
main();
async function main() {
// Build the model
const autoencoder = buildModel();
// load all image data
const images = await loadImages(550);
// train the model
const x_train = tf.tensor2d(images.slice(0, 500));
await trainModel(autoencoder, x_train, 250);
// test the model
const x_test = tf.tensor2d(images.slice(500));
await generateTests(autoencoder, x_test);
}
async function generateTests(autoencoder, x_test) {
const output = autoencoder.predict(x_test);
// output.print();
const newImages = await output.array();
for (let i = 0; i < newImages.length; i++) {
const img = newImages[i];
const buffer = [];
for (let n = 0; n < img.length; n++) {
const val = Math.floor(img[n] * 255);
buffer[n * 4 + 0] = val;
buffer[n * 4 + 1] = val;
buffer[n * 4 + 2] = val;
buffer[n * 4 + 3] = 255;
}
const image = new Jimp(
{
data: Buffer.from(buffer),
width: 28,
height: 28,
},
(err, image) => {
const num = numeral(i).format("000");
image.write(`output/square${num}.png`);
}
);
}
}
function buildModel() {
const autoencoder = tf.sequential();
// Build the model
autoencoder.add(
tf.layers.dense({
units: 256,
inputShape: [784],
activation: "relu",
})
);
autoencoder.add(
tf.layers.dense({
units: 128,
activation: "relu",
})
);
autoencoder.add(
tf.layers.dense({
units: 256,
activation: "sigmoid",
})
);
autoencoder.add(
tf.layers.dense({
units: 784,
activation: "sigmoid",
})
);
autoencoder.compile({
optimizer: "adam",
loss: "binaryCrossentropy",
metrics: ["accuracy"],
});
return autoencoder;
}
async function trainModel(autoencoder, x_train, epochs) {
await autoencoder.fit(x_train, x_train, {
epochs: epochs,
batch_size: 32,
shuffle: true,
verbose: true,
});
}
async function loadImages(total) {
const allImages = [];
for (let i = 0; i < total; i++) {
const num = numeral(i).format("000");
const img = await Jimp.read(`data/square${num}.png`);
let rawData = [];
for (let n = 0; n < 28 * 28; n++) {
let index = n * 4;
let r = img.bitmap.data[index + 0];
// let g = img.bitmap.data[n + 1];
// let b = img.bitmap.data[n + 2];
rawData[n] = r / 255.0;
}
allImages[i] = rawData;
}
return allImages;
}