Skip to content

Commit

Permalink
add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Vandenplas, Jeremie committed Apr 16, 2024
1 parent 572c331 commit 646a564
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/nf/nf_loss.f90
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
module nf_loss

!! This module will eventually provide a collection of loss functions and
!! their derivatives. For the time being it provides only the quadratic
!! function.
!! This module provides a collection of loss functions and their derivatives.
!! The implementation is based on an abstract loss derived type
!! which has the required eval and derivative methods.
!! An implementation of a new loss type thus requires writing a concrete
!! loss type that extends the abstract loss derived type, and that
!! implements concrete eval and derivative methods that accept vectors.

implicit none

Expand Down Expand Up @@ -31,12 +34,14 @@ end function loss_derivative_interface
end interface

type, extends(loss_type) :: mse
!! Mean Square Error loss function
contains
procedure, nopass :: eval => mse_eval
procedure, nopass :: derivative => mse_derivative
end type mse

type, extends(loss_type) :: quadratic
!! Quadratic loss function
contains
procedure, nopass :: eval => quadratic_eval
procedure, nopass :: derivative => quadratic_derivative
Expand Down Expand Up @@ -73,7 +78,7 @@ end function mse_derivative
pure module function quadratic_eval(true, predicted) result(res)
!! Quadratic loss function:
!!
!! L = (predicted - true)**2 / 2
!! L = sum((predicted - true)**2) / 2
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
Expand Down

0 comments on commit 646a564

Please sign in to comment.