From 0d8e8fd4d9b3aa30516efd4abaa14a8df08e5696 Mon Sep 17 00:00:00 2001 From: Dragon Fire Date: Mon, 22 Mar 2021 21:43:11 -0400 Subject: [PATCH] More accurate check --- commands/analyze/nsfw.js | 10 +++++++--- util/Util.js | 12 ++++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/commands/analyze/nsfw.js b/commands/analyze/nsfw.js index 3133f47f..d1305323 100644 --- a/commands/analyze/nsfw.js +++ b/commands/analyze/nsfw.js @@ -1,5 +1,6 @@ const Command = require('../../structures/Command'); const request = require('node-superfetch'); +const { stripIndents } = require('common-tags'); const { isImageNSFW } = require('../../util/Util'); module.exports = class NsfwCommand extends Command { @@ -27,9 +28,12 @@ module.exports = class NsfwCommand extends Command { async run(msg, { image }) { try { const { body } = await request.get(image); - const prediction = await isImageNSFW(this.client.nsfwModel, body, false); - const prob = Math.round(prediction.probability * 100); - return msg.reply(`I'm **${prob}%** sure this image is: **${prediction.className}**.`); + const predictions = await isImageNSFW(this.client.nsfwModel, body, false); + const formatted = predictions.map(result => `${Math.round(result.probability * 100)}% ${result.className}`); + return msg.reply(stripIndents` + **This image gives the following results:** + ${formatted.join('\n')} + `); } catch (err) { return msg.reply(`Oh no, an error occurred: \`${err.message}\`. Try again later!`); } diff --git a/util/Util.js b/util/Util.js index d5d14e46..7a7b1b86 100644 --- a/util/Util.js +++ b/util/Util.js @@ -219,9 +219,17 @@ module.exports = class Util { static async isImageNSFW(model, image, bool = true) { const img = await tf.node.decodeImage(image, 3); - const predictions = await model.classify(img, 1); + const predictions = await model.classify(img, 2); img.dispose(); - return bool ? predictions[0].className !== 'Neutral' && predictions[0].className !== 'Drawing' : predictions[0]; + const results = []; + results.push(predictions[0]); + for (const result of predictions) { + if (result.className === predictions[0].className) continue; + if (result.probability >= predictions[0].probability - 5) results.push(result); + } + return bool + ? results.some(result => result.className !== 'Drawing' && result.className !== 'Neutral') + : results; } static async reactIfAble(msg, user, emoji, fallbackEmoji) {