-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
First commit to add weighted mean square #140
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mind also adding tests for the value and gradient?
src/loss.jl
Outdated
end | ||
(::WeightedSquaredLoss)(y, w) = WeightedSquaredLoss(y, w) | ||
WeightedSquaredLoss() = WeightedSquaredLoss(nothing) | ||
target(wsl::WeightedSquaredLoss) = getfield(wsl, :y)#maybe need to return both :y and :weights? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the target should be sliceable and the loss should be callable on target's result to create a new one.
It's used for slicing/iterating over batches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, the point is that WeightedSquaredLoss(target(wsl))
should be able to run, did I get it right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
$ rg 'target\('
docs/src/examples/custom_loss_layer.md
49:SimpleChains.target(loss::BinaryLogitCrossEntropyLoss) = loss.targets
src/optimize.jl
93: tgt = view_slice_last(target(loss), f:l)
125: tgt = target(loss)
177: tgt = target(loss)
488: t = target(_chn)
679: tgt = target(chn)
src/loss.jl
25:target(_) = nothing
26:target(sc::SimpleChain) = target(last(sc.layers))
27:preserve_buffer(l::AbstractLoss) = target(l)
28:StrideArraysCore.object_and_preserve(l::AbstractLoss) = l, target(l)
31:iterate_over_losses(sc) = _iterate_over_losses(target(sc))
40: align(length(first(target(sl))) * static_sizeof(T)), static_sizeof(T)
42:function _layer_output_size_needs_temp_of_equal_len_as_target(
47: align(length(target(sl)) * static_sizeof(T)), static_sizeof(T)
66:target(sl::SquaredLoss) = getfield(sl, :y)
69:Base.getindex(sl::SquaredLoss, r) = SquaredLoss(view_slice_last(target(sl), r))
120:target(sl::AbsoluteLoss) = getfield(sl, :y)
127: AbsoluteLoss(view_slice_last(target(sl), r))
197:target(sl::LogitCrossEntropyLoss) = getfield(sl, :y)
205: _layer_output_size_needs_temp_of_equal_len_as_target(Val{T}(), sl, s)
212: _layer_output_size_needs_temp_of_equal_len_as_target(Val{T}(), sl, s)
254: LogitCrossEntropyLoss(view(target(sl), r))
273: correct_count(Y, target(loss))
283: ec = correct_count(Y, target(loss))
src/penalty.jl
68:target(c::AbstractPenalty) = target(getchain(c))
Note that we also need things like view_slice_last(target(loss), f:l)
to work.
So view_slice_last
should be implemented.
Some form of PtrArray(tgt)
should also work, but you could define a different function to use there that calls PtrArray
by default, as overloading constructors to return something else is generally frowned upon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few things @chriselrod .
So, as I thought target(wsl)
needs to give back all the field of the struct. This is needed because, as you pointed out, WeightedSquaredLoss(target(wsl))
need to be working.
So, I have updated the target method
target(wsl::WeightedSquaredLoss) = getfield(wsl, :y), getfield(wsl, :w)
Since this is giving back a tuple, I have added a constructor, using the splat operator
WeightedSquaredLoss(x::Tuple) = WeightedSquaredLoss(x...)
Do you have any consideration on that? In the meantime, I'll focus on view_slice_last
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I correctly understand, view_slice_last
is used to slice the fields of the loss. If so, this could possibly working.
function view_slice_last(target(wsl::WeightedSquaredLoss), r)
return Tuple(view_slice_last(f, r) for f in target(wsl))
end
I am returning a Tuple
assuming that this can work with my constructor I just created.
Co-authored-by: Chris Elrod <[email protected]>
Codecov ReportPatch coverage has no change and project coverage change:
Additional details and impacted files@@ Coverage Diff @@
## main #140 +/- ##
==========================================
- Coverage 73.82% 72.97% -0.85%
==========================================
Files 15 15
Lines 2617 2646 +29
==========================================
- Hits 1932 1931 -1
- Misses 685 715 +30
☔ View full report in Codecov by Sentry. |
Added the Weighted Mean Square Error Loss