Skip to content

Commit

Permalink
Addition of the MSE loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
Vandenplas, Jeremie committed Apr 16, 2024
1 parent 6efeea0 commit 572c331
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/nf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module nf
use nf_layer, only: layer
use nf_layer_constructors, only: &
conv2d, dense, flatten, input, maxpool2d, reshape
use nf_loss, only: mse, quadratic
use nf_network, only: network
use nf_optimizers, only: sgd, rmsprop, adam, adagrad
use nf_activation, only: activation_function, elu, exponential, &
Expand Down
35 changes: 34 additions & 1 deletion src/nf/nf_loss.f90
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module nf_loss

private
public :: loss_type
public :: mse
public :: quadratic

type, abstract :: loss_type
Expand All @@ -29,6 +30,12 @@ pure function loss_derivative_interface(true, predicted) result(res)
end function loss_derivative_interface
end interface

type, extends(loss_type) :: mse
contains
procedure, nopass :: eval => mse_eval
procedure, nopass :: derivative => mse_derivative
end type mse

type, extends(loss_type) :: quadratic
contains
procedure, nopass :: eval => quadratic_eval
Expand All @@ -37,6 +44,32 @@ end function loss_derivative_interface

interface

pure module function mse_eval(true, predicted) result(res)
!! Mean Square Error loss function:
!!
!! L = sum((predicted - true)**2) / size(true)
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting loss value
end function mse_eval

pure module function mse_derivative(true, predicted) result(res)
!! First derivative of the Mean Square Error loss function:
!!
!! L = 2 * (predicted - true) / size(true)
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res(size(true))
!! Resulting loss values
end function mse_derivative

pure module function quadratic_eval(true, predicted) result(res)
!! Quadratic loss function:
!!
Expand All @@ -47,7 +80,7 @@ pure module function quadratic_eval(true, predicted) result(res)
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting loss values
!! Resulting loss value
end function quadratic_eval

pure module function quadratic_derivative(true, predicted) result(res)
Expand Down
14 changes: 14 additions & 0 deletions src/nf/nf_loss_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,18 @@ pure module function quadratic_derivative(true, predicted) result(res)
res = predicted - true
end function quadratic_derivative

pure module function mse_eval(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res
res = sum((predicted - true)**2) / size(true)
end function mse_eval

pure module function mse_derivative(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res(size(true))
res = 2 * (predicted - true) / size(true)
end function mse_derivative

end submodule nf_loss_submodule

0 comments on commit 572c331

Please sign in to comment.