diff --git a/package.json b/package.json index a7444ecc..8f2297b0 100644 --- a/package.json +++ b/package.json @@ -37,6 +37,7 @@ "@vitalets/google-translate-api": "^5.1.0", "aki-api": "^5.2.1", "bombsweeper.js": "^1.0.1", + "buffer-image-size": "^0.6.4", "canvas": "^2.7.0", "cheerio": "^1.0.0-rc.5", "cloc": "^2.7.0", diff --git a/util/Util.js b/util/Util.js index 44e5f49f..fa0d253b 100644 --- a/util/Util.js +++ b/util/Util.js @@ -1,5 +1,6 @@ const crypto = require('crypto'); const tf = require('@tensorflow/tfjs-node'); +const sizeOf = require('buffer-image-size'); const { decode: decodeHTML } = require('html-entities'); const { stripIndents } = require('common-tags'); const { URL } = require('url'); @@ -218,8 +219,14 @@ module.exports = class Util { } static async isImageNSFW(model, image, bool = true) { - const img = await tf.node.decodeImage(new Uint8Array(image), 3); - const predictions = model.classify(image, 1); + const dimensions = sizeOf(image); + const data = { + data: new Uint8Array(image), + width: dimensions.width, + height: dimensions.height + }; + const img = await tf.node.decodeImage(data, 3); + const predictions = await model.classify(image, 1); img.dispose(); return bool ? predictions[0] !== 'Neutral' && predictions[0] !== 'Drawing' : predictions[0]; }