From a18543a5e5856ce35697ecd4ed8df385bf3d9e1e Mon Sep 17 00:00:00 2001 From: hawor Date: Fri, 23 Feb 2024 14:37:06 +0800 Subject: [PATCH] =?UTF-8?q?infer=20bug=20=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- extensions/sd_EasyPhoto/scripts/easyphoto_infer.py | 1 + modules/api/api.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/extensions/sd_EasyPhoto/scripts/easyphoto_infer.py b/extensions/sd_EasyPhoto/scripts/easyphoto_infer.py index 49023378..581f23d5 100644 --- a/extensions/sd_EasyPhoto/scripts/easyphoto_infer.py +++ b/extensions/sd_EasyPhoto/scripts/easyphoto_infer.py @@ -1732,6 +1732,7 @@ def easyphoto_infer_forward( traceback.print_exc() ep_logger.error(f"Skin Retouching error: {e}") + # TODO: 上采样可能存在问题 if super_resolution: try: ep_logger.info("Start Portrait enhancement.") diff --git a/modules/api/api.py b/modules/api/api.py index 8e6a38a0..a674f04f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -982,6 +982,10 @@ def invocations(self, req: models.InvocationsRequest): print(f'download template from s3: {req.s3ModeUrl} success.') for image_format in image_formats: + # 判断是不是图片格式文件 + if image_format not in ['*.jpg', '*.jpeg', '*.png', '*.webp']: + continue + print(glob(os.path.join(template_dir, image_format))) img_list.extend(glob(os.path.join(template_dir, image_format))) if len(img_list) == 0: print(f"Input template dir {template_dir} contains no images") @@ -1005,6 +1009,9 @@ def invocations(self, req: models.InvocationsRequest): "selected_template_images": selected_template_images, } outputs = self.easyphoto_infer(payload_infer) + print("Infer results: ", outputs["message"]) + print("Infer results numbers: ", len(outputs["outputs"])) + print("Mode images numbers: ", len(img_list)) if len(outputs["outputs"]) == len(img_list): for idx, img_path_output in enumerate(img_list): image = decode_image_from_base64jpeg(outputs["outputs"][idx]) @@ -1301,4 +1308,4 @@ def easyphoto_infer(self, datas: dict): face_id_outputs_base64 = [] traceback.print_exc() - return {"message": comment, "outputs": outputs, "face_id_outputs": face_id_outputs_base64} \ No newline at end of file + return {"message": comment, "outputs": outputs, "face_id_outputs": face_id_outputs_base64}