-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdrawn_digit_input.js
117 lines (107 loc) · 3.26 KB
/
drawn_digit_input.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
export function createDrawnDigitInput({
root,
width,
height,
size,
cellColor,
predict,
}) {
// Define the state of the canvas.
const data = Array.from({ length: size }, () =>
Array.from({ length: size }, () => 0)
);
// Define the animate function.
let isAnimating = false;
function animate() {
if (!isAnimating) {
isAnimating = true;
}
ctx.clearRect(0, 0, canvas.width, canvas.height);
const cellHeight = canvas.height / size;
const cellWidth = canvas.width / size;
for (let i = 0; i < size; i++) {
const y = i * cellHeight;
for (let j = 0; j < size; j++) {
const x = j * cellWidth;
if (data[i][j] === 1) {
ctx.fillStyle = cellColor;
} else {
ctx.fillStyle = "#fff";
}
ctx.fillRect(x, y, cellWidth, cellHeight);
ctx.strokeStyle = "black";
ctx.lineWidth = 0.25;
ctx.strokeRect(x, y, cellWidth, cellHeight);
}
}
requestAnimationFrame(animate);
}
// Define the draw handler.
function handleDraw(event, brush = [[1]]) {
const cellHeight = canvas.height / size;
const cellWidth = canvas.width / size;
const i = Math.floor(event.offsetY / cellHeight);
const j = Math.floor(event.offsetX / cellWidth);
for (let ii = 0; ii < brush.length; ii++) {
for (let jj = 0; jj < brush[ii].length; jj++) {
if (brush[ii][jj] === 1) {
const iii = i + ii - Math.floor(brush.length / 2);
const jjj = j + jj - Math.floor(brush[ii].length / 2);
if (iii >= 0 && iii < size && jjj >= 0 && jjj < size) {
data[iii][jjj] = 1;
}
}
}
}
}
// Define the output update function.
function updateOutput() {
const prediction = predict(data);
outputContainer.innerHTML = `Prediction: <b>${prediction}</b> - ${new Date().toLocaleTimeString()}`;
}
// Create the canvas element.
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d");
canvas.width = width;
canvas.height = height;
canvas.style.border = "1px solid black";
let isDrawing = false;
document.body.addEventListener("mousedown", () => {
isDrawing = true;
});
document.body.addEventListener("mouseup", () => {
isDrawing = false;
});
canvas.addEventListener("mousemove", (event) => {
if (!isDrawing) {
return;
}
handleDraw(event, [
[0, 1, 0],
[1, 1, 1],
[0, 1, 0],
]);
});
const predictButton = document.createElement("button");
predictButton.classList.add("button");
predictButton.textContent = "Predict";
predictButton.addEventListener("click", () => {
updateOutput();
});
const clearButton = document.createElement("button");
clearButton.classList.add("button");
clearButton.textContent = "Clear";
clearButton.addEventListener("click", () => {
data.forEach((row) => row.fill(0));
});
const outputContainer = document.createElement("p");
outputContainer.textContent =
'💡 Hint: Draw a digit on the canvas below and click the "Predict" button. See what happens!';
// Append the elements to the body.
root.appendChild(predictButton);
root.appendChild(clearButton);
root.appendChild(outputContainer);
root.appendChild(canvas);
// Kick off animation loop.
animate();
}