Skip to content

Commit

Permalink
Only download CLIP on rank 0 when doing eval (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Apr 15, 2024
1 parent eef6a01 commit 41b13bc
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions diffusion/evaluation/clean_fid_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import json
import os
from typing import List, Optional
from typing import Dict, List, Optional

import clip
import torch
Expand Down Expand Up @@ -50,6 +50,7 @@ class CleanFIDEvaluator:
precision (str): The precision to use for evaluation. Default: ``'amp_fp16'``.
prompts (List[str], optional): The prompts to use for image visualtization.
Default: ``["A shiba inu wearing a blue sweater]``.
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
"""

Expand All @@ -69,7 +70,8 @@ def __init__(self,
output_dir: str = '/tmp/',
num_samples: Optional[int] = None,
precision: str = 'amp_fp16',
prompts: Optional[List[str]] = None):
prompts: Optional[List[str]] = None,
additional_generate_kwargs: Optional[Dict] = None):
self.model = model
self.tokenizer: PreTrainedTokenizerBase = model.tokenizer
self.eval_dataloader = eval_dataloader
Expand All @@ -86,6 +88,7 @@ def __init__(self,
self.num_samples = num_samples if num_samples is not None else float('inf')
self.precision = precision
self.prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater']
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
self.sdxl = model.sdxl

# Init loggers
Expand All @@ -107,7 +110,13 @@ def __init__(self,
self.clip_metric = self.clip_metric.to(self.device)

# Predownload the CLIP model for computing clip-fid
_, _ = clip.load('ViT-B/32', device=self.device)
clip_url = 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt'
clip_name = os.path.basename(clip_url)
clip_path = os.path.expanduser('~/.cache/clip')
if dist.get_local_rank() == 0:
clip.clip._download(clip_url, clip_path)
with dist.local_rank_zero_download_and_wait(os.path.join(clip_path, clip_name)):
clip.load('ViT-B/32', device=self.device)

def _generate_images(self, guidance_scale: float):
"""Core image generation function. Generates images at a given guidance scale.
Expand Down Expand Up @@ -156,7 +165,8 @@ def _generate_images(self, guidance_scale: float):
seed=seed,
crop_params=crop_params,
input_size_params=input_size_params,
progress_bar=False) # type: ignore
progress_bar=False,
**self.additional_generate_kwargs) # type: ignore
# Get the prompts from the tokens
text_captions = self.tokenizer.batch_decode(captions, skip_special_tokens=True)
self.clip_metric.update((generated_images * 255).to(torch.uint8), text_captions)
Expand Down Expand Up @@ -233,7 +243,8 @@ def _generate_images_from_prompts(self, guidance_scale: float):
height=self.size,
width=self.size,
guidance_scale=guidance_scale,
seed=self.seed) # type: ignore
seed=self.seed,
**self.additional_generate_kwargs) # type: ignore
else:
generated_images = []
return generated_images
Expand Down

0 comments on commit 41b13bc

Please sign in to comment.