Skip to content

Commit

Permalink
add wavelet colorfix
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuliyi2015 committed May 21, 2023
1 parent d2dc384 commit 7378cce
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 10 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ Relevant Links
- When enabling it, the script ignores your denoising strength and gives you much more detailed images, but also changes the color & sharpness significantly
- When disabling it, the script starts by adding some noise to your image. The result will be not fully detailed, even if you set denoising strength = 1 (but maybe aesthetically good). See [Comparison](https://imgsli.com/MTgwMTMx).
- If you disable Pure Noise, we recommend denoising strength=1
- What is "Color Fix"?
- This is to mitigate the color shift problem from StableSR and the tiling process.
- AdaIN simply adjusts the color statistics between the original and the outcome images. This is the official algorithm but ineffective in many cases.
- Wavelet decomposes the original and the outcome images into low and high frequency, and then replace the outcome image's low-frequency part (colors) with the original image's. This is very powerful for uneven color shifting. The algorithm is from GIMP and Krita, which will take several seconds for each image.
- When enabling color fix, the original image will also show up in your preview window, but will NOT be saved automatically.

### 6. Important Notice

Expand Down
7 changes: 6 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ Licensed under S-Lab License 1.0
- 启用这个选项时,脚本会忽略你的重绘幅度设置。产出将会是更详细的图像,但也会显著改变颜色和锐度。
- 禁用这个选项时,脚本会开始添加一些噪声到你的图像。即使你将去噪强度设为1,结果也不会那么的细节(但可能更和谐好看)。参见 [对比图](https://imgsli.com/MTgwMTMx)
- 如果禁用Pure Noise,推荐重绘幅度设置为1
- 什么是"颜色修正"?
- 这是为了缓解来自StableSR和Tile处理过程中的颜色偏移问题。
- AdaIN简单地匹配原图和结果图的颜色统计信息。这是StableSR官方算法,但常常效果不佳。
- Wavelet将原图和结果图分解为低频和高频,然后用原图的低频信息(颜色)替换掉结果图的低频信息。该算法对于不均匀的颜色偏移非常强力。算法来自GIMP和Krita,对每张图像需要几秒钟的时间。
- 启用颜色修正时,原图也会出现在您的预览窗口中,但不会被自动保存。

### 6. 重要问题

Expand All @@ -86,7 +91,7 @@ Licensed under S-Lab License 1.0
- 如果你安装了可选的 VQVAE,整个模型权重将与融合权重为 0 的官方模型相同。
- 但是,你的结果将**不如**官方结果,因为:
- 采样器差异:
-官方仓库进行 100 或 200 步的 legacy DDPM 采样,并使用自定义的时间步调度器,采样时不使用负提示。
- 官方仓库进行 100 或 200 步的 legacy DDPM 采样,并使用自定义的时间步调度器,采样时不使用负提示。
- 然而,WebUI 不提供这样的采样器,必须带有负提示进行采样。**这是主要的差异。**
- VQVAE 解码器差异:
- 官方 VQVAE 解码器将一些编码器特征作为输入。
Expand Down
50 changes: 43 additions & 7 deletions scripts/stablesr.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@
from torch import Tensor
from tqdm import tqdm

from modules import scripts, processing, sd_samplers, devices
from modules import scripts, processing, sd_samplers, devices, images
from modules.processing import StableDiffusionProcessingImg2Img, Processed
from modules.shared import opts
from ldm.modules.diffusionmodules.openaimodel import UNetModel

from srmodule.spade import SPADELayers
from srmodule.struct_cond import EncoderUNetModelWT, build_unetwt
from srmodule.colorfix import fix_color
from srmodule.colorfix import adain_color_fix, wavelet_color_fix

SD_WEBUI_PATH = Path.cwd()
ME_PATH = SD_WEBUI_PATH / 'extensions' / 'sd-webui-stablesr'
Expand Down Expand Up @@ -150,12 +151,14 @@ def refresh_fn(selected):
with gr.Row():
scale_factor = gr.Slider(minimum=1, maximum=16, step=0.1, value=2, label='Scale Factor', elem_id=f'StableSR-scale')
with gr.Row():
color_fix = gr.Dropdown(['None', 'Wavelet', 'AdaIN'], label="Color Fix", value='Wavelet', elem_id=f'StableSR-color-fix')
save_original = gr.Checkbox(label='Save Original', value=False, elem_id=f'StableSR-save-original', visible=color_fix.value != 'None')
color_fix.change(fn=lambda selected: gr.Checkbox.update(visible=selected != 'None'))
pure_noise = gr.Checkbox(label='Pure Noise', value=True, elem_id=f'StableSR-pure-noise')
color_fix = gr.Checkbox(label='Color Fix', value=True, elem_id=f'StableSR-color-fix')

return [model, scale_factor, pure_noise, color_fix]
return [model, scale_factor, pure_noise, color_fix, save_original]

def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:bool):
def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:str, save_original:bool) -> Processed:

if model == 'None':
# do clean up
Expand All @@ -169,6 +172,10 @@ def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:floa
if not os.path.exists(self.model_list[model]):
raise gr.Error(f"Model {model} is not on your disk! Please refresh the model list!")

if color_fix not in ['None', 'Wavelet', 'AdaIN']:
print(f'[StableSR] Invalid color fix method: {color_fix}')
color_fix = 'None'

# upscale the image, set the ouput size
init_img: Image = p.init_images[0]
target_width = int(init_img.width * scale_factor)
Expand Down Expand Up @@ -222,11 +229,40 @@ def sample_custom(conditioning, unconditional_conditioning, seeds, subseeds, sub
# Hook the unet, and unhook after processing.
try:
self.stablesr_model.hook(unet)

if color_fix != 'None':
p.do_not_save_samples = True

result: Processed = processing.process_images(p)
if color_fix:

if color_fix != 'None':

fixed_images = []
# fix the color
color_fix_func = wavelet_color_fix if color_fix == 'Wavelet' else adain_color_fix
for i in range(len(result.images)):
result.images[i] = fix_color(result.images[i], init_img)
try:
fixed_images.append(color_fix_func(result.images[i], init_img))
except Exception as e:
print(f'[StableSR] Error fixing color with default method: {e}')

# save the fixed color images
for i in range(len(fixed_images)):
try:
images.save_image(fixed_images[i], p.outpath_samples, "", result.seed, result.prompt, opts.samples_format, info=result.infotexts, p=p)
except Exception as e:
print(f'[StableSR] Error saving color fixed image: {e}')

if save_original:
for i in range(len(result.images)):
try:
images.save_image(result.images[i], p.outpath_samples, "", result.seed, result.prompt, opts.samples_format, info=result.infotexts, p=p, suffix="-before-color-fix")
except Exception as e:
print(f'[StableSR] Error saving original image: {e}')
result.images = result.images + fixed_images

return result

finally:
self.stablesr_model.unhook(unet)

Expand Down
72 changes: 70 additions & 2 deletions srmodule/colorfix.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
from PIL import Image
from torch import Tensor
from torch.nn import functional as F

from torchvision.transforms import ToTensor, ToPILImage

def fix_color(target: Image, source: Image):
def adain_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
Expand All @@ -18,6 +20,21 @@ def fix_color(target: Image, source: Image):

return result_image

def wavelet_color_fix(target: Image, source: Image):
# Convert images to tensors
to_tensor = ToTensor()
target_tensor = to_tensor(target).unsqueeze(0)
source_tensor = to_tensor(source).unsqueeze(0)

# Apply wavelet reconstruction
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)

# Convert tensor back to image
to_image = ToPILImage()
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))

return result_image

def calc_mean_std(feat: Tensor, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
Expand Down Expand Up @@ -45,4 +62,55 @@ def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)

def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output

def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq

return high_freq, low_freq

def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq

0 comments on commit 7378cce

Please sign in to comment.