Skip to content

Commit

Permalink
Added Black and White node
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisfreilich committed May 1, 2024
1 parent b4c39ca commit e1782cf
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 4 deletions.
7 changes: 5 additions & 2 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .colors import MergeRGB
from .colors import ColorBalance
from .colors import ColorBalanceAdvanced
from .colors import BlackAndWhite
from .contrast import Levels

NODE_CLASS_MAPPINGS = {
Expand All @@ -15,7 +16,8 @@
"ColorBalance": ColorBalance,
"ColorBalanceAdvanced": ColorBalanceAdvanced,
"SplitRGB": SplitRGB,
"MergeRGB": MergeRGB
"MergeRGB": MergeRGB,
"BlackAndWhite": BlackAndWhite
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand All @@ -26,7 +28,8 @@
"ColorBalance": "Color Balance",
"ColorBalanceAdvanced": "Color Balance Advanced",
"SplitRGB": "Split RGB",
"MergeRGB": "Merge RGB"
"MergeRGB": "Merge RGB",
"BlackAndWhite": "Black and White"
}

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
101 changes: 99 additions & 2 deletions colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
@author: Chris Freilich
@title: Virtuoso Pack - Color Nodes
@nickname: Virtuoso Pack -Color Nodes
@description: This extension provides a solid color node, SplitRGB and MergeRGB nodes.
@description: This extension provides a solid color node, Color Balance Node, Color Balance Advanced Node,
SplitRGB and MergeRGB nodes, and Black and White node.
"""
import torch
from scipy.interpolate import CubicSpline
Expand Down Expand Up @@ -275,4 +276,100 @@ def adjust(x, center, value, max_adjustment):
current_luminance = 0.2126 * img_copy[..., 0] + 0.7152 * img_copy[..., 1] + 0.0722 * img_copy[..., 2]
img_copy *= (original_luminance / current_luminance).unsqueeze(-1)

return img_copy
return img_copy

class BlackAndWhite():
NAME = "Black and White"
CATEGORY = "Virtuoso"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "do_black_and_white"

@classmethod
def INPUT_TYPES(s) -> dict:
return {
"required": {
"image": ("IMAGE",),
"red": ("FLOAT", {
"default": 0,
"min": -1.0,
"max": 1.0,
"step": 0.01,
"round": 0.001,
"display": "number"}),
"green": ("FLOAT", {
"default": 0,
"min": -1.0,
"max": 1.0,
"step": 0.01,
"round": 0.001,
"display": "number"}),
"blue": ("FLOAT", {
"default": 0,
"min": -1.0,
"max": 1.0,
"step": 0.01,
"round": 0.001,
"display": "number"}),
"cyan": ("FLOAT", {
"default": 0,
"min": -1.0,
"max": 1.0,
"step": 0.01,
"round": 0.001,
"display": "number"}),
"magenta": ("FLOAT", {
"default": 0,
"min": -1.0,
"max": 1.0,
"step": 0.01,
"round": 0.001,
"display": "number"}),
"yellow": ("FLOAT", {
"default": 0,
"min": -1.0,
"max": 1.0,
"step": 0.01,
"round": 0.001,
"display": "number"}),
}
}

def do_black_and_white(self, image, red, green, blue, cyan, magenta, yellow):
"""
Convert a color image to black and white with adjustable color weights.
Parameters:
img (torch.Tensor): Input image tensor with shape [batch size, height, width, number of channels]
red (float): Weight for red, range -1.0 to 1.0
green (float): Weight for green, range -1.0 to 1.0
blue (float): Weight for blue, range -1.0 to 1.0
cyan (float): Weight for cyan, range -1.0 to 1.0
magenta (float): Weight for magenta, range -1.0 to 1.0
yellow (float): Weight for yellow, range -1.0 to 1.0
Returns:
torch.Tensor: Black and white image tensor with values in range 0-1
"""
# Calculate minimum color value across all color channels for each pixel
min_c, _ = image.min(dim=-1)

# Calculate differences between color channels and minimum color value
diff = image - min_c.unsqueeze(-1)

# Create masks for red, green, and blue pixels
red_mask = (diff[:, :, :, 0] == 0)
green_mask = torch.logical_and((diff[:, :, :, 1] == 0), ~red_mask)
blue_mask = ~torch.logical_or(red_mask, green_mask)

# Calculate c, m, and yel values
c, _ = diff[:, :, :, 1:].min(dim=-1)
m, _ = diff[:, :, :, [0, 2]].min(dim=-1)
yel, _ = diff[:, :, :, :2].min(dim=-1)

# Calculate luminance using vectorized operations
luminance = min_c + red_mask * (c * cyan + (diff[:, :, :, 1] - c) * green + (diff[:, :, :, 2] - c) * blue)
luminance += green_mask * (m * magenta + (diff[:, :, :, 0] - m) * red + (diff[:, :, :, 2] - m) * blue)
luminance += blue_mask * (yel * yellow + (diff[:, :, :, 0] - yel) * red + (diff[:, :, :, 1] - yel) * green)

# Clip luminance values to be between 0 and 1
return (luminance.clamp(0, 1),)

0 comments on commit e1782cf

Please sign in to comment.