mirror of
https://github.com/arthur-pbty/xiao.git
synced 2026-06-06 06:10:49 +02:00
NSFW Command with Tensorflow
This commit is contained in:
@@ -14,6 +14,9 @@ command-leaderboard.json
|
||||
command-last-run.json
|
||||
blacklist.json
|
||||
|
||||
# Tensorflow Models
|
||||
tf_models/
|
||||
|
||||
# Tesseract Trained Data
|
||||
*.traineddata
|
||||
|
||||
|
||||
@@ -148,6 +148,13 @@ client.on('ready', async () => {
|
||||
client.logger.error(`[ADULT SITES] Failed to fetch list\n${err.stack}`);
|
||||
}
|
||||
|
||||
// Fetch NSFW model
|
||||
try {
|
||||
await client.loadNSFWModel();
|
||||
} catch (err) {
|
||||
client.logger.error(`[NSFW MODEL] Failed to load NSFW model\n${err.stack}`);
|
||||
}
|
||||
|
||||
// Post bot list stats
|
||||
await client.postBotsGGStats();
|
||||
setInterval(() => client.postBotsGGStats(), 1.8e+6);
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
const Command = require('../../structures/Command');
|
||||
const request = require('node-superfetch');
|
||||
const { isImageNSFW } = require('../../util/Util');
|
||||
|
||||
module.exports = class NsfwCommand extends Command {
|
||||
constructor(client) {
|
||||
super(client, {
|
||||
name: 'nsfw',
|
||||
aliases: ['nsfw-image', 'nsfw-img', 'img-nsfw', 'image-nsfw'],
|
||||
group: 'analyze',
|
||||
memberName: 'nsfw',
|
||||
description: 'Determines if an image is NSFW.',
|
||||
throttling: {
|
||||
usages: 1,
|
||||
duration: 30
|
||||
},
|
||||
args: [
|
||||
{
|
||||
key: 'image',
|
||||
prompt: 'What image would you like to test?',
|
||||
type: 'image-or-avatar'
|
||||
}
|
||||
]
|
||||
});
|
||||
}
|
||||
|
||||
async run(msg, { image }) {
|
||||
try {
|
||||
const { body } = await request.get(image);
|
||||
const prediction = await isImageNSFW(this.client.nsfwModel, body, false);
|
||||
return msg.reply(`I'm **${prediction.probability}%** sure this image is: **${prediction.className}**.`);
|
||||
} catch (err) {
|
||||
return msg.reply(`Oh no, an error occurred: \`${err.message}\`. Try again later!`);
|
||||
}
|
||||
}
|
||||
};
|
||||
+3
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "xiao",
|
||||
"version": "133.0.0",
|
||||
"version": "133.1.0",
|
||||
"description": "Your personal server companion.",
|
||||
"main": "Xiao.js",
|
||||
"scripts": {
|
||||
@@ -33,6 +33,7 @@
|
||||
"dependencies": {
|
||||
"@discordjs/collection": "^0.1.6",
|
||||
"@discordjs/opus": "^0.5.0",
|
||||
"@tensorflow/tfjs-node": "^3.3.0",
|
||||
"@vitalets/google-translate-api": "^5.1.0",
|
||||
"aki-api": "^5.2.1",
|
||||
"bombsweeper.js": "^1.0.1",
|
||||
@@ -66,6 +67,7 @@
|
||||
"moment-duration-format": "^2.3.2",
|
||||
"moment-timezone": "^0.5.33",
|
||||
"node-superfetch": "^0.1.11",
|
||||
"nsfwjs": "^2.3.0",
|
||||
"pokersolver": "^2.1.4",
|
||||
"random-js": "^2.1.0",
|
||||
"rss-parser": "^3.12.0",
|
||||
|
||||
@@ -4,7 +4,9 @@ const request = require('node-superfetch');
|
||||
const Collection = require('@discordjs/collection');
|
||||
const winston = require('winston');
|
||||
const fontFinder = require('font-finder');
|
||||
const nsfw = require('nsfwjs');
|
||||
const fs = require('fs');
|
||||
const url = require('url');
|
||||
const path = require('path');
|
||||
const Redis = require('./Redis');
|
||||
const Font = require('./Font');
|
||||
@@ -45,6 +47,7 @@ module.exports = class XiaoClient extends CommandoClient {
|
||||
this.activities = activities;
|
||||
this.leaveMessages = leaveMsgs;
|
||||
this.adultSiteList = null;
|
||||
this.nsfwModel = null;
|
||||
}
|
||||
|
||||
async registerFontsIn(filepath) {
|
||||
@@ -183,6 +186,12 @@ module.exports = class XiaoClient extends CommandoClient {
|
||||
return this.adultSiteList;
|
||||
}
|
||||
|
||||
async loadNSFWModel() {
|
||||
const model = await nsfw.load(url.pathToFileURL(path.join(__dirname, '..', 'tf_models', 'nsfw')));
|
||||
this.nsfwModel = model;
|
||||
return this.nsfwModel;
|
||||
}
|
||||
|
||||
fetchReportChannel() {
|
||||
if (!REPORT_CHANNEL_ID) return null;
|
||||
return this.channels.fetch(REPORT_CHANNEL_ID);
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const crypto = require('crypto');
|
||||
const tf = require('@tensorflow/tfjs-node');
|
||||
const { decode: decodeHTML } = require('html-entities');
|
||||
const { stripIndents } = require('common-tags');
|
||||
const { URL } = require('url');
|
||||
@@ -216,6 +217,13 @@ module.exports = class Util {
|
||||
return str;
|
||||
}
|
||||
|
||||
static isImageNSFW(model, image, bool = true) {
|
||||
const img = await tf.node.decodeImage(new Uint8Array(image), 3);
|
||||
const predictions = model.classify(image, 1);
|
||||
img.dispose();
|
||||
return bool ? predictions[0] !== 'Neutral' && predictions[0] !== 'Drawing' : predictions[0];
|
||||
}
|
||||
|
||||
static async reactIfAble(msg, user, emoji, fallbackEmoji) {
|
||||
const dm = !msg.guild;
|
||||
if (fallbackEmoji && (!dm && !msg.channel.permissionsFor(user).has('USE_EXTERNAL_EMOJIS'))) {
|
||||
|
||||
Reference in New Issue
Block a user