diff --git a/tests/sd-webui/test_api.py b/tests/sd-webui/test_api.py index f4cee6601..846d3b591 100644 --- a/tests/sd-webui/test_api.py +++ b/tests/sd-webui/test_api.py @@ -26,7 +26,7 @@ @pytest.fixture(scope="session", autouse=True) def change_model(): option_payload = { - "sd_model_checkpoint": "checkpoints/AWPainting_v1.2.safetensors", + "sd_model_checkpoint": "AWPainting_v1.2.safetensors", } post_request_and_check(f"{WEBUI_SERVER_URL}/{OPTIONS_API_ENDPOINT}", option_payload) diff --git a/tests/sd-webui/utils.py b/tests/sd-webui/utils.py index 3a1bbaedd..63d164c85 100644 --- a/tests/sd-webui/utils.py +++ b/tests/sd-webui/utils.py @@ -22,9 +22,32 @@ TXT2IMG_TARGET_FOLDER = "/share_nfs/onediff_ci/sd-webui/images/txt2img" SAVED_GRAPH_NAME = "saved_graph" +control_modules = ["segmentation", "canny", "depth", "openpose"] + +sd15_mapping = { + "segmentation": "control_sd15_seg", + "canny": "control_v11p_sd15_canny", + "depth": "control_v11f1p_sd15_depth", + "openpose": "control_sd15_openpose", +} + +control_mapping_imgs = { + "segmentation": "an-source.jpg", + "canny": "sk-b-src.png", + "depth": "sk-b-dep.png", + "openpose": "an-pose.png", +} + +sdxl_mapping = { + "canny": "controlnet-canny-sdxl", +} + os.makedirs(IMG2IMG_TARGET_FOLDER, exist_ok=True) os.makedirs(TXT2IMG_TARGET_FOLDER, exist_ok=True) +def get_model_img(model_name : str ) -> str : + img_path = str(Path(__file__).parent / "images" /"txt2img"/control_mapping_imgs.get(model_name)) + return encode_file_to_base64(img_path) def get_base_args() -> Dict[str, Any]: return { @@ -42,6 +65,10 @@ def get_base_args() -> Dict[str, Any]: } +def get_model(module : str, mapping :Dict[str, str]) -> str: + return mapping.get(module, "Unknown Module") + + def get_extra_args() -> List[Dict[str, Any]]: quant_args = [ { @@ -57,11 +84,42 @@ def get_extra_args() -> List[Dict[str, Any]]: txt2img_args = [ {}, {"init_images": [get_init_image()]}, + # {"init_images": ["images"]}, + ] + + + controlnet_args = [ + { + "alwayson_scripts": { + "controlnet": { + "args": [ + { + "enabled": (not txt2img_args[0]) and x, + "module": module, + "model": get_model(module, sd15_mapping), + "weight": 1.0, + "image": get_model_img(module), + "resize_mode": "Crop and Resize", + "low_vram": False, + "processor_res": 64, + "threshold_a": 64, + "threshold_b": 64, + "guidance_start": 0.0, + "guidance_end": 1.0, + "control_mode": "Balanced", + "pixel_perfect": False + } + ] + } + } + } if x else {} + for x in [True, False] + for module in control_modules ] - return [ quant_args, txt2img_args, + controlnet_args ] @@ -73,6 +131,13 @@ def get_all_args() -> Iterable[Dict[str, Any]]: yield args +def get_controlnet_model(data: Dict[str, Any]) -> bool: + try: + return data["alwayson_scripts"]["controlnet"]["args"][0]["model"] + except (KeyError, IndexError): + return False + + def is_txt2img(data: Dict[str, Any]) -> bool: return "init_images" not in data @@ -80,6 +145,12 @@ def is_txt2img(data: Dict[str, Any]) -> bool: def is_quant(data: Dict[str, Any]) -> bool: return data["script_args"][0] +def is_controlnet(data: Dict[str, Any]) -> bool: + try: + return data["alwayson_scripts"]["controlnet"]["args"][0]["enabled"] + except (KeyError, IndexError): + return False + def encode_file_to_base64(path: str) -> str: with open(path, "rb") as file: @@ -127,7 +198,9 @@ def get_target_image_filename(data: Dict[str, Any]) -> str: txt2img_str = "txt2img" if is_txt2img(data) else "img2img" quant_str = "-quant" if is_quant(data) else "" - return f"{parent_path}/onediff{quant_str}-{txt2img_str}-w{WIDTH}-h{HEIGHT}-seed-{SEED}-numstep-{NUM_STEPS}.png" + controlnet_str = "-controlnet" if is_controlnet(data) else "" + controlnet_model_str = get_controlnet_model(data) if is_controlnet(data) else "" + return f"{parent_path}/onediff{quant_str}-{txt2img_str}{controlnet_str}{controlnet_model_str}-w{WIDTH}-h{HEIGHT}-seed-{SEED}-numstep-{NUM_STEPS}.png" def check_and_generate_images(): @@ -144,6 +217,8 @@ def get_data_summary(data: Dict[str, Any]) -> Dict[str, bool]: return { "is_txt2img": is_txt2img(data), "is_quant": is_quant(data), + "is_controlnet":is_controlnet(data), + "controlnet_model":get_controlnet_model(data) if is_controlnet(data) else "", }