Skip to content

Commit

Permalink
Fix a bug where cached outputs affected IS_CHANGED (comfyanonymous#4535)
Browse files Browse the repository at this point in the history
This change fixes a bug where non-constant values could be passed to the
IS_CHANGED function. This would result in workflows taking an extra
execution before they acted as if they were cached.

The actual change is like 4 characters -- the rest is adding unit tests.
  • Loading branch information
guill authored Aug 22, 2024
1 parent 5f84ea6 commit dafbe32
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 1 deletion.
3 changes: 2 additions & 1 deletion execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def get(self, node_id):
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]

input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
Expand Down
19 changes: 19 additions & 0 deletions tests/inference/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,22 @@ def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder):
assert len(images1) == 1, "Should have 1 image"
assert len(images2) == 1, "Should have 1 image"


# This tests that only constant outputs are used in the call to `IS_CHANGED`
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)

output = g.node("PreviewImage", images=test_node.out(0))

result = client.run(g)
images = result.get_images(output)
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"

result = client.run(g)
images = result.get_images(output)
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached"
27 changes: 27 additions & 0 deletions tests/inference/testing_nodes/testing-pack/specific_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,31 @@ def IS_CHANGED(cls, should_change=False, *args, **kwargs):
else:
return False

class TestIsChangedWithConstants:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_is_changed"

CATEGORY = "Testing/Nodes"

def custom_is_changed(self, image, value):
return (image * value,)

@classmethod
def IS_CHANGED(cls, image, value):
if image is None:
return value
else:
return image.mean().item() * value

class TestCustomValidation1:
@classmethod
def INPUT_TYPES(cls):
Expand Down Expand Up @@ -312,6 +337,7 @@ def mixed_expansion_returns(self, input1):
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
"TestCustomIsChanged": TestCustomIsChanged,
"TestIsChangedWithConstants": TestIsChangedWithConstants,
"TestCustomValidation1": TestCustomValidation1,
"TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3,
Expand All @@ -325,6 +351,7 @@ def mixed_expansion_returns(self, input1):
"TestLazyMixImages": "Lazy Mix Images",
"TestVariadicAverage": "Variadic Average",
"TestCustomIsChanged": "Custom IsChanged",
"TestIsChangedWithConstants": "IsChanged With Constants",
"TestCustomValidation1": "Custom Validation 1",
"TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3",
Expand Down
24 changes: 24 additions & 0 deletions tests/inference/testing_nodes/testing-pack/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@ def stub_image(self, content, height, width, batch_size):
elif content == "NOISE":
return (torch.rand(batch_size, height, width, 3),)

class StubConstantImage:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
},
}

RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_constant_image"

CATEGORY = "Testing/Stub Nodes"

def stub_constant_image(self, value, height, width, batch_size):
return (torch.ones(batch_size, height, width, 3) * value,)

class StubMask:
def __init__(self):
pass
Expand Down Expand Up @@ -93,12 +115,14 @@ def stub_float(self, value):

TEST_STUB_NODE_CLASS_MAPPINGS = {
"StubImage": StubImage,
"StubConstantImage": StubConstantImage,
"StubMask": StubMask,
"StubInt": StubInt,
"StubFloat": StubFloat,
}
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubImage": "Stub Image",
"StubConstantImage": "Stub Constant Image",
"StubMask": "Stub Mask",
"StubInt": "Stub Int",
"StubFloat": "Stub Float",
Expand Down

0 comments on commit dafbe32

Please sign in to comment.