From 7378cce08e899cab4ef7e0a8947927a6b6582c16 Mon Sep 17 00:00:00 2001 From: pkuliyi2015 Date: Sun, 21 May 2023 15:18:56 +0000 Subject: [PATCH] add wavelet colorfix --- README.md | 5 +++ README_CN.md | 7 ++++- scripts/stablesr.py | 50 +++++++++++++++++++++++++----- srmodule/colorfix.py | 72 ++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 124 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 47825c6..afe6e0e 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,11 @@ Relevant Links - When enabling it, the script ignores your denoising strength and gives you much more detailed images, but also changes the color & sharpness significantly - When disabling it, the script starts by adding some noise to your image. The result will be not fully detailed, even if you set denoising strength = 1 (but maybe aesthetically good). See [Comparison](https://imgsli.com/MTgwMTMx). - If you disable Pure Noise, we recommend denoising strength=1 +- What is "Color Fix"? + - This is to mitigate the color shift problem from StableSR and the tiling process. + - AdaIN simply adjusts the color statistics between the original and the outcome images. This is the official algorithm but ineffective in many cases. + - Wavelet decomposes the original and the outcome images into low and high frequency, and then replace the outcome image's low-frequency part (colors) with the original image's. This is very powerful for uneven color shifting. The algorithm is from GIMP and Krita, which will take several seconds for each image. + - When enabling color fix, the original image will also show up in your preview window, but will NOT be saved automatically. ### 6. Important Notice diff --git a/README_CN.md b/README_CN.md index ac5d8c5..fad6d60 100644 --- a/README_CN.md +++ b/README_CN.md @@ -76,6 +76,11 @@ Licensed under S-Lab License 1.0 - 启用这个选项时,脚本会忽略你的重绘幅度设置。产出将会是更详细的图像,但也会显著改变颜色和锐度。 - 禁用这个选项时,脚本会开始添加一些噪声到你的图像。即使你将去噪强度设为1,结果也不会那么的细节(但可能更和谐好看)。参见 [对比图](https://imgsli.com/MTgwMTMx)。 - 如果禁用Pure Noise,推荐重绘幅度设置为1 +- 什么是"颜色修正"? + - 这是为了缓解来自StableSR和Tile处理过程中的颜色偏移问题。 + - AdaIN简单地匹配原图和结果图的颜色统计信息。这是StableSR官方算法,但常常效果不佳。 + - Wavelet将原图和结果图分解为低频和高频,然后用原图的低频信息(颜色)替换掉结果图的低频信息。该算法对于不均匀的颜色偏移非常强力。算法来自GIMP和Krita,对每张图像需要几秒钟的时间。 + - 启用颜色修正时,原图也会出现在您的预览窗口中,但不会被自动保存。 ### 6. 重要问题 @@ -86,7 +91,7 @@ Licensed under S-Lab License 1.0 - 如果你安装了可选的 VQVAE,整个模型权重将与融合权重为 0 的官方模型相同。 - 但是,你的结果将**不如**官方结果,因为: - 采样器差异: - -官方仓库进行 100 或 200 步的 legacy DDPM 采样,并使用自定义的时间步调度器,采样时不使用负提示。 + - 官方仓库进行 100 或 200 步的 legacy DDPM 采样,并使用自定义的时间步调度器,采样时不使用负提示。 - 然而,WebUI 不提供这样的采样器,必须带有负提示进行采样。**这是主要的差异。** - VQVAE 解码器差异: - 官方 VQVAE 解码器将一些编码器特征作为输入。 diff --git a/scripts/stablesr.py b/scripts/stablesr.py index f98eb81..91a0920 100644 --- a/scripts/stablesr.py +++ b/scripts/stablesr.py @@ -46,13 +46,14 @@ from torch import Tensor from tqdm import tqdm -from modules import scripts, processing, sd_samplers, devices +from modules import scripts, processing, sd_samplers, devices, images from modules.processing import StableDiffusionProcessingImg2Img, Processed +from modules.shared import opts from ldm.modules.diffusionmodules.openaimodel import UNetModel from srmodule.spade import SPADELayers from srmodule.struct_cond import EncoderUNetModelWT, build_unetwt -from srmodule.colorfix import fix_color +from srmodule.colorfix import adain_color_fix, wavelet_color_fix SD_WEBUI_PATH = Path.cwd() ME_PATH = SD_WEBUI_PATH / 'extensions' / 'sd-webui-stablesr' @@ -150,12 +151,14 @@ def refresh_fn(selected): with gr.Row(): scale_factor = gr.Slider(minimum=1, maximum=16, step=0.1, value=2, label='Scale Factor', elem_id=f'StableSR-scale') with gr.Row(): + color_fix = gr.Dropdown(['None', 'Wavelet', 'AdaIN'], label="Color Fix", value='Wavelet', elem_id=f'StableSR-color-fix') + save_original = gr.Checkbox(label='Save Original', value=False, elem_id=f'StableSR-save-original', visible=color_fix.value != 'None') + color_fix.change(fn=lambda selected: gr.Checkbox.update(visible=selected != 'None')) pure_noise = gr.Checkbox(label='Pure Noise', value=True, elem_id=f'StableSR-pure-noise') - color_fix = gr.Checkbox(label='Color Fix', value=True, elem_id=f'StableSR-color-fix') - return [model, scale_factor, pure_noise, color_fix] + return [model, scale_factor, pure_noise, color_fix, save_original] - def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:bool): + def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:str, save_original:bool) -> Processed: if model == 'None': # do clean up @@ -169,6 +172,10 @@ def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:floa if not os.path.exists(self.model_list[model]): raise gr.Error(f"Model {model} is not on your disk! Please refresh the model list!") + if color_fix not in ['None', 'Wavelet', 'AdaIN']: + print(f'[StableSR] Invalid color fix method: {color_fix}') + color_fix = 'None' + # upscale the image, set the ouput size init_img: Image = p.init_images[0] target_width = int(init_img.width * scale_factor) @@ -222,11 +229,40 @@ def sample_custom(conditioning, unconditional_conditioning, seeds, subseeds, sub # Hook the unet, and unhook after processing. try: self.stablesr_model.hook(unet) + + if color_fix != 'None': + p.do_not_save_samples = True + result: Processed = processing.process_images(p) - if color_fix: + + if color_fix != 'None': + + fixed_images = [] + # fix the color + color_fix_func = wavelet_color_fix if color_fix == 'Wavelet' else adain_color_fix for i in range(len(result.images)): - result.images[i] = fix_color(result.images[i], init_img) + try: + fixed_images.append(color_fix_func(result.images[i], init_img)) + except Exception as e: + print(f'[StableSR] Error fixing color with default method: {e}') + + # save the fixed color images + for i in range(len(fixed_images)): + try: + images.save_image(fixed_images[i], p.outpath_samples, "", result.seed, result.prompt, opts.samples_format, info=result.infotexts, p=p) + except Exception as e: + print(f'[StableSR] Error saving color fixed image: {e}') + + if save_original: + for i in range(len(result.images)): + try: + images.save_image(result.images[i], p.outpath_samples, "", result.seed, result.prompt, opts.samples_format, info=result.infotexts, p=p, suffix="-before-color-fix") + except Exception as e: + print(f'[StableSR] Error saving original image: {e}') + result.images = result.images + fixed_images + return result + finally: self.stablesr_model.unhook(unet) diff --git a/srmodule/colorfix.py b/srmodule/colorfix.py index a284cca..55bb066 100644 --- a/srmodule/colorfix.py +++ b/srmodule/colorfix.py @@ -1,9 +1,11 @@ +import torch from PIL import Image from torch import Tensor +from torch.nn import functional as F from torchvision.transforms import ToTensor, ToPILImage -def fix_color(target: Image, source: Image): +def adain_color_fix(target: Image, source: Image): # Convert images to tensors to_tensor = ToTensor() target_tensor = to_tensor(target).unsqueeze(0) @@ -18,6 +20,21 @@ def fix_color(target: Image, source: Image): return result_image +def wavelet_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply wavelet reconstruction + result_tensor = wavelet_reconstruction(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + def calc_mean_std(feat: Tensor, eps=1e-5): """Calculate mean and std for adaptive_instance_normalization. Args: @@ -45,4 +62,55 @@ def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): style_mean, style_std = calc_mean_std(style_feat) content_mean, content_std = calc_mean_std(content_feat) normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) - return normalized_feat * style_std.expand(size) + style_mean.expand(size) \ No newline at end of file + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + +def wavelet_blur(image: Tensor, radius: int): + """ + Apply wavelet blur to the input tensor. + """ + # input shape: (1, 3, H, W) + # convolution kernel + # input shape: (1, 3, H, W) + # convolution kernel + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + # add channel dimensions to the kernel to make it a 4D tensor + kernel = kernel[None, None] + # repeat the kernel across all input channels + kernel = kernel.repeat(3, 1, 1, 1) + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') + # apply convolution + output = F.conv2d(image, kernel, groups=3, dilation=radius) + return output + +def wavelet_decomposition(image: Tensor, levels=5): + """ + Apply wavelet decomposition to the input tensor. + This function only returns the low frequency & the high frequency. + """ + high_freq = torch.zeros_like(image) + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq += (image - low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): + """ + Apply wavelet decomposition, so that the content will have the same color as the style. + """ + # calculate the wavelet decomposition of the content feature + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq + # calculate the wavelet decomposition of the style feature + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq + # reconstruct the content feature with the style's high frequency + return content_high_freq + style_low_freq +