Skip to content

Commit

Permalink
Progress on moving away from 8-bit blend modes.
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisfreilich committed May 5, 2024
1 parent d900e8d commit cdbc9ca
Showing 1 changed file with 36 additions and 56 deletions.
92 changes: 36 additions & 56 deletions blendmodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
grain_extract, grain_merge, divide, overlay
from .resize import match_sizes

modes_8bit = ["difference", "exclusion", "normal", "screen", "soft light", "lighten",
"lighter color", "dodge", "color dodge", "linear burn","linear dodge (add)",
"linear light", "vivid light"," pin light", "hard mix", "darken", "darker color",
"multiply", "color burn", "hard light", "subtract", "grain extract", "grain merge",
modes_8bit = ["difference", "normal", "screen", "soft light", "lighten",
"lighter color", "dodge", "linear dodge (add)",
"darken", "darker color",
"multiply", "hard light", "grain extract", "grain merge",
"divide", "overlay", "hue", "saturation", "color", "luminosity"]

class BlendModes:
Expand Down Expand Up @@ -67,42 +67,35 @@ def do_blend(self, backdrop, source, blend_mode, opacity, source_adjust, invert_
final_tensor = (torch.from_numpy(blended_np / 255)).unsqueeze(0)
return (final_tensor,)
else:
backdrop_prepped = handle_alpha(backdrop, False)
source_prepped = handle_alpha(source, invert_mask, mask)
source_prepped = match_sizes(source_prepped, backdrop)
return (backdrop, modes[blend_mode](backdrop, source_prepped, opacity))
source_prepped, _ = match_sizes(source_adjust, source_prepped, backdrop_prepped)
final_tensor = modes[blend_mode](backdrop_prepped, source_prepped, opacity)
return (final_tensor,)

def handle_alpha(img, invert_mask="true", mask=None):

alpha_img = img.clone()

# Create or process mask
if mask is None:
if alpha_img.shape[2] == 4: # If img already has an alpha channel, use it
if alpha_img.shape[3] == 4: # If img already has an alpha channel, use it
alpha_channel = alpha_img[:, :, 3:4]
else:
alpha_channel = torch.full((alpha_img.shape[0], alpha_img.shape[1], 1), fill_value=1, dtype=img.dtype, device=img.device)
alpha_channel = torch.full((1, alpha_img.shape[1], alpha_img.shape[2], 1), fill_value=1, dtype=img.dtype, device=img.device)
else:
alpha_channel = mask.clone()
# Add channel dimension if it doesn't exist
# if len(alpha_channel.shape) == 3:
# alpha_channel = alpha_channel.unsqueeze(-1)
if invert_mask == "yes":
alpha_channel = 1 - alpha_channel
_, h, w, _ = alpha_img.shape
alpha_channel = F.interpolate(alpha_channel[None, ...], size=(h, w), mode='bilinear', align_corners=False)[0]

# Resize mask to match image dimensions
h, w = alpha_img.shape[:2]
alpha_channel = F.interpolate(mask[None, ...], size=(h, w), mode='bilinear', align_corners=False)[0]

# Ensure alpha_channel has the same number of dimensions as img
if len(alpha_channel.shape) < len(alpha_img.shape):
alpha_channel = alpha_channel.unsqueeze(-1)

# If img already has an alpha channel, replace it
if alpha_img.shape[2] == 4:
if alpha_img.shape[3] == 4: # If img already has an alpha channel, replace it
alpha_img[:, :, 3:4] = alpha_channel
else:
# Concatenate the input image with the alpha channel along the channel dimension
alpha_img = torch.cat((alpha_img, alpha_channel), dim=2)
alpha_img = torch.cat((alpha_img, alpha_channel), dim=3)

return alpha_img

Expand Down Expand Up @@ -308,51 +301,38 @@ def lighter_color(backdrop, source, opacity):
return darker_lighter_color(backdrop, source, opacity, "light")

def simple_mode(backdrop, source, opacity, mode):
# Normalize the RGB and alpha values to 0-1
backdrop_norm = backdrop[:, :, :3] / 255
source_norm = source[:, :, :3] / 255
source_alpha_norm = source[:, :, 3:4] / 255

# Calculate the blend without any transparency considerations

if mode == "linear_burn":
blend = backdrop_norm + source_norm - 1
blend = backdrop + source - 1
elif mode == "linear_light":
blend = backdrop_norm + (2 * source_norm) - 1
blend = backdrop + (2 * source) - 1
elif mode == "color_dodge":
blend = backdrop_norm / (1 - source_norm)
blend = np.clip(blend, 0, 1)
blend = backdrop / (1 - source)
elif mode == "color_burn":
blend = 1 - ((1 - backdrop_norm) / source_norm)
blend = np.clip(blend, 0, 1)
blend = 1 - ((1 - backdrop) / source)
elif mode == "exclusion":
blend = backdrop_norm + source_norm - (2 * backdrop_norm * source_norm)
blend = backdrop + source - (2 * backdrop * source)
elif mode == "subtract":
blend = backdrop_norm - source_norm
blend = backdrop - source
elif mode == "vivid_light":
blend = np.where(source_norm <= 0.5, backdrop_norm / (1 - 2 * source_norm), 1 - (1 -backdrop_norm) / (2 * source_norm - 0.5) )
blend = np.clip(blend, 0, 1)
blend = torch.where(source <= 0.5, backdrop / (1 - 2 * source), 1 - (1 -backdrop) / (2 * source - 0.5) )
elif mode == "pin_light":
blend = np.where(source_norm <= 0.5, np.minimum(backdrop_norm, 2 * source_norm), np.maximum(backdrop_norm, 2 * (source_norm - 0.5)))
blend = torch.where(source <= 0.5, torch.minimum(backdrop, 2 * source), torch.maximum(backdrop, 2 * (source - 0.5)))
elif mode == "hard_mix":
blend = simple_mode(backdrop, source, opacity, "linear_light")
blend = np.round(blend[:, :, :3] / 255)

# Apply the blended layer back onto the backdrop layer while utilizing the alpha channel and opacity information
new_rgb = (1 - source_alpha_norm * opacity) * backdrop_norm + source_alpha_norm * opacity * blend

# Ensure the RGB values are within the valid range
new_rgb = np.clip(new_rgb, 0, 1)

# Convert the RGB values back to 0-255
new_rgb = new_rgb * 255

# Calculate the new alpha value by taking the maximum of the backdrop and source alpha channels
new_alpha = np.maximum(backdrop[:, :, 3], source[:, :, 3])

# Create a new RGBA image with the calculated RGB and alpha values
result = np.dstack((new_rgb, new_alpha))

return result
rgb_channels = torch.round(blend[:, :, :, :3])
alpha_channel = blend[:, :, :, 3:4]
blend = torch.cat((rgb_channels, alpha_channel), dim=-1)

blend = torch.clamp(blend, 0, 1)

source_alpha = source[:, :, :, 3:4]
new_rgb = (1 - source_alpha * opacity) * backdrop + source_alpha * opacity * blend
#new_rgb = torch.clamp(new_rgb, 0, 1)
rgb_channels = new_rgb[:, :, :, :3]
backdrop_alpha = backdrop[:, :, :, 3:4]
new_img = torch.cat((rgb_channels, backdrop_alpha), dim=-1)
return new_img

def linear_light(backdrop, source, opacity):
return simple_mode(backdrop, source, opacity, "linear_light")
Expand Down

0 comments on commit cdbc9ca

Please sign in to comment.