Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training with custom dataset. #14

Open
SamiurRahman1 opened this issue Dec 30, 2020 · 20 comments
Open

Training with custom dataset. #14

SamiurRahman1 opened this issue Dec 30, 2020 · 20 comments

Comments

@SamiurRahman1
Copy link

Hello, i am new to Meta-Learning. For a project, i need to train a model with custom image data. Is it possible for me somehow to use your repo and train a model with my custom data?

Thank you

@siavash-khodadadeh
Copy link
Owner

siavash-khodadadeh commented Dec 30, 2020

Hello,

Yes, that would be possible. First, you need to subclass the Database class here.
Then you need to implement two abstract methods:
First one is this method that returns three dictionaries: train_classes, val_classes and test_classes. Each of these dictionaries maps a string which is the name of a class to a list that contains the full address of all instances in that class.
The second one is this other method. This method describes how should each of the instances in class be parsed. For example here, we show how to do that for JPG images. If your dataset is also JPG images, you can use our mixin or you can write your own parse_function to make sure that data shape and everything is fed to network as you wish.
As an example, here is how we define Omniglot dataset.
More examples could also be found in python files here.

After this, you just pass your database class to any of the meta-learning algorithms implemented in this repo, and it will be applied on your dataset. You can see tensorboard logs and saved models as well. If you want you can make a run file for it as well and contribute to the repo. The run file is something like this.
Let me know if you have any questions.

@SamiurRahman1
Copy link
Author

SamiurRahman1 commented Jan 2, 2021

Hi. Thank you so much for your time and nice explanation. I believe i was able to create the scripts that i needed according to your instruction. But i am getting an error when i run my script. Here is the error:

File "myTestRun.py", line 1, in <module>
    from .models.maml.maml import ModelAgnosticMetaLearningModel
ModuleNotFoundError: No module named '__main__.models'; '__main__' is not a package

Any idea why i am getting this?
Thanks
Edit: i placed my "myTestRun.py"(extracted from "maml_omniglot.py") in the "/models/maml/" folder and "myDataset.py"(subclassed from data_bases.py) in the "/databases/" folder.

@SamiurRahman1
Copy link
Author

If i write

from models.maml.maml import ModelAgnosticMetaLearningModel
from networks.maml_umtra_networks import SimpleModel

i get ModuleNotFoundError: No module named 'models'

@siavash-khodadadeh
Copy link
Owner

siavash-khodadadeh commented Jan 2, 2021

I see. I believe this is because python interpreter does not find the models module. The way to resolve it is to add the path of the root of the project to your system variable PYTHONPATH. So it would be something like this:

export PYTHONPATH=$PYTHONPATH:<absolute-path-to-project-root>
python models/maml/myTestRun.py

There is also a pythonic way for this as follows: in your myTestRun.py add

import sys


project_root_address = '<absolute-path-to-project-root-folder>'
sys.path.insert(0, project_root_address)

...
# The rest of your code here
...
...

@SamiurRahman1
Copy link
Author

Hi, thanks for your quick reply. I have manages to solve the problem. Now i have another one. I have faced this problem before and i still haven't been able to fix it. It has got something to do with the shape of the images. The error i am getting is:

File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: PartialTensorShape: Incompatible shapes during merge: [28,28,3] vs. [28,28,4]
	 [[{{node map_1/TensorArrayV2Stack/TensorListStack}}]] [Op:IteratorGetNext]

Since i have adopted your code written for omniglot, i am assuming the shape of my images don't match with omniglot dataset.
I would be really grateful if you could guide me towards how i could reshape(or properly shape) my images using which script.
p.s. i am a noob in tensorflow and image processing.

@SamiurRahman1
Copy link
Author

i think i did not understand how to use the parse_function. I have not used it anywhere. I need to implement that function. But i am not sure about where exactly.

@siavash-khodadadeh
Copy link
Owner

siavash-khodadadeh commented Jan 2, 2021

So you need to write a function similar to this in your database class. In fact
image = tf.image.decode_jpeg(tf.io.read_file(example_address))

this line loads your image assuming it is jpeg image. Note that if your image is not jpeg you have to decode it based on its format or you have to save it as a jpeg before using this.

The reshaping happens in this line that uses self.get_input_shape(), you can ignore it for now and just pass some fixed numbers based on your network architecture.
image = tf.image.resize(image, self.get_input_shape()[:2])

For example you can do something like this:
image = tf.image.resize(image, (84, 84))

@SamiurRahman1
Copy link
Author

SamiurRahman1 commented Jan 3, 2021

So let's say this is my subclass:

class MySubClass(JPGParseMixin,Database):
    def __init__(
            self,
            random_seed,
            num_train_classes,
            num_val_classes,
    ):
        self.num_train_classes = num_train_classes
        self.num_val_classes = num_val_classes
        super(MySubClass, self).__init__(
            settings.OMNIGLOT_RAW_DATA_ADDRESS,
            os.path.join(settings.PROJECT_ROOT_ADDRESS, 'data/'),
            random_seed=random_seed,
            input_shape=(28, 28, 1)
        )
def get_train_val_test_folders(self) -> Tuple:
... ... ... ...
#rest of my code...

i add this def _get_parse_function(self) -> Callable: part before my get_train_val_test_folders function?
And then how am i using my def _get_parse_function? By how i mean, at which point do i call this function and reshape my images? or is it done by itself?
sorry if this seems a stupid question.

@siavash-khodadadeh
Copy link
Owner

Yes, it is done in the algorithm part. As an example, you can look at this. Just make an instance of your dataset class and pass it to MAML algorithm. Or any other algorithm like UMTRA, Protonets, etc. Also, in your example above, make sure that you indent functions such that they are members of your class.

@SamiurRahman1
Copy link
Author

Hi, i have done that. Still get this error:

pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
tensorflow.python.framework.errors_impl.InvalidArgumentError: PartialTensorShape: Incompatible shapes during merge: [84,84,3] vs. [84,84,4]
	 [[{{node map_1/TensorArrayV2Stack/TensorListStack}}]]

any idea how to fix it?

@lboloni
Copy link
Collaborator

lboloni commented Jan 3, 2021 via email

@SamiurRahman1
Copy link
Author

Thank you for your suggestion. I ran the following bit of code:

img_pil = Image.open(myImg)
print('Pillow: ', img_pil.mode, img_pil.size)

img = cv2.imread(myImg, cv2.IMREAD_UNCHANGED)
print('OpenCV: ', img.shape)

and i got this output:

Pillow:  RGB (1280, 720)
OpenCV:  (720, 1280, 3)

Which suggests that my images are already in RGB and the channel is 3. So, i can't find where exactly the error is coming from.

@siavash-khodadadeh
Copy link
Owner

Can you please check your image shapes when you load them by TensorFlow? To do that you can add a line
tf.print(image.shape)
before returning the image in this function.

@SamiurRahman1
Copy link
Author

Hmm, Interesting. This is the reply i get if i run
tf.print(image.shape):

TensorShape([84, 84, None])

@SamiurRahman1
Copy link
Author

is it possible to get this output if the function doesn't get the input images? maybe i made a mistake while providing the input images. can it cause this problem?

@SamiurRahman1
Copy link
Author

another thing, in the _get_parse_function, if i comment out the image = tf.image.resize(image, (84, 84)) part, technically, i should get the actual size of the images when i run tf.print(image.shape). But what i get is: TensorShape([None, None, None])

@siavash-khodadadeh
Copy link
Owner

siavash-khodadadeh commented Jan 4, 2021

I would definitely try to do something like this to see if it solves the problem.
image = image[:, :, :3]
The reason why these are None is because of the way decode_jpeg works with channels attribute. See here. As you can see in the newer version of TF, they moved this function from tf.image to tf.io. Make sure to take that into account if you are going to use the latest version of TF.

@SamiurRahman1
Copy link
Author

so, since my code isn't throwing any error and it seems like it is training, i hope it is working and the problem has been solved. Now, i actually applied both of your suggestions so i am not sure which one did the trick. This is how my _get_parse_function looks like now:

def _get_parse_function(self) -> Callable:
        def parse_function(example_address):
            image = tf.io.read_file(example_address)
            image = tf.io.decode_jpeg(image)
            image = tf.image.resize(image, (84, 84))
            image = image[:, :, :3]
            image = tf.image.rgb_to_grayscale(image)
            image = tf.cast(image, tf.float32)
            return image / 255.
        return parse_function

I will give an update once my code finishes running.
Thanks a lot for your help!

@SamiurRahman1
Copy link
Author

SamiurRahman1 commented Jan 5, 2021

I have one more question. How do i read the log files and plot/view my training performance?
Also, it is stuck here:

No previous checkpoint found!
0it [00:00, ?it/s]

for a few hours when it started. Is it normal?

@siavash-khodadadeh
Copy link
Owner

siavash-khodadadeh commented Jan 5, 2021

Hello Samiur,

Thank you for using this repository for your project. Glad that you were able to run the training. Just to make sure in the future people can search for issues in repo easily, do you mind to open another issue for your last question since it is another topic?

Also, I keep this issue open until we add the description and details of how to use custom datasets to readme of the project.

Thanks again for your great questions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants