Skip to content

Commit

Permalink
Intrinsic pack replaced by pointers in get_params and `get_gradie…
Browse files Browse the repository at this point in the history
…nts` (#183)

* Replace intrinsic pack by pointers

* Dense layer: remove an avoidable reshape

* conv2d: avoid intrinsics pack and reshape

* replace a reshape by a pointer

* clean conv2d_layer_submodule

---------

Co-authored-by: Vandenplas, Jeremie <[email protected]>
Co-authored-by: milancurcic <[email protected]>
  • Loading branch information
3 people authored Jun 14, 2024
1 parent a843c83 commit 118f795
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 46 deletions.
8 changes: 4 additions & 4 deletions src/nf/nf_conv2d_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,19 @@ pure module function get_num_params(self) result(num_params)
!! Number of parameters
end function get_num_params

pure module function get_params(self) result(params)
module function get_params(self) result(params)
!! Return the parameters (weights and biases) of this layer.
!! The parameters are ordered as weights first, biases second.
class(conv2d_layer), intent(in) :: self
class(conv2d_layer), intent(in), target :: self
!! A `conv2d_layer` instance
real, allocatable :: params(:)
!! Parameters to get
end function get_params

pure module function get_gradients(self) result(gradients)
module function get_gradients(self) result(gradients)
!! Return the gradients of this layer.
!! The gradients are ordered as weights first, biases second.
class(conv2d_layer), intent(in) :: self
class(conv2d_layer), intent(in), target :: self
!! A `conv2d_layer` instance
real, allocatable :: gradients(:)
!! Gradients to get
Expand Down
29 changes: 18 additions & 11 deletions src/nf/nf_conv2d_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -189,24 +189,32 @@ pure module function get_num_params(self) result(num_params)
end function get_num_params


pure module function get_params(self) result(params)
class(conv2d_layer), intent(in) :: self
module function get_params(self) result(params)
class(conv2d_layer), intent(in), target :: self
real, allocatable :: params(:)

real, pointer :: w_(:) => null()

w_(1:size(self % kernel)) => self % kernel

params = [ &
pack(self % kernel, .true.), &
w_, &
self % biases &
]

end function get_params


pure module function get_gradients(self) result(gradients)
class(conv2d_layer), intent(in) :: self
module function get_gradients(self) result(gradients)
class(conv2d_layer), intent(in), target :: self
real, allocatable :: gradients(:)

real, pointer :: dw_(:) => null()

dw_(1:size(self % dw)) => self % dw

gradients = [ &
pack(self % dw, .true.), &
dw_, &
self % db &
]

Expand All @@ -219,7 +227,7 @@ module subroutine set_params(self, params)

! Check that the number of parameters is correct.
if (size(params) /= self % get_num_params()) then
error stop 'conv2d % set_params: Number of parameters does not match'
error stop 'conv2d % set_params: Number of parameters does not match'
end if

! Reshape the kernel.
Expand All @@ -229,10 +237,9 @@ module subroutine set_params(self, params)
)

! Reshape the biases.
self % biases = reshape( &
params(product(shape(self % kernel)) + 1:), &
[self % filters] &
)
associate(n => product(shape(self % kernel)))
self % biases = params(n + 1 : n + self % filters)
end associate

end subroutine set_params

Expand Down
10 changes: 5 additions & 5 deletions src/nf/nf_dense_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,19 @@ pure module function get_num_params(self) result(num_params)
!! Number of parameters in this layer
end function get_num_params

pure module function get_params(self) result(params)
module function get_params(self) result(params)
!! Return the parameters (weights and biases) of this layer.
!! The parameters are ordered as weights first, biases second.
class(dense_layer), intent(in) :: self
class(dense_layer), intent(in), target :: self
!! Dense layer instance
real, allocatable :: params(:)
!! Parameters of this layer
end function get_params

pure module function get_gradients(self) result(gradients)
module function get_gradients(self) result(gradients)
!! Return the gradients of this layer.
!! The gradients are ordered as weights first, biases second.
class(dense_layer), intent(in) :: self
class(dense_layer), intent(in), target :: self
!! Dense layer instance
real, allocatable :: gradients(:)
!! Gradients of this layer
Expand All @@ -110,7 +110,7 @@ module subroutine set_params(self, params)
!! The parameters are ordered as weights first, biases second.
class(dense_layer), intent(in out) :: self
!! Dense layer instance
real, intent(in) :: params(:)
real, intent(in), target :: params(:)
!! Parameters of this layer
end subroutine set_params

Expand Down
43 changes: 25 additions & 18 deletions src/nf/nf_dense_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,32 @@ pure module function get_num_params(self) result(num_params)
end function get_num_params


pure module function get_params(self) result(params)
class(dense_layer), intent(in) :: self
module function get_params(self) result(params)
class(dense_layer), intent(in), target :: self
real, allocatable :: params(:)

real, pointer :: w_(:) => null()

w_(1:size(self % weights)) => self % weights

params = [ &
pack(self % weights, .true.), &
w_, &
self % biases &
]

end function get_params


pure module function get_gradients(self) result(gradients)
class(dense_layer), intent(in) :: self
module function get_gradients(self) result(gradients)
class(dense_layer), intent(in), target :: self
real, allocatable :: gradients(:)

real, pointer :: dw_(:) => null()

dw_(1:size(self % dw)) => self % dw

gradients = [ &
pack(self % dw, .true.), &
dw_, &
self % db &
]

Expand All @@ -87,24 +95,23 @@ end function get_gradients

module subroutine set_params(self, params)
class(dense_layer), intent(in out) :: self
real, intent(in) :: params(:)
real, intent(in), target :: params(:)

real, pointer :: p_(:,:) => null()

! check if the number of parameters is correct
if (size(params) /= self % get_num_params()) then
error stop 'Error: number of parameters does not match'
end if

! reshape the weights
self % weights = reshape( &
params(:self % input_size * self % output_size), &
[self % input_size, self % output_size] &
)

! reshape the biases
self % biases = reshape( &
params(self % input_size * self % output_size + 1:), &
[self % output_size] &
)
associate(n => self % input_size * self % output_size)
! reshape the weights
p_(1:self % input_size, 1:self % output_size) => params(1 : n)
self % weights = p_

! reshape the biases
self % biases = params(n + 1 : n + self % output_size)
end associate

end subroutine set_params

Expand Down
4 changes: 2 additions & 2 deletions src/nf/nf_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ elemental module function get_num_params(self) result(num_params)
!! Number of parameters in this layer
end function get_num_params

pure module function get_params(self) result(params)
module function get_params(self) result(params)
!! Returns the parameters of this layer.
class(layer), intent(in) :: self
!! Layer instance
real, allocatable :: params(:)
!! Parameters of this layer
end function get_params

pure module function get_gradients(self) result(gradients)
module function get_gradients(self) result(gradients)
!! Returns the gradients of this layer.
class(layer), intent(in) :: self
!! Layer instance
Expand Down
4 changes: 2 additions & 2 deletions src/nf/nf_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ elemental module function get_num_params(self) result(num_params)

end function get_num_params

pure module function get_params(self) result(params)
module function get_params(self) result(params)
class(layer), intent(in) :: self
real, allocatable :: params(:)

Expand All @@ -323,7 +323,7 @@ pure module function get_params(self) result(params)

end function get_params

pure module function get_gradients(self) result(gradients)
module function get_gradients(self) result(gradients)
class(layer), intent(in) :: self
real, allocatable :: gradients(:)

Expand Down
4 changes: 2 additions & 2 deletions src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,15 @@ pure module integer function get_num_params(self)
!! Network instance
end function get_num_params

pure module function get_params(self) result(params)
module function get_params(self) result(params)
!! Get the network parameters (weights and biases).
class(network), intent(in) :: self
!! Network instance
real, allocatable :: params(:)
!! Network parameters to get
end function get_params

pure module function get_gradients(self) result(gradients)
module function get_gradients(self) result(gradients)
class(network), intent(in) :: self
!! Network instance
real, allocatable :: gradients(:)
Expand Down
4 changes: 2 additions & 2 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ pure module function get_num_params(self)
end function get_num_params


pure module function get_params(self) result(params)
module function get_params(self) result(params)
class(network), intent(in) :: self
real, allocatable :: params(:)
integer :: n, nstart, nend
Expand All @@ -546,7 +546,7 @@ pure module function get_params(self) result(params)
end function get_params


pure module function get_gradients(self) result(gradients)
module function get_gradients(self) result(gradients)
class(network), intent(in) :: self
real, allocatable :: gradients(:)
integer :: n, nstart, nend
Expand Down

0 comments on commit 118f795

Please sign in to comment.