Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First commit to add weighted mean square #140

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

marcobonici
Copy link

Added the Weighted Mean Square Error Loss

Copy link
Contributor

@chriselrod chriselrod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind also adding tests for the value and gradient?

src/loss.jl Outdated Show resolved Hide resolved
src/loss.jl Outdated
end
(::WeightedSquaredLoss)(y, w) = WeightedSquaredLoss(y, w)
WeightedSquaredLoss() = WeightedSquaredLoss(nothing)
target(wsl::WeightedSquaredLoss) = getfield(wsl, :y)#maybe need to return both :y and :weights?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the target should be sliceable and the loss should be callable on target's result to create a new one.
It's used for slicing/iterating over batches.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, the point is that WeightedSquaredLoss(target(wsl)) should be able to run, did I get it right?

Copy link
Contributor

@chriselrod chriselrod May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

$ rg 'target\('
docs/src/examples/custom_loss_layer.md
49:SimpleChains.target(loss::BinaryLogitCrossEntropyLoss) = loss.targets

src/optimize.jl
93:  tgt = view_slice_last(target(loss), f:l)
125:  tgt = target(loss)
177:  tgt = target(loss)
488:  t = target(_chn)
679:  tgt = target(chn)

src/loss.jl
25:target(_) = nothing
26:target(sc::SimpleChain) = target(last(sc.layers))
27:preserve_buffer(l::AbstractLoss) = target(l)
28:StrideArraysCore.object_and_preserve(l::AbstractLoss) = l, target(l)
31:iterate_over_losses(sc) = _iterate_over_losses(target(sc))
40:  align(length(first(target(sl))) * static_sizeof(T)), static_sizeof(T)
42:function _layer_output_size_needs_temp_of_equal_len_as_target(
47:  align(length(target(sl)) * static_sizeof(T)), static_sizeof(T)
66:target(sl::SquaredLoss) = getfield(sl, :y)
69:Base.getindex(sl::SquaredLoss, r) = SquaredLoss(view_slice_last(target(sl), r))
120:target(sl::AbsoluteLoss) = getfield(sl, :y)
127:  AbsoluteLoss(view_slice_last(target(sl), r))
197:target(sl::LogitCrossEntropyLoss) = getfield(sl, :y)
205:  _layer_output_size_needs_temp_of_equal_len_as_target(Val{T}(), sl, s)
212:  _layer_output_size_needs_temp_of_equal_len_as_target(Val{T}(), sl, s)
254:  LogitCrossEntropyLoss(view(target(sl), r))
273:  correct_count(Y, target(loss))
283:  ec = correct_count(Y, target(loss))

src/penalty.jl
68:target(c::AbstractPenalty) = target(getchain(c))

Note that we also need things like view_slice_last(target(loss), f:l) to work.

So view_slice_last should be implemented.
Some form of PtrArray(tgt) should also work, but you could define a different function to use there that calls PtrArray by default, as overloading constructors to return something else is generally frowned upon.

Copy link
Author

@marcobonici marcobonici May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things @chriselrod .

So, as I thought target(wsl) needs to give back all the field of the struct. This is needed because, as you pointed out, WeightedSquaredLoss(target(wsl)) need to be working.

So, I have updated the target method

target(wsl::WeightedSquaredLoss) = getfield(wsl, :y), getfield(wsl, :w)

Since this is giving back a tuple, I have added a constructor, using the splat operator
WeightedSquaredLoss(x::Tuple) = WeightedSquaredLoss(x...)

Do you have any consideration on that? In the meantime, I'll focus on view_slice_last.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I correctly understand, view_slice_last is used to slice the fields of the loss. If so, this could possibly working.

function view_slice_last(target(wsl::WeightedSquaredLoss), r)
    return Tuple(view_slice_last(f, r) for f in target(wsl))
end

I am returning a Tuple assuming that this can work with my constructor I just created.

@codecov
Copy link

codecov bot commented Jul 3, 2023

Codecov Report

Patch coverage has no change and project coverage change: -0.85 ⚠️

Comparison is base (1df08c6) 73.82% compared to head (0ba4462) 72.97%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #140      +/-   ##
==========================================
- Coverage   73.82%   72.97%   -0.85%     
==========================================
  Files          15       15              
  Lines        2617     2646      +29     
==========================================
- Hits         1932     1931       -1     
- Misses        685      715      +30     
Impacted Files Coverage Δ
src/SimpleChains.jl 100.00% <ø> (ø)
src/loss.jl 56.59% <0.00%> (-10.73%) ⬇️

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants