diff --git a/extensions/sd_EasyPhoto/scripts/easyphoto_down.py b/extensions/sd_EasyPhoto/scripts/easyphoto_down.py index f3878ac3..e919b3a0 100644 --- a/extensions/sd_EasyPhoto/scripts/easyphoto_down.py +++ b/extensions/sd_EasyPhoto/scripts/easyphoto_down.py @@ -487,7 +487,7 @@ def check_files_exists_and_download(check_hash, download_mode="base"): if exist_flag: continue - ep_logger.info(f"Start Downloading: {url}") + # ep_logger.info(f"Start Downloading: {url}") os.makedirs(os.path.dirname(filename[0]), exist_ok=True) urldownload_progressbar(url, filename[0]) diff --git a/extensions/sd_EasyPhoto/scripts/easyphoto_infer.py b/extensions/sd_EasyPhoto/scripts/easyphoto_infer.py index 581f23d5..e7e35167 100644 --- a/extensions/sd_EasyPhoto/scripts/easyphoto_infer.py +++ b/extensions/sd_EasyPhoto/scripts/easyphoto_infer.py @@ -1,5 +1,6 @@ import copy import glob +import base64 import math import os import traceback @@ -416,6 +417,7 @@ def inpaint( @switch_sd_model_vae() @cleanup_decorator() def easyphoto_infer_forward( + user_id, sd_model_checkpoint, selected_template_images, init_image, @@ -485,7 +487,7 @@ def easyphoto_infer_forward( valid_user_id_num += 1 if len(user_ids) == last_user_id_none_num: - ep_logger.error("Please choose a user id.") + # ep_logger.error("Please choose a user id.") return "Please choose a user id.", [], [] # check & download weights of basemodel/controlnet+annotator/VAE/face_skin/buffalo/validation_template @@ -501,7 +503,7 @@ def easyphoto_infer_forward( # check the checkpoint_type of sd_model_checkpoint checkpoint_type = get_checkpoint_type(sd_model_checkpoint) if checkpoint_type == 2: - ep_logger.error("EasyPhoto does not support the SD2 checkpoint.") + # ep_logger.error("EasyPhoto does not support the SD2 checkpoint.") return "EasyPhoto does not support the SD2 checkpoint.", [], [] sdxl_pipeline_flag = True if checkpoint_type == 3 else False @@ -546,7 +548,7 @@ def easyphoto_infer_forward( error_info = "The type of the stable diffusion model {} ({}) and the user id {} ({}) does not match.".format( sd_model_checkpoint, checkpoint_type_name, user_id, lora_type_name ) - ep_logger.error(error_info) + # ep_logger.error(error_info) return error_info, [], [] loractl_flag = False @@ -555,7 +557,7 @@ def easyphoto_infer_forward( error_info = "The type of the stable diffusion model {} and attribute edit sliders ({}) does not match.".format( sd_model_checkpoint, additional_prompt ) - ep_logger.error(error_info) + # ep_logger.error(error_info) return error_info, [], [] # download all sliders here. check_files_exists_and_download(check_hash.get("sliders", True), download_mode="sliders") @@ -566,23 +568,23 @@ def easyphoto_infer_forward( controlnet_version = get_controlnet_version() major, minor, patch = map(int, controlnet_version.split(".")) if major == 0 and minor == 0 and patch == 0: - ep_logger.error("Please install sd-webui-controlnet from https://github.com/Mikubill/sd-webui-controlnet.") + # ep_logger.error("Please install sd-webui-controlnet from https://github.com/Mikubill/sd-webui-controlnet.") return "Please install sd-webui-controlnet from https://github.com/Mikubill/sd-webui-controlnet.", [], [] if ipa_control: if major < 1 or minor < 1 or patch < 417: - ep_logger.error("To use IP-Adapter Control, please upgrade sd-webui-controlnet to the latest version.") + # ep_logger.error("To use IP-Adapter Control, please upgrade sd-webui-controlnet to the latest version.") return "To use IP-Adapter Control, please upgrade sd-webui-controlnet to the latest version.", [], [] # check the number of controlnets max_control_net_unit_count = 3 if not ipa_control else 4 control_net_unit_count = shared.opts.data.get("control_net_unit_count", 3) - ep_logger.info("ControlNet unit number: {}".format(control_net_unit_count)) + # ep_logger.info("ControlNet unit number: {}".format(control_net_unit_count)) if control_net_unit_count < max_control_net_unit_count: error_info = ( "Please go to Settings/ControlNet and at least set {} for " "Multi-ControlNet: ControlNet unit number (requires restart).".format(max_control_net_unit_count) ) - ep_logger.error(error_info) + # ep_logger.error(error_info) return error_info, [], [] if ipa_control: @@ -598,7 +600,7 @@ def easyphoto_infer_forward( valid_user_id_num += 1 if valid_user_id_num > 1: - ep_logger.error("EasyPhoto does not support IP-Adapter Control with multiple user ids currently.") + # ep_logger.error("EasyPhoto does not support IP-Adapter Control with multiple user ids currently.") return "EasyPhoto does not support IP-Adapter Control with multiple user ids currently.", [], [] if ipa_control and valid_user_id_num != valid_ipa_image_path_num: ep_logger.warning( @@ -607,7 +609,7 @@ def easyphoto_infer_forward( ) if not display_score: display_score = True - ep_logger.warning("Display score is forced to be true when IP-Adapter Control is enabled.") + # ep_logger.warning("Display score is forced to be true when IP-Adapter Control is enabled.") if lcm_accelerate: lcm_lora_name_and_weight = "lcm_lora_sdxl:0.40" if sdxl_pipeline_flag else "lcm_lora_sd15:0.80" @@ -662,7 +664,7 @@ def easyphoto_infer_forward( except Exception: torch.cuda.empty_cache() traceback.print_exc() - ep_logger.error("Please choose or upload a template.") + # ep_logger.error("Please choose or upload a template.") return "Please choose or upload a template.", [], [] # create modelscope model @@ -678,7 +680,7 @@ def easyphoto_infer_forward( except Exception as e: torch.cuda.empty_cache() traceback.print_exc() - ep_logger.error(f"Skin Retouching model load error. Error Info: {e}") + # ep_logger.error(f"Skin Retouching model load error. Error Info: {e}") if portrait_enhancement is None or old_super_resolution_method != super_resolution_method: try: if super_resolution_method == "gpen": @@ -693,7 +695,7 @@ def easyphoto_infer_forward( except Exception as e: torch.cuda.empty_cache() traceback.print_exc() - ep_logger.error(f"Portrait Enhancement model load error. Error Info: {e}") + # ep_logger.error(f"Portrait Enhancement model load error. Error Info: {e}") # To save the GPU memory, create the face recognition model for computing FaceID if the user intend to show it. if display_score and face_recognition is None: @@ -710,7 +712,7 @@ def easyphoto_infer_forward( except Exception as e: torch.cuda.empty_cache() traceback.print_exc() - ep_logger.error(f"MakeUp Transfer model load error. Error Info: {e}") + # ep_logger.error(f"MakeUp Transfer model load error. Error Info: {e}") # This is to increase the fault tolerance of the code. # If the code exits abnormally, it may cause the model to not function properly on the CPU @@ -747,7 +749,7 @@ def easyphoto_infer_forward( if lcm_accelerate: input_prompt_without_lora += f", " - ep_logger.info("Start templates and user_ids preprocess.") + # ep_logger.info("Start templates and user_ids preprocess.") if tabs == 3: reload_sd_model_vae(prompt_generate_sd_model_checkpoint, prompt_generate_vae) @@ -772,10 +774,10 @@ def easyphoto_infer_forward( # scene lora path scene_lora_model_path = os.path.join(models_path, "Lora", f"{scene_id}.safetensors") if not os.path.exists(scene_lora_model_path): - ep_logger.error("Please check scene lora is exist or not.") + # ep_logger.error("Please check scene lora is exist or not.") return "Please check scene lora is exist or not.", [], [] if not check_scene_valid(f"{scene_id}.safetensors", models_path): - ep_logger.error("Please use the lora trained by ep.") + # ep_logger.error("Please use the lora trained by ep.") return "Please use the lora trained by ep.", [], [] # get lora scene prompt @@ -800,7 +802,7 @@ def easyphoto_infer_forward( last_scene_lora_prompt_low_weight += f", " # text to image with scene lora - ep_logger.info(f"Text to Image with prompt: {last_scene_lora_prompt_high_weight} and lora: {scene_lora_model_path}") + # ep_logger.info(f"Text to Image with prompt: {last_scene_lora_prompt_high_weight} and lora: {scene_lora_model_path}") template_images = txt2img( controlnet_pairs, @@ -814,7 +816,7 @@ def easyphoto_infer_forward( seed=seed, sampler="Euler a", ) - ep_logger.info(f"Hire Fix with prompt: {last_scene_lora_prompt_low_weight} and lora: {scene_lora_model_path}") + # ep_logger.info(f"Hire Fix with prompt: {last_scene_lora_prompt_low_weight} and lora: {scene_lora_model_path}") template_images = inpaint( template_images[0], None, @@ -839,7 +841,7 @@ def easyphoto_infer_forward( # text to image for template if lcm_accelerate: text_to_image_input_prompt += f", " - ep_logger.info(f"Text to Image with prompt: {text_to_image_input_prompt}") + # ep_logger.info(f"Text to Image with prompt: {text_to_image_input_prompt}") template_images = txt2img( controlnet_pairs, @@ -895,7 +897,7 @@ def easyphoto_infer_forward( retinaface_detection, ipa_image, 1.05, "crop" ) if len(_ipa_retinaface_boxes) == 0: - ep_logger.error("No face is detected in the uploaded image prompt.") + # ep_logger.error("No face is detected in the uploaded image prompt.") return "Please upload a image prompt with face.", [], [] if len(_ipa_retinaface_boxes) > 1: ep_logger.warning( @@ -951,7 +953,7 @@ def easyphoto_infer_forward( retinaface_detection, ipa_image, 1, "crop" ) if len(_ipa_retinaface_boxes) == 0: - ep_logger.error("No face is detected in the uploaded image prompt.") + # ep_logger.error("No face is detected in the uploaded image prompt.") return "Please upload a image prompt with face.", [], [] if len(_ipa_retinaface_boxes) > 1: ep_logger.warning( @@ -1015,7 +1017,7 @@ def easyphoto_infer_forward( ipa_only_weight : {str(ipa_only_weight)} ipa_only_image_path : {str(ipa_only_image_path)} """ - ep_logger.info(template_idx_info) + # ep_logger.info(template_idx_info) try: # open the template image if tabs == 0 or tabs == 2: @@ -1025,7 +1027,7 @@ def easyphoto_infer_forward( template_face_safe_boxes, _, _ = call_face_crop(retinaface_detection, template_image, multi_user_safecrop_ratio, "crop") if len(template_face_safe_boxes) == 0: - ep_logger.error("Please upload a template with face.") + # ep_logger.error("Please upload a template with face.") return "Please upload a template with face.", [], [] template_detected_facenum = len(template_face_safe_boxes) @@ -1101,10 +1103,10 @@ def easyphoto_infer_forward( target_area = 1024 * 1024 ratio = math.sqrt(target_area / (input_image.width * input_image.height)) new_size = (int(input_image.width * ratio), int(input_image.height * ratio)) - ep_logger.info("Start resize image from {} to {}.".format(input_image.size, new_size)) + # ep_logger.info("Start resize image from {} to {}.".format(input_image.size, new_size)) else: input_short_size = 512.0 - ep_logger.info("Start Image resize to {}.".format(input_short_size)) + # ep_logger.info("Start Image resize to {}.".format(input_short_size)) short_side = min(input_image.width, input_image.height) resize = float(short_side / input_short_size) new_size = (int(input_image.width // resize), int(input_image.height // resize)) @@ -1118,7 +1120,7 @@ def easyphoto_infer_forward( input_image = input_image.resize([new_width, new_height], Image.Resampling.LANCZOS) # Detect the box where the face of the template image is located and obtain its corresponding small mask - ep_logger.info("Start face detect.") + # ep_logger.info("Start face detect.") input_image_retinaface_boxes, input_image_retinaface_keypoints, input_masks = call_face_crop( retinaface_detection, input_image, 1.1, "template" ) @@ -1275,7 +1277,7 @@ def easyphoto_infer_forward( replaced_input_image = input_image # First diffusion, facial reconstruction - ep_logger.info("Start First diffusion.") + # ep_logger.info("Start First diffusion.") if not face_shape_match: if not sdxl_pipeline_flag: controlnet_pairs = [ @@ -1355,7 +1357,7 @@ def easyphoto_infer_forward( if color_shift_middle: # apply color shift - ep_logger.info("Start color shift middle.") + # ep_logger.info("Start color shift middle.") first_diffusion_output_image_uint8 = np.uint8(np.array(first_diffusion_output_image)) # crop image first first_diffusion_output_image_crop = Image.fromarray( @@ -1391,7 +1393,7 @@ def easyphoto_infer_forward( # Second diffusion if roop_images[index] is not None and apply_face_fusion_after: # Fusion of facial photos with user photos - ep_logger.info("Start second face fusion.") + # ep_logger.info("Start second face fusion.") fusion_image = image_face_fusion(dict(template=first_diffusion_output_image, user=roop_images[index]))[ OutputKeys.OUTPUT_IMG ] # swap_face(target_img=output_image, source_img=roop_image, model="inswapper_128.onnx", upscale_options=UpscaleOptions()) @@ -1436,7 +1438,7 @@ def easyphoto_infer_forward( if enable_second_diffusion: # Add mouth_mask to avoid some fault lips, close if you dont need if need_mouth_fix: - ep_logger.info("Start mouth detect.") + # ep_logger.info("Start mouth detect.") mouth_mask, face_mask = face_skin( input_image, retinaface_detection, [[4, 5, 12, 13], [1, 2, 3, 4, 5, 10, 11, 12, 13]] ) @@ -1454,7 +1456,7 @@ def easyphoto_infer_forward( face_mask = face_mask.resize([m_w, m_h]) input_mask = Image.fromarray(np.uint8(np.clip(np.float32(face_mask) + np.float32(mouth_mask), 0, 255))) - ep_logger.info("Start Second diffusion.") + # ep_logger.info("Start Second diffusion.") if not sdxl_pipeline_flag: controlnet_pairs = [["canny", fusion_image, 1.00], ["tile", fusion_image, 1.00]] if ipa_control: @@ -1482,7 +1484,7 @@ def easyphoto_infer_forward( # use original template face area to shift generated face color at last if color_shift_last: - ep_logger.info("Start color shift last.") + # ep_logger.info("Start color shift last.") # scale box rescale_retinaface_box = [int(i * default_hr_scale) for i in input_image_retinaface_box] second_diffusion_output_image_uint8 = np.uint8(np.array(second_diffusion_output_image)) @@ -1553,7 +1555,7 @@ def easyphoto_infer_forward( # If it is a large template for cutting, paste the reconstructed image back if crop_face_preprocess: - ep_logger.info("Start paste crop image to origin template.") + # ep_logger.info("Start paste crop image to origin template.") origin_loop_template_image = np.array(copy.deepcopy(loop_template_image)) x1, y1, x2, y2 = loop_template_crop_safe_box @@ -1593,7 +1595,7 @@ def easyphoto_infer_forward( loop_output_image = Image.fromarray(loop_output_image) if min(len(template_face_safe_boxes), len(user_ids) - last_user_id_none_num) > 1: - ep_logger.info("Start paste crop image to origin template in multi people.") + # ep_logger.info("Start paste crop image to origin template in multi people.") template_face_safe_box = template_face_safe_boxes[index] output_image_mask = np.zeros_like(np.array(output_image)) output_image_mask[ @@ -1610,7 +1612,7 @@ def easyphoto_infer_forward( try: if min(len(template_face_safe_boxes), len(user_ids) - last_user_id_none_num) > 1 or background_restore: - ep_logger.info("Start Third diffusion for background.") + # ep_logger.info("Start Third diffusion for background.") output_image = Image.fromarray(np.uint8(output_image)) # When reconstructing the entire background, use smaller denoise values with larger diffusion_steps to prevent discordant scenes and image collapse. denoising_strength = background_restore_denoising_strength if background_restore else 0.3 @@ -1716,13 +1718,13 @@ def easyphoto_infer_forward( except Exception as e: torch.cuda.empty_cache() traceback.print_exc() - ep_logger.error(f"Background Restore Failed, Please check the ratio of height and width in template. Error Info: {e}") + # ep_logger.error(f"Background Restore Failed, Please check the ratio of height and width in template. Error Info: {e}") return f"Background Restore Failed, Please check the ratio of height and width in template. Error Info: {e}", outputs, [] if total_processed_person != 0: if skin_retouching_bool: try: - ep_logger.info("Start Skin Retouching.") + # ep_logger.info("Start Skin Retouching.") # Skin Retouching is performed here. output_image = Image.fromarray( cv2.cvtColor(skin_retouching(output_image)[OutputKeys.OUTPUT_IMG], cv2.COLOR_BGR2RGB) @@ -1730,12 +1732,12 @@ def easyphoto_infer_forward( except Exception as e: torch.cuda.empty_cache() traceback.print_exc() - ep_logger.error(f"Skin Retouching error: {e}") + # ep_logger.error(f"Skin Retouching error: {e}") # TODO: 上采样可能存在问题 if super_resolution: try: - ep_logger.info("Start Portrait enhancement.") + # ep_logger.info("Start Portrait enhancement.") h, w, c = np.shape(np.array(output_image)) # Super-resolution is performed here. output_image_sr = Image.fromarray( @@ -1758,13 +1760,32 @@ def easyphoto_infer_forward( except Exception as e: torch.cuda.empty_cache() traceback.print_exc() - ep_logger.error(f"Portrait enhancement error: {e}") + print(f"Portrait enhancement error: {e}") + # ep_logger.error(f"Portrait enhancement error: {e}") else: output_image = template_image + + # Save the output image + output_path = f'./outputs_easyphoto/{user_ids[0]}/' + if output_path is not None: + # 如果文件夹不存在就创建它 + if not os.path.exists(output_path): + os.makedirs(output_path) + print(f"Save template {str(template_idx + 1)} to S3.") + image = decode_image_from_base64jpeg(output_image) + output_img_path = os.path.join(os.path.join(output_path), + f"{user_ids[0]}_" + + str(template_idx + 1)) + cv2.imwrite(output_img_path, image) + if shared.upload_image(output_img_path, user_id, user_ids[0]): + print(f"Template {str(template_idx + 1)} Success.") + else: + print(f"Template {str(template_idx + 1)} Failed.") outputs.append(output_image) - if loractl_flag: - outputs.append(lora_weight_image) + # if loractl_flag: + # outputs.append(lora_weight_image) + save_image( output_image, easyphoto_outpath_samples, @@ -1784,7 +1805,7 @@ def easyphoto_infer_forward( except Exception as e: torch.cuda.empty_cache() traceback.print_exc() - ep_logger.error(f"Template {str(template_idx + 1)} error: Error info is {e}, skip it.") + # ep_logger.error(f"Template {str(template_idx + 1)} error: Error info is {e}, skip it.") if loop_message != "": loop_message += "\n" @@ -2880,3 +2901,22 @@ def easyphoto_video_infer_forward( loop_message += f"Template {str(template_idx + 1)} error: Error info is {e}." return loop_message, output_video, output_gif, outputs + + +def post_single_image(image_path, user_id, unique_id): + bucket, key = shared.get_bucket_and_key(shared.generated_lora_s3uri) + if key.endswith('/'): + key = key[:-1] + key += "/" + user_id + shared.s3_client.put_object( + Body=open(image_path, 'rb'), + Bucket=bucket, + Key=f'{key}/{unique_id}.safetensors' + ) + + +def decode_image_from_base64jpeg(base64_image): + image_bytes = base64.b64decode(base64_image) + np_arr = np.frombuffer(image_bytes, np.uint8) + image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + return image \ No newline at end of file diff --git a/extensions/sd_EasyPhoto/scripts/easyphoto_train.py b/extensions/sd_EasyPhoto/scripts/easyphoto_train.py index 828c44fe..fcd94019 100644 --- a/extensions/sd_EasyPhoto/scripts/easyphoto_train.py +++ b/extensions/sd_EasyPhoto/scripts/easyphoto_train.py @@ -62,10 +62,10 @@ def easyphoto_train_forward( print(f'lora training v2: {unique_id} start...') if unique_id == "" or unique_id is None: - ep_logger.error("User id cannot be set to empty.") + # ep_logger.error("User id cannot be set to empty.") return "User id cannot be set to empty." if unique_id == "none": - ep_logger.error("User id cannot be set to none.") + # ep_logger.error("User id cannot be set to none.") return "User id cannot be set to none." ids = [] @@ -76,15 +76,15 @@ def easyphoto_train_forward( ids.append(_id) ids = sorted(ids) if unique_id in ids: - ep_logger.error("User id non-repeatability.") + # ep_logger.error("User id non-repeatability.") return "User id non-repeatability." if len(instance_images) == 0: - ep_logger.error("Please upload training photos.") + # ep_logger.error("Please upload training photos.") return "Please upload training photos." if int(rank) < int(network_alpha): - ep_logger.error("The network alpha {} must not exceed rank {}. " "It will result in an unintended LoRA.".format(network_alpha, rank)) + # ep_logger.error("The network alpha {} must not exceed rank {}. " "It will result in an unintended LoRA.".format(network_alpha, rank)) return "The network alpha {} must not exceed rank {}. " "It will result in an unintended LoRA.".format(network_alpha, rank) # TODO: 将训练所需资源从 S3 下载到本地(当前方式为从网络下载) @@ -95,7 +95,7 @@ def easyphoto_train_forward( checkpoint_type = get_checkpoint_type(sd_model_checkpoint) if checkpoint_type == 2: - ep_logger.error("EasyPhoto does not support the SD2 checkpoint: {}.".format(sd_model_checkpoint)) + # ep_logger.error("EasyPhoto does not support the SD2 checkpoint: {}.".format(sd_model_checkpoint)) return "EasyPhoto does not support the SD2 checkpoint: {}.".format(sd_model_checkpoint) sdxl_pipeline_flag = True if checkpoint_type == 3 else False @@ -105,21 +105,21 @@ def easyphoto_train_forward( # check if user want to train Scene Lora train_scene_lora_bool = True if train_mode_choose == "Train Scene Lora" else False - if train_scene_lora_bool and float(crop_ratio) < 1: - ep_logger.warning("The crop ratio {} is smaller than 1. Use original photos to train the scene LoRA.".format(crop_ratio)) + # if train_scene_lora_bool and float(crop_ratio) < 1: + # ep_logger.warning("The crop ratio {} is smaller than 1. Use original photos to train the scene LoRA.".format(crop_ratio)) cache_outpath_samples = scene_id_outpath_samples if train_scene_lora_bool else user_id_outpath_samples # Check conflicted arguments in SDXL training. if sdxl_pipeline_flag: if enable_rl: - ep_logger.error("EasyPhoto does not support RL with the SDXL checkpoint: {}.".format(sd_model_checkpoint)) + # ep_logger.error("EasyPhoto does not support RL with the SDXL checkpoint: {}.".format(sd_model_checkpoint)) return "EasyPhoto does not support RL with the SDXL checkpoint: {}.".format(sd_model_checkpoint) if int(resolution) < 1024: - ep_logger.error("The resolution for SDXL Training needs to be 1024.") + # ep_logger.error("The resolution for SDXL Training needs to be 1024.") return "The resolution for SDXL Training needs to be 1024." if validation: # We do not ensemble models by validation in SDXL training. - ep_logger.error("To save training time and VRAM, please turn off validation in SDXL training.") + # ep_logger.error("To save training time and VRAM, please turn off validation in SDXL training.") return "To save training time and VRAM, please turn off validation in SDXL training." # Template address @@ -194,23 +194,24 @@ def easyphoto_train_forward( # check preprocess results train_images = glob(os.path.join(images_save_path, "*.jpg")) if len(train_images) == 0: - ep_logger.error("Failed to obtain preprocessed images, please check the preprocessing process.") + # ep_logger.error("Failed to obtain preprocessed images, please check the preprocessing process.") return "Failed to obtain preprocessed images, please check the preprocessing process." if not os.path.exists(json_save_path): - ep_logger.error("Failed to obtain preprocessed metadata.jsonl, please check the preprocessing process.") + # ep_logger.error("Failed to obtain preprocessed metadata.jsonl, please check the preprocessing process.") return "Failed to obtain preprocessed metadata.jsonl, please check the preprocessing process." if not sdxl_pipeline_flag: train_kohya_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "train_kohya/train_lora.py") else: train_kohya_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "train_kohya/train_lora_sd_XL.py") - ep_logger.info("train_file_path : {}".format(train_kohya_path)) + # ep_logger.info("train_file_path : {}".format(train_kohya_path)) + print(f'train_file_path : {train_kohya_path}') if enable_rl and not train_scene_lora_bool: train_ddpo_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "train_kohya/train_ddpo.py") - ep_logger.info("train_ddpo_path : {}".format(train_kohya_path)) + # ep_logger.info("train_ddpo_path : {}".format(train_kohya_path)) # outputs/easyphoto-tmp/train_kohya_log.txt, use to cache log and flush to UI - ep_logger.info("cache_log_file_path: {}".format(cache_log_file_path)) + # ep_logger.info("cache_log_file_path: {}".format(cache_log_file_path)) if not os.path.exists(os.path.dirname(cache_log_file_path)): os.makedirs(os.path.dirname(cache_log_file_path), exist_ok=True) @@ -426,7 +427,7 @@ def easyphoto_train_forward( max_rl_time = int(float(max_rl_time) * 60 * 60) env["MAX_RL_TIME"] = str(max_rl_time) try: - ep_logger.info("Start RL (reinforcement learning). The max time of RL is {}.".format(max_rl_time)) + # ep_logger.info("Start RL (reinforcement learning). The max time of RL is {}.".format(max_rl_time)) # Since `accelerate` spawns a new process, set `timeout` in `subprocess.run` does not take effects. subprocess.run(command, env=env, check=True) except subprocess.CalledProcessError as e: diff --git a/modules/api/api.py b/modules/api/api.py index d588ad64..53b78877 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -922,7 +922,7 @@ def invocations(self, req: models.InvocationsRequest): with self.invocations_lock: print('-------invocation------') print(req) - print("working..........") + # print("working..........") # check memory and collect garbage self.check_memory_and_collect_garbage() @@ -942,7 +942,7 @@ def invocations(self, req: models.InvocationsRequest): # 请求任务模式,临时下载数据集 user_path = f'./datasets/{req.userId}/{req.unique_id}' if req.s3Url !='': - print(f'user_path: {user_path}, download dataset from s3: {req.s3Url}.') + # print(f'user_path: {user_path}, download dataset from s3: {req.s3Url}.') shared.download_dataset_from_s3(req.s3Url, user_path) print(f'download dataset from s3: {req.s3Url} success.') @@ -965,7 +965,7 @@ def invocations(self, req: models.InvocationsRequest): "network_alpha": req.network_alpha } outputs = self.easyphoto_train(payload_train) - print(outputs["message"]) + print("Train outputs: ", outputs["message"]) time.sleep(10) @@ -985,12 +985,13 @@ def invocations(self, req: models.InvocationsRequest): # 判断是不是图片格式文件 if image_format not in ['*.jpg', '*.jpeg', '*.png', '*.webp']: continue - print(glob(os.path.join(template_dir, image_format))) + # print(glob(os.path.join(template_dir, image_format))) img_list.extend(glob(os.path.join(template_dir, image_format))) if len(img_list) == 0: print(f"Input template dir {template_dir} contains no images") else: print(f"Total {len(img_list)} templates to process for {req.unique_id} ID") + print(img_list) output_path = f'./outputs_easyphoto/{req.unique_id}/' if output_path is not None: # 如果文件夹不存在就创建它 @@ -1004,11 +1005,13 @@ def invocations(self, req: models.InvocationsRequest): selected_template_images.append(encoded_image) payload_infer = { - "user_ids": [req.unique_id], + "user_id": req.userId, + "user_ids": req.unique_id, "sd_model_checkpoint": req.model, "selected_template_images": selected_template_images, } outputs = self.easyphoto_infer(payload_infer) + ############################################################## print("Infer results: ", outputs["message"]) print("Infer results numbers: ", len(outputs["outputs"])) print("Mode images numbers: ", len(img_list)) @@ -1016,7 +1019,8 @@ def invocations(self, req: models.InvocationsRequest): for idx, img_path_output in enumerate(img_list): image = decode_image_from_base64jpeg(outputs["outputs"][idx]) output_img_path = os.path.join(os.path.join(output_path), - f"{req.unique_id}_" + os.path.basename(img_path_output)) + f"{req.unique_id}_" + + os.path.basename(img_path_output)) print(output_img_path) cv2.imwrite(output_img_path, image) # 最终将用户数据集上传到 S3 ( user_id / uuid 路径下) @@ -1136,6 +1140,7 @@ def easyphoto_train(self, datas: dict): return {"message": message} def easyphoto_infer(self, datas: dict): + user_id = datas.get("user_id", "test") user_ids = datas.get("user_ids", []) sd_model_checkpoint = datas.get("sd_model_checkpoint", "Chilloutmix-Ni-pruned-fp16-fix.safetensors") selected_template_images = datas.get("selected_template_images", []) @@ -1249,6 +1254,7 @@ def easyphoto_infer(self, datas: dict): tabs = int(tabs) try: comment, outputs, face_id_outputs = easyphoto_infer_forward( + user_id, sd_model_checkpoint, selected_template_images, init_image, @@ -1306,6 +1312,7 @@ def easyphoto_infer(self, datas: dict): torch.cuda.empty_cache() comment = f"Infer error, error info:{str(e)}" outputs = [] + print("Infer error, error info:", str(e)) face_id_outputs_base64 = [] traceback.print_exc() diff --git a/modules/shared.py b/modules/shared.py index cd297725..4b4668ec 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -159,6 +159,22 @@ def get_bucket_and_key(s3uri): return bucket, key +def upload_image(output_file_path, user_id, unique_id): + bucket, key = get_bucket_and_key(generated_images_s3uri) + try: + file_name = os.path.basename(output_file_path) + if key.endswith('/'): + key = key[:-1] + key += "/" + user_id + __s3file = f'{key}/{unique_id}_{file_name}' + print(output_file_path, __s3file) + s3_client.upload_file(output_file_path, bucket, __s3file) + except ClientError as e: + print(e) + return False + return True + + def upload_image_to_s3(output_file_path, user_id, unique_id): bucket, key = get_bucket_and_key(generated_images_s3uri) if key.endswith('/'): @@ -255,6 +271,7 @@ def upload_s3files(s3uri, file_path_with_pattern): return False return True + def upload_s3folder(s3uri, file_path): pos = s3uri.find('/', 5) bucket = s3uri[5 : pos]