-
Notifications
You must be signed in to change notification settings - Fork 17
feat(scalarization)!: Add UW #721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
b9f3ca3
add UW
ppraneth 988882b
add UW
ppraneth cd5ec22
Merge branch 'main' into scalarization-4
ppraneth 2d581a0
fix docs
ppraneth 0b454f1
fix
ppraneth e05f76a
Merge branch 'main' into scalarization-4
ppraneth e91e2f3
fix modo import
ppraneth 375e627
Merge branch 'main' into scalarization-4
ppraneth 81afe8b
fix pre-commit error CI
ppraneth afbdb6e
Update src/torchjd/scalarization/_uw.py
ppraneth a0cd5bc
Update src/torchjd/scalarization/_uw.py
ppraneth 8514922
Update src/torchjd/scalarization/_uw.py
ppraneth 8989b11
Update docs/source/docs/scalarization/uw.rst
ppraneth 5c78a4d
Update src/torchjd/scalarization/_uw.py
ppraneth adaf4ce
Merge branch 'main' into scalarization-4
ppraneth e19dd80
fix
ppraneth 525789c
Revert re-exporting Stateful in scalarization and aggregation, add an…
ValerianRey File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| :hide-toc: | ||
|
|
||
| UW | ||
| == | ||
|
|
||
| .. autoclass:: torchjd.scalarization.UW | ||
| :members: __call__, reset |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| :orphan: | ||
|
|
||
| .. autoclass:: torchjd.Stateful | ||
| :members: reset |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| from collections.abc import Sequence | ||
|
|
||
| import torch | ||
| from torch import Tensor, nn | ||
|
|
||
| from torchjd._mixins import Stateful | ||
|
|
||
| from ._scalarizer_base import Scalarizer | ||
|
|
||
|
|
||
| class UW(Scalarizer, Stateful): | ||
| r""" | ||
| :class:`~torchjd.Stateful` | ||
| :class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using | ||
| learned per-task uncertainties. ``UW`` is short for Uncertainty Weighting, the method proposed | ||
| in `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics | ||
| <https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf>`_. | ||
|
|
||
| Each value :math:`L_i` is assigned a learnable log-variance :math:`s_i`, and the values are | ||
| combined as | ||
|
|
||
| .. math:: | ||
| \sum_i \left( \frac{1}{2} e^{-s_i} L_i + \frac{1}{2} s_i \right) | ||
|
|
||
| where: | ||
|
|
||
| - :math:`L_i` is the :math:`i`-th value (typically the loss of task :math:`i`); | ||
| - :math:`s_i = \log \sigma_i^2` is the learnable log-variance of task :math:`i`. | ||
|
|
||
| Following the paper, the log-variance :math:`s_i` is learned rather than the variance | ||
| :math:`\sigma_i^2` directly: this is numerically more stable (the combination never divides by | ||
| zero) and keeps :math:`s_i` unconstrained, since :math:`e^{-s_i}` is always positive. The | ||
| :math:`s_i` are stored as an ``nn.Parameter``, so the parameters of this scalarizer must be | ||
| passed to the optimizer to be learned jointly with the model. | ||
|
|
||
| :param shape: The shape of the values to scalarize, used to create one log-variance per value. | ||
| An ``int`` ``n`` is interpreted as the shape ``(n,)``. | ||
|
|
||
| The following example shows train a model with Uncertainty Weighting, as described in the paper. | ||
|
|
||
| >>> import torch | ||
| >>> from torch.nn import Linear | ||
| >>> | ||
| >>> from torchjd.scalarization import UW | ||
| >>> | ||
| >>> model = Linear(3, 2) | ||
| >>> scalarizer = UW(2) # Move to the right device with e.g. UW(2).to(device="cuda") | ||
| >>> optimizer = torch.optim.SGD([*model.parameters(), *scalarizer.parameters()], lr=0.1) | ||
| >>> | ||
| >>> features = torch.randn(8, 3) | ||
| >>> # Compute some dummy losses just for the sake of the example | ||
| >>> losses = model(features).pow(2).mean(dim=0) # One loss per output dimension. | ||
|
ppraneth marked this conversation as resolved.
|
||
| >>> loss = scalarizer(losses) | ||
| >>> loss.backward() | ||
| >>> optimizer.step() | ||
|
|
||
| .. note:: | ||
| The log-variances are initialized to ``0`` (i.e. :math:`\sigma_i^2 = 1`), which gives | ||
| uniform weights at the start of training. The paper reports that the result is robust to | ||
| this initialization. (`LibMTL <https://github.com/median-research-group/LibMTL>`_ | ||
| initializes them to ``-0.5`` instead.) | ||
| """ | ||
|
|
||
| def __init__(self, shape: int | Sequence[int]) -> None: | ||
| super().__init__() | ||
| self.log_var = nn.Parameter(torch.zeros(shape)) | ||
|
|
||
| def forward(self, values: Tensor, /) -> Tensor: | ||
| if values.shape != self.log_var.shape: | ||
| raise ValueError( | ||
| f"Parameter `values` should have shape {tuple(self.log_var.shape)} (matching the " | ||
| f"shape of the log-variances). Found `values.shape = {tuple(values.shape)}`.", | ||
| ) | ||
| return (0.5 * torch.exp(-self.log_var) * values + 0.5 * self.log_var).sum() | ||
|
|
||
| def reset(self) -> None: | ||
| with torch.no_grad(): | ||
| self.log_var.zero_() | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}(shape={tuple(self.log_var.shape)})" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.