mirror of
https://github.com/arthur-pbty/xiao.git
synced 2026-06-03 23:36:43 +02:00
Rewrite how tensorflow works
This commit is contained in:
+3
-40
@@ -3,20 +3,16 @@ const request = require('node-superfetch');
|
||||
const { Collection } = require('@discordjs/collection');
|
||||
const winston = require('winston');
|
||||
const fontFinder = require('font-finder');
|
||||
const nsfw = require('nsfwjs');
|
||||
const tfnode = require('@tensorflow/tfjs-node');
|
||||
const faceDetection = require('@tensorflow-models/face-detection');
|
||||
const model = faceDetection.SupportedModels.MediaPipeFaceDetector;
|
||||
const moment = require('moment-timezone');
|
||||
const fs = require('fs');
|
||||
const url = require('url');
|
||||
const path = require('path');
|
||||
const Redis = require('./Redis');
|
||||
const Tensorflow = require('./Tensorflow');
|
||||
const Font = require('./Font');
|
||||
const PhoneManager = require('./phone/PhoneManager');
|
||||
const TimerManager = require('./remind/TimerManager');
|
||||
const PokemonStore = require('./pokemon/PokemonStore');
|
||||
const activities = require('./activity');
|
||||
const activities = require('./Activity');
|
||||
const { REPORT_CHANNEL_ID, JOIN_LEAVE_CHANNEL_ID } = process.env;
|
||||
|
||||
module.exports = class XiaoClient extends CommandClient {
|
||||
@@ -37,17 +33,9 @@ module.exports = class XiaoClient extends CommandClient {
|
||||
this.dispatchers = new Map();
|
||||
this.cleverbots = new Map();
|
||||
this.phone = new PhoneManager(this);
|
||||
this.tensorflow = new Tensorflow(this);
|
||||
this.activities = activities;
|
||||
this.adultSiteList = null;
|
||||
this.nsfwModel = null;
|
||||
this.faceDetector = null;
|
||||
}
|
||||
|
||||
async loadParseDomain() {
|
||||
const parseDomainModule = await import('parse-domain');
|
||||
this.parseDomain = parseDomainModule.parseDomain;
|
||||
this.ParseResultType = parseDomainModule.ParseResultType;
|
||||
return parseDomainModule;
|
||||
}
|
||||
|
||||
async registerFontsIn(filepath) {
|
||||
@@ -76,31 +64,6 @@ module.exports = class XiaoClient extends CommandClient {
|
||||
return this.adultSiteList;
|
||||
}
|
||||
|
||||
async loadNSFWModel() {
|
||||
const nsfwModel = await nsfw.load(
|
||||
`${url.pathToFileURL(path.join(__dirname, '..', 'tf_models', 'nsfw', 'web_model')).href}/`,
|
||||
{ type: 'graph' }
|
||||
);
|
||||
this.nsfwModel = nsfwModel;
|
||||
return this.nsfwModel;
|
||||
}
|
||||
|
||||
async loadFaceDetector() {
|
||||
this.faceDetector = await faceDetection.createDetector(model, { runtime: 'tfjs', maxFaces: 10 });
|
||||
return this.faceDetector;
|
||||
}
|
||||
|
||||
async detectFaces(imgData) {
|
||||
if (Buffer.byteLength(imgData) >= 4e+6) return 'size';
|
||||
tfnode.setBackend('tensorflow');
|
||||
const image = tfnode.node.decodeImage(imgData, 3);
|
||||
tfnode.setBackend('cpu');
|
||||
const faces = await this.faceDetector.estimateFaces(image);
|
||||
tfnode.setBackend('tensorflow');
|
||||
if (!faces || !faces.length) return null;
|
||||
return faces;
|
||||
}
|
||||
|
||||
fetchReportChannel() {
|
||||
if (!REPORT_CHANNEL_ID) return null;
|
||||
return this.channels.fetch(REPORT_CHANNEL_ID);
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
const tfnode = require('@tensorflow/tfjs-node');
|
||||
const nsfw = require('nsfwjs');
|
||||
const faceDetection = require('@tensorflow-models/face-detection');
|
||||
const url = require('url');
|
||||
const path = require('path');
|
||||
|
||||
module.exports = class Tensorflow {
|
||||
constructor(client) {
|
||||
Object.defineProperty(this, 'client', { value: client });
|
||||
|
||||
this.nsfwModel = null;
|
||||
this.faceModel = faceDetection.SupportedModels.MediaPipeFaceDetector;
|
||||
this.faceDetector = null;
|
||||
}
|
||||
|
||||
async loadNSFWModel() {
|
||||
const nsfwModel = await nsfw.load(
|
||||
`${url.pathToFileURL(path.join(__dirname, '..', 'tf_models', 'nsfw', 'web_model')).href}/`,
|
||||
{ type: 'graph' }
|
||||
);
|
||||
this.nsfwModel = nsfwModel;
|
||||
return this.nsfwModel;
|
||||
}
|
||||
|
||||
async loadFaceDetector() {
|
||||
const faceDetector = await faceDetection.createDetector(this.faceModel, { runtime: 'tfjs', maxFaces: 10 });
|
||||
this.faceDetector = faceDetector;
|
||||
return this.faceDetector;
|
||||
}
|
||||
|
||||
async detectFaces(imgData) {
|
||||
if (Buffer.byteLength(imgData) >= 4e+6) return 'size';
|
||||
tfnode.setBackend('tensorflow');
|
||||
const image = tfnode.node.decodeImage(imgData, 3);
|
||||
tfnode.setBackend('cpu');
|
||||
const faces = await this.faceDetector.estimateFaces(image);
|
||||
tfnode.setBackend('tensorflow');
|
||||
image.dispose();
|
||||
if (!faces || !faces.length) return null;
|
||||
return faces;
|
||||
}
|
||||
|
||||
async isImageNSFW(image, bool = true) {
|
||||
const img = await tfnode.node.decodeImage(image, 3);
|
||||
const predictions = await this.nsfwModel.classify(img);
|
||||
img.dispose();
|
||||
if (bool) {
|
||||
const results = [];
|
||||
results.push(predictions[0]);
|
||||
for (const result of predictions) {
|
||||
if (result.className === predictions[0].className) continue;
|
||||
if (result.probability >= predictions[0].probability - 0.1) results.push(result);
|
||||
}
|
||||
return results.some(result => result.className !== 'Drawing' && result.className !== 'Neutral');
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user