This is the implementation of Paper Optimization as a model for few-shot learning in pytorch
- Gradient-based algorithms are not designed for a limited number of updates. specifically when the objective function is non-convex (which only has a local optimum solution).
- Due to the random initialization of parameters, a limited number of updates will not lead to the optimal solution. (Using pre-trained network parameters greatly reduces the accuracy of the network as the trained task diverges from the target task)
The first is the quick acquisition of knowledge within each separate task presented. This process is guided by the second, which involves a slower extraction of information learned across all the tasks.
- we leverage that gradient descent update resembles the update for the cell state in an LSTM.
- Avoiding batch normalization during meta testing
- preprocessing method of Andrychowicz et al. (2016) worked well when applied to both the dimensions of the gradients and the losses at each time step.
- We set the cell state of the LSTM to be the parameters of the learner, or ct = θt
💡 Learner and the meta-learner are two different things. The learner is a neural network classifier and meta-learner LSTM is an optimizer trained to optimize a learner similar to the cell state update of LSTM
- Here the output produced by the meta-learner is again used by the learner and the meta-learner in the next iteration. Hence the two arrows from meta-learner.
- The dashed line indicates that the gradient of loss function of learner parameters is used in the meta-learner output equation.
- Random initialization of meta-learner parameters ( Initialize the cell state of the metalearner with the parameters of the Learner Network)
- Iterate for the data sets in the meta-train
- Randomly pick D_train and D_test from the meta-train
- Initialize the learner parameters
- Start training on D_train randomly pick a batch from (D_train)
- calculate Loss using the initialized learner parameters
- Calculate the learner parameters by leveraging the LSTM cell update.
- Calculate the loss on the (D_test) using the updated learner parameters
- Update the meta-learner parameters using gradient descent
- Parameter sharing
- Avoiding batch normalization during meta testing
[ Reading the images from each folder separately using ImageDataset Class] In each episode n_shot + n_eval = batch_size images are being read
ImageDataset is called in create episode called for reading images in each folder
Episode sampler is used as a batch sampler while loading data BatchSampler takes indices from your Sampler() instance (in this case 3 of them) and returns it as list so those can be used in your MyDataset getitem method
If you have any doubts contact me : [email protected]