A Causality-Aware Perspective on Domain Generalization via Domain Intervention
Hits:
Key Words:REPRESENTATIONS
Abstract:Domain Generalization (DG) focuses on the Out-Of-Distribution (OOD) generalization, which is able to learn a robust model that generalizes the knowledge acquired from the source domain to the unseen target domain. However, due to the existence of the domain shift, domain-iniant representation learning is challenging. Guided by fine-grained knowledge, we propose a novel paradigm Mask-Shift-Inference (MSI) for DG based on the architecture of Convolutional Neural Networks (CNN). Different from relying on a series of constraints and assumptions for model optimization, this paradigm novelly shifts the focus to feature channels in the latent space for domain-iniant representation learning. We put forward a two-branch working mode of a main module and multiple domain-specific sub-modules. The latter can only achieve good prediction performance in its own specific domain but poor predictions in other source domains, which provides the main module with the fine-grained knowledge guidance and contributes to the improvement of the cognitive ability of MSI. Firstly, during the forward propagation of the main module, the proposed MSI accurately discards unstable channels based on spurious classifications ying across domains, which have domain-specific prediction limitations and are not conducive to generalization. In this process, a progressive scheme is adopted to adaptively increase the masking ratio according to the training progress to further reduce the risk of overfitting. Subsequently, our paradigm enters the compatible shifting stage before the formal prediction. Based on maximizing semantic retention, we implement the domain style matching and shifting through the simple transformation through Fourier transform, which can explicitly and safely shift the target domain back to the source domain whose style is closest to it, requiring no additional model updates and reducing the domain gap. Eventually, the paradigm MSI enters the formal inference stage. The updated target domain is predicted in the main module trained in the previous stage with the benefit of familiar knowledge from the nearest source domain masking scheme. Our paradigm is logically progressive, which can intuitively exclude the confounding influence of domain-specific spurious information along with mitigating domain shifts and implicitly perform semantically iniant representation learning, achieving robust OOD generalization. Extensive experimental results on PACS, VLCS, Office-Home and DomainNet datasets verify the superiority and effectiveness of the proposed method.
Volume:179
Issue:-
Translation or Not:no
-
|
|