Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation loss #914

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open

Conversation

rockerBOO
Copy link
Contributor

@rockerBOO rockerBOO commented Oct 30, 2023

Add a validation ratio parameter --validation_ratio (splits the training dataset into train/validation by the ratio. Defaults to 0.0 or no split)
Add loss/val log

Screenshot 2023-10-30 at 11-31-41 Weights   Biases

Screenshot 2023-10-30 at 11-35-44 Weights   Biases

@kohya-ss I want to propose these changes. Before I spent too much time, wanted to get your feedback. I copied the training process code to do in the validation process. This duplicates the code. Would you have feedback on how you'd want this to happen if you accept this code? I am willing to make the changes. Thank you!

See #193 for more discussion

@kohya-ss
Copy link
Owner

Thank you so much for your PR to our long pending project! Your implementation is simple and well done.

But unfortunately this may not work well. The reason is in the implementation of the dataset: the dataset does its own shuffling, so even if we specify the same index, it will not return the same data.

Aspect Ratio Bucketing is the reason for this implementation.
First, because of ARB, it is not possible to assemble a batch with a DataLoader (because each image size is different). Therefore, the batch is assembled in a dataset, and the DataLoader's batch_size is set to 1.
Also, each bucket is shuffled at each epoch. Therefore, even if the same index is specified in the dataset, different data will be returned.

This means that a validation dataset cannot simply be created as a split of the original dataset.
I think it is necessary to separate the read data inside the data loading process in the dataset.

I also think that regularization images need to be considered. Regularization images may need to be removed from the validation dataset.

However, I think it is possible to add the validation loss if we modify the dataset code...

That is my understanding, please let me know what you think.

@rockerBOO
Copy link
Contributor Author

rockerBOO commented Oct 31, 2023

Thank you Kohya :) I thought that it might be problematic to split the dataset but the code ran so I got a little excited.

Originally I had the following code that allowed me to pass in a validation dataset.

  • args.validation_dataset dataset path
  • args.validation_metadata_file metadata.json { "filename.jpg": { "caption": "the image caption" }, ... }
class ValidationDataset(train_util.FineTuningDataset):
   def __init__(self, image_dir, metadata_file):
       validation_subset = train_util.FineTuningSubset(image_dir=image_dir, metadata_file=metadata_file, num_repeats=1, shuffle_caption=False, keep_tokens=0, color_aug=None, flip_aug=None, face_crop_aug_range=None, random_crop=None, caption_dropout_rate=0.0, caption_dropout_every_n_epochs=0.0, caption_tag_dropout_rate=0.0, caption_prefix="", caption_suffix="", token_warmup_min=0, token_warmup_step=0)
       super().__init__(subsets=[validation_subset], batch_size=1, tokenizer=tokenizer, max_token_length=75, resolution=[768,448], enable_bucket=False, min_bucket_reso=256, max_bucket_reso=1024, bucket_reso_steps=64, bucket_no_upscale=False, debug_dataset=args.debug_dataset)

val_dataset = ValidationDataset(args.validation_dataset, args.validation_metadata_file)
val_dataset.make_buckets()
val_dataset.set_seed(args.seed)

It's really rough but is this a direction you would think we should go into? Mostly we can start with a separate validation dataset. I am using the FineTuningDataset option (which required a metadata file) but the DreamboothDataset option seemed easier as we could just pass a image directory. I'm really only reading the surface of these datasets classes at the moment but wanted to get your feedback first.

@kohya-ss
Copy link
Owner

kohya-ss commented Nov 1, 2023

Thank you for your comment!

Your original code appears to be a solid approach, and I believe it will work well. However, preparing new metadata (or an image directory), even if it's just a copy of the one for training, might be a bit cumbersome for users.

Therefore, your new approach to splitting the dataset might be more desirable.

I took a quick look at the code, and I think we can split the image files into training and validation sets in the following section for DreamBooth method:

img_paths = glob_images(subset.image_dir, "*")

Similarly, it seems possible to split the metadata in this section for Finetuning:
https://github.com/kohya-ss/sd-scripts/blob/2a23713f71628b2d1b88a51035b3e4ee2b5dbe46/library/train_util.py#L1515C1-L1522C1

In the Dataset constructor's arguments, I believe we can pass whether it's a training set or a validation set, as well as the proportion of the validation set.

However, I can't be sure if it will work correctly unless I read it a little more closely.

Unfortunately, I don't have time today, but I'll take a closer look tomorrow.

@rockerBOO
Copy link
Contributor Author

rockerBOO commented Nov 2, 2023

So I was considering your proposal to do the split inside the dataset when we initialize it. Maybe we'd want to do the split before we __init__ the class? The problem I'm thinking is once we are inside the dataset class we can't properly create 2 different datasets? We could set the dataset type from the outside but we'd still need to split the dataset up before.

Maybe we could utilize the Subset from PyTorch which has indices that would represent different dataset indices from the parent dataset. We could then iterate through the subsets.

Only problem with that approach would be the reg images? At least with the random_split function.

But we could setup our own Subset with our own indices and work around the reg images in the "validation set". Then we can keep everything else the same way. If I am thinking about that correctly.

Will wait for your considerations. Thank you :)

@rockerBOO
Copy link
Contributor Author

rockerBOO commented Nov 2, 2023

Additionally this may allow us to have the "validation" dataset be a [[dataset.subset]] where we don't split it but the user can pass it (maybe something like is_validation = true). Then we can create the 2 torch.utils.data.Subset (train, validation) from that.

@rockerBOO
Copy link
Contributor Author

For example I have the following code. I am showing an augmented code from PyTorch for random split but limiting it to the only the viable indices.

        def get_indices_without_reg(dataset: torch.utils.data.Dataset):
            return [i for i, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False]

        from typing import Sequence, Union
        from torch._utils import _accumulate
        import warnings
        from torch.utils.data.dataset import Subset

        # Augmented from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#random_split
        def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]):
            indices = get_indices_without_reg(dataset)
            random.shuffle(indices)

            subset_lengths = []

            for i, frac in enumerate(lengths):
                if frac < 0 or frac > 1:
                    raise ValueError(f"Fraction at index {i} is not between 0 and 1")
                n_items_in_split = int(math.floor(len(indices) * frac))
                subset_lengths.append(n_items_in_split)

            remainder = len(indices) - sum(subset_lengths)

            for i in range(remainder):
                idx_to_add_at = i % len(subset_lengths)
                subset_lengths[idx_to_add_at] += 1

            lengths = subset_lengths
            for i, length in enumerate(lengths):
                if length == 0:
                    warnings.warn(f"Length of split at index {i} is 0. "
                                  f"This might result in an empty dataset.")

            if sum(lengths) != len(indices):
                raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

            return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)]

And thus would work the same way to split it (ignoring the reg images).

            train_ratio = 1 - args.validation_ratio
            validation_ratio = args.validation_ratio
            train, val = random_split(
                train_dataset_group,
                [train_ratio, validation_ratio]
            )

Note that there are different random (using random.shuffle vs randperm from pytorch). This remains consistent though using a seed (since random.seed is being set).

Example indices:

print(train.indices)
print(val.indices)

[45, 8, 35, 36, 51, 34, 16, 24, 6, 14, 48, 9, 30, 27, 19, 0, 2, 32, 18, 23, 29, 4, 3, 33, 55, 10, 46, 25, 11, 38, 56, 40, 49, 1, 15, 12, 21, 50, 47, 7, 39, 20, 5, 22, 44, 41] 
[42, 31, 28, 17, 53, 26, 52, 37, 43, 13, 54]

@kohya-ss
Copy link
Owner

kohya-ss commented Nov 3, 2023

Thank you for your suggestion! And sorry for the delay.

It is generally a good idea to split the data set using indexes. However, with the current dataset implementation, it does not seem possible to split the dataset using indexes, since the same index returns different values, as mentioned above.

If the dataset is to have the same index return the same content, this in turn creates the problem that Aspect Ratio Bucketing does not work well.

The idea I have that would probably work is for config_util.py to split the dataset into train and validation.
Below is a pseudo implementation.

First, add the parameter validation_split (or ratio etc.) to the dataset configuration for config_util.py.

@dataclass
class BaseDatasetParams:
  tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
  max_token_length: int = None
  resolution: Optional[Tuple[int, int]] = None
  debug_dataset: bool = False
  validation_split: float = 0.0 # add this

Then, in config_util.py, make generate_dataset_group_by_blueprint to create two dataset groups, train and validation. The dataset type (train or validation) and validation split ratio are passed to dataset.

  # firstly, make train datasets
  datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
  for dataset_blueprint in dataset_group_blueprint.datasets:
    if dataset_blueprint.is_controlnet:
      subset_klass = ControlNetSubset
      dataset_klass = ControlNetDataset
    elif dataset_blueprint.is_dreambooth:
      subset_klass = DreamBoothSubset
      dataset_klass = DreamBoothDataset
    else:
      subset_klass = FineTuningSubset
      dataset_klass = FineTuningDataset

    subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
    dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params))
    datasets.append(dataset)
  
  # secondly, make validation datasets
  val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
  for dataset_blueprint in dataset_group_blueprint.datasets:
    if dataset_blueprint.params.validation_split <= 0.0:
      continue
    if dataset_blueprint.is_controlnet:
      subset_klass = ControlNetSubset
      dataset_klass = ControlNetDataset
    elif dataset_blueprint.is_dreambooth:
      subset_klass = DreamBoothSubset
      dataset_klass = DreamBoothDataset
    else:
      subset_klass = FineTuningSubset
      dataset_klass = FineTuningDataset

    subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
    dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
    val_datasets.append(dataset)

  ...

  return DatasetGroup(datasets), DatasetGroup(val_datasets) if val_datasets else None

The dataset splits the image file in two using a ratio and seed value. seed may need to be added as a parameter. This is DreamBoothDataset example.

class DreamBoothDataset(BaseDataset):
    def __init__(
        self,
        subsets: Sequence[DreamBoothSubset],
        is_train: bool, 
        batch_size: int,
        tokenizer,
        max_token_length,
        resolution,
        enable_bucket: bool,
        min_bucket_reso: int,
        max_bucket_reso: int,
        bucket_reso_steps: int,
        bucket_no_upscale: bool,
        prior_loss_weight: float,
        validation_split: float,
        debug_dataset,
    ) -> None:
        super().__init__(tokenizer, max_token_length, resolution, debug_dataset)

        assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"

        self.is_train = is_train
        self.validation_split = validation_split
        
        ...

        def load_dreambooth_dir(subset: DreamBoothSubset):
            if not os.path.isdir(subset.image_dir):
                print(f"not directory: {subset.image_dir}")
                return [], []

            img_paths = glob_images(subset.image_dir, "*")
            if self.validation_split > 0.0:
                img_paths = split_train_val(img_paths, self.is_train, self.validation_split, some_seed) # seed might be also needed in dataset config

            print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")

I don't think this works as is (will be some errors), but I think you get the basic idea.

This is something that I thought, and I also thought it is not easy to implement because of the complexity of the implementation.

I would be happy to hear your thoughts :)

@rockerBOO
Copy link
Contributor Author

rockerBOO commented Nov 5, 2023

Screenshot 2023-11-05 at 01-46-48 Weights   Biases

sample output log from running a validation split

In my dataset_config.toml

[[datasets]]
validation_seed=6789
validation_split=0.2
...

Ok, that seemed to have worked. Only thing was the seed factor. Using an alternate seed for the validation makes it a little tricky. I made it work with a validation_seed and without (which could use the training seed instead) but let me know what you think of that.

I also added caching the latents for the validation dataset as well and looked of other options that may need to be updated for both options.

If this looks good so far then we can possibly look into the duplicate code for the training block (the one I copy and pasted over into the validation section). Probably should unify those 2 so they both run the same code. If you have any thoughts on that part I'm happy to implement that.

Thank you! :)

train_network.py Outdated Show resolved Hide resolved
@kohya-ss
Copy link
Owner

kohya-ss commented Nov 5, 2023

Thank you for updating!

Using a training seed is a very good idea. I think it is fine to use that when omitting the seed for validation set.

Also thank you for consideration the latent caching.

As I commented above, I think the block for getting the batch, preparing, calling U-Net, and calculating the loss might be extracted as a new function to common training/validation codes. However, there may be a problem I am missing.

@rockerBOO
Copy link
Contributor Author

Ok I moved them over to process_batch and tried not to refactor anything so the function signature is a little long.

I ran into a bug (I think it's a bug at least) with Wandb to produce the charts. They changed how strict their functionality is for the step metrics. They want you to define a step metric for different steps, otherwise it drops it from being reported. Very frustrating to figure out but hopefully we'll get somewhere with that. I have additional code for defining these metrics. See https://docs.wandb.ai/guides/track/log/customize-logging-axes and wandb/wandb#6554 for more information.

This following code would also need to be made to define these metrics. Also associated probably with #792

    accelerator.init_trackers("kohya-train", config=args)

    if args.log_with == "wandb":
        wandb_tracker = accelerator.get_tracker("wandb").tracker
        wandb_tracker.define_metric("loss_validation_current_step")
        wandb_tracker.define_metric("loss/validation_current", step_metric="loss_validation_current_step")

        wandb_tracker.define_metric("loss_validation_average_step")
        wandb_tracker.define_metric("loss/validation_average", step_metric="loss_validation_average_step")

        wandb_tracker.define_metric("loss_epoch_average_step")
        wandb_tracker.define_metric("loss/epoch_average", step_metric="loss_epoch_average_step")

Note for the following charts I had to remove the , step=x part of accelerate.log so they are a little disjointed but showcasing that it does report. I set the code to be the accurate representation that should work once the wandb bug is fixed. Please try it with tensorboard or other trackers to confirm.

Screenshot 2023-11-05 at 16-35-33 Weights   Biases

@kohya-ss
Copy link
Owner

kohya-ss commented Nov 7, 2023

Thanks again for the update! The code looks very good.

The issue regarding wandb logging is annoying. Hopefully that bug will be fixed soon.

Sorry for the delay in reviewing this as I haven't had much time to do so. I will check as soon as I can find the time.

We only want to be enabling grad if we are training.
@rockerBOO
Copy link
Contributor Author

No worries about reviewing this. Would just want to not deal with too many conflicts so we are good :). I am continuing to use this branch and testing it. Will try to fix little bugs as I find them.

  • If a dataset directory is empty in the training images dataset it will say it's empty. This happens if all the images in the dataset get moved over to the validation dataset.

found directory /path/to/dir contains 0 image files
ignore subset with image_dir='/path/to/dir ': no images found / 画像が見つからないためサブセットを無視します

@rockerBOO rockerBOO changed the base branch from main to dev November 20, 2023 02:03
@rockerBOO
Copy link
Contributor Author

Validation of dreambooth method with repeats in respect to regularization images. We separate out the reg images between the 2 datasets (training and validation). This means the validation dataset will have a certain amount of the reg images to process. Is this idea to have the reg images in the validation dataset?

Secondly, we probably want validation to trigger at a certain number of steps. Also or separately with each epoch. During the dreambooth method where you have a lot of repeats of the training dataset, we probably want to be able to do validation during this run. Maybe we use the number of repeats as some indicator of when to do the validation runs or just plain steps.

@zhchaoxing
Copy link

Hi @rockerBOO , thank you so much for making the validation loss for DB! I need to run finetuning with validation loss, could you please support finetuning as well? Thanks in advance.

@rockerBOO
Copy link
Contributor Author

Hi @rockerBOO , thank you so much for making the validation loss for DB! I need to run finetuning with validation loss, could you please support finetuning as well? Thanks in advance.

I would love to add it to all the options. I don't have a good way of testing finetuning option or dreambooth so would need help making sure that is working properly if I made the changes. I think initially getting this through and then replicating it for TI and then the other options.

@markojak
Copy link

Will this be merged at some point?

@markojak
Copy link

@rockerBOO if I can help you test let me know can run any of the finetuning options locally or in the cloud

@kohya-ss
Copy link
Owner

Sorry it is taking so long. I will make time to deal with this earlier.

@markojak
Copy link

Any updates?

@mostlyhuman
Copy link

Was this ever implemented?

@rockerBOO
Copy link
Contributor Author

This was taken a little further in #1165 . Haven't tried it though.

I wanted to wait to merge until we got a dev/sd3/flux release out so it could get the proper attention.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants