Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CIFAR-10 evaluation #12

Open
wants to merge 144 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
4424bda
Initial implementation of CIFAR evaluation, currently runs but haven'…
rohinmshah Jul 22, 2020
53d2f2e
Initial implementation of CIFAR evaluation, currently runs but haven'…
rohinmshah Jul 22, 2020
1fc4d9a
Add support for GPU training, miscellaneous improvements
rohinmshah Aug 5, 2020
5f24dd9
Pull out model training code into its own function
rohinmshah Aug 5, 2020
96dc990
Compatibility with new learn interface
rohinmshah Aug 5, 2020
513b530
Merge master
rohinmshah Aug 5, 2020
7ff2804
Merge branch 'master' into cifar_eval
rohinmshah Aug 6, 2020
89a30dc
Implement the correct augmentations for SimCLR on CIFAR-10
rohinmshah Aug 6, 2020
db49e92
Changes to optimizers and learning rates to be more in line with SimCLR
rohinmshah Aug 9, 2020
c29cf52
Add momentum to optimizer
rohinmshah Aug 10, 2020
66ae3d5
Fix indentation bug
rohinmshah Aug 10, 2020
8303ca5
Address comments on PR, except for LinearWarmupCosine documentation, …
rohinmshah Aug 17, 2020
0d1622f
Merge branch 'master' into cifar_eval
rohinmshah Aug 17, 2020
474b6f9
Rewrote LinearWarmupCosine to be more understandable
rohinmshah Aug 17, 2020
375e76f
Merge
rohinmshah Aug 25, 2020
4950c8e
Merge
rohinmshah Aug 25, 2020
294d436
Miscellaneous small fixes
rohinmshah Aug 25, 2020
9cae4a7
Make things more parameterizable
rohinmshah Aug 26, 2020
89d5864
Update .gitignore
RPC2 Apr 13, 2021
51dad87
Merge branch 'master' into cifar_eval
RPC2 Apr 13, 2021
3fef45e
update model setting
RPC2 Apr 19, 2021
a93d017
Make CIFAR runnable for RepL!
RPC2 Apr 20, 2021
f2fc56b
classification + cleanup
RPC2 Apr 20, 2021
988ec4a
some cleanup
RPC2 Apr 20, 2021
fa647ad
Hardcode warmup_epochs to 2
decodyng Apr 21, 2021
8fe17b5
Import time
decodyng Apr 21, 2021
95a01d0
Make testloader exist
decodyng Apr 21, 2021
6869714
Is RepL training?
decodyng Apr 21, 2021
d977571
Hardcode dataset length
decodyng Apr 21, 2021
574d8fc
Remove excess logging
decodyng Apr 21, 2021
45b285a
Remove Cosine Annealing to be consistent with repo
decodyng Apr 21, 2021
f0b8f71
Put their loss in for ours
decodyng Apr 21, 2021
fc10cd4
Comment otu their losss which is nan for some reason
decodyng Apr 21, 2021
97bf11a
Add breakpoint
decodyng Apr 21, 2021
c22329f
Fix config name
decodyng Apr 21, 2021
55963e0
Switch to running their loss
decodyng Apr 21, 2021
b51f49c
Add another breakpoint
decodyng Apr 21, 2021
2c3b69a
Fix numpy call
decodyng Apr 21, 2021
9e6c792
What if you used their loss but normalized first to maybe avoid infin…
decodyng Apr 21, 2021
09254b1
Add ability to do K means evaluation
decodyng Apr 21, 2021
d464fca
Accidentally called encoder.encoder
decodyng Apr 21, 2021
44e6213
double-import tqdm
decodyng Apr 21, 2021
a35ba0a
Maybe avoid needing traj_info
decodyng Apr 21, 2021
8033682
Remove unused feature, out
decodyng Apr 21, 2021
720c15d
Allow passing in a pretrained model
decodyng Apr 21, 2021
382e9fd
Make it easier to switch between our loss and repo loss
decodyng Apr 21, 2021
2f221c0
Normalize our features before using them in KNN
decodyng Apr 21, 2021
81932ee
Unbreak torch.nn.functional import
decodyng Apr 21, 2021
5724cef
Examine image scale before augmentations
decodyng Apr 22, 2021
8592c66
Explicitly use their model class
decodyng Apr 22, 2021
6bda628
Use SimCLR model for encoder at least
decodyng Apr 22, 2021
d1463f5
No longer expect a tuple in KNN code
decodyng Apr 22, 2021
767511d
Modify decoder kwargs to be closer to SimCLR
decodyng Apr 22, 2021
6e55630
Add comma back in
decodyng Apr 22, 2021
0fc1975
Add code to save images out
decodyng Apr 22, 2021
592a278
Remove image prepreprocessing to avoid double-normalizing
decodyng Apr 22, 2021
fc7388d
Add more image saving and warnings
decodyng Apr 22, 2021
56ebd6f
Add more image saving and warnings
decodyng Apr 22, 2021
618f4ea
Save out more images
decodyng Apr 22, 2021
8b6f75a
Try to get augmentations to match SimCLR
decodyng Apr 22, 2021
01b6a6a
Still convert to PILImage
decodyng Apr 22, 2021
c5aee28
Add back numpy conversion without x255
decodyng Apr 22, 2021
94f0b9a
Add 255x back in
decodyng Apr 22, 2021
be29ef9
Normalize in the same way as SimCLR
decodyng Apr 22, 2021
9270bbc
For some reason getting a dimension error
decodyng Apr 22, 2021
7039fbf
Go back to other normalization
decodyng Apr 22, 2021
cc24d47
Cleanup and final push for the evening
decodyng Apr 23, 2021
d201c9a
Switch from bilinear to bicubic interpolation
decodyng Apr 23, 2021
fda28dd
No longer convert to numpy array before PIL image
decodyng Apr 23, 2021
292df58
Transpose numpy array so PILImage has the right shape:
decodyng Apr 23, 2021
5bdada6
Breakpoint before augmentation
Apr 23, 2021
8b7ac9e
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
Apr 23, 2021
40d8e7a
Uniform_ contexts and target
decodyng Apr 23, 2021
28cdd83
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
Apr 23, 2021
43c2f57
log every interval
decodyng Apr 23, 2021
0049873
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
Apr 23, 2021
d5bdbb0
Make zs uniform
decodyng Apr 23, 2021
90fb7c0
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
Apr 23, 2021
4446621
Set seed to 10
decodyng Apr 23, 2021
1b16a55
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
Apr 23, 2021
9884d1b
Add seed back in
Apr 23, 2021
432c7b0
No longer have random zs
Apr 23, 2021
1224059
Remove breakpoint
Apr 23, 2021
e6bff32
Remove random seed
decodyng Apr 23, 2021
877ecc5
Examine distribution after encoder
decodyng Apr 23, 2021
fc9dad2
No longer randomize images
decodyng Apr 23, 2021
7bbb153
Remove extraneous breakpoint
decodyng Apr 23, 2021
cb06dbc
Add parameter check to repl
decodyng Apr 24, 2021
71e129f
Swap our data loader for theirs
decodyng Apr 26, 2021
5553d2a
Add dataloader back in and add breakpoint
decodyng Apr 26, 2021
294bab4
Try to get .next() to work
decodyng Apr 26, 2021
2884db6
Swap in new contexts/targets temporarily
decodyng Apr 26, 2021
9eb90b9
If we double augment that should break things... right?
decodyng Apr 26, 2021
31864d6
Switch back to using our data loader
decodyng Apr 26, 2021
502c223
Skip the decoding step entirely
decodyng Apr 26, 2021
c4096a1
Don't calculate norm on decoder while we're testing out not using it
decodyng Apr 26, 2021
3dd1e80
Remove decoder from _calculate_norms
decodyng Apr 26, 2021
f2645d5
try to use direct network output instead of a distribution
Apr 27, 2021
29a9083
return to using multivariate normal and adjust loss and temperature
Apr 27, 2021
31d04e5
test linear head
Apr 27, 2021
77afc23
Try to fully use SimCLR repo's linear evaluation code
RPC2 Apr 27, 2021
130c127
select test method
RPC2 Apr 27, 2021
869f3bd
Add comment
decodyng Apr 27, 2021
feffcf0
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
decodyng Apr 27, 2021
1a537bf
Specifically ablate change to decoder
decodyng Apr 27, 2021
f769190
Switch ReLu back to be after BatchNorm
decodyng Apr 27, 2021
5d2ff3f
Remove breakpoint on Github
decodyng Apr 27, 2021
ea8deae
config for running few trajs
RPC2 Apr 30, 2021
7319c1e
Merge branch 'master' into gcp-cyn
RPC2 Apr 30, 2021
6a2842e
Update chain_configs.py
RPC2 Apr 30, 2021
93783a1
Finding a good gpu number balance
RPC2 Apr 30, 2021
4c1c2b8
Add SimCLR model to default SimCLR settings
RPC2 May 6, 2021
01473d8
Try to use 3e-4 lr for SimCLR repl
RPC2 May 6, 2021
05805a0
Merge branch 'gcp-cyn' into cifar_eval
RPC2 May 6, 2021
701c691
update config
RPC2 May 6, 2021
17af15f
comment out context saving code
RPC2 May 6, 2021
766b160
Try augmenting with SimCLR default
RPC2 May 6, 2021
a398ed8
adjust augmenter
May 6, 2021
ae814c3
Try to use multiple GPUs
RPC2 May 6, 2021
2491d70
Add a script for running simclr
May 6, 2021
d992c52
adjust decoder input dim
RPC2 May 6, 2021
111a97d
Merge branch 'cifar_eval' of https://github.com/HumanCompatibleAI/il-…
RPC2 May 6, 2021
8d4ebd5
Adjust decoder shape and normalization
May 6, 2021
872101b
Update run_il.sh for long dmc runs with few trajs
May 6, 2021
56c724e
Merge branch 'gcp-cyn' of ssh://github.com/HumanCompatibleAI/il-repre…
May 6, 2021
3834b88
Setting up loading procgen dataset
RPC2 May 7, 2021
d57fcbf
Merge branch 'procgen' of github.com:HumanCompatibleAI/il-representat…
May 7, 2021
1f2a6ad
Merge branch 'gcp-cyn' into procgen
RPC2 May 7, 2021
c679534
Merge branch 'procgen' of github.com:HumanCompatibleAI/il-representat…
May 7, 2021
2c70645
Adding support for procgen (loading env)
May 7, 2021
2f0659a
Set Procgen env names
May 7, 2021
fe9558e
Update loading procgen envs
May 11, 2021
dcd44ca
Maybe we don't need next_obs?
May 11, 2021
e236df2
Env wrapper is already handled by Procgen
May 11, 2021
9faea2d
More clean up
May 11, 2021
19d1dc3
Add framestack
May 11, 2021
7e67b54
Adjust encoder network channel
RPC2 May 12, 2021
0234acf
Merge branch 'procgen' into cifar_eval
RPC2 May 12, 2021
f2c9be2
Update simclr running script
May 12, 2021
ad9a9d1
Try a smaller network
RPC2 May 12, 2021
4365e02
See if it can run end to end
RPC2 May 12, 2021
5eb9283
Update encoder kwargs
RPC2 May 12, 2021
5fdf57a
Current script to train simclr as repl
May 12, 2021
9983123
Use default augmentation
May 12, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ torchsummary~=1.5.1
#webdataset introduces breaking changes in 0.1.49, so setting this to an exact equality
webdataset==0.1.40
tqdm~=4.48.0
procgen==0.10.4

# Jupyter Lab is used for our experiment analysis notebook
jupyterlab~=2.2.6
Expand Down
2 changes: 1 addition & 1 deletion src/il_representations/algos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from il_representations.algos.representation_learner import RepresentationLearner, DEFAULT_HARDCODED_PARAMS
from il_representations.algos.encoders import MomentumEncoder, InverseDynamicsEncoder, TargetStoringActionEncoder, \
RecurrentEncoder, BaseEncoder, VAEEncoder, ActionEncodingEncoder, ActionEncodingInverseDynamicsEncoder, \
infer_action_shape_info
infer_action_shape_info, SimCLRModel
from il_representations.algos.decoders import NoOp, MomentumProjectionHead, \
BYOLProjectionHead, ActionConditionedVectorDecoder, ContrastiveInverseDynamicsConcatenationHead, \
ActionPredictionHead, PixelDecoder, SymmetricProjectionHead, AsymmetricProjectionHead
Expand Down
29 changes: 25 additions & 4 deletions src/il_representations/algos/augmenters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
either augment just the context, or both the context and the target, depending on the algorithm.
"""

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Augmenter(ABC):
def __init__(self, augmenter_spec, color_space):
augment_op = StandardAugmentations.from_string_spec(
augmenter_spec, color_space)
self.augment_op = augment_op
def __init__(self, augmenter_spec, color_space, augment_func=None):
self.augment_func = augment_func
if augment_func:
self.augment_op = augment_func
else:
augment_op = StandardAugmentations.from_string_spec(
augmenter_spec, color_space)
self.augment_op = augment_op

@abstractmethod
def __call__(self, contexts, targets):
Expand All @@ -33,6 +39,21 @@ def __call__(self, contexts, targets):

class AugmentContextAndTarget(Augmenter):
def __call__(self, contexts, targets):
pil_process_func = transforms.Compose([
transforms.ToPILImage()
])
if self.augment_func:
context_ret, target_ret = [], []
for context, target in zip(contexts, targets):
if isinstance(context, torch.Tensor) and \
isinstance(self.augment_op.transforms[0],
transforms.RandomResizedCrop):
context, target = pil_process_func(context.cpu()), \
pil_process_func(target.cpu())
context_ret.append(self.augment_op(context))
target_ret.append(self.augment_op(target))
return torch.stack(context_ret, dim=0).to(device), \
torch.stack(target_ret, dim=0).to(device)
return self.augment_op(contexts), self.augment_op(targets)


Expand Down
4 changes: 2 additions & 2 deletions src/il_representations/algos/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def get_sequential_from_architecture(architecture, representation_dim, projectio
input_dim = representation_dim
for layer_def in architecture:
layers.append(nn.Linear(input_dim, layer_def['output_dim']))
layers.append(nn.ReLU())
layers.append(nn.BatchNorm1d(num_features=layer_def['output_dim']))
layers.append(nn.ReLU(inplace=True))
input_dim = layer_def['output_dim']
layers.append(nn.Linear(input_dim, projection_dim))
return nn.Sequential(*layers)
Expand Down Expand Up @@ -131,7 +131,7 @@ def _apply_projection_layer(self, z_dist, mean_layer, stdev_layer):
# We better not have had a learned standard deviation in
# the encoder, since there's no clear way on how to pass
# it forward
assert np.all((z_dist.stddev == 1).numpy())
assert np.all((z_dist.stddev == 1).cpu().numpy())
stddev = self.ones_like_projection_dim(mean)
else:
stddev = stdev_layer(z_vector)
Expand Down
67 changes: 51 additions & 16 deletions src/il_representations/algos/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torchvision.models.resnet import BasicBlock as BasicResidualBlock
import torch
from torch import nn
from torchvision.models.resnet import resnet50, resnet34
import torch.nn.functional as F
from pyro.distributions import Delta

from gym import spaces
Expand Down Expand Up @@ -197,8 +199,10 @@ def __init__(self,
use_sn=False,
arch_str='MAGICALCNN-resnet-128',
ActivationCls=torch.nn.ReLU):

super().__init__()


# If block_type == resnet, use ResNet's basic block.
# If block_type == magical, use MAGICAL block from its paper.
assert arch_str in NETWORK_ARCHITECTURE_DEFINITIONS.keys()
Expand Down Expand Up @@ -265,11 +269,35 @@ def forward(self, x):
warn_on_non_image_tensor(x)
return self.shared_network(x)


class SimCLRModel(nn.Module):
def __init__(self, observation_space, representation_dim=128):
super(SimCLRModel, self).__init__()

self.f = []
in_channel = observation_space.shape[0]
for name, module in resnet34().named_children():
if name == 'conv1':
module = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False)
if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
self.f.append(module)
# encoder
# Temporarily add an extra layer to be closer to our model implementation
self.f = nn.Sequential(*self.f)


def forward(self, x):
x = self.f(x)
feature = torch.flatten(x, start_dim=1)
return F.normalize(feature, dim=-1)


# string names for convolutional networks; this makes it easier to choose
# between them from the command line
NETWORK_SHORT_NAMES = {
'BasicCNN': BasicCNN,
'MAGICALCNN': MAGICALCNN,
'SimCLRModel': SimCLRModel
}


Expand Down Expand Up @@ -348,22 +376,22 @@ class BaseEncoder(Encoder):
def __init__(self, obs_space, representation_dim, obs_encoder_cls=None,
learn_scale=False, latent_dim=None, scale_constant=1, obs_encoder_cls_kwargs=None):
"""
:param obs_space: The observation space that this Encoder will be used on
:param representation_dim: The number of dimensions of the representation
that will be learned
:param obs_encoder_cls: An internal architecture implementing `forward`
to return a single vector representing the mean representation z
of a fixed-variance representation distribution (in the deterministic
case), or a latent dimension, in the stochastic case. This is
expected NOT to end in a ReLU (i.e. final layer should be linear).
:param learn_scale: A flag for whether we want to learn a parametrized
standard deviation. If this is set to False, a constant value of
<scale_constant> will be returned as the standard deviation
:param latent_dim: Dimension of the latents that feed into mean and std networks
If not set, this defaults to representation_dim * 2.
:param scale_constant: The constant value that will be returned if learn_scale is
set to False.
:param obs_encoder_cls_kwargs: kwargs the encoder class will take.
:param obs_space: The observation space that this Encoder will be used on
:param representation_dim: The number of dimensions of the representation
that will be learned
:param obs_encoder_cls: An internal architecture implementing `forward`
to return a single vector representing the mean representation z
of a fixed-variance representation distribution (in the deterministic
case), or a latent dimension, in the stochastic case. This is
expected NOT to end in a ReLU (i.e. final layer should be linear).
:param learn_scale: A flag for whether we want to learn a parametrized
standard deviation. If this is set to False, a constant value of
<scale_constant> will be returned as the standard deviation
:param latent_dim: Dimension of the latents that feed into mean and std networks
If not set, this defaults to representation_dim * 2.
:param scale_constant: The constant value that will be returned if learn_scale is
set to False.
:param obs_encoder_cls_kwargs: kwargs the encoder class will take.
"""
super().__init__()
if obs_encoder_cls_kwargs is None:
Expand All @@ -380,6 +408,13 @@ def __init__(self, obs_space, representation_dim, obs_encoder_cls=None,
self.network = obs_encoder_cls(obs_space, representation_dim, **obs_encoder_cls_kwargs)
self.scale_constant = scale_constant

if torch.cuda.device_count() > 1:
print("Using", torch.cuda.device_count(), "GPUs!")
self.network = nn.DataParallel(self.network)

self.network.to(self.device)


def forward(self, x, traj_info):
if self.learn_scale:
return self.forward_with_stddev(x, traj_info)
Expand Down
90 changes: 58 additions & 32 deletions src/il_representations/algos/losses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
import torch
import numpy as np
import torch.nn.functional as F
import stable_baselines3.common.logger as sb_logger
from pyro.distributions import Delta
Expand Down Expand Up @@ -161,11 +162,12 @@ class SymmetricContrastiveLoss(RepresentationLoss):
all similarities with J, and also all similarities with I, and calculates cross-entropy on both
"""

def __init__(self, device, sample=False, temp=0.1, normalize=True):
def __init__(self, device, sample=False, temp=0.1, normalize=True, use_repo_loss=False):
super(SymmetricContrastiveLoss, self).__init__(device, sample)

self.criterion = torch.nn.CrossEntropyLoss()
self.temp = temp
self.use_repo_loss = use_repo_loss

# Most methods use either cosine similarity or matrix multiplication similarity. Since cosine similarity equals
# taking MatMul on normalized vectors, setting normalize=True is equivalent to using torch.CosineSimilarity().
Expand All @@ -180,50 +182,74 @@ def __call__(self, decoded_context_dist, target_dist, encoded_context_dist=None)
# decoded_context -> representation of context + optional projection head
# target -> representation of target + optional projection head
# encoded_context -> not used by this loss

decoded_contexts, targets = self.get_vector_forms(decoded_context_dist, target_dist)
z_i = decoded_contexts
z_j = targets
batch_size = z_i.shape[0]

if self.normalize: # Use cosine similarity

if self.use_repo_loss:
# Normalize to avoid infinities
z_i = F.normalize(z_i, dim=1)
z_j = F.normalize(z_j, dim=1)
out = torch.cat([z_i, z_j], dim=0)
# [2*B, 2*B]
sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / self.temp)
mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool()
# [2*B, 2*B-1]
sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)

# compute loss
pos_sim = torch.exp(torch.sum(z_i * z_j, dim=-1) / self.temp)
# [2*B]
pos_sim = torch.cat([pos_sim, pos_sim], dim=0)
loss = (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()
if torch.isnan(loss):
breakpoint()
return loss
else:
if not self.normalize:
breakpoint()
if self.normalize: # Use cosine similarity
z_i = F.normalize(z_i, dim=1)
z_j = F.normalize(z_j, dim=1)

mask = (torch.eye(batch_size) * self.large_num).to(self.device)

# Similarity of the original images with all other original images in current batch. Return a matrix of NxN.
logits_aa = torch.matmul(z_i, z_i.T) # NxN

# Values on the diagonal line are each image's similarity with itself
logits_aa = logits_aa - mask
# Similarity of the augmented images with all other augmented images.
logits_bb = torch.matmul(z_j, z_j.T) # NxN
logits_bb = logits_bb - mask
# Similarity of original images and augmented images
logits_ab = torch.matmul(z_i, z_j.T) # NxN
logits_ba = torch.matmul(z_j, z_i.T) # NxN

avg_self_similarity = logits_ab.diag().mean().item()
logits_other_sim_mask = ~torch.eye(batch_size, dtype=bool, device=logits_ab.device)
avg_other_similarity = logits_ab.masked_select(logits_other_sim_mask).mean().item()
mask = (torch.eye(batch_size) * self.large_num).to(self.device)

sb_logger.record('avg_self_similarity', avg_self_similarity)
sb_logger.record('avg_other_similarity', avg_other_similarity)
sb_logger.record('self_other_sim_delta', avg_self_similarity - avg_other_similarity)
# Similarity of the original images with all other original images in current batch. Return a matrix of NxN.
logits_aa = torch.matmul(z_i, z_i.T) # NxN

# Each row now contains an image's similarity with the batch's augmented images & original images. This applies
# to both original and augmented images (hence "symmetric").
logits_i = torch.cat((logits_ab, logits_aa), 1) # Nx2N
logits_j = torch.cat((logits_ba, logits_bb), 1) # Nx2N
logits = torch.cat((logits_i, logits_j), axis=0) # 2Nx2N
logits /= self.temp
# Values on the diagonal line are each image's similarity with itself
logits_aa = logits_aa - mask
# Similarity of the augmented images with all other augmented images.
logits_bb = torch.matmul(z_j, z_j.T) # NxN
logits_bb = logits_bb - mask
# Similarity of original images and augmented images
logits_ab = torch.matmul(z_i, z_j.T) # NxN
logits_ba = torch.matmul(z_j, z_i.T) # NxN

avg_self_similarity = logits_ab.diag().mean().item()
logits_other_sim_mask = ~torch.eye(batch_size, dtype=bool, device=logits_ab.device)
avg_other_similarity = logits_ab.masked_select(logits_other_sim_mask).mean().item()
sb_logger.record('avg_self_similarity', avg_self_similarity)
sb_logger.record('avg_other_similarity', avg_other_similarity)
sb_logger.record('self_other_sim_delta', avg_self_similarity - avg_other_similarity)

# Each row now contains an image's similarity with the batch's augmented images & original images. This applies
# to both original and augmented images (hence "symmetric").
logits_i = torch.cat((logits_ab, logits_aa), 1) # Nx2N
logits_j = torch.cat((logits_ba, logits_bb), 1) # Nx2N
logits = torch.cat((logits_i, logits_j), axis=0) # 2Nx2N
logits /= self.temp

# The values we want to maximize lie on the i-th index of each row i. i.e. the dot product of
# represent(image_i) and represent(augmented_image_i).
label = torch.arange(batch_size, dtype=torch.long).to(self.device)
labels = torch.cat((label, label), axis=0)
# The values we want to maximize lie on the i-th index of each row i. i.e. the dot product of
# represent(image_i) and represent(augmented_image_i).
label = torch.arange(batch_size, dtype=torch.long).to(self.device)
labels = torch.cat((label, label), axis=0)

return self.criterion(logits, labels)
return self.criterion(logits, labels)


class NegativeLogLikelihood(RepresentationLoss):
Expand Down
Loading