From 059457a61a7ff2b6a185f9cee10da9f7e57f4a2a Mon Sep 17 00:00:00 2001 From: Dragon Fire Date: Wed, 8 May 2024 18:14:08 -0400 Subject: [PATCH] Fix --- structures/Tensorflow.js | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/structures/Tensorflow.js b/structures/Tensorflow.js index 50130c5f..40b70a02 100644 --- a/structures/Tensorflow.js +++ b/structures/Tensorflow.js @@ -73,15 +73,13 @@ module.exports = class Tensorflow { async stylizeImage(image, styleImg) { const imageTensor = await tf.node.decodeImage(image, 3); - const loadedImage = imageTensor.div(tf.scalar(255)).expandDims(); + const loadedImage = imageTensor.toFloat().div(tf.scalar(255)).expandDims(); imageTensor.dispose(); const styleTensor = tf.node.decodeImage(styleImg, 3); - const loadedStyle = styleTensor.div(tf.scalar(255)).expandDims(); styleTensor.dispose(); - const reshapedStyle = loadedStyle.reshape([1, 1, 1, 100]); + const loadedStyle = styleTensor.toFloat().div(tf.scalar(255)).expandDims(); + const stylePrediction = await this.styleModel.predict(loadedStyle); loadedStyle.dispose(); - const stylePrediction = await this.styleModel.predict(reshapedStyle); - reshapedStyle.dispose(); const stylizedImage = await this.transformerModel.predict([loadedImage, stylePrediction.squeeze()]); loadedImage.dispose(); stylePrediction.dispose();