diff --git a/magvit2_pytorch/magvit2_pytorch.py b/magvit2_pytorch/magvit2_pytorch.py index 33e74b7..5df0938 100644 --- a/magvit2_pytorch/magvit2_pytorch.py +++ b/magvit2_pytorch/magvit2_pytorch.py @@ -910,6 +910,8 @@ def __init__( self.conv_out = CausalConv3d(init_dim, channels, output_conv_kernel_size, pad_mode = pad_mode) dim = init_dim + dim_out = dim + time_downsample_factor = 1 has_cond = False @@ -919,7 +921,11 @@ def __init__( if layer_type == 'residual': encoder_layer = ResidualUnit(dim, residual_conv_kernel_size) decoder_layer = ResidualUnit(dim, residual_conv_kernel_size) - dim_out = dim + + elif layer_type == 'consecutive_residual': + num_consecutive, = layer_params + encoder_layer = Sequential(*[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)]) + decoder_layer = Sequential(*[ResidualUnit(dim, residual_conv_kernel_size) for _ in range(num_consecutive)]) elif layer_type == 'cond_residual': assert exists(dim_cond), 'dim_cond must be passed into VideoTokenizer, if tokenizer is to be conditioned' @@ -928,6 +934,7 @@ def __init__( encoder_layer = ResidualUnitMod(dim, residual_conv_kernel_size, dim_cond = int(dim_cond * dim_cond_expansion_factor)) decoder_layer = ResidualUnitMod(dim, residual_conv_kernel_size, dim_cond = int(dim_cond * dim_cond_expansion_factor)) + dim_out = dim elif layer_type == 'compress_space': dim_out, = layer_params diff --git a/magvit2_pytorch/version.py b/magvit2_pytorch/version.py index abebd78..369521b 100644 --- a/magvit2_pytorch/version.py +++ b/magvit2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.0.44' +__version__ = '0.0.45'