Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed May 2, 2024
1 parent 1359e10 commit 2b08015
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
global_state,
external_code,
)
from lib_controlnet.external_code import ControlNetUnit
from lib_controlnet.logging import logger
from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor
from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
Expand All @@ -21,7 +22,6 @@
from modules import shared, script_callbacks
from modules.ui_components import FormRow
from modules_forge.forge_util import HWC3
from lib_controlnet.external_code import UiControlNetUnit


@dataclass
Expand Down Expand Up @@ -172,10 +172,10 @@ def __init__(
self.webcam_mirrored = False

# Note: All gradio elements declared in `render` will be defined as member variable.
# Update counter to trigger a force update of UiControlNetUnit.
# Update counter to trigger a force update of ControlNetUnit.
# dummy_gradio_update_trigger is useful when a field with no event subscriber available changes.
# e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment
# this counter to trigger a sync update of UiControlNetUnit.
# this counter to trigger a sync update of ControlNetUnit.
self.dummy_gradio_update_trigger = None
self.enabled = None
self.upload_tab = None
Expand Down Expand Up @@ -610,6 +610,12 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
)

unit = gr.State(self.default_unit)
def create_unit(*args):
return ControlNetUnit.from_dict({
k: v
for k, v in zip(vars(ControlNetUnit()).keys(), args)
})

for comp in unit_args + (self.dummy_gradio_update_trigger,):
event_subscribers = []
if hasattr(comp, "edit"):
Expand All @@ -626,15 +632,15 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:

for event_subscriber in event_subscribers:
event_subscriber(
fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit
fn=create_unit, inputs=list(unit_args), outputs=unit
)

(
ControlNetUiGroup.a1111_context.img2img_submit_button
if self.is_img2img
else ControlNetUiGroup.a1111_context.txt2img_submit_button
).click(
fn=UiControlNetUnit,
fn=create_unit,
inputs=list(unit_args),
outputs=unit,
queue=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def try_crop_image_with_a1111_mask(
def get_input_data(self, p, unit, preprocessor, h, w):
logger.info(f'ControlNet Input Mode: {unit.input_mode}')
image_list = []
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
resize_mode = unit.resize_mode

if unit.input_mode == InputMode.MERGE:
for idx, item in enumerate(unit.batch_input_gallery):
Expand Down

0 comments on commit 2b08015

Please sign in to comment.