diff --git a/structures/Tensorflow.js b/structures/Tensorflow.js index 8006a0fa..7234737e 100644 --- a/structures/Tensorflow.js +++ b/structures/Tensorflow.js @@ -73,8 +73,14 @@ module.exports = class Tensorflow { async stylizeImage(image, styleImg) { const imageTensor = await tf.node.decodeImage(image, 3); - const loadedImage = imageTensor.toFloat().div(tf.scalar(255)).expandDims(); + const [originalHeight, originalWidth] = imageTensor.shape.slice(0, 2); + const desiredWidth = 400; + const aspectRatio = originalWidth / originalHeight; + const newHeight = Math.round(desiredWidth / aspectRatio); + const resizedImage = tf.image.resizeBilinear(imageTensor, [newHeight, desiredWidth]); imageTensor.dispose(); + const loadedImage = resizedImage.toFloat().div(tf.scalar(255)).expandDims(); + resizedImage.dispose(); const styleTensor = tf.node.decodeImage(styleImg, 3); const loadedStyle = styleTensor.toFloat().div(tf.scalar(255)).expandDims(); styleTensor.dispose();