diff --git a/.gitignore b/.gitignore index 36fd3ed7..1bfa58fb 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,3 @@ tf_models/ # In-Development Commands commands/edit-face/danny-devito.js commands/edit-face/emoji-face.js -commands/edit-face/eyes.js diff --git a/Xiao.js b/Xiao.js index a420a670..4f962f55 100644 --- a/Xiao.js +++ b/Xiao.js @@ -210,6 +210,10 @@ client.on('ready', async () => { client.logger.error(`[NSFW MODEL] Failed to load NSFW model\n${err.stack}`); } + // Set up face detection + await client.loadFaceDetector(); + client.logger.info('[FACE DETECTOR] Loaded face detector.'); + // Fetch all members for (const [id, guild] of client.guilds.cache) { // eslint-disable-line no-unused-vars await guild.members.fetch(); diff --git a/commands/edit-face/anime-eyes.js b/commands/edit-face/anime-eyes.js index da38e6c6..5032eea3 100644 --- a/commands/edit-face/anime-eyes.js +++ b/commands/edit-face/anime-eyes.js @@ -1,9 +1,6 @@ const Command = require('../../framework/Command'); const request = require('node-superfetch'); const { createCanvas, loadImage } = require('canvas'); -const tfnode = require('@tensorflow/tfjs-node'); -const faceDetection = require('@tensorflow-models/face-detection'); -const model = faceDetection.SupportedModels.MediaPipeFaceDetector; const path = require('path'); module.exports = class AnimeEyesCommand extends Command { @@ -26,16 +23,13 @@ module.exports = class AnimeEyesCommand extends Command { } ] }); - - this.detector = null; } async run(msg, { image }) { - if (!this.detector) this.detector = await faceDetection.createDetector(model, { runtime: 'tfjs', maxFaces: 10 }); const leftEye = await loadImage(path.join(__dirname, '..', '..', 'assets', 'images', 'anime-eyes', 'left.png')); const rightEye = await loadImage(path.join(__dirname, '..', '..', 'assets', 'images', 'anime-eyes', 'right.png')); const imgData = await request.get(image); - const faces = await this.detect(imgData.body); + const faces = await this.client.detectFaces(imgData.body); if (!faces) return msg.reply('There are no faces in this image.'); if (faces === 'size') return msg.reply('This image is too large.'); const base = await loadImage(imgData.body); @@ -56,15 +50,4 @@ module.exports = class AnimeEyesCommand extends Command { } return msg.say({ files: [{ attachment: canvas.toBuffer(), name: 'anime-eyes.png' }] }); } - - async detect(imgData) { - if (Buffer.byteLength(imgData) >= 4e+6) return 'size'; - tfnode.setBackend('tensorflow'); - const image = tfnode.node.decodeImage(imgData); - tfnode.setBackend('cpu'); - const faces = await this.detector.estimateFaces(image); - tfnode.setBackend('tensorflow'); - if (!faces || !faces.length) return null; - return faces; - } }; diff --git a/commands/edit-face/eyes.js b/commands/edit-face/eyes.js new file mode 100644 index 00000000..5b9e98bf --- /dev/null +++ b/commands/edit-face/eyes.js @@ -0,0 +1,51 @@ +const Command = require('../../framework/Command'); +const request = require('node-superfetch'); +const { createCanvas, loadImage } = require('canvas'); +const path = require('path'); + +module.exports = class EyesCommand extends Command { + constructor(client) { + super(client, { + name: 'eyes', + group: 'edit-face', + memberName: 'eyes', + description: 'Draws emoji eyes onto the faces in an image.', + throttling: { + usages: 1, + duration: 60 + }, + args: [ + { + key: 'image', + prompt: 'What face would you like to scan?', + type: 'image-or-avatar' + } + ] + }); + } + + async run(msg, { image }) { + const eyes = await loadImage(path.join(__dirname, '..', '..', 'assets', 'images', 'eyes.png')); + const imgData = await request.get(image); + const faces = await this.client.detectFaces(imgData); + if (!faces) return msg.reply('There are no faces in this image.'); + if (faces === 'size') return msg.reply('This image is too large.'); + const base = await loadImage(imgData.body); + const canvas = createCanvas(base.width, base.height); + const ctx = canvas.getContext('2d'); + ctx.drawImage(base, 0, 0); + for (const face of faces) { + const eyeWidth = face.box.width / 4; + const eyeHeight = face.box.height / 4; + const leftEyeData = face.keypoints.find(landmark => landmark.name === 'leftEye'); + const rightEyeData = face.keypoints.find(landmark => landmark.name === 'rightEye'); + const leftEyeX = leftEyeData.x - (eyeWidth / 2); + const leftEyeY = leftEyeData.y - (eyeHeight / 2); + const rightEyeX = rightEyeData.x - (eyeWidth / 2); + const rightEyeY = rightEyeData.y - (eyeHeight / 2); + ctx.drawImage(eyes, leftEyeX, leftEyeY, eyeWidth, eyeHeight); + ctx.drawImage(eyes, rightEyeX, rightEyeY, eyeWidth, eyeHeight); + } + return msg.say({ files: [{ attachment: canvas.toBuffer(), name: 'eyes.png' }] }); + } +}; diff --git a/structures/Client.js b/structures/Client.js index a128eaf6..9c79ef20 100644 --- a/structures/Client.js +++ b/structures/Client.js @@ -4,6 +4,9 @@ const { Collection } = require('@discordjs/collection'); const winston = require('winston'); const fontFinder = require('font-finder'); const nsfw = require('nsfwjs'); +const tfnode = require('@tensorflow/tfjs-node'); +const faceDetection = require('@tensorflow-models/face-detection'); +const model = faceDetection.SupportedModels.MediaPipeFaceDetector; const moment = require('moment-timezone'); const fs = require('fs'); const url = require('url'); @@ -38,6 +41,7 @@ module.exports = class XiaoClient extends CommandClient { this.activities = activities; this.adultSiteList = null; this.nsfwModel = null; + this.faceDetector = null; } async loadParseDomain() { @@ -82,6 +86,22 @@ module.exports = class XiaoClient extends CommandClient { return this.nsfwModel; } + async loadFaceDetector() { + this.faceDetector = await faceDetection.createDetector(model, { runtime: 'tfjs', maxFaces: 10 }); + return this.faceDetector; + } + + async detectFaces(imgData) { + if (Buffer.byteLength(imgData) >= 4e+6) return 'size'; + tfnode.setBackend('tensorflow'); + const image = tfnode.node.decodeImage(imgData); + tfnode.setBackend('cpu'); + const faces = await this.faceDetector.estimateFaces(image); + tfnode.setBackend('tensorflow'); + if (!faces || !faces.length) return null; + return faces; + } + fetchReportChannel() { if (!REPORT_CHANNEL_ID) return null; return this.channels.fetch(REPORT_CHANNEL_ID);