Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/train-latent-space' into train-l…
Browse files Browse the repository at this point in the history
…atent-space

# Conflicts:
#	audioenhancer/dataset/loader.py
#	audioenhancer/model/audio_ae/model.py
  • Loading branch information
ostix360 committed Jul 4, 2024
2 parents 6f9ed4b + 5446afc commit 8aa1f26
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 34 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,6 @@ test.py
data/output.*
*.pth
runs/
data/input.mp3
data/input.mp3

flagged
2 changes: 1 addition & 1 deletion audioenhancer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

SAMPLING_RATE = 44100
MAX_AUDIO_LENGTH = 10
BATCH_SIZE = 2
BATCH_SIZE = 4
EPOCH = 1
LOGGING_STEPS = 10
GRADIENT_ACCUMULATION_STEPS = 2
Expand Down
42 changes: 19 additions & 23 deletions audioenhancer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import os

import dac
import torch
import torchaudio
from einops import rearrange

from audioenhancer.model.audio_ae.model import model_xtransformer_small as model
from audiocraft.models import MultiBandDiffusion


class Inference:
Expand All @@ -22,8 +22,7 @@ def __init__(self, model_path: str, sampling_rate: int):

self._sampling_rate = sampling_rate

autoencoder_path = dac.utils.download(model_type="44khz")
self._autoencoder = dac.DAC.load(autoencoder_path).to(self.device)
self.autoencoder = MultiBandDiffusion.get_mbd_musicgen()

def load(self, waveform_path):
"""
Expand All @@ -45,6 +44,7 @@ def load(self, waveform_path):

return waveform.to(self.device)

@torch.no_grad()
def inference(self, audio_path: str, chunk_duration: int = 10):
"""Run inference on the given audio file.
Expand All @@ -68,35 +68,31 @@ def inference(self, audio_path: str, chunk_duration: int = 10):
0,
)

with torch.no_grad():
encoded, encoded_q, _, _, _, _ = self._autoencoder.encode(
chunk.transpose(0, 1)
)

# create input for the model
decoded = self._autoencoder.decode(encoded_q)
encoded = self.autoencoder.get_condition(
chunk.transpose(0, 1), sample_rate=self._sampling_rate
)

decoded = decoded.transpose(0, 1)
# create input for the model
decoded = self.autoencoder.generate(encoded)

ae_input = torch.cat([ae_input, decoded], dim=2)
decoded = decoded.transpose(0, 1)

encoded = encoded.unsqueeze(0)
c, d = encoded.shape[1], encoded.shape[2]
encoded = rearrange(encoded, "b c d t -> b (t c) d")
ae_input = torch.cat([ae_input, decoded], dim=2)

pred = self.model(encoded)
encoded = encoded.unsqueeze(0)
c, d = encoded.shape[1], encoded.shape[2]
encoded = rearrange(encoded, "b c d t -> b (t c) d")

pred = rearrange(pred, "b (t c) d -> b c d t", c=c, d=d)
pred = pred.squeeze(0)
pred = self.model(encoded)

# quantize
z_q, _, _, _, _ = self._autoencoder.quantizer(pred, None)
pred = rearrange(pred, "b (t c) d -> b c d t", c=c, d=d)
pred = pred.squeeze(0)

decoded = self._autoencoder.decode(z_q)
decoded = self.autoencoder.generate(pred)

decoded = decoded.transpose(0, 1)
decoded = decoded.transpose(0, 1)

output = torch.cat([output, decoded], dim=2)
output = torch.cat([output, decoded], dim=2)

# fix runtime error: numpy
output = output.squeeze(0).detach().cpu()
Expand Down
10 changes: 1 addition & 9 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

args = parser.parse_args()

dtype = torch.float32
dtype = torch.bfloat16

# Load the dataset
dataset = SynthDataset(
Expand Down Expand Up @@ -132,7 +132,6 @@
],
lr=1e-4,
betas=(0.95, 0.999),
eps=1e-6,
weight_decay=1e-3,
)

Expand Down Expand Up @@ -161,13 +160,6 @@
# print number of parameters
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()) / 1e6}M")

import dac

autoencoder_path = dac.utils.download(model_type="44khz")
autoencoder = dac.DAC.load(autoencoder_path).to("cuda")
autoencoder.eval()
autoencoder.requires_grad_(False)


def eval_model(model, test_loader):
model.eval()
Expand Down

0 comments on commit 8aa1f26

Please sign in to comment.