Skip to content

Commit

Permalink
Started Hue/Saturation node
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisfreilich committed May 2, 2024
1 parent f714d9e commit 13619c2
Showing 1 changed file with 134 additions and 53 deletions.
187 changes: 134 additions & 53 deletions colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
@title: Virtuoso Pack - Color Nodes
@nickname: Virtuoso Pack -Color Nodes
@description: This extension provides a solid color node, Color Balance Node, Color Balance Advanced Node,
SplitRGB and MergeRGB nodes, and Black and White node.
SplitRGB and MergeRGB nodes, Hue/Saturation, and Black and White node.
"""
import torch
from scipy.interpolate import CubicSpline
Expand Down Expand Up @@ -376,13 +376,11 @@ def do_black_and_white(self, image, red, green, blue, cyan, magenta, yellow):



import torch.nn.functional as F
import colorsys

class HueSat():
NAME = "Hue/Saturation"
CATEGORY = "Virtuoso"
RETURN_TYPES = ("IMAGE",)
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "do_hue_sat"

@classmethod
Expand Down Expand Up @@ -442,54 +440,137 @@ def INPUT_TYPES(s) -> dict:
}
}

def do_hue_sat(self, image, hue_low, hue_low_feather,hue_high, hue_high_feather, hue_offset, sat_offset, lightness_offset):

def do_hue_sat(self, image, hue_low, hue_low_feather, hue_high, hue_high_feather, hue_offset, sat_offset, lightness_offset):
# Convert image to HSV
image_hsv = torch.zeros_like(image)
for i in range(image.shape[0]):
for j in range(image.shape[1]):
for k in range(image.shape[2]):
r, g, b = image[i, j, k, 0:3]
h, s, v = colorsys.rgb_to_hsv(r.item(), g.item(), b.item())
image_hsv[i, j, k, 0] = h * 360
image_hsv[i, j, k, 1] = s
image_hsv[i, j, k, 2] = v
if image.shape[3] == 4:
image_hsv[i, j, k, 3] = image[i, j, k, 3]

# Calculate the range of hues that will be affected by the adjustment to be made
hue_low_wrap = (hue_low - hue_low_feather) % 360
hue_high_wrap = (hue_high + hue_high_feather) % 360

if hue_low < hue_high:
mask = (image_hsv[:, :, :, 0] >= hue_low_wrap) & (image_hsv[:, :, :, 0] <= hue_high_wrap)
else:
mask = (image_hsv[:, :, :, 0] >= hue_low_wrap) | (image_hsv[:, :, :, 0] <= hue_high_wrap)

# Shift hues
image_hsv[:, :, :, 0] = (image_hsv[:, :, :, 0] + hue_offset) % 360

# Change saturation values
image_hsv[:, :, :, 1] = torch.clamp(image_hsv[:, :, :, 1] + sat_offset / 100, 0, 1)

# Change lightness values
if lightness_offset < 0:
image_hsv[:, :, :, 2] = torch.clamp(image_hsv[:, :, :, 2] + lightness_offset / 100, 0, 1)
image_hsv = rgb_to_hsv(image)

# Calculate the mask
mask = create_mask(image_hsv[..., 0], hue_low, hue_high, hue_low_feather, hue_high_feather)

# Adjust HSL values
image_hsv[..., 0] = (image_hsv[..., 0] + hue_offset) % 360
image_hsv[..., 1] = torch.clamp(image_hsv[..., 1] + sat_offset / 100, 0, 1)
lightness_adjust = lightness_offset / 100 * (-1 if lightness_offset < 0 else (1 - image_hsv[..., 2]))
image_hsv[..., 2] = torch.clamp(image_hsv[..., 2] + lightness_adjust, 0, 1)

# Convert back to RGB
adjusted_image_rgb = hsv_to_rgb(image_hsv[..., :3])

# Blend the original and adjusted images based on the mask
blended_rgb = (adjusted_image_rgb * mask.unsqueeze(-1)) + (image[..., :3] * (1 - mask.unsqueeze(-1)))

# Include the alpha channel if present
if image.shape[-1] == 4:
blended_rgba = torch.cat((blended_rgb, image[..., 3:4]), dim=-1)
else:
image_hsv[:, :, :, 2] = torch.clamp(image_hsv[:, :, :, 2] + lightness_offset / 100 * (1 - image_hsv[:, :, :, 2]), 0, 1)

# Apply mask
image_hsv[~mask] = image[~mask]

# Convert image back to RGB
image_rgb = torch.zeros_like(image)
for i in range(image.shape[0]):
for j in range(image.shape[1]):
for k in range(image.shape[2]):
h, s, v = image_hsv[i, j, k, 0:3].tolist()
r, g, b = colorsys.hsv_to_rgb(h / 360, s, v)
image_rgb[i, j, k, 0:3] = torch.tensor([r, g, b])
if image.shape[3] == 4:
image_rgb[i, j, k, 3] = image_hsv[i, j, k, 3]

return (image_rgb,)
blended_rgba = blended_rgb

return (blended_rgba, mask)


# Thanks to MA Lee for conversion code
def rgb_to_hsv(rgb: torch.Tensor) -> torch.Tensor:

input_tensor = rgb.clone()

# Check if there's an alpha channel
has_alpha = input_tensor.shape[-1] == 4

# Remove the alpha channel if it exists
if has_alpha:
alpha_channel = input_tensor[:, :, :, 3:4]
input_tensor = input_tensor[:, :, :, :3]

# Permute the dimensions from [B, H, W, 3] to [B, 3, H, W]
input_tensor = input_tensor.permute(0, 3, 1, 2)

# Convert RGB to HSV
cmax, cmax_idx = torch.max(input_tensor, dim=1, keepdim=True)
cmin = torch.min(input_tensor, dim=1, keepdim=True)[0]
delta = cmax - cmin
hsv_h = torch.empty_like(input_tensor[:, 0:1, :, :])
cmax_idx[delta == 0] = 3
hsv_h[cmax_idx == 0] = (((input_tensor[:, 1:2] - input_tensor[:, 2:3]) / delta) % 6)[cmax_idx == 0]
hsv_h[cmax_idx == 1] = (((input_tensor[:, 2:3] - input_tensor[:, 0:1]) / delta) + 2)[cmax_idx == 1]
hsv_h[cmax_idx == 2] = (((input_tensor[:, 0:1] - input_tensor[:, 1:2]) / delta) + 4)[cmax_idx == 2]
hsv_h[cmax_idx == 3] = 0.
hsv_h /= 6.
hsv_s = torch.where(cmax == 0, torch.tensor(0.).type_as(input_tensor), delta / cmax)
hsv_v = cmax
hsv_tensor = torch.cat([hsv_h, hsv_s, hsv_v], dim=1)

# Permute the dimensions back to [B, H, W, 3]
hsv_tensor = hsv_tensor.permute(0, 2, 3, 1)

# Add back the alpha channel if it was present
if has_alpha:
hsv_tensor = torch.cat([hsv_tensor, alpha_channel], dim=-1)

return hsv_tensor

def hsv_to_rgb(hsv: torch.Tensor) -> torch.Tensor:

input_tensor = hsv.clone()

# Check if there's an alpha channel
has_alpha = input_tensor.shape[-1] == 4

# Remove the alpha channel if it exists
if has_alpha:
alpha_channel = input_tensor[:, :, :, 3:4]
input_tensor = input_tensor[:, :, :, :3]

# Permute the dimensions from [B, H, W, 3] to [B, 3, H, W]
input_tensor = input_tensor.permute(0, 3, 1, 2)

# Extract HSV components
hsv_h, hsv_s, hsv_v = input_tensor[:, 0:1], input_tensor[:, 1:2], input_tensor[:, 2:3]
_c = hsv_v * hsv_s
_x = _c * (- torch.abs(hsv_h * 6. % 2. - 1) + 1.)
_m = hsv_v - _c
_o = torch.zeros_like(_c)
idx = (hsv_h * 6.).type(torch.uint8)
idx = (idx % 6).expand(-1, 3, -1, -1)
rgb = torch.empty_like(input_tensor)
rgb[idx == 0] = torch.cat([_c, _x, _o], dim=1)[idx == 0]
rgb[idx == 1] = torch.cat([_x, _c, _o], dim=1)[idx == 1]
rgb[idx == 2] = torch.cat([_o, _c, _x], dim=1)[idx == 2]
rgb[idx == 3] = torch.cat([_o, _x, _c], dim=1)[idx == 3]
rgb[idx == 4] = torch.cat([_x, _o, _c], dim=1)[idx == 4]
rgb[idx == 5] = torch.cat([_c, _o, _x], dim=1)[idx == 5]
rgb += _m

# Permute the dimensions back to [B, H, W, 3]
rgb_tensor = rgb.permute(0, 2, 3, 1)

# Add back the alpha channel if it was present
if has_alpha:
rgb_tensor = torch.cat([rgb_tensor, alpha_channel], dim=-1)

return rgb_tensor

def create_mask(hue, hue_low, hue_high, hue_low_feather, hue_high_feather):
hue_low = hue_low / 360.0
hue_high = hue_high / 360.0
hue_low_feather = hue_low_feather / 360.0
hue_high_feather = hue_high_feather / 360.0

# Wrap hue values
hue = hue % 1.0
hue_low = hue_low % 1.0
hue_high = hue_high % 1.0

# Calculate mask
if hue_low < hue_high:
mask = smoothstep(hue_low - hue_low_feather, hue_low, hue) * smoothstep(hue_high + hue_high_feather, hue_high, hue)
else:
mask = smoothstep(hue_low - hue_low_feather, hue_low, hue) + smoothstep(hue_high + hue_high_feather, hue_high, hue)
mask = torch.clamp(mask, 0.0, 1.0)

return mask

def smoothstep(edge0, edge1, x):
# Scale, bias and saturate x to 0..1 range
x = torch.clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0)
# Evaluate polynomial
return x * x * (3 - 2 * x)

0 comments on commit 13619c2

Please sign in to comment.