Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of the Loss derived type and of the MSE loss function #175

Merged
merged 8 commits into from
Apr 19, 2024

Conversation

jvdp1
Copy link
Collaborator

@jvdp1 jvdp1 commented Apr 16, 2024

As discussed with @milancurcic in #173 :

  • addition of a derived type for loss functions
  • addition of the Mean Square Error loss function

TODO:

  • Addition of docs
  • Addition of tests

@jvdp1 jvdp1 marked this pull request as ready for review April 16, 2024 19:17
@jvdp1
Copy link
Collaborator Author

jvdp1 commented Apr 16, 2024

@milancurcic what is the strategy for the tests?

Comment on lines +28 to +33
pure module function mse_derivative(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res(size(true))
res = 2 * (predicted - true) / size(true)
end function mse_derivative
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be checked if it is valid.

@milancurcic milancurcic self-requested a review April 18, 2024 13:52
@milancurcic
Copy link
Member

Thanks, @jvdp1, I'll start a test program.

@milancurcic
Copy link
Member

@jvdp1 I put a few very minimal tests that check the expected values given simple inputs. Feel free to add if you can think of better tests. I've been also thinking about how we can test for the integration of these loss functions with the network; perhaps also using simple inputs and known outputs, but pass them through the network type.

@jvdp1
Copy link
Collaborator Author

jvdp1 commented Apr 19, 2024

Feel free to add if you can think of better tests.

Thank you. These tests LGTM.

perhaps also using simple inputs and known outputs, but pass them through the network type.
It could be a possibility. But I guess this will be more to test their support in the implementation than the functions themself. If so, would such tests be more appropriate in e.g., test_dense_network.f90?

@milancurcic
Copy link
Member

On second thought, let's wait on testing the integration with the network (regardless of where those tests would be defined). As we implemented general mechanisms to specify and use losses and optimizers, it's become apparent to me the important to separate model creation (i.e. via the network_from_layers constructor) from the "compilation", as it's done in more mature Python frameworks (e.g. in Keras you create the model first by specifying the architecture, and then in a separate step "compile" the model by passing it the loss function, the optimizer, and the eval metrics to use; this allows for example reusing the same network instance with different optimizers/losses, etc.).

I'll merge this and open a separate issue. Thank you for the PR!

@milancurcic milancurcic merged commit f7b6006 into modern-fortran:main Apr 19, 2024
2 checks passed
@jvdp1 jvdp1 deleted the loss_dt branch April 19, 2024 16:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants