Skip to content

Commit

Permalink
[cm] Initial RONN code
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Dec 27, 2023
1 parent e7c3a3a commit f59585b
Showing 1 changed file with 49 additions and 31 deletions.
80 changes: 49 additions & 31 deletions examples/example_ronn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,66 +20,75 @@
log.setLevel(level=os.environ.get("LOGLEVEL", "INFO"))


class OverdriveModel(nn.Module):
class RONNModel(nn.Module):
def __init__(self,
activation: str = "ReLU",
init: str = "normal",
act_name: str = "relu",
init_name: str = "normal",
in_ch: int = 1,
n_blocks: int = 1,
channel_width: int = 1,
kernel_size: int = 3,
dilation_growth: int = 2,
n_params: int = 2,
n_cond_params: int = 2,
cond_dim: int = 128) -> None:
super().__init__()

# MLP layers for conditioning
self.n_controls = n_params
self.control_to_cond_network = nn.Sequential(
nn.Linear(n_params, cond_dim // 2),
# nn.ReLU(),
nn.Linear(cond_dim // 2, cond_dim),
# nn.ReLU(),
nn.Linear(cond_dim, cond_dim),
# nn.ReLU(),
# MLP layers for conditioning vector generation
self.n_cond_params = n_cond_params
self.cond_generator = nn.Sequential(
nn.Linear(n_cond_params, n_cond_params ** 2),
nn.ReLU(),
nn.Linear(n_cond_params ** 2, n_cond_params ** 4),
nn.ReLU(),
nn.Linear(n_cond_params ** 4, cond_dim),
nn.ReLU(),
)

# TCN model
out_channels = [channel_width] * n_blocks
dilations = [dilation_growth ** n for n in range(n_blocks)]
self.tcn = TCN(out_channels,
dilations,
in_ch,
self.tcn = TCN(in_ch,
out_channels,
kernel_size,
use_act=False,
dilations=dilations,
use_act=True,
act_name=act_name,
use_res=False,
cond_dim=cond_dim,
use_film_bn=False,
is_cached=True)
bias=True,
batch_size=2,
causal=True,
cached=True)

# Weight initialization
self.init_weights(init)
self.init_weights(init_name)

def forward(self, x: Tensor, params: Tensor) -> Tensor:
print(f"in x {x.min()}")
print(f"in x {x.max()}")
cond = self.control_to_cond_network(params) # Map params to conditioning vector
assert x.ndim == 3
assert params.ndim == 2
# print(f"in x {x.min()}")
# print(f"in x {x.max()}")
cond = self.cond_generator(params) # Map params to conditioning vector
x = self.tcn(x, cond) # Process the dry audio
# x = self.tcn(x) # Process the dry audio
# x = self.output(x) # Convert to 1 channel
# x = tr.tanh(x) # Ensure the wet audio is between -1 and 1
print(x.min())
print(x.mean())
print(x.max())
# print(x.min())
# print(x.mean())
# print(x.max())
return x

def init_weights(self, init: str) -> None:
def init_weights(self, init_name: str) -> None:
for k, param in dict(self.named_parameters()).items():
if "weight" in k:
self.init_param_weight(param, init)
self.init_param_weight(param, init_name)

@staticmethod
def init_param_weight(param: Tensor, init: str) -> None:
"""
Most of the code and experimental results in this method are from
https://github.com/csteinmetz1/ronn
"""
if init == "normal":
nn.init.normal_(param, std=1) # smooth
elif init == "uniform":
Expand All @@ -100,10 +109,10 @@ def init_param_weight(param: Tensor, init: str) -> None:

class OverdriveModelWrapper(WaveformToWaveformBase):
def get_model_name(self) -> str:
return "conv1d-overdrive.random"
return "tcn.ronn"

def get_model_authors(self) -> List[str]:
return ["Nao Tokui"]
return ["Christopher Mitcheltree"]

def get_model_short_description(self) -> str:
return "Neural distortion/overdrive effect"
Expand Down Expand Up @@ -153,6 +162,10 @@ def get_native_sample_rates(self) -> List[int]:
def get_native_buffer_sizes(self) -> List[int]:
return [] # Supports all buffer sizes

@tr.jit.export
def calc_model_delay_samples(self) -> int:
return self.model.tcn.get_delay_samples()

def do_forward_pass(self, x: Tensor, params: Dict[str, Tensor]) -> Tensor:
# conditioning for FiLM layer
p1 = params["P1"]
Expand All @@ -161,8 +174,13 @@ def do_forward_pass(self, x: Tensor, params: Dict[str, Tensor]) -> Tensor:
cond = tr.stack([p1, p2], dim=1) * depth
cond = cond.expand(2, cond.size(1))
x = x.unsqueeze(1)
# prev_x = x
x = self.model(x, cond)
x = x.squeeze(1)
max_val = x.abs().max() + 1e-8
x /= max_val
dc_offset = x.mean(dim=-1, keepdim=True)
x -= dc_offset
return x


Expand All @@ -172,7 +190,7 @@ def do_forward_pass(self, x: Tensor, params: Dict[str, Tensor]) -> Tensor:
args = parser.parse_args()
root_dir = Path(args.output)

model = OverdriveModel()
model = RONNModel()
wrapper = OverdriveModelWrapper(model)
metadata = wrapper.to_metadata()
save_neutone_model(
Expand Down

0 comments on commit f59585b

Please sign in to comment.