diff --git a/Xiao.js b/Xiao.js index 1afdd9dd..2f0124e7 100644 --- a/Xiao.js +++ b/Xiao.js @@ -235,14 +235,6 @@ client.on('ready', async () => { client.logger.error(`[TIMEZONES] Failed to set timezones\n${err.stack}`); } - // Set up parse-domain - try { - await client.loadParseDomain(); - client.logger.info('[PARSE DOMAIN] parse-domain loaded.'); - } catch (err) { - client.logger.error(`[PARSE DOMAIN] Failed to load parse-domain\n${err.stack}`); - } - // Fetch adult site list try { await client.fetchAdultSiteList(); @@ -253,7 +245,7 @@ client.on('ready', async () => { // Fetch NSFW model try { - await client.loadNSFWModel(); + await client.tensorflow.loadNSFWModel(); client.logger.info('[NSFW MODEL] Loaded NSFW model.'); } catch (err) { client.logger.error(`[NSFW MODEL] Failed to load NSFW model\n${err.stack}`); @@ -261,7 +253,7 @@ client.on('ready', async () => { // Set up face detection try { - await client.loadFaceDetector(); + await client.tensorflow.loadFaceDetector(); client.logger.info('[FACE DETECTOR] Loaded face detector.'); } catch (err) { client.logger.error(`[FACE DETECTOR] Failed to load face detector\n${err.stack}`); diff --git a/commands/analyze/faces.js b/commands/analyze/faces.js index 91801cbc..763cec54 100644 --- a/commands/analyze/faces.js +++ b/commands/analyze/faces.js @@ -25,7 +25,7 @@ module.exports = class FacesCommand extends Command { async run(msg, { image }) { const imgData = await request.get(image); - const faces = await this.client.detectFaces(imgData.body); + const faces = await this.client.tensorflow.detectFaces(imgData.body); if (!faces) return msg.reply('There are no faces in this image.'); if (faces === 'size') return msg.reply('This image is too large.'); const base = await loadImage(imgData.body); diff --git a/commands/analyze/is-it-down.js b/commands/analyze/is-it-down.js index 16124127..b1a7a7ad 100644 --- a/commands/analyze/is-it-down.js +++ b/commands/analyze/is-it-down.js @@ -1,5 +1,11 @@ const Command = require('../../framework/Command'); const request = require('node-superfetch'); +let parseDomain; +let ParseResultType; +import('parse-domain').then(loadedModule => { + parseDomain = loadedModule.parseDomain; + ParseResultType = loadedModule.ParseResultType; +}); module.exports = class IsItDownCommand extends Command { constructor(client) { @@ -26,8 +32,8 @@ module.exports = class IsItDownCommand extends Command { } async run(msg, { url }) { - const { type, domain, topLevelDomains } = this.client.parseDomain(url.hostname); - if (type !== this.client.ParseResultType.Listed) return msg.reply('This domain is not supported.'); + const { type, domain, topLevelDomains } = parseDomain(url.hostname); + if (type !== ParseResultType.Listed) return msg.reply('This domain is not supported.'); const { text } = await request .post('https://www.isitdownrightnow.com/check.php') .query({ domain: `${domain}.${topLevelDomains.join('.')}` }); diff --git a/commands/analyze/nsfw-image.js b/commands/analyze/nsfw-image.js index 628c674d..b6f5a2b8 100644 --- a/commands/analyze/nsfw-image.js +++ b/commands/analyze/nsfw-image.js @@ -1,7 +1,6 @@ const Command = require('../../framework/Command'); const request = require('node-superfetch'); const { stripIndents } = require('common-tags'); -const { isImageNSFW } = require('../../util/Util'); const displayNames = { Drawing: 'SFW (Drawing)', Neutral: 'SFW', @@ -33,7 +32,7 @@ module.exports = class NsfwImageCommand extends Command { async run(msg, { image }) { const { body } = await request.get(image); - const predictions = await isImageNSFW(this.client.nsfwModel, body, false); + const predictions = await this.client.tensorflow.isImageNSFW(body, false); const formatted = predictions.map(result => { const percentage = Math.round(result.probability * 100); return `${percentage}% ${displayNames[result.className]}`; diff --git a/commands/analyze/screenshot.js b/commands/analyze/screenshot.js index 54bea40e..5c802ee3 100644 --- a/commands/analyze/screenshot.js +++ b/commands/analyze/screenshot.js @@ -1,7 +1,7 @@ const Command = require('../../framework/Command'); const { PermissionFlagsBits } = require('discord.js'); const request = require('node-superfetch'); -const { isImageNSFW, isUrlNSFW } = require('../../util/Util'); +const { isUrlNSFW } = require('../../util/Util'); module.exports = class ScreenshotCommand extends Command { constructor(client) { @@ -41,7 +41,7 @@ module.exports = class ScreenshotCommand extends Command { } const { body } = await request.get(`https://image.thum.io/get/width/1920/crop/675/noanimate/${url.href}`); if (!msg.channel.nsfw) { - const aiDetect = await isImageNSFW(this.client.nsfwModel, body); + const aiDetect = await this.client.tensorflow.isImageNSFW(body); if (aiDetect) return msg.reply('This site isn\'t NSFW, but the resulting image was.'); } return msg.say({ files: [{ attachment: body, name: 'screenshot.png' }] }); diff --git a/commands/edit-face/anime-eyes.js b/commands/edit-face/anime-eyes.js index dd8c4cd0..6b77240b 100644 --- a/commands/edit-face/anime-eyes.js +++ b/commands/edit-face/anime-eyes.js @@ -28,7 +28,7 @@ module.exports = class AnimeEyesCommand extends Command { const leftEye = await loadImage(path.join(__dirname, '..', '..', 'assets', 'images', 'anime-eyes', 'left.png')); const rightEye = await loadImage(path.join(__dirname, '..', '..', 'assets', 'images', 'anime-eyes', 'right.png')); const imgData = await request.get(image); - const faces = await this.client.detectFaces(imgData.body); + const faces = await this.client.tensorflow.detectFaces(imgData.body); if (!faces) return msg.reply('There are no faces in this image.'); if (faces === 'size') return msg.reply('This image is too large.'); const base = await loadImage(imgData.body); diff --git a/commands/edit-face/danny-devito.js b/commands/edit-face/danny-devito.js index 8063ffdc..9cb79b6e 100644 --- a/commands/edit-face/danny-devito.js +++ b/commands/edit-face/danny-devito.js @@ -34,7 +34,7 @@ module.exports = class DannyDevitoCommand extends Command { async run(msg, { image }) { const danny = await loadImage(path.join(__dirname, '..', '..', 'assets', 'images', 'danny-devito.png')); const imgData = await request.get(image); - const faces = await this.client.detectFaces(imgData.body); + const faces = await this.client.tensorflow.detectFaces(imgData.body); if (!faces) return msg.reply('There are no faces in this image.'); if (faces === 'size') return msg.reply('This image is too large.'); const base = await loadImage(imgData.body); diff --git a/commands/edit-face/emoji-face.js b/commands/edit-face/emoji-face.js index 906476ae..0a05f8bc 100644 --- a/commands/edit-face/emoji-face.js +++ b/commands/edit-face/emoji-face.js @@ -41,7 +41,7 @@ module.exports = class EmojiFaceCommand extends Command { const emojiData = await request.get(emojiURL); const emojiImg = await loadImage(emojiData.body); const imgData = await request.get(image); - const faces = await this.client.detectFaces(imgData.body); + const faces = await this.client.tensorflow.detectFaces(imgData.body); if (!faces) return msg.reply('There are no faces in this image.'); if (faces === 'size') return msg.reply('This image is too large.'); const base = await loadImage(imgData.body); diff --git a/commands/edit-face/eyes.js b/commands/edit-face/eyes.js index 662b009f..639ac0a2 100644 --- a/commands/edit-face/eyes.js +++ b/commands/edit-face/eyes.js @@ -26,7 +26,7 @@ module.exports = class EyesCommand extends Command { async run(msg, { image }) { const eyes = await loadImage(path.join(__dirname, '..', '..', 'assets', 'images', 'eyes.png')); const imgData = await request.get(image); - const faces = await this.client.detectFaces(imgData.body); + const faces = await this.client.tensorflow.detectFaces(imgData.body); if (!faces) return msg.reply('There are no faces in this image.'); if (faces === 'size') return msg.reply('This image is too large.'); const base = await loadImage(imgData.body); diff --git a/commands/edit-face/shrek.js b/commands/edit-face/shrek.js index e88686b3..0c567b5d 100644 --- a/commands/edit-face/shrek.js +++ b/commands/edit-face/shrek.js @@ -34,7 +34,7 @@ module.exports = class ShrekCommand extends Command { async run(msg, { image }) { const shrek = await loadImage(path.join(__dirname, '..', '..', 'assets', 'images', 'shrek.png')); const imgData = await request.get(image); - const faces = await this.client.detectFaces(imgData.body); + const faces = await this.client.tensorflow.detectFaces(imgData.body); if (!faces) return msg.reply('There are no faces in this image.'); if (faces === 'size') return msg.reply('This image is too large.'); const base = await loadImage(imgData.body); diff --git a/package.json b/package.json index 33c1bdf5..e87d9456 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "xiao", - "version": "147.0.1", + "version": "148.0.0", "description": "Your personal server companion.", "main": "Xiao.js", "private": true, diff --git a/structures/Client.js b/structures/Client.js index a1cd3a9c..f31d5840 100644 --- a/structures/Client.js +++ b/structures/Client.js @@ -3,20 +3,16 @@ const request = require('node-superfetch'); const { Collection } = require('@discordjs/collection'); const winston = require('winston'); const fontFinder = require('font-finder'); -const nsfw = require('nsfwjs'); -const tfnode = require('@tensorflow/tfjs-node'); -const faceDetection = require('@tensorflow-models/face-detection'); -const model = faceDetection.SupportedModels.MediaPipeFaceDetector; const moment = require('moment-timezone'); const fs = require('fs'); -const url = require('url'); const path = require('path'); const Redis = require('./Redis'); +const Tensorflow = require('./Tensorflow'); const Font = require('./Font'); const PhoneManager = require('./phone/PhoneManager'); const TimerManager = require('./remind/TimerManager'); const PokemonStore = require('./pokemon/PokemonStore'); -const activities = require('./activity'); +const activities = require('./Activity'); const { REPORT_CHANNEL_ID, JOIN_LEAVE_CHANNEL_ID } = process.env; module.exports = class XiaoClient extends CommandClient { @@ -37,17 +33,9 @@ module.exports = class XiaoClient extends CommandClient { this.dispatchers = new Map(); this.cleverbots = new Map(); this.phone = new PhoneManager(this); + this.tensorflow = new Tensorflow(this); this.activities = activities; this.adultSiteList = null; - this.nsfwModel = null; - this.faceDetector = null; - } - - async loadParseDomain() { - const parseDomainModule = await import('parse-domain'); - this.parseDomain = parseDomainModule.parseDomain; - this.ParseResultType = parseDomainModule.ParseResultType; - return parseDomainModule; } async registerFontsIn(filepath) { @@ -76,31 +64,6 @@ module.exports = class XiaoClient extends CommandClient { return this.adultSiteList; } - async loadNSFWModel() { - const nsfwModel = await nsfw.load( - `${url.pathToFileURL(path.join(__dirname, '..', 'tf_models', 'nsfw', 'web_model')).href}/`, - { type: 'graph' } - ); - this.nsfwModel = nsfwModel; - return this.nsfwModel; - } - - async loadFaceDetector() { - this.faceDetector = await faceDetection.createDetector(model, { runtime: 'tfjs', maxFaces: 10 }); - return this.faceDetector; - } - - async detectFaces(imgData) { - if (Buffer.byteLength(imgData) >= 4e+6) return 'size'; - tfnode.setBackend('tensorflow'); - const image = tfnode.node.decodeImage(imgData, 3); - tfnode.setBackend('cpu'); - const faces = await this.faceDetector.estimateFaces(image); - tfnode.setBackend('tensorflow'); - if (!faces || !faces.length) return null; - return faces; - } - fetchReportChannel() { if (!REPORT_CHANNEL_ID) return null; return this.channels.fetch(REPORT_CHANNEL_ID); diff --git a/structures/Tensorflow.js b/structures/Tensorflow.js new file mode 100644 index 00000000..17712091 --- /dev/null +++ b/structures/Tensorflow.js @@ -0,0 +1,58 @@ +const tfnode = require('@tensorflow/tfjs-node'); +const nsfw = require('nsfwjs'); +const faceDetection = require('@tensorflow-models/face-detection'); +const url = require('url'); +const path = require('path'); + +module.exports = class Tensorflow { + constructor(client) { + Object.defineProperty(this, 'client', { value: client }); + + this.nsfwModel = null; + this.faceModel = faceDetection.SupportedModels.MediaPipeFaceDetector; + this.faceDetector = null; + } + + async loadNSFWModel() { + const nsfwModel = await nsfw.load( + `${url.pathToFileURL(path.join(__dirname, '..', 'tf_models', 'nsfw', 'web_model')).href}/`, + { type: 'graph' } + ); + this.nsfwModel = nsfwModel; + return this.nsfwModel; + } + + async loadFaceDetector() { + const faceDetector = await faceDetection.createDetector(this.faceModel, { runtime: 'tfjs', maxFaces: 10 }); + this.faceDetector = faceDetector; + return this.faceDetector; + } + + async detectFaces(imgData) { + if (Buffer.byteLength(imgData) >= 4e+6) return 'size'; + tfnode.setBackend('tensorflow'); + const image = tfnode.node.decodeImage(imgData, 3); + tfnode.setBackend('cpu'); + const faces = await this.faceDetector.estimateFaces(image); + tfnode.setBackend('tensorflow'); + image.dispose(); + if (!faces || !faces.length) return null; + return faces; + } + + async isImageNSFW(image, bool = true) { + const img = await tfnode.node.decodeImage(image, 3); + const predictions = await this.nsfwModel.classify(img); + img.dispose(); + if (bool) { + const results = []; + results.push(predictions[0]); + for (const result of predictions) { + if (result.className === predictions[0].className) continue; + if (result.probability >= predictions[0].probability - 0.1) results.push(result); + } + return results.some(result => result.className !== 'Drawing' && result.className !== 'Neutral'); + } + return predictions; + } +}; diff --git a/util/Util.js b/util/Util.js index 747b825c..02d6ea34 100644 --- a/util/Util.js +++ b/util/Util.js @@ -2,7 +2,6 @@ const { ActionRowBuilder, ButtonBuilder, PermissionFlagsBits, ButtonStyle, Compo const crypto = require('crypto'); const request = require('node-superfetch'); const fs = require('fs'); -const tf = require('@tensorflow/tfjs-node'); let parseDomain; let ParseResultType; import('parse-domain').then(loadedModule => { @@ -244,22 +243,6 @@ module.exports = class Util { return str; } - static async isImageNSFW(model, image, bool = true) { - const img = await tf.node.decodeImage(image, 3); - const predictions = await model.classify(img); - img.dispose(); - if (bool) { - const results = []; - results.push(predictions[0]); - for (const result of predictions) { - if (result.className === predictions[0].className) continue; - if (result.probability >= predictions[0].probability - 0.1) results.push(result); - } - return results.some(result => result.className !== 'Drawing' && result.className !== 'Neutral'); - } - return predictions; - } - static async reactIfAble(msg, user, emoji, fallbackEmoji) { const dm = !msg.guild; if (!emoji) emoji = fallbackEmoji;