From d837eec501dc55c6a2863f54d5e5256059ac6e73 Mon Sep 17 00:00:00 2001 From: Dragon Fire Date: Mon, 22 Mar 2021 17:23:12 -0400 Subject: [PATCH] NSFW Command with Tensorflow --- .gitignore | 3 +++ Xiao.js | 7 +++++++ commands/analyze/nsfw.js | 36 ++++++++++++++++++++++++++++++++++++ package.json | 4 +++- structures/Client.js | 9 +++++++++ util/Util.js | 8 ++++++++ 6 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 commands/analyze/nsfw.js diff --git a/.gitignore b/.gitignore index 5c23c14a..79b52548 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,9 @@ command-leaderboard.json command-last-run.json blacklist.json +# Tensorflow Models +tf_models/ + # Tesseract Trained Data *.traineddata diff --git a/Xiao.js b/Xiao.js index 00c2f211..c3058a46 100644 --- a/Xiao.js +++ b/Xiao.js @@ -148,6 +148,13 @@ client.on('ready', async () => { client.logger.error(`[ADULT SITES] Failed to fetch list\n${err.stack}`); } + // Fetch NSFW model + try { + await client.loadNSFWModel(); + } catch (err) { + client.logger.error(`[NSFW MODEL] Failed to load NSFW model\n${err.stack}`); + } + // Post bot list stats await client.postBotsGGStats(); setInterval(() => client.postBotsGGStats(), 1.8e+6); diff --git a/commands/analyze/nsfw.js b/commands/analyze/nsfw.js new file mode 100644 index 00000000..3da02fa5 --- /dev/null +++ b/commands/analyze/nsfw.js @@ -0,0 +1,36 @@ +const Command = require('../../structures/Command'); +const request = require('node-superfetch'); +const { isImageNSFW } = require('../../util/Util'); + +module.exports = class NsfwCommand extends Command { + constructor(client) { + super(client, { + name: 'nsfw', + aliases: ['nsfw-image', 'nsfw-img', 'img-nsfw', 'image-nsfw'], + group: 'analyze', + memberName: 'nsfw', + description: 'Determines if an image is NSFW.', + throttling: { + usages: 1, + duration: 30 + }, + args: [ + { + key: 'image', + prompt: 'What image would you like to test?', + type: 'image-or-avatar' + } + ] + }); + } + + async run(msg, { image }) { + try { + const { body } = await request.get(image); + const prediction = await isImageNSFW(this.client.nsfwModel, body, false); + return msg.reply(`I'm **${prediction.probability}%** sure this image is: **${prediction.className}**.`); + } catch (err) { + return msg.reply(`Oh no, an error occurred: \`${err.message}\`. Try again later!`); + } + } +}; diff --git a/package.json b/package.json index 99ff4b4b..a7444ecc 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "xiao", - "version": "133.0.0", + "version": "133.1.0", "description": "Your personal server companion.", "main": "Xiao.js", "scripts": { @@ -33,6 +33,7 @@ "dependencies": { "@discordjs/collection": "^0.1.6", "@discordjs/opus": "^0.5.0", + "@tensorflow/tfjs-node": "^3.3.0", "@vitalets/google-translate-api": "^5.1.0", "aki-api": "^5.2.1", "bombsweeper.js": "^1.0.1", @@ -66,6 +67,7 @@ "moment-duration-format": "^2.3.2", "moment-timezone": "^0.5.33", "node-superfetch": "^0.1.11", + "nsfwjs": "^2.3.0", "pokersolver": "^2.1.4", "random-js": "^2.1.0", "rss-parser": "^3.12.0", diff --git a/structures/Client.js b/structures/Client.js index 13bb9aee..47a3f388 100644 --- a/structures/Client.js +++ b/structures/Client.js @@ -4,7 +4,9 @@ const request = require('node-superfetch'); const Collection = require('@discordjs/collection'); const winston = require('winston'); const fontFinder = require('font-finder'); +const nsfw = require('nsfwjs'); const fs = require('fs'); +const url = require('url'); const path = require('path'); const Redis = require('./Redis'); const Font = require('./Font'); @@ -45,6 +47,7 @@ module.exports = class XiaoClient extends CommandoClient { this.activities = activities; this.leaveMessages = leaveMsgs; this.adultSiteList = null; + this.nsfwModel = null; } async registerFontsIn(filepath) { @@ -183,6 +186,12 @@ module.exports = class XiaoClient extends CommandoClient { return this.adultSiteList; } + async loadNSFWModel() { + const model = await nsfw.load(url.pathToFileURL(path.join(__dirname, '..', 'tf_models', 'nsfw'))); + this.nsfwModel = model; + return this.nsfwModel; + } + fetchReportChannel() { if (!REPORT_CHANNEL_ID) return null; return this.channels.fetch(REPORT_CHANNEL_ID); diff --git a/util/Util.js b/util/Util.js index 3cb876ea..fbc22fef 100644 --- a/util/Util.js +++ b/util/Util.js @@ -1,4 +1,5 @@ const crypto = require('crypto'); +const tf = require('@tensorflow/tfjs-node'); const { decode: decodeHTML } = require('html-entities'); const { stripIndents } = require('common-tags'); const { URL } = require('url'); @@ -216,6 +217,13 @@ module.exports = class Util { return str; } + static isImageNSFW(model, image, bool = true) { + const img = await tf.node.decodeImage(new Uint8Array(image), 3); + const predictions = model.classify(image, 1); + img.dispose(); + return bool ? predictions[0] !== 'Neutral' && predictions[0] !== 'Drawing' : predictions[0]; + } + static async reactIfAble(msg, user, emoji, fallbackEmoji) { const dm = !msg.guild; if (fallbackEmoji && (!dm && !msg.channel.permissionsFor(user).has('USE_EXTERNAL_EMOJIS'))) {