Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Multiclass documentation #686

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions docs/multi_species.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,29 @@
DeepForest allows training on multiple species annotations.
When creating a deepforest model object, pass the designed number of classes and a label dictionary that maps each numeric class to a character label. The number of classes can be either be specified in the config, or using config_args during creation.

```
``` python
m = main.deepforest(config_args={"num_classes":2},label_dict={"Alive":0,"Dead":1})
```

It is often, but not always, useful to start with a prebuilt model when trying to identify multiple species. This helps the model focus on learning the multiple classes and not waste data and time re-learning bounding boxes.

To load the backboard and box prediction portions of the release model, but create a classification model for more than one species.
To load the backbone and box prediction portions of the release model, but create a classification model for more than one species.
Here is an example using the alive/dead tree data stored in the package, but the same logic applies to the bird detector.

```


``` python
# Initialize new Deepforest model ( the model that you will train ) with your classes
m = main.deepforest(config_args={"num_classes":2}, label_dict={"Alive":0,"Dead":1})

# Inatialize Deepforest model ( the model that you will modify its regression head )
deepforest_release_model = main.deepforest()
deepforest_release_model.use_release()
deepforest_release_model.use_release() # or use_bird_release

# Extract single class backbone that will have useful features for multi-class classification
m.model.backbone.load_state_dict(deepforest_release_model.model.backbone.state_dict())

# load regression head in the new model
m.model.head.regression_head.load_state_dict(deepforest_release_model.model.head.regression_head.state_dict())

m.config["train"]["csv_file"] = get_data("testfile_multi.csv")
Expand All @@ -33,3 +40,5 @@ m.config["validation"]["val_accuracy_interval"] = 1
m.create_trainer()
m.trainer.fit(m)
```

* For more on loading with state_dict: [Pytorch Docs](https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-state-dict-recommended)
Loading