How to use this together with wandb? Plus a question about class imbalance #489
-
Hi Kevin, Thanks for creating this wonderful repo! I was just wondering what the best practice is to use this along with wandb, which I typically use for hyperparameter tune. Basically, I'll need to log each epoch's loss & metrics, etc., to wandb. I guess probably either not use the Also, as someone new to metric learning, I'm also wondering how well the standard metric learning losses & miners etc. handle imbalanced training data. For example, I have a few classes with tens of samples, while most classes only have one or two samples. This is like some contrastive learning scheme where I'll have only one or no positive sample and a lot of negatives. Could you recommend some papers that compare the losses or samplers or miners, etc., on those settings? Thanks a lot! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Yes you could pass in def hook(trainer):
for k,v in trainer.losses.items():
log(k, v)
trainer = MetricLossOnly(..., end_of_iteration_hook = hook) But instead of this, I would write my own training code, or use a framework like PyTorch Lightning, PyTorch Ignite, Catalyst etc. Their trainer classes offer more features and are maintained better, because the entire focus of those libraries is to make training easier.
A common technique in metric learning papers is to make each batch balanced. For example, if the batch size is 64, you could include 16 classes with 4 samples each. In this library its called from pytorch_metric_learning.samplers import MPerClassSampler
# pass this into your dataloader
sampler = MPerClassSampler(labels, m=4, batch_size=64) At the moment I can't think of any papers on this specific subject. |
Beta Was this translation helpful? Give feedback.
Yes you could pass in
end_of_iteration_hook
like this:But instead of this, I would write my own training code, or use a framework like PyTorch Lightning, PyTorch Ignite…