feat(scalarization)!: Add UW#721
Conversation
PierreQuinton
left a comment
There was a problem hiding this comment.
I think I would prefer a usage example of how to co-train parameters of a model and the log_var parameters in the docstring.
|
Thanks, addressed all four:
|
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
|
Already looking very good. I'll read the UW paper and make a thorough review soon. |
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
ValerianRey
left a comment
There was a problem hiding this comment.
Very clean. I also like how the testing is simple and comprehensive. Thanks a lot for that.
I made a few minor change suggestions to slightly improve clarity. Also, I'd like to try to improve how the Stateful class is defined.
Currently (in this PR), Stateful is defined in torchjd._mixins, and is used in aggregation and scalarization. Its documentation entry is only in aggregation.
I think we could have Stateful still defined in torchjd._mixins, but exported as torchjd.aggregation.Stateful AND torchjd.scalarization.Stateful, and have one documentation entry for each. This way, we will avoid having to make changes to existing stateful aggregators in this PR, having to make a breaking change, and having the link to "Stateful" of UW go to the wrong package (i.e. go to the Aggregator page instead of the Scalarizer page).
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
|
@ValerianRey I have made changes can you review it again also one of the CI is failing but i think its not relevant to this PR |
… orphan doc entry for it
|
Actually what I suggested (a common Stateful class, exported by both aggregation and scalarization) does not work because sphinx doesn't allow that. I guess that's why you don't use autoclass in e19dd80. The fix is to just have a common Stateful class, only exported as torchjd.Stateful, and whose documentation page is an orphan. I made the change in 525789c We can merge now IMO. |
Adds
UW, the uncertainty weighting scalarizer from Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (Kendall et al., CVPR 2018). This is the first stateful, trainable scalarizer, so the PR also moves theStatefulmixin to a shared location so both the aggregation and scalarization packages can use it.UWEach value$L_i$ (typically a per-task loss) is assigned a learnable log-variance $s_i = \log \sigma_i^2$ , and the values are combined as:
This is the regression objective (eq. 7 in the paper) after substituting$s_i = \log \sigma_i^2$ , which matches the LibMTL implementation.
Following the paper, the model learns the log-variance$s_i$ rather than the variance $\sigma_i^2$ directly. This is numerically more stable (the combination never divides by zero) and keeps $s_i$ unconstrained, since $e^{-s_i}$ is always positive. The $s_i$ are stored as an
nn.Parameter, so the scalarizer's parameters must be passed to the optimizer to be learned jointly with the model.Design notes:
shapeis given at construction (UW(3)orUW((2, 3))), since the parameter has to exist before the optimizer is built. The shape is validated against the input at call time, likeConstant.0(so-0.5instead.reset()(fromStateful), which zeros the log-variances.StatefulmoveStatefulwas defined inaggregation/_mixins.py, but it is now needed by scalarization too. It moves to the sharedtorchjd/_mixins.py. This is backward compatible:torchjd.aggregation.Statefulstill works (re-exported).torchjd.Statefulis added as the new top-level path, since the mixin is now cross-cutting.Internal aggregation modules (
_cr_mogm,_gradvac,_nash_mtl) had their imports and docstring cross-references updated. No behavior change for existing aggregators.Tests
tests/unit/scalarization/test_uw.pycovers the value at init, int-vs-tuple shape equivalence, scalar output and gradient flow over all input shapes (0-dim, vector, matrix, higher-dim), gradient flow tolog_var, shape validation,reset(), that negative inputs are allowed (unlikeGeometricMean), trainability via an optimizer step, and the representations. The fulltests/unitsuite was run as a regression check since theStatefulmove touches the aggregation package.Question on the shape API
Unlike the stateless scalarizers,
UWcannot be shape-agnostic: it holds one learnable log-variance per value, and that parameter has to be created at construction time so it can be handed to the optimizer before training starts.I went with
shape: int | Sequence[int], soUW(3)builds a length-3 vector andUW((2, 3))builds a 2D grid. Reasons:Constantalready establishes the "fix the shape at construction, validate at call time" pattern for shape-bound scalarizers, andUWfollows it rather than being a 1D-only special case.intjust collapses to(n,)internally, soUW(3)is exactly as ergonomic as anum_tasksargument would be, while higher-dim losses still work for free.int | Sequence[int]for a shape matches how many torch constructors behave.The alternative is
num_tasks: intonly (1D vectors of m task losses, matching LibMTL exactly). It is slightly simpler conceptually, since "uncertainty per task" is naturally 1D, but it cannot scalarize higher-dim loss tensors and makesUWinconsistent with the other scalarizers. Happy to switch if you prefer that.