Skip to content

Commit

Permalink
[update] update loader with latents space
Browse files Browse the repository at this point in the history
  • Loading branch information
Jourdelune committed Jul 1, 2024
1 parent 2e8117b commit 1ce6c4a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
19 changes: 15 additions & 4 deletions audioenhancer/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,23 @@ def __getitem__(self, index: int) -> tuple:
compressed_waveform = compressed_waveform[:, : 2**self._pad_length_input]
base_waveform = base_waveform[:, : 2**self._pad_length_output]

encoded_compressed_waveform = self.autoencoder.compress(
compressed_waveform = self.autoencoder.preprocess(
compressed_waveform.audio_data, compressed_waveform.sample_rate
)

base_waveform = self.autoencoder.preprocess(
base_waveform.audio_data, base_waveform.sample_rate
)

compressed_waveform = compressed_waveform.transpose(0, 1).cuda()
base_waveform = base_waveform.transpose(0, 1).cuda()

encoded_compressed_waveform, _, _, _, _ = self.autoencoder.encode(
compressed_waveform
).codes
encoded_base_waveform = self.autoencoder.compress(base_waveform).codes
)

encoded_base_waveform, _, _, _, _ = self.autoencoder.encode(base_waveform)

# convert to mono or stereo
if self._mono:
encoded_compressed_waveform = encoded_compressed_waveform.mean(dim=1)
encoded_base_waveform = encoded_base_waveform.mean(dim=1)
Expand Down
3 changes: 2 additions & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,11 @@

# rearrange x and y
x = rearrange(x, "b c d t -> b t (c d)")
y = rearrange(y, "b c d t -> b t (c d)")

y_hat = model(x, mask=None)

y_hat = rearrange(y_hat, "b t (c d) -> b c d t", c=2, d=9)

loss = sum([loss_fn[i](y_hat, y) for i in range(len(loss_fn))])
loss.backward()

Expand Down

0 comments on commit 1ce6c4a

Please sign in to comment.