diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 19ec8eea1..43accd9f3 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -144,7 +144,7 @@ def ipex_init(): # pylint: disable=too-many-statements ipex._C._DeviceProperties.minor = 2 #Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory] + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] torch._utils._get_available_device_type = lambda: "xpu" torch.has_cuda = True torch.cuda.has_half = True @@ -156,7 +156,6 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.get_device_properties.minor = 7 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 - # getDeviceIdListForCard is renamed since https://github.com/intel/intel-extension-for-pytorch/commit/835b41fd5c8b6facf9efee8312f20699850ee592 if hasattr(torch.xpu, 'getDeviceIdListForCard'): torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard diff --git a/library/ipex/attention.py b/library/ipex/attention.py index e38689f21..84848b6a6 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -10,13 +10,15 @@ def torch_bmm(input, mat2, *, out=None): #ARC GPUs can't allocate more than 4GB to a single block, Slice it: batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] - block_multiply = 2.4 if input.dtype == torch.float32 else 1.2 - block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB + block_multiply = input.element_size() + slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply + block_size = batch_size_attention * slice_block_size + split_slice_size = batch_size_attention - if block_size >= 4000: + if block_size > 4: do_split = True #Find something divisible with the input_tokens - while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000: + while (split_slice_size * slice_block_size) > 4: split_slice_size = split_slice_size // 2 if split_slice_size <= 1: split_slice_size = 1 @@ -24,12 +26,12 @@ def torch_bmm(input, mat2, *, out=None): else: do_split = False - split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB split_2_slice_size = input_tokens - if split_block_size >= 4000: + if split_slice_size * slice_block_size > 4: + slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply do_split_2 = True #Find something divisible with the input_tokens - while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000: + while (split_2_slice_size * slice_block_size2) > 4: split_2_slice_size = split_2_slice_size // 2 if split_2_slice_size <= 1: split_2_slice_size = 1 @@ -71,13 +73,16 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. else: shape_one, batch_size_attention, query_tokens, shape_four = query.shape no_shape_one = False - block_multiply = 3.6 if query.dtype == torch.float32 else 1.8 - block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB + + block_multiply = query.element_size() + slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply + block_size = batch_size_attention * slice_block_size + split_slice_size = batch_size_attention - if block_size >= 4000: + if block_size > 4: do_split = True #Find something divisible with the shape_one - while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000: + while (split_slice_size * slice_block_size) > 4: split_slice_size = split_slice_size // 2 if split_slice_size <= 1: split_slice_size = 1 @@ -85,12 +90,12 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. else: do_split = False - split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB split_2_slice_size = query_tokens - if split_block_size >= 4000: + if split_slice_size * slice_block_size > 4: + slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply do_split_2 = True #Find something divisible with the batch_size_attention - while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000: + while (split_2_slice_size * slice_block_size2) > 4: split_2_slice_size = split_2_slice_size // 2 if split_2_slice_size <= 1: split_2_slice_size = 1 diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 4c39896ed..005ee49f0 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -55,13 +55,14 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a ) #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 - block_size = (batch_size_attention * query_tokens * shape_three) / 1024 * block_multiply #MB + block_multiply = query.element_size() + slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply + block_size = query_tokens * slice_block_size split_2_slice_size = query_tokens - if block_size >= 4000: + if block_size > 4: do_split_2 = True #Find something divisible with the query_tokens - while ((self.slice_size * split_2_slice_size * shape_three) / 1024 * block_multiply) > 4000: + while (split_2_slice_size * slice_block_size) > 4: split_2_slice_size = split_2_slice_size // 2 if split_2_slice_size <= 1: split_2_slice_size = 1 diff --git a/networks/merge_lora.py b/networks/merge_lora.py index c8d743f56..71492621e 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -110,7 +110,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): module.weight = torch.nn.Parameter(weight) -def merge_lora_models(models, ratios, merge_dtype): +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model base_dims = {} @@ -158,6 +158,12 @@ def merge_lora_models(models, ratios, merge_dtype): for key in lora_sd.keys(): if "alpha" in key: continue + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None lora_module_name = key[: key.rfind(".lora_")] @@ -165,12 +171,16 @@ def merge_lora_models(models, ratios, merge_dtype): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: merged_sd[key] = lora_sd[key] * scale @@ -178,6 +188,13 @@ def merge_lora_models(models, ratios, merge_dtype): for lora_module_name, alpha in base_alphas.items(): key = lora_module_name + ".alpha" merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:,perm] print("merged model") print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") @@ -256,7 +273,7 @@ def str_to_dtype(p): args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae ) else: - state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype) + state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) print(f"calculating hashes and creating metadata...") @@ -317,7 +334,19 @@ def setup_parser() -> argparse.ArgumentParser: help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", ) - + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + + "LoRAの重みをシャッフルする", + ) + return parser diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 0608c01f9..c513eb59f 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -113,7 +113,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module.weight = torch.nn.Parameter(weight) -def merge_lora_models(models, ratios, merge_dtype): +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model base_dims = {} @@ -161,6 +161,13 @@ def merge_lora_models(models, ratios, merge_dtype): for key in tqdm(lora_sd.keys()): if "alpha" in key: continue + + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None lora_module_name = key[: key.rfind(".lora_")] @@ -168,12 +175,16 @@ def merge_lora_models(models, ratios, merge_dtype): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: merged_sd[key] = lora_sd[key] * scale @@ -181,6 +192,13 @@ def merge_lora_models(models, ratios, merge_dtype): for lora_module_name, alpha in base_alphas.items(): key = lora_module_name + ".alpha" merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:,perm] print("merged model") print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") @@ -252,7 +270,7 @@ def str_to_dtype(p): args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype ) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype) + state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) print(f"calculating hashes and creating metadata...") @@ -307,6 +325,18 @@ def setup_parser() -> argparse.ArgumentParser: help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", ) + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + + "LoRAの重みをシャッフルする", + ) return parser