Representation Learning via Invariant Causal Mechanisms
Self-supervised learning has emerged as a strategy to reduce the reliance on costly supervised signal by pretraining representations only using unlabeled data. These methods combine heuristic proxy classification tasks with data augmentations and have achieved significant success, but our theoretical understanding of this success remains limited. In this paper we analyze self-supervised representation learning using a causal framework. We show how data augmentations can be more effectively utilized through explicit invariance constraints on the proxy classifiers employed during pretraining. Based on this, we propose a novel self-supervised objective, Representation Learning via Invariant Causal Mechanisms (ReLIC), that enforces invariant prediction of proxy targets across augmentations through an invariance regularizer which yields improved generalization guarantees. Further, using causality we generalize contrastive learning, a particular kind of self-supervised method, and provide an alternative theoretical explanation for the success of these methods. Empirically, ReLIC significantly outperforms competing methods in terms of robustness and out-of-distribution generalization on ImageNet, while also significantly outperforming these methods on Atari achieving above human-level performance on out of games.
Training deep networks often relies heavily on large amounts of useful supervisory signal, such as labels for supervised learning or rewards for reinforcement learning. These training signals can be costly or otherwise impractical to acquire. On the other hand, unsupervised data is often abundantly available. Therefore, pretraining representations for unknown downstream tasks without the need for labels or extrinsic reward holds great promise for reducing the cost of applying machine learning models. To pretrain representations, self-supervised learning makes use of proxy tasks defined on unsupervised data. Recently, self-supervised methods using contrastive objectives have emerged as one of the most successful strategies for unsupervised representation learning (Oord et al., 2018; Hjelm et al., 2018; Chen et al., 2020a). These methods learn a representation by classifying every datapoint against all others datapoints (negative examples). Under assumptions on how the negative examples are sampled, minimizing the resulting contrastive loss has been justified as maximizing a lower bound on the mutual information (MI) between representations (Poole et al., 2019). However, (Tschannen et al., 2019) has shown that performance on downstream tasks may be more tightly correlated with the choice of encoder architecture than the achieved MI bound, highlighting issues with the MI theory of contrastive learning. Further, contrastive approaches compare different views of the data (usually under different data augmentations) to calculate similarity scores. This approach to computing scores has been empirically observed as a key success factor of contrastive methods, but has yet to be theoretically justified. This lack of a solid theoretical explanation for the effectiveness of contrastive methods hinders their further development.
To remedy the theoretical shortcomings, we analyze the problem of self-supervised representation learning through a causal lens. We formalize intuitions about the data generating process using a causal graph and leverage causal tools to derive properties of the optimal representation. We show that a representation should be an invariant predictor of proxy targets under interventions on features that are only correlated, but not causally related to the downstream targets of interest. Since neither causally nor purely correlationally related features are observed and thus performing actual interventions on them is not feasible, for learning representation with this property we use data augmentations to simulate a subset of possible interventions. Based on our causal interpretation, we propose a regularizer which enforces that the prediction of the proxy targets is invariant across data augmentations. We propose a novel objective for self-supervised representation learning called REpresentation Learning with Invariant Causal mechanisms (ReLIC). We show how this explicit invariance regularization leverages augmentations more effectively than previous self-supervised methods and that representations learned using ReLIC are guaranteed to generalize well to downstream tasks under weaker assumptions than those required by previous work (Saunshi et al., 2019).
Next we generalize contrastive learning and provide an alternative theoretical explanation to MI for the success of these methods. We generalize the proxy task of instance discrimination commonly used in contrastive learning using the causal concept of refinements (Chalupka et al., 2014). Intuitively, a refinement of a task can be understood as a more fine-grained variant of the original problem. For example, a refinement for classifying cats against dogs would be the task of classifying individual cat and dog breeds. The instance discrimination task results from the most fine-grained refinement, e.g. discriminating individual cats and dogs from one another. We show that using refinements as proxy tasks enables us to learn useful representations for downstream tasks. Specifically, using causal tools, we show that learning a representation on refinements such that it is an invariant predictor of proxy targets across augmentations is a sufficient condition for these representations to generalize to downstream tasks (cf. Theorem 1). In summary, we provide theoretical support both for the general form of the contrastive objective as well as for the use of data augmentations. Thus, we provide an alternative explanation to mutual information for the success of recent contrastive approaches namely that of causal refinements of downstream tasks.
We test ReLIC on a variety of prediction and reinforcement learning problems. First, we evaluate the quality of representations pretrained on ImageNet with a special focus on robustness and out-of-distribution generalization. ReLIC performs competitively with current state-of-the-art methods on ImageNet, while significantly outperforming competing methods on robustness and out-of-distribution generalization of the learned representations when tested on corrupted ImageNet (ImageNet-C (Hendrycks and Dietterich, 2019)) and a version of ImageNet that consist of different renditions of the same classes (ImageNet-R (Hendrycks et al., 2020)). In terms of robustness, ReLIC also significantly outperforms the supervised baseline with an absolute reduction of in error. Unlike much prior work that specifically focuses on computer vision tasks, we test ReLIC for representation learning in the context of reinforcement learning on the Atari suite (Bellemare et al., 2013). There we find that ReLIC significantly outperforms competing methods and achieves above human-level performance on out of games.
We formalize problem of self-supervised representation learning using causality and propose to more effectively leverage data augmentations through invariant prediction.
We propose a new self-supervised objective, REpresentation Learning with Invariance Causal mechanisms (ReLIC), that enforces invariant prediction through an explicit regularizer and show improved generalization guarantees.
We generalize contrastive learning using refinements and show that learning on refinements is a sufficient condition for learning useful representations; this provides an alternative explanation to MI for the success of contrastive methods.
2 Representation Learning via Invariant Causal Mechanisms
Problem setting. Let denote the unlabelled observed data and be a set of unknown tasks with denoting the targets for task . The tasks can represent both a multi-environment as well as a multi-task setup. Our goal is to pretrain with unsupervised data a representation that will be useful for solving the downstream tasks .
Causal interpretation. To effectively leverage common assumptions and intuitions about data generation of the unknown downstream tasks for the learning algorithm, we propose to formalize them using a causal graph.
We start from the following assumptions: a) the data is generated from content and style variables, with b) only content (and not style) being relevant for the unknown downstream tasks and c) content and style are independent, i.e. style changes are content-preserving.
For example, when classifying dogs against giraffes from images, different parts of the animals constitute content, while style could be, for example, background, lighting conditions and camera lens characteristics.
By assumption, content is a good representation of the data for downstream tasks and we therefore cast the goal of representation learning as estimating content.
In the following, we compactly formalize these assumptions with a causal graph
Let and be the latent variables describing content and style. In Figure 0(a), the directed arrows from and to the observed data (e.g. images) indicate that is generated based on content and style. The directed arrow from to the target (e.g. class labels) encodes the assumption that content directly influences the target tasks, while the absence of any directed arrow from to indicates that style does not. Thus, content has all the necessary information to predict . The absence of any directed path between and in Figure 0(a) encodes the intuition that these variables are independent, i.e. .
Using the independence of mechanisms (Peters et al., 2017), we can conclude that under this causal model performing interventions on does not change the conditional distribution , i.e. manipulating the value of does not influence this conditional distribution. Thus, is invariant under changes in style . We call an invariant representation for under , i.e.
where denotes the distribution arising from assigning the value with the domain of (Pearl, 2009). Specifically, using as a representation allows for us to predict targets stably across perturbations, i.e. content is both a useful and robust representation for tasks .
Since the targets are unknown, we will construct a proxy task in order to learn representations from unlabeled data only.
In order to learn useful representations for , we will construct proxy tasks that represents more fine-grained problems that ; for a more formal treatment of proxy tasks please refer to Section 3.
Further, to learn invariant representations, such as , we enforce Equation 1 which requires us to observe data under different style interventions, i.e. we need data that describes the same content under varying style.
Since we do not have access to , to simulate style variability we use content-preserving data augmentations (e.g. rotation, grayscaling, translation, cropping for images).
Specifically, we utilize data augmentations as interventions on the style variable , i.e. applying data augmentation corresponds to intervening on and setting it to .
ReLIC objective. Equation 1 provides a general scheme to estimate content (c.f. Figure 0(a)). We operationalize this by proposing to learn representations such that prediction of proxy targets from the representation is invariant under data augmentations. The representation must fulfill the following invariant prediction criteria
is the set of data augmentations which simulate interventions on the style variables and denotes .
To achieve invariant prediction, we propose to explicitly enforce invariance under augmentations through a regularizer. This gives rise to an objective for self-supervised learning we call Representation Learning via Invariant Causal Mechanisms (ReLIC). We write this objective as
where is the proxy task loss and is the Kullback-Leibler (KL) divergence. Note that any distance measure on distributions can be used in place of the KL divergence. We explain the remaining terms in detail below.
Concretely, as proxy task we associate to every datapoint the label . This corresponds to the instance discrimination task, commonly used in contrastive learning (Hadsell et al., 2006). We take pairs of points to compute similarity scores and use pairs of augmentations to perform a style intervention. Given a batch of samples , we use
with data augmented with and a softmax temperature parameter. We encode using a neural network and choose to be related to , e.g. or as a network with an exponential moving average of the weights of (e.g. target networks similar to (Grill et al., 2020)). To compare representations we use the function where is a fully-connected neural network often called the critic.
Combining these pieces, we learn representations by minimizing the following objective over the full set of data and augmentations
with the number of points we use to construct the contrast set and the weighting of the invariance penalty. We used the shorthand for . With appropriate choices for , , and above, Equation 3 recovers many recent state-of-the-art methods (c.f. Table 5 in Section A). Figure 0(b) presents a schematic of the ReLIC objective.
The explicit invariance penalty encourages the within-class distances (for a downstream task of interest) of the representations learned by ReLIC to be tightly concentrated. We show this empirically in Figure 2 and theoretically in Appendix B. In the following section we provide theoretical justification for using an instance discrimination-based contrastive loss using a causal perspective. We also show (cf. Theorem 1 below) that minimizing the contrastive loss alone (i.e. ) does not guarantee generalization. Instead, invariance across augmentations must be explicitly enforced.
3 Generalizing Contrastive Learning
Learning with refinements. In contrastive learning, the task of instance discrimination, i.e. classifying the dataset , is used as the proxy task. To better understand contrastive learning and motivate this proxy task, we generalize instance discrimination using the causal concept of refinements (Chalupka et al., 2014). Intuitively, a refinement of one problem is another more fine-grained problem. If task is to classify cats against dogs, then a refinement of is the task of classifying cats and dogs into their individual breeds. See Figure 4 for a further visual example. For any set of tasks, there exist many different refinements. However, the most fine-grained refinement corresponds exactly to classifying the dataset . Thus, the instance discrimination task used in contrastive learning is a specific type of refinement. For a definition and formal treatment of refinements please refer to Appendix D.
Let be targets of a proxy task that is a refinement for all tasks in .
Leveraging causal tools, we connect learning on refinements to learning on downstream tasks.
Specifically, we provide a theoretical justification for exchanging unknown downstream tasks with these specially constructed proxy tasks.
We show that if is an invariant representation for under changes in style , then is also an invariant representation for tasks in under changes in style .
Thus by enforcing invariance under style interventions on a refinement, we learn representations that generalize to downstream tasks.
Let be a family of downstream tasks. Let be a refinement for all tasks in . If is an invariant representation for under style interventions , then is an invariant representation for all tasks in under style interventions , i.e.
for all with . Thus, is a representation that generalizes to .
Theorem 1 states that if is a refinement of then learning a representation on is a sufficient condition for this representation to be useful on . For a formal exposition of these points and accompanying proofs, please refer to Appendix D. Recall that the instance discrimination proxy task is the most fine-grained refinement, and so the left hand side of 4 is satisfied for any downstream task satisfying the stated assumptions of the theorem.
We generalize contrastive learning through refinements and connect representations learned on refinements and downstream tasks in Theorem 1. Thus, using causality we provide an alternative explanation to mutual information for the success of contrastive learning. Note that our methodology of refinements is not limited to instance discrimination tasks and is thus more general than currently used contrastive losses. Real world data often includes rich sources of metadata which can be used to guide the construction of refinements by grouping the data according to any available meta-data. Note that the coarser we can create a refinement, the more data efficient we can expect to be when learning representations for downstream tasks. Further, we can also expect to require less supervised data to finetune the representation.
4 Related Work
Contrastive objectives and mutual information maximization. Many recent approaches to self-supervised learning are rooted in the well-established idea of maximizing mutual information (MI), e.g. Contrastive Predictive Coding (CPC) (Oord et al., 2018; Hénaff et al., 2019), Deep InfoMax (DIM) (Hjelm et al., 2018) and Augmented Multiscale DIM (AMDIM) (Bachman et al., 2019). These methods are based on noise contrastive estimation (NCE) (Gutmann and Hyvärinen, 2010) which, under specific conditions, can be viewed as a bound on MI (Poole et al., 2019). The resulting objective functions are commonly referred to as InfoNCE.
The precise role played by mutual information maximization in self-supervised learning is subject to some debate. (Tschannen et al., 2019) argue that the performance on downstream tasks is not correlated with the achieved bound on MI, but may be more tightly correlated with encoder architecture and capacity.
Importantly, InfoNCE objectives require custom architectures to ensure the network does not converge to non-informative solutions thus precluding the use of standard architectures.
Recently, several works (He et al., 2019; Chen et al., 2020a) successfully combined contrastive estimation with a standard ResNet-50 architecture.
In particular, SimCLR (Chen et al., 2020a) relies on a set of strong augmentations
Recently, (Saunshi et al., 2019) proposed a learning theoretic framework to analyze the performance of contrastive objectives. However, without strong assumptions on intra-class concentration they note that contrastive objectives are fundamentally limited in the representations they are able to learn. ReLIC explicitly enforces intra-class concentration via the invariance regularizer, ensuring that it generalizes under weaker assumptions. Unlike (Saunshi et al., 2019) which do not discuss augmentations, we incorporate augmentations into our theoretical explanation of contrastive methods.
The reasons for the improvement in performance from AMDIM through to SimCLR and BYOL are not easily explained by either the MI maximization or the learning theoretic viewpoint. Further, it is not clear why relatively minor architectural differences between the methods result in significant differences in performance nor is it obvious how current state-of-the-art can be improved. In contrast to prior art, the performance of ReLIC is explained by connections to causal theory. As such it gives a clear path for improving results by devising problem appropriate refinements, interventions and invariance penalties. Furthermore, the use of invariance penalties in ReLIC as dictated by causal theory yields significantly more robust representations that generalize better than those learned with SimCLR or BYOL.
Causality and invariance. Recently, the notion of invariant prediction has emerged as an important operational concept in causal inference (Peters et al., 2016). This idea has been used to learn classifiers which are robust against domain shifts (Gong et al., 2016). Notably, (Heinze-Deml and Meinshausen, 2017) propose to use group structure to delineate between different environments where the aim is to minimize the classification loss while also ensuring that the conditional variance of the prediction function within each group remains small. Unlike (Heinze-Deml and Meinshausen, 2017) who use supervised data and rely on having a grouping in the training data, our approach does not rely on ground-truth targets and can flexibly create groupings of the training data if none are present. Further, we enforce invariant prediction within the group by constraining the distance between distributions resulting from contrasting data across groups.
We first visualize the influence of the explicit invariance constraint in ReLIC on the linear separability of the learned representations. We then evaluate ReLIC on a number of prediction and reinforcement learning tasks for usefulness and robustness. For the prediction tasks, we test ReLIC after pretraining the representation in a self-supervised way on the training set of the ImageNet ILSVRC-2012 dataset (Russakovsky et al., 2015). We evaluate ReLIC in the linear evaluation setup on ImageNet and test its robustness and out-of-distribution generalization on datasets related to ImageNet. Unlike much prior work in contrastive learning which focuses specifically on computer vision tasks, we test ReLIC also in the context of learning representations for reinforcement learning. Specifically, we test ReLIC on the suite of Atari games (Bellemare et al., 2013) which consists of diverse games of varying difficulty.
Linear evaluation. In order to understand how representations learned by ReLIC differ from other methods, we compare it against those learned by AMDIM and SimCLR in terms of Fischer’s linear discriminant ratio (Friedman et al., 2009): where is the mean of the representations of class and is the index set of that class. A larger implies that classes are more easily separated with a linear classifier. This can be achieved by either increasing distances between classes (numerator) or shrinking within-class variance (denominator).
Figure 2 shows the distribution of for ReLIC, SimCLR and AMDIM after training as measured on the (downsampled) ImageNet validation set. The distance between medians of ReLIC and SimCLR is 162. AMDIM is tightly concentrated close to 20. The invariance penalty ensures that—even though labels are a-priori unknown—for ReLIC within-class variability of is concentrated leading to better linear separability between classes in the downstream task of interest. This is reflected in the rightward shift of the distribution of in Figure 2 for ReLIC compared with SimCLR and AMDIM which do not impose such a constraint.
|PIRL (Misra and Maaten, 2020)||63.6||-|
|CPC v2 (Hénaff et al., 2019)||63.8||85.3|
|CMC (Tian et al., 2019)||66.2||87.0|
|SimCLR (Chen et al., 2020a)||*||69.3||89.0|
|SwAV (Caron et al., 2020)||*||70.1||-|
|InfoMin Aug. (Tian et al., 2020)||73.0||91.1|
|SwAV (Caron et al., 2020)||75.3||-|
|ResNet-50 with target network|
|MoCo v2 (Chen et al., 2020b)||71.1||-|
|BYOL (Grill et al., 2020)||*||74.3||91.6|
Next we evaluate ReLIC’s representation by training a linear classifier of top of the fixed encoder following the procedure in (Kolesnikov et al., 2019; Chen et al., 2020a) and Appendix E.4. In Table 1, we report top-1 and top-5 accuracy on the ImageNet test set. Methods denoted with * use SimCLR augmentations (Chen et al., 2020a), while methods denoted use custom, stronger augmentations. Comparing methods which use SimCLR augmentations, ReLIC outperforms competing approaches on both ResNet-50 and ResNet-50 with target network. For completeness, we report results for SwAV (Caron et al., 2020) and InfoMin (Tian et al., 2020), but note that these methods use stronger augmentations which alone have been shown to boost performance by over . A fair comparison between different objectives can only be achieved under the same architecture and the same set of augmentations.
Robustness and generalization. We evaluate robustness and out-of-distribution generalization of ReLIC’s representation on datasets Imagenet-C (Hendrycks and Dietterich, 2019) and ImageNet-R (Hendrycks et al., 2020), respectively. To evaluate ReLIC’s representation, we train a linear classifier on top of the frozen representation following the procedure described in (Chen et al., 2020a) and appendix E.5.2. For Imagenet-C we report the mean Corruption Error (mCE) and Corruption Errors for Noise corruptions in Table 3. ReLIC has significantly lower mCE than both the supervised ResNet-50 baseline and the unsupervised methods SimCLR and BYOL. Also, it has the lowest Corruption Error on 14 out of 15 corruptions when compared to SimCLR and BYOL. Thus, we see that ReLIC learns the most robust representation. ReLIC also outperforms SimCLR and BYOL on ImageNet-R showing its superior out-of-distribution generalization ability; see Table 2. For further details and results please consult E.5.
|Method||Supervised||SimCLR||ReLIC (ours)||BYOL||ReLIC (ours)|
|Top-1 Error (%)||63.9||81.7||77.4||77.0||76.2|
|ResNet-50 with target network:|
Reinforcement Learning. Much prior work in contrastive learning has focused specifically on computer vision tasks. In order to compare these approaches in a different domain, we investigate representation learning in the context of reinforcement learning. We compare ReLIC as an auxiliary loss against other state of the art self-supervised losses on an agent trained on 57 Atari games. Using human normalized scores as a metric, we use the original architecture and hyperparameters of the R2D2 agent (Kapturowski et al., 2019) and supplement it with a second encoder trained with a given representation learning loss. When auxiliary losses are present, the Q-Network takes the output of the second encoder as an input. The Q-Network and the encoder are trained with separate optimizers. For the augmentation baseline, the Q-Network takes two identical encoders trained end-to-end. Table 4 shows a comparison between ReLIC, SimCLR, BYOL, CURL (Srinivas et al., 2020), and feeding augmented observations directly to the agent (Kostrikov et al., 2020). We find that ReLIC has a significant advantage over competing self-supervised methods, performing best in 25 out of 57 games. The next best performing method, CURL performs best in 11 games. Full details are presented in Section E.6.
|Number of superhuman games||51||49||49||49||34|
In this work we have analyzed self-supervised learning using a causal framework. Using a causal graph, we have formalized the problem of self-supervised representation learning and derived properties of the optimal representation. We have shown that representations need to be invariant predictors of proxy targets under interventions on features that are only correlated, but not causally related to the downstream tasks. We have leveraged data augmentations to simulate these interventions and have proposed to explicitly enforce this invariance constraint. Based on this, we have proposed a new self-supervised objective, Representation Learning via Invariant Causal Mechanisms (ReLIC), that enforces invariant prediction of proxy targets across augmentations using an invariance regularizer. Further, we have generalized contrastive methods using the concept of refinements and have shown that learning a representation on refinements using the principle of invariant prediction is a sufficient condition for these representations to generalize to downstream tasks. With this, we have provided an alternative explanation to mutual information for the success of contrastive methods. Empirically we have compared ReLIC against recent self-supervised methods on a variety of prediction and reinforcement learning tasks. Specifically, ReLIC significantly outperforms competing methods in terms of robustness and out-of-distribution generalization of the representations it learns on ImageNet. ReLIC also significantly outperforms related self-supervised methods on the Atari suite achieving superhuman performance on out of games. We aim to investigate the construction of more coarse-grained refinements and the empirical evaluation of different kinds of refinements in future work.
Acknowledgements. We thank David Balduzzi, Melanie Rey, Christina Heinze-Deml, Ilja Kuzborskij, Ali Eslami, Jeffrey de Fauw and Josip Djolonga for invaluable discussions.
Appendix A Relationship between ReLIC and other methods
|CPC (Hénaff et al., 2019)||PixelCNN||-|
|AMDIM (Bachman et al., 2019)||-||-|
|SimCLR (Chen et al., 2020a)||MLP, norml.||-|
|BYOL (Grill et al., 2020)||-||, 1 layer MLP, norml.|
|ReLIC (ours)||MLP, norml.||Eq. (3)|
Appendix B Distance concentration and generalization
Quantifying the generalization performance of representations learned on unlabelled data is a difficult task without imposing assumptions on the underlying structure of the data and the downstream tasks of interest. The results in (Saunshi et al., 2019) assume a latent class structure underlying the data. The similarity of images under each (potentially overlapping) latent class is measured by a probability distribution . In the contrastive setting a positive pair of points is said to be sampled from a distribution and a negative example is sampled from the marginal distribution. The task of interest is multi-class classification using the learned representation. In our setting the augmented data points and take the roles of the pairs of positive and negative points, respectively.
In this section, under the same structural assumptions on the data as (Saunshi et al., 2019) we will show that a similar result holds but under weaker assumptions on the function, .
To intuit the following results, we can view our explicit invariance constraint through the lens of distance concentration. Its effect can be seen intuitively in Figure 3. The shaded region represents the set of augmentations, around an image. Depicted are two images and from the ImageNet class Stingray. The points and are augmentations which correspond to a region of overlap between the augmentation sets of and . If the augmentations and are similar enough, encouraging to be close to and similarly for and indirectly encourages to be close to . This has the effect of concentrating distances between similar images. We will make this intuition more formal in the following discussion.
Consider a modified, Euclidean distance regularized version of our objective
where with . Here is the logistic loss. For a single negative, this is equivalent to the standard ReLIC objective with an identity critic.
Assumptions. We require that the following assumptions hold: (A1) is -Lipschitz and minimizes eq. (5) such that the constraint is active and (A2) x is a bounded variable.
Lemma 1 (Concentration).
If assumption (A1) holds for , and (A2) holds for , is a sub-Gaussian random variable with parameter .
See Appendix C for proof. This result states that the Euclidean version of our invariance regularizer has the effect of contracting the within-class variance of the data. Figure 2 shows that this holds in practise for the original version of our objective in eq. (3). This guarantees that the following generalization result from (Saunshi et al., 2019) holds. For brevity we state an informal version of the Theorem with details deferred to the original publication.
Theorem 2 (Generalization. Adapted from Lemma B.2. from (Saunshi et al., 2019)).
Let be the standard -wise hinge loss of the linear classification function whose column is the mean of representations corresponding to class . Further, let use the the hinge loss with margin with constant and . If is the minimizer of eq. (5) and if Assumptions (A1) and (A2) hold then with high probability
Here, is a standard generalization bound which depends on the Rademacher complexity of the function class and the sample size, .
For all practical purposes, the final generalization result is identical to (Saunshi et al., 2019) stating that —which is learned by minimizing a contrastive objective on unlabelled data—performs well on labelled data. However, this crucially depends on the intraclass concentration of the representation, that is sub-Gaussian with parameter . Whereas in (Saunshi et al., 2019) this was assumed to hold, our Lemma 1 shows that the necessary concentration is ensured by our invariance penalty. Experimentally we see this property holds in practise (figure 2).
Appendix C Additional Results
Proof of Lemma 1.
Assume the data is -sub-Gaussian. In practise this holds since is bounded. It immediately follows that -Lipschitz function sub-Gaussian with parameter at most . Now we will characterize the reduction in variance from to . Assume there is a ball of radius around each point such that for any augmentation of . By assumption (A1) we have that . This implies that for points and such that , there exists a region of overlap so that .
In practise this says that there are augmentations of which are sufficiently similar to augmentations of so that their representations should be similar, thereby driving and to be closer.
The variance of points in space is
The overlap induces a graph where we say . For samples we can decompose the variance as
By smoothness of we always have that have . By the constraint we have that and for .
Constant proportion overlap. Now, assuming that for each point there is a constant proportion of the points, in the set we can obtain the following inequality
For we require . Since both terms are positive we separately require :
This condition makes sense since the larger , the fewer unconnected components in the graph. If the above holds, we also require to ensure the sum is bounded above by 1. This implies .
However, is a property of the augmentation set and not directly a user-controllable parameter so if is too small or the function is not smooth enough, it might not be possible to set in such a way to induce contraction in .
In the next section we derive a tighter concentration based on the structure of random graphs which are induced by the connectivity between data points and their augmentations.
Random graphs. Consider the graph induced by the constraints . Call the set of neighbours of point . For points, if there is a constant probability that then is an Erdös-Renyi graph.
From Theorem 3, if for then with high probability, there are no unconnected components in . That is, every vertex in V is reachable from any other vertex in a finite number of steps. We can then decompose the contribution to the variance in terms of components in the graph that are adjacent and those which are reachable within a certain number of steps.
Theorem 3 (Connectedness (Erdős and Rényi, 1960)).
If where with high probability then the graph has no unconnected components.
Definition 1 (Diameter).
For a connected graph, the diameter where is the minimum number of edges in the path between and .
Theorem 4 (Diameter of random graphs (Frieze and Karoński, 2016)).
Let be a fixed positive integer. For and
Then with probability and with probability .
Appendix D Generalizing Contrastive Learning
On the unsupervised observed data , any task as defined by targets induces an equivalence relation, i.e. partitions into equivalence classes. It divides based on values of the target, where for some is the set of target values. Here the equivalence relation associates datapoints based on the value of the target they predict. For example, if is a set of images of cats and dogs and denotes labels cat and dog, then is partitioned into two equivalence classes corresponding to cat and dog images by .
Intuitively, a refinement is a subdivision of an existing partition. For a visualization of a refinement of a set of tasks see Figure 4. To mathematically define refinements, we first need to introduce what it means for an equivalence relation to be finer than another equivalence relation.
(Fineness). Let and be two equivalence relations on the set . If every equivalence class of is a subset of an equivalence class of , we say that is finer than .
Now we define what refinements.
(Refinement). Let be sets of equivalence classes induced by equivalence relations and over the set . If is finer than , then we call a refinement of B.
Furthermore, we can relate the corresponding sets of equivalence classes.
Let and be two equivalence relationships on the set and denote the corresponding induced partitions by and . If is finer than , then every equivalence class of is a union of equivalence classes of .
Coming back to the example of cats and dogs, let be the relation that associates cats with cats and dogs with dogs. Now the relation which associated both cats and dogs with their specific breed (e.g. poodles with other poodles) is finer than . Note that partitions into breeds and so we can easily generate the sets of cats and dogs (i.e. equivalence classes of ) by taking a union over all the corresponding breeds.
d.2 Proof of Theorem 1
(Invariant Representation). Let and be the covariates and target, respectively. We call an invariant representation for under style if
where denotes assigning the value and is the domain of .
Theorem 1. Let be a family of downstream tasks. Let be a refinement for all tasks in . If is an invariant representation for under changes in style , then is an invariant representation for all tasks in under changes in style , i.e.
for all and for all with . Thus, is a representation that generalizes to .
Proof. Let . We have
For the second and last equality, we used that the mechanism of is independent of , i.e. . The third equality follows from the assumption that is an invariant representation for under changes in . Thus, we get that is an invariant representation for under changes in . Specifically, for a representation to be an invariant representation for it is a sufficient condition for it to be an invariant representation for . ∎
Appendix E Experimental Details
e.1 Image Augmentations
For pretraining the representations in ReLIC, we apply the augmentation scheme proposed in SimCLR (Chen et al., 2020a) and used in (Grill et al., 2020). This consists of the following augmentations applied in the order they are listed
random crop – we randomly crop the image using an area randomly selected between and of the image with an logarithmically sampled aspect ration between and . After this, we resize the patch to ;
random horizontal flip;
color jittering – we apply in random order perturbations to brightness, contrast, saturation and hue of the image by shifting them by a random uniform offset;
grayscale – we randomly apply grayscaling;
Gaussian blurring – we blur the image using a square Gaussian kernel with standard deviation uniformly sampled in ;
solarization – we transform all the pixels with .
We use the same parameters for the augmentations and probabilities of applying individual augmentations as SimCLR (Chen et al., 2020a). After applying augmentations, we normalize the images with the mean and standard deviation computed on ImageNet across the color channels.
We test ReLIC on two different architectures – ResNet-50 (He et al., 2016) and ResNet-50 with target network as in (Grill et al., 2020). For ResNet-50, we use version 1 with post-activation. We take the representation to be the output of the final average pooling layer, which is of dimension . As in SimCLR (Chen et al., 2020a), we use a critic network to project the representation to a lower dimensional space with a multi-layer perceptron (MLP). When using ResNet-50 as encoder, we treat the parameters of the MLP (e.g. depth and width) as hyperparameters and sweep over them. This MLP has batch normalization (Ioffe and Szegedy, 2015) after every layer, rectified linear activations (ReLU) (Nair and Hinton, 2010). We used a 4 layer MLP with widths and output size with ResNet-50. When using a ResNet-50 with target networks as in (Grill et al., 2020), we exactly follow their architecture settings.
We use a batch size of 4096 and the LARS optimizer (You et al., 2017) with a cosine decay learning rate schedule (Loshchilov and Hutter, 2017) for epochs with epochs for warm-up. We exclude the biases and batch normalization parameters from LARS adaptation. We use as the base learning rate for ResNet-50 and for ResNet-50 with target network. We scale this learning rate by batch size and use a global weight decay parameter of and exclude the biases and batch normalization parameters. For the target network, we follow the approach of BYOL (Grill et al., 2020) and start the exponential moving average parameter at and increase it to one during training via with k the current training step and K the maximum number of training steps.
e.4 Evaluation on ImageNet
We follow the standard linear evaluation protocol on ImageNet as in (Kolesnikov et al., 2019; Chen et al., 2020a; Grill et al., 2020). We train a linear classifier on top of the fixed representation, i.e. we do not update the network parameters or the batch statistics. For training, we randomly crop and resize images to , and randomly horizontally flip the images after that. For testing, the images are resized to pixels along the shorter dimension with bicubic resampling after which we take a center crop of size . Both for training and testing, the images are normalized by substracting the mean and standard deviations across the color channels computed on ImageNet after the augmentations. We use Stochastic Gradient Descent with a Nestorov momentum of and train for epochs with a batch size of . We do not use any regularization techniques, e.g. weight decay.
e.5 Robustness and Generalization
ImageNet-C. The ImageNet-C dataset (Hendrycks and Dietterich, 2019) consists of different types of corruptions from the noise, blur, weather, and digital categories applied to the validation images of ImageNet. This dataset is used for measuring semantic robustness. Figure 5 visualizes the corruption types. Each type of corruption has levels of severity, i.e. there are distinct corruptions in the dataset. In Figure 6, we display the Impulse noise corruption for different severity levels. As can be seen, with increasing severity level the image becomes increasingly corrupted and difficult to parse. In addition to these corruption types, there are an additional corruption types (speckle noise, gaussian blur, spatter and saturate) that are provided as a validation set. We use these additional corruption types for selecting the best hyperparameters. For further details on this dataset, please refer to (Hendrycks and Dietterich, 2019).
ImageNet-R. The ImageNet-R dataset (Hendrycks et al., 2020) consists of images depicting various artistic renditions (e.g., paintings, sculpture, origami, cartoon) of ImageNet object classes. This dataset is used to measure out-of-distribution generalization to various abstract visual renditions as it emphasizes shape over texture. The data was collected primarily from Flickr and also includes line drawings from (Wang et al., 2019). The images represent naturally occurring objects and have different textures and local image statistic to those of ImageNet. Figure 7 visualizes different images from the dataset. For further details on this dataset, please refer to (Hendrycks et al., 2020).
To evaluate robustness and generalization of the learned representation, we follow the standard linear evaluation protocol on ImageNet as in (Chen et al., 2020b, a; Kolesnikov et al., 2019). We train a linear classifier on top of the frozen representation, i.e. we do not update either the network parameters nor the batch statistics. During training, we augment the data by randomly cropping, resizing to and randomly flipping the image. At test time, images are resized to 256 pixels along the shorter side via bicubic resampling and we take a center crop. Both during training and testing, after applying augmentations we normalize the color channels by subtracting the average color and dividing by the standard deviation that is computed on ImageNet. We optimize the cross-entropy loss using Stochastic Gradient Descent with Nestorov momentum of . We sweep over number for epochs , learning rates and batch sizes . We select hyperparameters on the validation set provided in ImageNet-C and report the performance on ImageNet-R and on the test set of ImageNet-C under the best validation hyperparameters. We do not use any regularization techniques such as weight decay, gradient clipping, clipping or logits regularization.
Robustness metrics and further results
Let be a classifier that has not been trained on ImageNet-C. For each corruption type and level of severity , denote the top-1 error of this classifier as . Different corruption types pose different levels of difficulty. To make error rates across corruption types more comparable, the error rates are divided by AlexNet’s errors. This standardized measure is the Corruption Error and is computed as
The average error across all corruption types is called the mean Corruption Error (mCE). Corruption Errors and mCE measure absolute robustness.
To better assess robustness, we also report the relative Corruption Error which measures relative robustness, i.e. loss in performance under corruptions. Denote by the top-1 error rate for on the clean test set of ImageNet. The relative Corruption Error is given as
The mean relative Corruption Error (mrCE) is the mean of the relative Corruption Errors across all the corruption types. For more details and intuitions about there measures please refer to (Hendrycks and Dietterich, 2019).
In Table 6, we report Corruption Errors for Blur, Weather, and Digital corruption types. In Table 7, we report the relative robustness. As per (Hendrycks and Dietterich, 2019), we used the following values as the average AlexNet errors across severities, i.e. , to normalize the Corruption Error values – Gaussian Noise 88.6%, Shot Noise 89.4%, Impulse Noise 92.3%, Defocus Blur 82.0%, Glass Blur 82.6%, Motion Blur 78.6%, Zoom Blur 79.8%, Snow 86.7%, Frost 82.7%, Fog 81.9%, Brightness 56.5%, Contrast 85.3%, Elastic Transformation 64.6%, Pixelate 71.8%, JPEG 60.7%, Speckle Noise 84.5%, Gaussian Blur 78.7%, Spatter 71.8%, Saturate 65.8%.
|ResNet-50 with target network:|
|ResNet-50 with target network:|
e.6 Evaluation on Atari
For our experiments on Atari, we use the agent from R2D2 (Kapturowski et al., 2019) with standard hyperparameters noted below. We train each agent on approximately 15 billion frames and add a second encoder with the same architecture used in the Q-Network of the original agent. This second encoder is trained with a separate optimizer with only a representation learning objective. The agent then takes the output of this encoder as a given input. We use standard augmentations used in prior work (Kostrikov et al., 2020) where we pad the frames on all sides with 4 pixels copied from the borders and then randomly cropping 84 windows. We randomly shift pixel intensity according to the distribution where is the standard Normal distribution with values clipped between -2 and 2. is then multiplied by the original image to return the augmented image.
ReLIC and SimCLR For our implementation of ReLIC and SimCLR, we do not use a critic embedding at all and utilize the last layer of the encoder for the objective. As in CURL (Srinivas et al., 2020) we utilize a target encoder for the second augmentation where we update the weights with a momentum of . We also clipped the gradients of our optimizer using a global norm ratio of 40. We report the hyperparameters in Table LABEL:relic_atari.
|Scaling of Embeddings||False|
Curl For CURL, we use a second encoder as noted before. With the exception of the encoder architecture and the optimizer parameters, all hyperparameters are the same as in (Srinivas et al., 2020) including the momentum value for the target network weight updates. We utilize the same architecture in the paper with a linear layer as a critic embedding for the target encoder.
Byol In BYOL, we utilize two-layer perceptron networks as our predictor and projection layers. For both networks, the number of hidden units in the two layers was 1024 and 512. We use a target network update momentum of .99. The optimizer parameters are the same as in Table LABEL:relic_atari.
Direct Augmentation We also compared against direct augmentation of the observations in the replay buffer as in DrQ (Kostrikov et al., 2020). We keep the architecture the same in this instance and use two duplicate encoders as input to the agent. In this case, the optimizer can jointly update both encoders and train them end-to-end.
|Games||Average Human||Random||ReLIC (ours)||SimCLR||CURL||BYOL||Augmentation|
|kung fu master||22736.30||258.50||230241.57||220076.57||228943.94||208064.38||64632.42|
|name this game||8049.00||2292.30||48669.30||46657.55||47417.82||44848.29||13416.57|
|up n down||11693.20||533.40||577256.03||520666.59||566912.89||552110.67||143512.38|
|wizard of wor||4756.50||563.50||123513.74||89462.62||106801.20||68256.44||5940.82|
- See (Peters et al., 2017) for a review of causal graphs and causality.
- Since neither content nor style are a priori known, choosing a set of augmentations implicitly defines which aspects of the data are considered style and which are content.
- Note that since refinements are more fine-grained that the original task, if a representation captures a refinement then it also captures the downstream tasks as strictly more information is needed to solve the refinement.
- The set of augmentations includes Gaussian blurring, various colour distortions, flips and random cropping.
- Learning representations by maximizing mutual information across views. In Advances in Neural Information Processing Systems, pp. 15509–15519. Cited by: Table 5, §4.
- The arcade learning environment: an evaluation platform for general agents (extended abstract). J. Artif. Intell. Res. 47, pp. 253–279. Cited by: §1, §5.
- Unsupervised learning of visual features by contrasting cluster assignments. ArXiv abs/2006.09882. Cited by: §5, Table 1.
- Visual causal feature learning. arXiv preprint arXiv:1412.2309. Cited by: §1, §3.
- A simple framework for contrastive learning of visual representations. arXiv preprint arXiv:2002.05709. Cited by: Table 5, §E.1, §E.2, §E.4, §E.5.2, §1, §4, §5, §5, Table 1.
- Improved baselines with momentum contrastive learning. ArXiv abs/2003.04297. Cited by: §E.5.2, Table 1.
- On the evolution of random graphs. Publ. Math. Inst. Hung. Acad. Sci 5 (1), pp. 17–60. Cited by: Theorem 3.
- The elements of statistical learning. Vol. 2, Springer series in statistics New York. Cited by: §5.
- Introduction to random graphs. Cambridge University Press. Cited by: Theorem 4.
- Domain adaptation with conditional transferable components. In International conference on machine learning, pp. 2839–2848. Cited by: §4.
- Bootstrap your own latent: a new approach to self-supervised learning. arXiv preprint arXiv:2006.07733. Cited by: Table 5, §E.1, §E.2, §E.3, §E.4, Table 7, §2, §4, Table 1, Table 2.
- Noise-contrastive estimation: a new estimation principle for unnormalized statistical models. In Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics, pp. 297–304. Cited by: §4.
- Dimensionality reduction by learning an invariant mapping. In 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’06), Vol. 2, pp. 1735–1742. Cited by: §2.
- Momentum contrast for unsupervised visual representation learning. arXiv preprint arXiv:1911.05722. Cited by: §4.
- Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §E.2.
- Conditional variance penalties and domain shift robustness. arXiv preprint arXiv:1710.11469. Cited by: §4.
- Data-efficient image recognition with contrastive predictive coding. arXiv preprint arXiv:1905.09272. Cited by: Table 5, §4, Table 1.
- The many faces of robustness: a critical analysis of out-of-distribution generalization. arXiv preprint arXiv:2006.16241. Cited by: §E.5.1, §1, §5.
- Benchmarking neural network robustness to common corruptions and perturbations. Proceedings of the International Conference on Learning Representations. Cited by: §E.5.1, §E.5.3, §E.5.3, §1, §5.
- Learning deep representations by mutual information estimation and maximization. arXiv preprint arXiv:1808.06670. Cited by: §1, §4.
- Batch normalization: accelerating deep network training by reducing internal covariate shift. ArXiv abs/1502.03167. Cited by: §E.2.
- Recurrent experience replay in distributed reinforcement learning.. Iclr. Cited by: §E.6, §5.
- Revisiting self-supervised visual representation learning. 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 1920–1929. Cited by: §E.4, §E.5.2, §5.
- Image augmentation is all you need: regularizing deep reinforcement learning from pixels. arXiv preprint arXiv:2004.13649. Cited by: §E.6, §E.6, §5.
- SGDR: stochastic gradient descent with warm restarts. In ICLR, Cited by: §E.3.
- Self-supervised learning of pretext-invariant representations. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 6707–6717. Cited by: Table 1.
- Rectified linear units improve restricted boltzmann machines. In ICML, Cited by: §E.2.
- Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748. Cited by: §1, §4.
- Causality. Cambridge university press. Cited by: §2.
- Causal inference by using invariant prediction: identification and confidence intervals. Journal of the Royal Statistical Society: Series B (Statistical Methodology) 78 (5), pp. 947–1012. Cited by: §4.
- Elements of causal inference: foundations and learning algorithms. MIT press. Cited by: §2, footnote 1.
- On variational bounds of mutual information. In International Conference on Machine Learning, pp. 5171–5180. Cited by: §1, §4.
- ImageNet large scale visual recognition challenge. International Journal of Computer Vision 115, pp. 211–252. Cited by: §5.
- A theoretical analysis of contrastive unsupervised representation learning. In International Conference on Machine Learning, pp. 5628–5637. Cited by: Appendix B, Appendix B, Appendix B, Appendix B, §1, §4, Theorem 2.
- Curl: contrastive unsupervised representations for reinforcement learning. arXiv preprint arXiv:2004.04136. Cited by: §E.6, §E.6, §5.
- Contrastive multiview coding. arXiv preprint arXiv:1906.05849. Cited by: Table 1.
- What makes for good views for contrastive learning. ArXiv abs/2005.10243. Cited by: §5, Table 1.
- On mutual information maximization for representation learning. arXiv preprint arXiv:1907.13625. Cited by: §1, §4.
- Learning robust global representations by penalizing local predictive power. ArXiv abs/1905.13549. Cited by: §E.5.1.
- Large batch training of convolutional networks. arXiv preprint arXiv:1708.03888. Cited by: §E.3.