Merge pull request #14547 from AUTOMATIC1111/lyco-forward

Implement general forward method for all method in built-in lora ext
This commit is contained in:
AUTOMATIC1111
2024-01-06 10:50:06 +03:00
committed by GitHub
2 changed files with 38 additions and 7 deletions
+32 -1
View File
@@ -3,6 +3,9 @@ import os
from collections import namedtuple from collections import namedtuple
import enum import enum
import torch.nn as nn
import torch.nn.functional as F
from modules import sd_models, cache, errors, hashes, shared from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@@ -115,6 +118,29 @@ class NetworkModule:
if hasattr(self.sd_module, 'weight'): if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape self.shape = self.sd_module.weight.shape
self.ops = None
self.extra_kwargs = {}
if isinstance(self.sd_module, nn.Conv2d):
self.ops = F.conv2d
self.extra_kwargs = {
'stride': self.sd_module.stride,
'padding': self.sd_module.padding
}
elif isinstance(self.sd_module, nn.Linear):
self.ops = F.linear
elif isinstance(self.sd_module, nn.LayerNorm):
self.ops = F.layer_norm
self.extra_kwargs = {
'normalized_shape': self.sd_module.normalized_shape,
'eps': self.sd_module.eps
}
elif isinstance(self.sd_module, nn.GroupNorm):
self.ops = F.group_norm
self.extra_kwargs = {
'num_groups': self.sd_module.num_groups,
'eps': self.sd_module.eps
}
self.dim = None self.dim = None
self.bias = weights.w.get("bias") self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
@@ -155,5 +181,10 @@ class NetworkModule:
raise NotImplementedError() raise NotImplementedError()
def forward(self, x, y): def forward(self, x, y):
raise NotImplementedError() """A general forward implementation for all modules"""
if self.ops is None:
raise NotImplementedError()
else:
updown, ex_bias = self.calc_updown(self.sd_module.weight)
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
+6 -6
View File
@@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_current_names = wanted_names self.network_current_names = wanted_names
def network_forward(module, input, original_forward): def network_forward(org_module, input, original_forward):
""" """
Old way of applying Lora by executing operations during layer's forward. Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation. Stacking many loras this way results in big performance degradation.
""" """
if len(loaded_networks) == 0: if len(loaded_networks) == 0:
return original_forward(module, input) return original_forward(org_module, input)
input = devices.cond_cast_unet(input) input = devices.cond_cast_unet(input)
network_restore_weights_from_backup(module) network_restore_weights_from_backup(org_module)
network_reset_cached_weight(module) network_reset_cached_weight(org_module)
y = original_forward(module, input) y = original_forward(org_module, input)
network_layer_name = getattr(module, 'network_layer_name', None) network_layer_name = getattr(org_module, 'network_layer_name', None)
for lora in loaded_networks: for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None) module = lora.modules.get(network_layer_name, None)
if module is None: if module is None: