Skip to content

Commit

Permalink
Text Encoder cache (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 27, 2024
1 parent bdac55e commit 3677094
Show file tree
Hide file tree
Showing 15 changed files with 637 additions and 480 deletions.
20 changes: 15 additions & 5 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,20 @@ def train(args):

_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.debug_dataset:
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
False,
)
)
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))

train_dataset_group.set_current_strategies()
Expand Down Expand Up @@ -236,7 +241,12 @@ def train(args):
t5xxl.to(accelerator.device)

text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
False,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)

Expand Down
11 changes: 8 additions & 3 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@

init_ipex()

from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
import train_network
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
import train_network


class FluxNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
Expand Down Expand Up @@ -174,13 +175,17 @@ def get_text_encoders_train_flags(self, args, text_encoders):

def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length

# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None
Expand Down
Loading

0 comments on commit 3677094

Please sign in to comment.