Skip to content

Commit

Permalink
Merge branch 'pr/1165'
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Mar 16, 2024
2 parents f1a482f + b5e8045 commit 582830b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
6 changes: 4 additions & 2 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False]
if subset_klass == DreamBoothSubset:
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets if subset_blueprint.params.is_reg is False]
else:
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)

Expand Down
15 changes: 12 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,7 @@ class ControlNetDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[ControlNetSubset],
is_train: bool,
batch_size: int,
tokenizer,
max_token_length,
Expand All @@ -1829,6 +1830,8 @@ def __init__(
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
validation_split: float,
validation_seed: Optional[int],
debug_dataset: float,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
Expand Down Expand Up @@ -1863,6 +1866,7 @@ def __init__(

self.dreambooth_dataset_delegate = DreamBoothDataset(
db_subsets,
is_train,
batch_size,
tokenizer,
max_token_length,
Expand All @@ -1874,14 +1878,19 @@ def __init__(
bucket_reso_steps,
bucket_no_upscale,
1.0,
validation_split,
validation_seed,
debug_dataset,
)

# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
self.image_data = self.dreambooth_dataset_delegate.image_data
self.batch_size = batch_size
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.is_train = is_train
self.validation_split = validation_split
self.validation_seed = validation_seed

# assert all conditioning data exists
missing_imgs = []
Expand Down Expand Up @@ -1914,8 +1923,8 @@ def __init__(
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
)

assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
#assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
#assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"

self.conditioning_image_transforms = IMAGE_TRANSFORMS

Expand Down

0 comments on commit 582830b

Please sign in to comment.