NSFW Command with Tensorflow

This commit is contained in:
Dragon Fire
2021-03-22 17:23:12 -04:00
parent 9f3cd4d2db
commit d837eec501
6 changed files with 66 additions and 1 deletions
+3
View File
@@ -14,6 +14,9 @@ command-leaderboard.json
command-last-run.json
blacklist.json
# Tensorflow Models
tf_models/
# Tesseract Trained Data
*.traineddata
+7
View File
@@ -148,6 +148,13 @@ client.on('ready', async () => {
client.logger.error(`[ADULT SITES] Failed to fetch list\n${err.stack}`);
}
// Fetch NSFW model
try {
await client.loadNSFWModel();
} catch (err) {
client.logger.error(`[NSFW MODEL] Failed to load NSFW model\n${err.stack}`);
}
// Post bot list stats
await client.postBotsGGStats();
setInterval(() => client.postBotsGGStats(), 1.8e+6);
+36
View File
@@ -0,0 +1,36 @@
const Command = require('../../structures/Command');
const request = require('node-superfetch');
const { isImageNSFW } = require('../../util/Util');
module.exports = class NsfwCommand extends Command {
constructor(client) {
super(client, {
name: 'nsfw',
aliases: ['nsfw-image', 'nsfw-img', 'img-nsfw', 'image-nsfw'],
group: 'analyze',
memberName: 'nsfw',
description: 'Determines if an image is NSFW.',
throttling: {
usages: 1,
duration: 30
},
args: [
{
key: 'image',
prompt: 'What image would you like to test?',
type: 'image-or-avatar'
}
]
});
}
async run(msg, { image }) {
try {
const { body } = await request.get(image);
const prediction = await isImageNSFW(this.client.nsfwModel, body, false);
return msg.reply(`I'm **${prediction.probability}%** sure this image is: **${prediction.className}**.`);
} catch (err) {
return msg.reply(`Oh no, an error occurred: \`${err.message}\`. Try again later!`);
}
}
};
+3 -1
View File
@@ -1,6 +1,6 @@
{
"name": "xiao",
"version": "133.0.0",
"version": "133.1.0",
"description": "Your personal server companion.",
"main": "Xiao.js",
"scripts": {
@@ -33,6 +33,7 @@
"dependencies": {
"@discordjs/collection": "^0.1.6",
"@discordjs/opus": "^0.5.0",
"@tensorflow/tfjs-node": "^3.3.0",
"@vitalets/google-translate-api": "^5.1.0",
"aki-api": "^5.2.1",
"bombsweeper.js": "^1.0.1",
@@ -66,6 +67,7 @@
"moment-duration-format": "^2.3.2",
"moment-timezone": "^0.5.33",
"node-superfetch": "^0.1.11",
"nsfwjs": "^2.3.0",
"pokersolver": "^2.1.4",
"random-js": "^2.1.0",
"rss-parser": "^3.12.0",
+9
View File
@@ -4,7 +4,9 @@ const request = require('node-superfetch');
const Collection = require('@discordjs/collection');
const winston = require('winston');
const fontFinder = require('font-finder');
const nsfw = require('nsfwjs');
const fs = require('fs');
const url = require('url');
const path = require('path');
const Redis = require('./Redis');
const Font = require('./Font');
@@ -45,6 +47,7 @@ module.exports = class XiaoClient extends CommandoClient {
this.activities = activities;
this.leaveMessages = leaveMsgs;
this.adultSiteList = null;
this.nsfwModel = null;
}
async registerFontsIn(filepath) {
@@ -183,6 +186,12 @@ module.exports = class XiaoClient extends CommandoClient {
return this.adultSiteList;
}
async loadNSFWModel() {
const model = await nsfw.load(url.pathToFileURL(path.join(__dirname, '..', 'tf_models', 'nsfw')));
this.nsfwModel = model;
return this.nsfwModel;
}
fetchReportChannel() {
if (!REPORT_CHANNEL_ID) return null;
return this.channels.fetch(REPORT_CHANNEL_ID);
+8
View File
@@ -1,4 +1,5 @@
const crypto = require('crypto');
const tf = require('@tensorflow/tfjs-node');
const { decode: decodeHTML } = require('html-entities');
const { stripIndents } = require('common-tags');
const { URL } = require('url');
@@ -216,6 +217,13 @@ module.exports = class Util {
return str;
}
static isImageNSFW(model, image, bool = true) {
const img = await tf.node.decodeImage(new Uint8Array(image), 3);
const predictions = model.classify(image, 1);
img.dispose();
return bool ? predictions[0] !== 'Neutral' && predictions[0] !== 'Drawing' : predictions[0];
}
static async reactIfAble(msg, user, emoji, fallbackEmoji) {
const dm = !msg.guild;
if (fallbackEmoji && (!dm && !msg.channel.permissionsFor(user).has('USE_EXTERNAL_EMOJIS'))) {