Skip to content

Commit

Permalink
Merge pull request #408 from Kosinkadink/new-hook-update
Browse files Browse the repository at this point in the history
Made lora hook patches work for newest ComfyUI
  • Loading branch information
Kosinkadink authored Jun 20, 2024
2 parents c98e8e6 + 517d62b commit 72482e7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
23 changes: 18 additions & 5 deletions animatediff/model_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,19 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
model_sd = self.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]

if key in model_sd:
p.add(k)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset))
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
current_hooked_patches[key] = current_patches
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
# since should care about these patches too to determine if same model, reroll patches_uuid
Expand All @@ -171,18 +174,21 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, streng
model_sd = self.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]

if key in model_sd:
p.add(k)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
# take difference between desired weight and existing weight to get diff
# TODO: create fix for fp8
current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset))
current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset, function))
current_hooked_patches[key] = current_patches
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
# since should care about these patches too to determine if same model, reroll patches_uuid
Expand Down Expand Up @@ -506,23 +512,26 @@ def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, s
model_sd = self.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]

if key in model_sd:
p.add(k)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
current_patches.append((strength_patch, patches[k], strength_model, offset))
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
current_hooked_patches[key] = current_patches
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
# since should care about these patches too to determine if same model, reroll patches_uuid
self.patches_uuid = uuid.uuid4()
return list(p)

def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches, strength_patch=1.0, strength_model=1.0):
def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, strength_patch=1.0, strength_model=1.0):
'''
Based on add_hooked_patches, but intended for using a model's weights as lora hook.
'''
Expand All @@ -531,17 +540,21 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches, strength_pat
model_sd = self.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]

if key in model_sd:
p.add(k)
current_patches: list[tuple] = current_hooked_patches.get(key, [])
# take difference between desired weight and existing weight to get diff
current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset))
# TODO: create fix for fp8
current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset, function))
current_hooked_patches[key] = current_patches
self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches
# since should care about these patches too to determine if same model, reroll patches_uuid
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-animatediff-evolved"
description = "Improved AnimateDiff integration for ComfyUI."
version = "1.0.5"
version = "1.0.6"
license = "LICENSE"
dependencies = []

Expand Down

0 comments on commit 72482e7

Please sign in to comment.