Skip to content

A partial implementation of Continuous Diffusion for Categorical Data by Deepmind, in pytorch.

License

Notifications You must be signed in to change notification settings

elyxlz/cdcd-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

cdcd-pytorch

A partial implementation of Continuous Diffusion for Categorical Data by Deepmind, in pytorch.

Usage

from cdcd_pytorch import CDCD, DDPMScheduler

scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_start=0.0001,
    beta_end=0.01,
    beta_schedule="linear",
    clip_sample=False,
    prediction_type="epsilon",
)

model = CDCD(
    scheduler=scheduler,
    hidden_size=768,
    num_heads=8,
    depth=12,
    score_hidden_size=256, # size of embeddings
    embedding_max_length=1,
    embedding_features=512,        
)

# get your tokens and some conditioning signal
x = torch.randint(0, 50000, (8, 1000))
embedding = torch.randn(8, 1, 512)

# do this many times
loss = model(x, embedding=embedding)


# once you're done training, you can sample from the model with classifier-free-guidance
noise = torch.randn(1, 1024, 256)
token_pred = model.sample(noise, num_steps=50, embedding=embedding, embedding_scale=2.5)

TODO

  • Add self conditioning
  • Add input masking
  • Experiment with a two stage training, first stage to train the embeddings with cross entropy and a second stage with frozen embeddings and a score matching loss

About

A partial implementation of Continuous Diffusion for Categorical Data by Deepmind, in pytorch.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages