diff --git a/src/main.js b/src/main.js index 17d011f..0f449f5 100644 --- a/src/main.js +++ b/src/main.js @@ -17,7 +17,7 @@ const backendloaded = (async () => { try { // dead code elimination should occur here // eslint-disable-next-line camelcase - if (execution_mode === 'userscript') { + if (execution_mode === 'userscript' || execution_mode === 'test') { weightsData = (await import('./model.weights.bin')).default const tfwasmthreadedsimd = (await import('./tfjs-backend-wasm-threaded-simd.wasm')).default const tfwasmsimd = (await import('./tfjs-backend-wasm-simd.wasm')).default @@ -148,7 +148,7 @@ function imageFromCanvas (img, bg, off) { canvas.height = w * scale + pw * 2 canvas.width = th - const ctx = canvas.getContext('2d') + const ctx = canvas.getContext('2d', { willReadFrequently: true }) ctx.fillStyle = 'rgb(238,238,238)' ctx.fillRect(0, 0, canvas.width, canvas.height) @@ -161,6 +161,7 @@ function imageFromCanvas (img, bg, off) { const draw = function (off, adj) { if (adj) { + // stretching might cause interpolation that throws off the model, might need to clean up if (bg) { const border = 4 ctx.drawImage( @@ -254,6 +255,14 @@ async function predict (img, bg, off) { } const image = imageFromCanvas(img, bg, off) + for (let i = 0; i < image.data.length; i += 4) { + if (image.data[i + 0] || + image.data[i + 1] || + image.data[i + 2]) { + image.data[i + 0] = image.data[i + 1] = image.data[i + 2] = 238 + } + } + const tensor = tf.browser .fromPixels(image, 1) .mul(-1 / 238) @@ -379,7 +388,8 @@ async function imageFromUri (uri) { if (uri.startsWith('url("')) { uri = uri.substr(5, uri.length - 7) } - if (!uri.startsWith('data:')) { + // eslint-disable-next-line camelcase + if (execution_mode !== 'test' && !uri.startsWith('data:')) { return null }