LayerPeeled Model:
Toward Understanding WellTrained
Deep Neural Networks
Abstract
In this paper, we introduce the LayerPeeled Model, a nonconvex yet analytically tractable optimization program, in a quest to better understand deep neural networks that are trained for a sufficiently long time. As the name suggests, this new model is derived by isolating the topmost layer from the remainder of the neural network, followed by imposing certain constraints separately on the two parts. We demonstrate that the LayerPeeled Model, albeit simple, inherits many characteristics of welltrained neural networks, thereby offering an effective tool for explaining and predicting common empirical patterns of deep learning training. First, when working on classbalanced datasets, we prove that any solution to this model forms a simplex equiangular tight frame, which in part explains the recently discovered phenomenon of neural collapse in deep learning training [PHD20]. Moreover, when moving to the imbalanced case, our analysis of the LayerPeeled Model reveals a hitherto unknown phenomenon that we term Minority Collapse, which fundamentally limits the performance of deep learning models on the minority classes. In addition, we use the LayerPeeled Model to gain insights into how to mitigate Minority Collapse. Interestingly, this phenomenon is first predicted by the LayerPeeled Model before its confirmation by our computational experiments.
University of Pennsylvania
January 26, 2021
Contents
1 Introduction
In the past decade, deep learning has achieved remarkable performance across a range of scientific and engineering domains [KSH17, LBH15, SHM+16]. Interestingly, these impressive accomplishments were mostly achieved by empirical intuition and various maneuvers, though often plausible, without much principled guidance from a theoretical perspective. On the flip side, however, this reality also suggests the great potential a theory could have for advancing the development of deep learning methodologies in the coming decade.
Unfortunately, it is not easy to develop a theoretical foundation for deep learning. Perhaps the most difficult hurdle lies in the nonconvexity of the optimization problem for training neural networks, which, loosely speaking, stems from the interaction between different layers of neural networks. To be more precise, consider a neural network for Kclass classification as a function, which in its simplest form reads
\bm{f}(\bm{x};\bm{W}_{\textnormal{full}})=\bm{W}_{L}\sigma(\bm{W}_{L1}\sigma(% \cdots\sigma(\bm{W}_{1}\bm{x}))). 
Here, \bm{W}_{\textnormal{full}}:=\{\bm{W}_{1},\bm{W}_{2},\ldots,\bm{W}_{L}\} denotes the partition of the weights in a matrix form according to layers and \sigma(\cdot) is a nonlinear activation function such as the ReLU.^{1}^{1}1Here the function only outputs logits in \mathbb{R}^{K}, and we omit the softmax step. The lastlayer weights, \bm{W}_{L}, consists of K vectors that correspond to the K classes. For simplicity, we omit the bias term and other operations such as maxpooling. Owing to the complex and nonlinear interaction between the L layers, when applying stochastic gradient descent to the optimization problem
\min_{\bm{W}_{\textnormal{full}}}~{}\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n_{k}% }\mathcal{L}(\bm{f}(\bm{x}_{k,i};\bm{W}_{\textnormal{full}}),\bm{y}_{k})+\frac% {\lambda}{2}\\bm{W}_{\textnormal{full}}\^{2}  (1) 
with a loss function \mathcal{L} for training the neural network, it becomes very difficult to pinpoint how a given layer influences the output \bm{f} (above, \{\bm{x}_{k,i}\}_{i=1}^{n_{k}} denotes the training examples in the kth class, with label \bm{y}_{k},^{2}^{2}2We often encode \bm{y}_{k} as a Kdimensional onehot vector with 1 in the kth entry. N=n_{1}+\cdots+n_{K} is the total number of training examples, \lambda>0 is the weight decay parameter, and \\cdot\ throughout the paper is the \ell_{2} norm). Worse, this difficulty in analyzing deep learning models is compounded by an ever growing number of layers.
Therefore, an attempt to develop a tractable and comprehensive theory for demystifying deep learning would presumably first need to simplify the interaction between a large number of layers. Following this intuition, in this paper we introduce the following optimization program as a surrogate model for Program (1) for unveiling quantitative patterns of deep neural networks:
\displaystyle\min_{\bm{W}_{L},\bm{H}}  \displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n_{k}}\mathcal{L}(\bm{W}_{L}% \bm{h}_{k,i},\bm{y}_{k})  (2)  
subject to  \displaystyle\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}\right\^{2}\leq E_{W},  
\displaystyle\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\% \bm{h}_{k,i}\right\^{2}\leq E_{H}, 
where the decision variables \bm{W}_{L}=\left[\bm{w}_{1},\ldots,\bm{w}_{K}\right]^{\top}\in\mathbb{R}^{K% \times p} is, as in Program (1), comprised of K linear classifiers in the last layer, \bm{H}=[\bm{h}_{k,i}:1\leq k\leq K,1\leq i\leq n_{k}]\in\mathbb{R}^{p\times N} correspond to the pdimensional lastlayer activations/features of all N training examples^{3}^{3}3Strictly speaking, \bm{H} is used to model the activations from the L1 layer. Note that the dimension of the vector \bm{w}_{k} is also p., and E_{H} and E_{W} are two positive scalars. Although still nonconvex, this new optimization program is presumably much more amenable for analysis than the old one (1) as the interaction now is also between two variables.
In relating Program (2) to the optimization problem (1), a first simple observation is that \bm{f}(\bm{x}_{k,i};\bm{W}_{\textnormal{full}})=\bm{W}_{L}\sigma(\bm{W}_{L1}% \sigma(\cdots\sigma(\bm{W}_{1}\bm{x}_{k,i}))) in (1) is replaced by \bm{W}_{L}\bm{h}_{k,i} in (2). Put differently, the blackbox nature of the lastlayer features, namely \sigma(\bm{W}_{L1}\sigma(\cdots\sigma(\bm{W}_{1}\bm{x}_{k,i}))), is now modeled by a simple decision variable \bm{h}_{k,i} with a constraint on its \ell_{2} norm. Intuitively speaking, this simplification is done via peeling off the topmost layer from the neural network. Thus, we call the optimization program (2) the 1LayerPeeled Model, or simply the LayerPeeled Model.
At a high level, the LayerPeeled Model takes a topdown approach to the analysis of deep neural networks. As illustrated in Figure 1, the essence of the modeling strategy is to break down the neural network from top to bottom, specifically singling out the topmost layer and modeling all bottom layers collectively as a single variable. In fact, the topdown perspective that we took in the development of the LayerPeeled Model was inspired by a recent breakthrough made by Papyan, Han, and Donoho [PHD20], who discovered a mathematically elegant and pervasive phenomenon, termed neural collapse, through massive deep learning experiments on datasets with balanced classes. Roughly speaking, neural collapse refers to the emergence of certain geometric patterns of the lastlayer features \sigma(\bm{W}_{L1}\sigma(\cdots\sigma(\bm{W}_{1}\bm{x}_{k,i}))) and the lastlayer classifiers \bm{W}_{L}, when the neural network is welltrained in the sense that it is toward not only zero misclassification error but also negligible crossentropy loss.^{4}^{4}4In general, any global minimizer of Program (1) does not yield zero crossentropy loss due to the penalty term. This topdown approach was also taken in [WL90, SHN+18, OS20, YCY+20, SHA20] to investigate various aspects of deep learning models.
1.1 Two Applications
Despite its plausibility, the ultimate test of the LayerPeeled Model lies in its ability to faithfully approximate deep learning models through explaining empirical observations and even predicting new phenomena. In what follows, we provide convincing evidence that the LayerPeeled Model is up to this task by presenting two findings. To be concrete, we remark that the results below are concerned with welltrained deep learning models, which correspond to, in rough terms, (near) optimal solutions of Program (1).
Balanced Data.
When the dataset has the same number of training examples in each class, [PHD20] experimentally observed that neural collapse emerges in welltrained deep learning models (1) with the crossentropy loss: the lastlayer features from the same class tend to be very close to their class mean; these K class means centered at the globalmean have the same length and form the maximally possible equalsized angles between any pair; moreover, the lastlayer classifiers become dual to the class means in the sense that they are equal to each other for each class up to a scaling factor.
While it seems hopeless to rigorously prove neural collapse for multiplelayer neural networks (1) at the moment, alternatively, we seek to show that this phenomenon emerges in the surrogate model (2). More precisely, when the size of each class n_{k}=n for all k, is it true that any global minimizer \bm{W}_{L}^{\star}=\left[\bm{w}_{1}^{\star},\ldots,\bm{w}_{K}^{\star}\right]^{% \top},\bm{H}^{\star}=[\bm{h}_{k,i}^{\star}:1\leq k\leq K,1\leq i\leq n] of Program (2) exhibits neural collapse (see its formal definition in Section 1.2 and Theorem 3)? The following result answers this question in the affirmative:
Finding 1.
Neural collapses occurs in the LayerPeeled Model.
A formal statement of this result and a detailed discussion are given in Section 3.
This result applies to a family of loss functions \mathcal{L}, particularly including the crossentropy loss and the contrastive loss (see, e.g., [CKN+20]). As an immediate implication, this result provides evidence of the LayerPeeled Model’s ability to characterize welltrained deep learning models.
Imbalanced Data.
While a surrogate model would be satisfactory if it explains already observed phenomena, we set a high standard for the model, asking whether it can predict a new common empirical pattern. Encouragingly, the LayerPeeled Model happens to meet this standard. Specifically, we consider training deep learning models on imbalanced datasets, where some classes contain many more training examples than others. Despite the pervasiveness of imbalanced classification in many practical applications [JK19], the literature remains scarce on its impact on the trained neural networks from a theoretical standpoint. Here we provide mathematical insights into this problem by using the LayerPeeled Model. In the following result, we consider optimal solutions to the LayerPeeled Model on a dataset with two different class sizes: the first K_{A} majority classes each contain n_{A} training examples (n_{1}=n_{2}=\dots=n_{K_{A}}=n_{A}), and the remaining K_{B}:=KK_{A} minority classes each contain n_{B} examples (n_{K_{A}+1}=n_{K_{A}+2}=\dots=n_{K}=n_{B}). We call R:=n_{A}/n_{B}>1 the imbalance ratio.
Finding 2.
In the LayerPeeled Model, the lastlayer classifiers corresponding to the minority classes, namely \bm{w}^{\star}_{K_{A}+1},\bm{w}^{\star}_{K_{A}+2},\ldots,\bm{w}^{\star}_{K}, collapse to a single vector when R is sufficiently large.
This result is elaborated on in Section 4. The derivation involves some novel elements to tackle the nonconvexity of the LayerPeeled Model (2) and the asymmetry due to the imbalance in class sizes.
In slightly more detail, we identify a phase transition as the imbalance ratio R increases: when R is below a threshold, the minority classes are distinguishable in terms of their classifiers; when R is above the threshold, they become indistinguishable. While this phenomenon is merely predicted by the simple LayerPeeled Model (2), it appears in our computational experiments on deep neural networks. More surprisingly, our prediction of the phase transition point is in excellent agreement with the experiments, as shown in Figure 2.
This phenomenon, which we refer to as Minority Collapse, reveals the fundamental difficulty in using deep learning for classification when the dataset is widely imbalanced, even in terms of optimization, not to mention generalization. This is not a priori evident given that neural networks have a large approximation capacity (see, e.g., [YAR17]). Importantly, Minority Collapse emerges at a finite value of the imbalance ratio rather than at infinity. Moreover, even below the phase transition point of this ratio, we find that the angles between any pair of the minority classifiers are already smaller than those of the majority classes, both theoretically and empirically.
1.2 Related Work
There is a venerable line of work attempting to gain insights into deep learning from a theoretical point of view [JGH18, DLL+19, ALS19, ZCZ+18, COB19, EMW19, BFT17, HS20, PBL20, MMN18, SS19, RV18, FLY+20, KWL+19, SSJ20]. See also the reviews [FDZ21, HT20, FMZ19, SUN19] and references therein.
The work of neural collapse by [PHD20] in this body of work is particularly noticeable with its mathematically elegant and convincing insights. In brief, [PHD20] observed the following four properties of the lastlayer features and classifiers in deep learning training:^{5}^{5}5See the mathematical description of neural collapse in Theorem 3.

(NC1) Variability collapse: the withinclass variation of the lastlayer features becomes 0, which means that these features collapse to their class means.

(NC2) The class means centered at their global mean collapse to the vertices of a simplex equiangular tight frame (ETF) up to scaling.

(NC3) Up to scaling, the lastlayer classifiers each collapse to the corresponding class means.

(NC4) The classifier’s decision collapses to simply choosing the class with the closest Euclidean distance between its class mean and the activations of the test example.
Definition 1.
A Ksimplex ETF is a collection of points in \mathbb{R}^{p} specified by the columns of the matrix
{\bm{M}^{\star}}=\sqrt{\frac{K}{K1}}\bm{P}\left(\bm{I}_{K}\frac{1}{K}\bm{1}_% {K}\bm{1}_{K}^{\top}\right), 
where \bm{I}_{K}\in\mathbb{R}^{K\times K} is the identity matrix, \bm{1}_{K} is the ones vector, and \bm{P}\in\mathbb{R}^{p\times K} (p\geq K)^{6}^{6}6To be complete, we only require p\geq K1. When p=K1, we can choose \bm{P} such that \left[\bm{P}^{\top},\bm{1}_{K}\right] is an orthogonal matrix. is a partial orthogonal matrix such that \bm{P}^{\top}\bm{P}=\bm{I}_{K}.
These four properties emerge in massive experiments on popular network architectures during the terminal phase of training—when the trained model interpolates the insample training data—and a shared setting of these experiments is the use of balanced datasets and the crossentropy loss with \ell_{2} regularization. Using convincing arguments and numerical evidence, [PHD20] demonstrated that the symmetry and stability of neural collapse improve deep learning training in terms of generalization, robustness, and interpretability. Notably, these improvements occur with the benign overfitting phenomenon in deep neural networks [MBB18, BHM+19, LR20, BLL+20, LSS20]. As an aside, while we were preparing the manuscript, we became aware of [MPP20, EW20, LS20], which produced neural collapse using different models.
2 Derivation
In this section, we intuitively derive the LayerPeeled Model as an analytical surrogate for welltrained neural networks. Although our derivation lacks rigor, the priority is to reduce the complexity of the optimization problem (1) while roughly maintaining its structure. Notably, the penalty \frac{\lambda}{2}\\bm{W}_{\textnormal{full}}\^{2} corresponds to weight decay used in training deep learning models, which is necessary for preventing this optimization program from attaining its minimum at infinity when \mathcal{L} is the crossentropy loss.
Taking a topdown standpoint, our modeling strategy starts by singling out the weights \bm{W}_{L} of the topmost layer and rewriting (1) as
\min_{\bm{W}_{L},\bm{H}}~{}\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n_{k}}\mathcal% {L}(\bm{W}_{L}\bm{h}(\bm{x}_{k,i};\bm{W}_{L}),\bm{y}_{k})+\frac{\lambda}{2}\% \bm{W}_{L}\^{2}+\frac{\lambda}{2}\\bm{W}_{L}\^{2},  (3) 
where \bm{W}_{L} denotes the weights from all layers except for the last layer (for simplicity, we do not assume any bias terms). From the Lagrangian dual viewpoint, a minimum of the optimization program above is also an optimal solution to
\displaystyle\min_{\bm{W}_{L},\bm{W}_{L}}  \displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n_{k}}\mathcal{L}(\bm{W}_{L}% \bm{h}(\bm{x}_{k,i};\bm{W}_{L}),\bm{y}_{k})  (4)  
\displaystyle\mathrm{s.t.}  \displaystyle\\bm{W}_{L}\^{2}\leq C_{1},  
\displaystyle\\bm{W}_{L}\^{2}\leq C_{2}, 
for some positive numbers C_{1} and C_{2}.^{7}^{7}7Denoting by (\bm{W}_{L}^{\star},\bm{W}_{L}^{\star}) an optimal solution to (3), then we can take C_{1}=\\bm{W}_{L}^{\star}\^{2} and C_{2}=\\bm{W}_{L}^{\star}\^{2}. To clear up any confusion, note that due to its nonconvexity, (3) may admit multiple global minima and each in general correspond to different values of C_{1},C_{2}. Next, we can equivalently write (4) as
\displaystyle\min_{\bm{W}_{L},\bm{H}}  \displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n_{k}}\mathcal{L}(\bm{W}_{L}% \bm{h}_{k,i},\bm{y}_{k})  (5)  
\displaystyle\mathrm{s.t.}  \displaystyle\\bm{W}_{L}\^{2}\leq C_{1},  
\displaystyle\bm{H}\in\left\{\bm{H}(\bm{W}_{L}):\\bm{W}_{L}\^{2}\leq C_{2}% \right\}, 
where \bm{H}=[\bm{h}_{k,i}:1\leq k\leq K,1\leq i\leq n_{k}] denotes the decision variable and the function \bm{H}(\bm{W}_{L}) is defined as \bm{H}(\bm{W}_{L}):=\left[\bm{h}(\bm{x}_{k,i};\bm{W}_{L}):1\leq k\leq K,1% \leq i\leq n_{k}\right] for any \bm{W}_{L}.
To simplify (5), we make the ansatz that the range of \bm{h}(\bm{x}_{k,i};\bm{W}_{L}) under the constraint \\bm{W}_{L}\^{2}\leq C_{2} is approximately an ellipse in the sense that
\left\{\bm{H}(\bm{W}_{L}):\\bm{W}_{L}\^{2}\leq C_{2}\right\}\approx\left\{% \bm{H}:\sum_{k=1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\\bm{h}_{k,i}\^{2}\leq C% _{2}^{\prime}\right\}  (6) 
for some C_{2}^{\prime}>0. Loosely speaking, this ansatz asserts that \bm{H} should be regarded as a variable in an \ell_{2} space. To shed light on this point, note that \bm{h}_{k,i} intuitively lives in the dual space of \bm{W} in view of the appearance of the product \bm{W}\bm{h}_{k,i} in the objective. Furthermore, \bm{W} is in an \ell_{2} space for the \ell_{2} constraint on it. Hence, the rationale behind the ansatz follows from the selfduality of \ell_{2} spaces.
Inserting this approximation into (5), we obtain the following optimization program, which we call the LayerPeeled Model:
\displaystyle\min_{\bm{W},\bm{H}}  \displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n_{k}}\mathcal{L}(\bm{W}\bm{% h}_{k,i},\bm{y}_{k})  (7)  
\displaystyle\mathrm{s.t.}  \displaystyle\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}\right\^{2}\leq E_{W},  
\displaystyle\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\% \bm{h}_{k,i}\right\^{2}\leq E_{H}, 
where, for simplicity, we henceforth write \bm{W}:=\bm{W}_{L}\equiv[\bm{w}_{1},\ldots,\bm{w}_{K}]^{\top} for the lastlayer classifiers/weights and the thresholds E_{W}=C_{1}/K and E_{H}=C_{2}^{\prime}/K.
This optimization program is nonconvex but, as we will show soon, is generally mathematically tractable for analysis. On the surface, the LayerPeeled Model has no dependence on the data \{\bm{x}_{k,i}\}, which however is not the correct picture since the dependence has been implicitly incorporated into the threshold E_{H}.
3 LayerPeeled Model for Explaining Neural Collapse
In this section, we consider training deep neural networks on a balanced dataset, meaning n_{k}=n for all classes 1\leq k\leq K, and our main finding is that the LayerPeeled Model displays the neural collapse phenomenon, just as in deep learning training [PHD20]. The proofs are all deferred to Appendix A.1. Throughout this section, we assume p\geq K1 unless otherwise specified. This condition is satisfied by many popular architectures, where p is usually tens or hundreds of times of K.
3.1 CrossEntropy Loss
The crossentropy loss is perhaps the most popular loss used in training deep learning models for classification tasks. This loss function takes the form
\mathcal{L}(\bm{z},\bm{y}_{k})=\log\left(\frac{\exp(\bm{z}(k))}{\sum_{{k^{% \prime}}=1}^{K}\exp(\bm{z}({k^{\prime}}))}\right), 
where \bm{z}(k^{\prime}) denotes the k^{\prime}th entry of \bm{z}. Recall that \bm{y}_{k} is the label of the kth class and the feature \bm{z} is set to \bm{W}\bm{h}_{k,i} in the LayerPeeled Model (7). In contrast to the complex deep neural networks, which are often considered a blackbox, the LayerPeeled Model is much more amenable to analysis. As an exemplary use case, the following result shows that any minimizer of the LayerPeeled Model (7) with the crossentropy loss admits an almost closedform expression.
Theorem 3.
In the balanced case, any global minimizer \bm{W}^{\star}\equiv\left[\bm{w}_{1}^{\star},\ldots,\bm{w}_{K}^{\star}\right]^% {\top},\bm{H}^{\star}\equiv[\bm{h}_{k,i}^{\star}:1\leq k\leq K,1\leq i\leq n] of (7) with the crossentropy loss obeys
\bm{h}_{k,i}^{\star}=C\bm{w}_{k}^{\star}=C^{\prime}\bm{m}_{k}^{\star}  (8) 
for all 1\leq i\leq n,1\leq k\leq K, where the constants C=\sqrt{E_{H}/E_{W}},C^{\prime}=\sqrt{E_{H}}, and the matrix [\bm{m}_{1}^{\star},\ldots,\bm{m}_{K}^{\star}] forms a Ksimplex ETF specified in Definition 1.
Remark 4.
Note that the minimizers (\bm{W}^{\star},\bm{H}^{\star})’s are equivalent to each other up to rotation because of the rational invariance of simplex ETFs (see the rotation \bm{P} in Definition 1).
This theorem demonstrates the highly symmetric geometry of the lastlayer features and weights of the LayerPeeled Model, which is precisely the phenomenon of neural collapse. Explicitly, (8) says that all withinclass (lastlayer) features are the same: \bm{h}_{k,i}^{\star}=\bm{h}_{k,i^{\prime}}^{\star} for all 1\leq i,i^{\prime}\leq n; next, the K classmean features \bm{h}_{k}^{\star}:=\bm{h}_{k,i}^{\star} together exhibit a Ksimplex ETF up to scaling, from which we immediately conclude that
\cos\measuredangle(\bm{h}_{k}^{\star},\bm{h}_{{k^{\prime}}}^{\star})=\frac{1}% {K1}  (9) 
for any k\neq k^{\prime} by Definition 1;^{8}^{8}8Note that the cosine value \frac{1}{K1} corresponds to the largest possible angle for any K points that have an equal \ell_{2} norm and equalsized angles between any pair. As pointed out in [PHD20], the largest angle implies a largemargin solution [SHN+18]. in addition, (8) also displays the precise duality between the lastlayer classifiers and features. Taken together, these facts indicate that the minimizer \left(\bm{W}^{\star},\bm{H}^{\star}\right) satisfies exactly (NC1)–(NC3). Last, Property (NC4) is also satisfied by recognizing that, for any given lastlayer features \bm{h}, the predicted class is \operatorname*{arg\,max}_{k}\bm{w}_{k}^{\star}\mathop{\mathchoice{\vbox{\hbox{% \LARGE$\cdot$}}}{\vbox{\hbox{\LARGE$\cdot$}}}{\vbox{\hbox{\normalsize$\cdot$}}% }{\vbox{\hbox{\small$\cdot$}}}}\bm{h}, where \bm{a}\mathop{\mathchoice{\vbox{\hbox{\LARGE$\cdot$}}}{\vbox{\hbox{\LARGE$% \cdot$}}}{\vbox{\hbox{\normalsize$\cdot$}}}{\vbox{\hbox{\small$\cdot$}}}}\bm{b} denotes the inner product of the two vectors. Note that the predicted which satisfies
\operatorname*{arg\,max}_{k}\bm{w}_{k}^{\star}\mathop{\mathchoice{\vbox{\hbox{% \LARGE$\cdot$}}}{\vbox{\hbox{\LARGE$\cdot$}}}{\vbox{\hbox{\normalsize$\cdot$}}% }{\vbox{\hbox{\small$\cdot$}}}}\bm{h}=\operatorname*{arg\,max}_{k}\bm{h}_{k}^{% \star}\mathop{\mathchoice{\vbox{\hbox{\LARGE$\cdot$}}}{\vbox{\hbox{\LARGE$% \cdot$}}}{\vbox{\hbox{\normalsize$\cdot$}}}{\vbox{\hbox{\small$\cdot$}}}}\bm{h% }=\operatorname*{arg\,min}_{k}\\bm{h}_{k}^{\star}\bm{h}\^{2}. 
Conversely, the presence of neural collapse in the LayerPeeled Model offers evidence of the effectiveness of our model as a tool for analyzing neural networks. To be complete, we remark that other models were very recently proposed to justify the neural collapse phenomenon [MPP20, EW20, LS20] (see also [PL20]). For example, [EW20, LS20] considered models that impose a norm constraint for each individual class, rather than an overall constraint as employed in the LayerPeeled Model.
3.2 Extensions to Other Loss Functions
In the modern practice of deep learning, various loss functions are employed to take into account the problem characteristics. Here we show that the LayerPeeled Model continues to exhibit the phenomenon of neural collapse for some popular loss functions.
Contrastive Loss.
Contrastive losses have been extensively used recently in both supervised and unsupervised deep learning [PSM14, AKK+19, CKN+20, BZM+20]. These losses pull similar training examples together in their embedding space while pushing apart dissimilar examples. Here we consider the supervised contrastive loss [KTW+20], which (in the balanced) case is defined through the lastlayer features as
\mathcal{L}_{c}(\bm{h}_{k,i},\bm{y}_{k})=\frac{1}{n}\sum_{j=1}^{n}\log\left(% \frac{\exp(\bm{h}_{k,i}\mathop{\mathchoice{\vbox{\hbox{\LARGE$\cdot$}}}{\vbox{% \hbox{\LARGE$\cdot$}}}{\vbox{\hbox{\normalsize$\cdot$}}}{\vbox{\hbox{\small$% \cdot$}}}}\bm{h}_{k,j}/\tau)}{\sum_{{k^{\prime}}=1}^{K}\sum_{\ell=1}^{n}\exp(% \bm{h}_{k,i}\mathop{\mathchoice{\vbox{\hbox{\LARGE$\cdot$}}}{\vbox{\hbox{% \LARGE$\cdot$}}}{\vbox{\hbox{\normalsize$\cdot$}}}{\vbox{\hbox{\small$\cdot$}}% }}\bm{h}_{{k^{\prime}},\ell}/\tau)}\right),  (10) 
where \tau>0 is a parameter. As the loss does not involve the lastlayer classifiers explicitly, the LayerPeeled Model in the case of the supervised contrastive loss takes the form^{9}^{9}9In (10), \bm{h}_{k,i}\equiv\bm{h}(\bm{x}_{k,i},\bm{W}_{L}) depends on the data, whereas in (11) \bm{h}_{k,i}’s form the decision variable \bm{H}.
\displaystyle\min_{\bm{H}}  \displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n}\mathcal{L}_{c}(\bm{h}_{k,% i},\bm{y}_{k})  (11)  
\displaystyle\mathrm{s.t.}  \displaystyle\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n}\sum_{i=1}^{n}\left\\bm{h}_{% k,i}\right\^{2}\leq E_{H}. 
We show that this LayerPeeled Model also exhibits neural collapse in its lastlayer features, even though the label information is not explicitly explored in the loss.
Theorem 5.
Any global minimizer of (11) satisfies
\bm{h}_{k,i}^{\star}=\sqrt{E_{H}}\bm{m}_{k}^{\star}  (12) 
for all 1\leq k\leq K and 1\leq i\leq n, where [\bm{m}_{1}^{\star},\ldots,\bm{m}_{K}^{\star}] forms a Ksimplex ETF.
Theorem 5 shows that the contrastive loss in the associated LayerPeeled Model does a perfect job in pulling together training examples from the same class. Moreover, as seen from the denominator in (10), minimizing this loss would intuitively render the betweenclass inner products of lastlayer features as small as possible, thereby pushing the features to form the vertices of a Ksimplex ETF up to scaling.
SoftmaxBased Loss.
The crossentropy loss can be thought of as a softmaxbased loss. To see this, define the softmax transform as
\bm{S}(\bm{z})=\left[\frac{\exp(\bm{z}(1))}{\sum_{k=1}^{K}\exp(\bm{z}(k))},% \ldots,\frac{\exp(\bm{z}(K))}{\sum_{k=1}^{K}\exp(\bm{z}(k))}\right]^{\top} 
for \bm{z}\in\mathbb{R}^{K}. Let g_{1} be any nonincreasing convex function and g_{2} be any nondecreasing function, both defined on (0,1). We consider a softmaxbased loss function that takes the form
\mathcal{L}(\bm{z},\bm{y}_{k})=g_{1}\left(\bm{S}(\bm{z})(k)\right)+\sum_{{k^{% \prime}}=1,~{}{k^{\prime}}\neq k}^{K}g_{2}\left(\bm{S}(\bm{z})({k^{\prime}})% \right).  (13) 
Here, \bm{S}(\bm{z})(k) denotes the kth element of \bm{S}(\bm{z}). Taking g_{1}(x)=\log x and g_{2}\equiv 0, we recover the crossentropy loss. Another example is to take g_{1}(x)=(1x)^{q} and g_{2}(x)=x^{q} for q>1, which can be implemented in most deep learning libraries such as PyTorch [PGM+19].
We have the following theorem regarding the softmaxbased loss functions in the balanced case.
Theorem 6.
Assume \sqrt{E_{H}E_{W}}>\frac{K1}{K}\log\left(K^{2}\sqrt{E_{H}E_{W}}+(2K1)(K1)\right). For any loss function defined in (13), (\bm{W}^{\star},\bm{H}^{\star}) given by (8) is a global minimizer of Program (7). Moreover, if g_{2} is strictly convex and at least one of g_{1},g_{2} is strictly monotone, then any global minimizer must be given by (8).
In other words, neural collapse continues to emerge with softmaxbased losses under mild regularity conditions. The first part of this theorem does not preclude the possibility that the LayerPeeled Model admits solutions other than (8). When applied to the crossentropy loss, it is worth pointing out that this theorem is a weak version of Theorem 3, albeit more general. Regarding the first assumption in Theorem 6, note that E_{H} and E_{W} would be arbitrarily large if the weight decay \lambda in (1) is sufficiently small, thereby meeting the assumption concerning \sqrt{E_{H}E_{W}} in this theorem.
We remark that Theorem 6 does not require the convexity of the loss \mathcal{L}. To circumvent the hurdle of nonconvexity, our proof in Appendix A.1 presents several novel elements.
In passing, we leave the experimental confirmation of neural collapse with these loss functions for future work.
4 LayerPeeled Model for Predicting Minority Collapse
Deep learning models are often trained on datasets where there is a disproportionate ratio of observations in each class [WLW+16, HLL+16, MR17]. For example, in the Places2 challenge dataset [ZKL+16], the number of images in its majority scene categories is about eight times that in its minority classes. Another example is the Ontonotes dataset for partofspeech tagging [HMP+06], where the number of words in its majority classes can be more than one hundred times that in its minority classes. While empirically the imbalance in class sizes often leads to inferior model performance of deep learning (see, e.g., [JK19]), there remains a lack of a solid theoretical footing for understanding its effect, perhaps due to the complex details of deep learning training.
In this section, we use the LayerPeeled Model to seek a finegrained characterization of how class imbalance impacts neural networks that are trained for a sufficiently long time. In short, our analysis predicts a phenomenon we term Minority Collapse, which fundamentally limits the performance of deep learning especially on the minority classes, both theoretically and empirically. All omitted proofs are relegated to Appendix A.2.
4.1 Technique: Convex Relaxation
When it comes to imbalanced datasets, the LayerPeeled Model no longer admits a simple expression for its minimizers as in the balanced case, due to the lack of symmetry between classes. This fact results in, among others, an added burden on numerically computing the solutions of the LayerPeeled Model.
To overcome this difficulty, we introduce a convex optimization program as a relaxation of the nonconvex LayerPeeled Model (7), relying on the wellknown result for relaxing a quadratically constrained quadratic program as a semidefinite program (see, e.g., [SZ03]). To begin with, defining \bm{h}_{k} as the feature mean of the kth class (i.e., \bm{h}_{k}:=\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\bm{h}_{k,i}), we introduce a new decision variable \bm{X}:=\left[\bm{h}_{1},\bm{h}_{2},\dots,\bm{h}_{K},\bm{W}^{\top}\right]^{% \top}\left[\bm{h}_{1},\bm{h}_{2},\dots,\bm{h}_{K},\bm{W}^{\top}\right]\in% \mathbb{R}^{2K\times 2K}. By definition, \bm{X} is positive semidefinite and satisfies
\frac{1}{K}\sum_{k=1}^{K}\bm{X}(k,k)\\ =\frac{1}{K}\sum_{k=1}^{K}\\bm{h}_{k}\^{2}\overset{a}{\leq}\frac{1}{K}\sum_{% k=1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\\bm{h}_{k,i}\right\^{2}\\ \leq E_{H} 
and
\frac{1}{K}\sum_{k=K+1}^{2K}\bm{X}(k,k)=\frac{1}{K}\sum_{k=1}^{K}\\bm{w}_{k}% \^{2}\leq E_{W}, 
where \overset{a}{\leq} follows from the Cauchy–Schwarz inequality. Thus, we consider the following semidefinite programming problem:^{10}^{10}10Although Program (13) involves a semidefinite constraint, it is not a semidefinite program in the strict sense because a semidefinite program uses a linear objective function.
\displaystyle\min_{\bm{X}\in\mathbb{R}^{2K\times 2K}}  \displaystyle\sum_{k=1}^{K}\frac{n_{k}}{N}\mathcal{L}(\bm{z}_{k},\bm{y}_{k})  (13)  
\displaystyle\mathrm{s.t.}  \displaystyle\bm{z}_{k}=\left[\bm{X}(k,K+1),\bm{X}(k,K+2),\dots,\bm{X}(k,2K)~{% }\right]^{\top},~{}\text{ for all }1\leq k\leq K,  
\displaystyle\frac{1}{K}\sum_{k=1}^{K}\bm{X}(k,k)\leq E_{H},\quad\frac{1}{K}% \sum_{k=K+1}^{2K}\bm{X}(k,k)\leq E_{W},  
\displaystyle\bm{X}\succeq 0. 
Lemma 1.
Assume p\geq 2K and the loss function \mathcal{L} is convex in its first argument. Let \bm{X}^{\star} be a minimizer of the convex program (13). Define \left(\bm{H}^{\star},\bm{W}^{\star}\right) as
\displaystyle\left[\bm{h}_{1}^{\star},\bm{h}_{2}^{\star},\dots,\bm{h}_{K}^{% \star},~{}(\bm{W}^{\star})^{\top}\right]=\bm{P}(\bm{X}^{\star})^{1/2},  (14)  
\displaystyle\bm{h}_{k,i}^{\star}=\bm{h}_{k}^{\star},~{}\text{ for all }1\leq i% \leq n,1\leq k\leq K, 
where (\bm{X}^{\star})^{1/2} denotes the positive square root of \bm{X}^{\star} and \bm{P}\in\mathbb{R}^{p\times 2K} is any partial orthogonal matrix such that \bm{P}^{\top}\bm{P}=\bm{I}_{2K}. Then (\bm{H}^{\star},\bm{W}^{\star}) is a minimizer of (7). Moreover, if all \bm{X}^{\star}’s satisfy \frac{1}{K}\sum_{k=1}^{K}\bm{X}^{\star}(k,k)=E_{H}, then all the solutions of (7) are in the form of (14).
This lemma in effect says that the relaxation does not lead to any loss of information when we study the LayerPeeled Model through a convex program, thereby offering a computationally efficient tool for gaining insights into the terminal phase of training deep neural networks on imbalanced datasets. An appealing feature is that the size of the program (13) is independent of the number of training examples. Besides, this lemma predicts that even in the imbalanced case the lastlayer features collapse to their class means under mild conditions. Therefore, Property (NC1) is satisfied (see more discussion about the condition in Section B).
The assumption of the convexity of \mathcal{L} in the first argument is satisfied by a large class of loss functions, such as the crossentropy loss. We also remark that (13) is not the unique convex relaxation. An alternative is to relax (7) via a nuclear normconstrained convex program [BMP08, HV19] (see more details in Section B).
4.2 Minority Collapse
With the technique of convex relaxation in place, now we numerically solve the LayerPeeled Model on imbalanced datasets, with the goal of identifying nontrivial patterns in this regime. As a worthwhile starting point, we consider a dataset that has K_{A} majority classes each containing n_{A} training examples and K_{B} minority classes each containing n_{B} training examples. That is, assume n_{1}=n_{2}=\dots=n_{K_{A}}=n_{A} and n_{K_{A}+1}=n_{K_{A}+2}=\dots=n_{K}=n_{B}. For convenience, call R:=n_{A}/n_{B}>1 the imbalance ratio. Note that the case R=1 reduces to the balanced setting.
An important question is to understand how the K_{B} lastlayer minority classifiers behave as the imbalance ratio R increases, as this is directly related to the model performance on the minority classes. To address this question, we show that the average cosine of the angles between any pair of the K_{B} minority classifiers in Figure 3 by solving the simple convex program (13). This figure reveals a twophase behavior of the minority classifiers \bm{w}^{\star}_{K_{A}+1},\bm{w}^{\star}_{K_{A}+2},\ldots,\bm{w}^{\star}_{K} as R increases:

(1)
When R<R_{0} for some R_{0}>0, the average betweenminorityclass angle becomes smaller as R increases.

(2)
Once R\geq R_{0}, the average betweenminorityclass angle become zero, implying that all the minority classifiers collapse to a single vector.
Above, the phase transition point R_{0} depends on the imbalance configuration K_{A},K_{B} and the thresholds E_{H},E_{W}.
We refer to the phenomenon that appears in the second phase as Minority Collapse. While it can be expected that the minority classifiers get closer to each other as the level of imbalance increases, surprisingly, these classifiers become completely indistinguishable once R hits a finite value. Once Minority Collapse takes place, the neural network would predict equal probabilities for all the minority classes regardless of the input. As such, its predictive ability is by no means better than a coin toss when conditioned on the minority classes for both optimization and generalization, and this situation would only get worse in the presence of adversarial perturbations. This phenomenon is especially detrimental when the minority classes are more frequent in the application domains than in the training data.
From an optimization point of view, the emergence of Minority Collapse would prevent the model from achieving zero training error since its prediction is simply uniform over the minority classes. While it seems to contradict conventional wisdom on the approximation power of deep learning, a careful examination indicates that the occurrence can be attributed to the two constraints in the LayerPeeled Model or the \ell_{2} penalty in (1). However, this issue does not disappear by simply setting a small penalty coefficient \lambda in deep learning because the imbalance ratio can be arbitrarily large. Even outside the regime of Minority Collapse, the classification might still be unreliable if the imbalance ratio is large since the softmax predictions for the minority classes can be close to each other.
To put the observations in Figure 3 on a firm footing, we prove that Minority Collapse indeed emerges in the LayerPeeled Model as R tends to infinity.
Theorem 7.
Assume p\geq K and n_{A}/n_{B}\to\infty, and fix K_{A} and K_{B}. Let \left(\bm{H}^{\star},\bm{W}^{\star}\right) be any global minimizer of the LayerPeeled Model (7) with the crossentropy loss. As R\equiv n_{A}/n_{B}\to\infty, we have
\lim\bm{w}^{\star}_{k}\bm{w}^{\star}_{{k^{\prime}}}=\bm{0}_{p},~{}\text{ for % all }K_{A}<k<{k^{\prime}}\leq K. 
To intuitively see why Minority Collapse occurs, first note that the majority classes become the predominant part of the risk function as the level of imbalance increases. The minimization of the objective, therefore, pays too much emphasis on the majority classifiers, encouraging the betweenmajorityclass angles to grow and meanwhile shrinking the betweenminorityclass angles to zero. As an aside, an interesting question for future work is to prove that \bm{w}^{\star}_{k} and \bm{w}^{\star}_{{k^{\prime}}} are exactly equal for sufficiently large R.
4.3 Experiments
At the moment, Minority Collapse is merely a prediction of the LayerPeeled Model. An immediate question thus is: does this phenomenon really occur in realworld neural networks? At first glance, it does not necessarily have to be the case since the LayerPeeled Model is a dramatic simplification of deep neural networks.
To this end, we resort to computational experiments.^{11}^{11}11Our code is publicly available at https://github.com/HornHehhf/LPM. Explicitly, we consider training two network architectures, VGG and ResNet [HZR+16], on the FashionMNIST [XRV17] and CIFAR10 datasets, and in particular, replace the dropout layers in VGG with batch normalization [IS15]. As both datasets have 10 classes, we use three combinations of (K_{A},K_{B})=(3,7),(5,5),(7,3) to split the data into majority classes and minority classes. In the case of FashionMNIST (CIFAR10), we let the K_{A} majority classes each contain all the n_{A}=6000 (n_{A}=5000) training examples from the corresponding class of FashionMNIST (CIFAR10), and the K_{B} minority classes each have n_{B}=6000/R (n_{B}=5000/R) examples randomly sampled from the corresponding class. The rest experiment setup is basically the same as [PHD20]. In detail, we use the crossentropy loss and stochastic gradient descent with momentum 0.9 and weight decay \lambda=5\times 10^{4}. The networks are trained for 350 epochs with a batch size of 128. The initial learning is annealed by a factor of 10 at 1/3 and 2/3 of the 350 epochs. The only difference from [PHD20] is that we simply set the learning rate to 0.1 instead of sweeping over 25 learning rates between 0.0001 and 0.25. This is because the test performance of our trained models is already comparable with their best reported test accuracy.
The results of the experiments above are displayed in Figure 4. This figure clearly indicates that the angles between the minority classifiers collapse to zero as soon as R is large enough. Moreover, the numerical examination in Table 1 shows that the norm of the classifier is constant across the minority classes. Taken together, these two pieces clearly give evidence for the emergence of Minority Collapse in these neural networks, thereby further demonstrating the effectiveness of our LayerPeeled Model. Besides, Figure 4 also shows that the issue of Minority Collapse is compounded when there are more majority classes, which is consistent with Figure 3. For completeness, we remark that, as with neural collapse, Minority Collapse only occurs during the terminal phase of training with a nondiminishing weight decay parameter.
In order to get a handle on how Minority Collapse impacts the test accuracy, we plot the results of another numerical study in Figure 5. The setting is the same as Figure 4, except that now we randomly sample 6 or 5 examples per class for the minority classes depending on whether the dataset is FashionMNIST or CIFAR10. The results show that the performance of the trained model deteriorates in the test data if the imbalance ratio R=1000, when Minority Collapse has occurred or is about to occur. This is by no means intuitive a priori as the test performance is only restricted to the minority classes and a large value of R only leads to more training data in the majority classes without affecting that in the minority classes.
Dataset  FashionMNIST  
Network architecture  VGG11  ResNet18  
No. of majority classes  K_{A}=3  K_{A}=5  K_{A}=7  K_{A}=3  K_{A}=5  K_{A}=7 
Norm variation  2.7\times 10^{5}  4.4\times 10^{8}  6.0\times 10^{8}  1.4\times 10^{5}  5.0^{8}  6.3\times 10^{8} 
Dataset  CIFAR10  
Network architecture  VGG13  ResNet18  
No. of majority classes  K_{A}=3  K_{A}=5  K_{A}=7  K_{A}=3  K_{A}=5  K_{A}=7 
Norm variation  1.4\times 10^{4}  9.0\times 10^{7}  5.2\times 10^{8}  5.4\times 10^{5}  3.5\times 10^{7}  5.4\times 10^{8} 
5 How to Mitigate Minority Collapse?
In this section, we further exploit the use of the LayerPeeled Model in an attempt to lessen the detrimental effect of Minority Collapse. Instead of aiming to develop a full set of methodologies to overcome this issue, which is beyond the scope of the paper, our focus is on the evaluation of some simple techniques used for imbalanced datasets.
Among many approaches to handling class imbalance in deep learning (see the review [JK19]), perhaps the most popular one is to oversample training examples from the minority classes [BMM18, SXY+19, CJL+19, CWG+19]. In its simplest form, this sampling scheme retains all majority training examples while duplicating each training example from the minority classes for w_{r} times, where the oversampling rate w_{r} is a positive integer. Oversampling in effect turns to the minimization of an adjusted optimization problem that is derived by replacing the risk in the optimization program (1) with
\frac{1}{n_{A}K_{A}+w_{r}n_{B}K_{B}}\left[\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}% \mathcal{L}(\bm{f}(\bm{x}_{k,i};\bm{W}_{\textnormal{full}}),\bm{y}_{k})+w_{r}% \sum_{k=K_{A}+1}^{K}\sum_{i=1}^{n_{B}}\mathcal{L}(\bm{f}(\bm{x}_{k,i};\bm{W}_{% \textnormal{full}}),\bm{y}_{k})\right]  (15) 
while keeping the penalty term \frac{\lambda}{2}\\bm{W}_{\textnormal{full}}\^{2}. Note that oversampling is closely related to weight adjusting (see more discussion in Section B).
A close look at (15) suggests that the neural network obtained by minimizing this new program might behave as if it were trained on a (larger) dataset with n_{A} and w_{r}n_{B} examples in each majority class and minority class, respectively. To formalize this intuition, as earlier, we start by considering the LayerPeeled Model in the case of oversampling:
\displaystyle\min_{\bm{H},\bm{W}}  \displaystyle\frac{1}{n_{A}K_{A}+w_{r}n_{B}K_{B}}\left[\sum_{k=1}^{K_{A}}\sum_% {i=1}^{n_{A}}\mathcal{L}(\bm{W}\bm{h}_{k,i},\bm{y}_{k})+w_{r}\sum_{k=K_{A}+1}^% {K}\sum_{i=1}^{n_{B}}\mathcal{L}(\bm{W}\bm{h}_{k,i},\bm{y}_{k})\right]  (16)  
\displaystyle\mathrm{s.t.}  \displaystyle\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}\right\^{2}\leq E_{W},  
\displaystyle\frac{1}{K}\sum_{k=1}^{K_{A}}\frac{1}{n_{A}}\sum_{i=1}^{n_{A}}% \left\\bm{h}_{k,i}\right\^{2}+\frac{1}{K}\sum_{k=K_{A}+1}^{K}\frac{1}{n_{B}}% \sum_{i=1}^{n_{B}}\left\\bm{h}_{k,i}\right\^{2}\leq E_{H}. 
The following result confirms our intuition that oversampling indeed boosts the size of the minority classes for the LayerPeeled Model.
Proposition 1.
Assume p\geq 2K and the loss function \mathcal{L} is convex in the first argument. Let \bm{X}^{\star} be any minimizer of the convex program (13) with n_{1}=n_{2}=\dots=n_{K_{A}}=n_{A} and n_{K_{A}+1}=n_{K_{A}+2}=\dots=n_{K}=w_{r}n_{B}. Define \left(\bm{H}^{\star},\bm{W}^{\star}\right) as
\displaystyle\left[\bm{h}_{1}^{\star},\bm{h}_{2}^{\star},\dots,\bm{h}_{K}^{% \star},(\bm{W}^{\star})^{\top}\right]=\bm{P}(\bm{X}^{\star})^{1/2},  (17)  
\displaystyle~{}~{}\bm{h}_{k,i}^{\star}=\bm{h}_{k}^{\star},~{}\text{ for all }% 1\leq i\leq n_{A},1\leq k\leq K_{A},  
\displaystyle~{}~{}\bm{h}_{k,i}^{\star}=\bm{h}_{k}^{\star},~{}\text{ for all }% 1\leq i\leq n_{B},K_{A}<k\leq K, 
where \bm{P}\in\mathbb{R}^{p\times 2K} is any partial orthogonal matrix such that \bm{P}^{\top}\bm{P}=\bm{I}_{2K}. Then (\bm{H}^{\star},\bm{W}^{\star}) is a global minimizer of the oversamplingadjusted LayerPeeled Model (16). Moreover, if all \bm{X}^{\star}’s satisfy \frac{1}{K}\sum_{k=1}^{K}\bm{X}^{\star}(k,k)=E_{H}, then all the solutions of (16) are in the form of (17).
Together with Lemma 1, Proposition 1 shows that the number of training examples in each minority class is now in effect w_{r}n_{B}, instead of n_{B}.
We turn to Figure 6 for an illustration of the effects of oversampling on realworld deep learning models, using the same experimental setup as in Figure 5. From Figure 6, we see that the angles between pairs of the minority classifiers become larger as the oversampling rate w_{r} increases. Consequently, the issue of Minority Collapse becomes less detrimental in terms of training accuracy as w_{r} increases. This again corroborates the predictive ability of the LayerPeeled Model.
Network architecture  VGG11  ResNet18  
No. of majority classes  K_{A}=3  K_{A}=5  K_{A}=7  K_{A}=3  K_{A}=5  K_{A}=7 
Original (minority)  15.29  20.30  17.00  30.66  34.26  5.53 
Oversampling (minority)  41.13  57.22  30.50  37.86  53.46  8.13 
Improvement (minority)  25.84  36.92  13.50  7.20  19.20  2.60 
Original (overall)  40.10  57.61  69.09  50.88  64.89  66.13 
Oversampling (overall)  58.25  76.17  73.37  55.91  74.56  67.10 
Improvement (overall)  18.15  18.56  4.28  5.03  9.67  0.97 
Next, we refer to Table 2 for effect on the test performance. The results clearly demonstrate the improvement in test accuracy brought by oversampling for certain choices of the oversampling rates. The improvement is noticeable on both the minority classes and all classes.
A closer look at the results of Table 2, however, reveals that issues remain when addressing Minority Collapse by oversampling. Perhaps the most critical one is that although oversampling with a very large value of w_{r} can mitigate Minority Collapse on the training set, it is at the cost of degrading test accuracy. More specifically, how can we efficiently select an oversampling rate for optimal test performance? More broadly, Minority Collapse does not seem likely to be fully resolved by samplingbased approaches alone, and the doors are widely open for future investigation.
6 Discussion
In this paper, we have developed the LayerPeeled Model as a simple yet effective modeling strategy toward understanding welltrained deep neural networks. The derivation of this model follows a topdown strategy by isolating the last layer from the remaining layers. Owing to the analytical and numerical tractability of the LayerPeeled Model, we provide some explanation of a recently observed phenomenon called neural collapse in deep neural networks trained on balanced datasets [PHD20]. Moving to imbalanced datasets, an analysis of this model suggests that the lastlayer classifiers corresponding to the minority classes would collapse to a single point once the imbalance level is above a certain threshold. This new phenomenon, which we refer to as Minority Collapse, occurs consistently in our computational experiments.
The efficacy of the LayerPeeled Model in analyzing welltrained deep learning models implies that the ansatz (6)—a crucial step in the derivation of this model—is at least a useful approximation. Moreover, this ansatz can be further justified by the following result in an indirect manner, which, together with Theorem 3, shows that the \ell_{2} norm suggested by the ansatz happens to be the only choice among all the \ell_{q} norms that is consistent with empirical observations. Its proof is given in Appendix A.1.
Proposition 2.
Assume p\geq K. For any q\in(1,2)\cup(2,\infty), consider the optimization problem
\displaystyle\min_{\bm{W},\bm{H}}  \displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n}\mathcal{L}(\bm{W}\bm{h}_{% k,i},\bm{y}_{k})  
\displaystyle\mathrm{s.t.}  \displaystyle\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}\right\^{2}\leq E_{W},  
\displaystyle\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n}\sum_{i=1}^{n}\left\\bm{h}_{% k,i}\right\^{q}_{q}\leq E_{H}, 
where \mathcal{L} is the crossentropy loss. Then, any global minimizer of this program does not satisfy (8) for any positive numbers C and C^{\prime}. That is, neural collapse does not emerge in this model.
While the paper has demonstrated its noticeable effectiveness, the LayerPeeled Model requires future investigation for consolidation and extension. First, an important question is to better justify the ansatz (6) used in the development of this model, or equivalently, the second constraint of (7). For example, is the permutation invariance of the weights within the same layer useful for the justification? Moreover, an analysis of the gap between the LayerPeeled Model and welltrained deep learning models would be a welcome advance. For example, how does the gap depend on the neural network architectures? From a different angle, a possible extension is to retain multiple layers following the topdown viewpoint. Explicitly, letting 1\leq m<L be the number of the top layers we wish to retain in the model, we can represent the prediction of the neural network as \bm{f}(\bm{x},\bm{W}_{\textnormal{full}})=\bm{f}(\bm{h}(\bm{x};\bm{W}_{1:(Lm)% }),\bm{W}_{(Lm+1):L}) by denoting by \bm{W}_{1:(Lm)} and \bm{W}_{(Lm+1):L} the first Lm layers and the last m layers, respectively. Consider the mLayerPeeled Model:
\displaystyle\min_{\bm{W},\bm{H}}  \displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n_{k}}\mathcal{L}(\bm{f}(\bm% {h}_{k,i},\bm{W}_{(Lm+1):L}),\bm{y}_{k})  
\displaystyle\mathrm{s.t.}  \displaystyle\frac{1}{K}\\bm{W}_{(Lm+1):L})\^{2}\leq E_{W},  
\displaystyle\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\% \bm{h}_{k,i}\right\^{2}\leq E_{H}. 
The two constraints might be modified to take into account the network architectures. An immediate question is whether this model with m=2 is capable of capturing new patterns of deep learning training.
From a practical standpoint, the LayerPeeled Model together with its convex relaxation (13) offers an analytical and computationally efficient technique to identify and mitigate bias induced by class imbalance when training deep learning models. First, an interesting question is to extend Minority Collapse from the case of twovalued class sizes to general imbalanced datasets. Second, as suggested by our findings in Section 5, how should we choose loss functions in order to mitigate Minority Collapse [CWG+19]. Last, a possible use case of the LayerPeeled Model is to design more efficient sampling schemes to take into account fairness considerations [BG18, ZS18, MMS+19].
Broadly speaking, insights can be gained not only from the LayerPeeled Model but also from its modeling strategy. The details of empirical deep learning models, though formidable, can often be simplified by rendering part of the network modular. When the interest is about the top few layers, for example, this paper clearly demonstrates the benefits of taking a topdown strategy for modeling neural networks, especially in consolidating our understanding of previous results and in discovering new patterns. Owing to its mathematical convenience, the LayerPeeled Model shall open the door for future research extending these benefits.
Acknowledgments
We are grateful to X.Y. Han for helpful discussions about some results of [PHD20]. This work was supported in part by NIH through RF1AG063481, NSF through CAREER DMS1847415 and CCF1934876, an Alfred Sloan Research Fellowship, and the Wharton Dean’s Research Fund.
References
 [ALS19] (2019) A convergence theory for deep learning via overparameterization. In International Conference on Machine Learning, pp. 2388–2464. Cited by: §1.2.
 [AKK+19] (2019) A theoretical analysis of contrastive unsupervised representation learning. arXiv preprint arXiv:1902.09229. Cited by: §3.2.
 [BMP08] (2008) Convex sparse matrix factorizations. arXiv preprint arXiv:0812.1869. Cited by: Appendix B, §4.1.
 [BZM+20] (2020) Wav2vec 2.0: a framework for selfsupervised learning of speech representations. arXiv preprint arXiv:2006.11477. Cited by: §3.2.
 [BFT17] (2017) Spectrallynormalized margin bounds for neural networks. Advances in Neural Information Processing Systems 30, pp. 6241–6250. Cited by: §1.2.
 [BLL+20] (2020) Benign overfitting in linear regression. Proceedings of the National Academy of Sciences. Cited by: §1.2.
 [BHM+19] (2019) Reconciling modern machinelearning practice and the classical bias–variance tradeoff. Proceedings of the National Academy of Sciences 116 (32), pp. 15849–15854. Cited by: §1.2.
 [BCN18] (2018) Optimization methods for largescale machine learning. Siam Review 60 (2), pp. 223–311. Cited by: Appendix B.
 [BMM18] (2018) A systematic study of the class imbalance problem in convolutional neural networks. Neural Networks 106, pp. 249–259. Cited by: §5.
 [BG18] (2018) Gender shades: intersectional accuracy disparities in commercial gender classification. In Conference on fairness, accountability and transparency, pp. 77–91. Cited by: §6.
 [CWG+19] (2019) Learning imbalanced datasets with labeldistributionaware margin loss. In Advances in Neural Information Processing Systems, Vol. 32, pp. 1567–1578. External Links: Link Cited by: §5, §6.
 [CKN+20] (2020) A simple framework for contrastive learning of visual representations. arXiv preprint arXiv:2002.05709. Cited by: §1.1, §3.2.
 [COB19] (2019) On lazy training in differentiable programming. In Advances in Neural Information Processing Systems, Cited by: §1.2.
 [CJL+19] (2019) Classbalanced loss based on effective number of samples. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9268–9277. Cited by: §5.
 [DLL+19] (2019) Gradient descent finds global minima of deep neural networks. In International Conference on Machine Learning, Cited by: §1.2.
 [EMW19] (2019) A comparative analysis of the optimization and generalization property of twolayer neural network and random feature models under gradient descent dynamics. arXiv preprint arXiv:1904.04326. Cited by: §1.2.
 [EW20] (2020) On the emergence of tetrahedral symmetry in the final and penultimate layers of neural network classifiers. arXiv preprint arXiv:2012.05420. Cited by: §1.2, §3.1.
 [FMZ19] (2019) A selective overview of deep learning. arXiv preprint arXiv:1904.05526. Cited by: §1.2.
 [FDZ21] (2021) Mathematical models of overparameterized neural networks. Proceedings of the IEEE (), pp. 1–21. External Links: Document Cited by: §1.2.
 [FLY+20] (2020) Modeling from features: a meanfield framework for overparameterized deep neural networks. arXiv preprint arXiv:2007.01452. Cited by: §1.2.
 [FLL+18] (2018) Spider: nearoptimal nonconvex optimization via stochastic pathintegrated differential estimator. In Advances in Neural Information Processing Systems, pp. 689–699. Cited by: Appendix B.
 [FLZ19] (2019) Sharp analysis for nonconvex SGD escaping from saddle points. In Annual Conference on Learning Theory, pp. 1192–1234. Cited by: Appendix B.
 [HV19] (2019) Structured lowrank matrix factorization: global optimality, algorithms, and applications. IEEE transactions on pattern analysis and machine intelligence 42 (6), pp. 1468–1482. Cited by: Appendix B, §4.1.
 [HT20] (2020) Recent advances in deep learning theory. arXiv preprint arXiv:2012.10931. Cited by: §1.2.
 [HS20] (2020) The local elasticity of neural networks. In International Conference on Learning Representations, Cited by: §1.2.
 [HZR+16] (2016) Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §4.3.
 [HMP+06] (2006) OntoNotes: the 90% solution. In Proceedings of the human language technology conference of the NAACL, Companion Volume: Short Papers, pp. 57–60. Cited by: §4.
 [HLL+16] (2016) Learning deep representation for imbalanced classification. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 5375–5384. Cited by: §4.
 [IS15] (2015) Batch normalization: accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning, pp. 448–456. Cited by: §4.3.
 [JGH18] (2018) Neural tangent kernel: convergence and generalization in neural networks. In Advances in Neural Information Processing Systems, Cited by: §1.2.
 [JK19] (2019) Survey on deep learning with class imbalance. Journal of Big Data 6 (1), pp. 27. Cited by: Appendix B, §1.1, §4, §5.
 [KTW+20] (2020) Supervised contrastive learning. arXiv preprint arXiv:2004.11362. Cited by: §3.2.
 [KRI09] (2009) Learning multiple layers of features from tiny images. Master’s thesis, University of Tront. Cited by: Figure 2.
 [KSH17] (2017) Imagenet classification with deep convolutional neural networks. Communications of the ACM 60 (6), pp. 84–90. Cited by: §1.
 [KWL+19] (2019) Explaining landscape connectivity of lowcost solutions for multilayer nets. In Advances in Neural Information Processing Systems, pp. 14601–14610. Cited by: §1.2.
 [LBH15] (2015) Deep learning. Nature 521 (7553), pp. 436–444. Cited by: §1.
 [LSS20] (2020) Benign overfitting and noisy features. arXiv preprint arXiv:2008.02901. Cited by: §1.2.
 [LR20] (2020) Just interpolate: kernel “ridgeless” regression can generalize. Annals of Statistics 48 (3), pp. 1329–1347. Cited by: §1.2.
 [LS20] (2020) Neural collapse with crossentropy loss. arXiv preprint arXiv:2012.08465. Cited by: §1.2, §3.1.
 [MBB18] (2018) The power of interpolation: understanding the effectiveness of sgd in modern overparametrized learning. In International Conference on Machine Learning, pp. 3325–3334. Cited by: §1.2.
 [MR17] (2017) Data imbalance and classifiers: impact and solutions from a big data perspective. International Journal of Computational Intelligence Research 13 (9), pp. 2267–2281. Cited by: §4.
 [MMS+19] (2019) A survey on bias and fairness in machine learning. arXiv preprint arXiv:1908.09635. Cited by: §6.
 [MMN18] (2018) A mean field view of the landscape of twolayer neural networks. Proceedings of the National Academy of Sciences 115 (33), pp. E7665–E7671. External Links: ISSN 00278424 Cited by: §1.2.
 [MPP20] (2020) Neural collapse with unconstrained features. arXiv preprint arXiv:2011.11619. Cited by: §1.2, §3.1.
 [OS20] (2020) Towards moderate overparameterization: global convergence guarantees for training shallow neural networks. IEEE Journal on Selected Areas in Information Theory. Cited by: §1.
 [PHD20] (2020) Prevalence of neural collapse during the terminal phase of deep learning training. Proceedings of the National Academy of Sciences 117 (40), pp. 24652–24663. Cited by: LayerPeeled Model: Toward Understanding WellTrained Deep Neural Networks, §1.1, §1.2, §1.2, §1.2, §1, §3, §4.3, §6, §6, footnote 8.
 [PGM+19] (2019) Pytorch: an imperative style, highperformance deep learning library. In Advances in neural information processing systems, pp. 8026–8037. Cited by: §3.2.
 [PSM14] (2014) Glove: global vectors for word representation. In Proceedings of the 2014 conference on empirical methods in natural language processing (EMNLP), pp. 1532–1543. Cited by: §3.2.
 [PBL20] (2020) Theoretical issues in deep networks. Proceedings of the National Academy of Sciences. Cited by: §1.2.
 [PL20] (2020) Explicit regularization and implicit bias in deep network classifiers trained with the square loss. arXiv preprint arXiv:2101.00072. Cited by: §3.1.
 [RV18] (2018) Neural networks as interacting particle systems: asymptotic convexity of the loss landscape and universal scaling of the approximation error. In Advances in Neural Information Processing Systems, Cited by: §1.2.
 [SHA20] (2020) Gradient methods never overfit on separable data. arXiv preprint arXiv:2007.00028. Cited by: §1.
 [SSJ20] (2020) On learning rates and Schrödinger operators. arXiv preprint arXiv:2004.06977. Cited by: §1.2.
 [SXY+19] (2019) Metaweightnet: learning an explicit mapping for sample weighting. arXiv preprint arXiv:1902.07379. Cited by: §5.
 [SHM+16] (2016) Mastering the game of go with deep neural networks and tree search. Nature 529 (7587), pp. 484–489. Cited by: §1.
 [SZ14] (2014) Very deep convolutional networks for largescale image recognition. arXiv preprint arXiv:1409.1556. Cited by: Figure 2.
 [SS19] (2019) Mean field analysis of neural networks: a central limit theorem. Stochastic Processes and their Applications. Cited by: §1.2.
 [SHN+18] (2018) The implicit bias of gradient descent on separable data. The Journal of Machine Learning Research 19 (1), pp. 2822–2878. Cited by: §1, footnote 8.
 [SH03] (2003) Grassmannian frames with applications to coding and communication. Applied and Computational Harmonic Analysis 14 (3), pp. 257–275. Cited by: §1.2.
 [SZ03] (2003) On cones of nonnegative quadratic functions. Mathematics of Operations Research 28 (2), pp. 246–267. Cited by: §4.1.
 [SUN19] (2019) Optimization for deep learning: theory and algorithms. arXiv preprint arXiv:1912.08957. Cited by: §1.2.
 [WLW+16] (2016) Training deep neural networks on imbalanced data sets. In 2016 international joint conference on neural networks (IJCNN), pp. 4368–4374. Cited by: §4.
 [WL90] (1990) The optimised internal representation of multilayer classifier networks performs nonlinear discriminant analysis. Neural Networks 3 (4), pp. 367–375. Cited by: §1.
 [XRV17] (2017) Fashionmnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747. Cited by: §4.3.
 [YAR17] (2017) Error bounds for approximations with deep ReLU networks. Neural Networks 94, pp. 103–114. Cited by: §1.1.
 [YCY+20] (2020) Learning diverse and discriminative representations via the principle of maximal coding rate reduction. Advances in Neural Information Processing Systems 33. Cited by: §1.
 [ZKL+16] (2016) Places: an image database for deep scene understanding. arXiv preprint arXiv:1610.02055. Cited by: §4.
 [ZCZ+18] (2018) Stochastic gradient descent optimizes overparameterized deep relu networks. In Advances in Neural Information Processing Systems, Cited by: §1.2.
 [ZS18] (2018) AI can be sexist and racist—it’s time to make it fair. Nature Publishing Group. Cited by: §6.
Appendix A Proofs
For simplicity, in this appendix we define [m_{1}:m_{2}]:=\{m_{1},m_{1}+1,\dots,m_{2}\} for m_{1},m_{2}\in\mathbb{N} with m_{1}\leq m_{2} and [m_{2}]:=[1:m_{2}] for m_{2}\geq 1.
A.1 Balanced Case
A.1.1 Proofs of Theorem 3 and Proposition 2
Because there are multiplications of variables in the objective function, Program (7) is nonconvex. Thus the KKT condition is not sufficient for optimality. To prove Theorem 3, we directly determine the global minimum of (7). During this procedure, one key step is to show that program (7) is equivalent to minimize a symmetric quadratic function:
\sum_{i=1}^{n}\left[\left(\sum_{k=1}^{K}\bm{h}_{k,i}\right)^{\top}\left(\sum_{% k=1}^{K}\bm{w}_{k}\right)K\sum_{k=1}^{K}\bm{h}_{k,i}^{\top}\bm{w}_{k}\right] 
under the same constraints with suitable conditions. Finally, by checking all the conditions to reach the minimum, we obtain the minimizer of (7). The detail is shown below.
Proof of Theorem 3.
By the concavity of \log(\cdot), for any \bm{z}\in\mathbb{R}^{K}, k\in[K], constants C_{a},C_{b}>0, letting C_{c}=\frac{C_{b}}{(C_{a}+C_{b})(K1)}, we have
\displaystyle\log\left(\frac{\bm{z}(k)}{\sum_{{k^{\prime}}=1}^{K}\bm{z}({k^{% \prime}})}\right)  
\displaystyle=  \displaystyle\log(\bm{z}(k))+\log\left(\frac{C_{a}}{C_{a}+C_{b}}\left(\frac{(% C_{a}+C_{b})~{}\bm{z}(k)}{C_{a}}\right)+C_{c}\sum_{{k^{\prime}}=1,{k^{\prime}}% \neq k}^{K}\frac{\bm{z}({k^{\prime}})}{C_{c}}\right)  
\displaystyle\overset{a}{\geq}  \displaystyle\log(\bm{z}(k))+\frac{C_{a}}{C_{a}+C_{b}}\log\left(\frac{(C_{a}+% C_{b})~{}\bm{z}(k)}{C_{a}}\right)+C_{c}\sum_{{k^{\prime}}=1,{k^{\prime}}\neq k% }^{K}\log\left(\frac{\bm{z}({k^{\prime}})}{C_{c}}\right)  
\displaystyle\overset{b}{=}  \displaystyle\frac{C_{b}}{C_{a}+C_{b}}\left[\log(\bm{z}(k))\frac{1}{K1}\sum% _{{k^{\prime}}=1,{k^{\prime}}\neq k}^{K}\log(\bm{z}({k^{\prime}}))\right]+C_{d},  (18) 
where \overset{a}{\geq} applies the concavity of \log(\cdot) and in \overset{b}{=}, we define C_{d}:=\frac{C_{a}}{C_{a}+C_{b}}\log(\frac{C_{a}+C_{b}}{C_{a}})+\frac{C_{b}}{C% _{a}+C_{b}}\log(1/C_{c}). Note that in (A.1.1), C_{a} and C_{b} can be any positive numbers. To prove Theorem 3, we set C_{a}:=\exp\left(\sqrt{E_{H}E_{W}}\right) and C_{b}:=\exp\left(\sqrt{E_{H}E_{W}}/(K1)\right), which shall lead to the tightest lower bound for the objective of (7). Applying (A.1.1) on the objective, we have
\displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n}\mathcal{L}(\bm{W}\bm{h}_{% k,i},\bm{y}_{k})  (19)  
\displaystyle\geq  \displaystyle\frac{C_{b}}{(C_{a}+C_{b})N(K1)}\sum_{i=1}^{n}\left[\left(\sum_{% k=1}^{K}\bm{h}_{k,i}\right)^{\top}\left(\sum_{k=1}^{K}\bm{w}_{k}\right)K\sum_% {k=1}^{K}\bm{h}_{k,i}^{\top}\bm{w}_{k}\right]+C_{d}. 
Defining \bar{\bm{h}}_{i}:=\frac{1}{K}\sum_{k=1}^{K}\bm{h}_{k,i} for i\in[n], it follows by Young’s inequality that
\displaystyle\sum_{i=1}^{n}\left[\left(\sum_{k=1}^{K}\bm{h}_{k,i}\right)^{\top% }\left(\sum_{k=1}^{K}\bm{w}_{k}\right)K\sum_{k=1}^{K}\bm{h}_{k,i}^{\top}\bm{w% }_{k}\right]  
\displaystyle=  \displaystyle K\sum_{i=1}^{n}\sum_{k=1}^{K}(\bar{\bm{h}}_{i}\bm{h}_{k,i})^{% \top}\bm{w}_{k}  
\displaystyle\geq  \displaystyle\frac{K}{2}\sum_{k=1}^{K}\sum_{i=1}^{n}\\bar{\bm{h}}_{i}\bm{h}% _{k,i}\^{2}/C_{e}\frac{C_{e}N}{2}\sum_{k=1}^{K}\\bm{w}_{k}\^{2},  (20) 
where we pick C_{e}:=\sqrt{E_{H}/E_{W}}. The two terms in the right hand side of (A.1.1) can be bounded via the constraints of (7). Especially, we have
\frac{C_{e}N}{2}\sum_{k=1}^{K}\\bm{w}_{k}\^{2}\leq\frac{KN\sqrt{E_{H}E_{W}}}% {2},  (21) 
and
\displaystyle\frac{K}{2}\sum_{k=1}^{K}\sum_{i=1}^{n}\\bar{\bm{h}}_{i}\bm{h}_% {k,i}\^{2}/C_{e}  \displaystyle\overset{a}{=}\frac{K^{2}}{2C_{e}}\sum_{i=1}^{n}\left(\frac{1}{K}% \sum_{k=1}^{K}\\bm{h}_{k,i}\^{2}\\bar{\bm{h}}_{i}\^{2}\right)  
\displaystyle\leq\frac{K}{2C_{e}}\sum_{k=1}^{K}\sum_{i=1}^{n}\\bm{h}_{k,i}\^% {2}\leq\frac{KN\sqrt{E_{H}E_{W}}}{2},  (22) 
where \overset{a}{=} uses the fact that \mathbb{E}\\bm{a}\mathbb{E}[\bm{a}]\^{2}=\mathbb{E}\\bm{a}\^{2}\\mathbb% {E}[\bm{a}]\^{2}. Thus plugging (A.1.1), (21), and (A.1.1) into (19), we have
\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n}\mathcal{L}(\bm{W}\bm{h}_{k,i},\bm{y}_{% k})\geq\frac{C_{b}}{C_{a}+C_{b}}\frac{K\sqrt{E_{H}E_{W}}}{K1}+C_{d}:=L_{0}.  (23) 
Now we check the conditions to make the equality in (23) hold.
By the strict concavity of \log(\cdot), the equality in (19) holds if and only if
\bm{h}_{k,i}\bm{w}_{k}=\bm{h}_{{k^{\prime}},i}\bm{w}_{{k^{\prime}}}+\log\left(% \frac{C_{b}}{C_{a}}\right), 
for all (k,i,{k^{\prime}})\in\{(k,i,{k^{\prime}}):k\in[K],{k^{\prime}}\in[K],{k^{% \prime}}\neq k,i\in[n]\}. The equality in (A.1.1) holds if and only if
\bar{\bm{h}}_{i}\bm{h}_{k,i}=C_{e}\bm{w}_{k},\quad k\in[K],~{}i\in[n]. 
The equalities in (21) and (A.1.1) hold if and only if:
\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n}\sum_{i=1}^{n}\left\\bm{h}_{k,i}\right\^% {2}=E_{H},\quad\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}\right\^{2}=E_{W},% \quad\bar{\bm{h}}_{i}=\bm{0}_{p},~{}i\in[n]. 
Applying Lemma 2 shown in the end of the section, we have \left(\bm{H},\bm{W}\right) satisfies (8).
Proof of Proposition 2.
We introduce the set \mathcal{S}_{R} as
\mathcal{S}_{R}:=\left\{\left(\bm{H},\bm{W}\right):\begin{matrix}[\bm{h}_{1},% \ldots,\bm{h}_{K}]=B_{1}b\bm{P}\left[(a+1)\bm{I}_{K}\bm{1}_{K}\bm{1}_{K}^{% \top}\right],\\ \bm{W}=B_{2}B_{3}b\left[(a+1))\bm{I}_{K}\bm{1}_{K}\bm{1}_{K}^{\top}\right]^{% \top}\bm{P}^{\top}\\ \bm{h}_{k,i}=\bm{h}_{k},\quad k\in[K],~{}i\in[n],\\ b\geq 0,~{}a\geq 0,~{}b^{q}[a^{q}+(K1)]=1,\\ B_{1}\leq\sqrt{E_{H}},~{}B_{2}\leq\sqrt{E_{W}},~{}B_{3}\geq 0,~{}B_{3}^{2}% b^{2}[a^{2}+(K1)]=1,\\ \bm{P}\in\mathbb{R}^{p\times K},~{}\bm{P}^{\top}\bm{P}=\bm{I}_{K}.\end{matrix}\right\} 
We can examine that \mathcal{S}_{R} admits the constraints of (7). So any \left(\bm{H},\bm{W}\right)\in\mathcal{S}_{R} is a feasible solution. Moreover, one can observe that this feasible solution has a special symmetry structure: for each k\in[K], the features in class k collapse to their mean \bm{h}_{k}, i.e., (NC1), and \bm{w}_{k} is parallel to \bm{h}_{k}, i.e., (NC3). However, weights do not form the vertices of ETF unless a=K1. Therefore, it suffices to show that the minimizer of \frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n}\mathcal{L}(\bm{W}\bm{h}_{k,i},\bm{y}_{% k}) in the set \mathcal{S}_{R} do not satisfy a=K1.
In fact, for any \left(\bm{H},\bm{W}\right)\in\mathcal{S}_{R}, the objective function value can be written as a function of B_{1}, B_{2}, B_{3}, a, and b. We have
\displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n}\mathcal{L}(\bm{W}\bm{h}_{% k,i},\bm{y}_{k})  
\displaystyle=  \displaystyle\log\left(\frac{\exp(B_{1}B_{2}B_{3}b^{2}[a^{2}+(K1)])}{\exp(B_% {1}B_{2}B_{3}b^{2}[a^{2}+K1])+(K1)\exp(B_{1}B_{2}B_{3}b^{2}[K22a])}\right)  
\displaystyle=  \displaystyle\log\left(\frac{1}{1+(K1)\exp(B_{1}B_{2}B_{3}b^{2}(a+1)^{2})}% \right). 
It follows to maximize B_{1}B_{2}B_{3}b^{2}(a+1)^{2} or equivalently \left[B_{1}B_{2}B_{3}b^{2}(a+1)^{2}\right]^{2}. By B_{3}^{2}b^{2}[a^{2}+(K1)]=1 and b^{q}[a^{q}+(K1)]=1, we have
\displaystyle\left[B_{1}B_{2}B_{3}b^{2}(a+1)^{2}\right]^{2}  \displaystyle\overset{a}{\leq}E_{H}E_{W}\left[B_{3}^{2}b^{2}(a+1)^{2}\right]% \left[b^{2}(a+1)^{2}\right]  
\displaystyle=E_{H}E_{W}\left[\frac{(a+1)^{2}}{a^{2}+(K1)}\right]\left[\frac{% (a+1)^{q}}{a^{q}+K1}\right]^{2/q}.  (24) 
where \overset{a}{\leq} picks B_{1}=\sqrt{E_{H}} and B_{2}=\sqrt{E_{W}}. Let us consider function g:[0,+\infty)\to\mathbb{R}:g(x)=\left[\frac{(x+1)^{2}}{x^{2}+(K1)}\right]% \left[\frac{(x+1)^{q}}{x^{q}+K1}\right]^{2/q}. Note that by the firstorder optimality, once if g^{\prime}(K1)\neq 0, then (24) cannot achieve the maximum at a=K1, which is our desired result. Indeed, we have
g^{\prime}(K1)=\frac{2K^{4}}{\left[(K1)^{2}+(K1)\right]\left[(K1)^{q}+K1% \right]^{2/q+1}}\left[(K1)(K1)^{q1}\right]. 
So a=K1 will not be the maximizer of (24) unless q=2. We complete the proof. ∎
Lemma 2.
Suppose \left(\bm{H},\bm{W}\right) satisfies
\bar{\bm{h}}_{i}\bm{h}_{k,i}=\sqrt{\frac{E_{H}}{E_{W}}}\bm{w}_{k},\quad k\in% [K],\quad i\in[n],  (25) 
and
\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n}\sum_{i=1}^{n}\left\\bm{h}_{k,i}\right\^% {2}=E_{H},\quad\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}\right\^{2}=E_{W},% \quad\bar{\bm{h}}_{i}=\bm{0}_{p},~{}i\in[n],  (26) 
where \bar{\bm{h}}_{i}:=\frac{1}{K}\sum_{k=1}^{K}\bm{h}_{k,i} with i\in[n]. Moreover, there exists a constant C such that for all (k,i,{k^{\prime}})\in\{(k,i,{k^{\prime}}):k\in[K],{k^{\prime}}\in[K],{k^{% \prime}}\neq k,i\in[n]\}, we have
\bm{h}_{k,i}\cdot\bm{w}_{k}=\bm{h}_{k,i}\cdot\bm{w}_{{k^{\prime}}}+C.  (27) 
Then \left(\bm{H},\bm{W}\right) satisfies (8).
Proof.
Combining (25) with the last equality in (26), we have
\bm{W}=\sqrt{\frac{E_{W}}{E_{H}}}~{}\bigg{[}\bm{h}_{1},\ldots,\bm{h}_{K}\bigg{% ]}^{\top},\quad\quad\bm{h}_{k,i}=\bm{h}_{k},~{}k\in[K],~{}i\in[n]. 
Thus it remains to show
\displaystyle\bm{W}=\sqrt{E_{W}}~{}\left({\bm{M}^{\star}}\right)^{\top},  (28) 
where {\bm{M}^{\star}} is a Ksimplex ETF.
Plugging \bm{h}_{k}=\bm{h}_{k,i}=\sqrt{\frac{E_{W}}{E_{H}}}\bm{w}_{k} into (27), we have, for all (k,{k^{\prime}})\in\{(k,{k^{\prime}}):k\in[K],{k^{\prime}}\in[K],{k^{\prime}}% \neq k\},
\sqrt{\frac{E_{H}}{E_{W}}}\\bm{w}_{k}\^{2}=\bm{h}_{k}\cdot\bm{w}_{k}=\bm{h}_% {k}\cdot\bm{w}_{{k^{\prime}}}+C=\sqrt{\frac{E_{H}}{E_{W}}}\\bm{w}_{{k^{\prime% }}}\^{2}+C, 
and
\sqrt{\frac{E_{H}}{E_{W}}}\\bm{w}_{{k^{\prime}}}\^{2}=\bm{h}_{{k^{\prime}}}% \cdot\bm{w}_{{k^{\prime}}}=\bm{h}_{{k^{\prime}}}\cdot\bm{w}_{k}+C=\sqrt{\frac{% E_{W}}{E_{H}}}\\bm{h}_{{k^{\prime}}}\^{2}+C=\sqrt{\frac{E_{H}}{E_{W}}}\\bm{% w}_{{k^{\prime}}}\^{2}+C. 
Therefore, from \frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}\right\^{2}=E_{W}, we have \\bm{w}_{k}\=\sqrt{E_{W}} and \bm{h}_{k}\bm{w}_{{k^{\prime}}}=C^{\prime}:=\sqrt{E_{H}E_{W}}C.
On the other hand, recalling that \bar{\bm{h}}_{i}=\bm{0}_{p} for i\in[n], we have \sum_{k=1}^{K}\bm{h}_{k}=\bm{0}_{p}, which further yields \sum_{k=1}^{K}\bm{h}_{k}\cdot\bm{w}_{k^{\prime}}=0 for {k^{\prime}}\in[K]. Then it follows from \bm{h}_{k}\bm{w}_{{k^{\prime}}}=C^{\prime} and \bm{h}_{k}\bm{w}_{k}=\sqrt{E_{H}E_{W}} that \bm{h}_{k}\bm{w}_{{k^{\prime}}}=\sqrt{E_{H}E_{W}}/(K1). Thus we obtain
\bm{W}\bm{W}^{\top}=\sqrt{\frac{E_{W}}{E_{H}}}\bm{W}[\bm{h}_{1},\ldots,\bm{h}_% {K}]=E_{W}\left[\frac{K}{K1}\left(\bm{I}_{K}\frac{1}{K}\bm{1}_{K}\bm{1}_{K}^% {\top}\right)\right], 
which implies (28). We complete the proof. ∎
A.1.2 Proofs of Theorems 5 and 6
Proof of Theorem 5.
For k\in[K], i\in[n], and {k^{\prime}}\in[K], define
E_{k,i,{k^{\prime}}}:=\frac{1}{n}\sum_{j=1}^{n}\exp(\bm{h}_{k,i}\cdot\bm{h}_{{% k^{\prime}},j}/\tau). 
For constants C_{a}:=\exp\left(\sqrt{E_{H}E_{W}}\right) and C_{b}:=\exp\left(\sqrt{E_{H}E_{W}}/(K1)\right), let C_{c}:=\frac{C_{b}}{(C_{a}+C_{b})(K1)}. Using a similar argument as (A.1.1), we have for j\in[n],
\displaystyle\log\left(\frac{\exp(\bm{h}_{k,i}\cdot\bm{h}_{k,j}/\tau)}{\sum_{% {k^{\prime}}=1}^{K}E_{k,i,{k^{\prime}}}}\right)  (29)  
\displaystyle=  \displaystyle\bm{h}_{k,i}\cdot\bm{h}_{k,j}/\tau+\log\left(\frac{C_{a}}{C_{a}+% C_{b}}\left(\frac{(C_{a}+C_{b})~{}E_{k,i,k}}{C_{a}}\right)+C_{c}\sum_{{k^{% \prime}}=1,~{}{k^{\prime}}\neq k}^{K}\frac{E_{k,i,{k^{\prime}}}}{C_{c}}\right)  
\displaystyle\overset{a}{\geq}  \displaystyle\bm{h}_{k,i}\cdot\bm{h}_{k,j}/\tau+\frac{C_{a}}{C_{a}+C_{b}}\log% \left(\frac{(C_{a}+C_{b})~{}E_{k,i,k}}{C_{a}}\right)+C_{c}\sum_{{k^{\prime}}=1% ,~{}{k^{\prime}}\neq k}^{K}\log\left(\frac{E_{k,i,{k^{\prime}}}}{C_{c}}\right)  
\displaystyle\overset{b}{=}  \displaystyle\bm{h}_{k,i}\cdot\bm{h}_{k,j}/\tau+\frac{C_{a}}{C_{a}+C_{b}}\log% \left(E_{k,i,k}\right)+C_{c}\sum_{{k^{\prime}}=1,~{}{k^{\prime}}\neq k}^{K}% \log\left(E_{k,i,{k^{\prime}}}\right)+C_{d}  
\displaystyle\overset{c}{\geq}  \displaystyle\bm{h}_{k,i}\cdot\bm{h}_{k,j}/\tau+\frac{C_{a}}{(C_{a}+C_{b})n}% \sum_{\ell=1}^{n}\bm{h}_{k,i}\cdot\bm{h}_{k,\ell}/\tau+\frac{C_{c}}{n}\sum_{{k% ^{\prime}}=1,~{}{k^{\prime}}\neq k}^{K}\sum_{\ell=1}^{n}\bm{h}_{k,i}\cdot\bm{h% }_{{k^{\prime}},\ell}/\tau+C_{d}. 
where \overset{a}{\geq} and \overset{c}{\geq} apply the concavity of \log(\cdot) and in \overset{b}{=}, we define C_{d}:=\frac{C_{a}}{C_{a}+C_{b}}\log(\frac{C_{a}+C_{b}}{C_{a}})+\frac{C_{b}}{C% _{a}+C_{b}}\log(1/C_{c}). Then plugging (29) into the objective function, we have
\displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n}\frac{1}{n}\sum_{j=1}^{n}% \log\left(\frac{\exp(\bm{h}_{k,i}\cdot\bm{h}_{k,j}/\tau)}{\sum_{{k^{\prime}}=1% }^{K}\sum_{\ell=1}^{n}\exp(\bm{h}_{k,i}\cdot\bm{h}_{{k^{\prime}},\ell})}\right)  (30)  
\displaystyle=  \displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n}\frac{1}{n}\sum_{j=1}^{n}% \log\left(\frac{\exp(\bm{h}_{k,i}\cdot\bm{h}_{k,j}/\tau)}{\sum_{{k^{\prime}}=1% }^{K}E_{k,i,{k^{\prime}}}}\right)+\log(n)  
\displaystyle\overset{\eqref{eq:contral1}}{\geq}  \displaystyle\frac{C_{b}K}{(C_{a}+C_{b})N(K1)\tau}\sum_{k=1}^{K}\sum_{i=1}^{n% }\left(\frac{1}{n}\sum_{j=1}^{n}\left(\bm{h}_{k,i}\cdot\bm{h}_{k,j}\frac{1}{% K}\sum_{{k^{\prime}}=1}^{K}\bm{h}_{k,i}\cdot\bm{h}_{{k^{\prime}},j}\right)% \right)+C_{d}+\log(n). 
Now defining \bar{\bm{h}}_{i}:=\frac{1}{K}\sum_{k=1}^{K}\bm{h}_{k,i} for i\in[n], a similar argument as (A.1.1) and (A.1.1) gives that
\displaystyle\sum_{k=1}^{K}\sum_{i=1}^{n}\left(\frac{1}{n}\sum_{j=1}^{n}\left% (\bm{h}_{k,i}\cdot\bm{h}_{k,j}\frac{1}{K}\sum_{{k^{\prime}}=1}^{K}\bm{h}_{k,i% }\cdot\bm{h}_{{k^{\prime}},j}\right)\right)  
\displaystyle=  \displaystyle\sum_{k=1}^{K}\sum_{i=1}^{n}\left(\frac{1}{n}\sum_{j=1}^{n}\bm{h% }_{k,i}\cdot(\bm{h}_{k,j}\bar{\bm{h}}_{j})\right)  
\displaystyle\overset{a}{\geq}  \displaystyle\frac{1}{2}\sum_{k=1}^{K}\sum_{i=1}^{n}\left\\bm{h}_{k,i}\right% \^{2}\frac{1}{2}\sum_{k=1}^{K}\sum_{i=1}^{n}\left\\bm{h}_{k,i}\bar{\bm{h}}% _{i}\right\^{2}  
\displaystyle\overset{b}{\geq}  \displaystyle\frac{1}{2}\sum_{k=1}^{K}\sum_{i=1}^{n}\left\\bm{h}_{k,i}\right% \^{2}\frac{K}{2}\sum_{i=1}^{n}\left(\frac{1}{K}\sum_{k=1}^{K}\left\\bm{h}_{% k,i}\right\^{2}\left\\bar{\bm{h}}_{i}\right\^{2}\right)  
\displaystyle\geq  \displaystyle\sum_{k=1}^{K}\sum_{i=1}^{n}\left\\bm{h}_{k,i}\right\^{2}% \overset{c}{\geq}NE_{H},  (31) 
where \overset{a}{\geq} follows from Young’s inequality, \overset{b}{\geq} follows from \mathbb{E}\\bm{a}\mathbb{E}[\bm{a}]\^{2}=\mathbb{E}\\bm{a}\^{2}\\mathbb% {E}[\bm{a}]\^{2}, and \overset{c}{\geq} uses the constraint of (10). Therefore, plugging (A.1.2) into (30) yields that
\displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n}\frac{1}{n}\sum_{j=1}^{n}% \log\left(\frac{\exp(\bm{h}_{k,i}\cdot\bm{h}_{k,j}/\tau)}{\sum_{{k^{\prime}}=1% }^{K}\sum_{\ell=1}^{n}\exp(\bm{h}_{k,i}\cdot\bm{h}_{{k^{\prime}},\ell}/\tau)}\right)  
\displaystyle\geq  \displaystyle\frac{C_{b}KE_{H}}{(C_{a}+C_{b})(K1)\tau}+C_{d}+\log(n).  (32) 
Now we check the conditions to make the equality in (32) hold. By the strictly concavity of \log(\cdot), the equality in (29) holds only if for all (k,i,{k^{\prime}})\in\{(k,i,{k^{\prime}}):k\in[K],{k^{\prime}}\in[K],{k^{% \prime}}\neq k,i\in[n]\},
\frac{E_{k,i,k}}{C_{a}}=\frac{E_{k,i,{k^{\prime}}}}{C_{b}}.  (33) 
The equality in (A.1.2) holds if and only if:
\bm{h}_{k,i}=\bm{h}_{k},~{}i\in[n],~{}k\in[K],\quad\frac{1}{K}\sum_{k=1}^{K}% \left\\bm{h}_{k}\right\^{2}=E_{H},\quad\sum_{k=1}^{K}\bm{h}_{k}=\bm{0}_{p}.  (34) 
Plugging \bm{h}_{k,i}=\bm{h}_{k} into (33), we have for (k,{k^{\prime}})\in\{k,{k^{\prime}}:k\in[K],{k^{\prime}}\in[K],{k^{\prime}}% \neq k\},
\frac{\exp(\\bm{h}_{k}\^{2})}{C_{a}}=\frac{\exp(\bm{h}_{k}\cdot\bm{h}_{{k^{% \prime}}})}{C_{b}}=\frac{\exp(\\bm{h}_{k^{\prime}}\^{2})}{C_{a}}. 
Then it follows from \frac{1}{K}\sum_{k=1}^{K}\left\\bm{h}_{k}\right\^{2}=E_{H} that \\bm{h}_{k}\^{2}=E_{H} for k\in[K]. On the other hand, since \sum_{k=1}^{K}\bm{h}_{k}=\bm{0}_{p}, we obtain
\bm{h}_{k}\cdot\bm{h}_{{k^{\prime}}}=\frac{E_{H}}{K1} 
for (k,{k^{\prime}})\in\{k,{k^{\prime}}:k\in[K],{k^{\prime}}\in[K],{k^{\prime}}% \neq k\}. Therefore,
[\bm{h}_{1},\ldots,\bm{h}_{K}]^{\top}[\bm{h}_{1},\ldots,\bm{h}_{K}]=E_{H}\left% [\frac{K}{K1}\left(\bm{I}_{K}\bm{1}_{K}\bm{1}_{K}^{\top}\right)\right], 
which implies (12).
Proof of Theorem 6.
We first determine the minimum value of (7). For the simplicity of our expressions, we introduce \bm{z}_{k,i}:=\bm{W}\bm{h}_{k,i} for k\in[K] and i\in[n]. By the convexity of g_{2}, for any k\in[K] and i\in[n], we have
\displaystyle\sum_{{k^{\prime}}=1,~{}{k^{\prime}}\neq k}^{K}g_{2}\left(\bm{S}(% \bm{z}_{k,j})({k^{\prime}})\right)  \displaystyle\geq(K1)g_{2}\left(\frac{1}{K1}\sum_{{k^{\prime}}=1,~{}{k^{% \prime}}\neq k}^{K}\bm{S}(\bm{z}_{k,i})({k^{\prime}})\right)  
\displaystyle\overset{a}{=}(K1)g_{2}\left(1\frac{1}{K1}\bm{S}(\bm{z}_{k,i})% (k)\right),  (35) 
where \overset{a}{=} uses \sum_{k=1}^{K}\bm{S}(\bm{a})(k)=1 for any \bm{a}\in\mathbb{R}^{K}. Then it follows by the convexity of g_{1} and g_{2} that
\displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n}\mathcal{L}(\bm{W}\bm{h}_{% k,i},\bm{y}_{k})  (36)  
\displaystyle=  \displaystyle\frac{1}{N}\sum_{i=1}^{n}\sum_{k=1}^{K}\left[g_{1}\left(\bm{S}(% \bm{z}_{k,i})(k)\right)+\sum_{{k^{\prime}}=1,~{}{k^{\prime}}\neq k}^{K}g_{2}% \left(\bm{S}(\bm{z}_{{k^{\prime}},i})({k^{\prime}})\right)\right]  
\displaystyle\overset{\eqref{eq:szz}}{\geq}  \displaystyle\frac{1}{N}\sum_{i=1}^{n}\sum_{k=1}^{K}\left[g_{1}\left(\bm{S}(% \bm{z}_{k,i})(k)\right)+(K1)g_{2}\left(1\frac{1}{K1}\bm{S}(\bm{z}_{k,i})(k)% \right)\right]  
\displaystyle\geq  \displaystyle g_{1}\left(\frac{1}{N}\sum_{i=1}^{n}\sum_{k=1}^{K}\bm{S}(\bm{z}_% {k,i})(k)\right)+(K1)g_{2}\left(1\frac{1}{N(K1)}\sum_{i=1}^{n}\sum_{k=1}^{K% }\bm{S}(\bm{z}_{k,i})(k)\right). 
Because g_{1}(x)+(K1)g_{2}(1\frac{x}{K1}) is monotonously deceasing, it suffices to maximize
\frac{1}{N}\sum_{i=1}^{n}\sum_{k=1}^{K}\bm{S}(\bm{z}_{k,i})(k). 
To begin with, for any \bm{z}_{k,i} with k\in[K] and i\in[n], by convexity of exponential function and the monotonicity of q(x)=\frac{a}{a+x} for x>0 if a>0, we have
\displaystyle\bm{S}(\bm{z}_{k,i})(k)  \displaystyle=\frac{\exp(\bm{z}_{k,i}(k))}{\sum_{{k^{\prime}}=1}^{K}\exp(\bm{z% }_{k,i}({k^{\prime}}))}  
\displaystyle\leq\frac{\exp(\bm{z}_{k,i}(k))}{\exp(\bm{z}_{k,i}(k))+(K1)\exp% \left(\frac{1}{K1}\sum_{{k^{\prime}}=1,~{}{k^{\prime}}\neq k}^{K}\bm{z}_{k,i}% ({k^{\prime}})\right)}  
\displaystyle=\frac{1}{1+(K1)\exp\left(\frac{1}{K1}\sum_{{k^{\prime}}=1,~{}{% k^{\prime}}\neq k}^{K}\bm{z}_{k,i}({k^{\prime}})\bm{z}_{k,i}(k)\right)}.  (37) 
Consider function g_{0}:\mathbb{R}\to\mathbb{R} as g_{0}(x)=\frac{1}{1+C\exp(x)} with C:=(K1)\geq 1. We have
g_{0}^{\prime\prime}(x)=\frac{\exp(x)(1+C\exp(x))(1C\exp(x))}{(1+C\exp(x))^{% 4}}.  (38) 
For any feasible solution \left(\bm{H},\bm{W}\right) of (7), we divide the index set [n] into two subsets \mathcal{S}_{1} and \mathcal{S}_{2} defined below:

(A)
i\in\mathcal{S}_{1} if there exists at least one k\in[K] such that
\frac{1}{K1}\sum_{{k^{\prime}}=1,~{}{k^{\prime}}\neq k}^{K}\bm{z}_{k,i}({k^{% \prime}})\bm{z}_{k,i}(k)\geq\log\left(\frac{1}{K1}\right). 
(B)
i\in\mathcal{S}_{2} if for all k\in[K], \frac{1}{K1}\sum_{{k^{\prime}}=1,~{}{k^{\prime}}\neq k}^{K}\bm{z}_{k,i}({k^{% \prime}})\bm{z}_{k,i}(k)<\log\left(\frac{1}{K1}\right).
Clearly, \mathcal{S}_{1}\cap\mathcal{S}_{2}=\varnothing. Let \mathcal{S}_{1}=t, then \mathcal{S}_{2}=nt. Define function L:[n]\to\mathbb{R} as
L(t):=\begin{cases}N\left(\frac{1}{2}t+\frac{K(nt)}{1+\exp\left(\frac{K}{K1% }\sqrt{n/(nt)}\sqrt{E_{H}E_{W}}\log(K1)\right)}\right),&\quad t\in[0:n1],% \\ N\frac{n}{2},&\quad t=n.\end{cases}  (39) 
We show in Lemma 3 (see the end of the proof) that
\frac{1}{N}\sum_{i=1}^{n}\sum_{k=1}^{K}\bm{S}(\bm{z}_{k,i})(i)\leq\frac{1}{N}L% (0).  (40) 
Plugging (40) into (36), the objective function can be lower bounded as:
\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n}\mathcal{L}(\bm{W}\bm{h}_{k,i},\bm{y}_{% k})\geq g_{1}\left(\frac{1}{N}L(0)\right)+(K1)g_{2}\left(1\frac{1}{N(K1)}L(% 0)\right):=L_{0}.  (41) 
On the other hand, one can directly verify that the equality for (41) is reachable when (\bm{H},\bm{W}) satisfies (8). So L_{0} is the global minimum of (7) and (8) is a minimizer of (7).
Now we show all the solutions are in form (8) under the assumption that g_{2} is strictly convex and g_{1} (or g_{2}) are strictly monotone.
By the strict convexity of g_{2}, the equality in (A.1.2) holds if and only if for any k\in[K] and i\in[n] and {k^{\prime}}\in[K], {k^{\prime\prime}}\in[K] such that {k^{\prime}}\neq k and {k^{\prime\prime}}\neq k, we have
\bm{S}(\bm{z}_{i,j})(k_{1})=\bm{S}(\bm{z}_{i,j})(k_{2}), 
which indicates that
\bm{h}_{k,i}\cdot\bm{w}_{{k^{\prime}}}=\bm{h}_{k,i}\cdot\bm{w}_{{k^{\prime% \prime}}}  (42) 
Again, by the strict convexity of g_{2}, (36) holds if and only if for all k\in[K], i\in[n], and a suitable number C^{\prime}\in(0,1), we have
\mathcal{S}(\bm{z}_{k,i})(k):=C^{\prime}.  (43) 
Combining (42) with (43), we have for all (k,i,{k^{\prime}})\in\{(k,i,{k^{\prime}}):k\in[K],{k^{\prime}}\in[K],{k^{% \prime}}\neq k,i\in[n]\},
\frac{\exp(\bm{h}_{k,i}\cdot\bm{w}_{k})}{\exp(\bm{h}_{k,i}\cdot\bm{w}_{{k^{% \prime}}})}=\frac{C^{\prime}(K1)}{1C^{\prime}}, 
which implies that
\bm{h}_{k,i}\cdot\bm{w}_{k}=\bm{h}_{k,i}\cdot\bm{w}_{{k^{\prime}}}+\log\left(% \frac{C^{\prime}(K1)}{1C^{\prime}}\right). 
On the other hand, by the strict monotonicity of g_{1}(x)+(K1)g_{2}(1\frac{x}{K1}), the equality in (41) holds if and only if \frac{1}{N}\sum_{i=1}^{n}\sum_{k=1}^{K}\bm{S}(\bm{z}_{k,i})(k)=L(0). Thus Lemma 3 reads
\bar{\bm{h}}_{i}\bm{h}_{k,i}=\sqrt{\frac{E_{H}}{E_{W}}}\bm{w}_{k},\quad k\in% [K],\quad i\in[n], 
and
\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n}\sum_{i=1}^{n}\left\\bm{h}_{k,i}\right\^% {2}=E_{H},\quad\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}\right\^{2}=E_{W},% \quad\bar{\bm{h}}_{i}=\bm{0}_{p},~{}i\in[n], 
where \bar{\bm{h}}_{i}:=\frac{1}{K}\sum_{k=1}^{K}\bm{h}_{k,i} with i\in[n]. In all, from Lemma 2, we have \left(\bm{H},\bm{W}\right) satisfies (8), achieving the uniqueness argument. We complete the proof of Theorem 6. ∎
Lemma 3.
For any feasible solution \left(\bm{H},\bm{W}\right), we have
\sum_{i=1}^{n}\sum_{k=1}^{K}\bm{S}(\bm{W}\bm{h}_{k,i})(k)\leq L(0),  (44) 
with L defined in (39). Moreover, recalling the definition of \mathcal{S}_{1} and \mathcal{S}_{2} in (A) and (B), the equality in (44) holds if and only if \mathcal{S}_{1}=0,
\bar{\bm{h}}_{i}\bm{h}_{k,i}=\sqrt{\frac{E_{H}}{E_{W}}}\bm{w}_{k},\quad k\in% [K],\quad i\in[n], 
and
\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n}\sum_{i=1}^{n}\left\\bm{h}_{k,i}\right\^% {2}=E_{H},\quad\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}\right\^{2}=E_{W},% \quad\bar{\bm{h}}_{i}=\bm{0}_{p},~{}i\in[n], 
where \bar{\bm{h}}_{i}:=\frac{1}{K}\sum_{k=1}^{K}\bm{h}_{k,i} with i\in[n].
Proof of Lemma 3.
For any feasible solution \left(\bm{H},\bm{W}\right), we separately consider \mathcal{S}_{1} and \mathcal{S}_{2} defined in (A) and (B), respectively. Let t:=\mathcal{S}_{1}.

•
For i\in\mathcal{S}_{1}, let k\in[K] be any index such that \frac{1}{K1}\sum_{{k^{\prime}}\neq k}\bm{z}_{k,i}({k^{\prime}})\bm{z}_{k,i}(% k)\geq\log\left(\frac{1}{K1}\right), where \bm{z}_{k,i}:=\bm{W}\bm{h}_{k,i}. By the monotonicity of g_{0}(x), it follows from (37) that S(\bm{z}_{k,i})(k)\leq 1/2. Furthermore, for the other index {k^{\prime}}\in[K] such that {k^{\prime}}\neq k, using that \frac{\exp(\bm{z}_{{k^{\prime}},i}({k^{\prime}}))}{\sum_{{k^{\prime\prime}}=1}% ^{K}\exp(\bm{z}_{{k^{\prime}},i})({k^{\prime\prime}})}\leq 1, we have
\sum_{i\in\mathcal{S}_{1}}\sum_{k=1}^{K}\bm{S}(\bm{z}_{k,i})(k)\leq t(1/2+K1). (45) 
•
For i\in\mathcal{S}_{2}, by the concavity of g_{0}(x) when x<\log\left(\frac{1}{K1}\right) from (38), we have, for \mathcal{S}_{2}\neq\varnothing,
\displaystyle\sum_{i\in\mathcal{S}_{2}}\sum_{k=1}^{K}\bm{S}(\bm{z}_{k,i})(k) (46) \displaystyle\overset{\eqref{eq:expb}}{\leq} \displaystyle\sum_{i\in\mathcal{S}_{2}}\sum_{k=1}^{K}\frac{1}{1+(K1)\exp\left% (\frac{1}{K1}\sum_{{k^{\prime}}=1,~{}{k^{\prime}}\neq k}^{K}\bm{z}_{k,i}({k^{% \prime}})\bm{z}_{k,i}(k)\right)} \displaystyle\leq \displaystyle\frac{(nt)K}{1+(K1)\exp\left(\frac{1}{(nt)K}\sum_{i\in\mathcal% {S}_{2}}\sum_{k=1}^{K}\left(\frac{1}{K1}\sum_{{k^{\prime}}=1,~{}{k^{\prime}}% \neq k}^{K}\bm{z}_{k,i}({k^{\prime}})\bm{z}_{k,i}(k)\right)\right)}. We can bound \sum_{i\in\mathcal{S}_{2}}\sum_{k=1}^{K}\left(\frac{1}{K1}\sum_{{k^{\prime}}=% 1,~{}{k^{\prime}}\neq k}^{K}\bm{z}_{k,i}({k^{\prime}})\bm{z}_{k,i}(k)\right) using the similar arguments in (A.1.1) and (A.1.1). Especially, recalling \bar{\bm{h}}_{i}=\frac{1}{K}\sum_{k=1}^{K}\bm{h}_{k,i} for i\in[n], we have
\displaystyle\sum_{i\in\mathcal{S}_{2}}\sum_{k=1}^{K}\left(\frac{1}{K1}\sum_{% {k^{\prime}}=1,~{}{k^{\prime}}\neq k}^{K}\bm{z}_{k,i}({k^{\prime}})\bm{z}_{k,% i}(k)\right) (47) \displaystyle= \displaystyle\frac{1}{K1}\sum_{i\in\mathcal{S}_{2}}\left[\left(\sum_{k=1}^{K}% \bm{h}_{k,i}\right)^{\top}\left(\sum_{k=1}^{K}\bm{w}_{k}\right)K\sum_{K=1}^{K% }\bm{h}_{k,i}^{\top}\bm{w}_{k}\right] \displaystyle\overset{\eqref{eq:padd}}{\geq} \displaystyle\frac{K}{2(K1)}\sum_{k=1}^{K}\sum_{i\in\mathcal{S}_{2}}\\bar{% \bm{h}}_{i}\bm{h}_{k,i}\^{2}/C^{\prime\prime}\frac{C^{\prime\prime}K(nt)}{% 2(K1)}\sum_{k=1}^{K}\\bm{w}_{k}\^{2} \displaystyle\overset{\eqref{eq:boundtheta}}{\geq} \displaystyle\frac{K}{2(K1)}\sum_{k=1}^{K}\sum_{i\in\mathcal{S}_{2}}\\bm{h}% _{k,i}\^{2}/C^{\prime\prime}\frac{C^{\prime\prime}K(nt)}{2(K1)}\sum_{k=1}^% {K}\\bm{w}_{k}\^{2} \displaystyle\geq \displaystyle\frac{K}{2(K1)}\sum_{k=1}^{K}\sum_{i=1}^{n}\\bm{h}_{k,i}\^{2}% /C^{\prime\prime}\frac{C^{\prime\prime}K(nt)}{2(K1)}\sum_{k=1}^{K}\\bm{w}_% {k}\^{2} \displaystyle\geq \displaystyle\frac{K^{2}}{(K1)}\sqrt{E_{H}E_{W}(nt)n}, where in the last inequality we follow from the constrains of (7) and set C^{\prime\prime}:=\sqrt{\frac{nE_{H}}{(nt)E_{W}}}.
We combine the above two cases. When t\in[0,n1], by plugging (47) into (46), using the monotonicity of g_{0}(x), and adding (45), we have
\displaystyle\sum_{k=1}^{n}\sum_{i=1}^{K}\bm{S}(\bm{z}_{k,i})(k)  \displaystyle\leq N\left(\frac{1}{2}t+\frac{K}{1+\exp\left(\frac{K}{K1}\sqrt% {n/(nt)}\sqrt{E_{H}E_{W}}\log(K1)\right)}(nt)\right)  
\displaystyle=L(t).  (48) 
And when t=n, it directly follows from (46) that
\sum_{k=1}^{n}\sum_{i=1}^{K}\bm{S}(\bm{z}_{k,i})(k)\leq N\frac{n}{2}=L(n). 
Therefore, it suffices to show L(t)\leq L(0) for all t\in[0:n]. We first consider the case when t\in[0:N1]. We show that L(t) is monotonously decreasing. Indeed, define
q(t):=\frac{K}{1+\exp\left(\frac{K}{K1}\sqrt{n/(nt)}\sqrt{E_{H}E_{W}}\log(K% 1)\right)}. 
We have
\displaystyle q^{\prime}(t)  \displaystyle=\frac{\frac{1}{2}K\exp\left(\frac{K}{K1}\sqrt{n/(nt)}\sqrt{E_% {H}E_{W}}\log(K1)\right)\frac{K}{K1}\sqrt{E_{H}E_{W}n}(nt)^{3/2}}{\left[1% +\exp\left(\frac{K}{K1}\sqrt{n/(nt)}\sqrt{E_{H}E_{W}}\log(K1)\right)\right% ]^{2}}  
\displaystyle\geq\frac{\frac{1}{2}\frac{K^{2}}{K1}\sqrt{E_{H}E_{W}n}(nt)^{% 3/2}}{1+\exp\left(\frac{K}{K1}\sqrt{n/(nt)}\sqrt{E_{H}E_{W}}\log(K1)\right% )}, 
which implies that
\displaystyle L^{\prime}(t)=\left[\frac{1}{2}q(t)+q^{\prime}(t)(nt)\right]  
\displaystyle\leq  \displaystyle\frac{\frac{1}{2}\frac{K^{2}}{K1}\sqrt{E_{H}E_{W}n}(nt)^{1/2}+% K}{1+\exp\left(\frac{K}{K1}\sqrt{n/(nt)}\sqrt{E_{H}E_{W}}\log(K1)\right)}% \frac{1}{2}  
\displaystyle=  \displaystyle\frac{K\left(\frac{K}{K1}\sqrt{n/(nt)}\sqrt{E_{H}E_{W}}\right)+% 2K1\exp\left(\frac{K}{K1}\sqrt{n/(nt)}\sqrt{E_{H}E_{W}}\log(K1)\right)}{% 2\left[1+\exp\left(\frac{K}{K1}\sqrt{n/(nt)}\sqrt{E_{H}E_{W}}\log(K1)% \right)\right]}. 
Consider function f(x):\left[\frac{K}{K1}\sqrt{E_{H}E_{W}},\frac{K}{K1}\sqrt{E_{H}E_{W}n}% \right]\to R as:
f(x)=Kx+2K1\exp(x\log(K1)). 
We have
f^{\prime}(x)=K\exp(x)/(K1)<0 
when x\in\left[\frac{K}{K1}\sqrt{E_{H}E_{W}},\frac{K}{K1}\sqrt{E_{H}E_{W}n}\right], where we use the assumption that
\sqrt{E_{H}E_{W}}>\frac{K1}{K}\log\left(K^{2}\sqrt{E_{H}E_{W}}+(2K1)(K1)% \right)\geq\frac{K1}{K}\log\left(K(K1)\right). 
Therefore, for all x\in\left[\frac{K}{K1}\sqrt{E_{H}E_{W}},\frac{K}{K1}\sqrt{E_{H}E_{W}n}\right], we have
f(x)\leq f\left(\frac{K}{K1}\sqrt{E_{H}E_{W}}\right)=\frac{K^{2}}{K1}\sqrt{E% _{H}E_{W}}+2K1\frac{1}{K1}\exp\left(\frac{K}{K1}\sqrt{E_{H}E_{W}}\right)% \overset{a}{<}0, 
where \overset{a}{<} use our assumption again. We obtain L^{\prime}(t)<0 for all t\in[0:N1]. So L(t) reaches the maximum if and only if t=0 when t\in[0:N1]. Moreover, under our assumption, one can verify that L(N)<L(0). We obtain (44) from (48) with t=0.
When t=0, the equality in the first inequality of (47) holds if and only if:
\bar{\bm{h}}_{i}\bm{h}_{k,i}=\sqrt{\frac{E_{H}}{E_{W}}}\bm{w}_{k},\quad k\in% [K],\quad i\in[n]. 
The equality in the second and third inequalities of (47) holds if and only if:
\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n}\sum_{i=1}^{n}\left\\bm{h}_{k,i}\right\^% {2}=E_{H},\quad\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}\right\^{2}=E_{W},% \quad\bar{\bm{h}}_{i}=\bm{0}_{p},~{}i\in[n]. 
We obtain Lemma 3. ∎
A.2 Imbalanced Case
A.2.1 Proofs of Lemma 1 and Proposition 1
Proof of Lemma 1.
For any feasible solution \left(\bm{H},\bm{W}\right) for the original program (7), we define
\bm{h}_{k}:=\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\bm{h}_{k,i},~{}k\in[K],\quad% \text{and}\quad\bm{X}:=\left[\bm{h}_{1},\bm{h}_{2},\dots,\bm{h}_{K},\bm{W}^{% \top}\right]^{\top}\left[\bm{h}_{1},\bm{h}_{2},\dots,\bm{h}_{K},\bm{W}^{\top}% \right]. 
Clearly, \bm{X}\succeq 0. For the other two constraints of (13), we have
\frac{1}{K}\sum_{k=1}^{K}\bm{X}(k,k)\\ =\frac{1}{K}\sum_{k=1}^{K}\\bm{h}_{k}\^{2}\overset{a}{\leq}\frac{1}{K}\sum_{% k=1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\\bm{h}_{k,i}\right\^{2}\\ \overset{b}{\leq}E_{H}, 
and
\frac{1}{K}\sum_{k=K+1}^{2K}\bm{X}(k,k)=\frac{1}{K}\sum_{k=1}^{K}\\bm{w}_{k}% \^{2}\overset{c}{\leq}E_{W}, 
where \overset{a}{\leq} applies Jensen’s inequality and \overset{b}{\leq} and \overset{c}{\leq} use that \left(\bm{H},\bm{W}\right) is a feasible solution. So \bm{X} is a feasible solution for the convex program (13). Letting L_{0} be the global minimum of (13), for any feasible solution \left(\bm{H},\bm{W}\right), we obtain
\displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n_{k}}\mathcal{L}(\bm{W}\bm{% h}_{k,i},\bm{y}_{k})  \displaystyle=\sum_{k=1}^{K}\frac{n_{k}}{N}\left[\frac{1}{n_{k}}\sum_{k=1}^{n_% {k}}\mathcal{L}(\bm{W}\bm{h}_{k,i},\bm{y}_{k})\right]  
\displaystyle\overset{a}{\geq}\sum_{k=1}^{K}\frac{n_{k}}{N}\mathcal{L}(\bm{W}% \bm{h}_{k},\bm{y}_{k})=\sum_{k=1}^{K}\frac{n_{k}}{N}\mathcal{L}(\bm{z}_{k},\bm% {y}_{k})\geq L_{0},  (49) 
where in \overset{a}{\geq}, we use \mathcal{L} is convex on the first argument, and so \mathcal{L}(\bm{W}\bm{h},\bm{y}_{k}) is convex on \bm{h} given \bm{W} and k\in[K].
On the other hand, considering the solution \left(\bm{H}^{\star},\bm{W}^{\star}\right) defined in (14) with \bm{X}^{\star} being a minimizer of (13), we have \left[\bm{h}_{1}^{\star},\bm{h}_{2}^{\star},\dots,\bm{h}_{K}^{\star},(\bm{W}^{% \star})^{\top}\right]^{\top}\left[\bm{h}_{1}^{\star},\bm{h}_{2}^{\star},\dots,% \bm{h}_{K}^{\star},(\bm{W}^{\star})^{\top}\right]=\bm{X}^{\star} (p\geq 2K guarantees the existence of \left[\bm{h}_{1}^{\star},\bm{h}_{2}^{\star},\dots,\bm{h}_{K}^{\star},(\bm{W}^{% \star})^{\top}\right]). We can verify that \left(\bm{H}^{\star},\bm{W}^{\star}\right) is a feasible solution for (7) and have
\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n_{k}}\mathcal{L}(\bm{W}^{\star}\bm{h}_{k% ,i}^{\star},\bm{y}_{k})=\sum_{k=1}^{K}\frac{n_{k}}{N}\mathcal{L}(\bm{z}_{k}^{% \star},\bm{y}_{k})=L_{0},  (50) 
where \bm{z}_{k}^{\star}=\left[\bm{X}^{\star}(k,1+K),\bm{X}^{\star}(k,2+K),\dots,\bm% {X}^{\star}(k,2K)~{}\right]^{\top} for k\in[K].
Combing (A.2.1) and (50), we conclude that L_{0} is the global minimum of (7) and (\bm{H}^{\star},\bm{W}^{\star}) is a minimizer.
Suppose there is a minimizer \left(\bm{H}^{\prime},\bm{W}^{\prime}\right) that cannot be written as (14). Let
\bm{h}_{k}^{\prime}=\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\bm{h}_{k,i}^{\prime},~{}% k\in[K],\quad\text{and}\quad\bm{X}^{\prime}=\left[\bm{h}_{1}^{\prime},\bm{h}_{% 2}^{\prime},\dots,\bm{h}_{K}^{\prime},(\bm{W}^{\prime})^{\top}\right]^{\top}% \left[\bm{h}_{1}^{\prime},\bm{h}_{2}^{\prime},\dots,\bm{h}_{K}^{\prime},(\bm{W% }^{\prime})^{\top}\right]. 
(A.2.1) implies that \bm{X}^{\prime} is a minimizer of (13). As \left(\bm{H}^{\prime},\bm{W}^{\prime}\right) cannot be written as (14) with \bm{X}^{\star}=\bm{X}^{\prime}, then there is a {k^{\prime}}\in[K], i,j\in[n_{k^{\prime}}] with i\neq j such that \bm{h}_{{k^{\prime}},i}\neq\bm{h}_{{k^{\prime}},j}. We have
\displaystyle\frac{1}{K}\sum_{k=1}^{K}X^{\prime}(k,k)=\frac{1}{K}\sum_{k=1}^{K% }\\bm{h}_{k}^{\prime}\^{2}  
\displaystyle=  \displaystyle\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\% \bm{h}_{k,i}^{\prime}\right\^{2}\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n_{k}}\sum% _{k=1}^{K}\\bm{h}_{k,i}^{\prime}\bm{h}_{k}^{\prime}\^{2}  
\displaystyle\leq  \displaystyle\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\% \bm{h}_{k,i}^{\prime}\right\^{2}\frac{1}{K}\frac{1}{n_{k^{\prime}}}(\\bm{h}% _{{k^{\prime}},i}^{\prime}\bm{h}_{{k^{\prime}}}^{\prime}\^{2}+\\bm{h}_{{k^{% \prime}},j}^{\prime}\bm{h}_{{k^{\prime}}}^{\prime}\^{2})  
\displaystyle\leq  \displaystyle\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\% \bm{h}_{k,i}^{\prime}\right\^{2}\frac{1}{K}\frac{1}{2n_{k^{\prime}}}\\bm{h}% _{{k^{\prime}},i}^{\prime}\bm{h}_{{k^{\prime}},j}^{\prime}\^{2}  
\displaystyle<  \displaystyle E_{H}. 
By contraposition, if all \bm{X}^{\star} satisfy that \frac{1}{K}\sum_{k=1}^{K}\bm{X}^{\star}(k,k)=E_{H}, then all the solutions of (7) are in form of (14). We complete the proof. ∎
A.2.2 Proof of Theorem 7
To prove Theorem 7, we first study a limit case where we only learn the classification for a partial classes. Especially, we solve the optimization program:
\displaystyle\min_{\bm{H},\bm{W}}  \displaystyle\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}\mathcal{% L}(\bm{W}\bm{h}_{k,i},\bm{y}_{k})  (50)  
\displaystyle\mathrm{s.t.}  \displaystyle\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\% \bm{h}_{k,i}\right\^{2}\leq E_{H},  
\displaystyle\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}\right\^{2}\leq E_{W}, 
where n_{1}=n_{2}=\dots=n_{K_{A}}=n_{A} and n_{K_{A}+1}=n_{K_{A}+2}=\dots=n_{K}=n_{B}. Lemma 4 characterizes useful properties for the minimizer of (50).
Lemma 4.
Let (\bm{H},\bm{W}) be a minimzer of (50). We have \bm{h}_{k,i}=\bm{0}_{p} for all k\in[K_{A}+1:K] and i\in[n_{B}]. Define L_{0} as the global minimum of (50), i.e.,
L_{0}:=\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}\mathcal{L}(\bm% {W}\bm{h}_{k,i},\bm{y}_{k}). 
Then L_{0} only depends on K_{A}, K_{B}, E_{H}, and E_{W}. Moreover, for any feasible solution \left(\bm{H}^{\prime},\bm{W}^{\prime}\right), if there exist k,{k^{\prime}}\in[K_{A}+1:K] such that \left\\bm{w}_{k}\bm{w}_{k^{\prime}}\right\=\varepsilon>0, we have
\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}\mathcal{L}\left(\bm{W% }^{\prime}\bm{h}_{k,i}^{\prime},\bm{y}_{k}\right)\geq L_{0}+\varepsilon^{% \prime}, 
where \varepsilon^{\prime}>0 depends on \varepsilon, K_{A}, K_{B}, E_{H}, and E_{W}.
Now we are ready to prove Theorem 7. The proof is based on the contradiction.
Proof of Theorem 7.
Consider sequences n_{A}^{\ell} and n_{B}^{\ell} with R^{\ell}:=n_{A}^{\ell}/n^{\ell}_{B} for \ell=1,2,\dots. We have R^{\ell}\to\infty. For each optimization program indexed by \ell\in\mathbb{N}_{+}, we define (\bm{H}^{\ell,\star}.\bm{W}^{\ell,\star}) as a minimizer and separate the objective function into two parts. That is, we introduce
\mathcal{L}^{\ell}\left(\bm{H}^{\ell},\bm{W}^{\ell}\right):=\frac{K_{A}n_{A}^{% \ell}}{K_{A}n_{A}^{\ell}+K_{B}n_{B}^{\ell}}\mathcal{L}^{\ell}_{A}\left(\bm{H}^% {\ell},\bm{W}^{\ell}\right)+\frac{K_{B}n_{B}^{\ell}}{K_{A}n_{A}^{\ell}+K_{B}n_% {B}^{\ell}}\mathcal{L}^{\ell}_{B}\left(\bm{H}^{\ell},\bm{W}^{\ell}\right), 
with
\mathcal{L}^{\ell}_{A}\left(\bm{H}^{\ell},\bm{W}^{\ell}\right):=\frac{1}{K_{A}% n_{A}^{\ell}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}^{\ell}}\mathcal{L}\left(\bm{W% }^{\ell}\bm{h}_{k,i}^{\ell},\bm{y}_{k}\right) 
and
\mathcal{L}^{\ell}_{B}\left(\bm{H}^{\ell},\bm{W}^{\ell}\right):=\frac{1}{K_{B}% n_{B}^{\ell}}\sum_{k=K_{A}+1}^{K}\sum_{i=1}^{n_{B}^{\ell}}\mathcal{L}\left(\bm% {W}^{\ell}\bm{h}_{k,i}^{\ell},\bm{y}_{k}\right). 
We define \left(\bm{H}^{\ell,A},\bm{W}^{\ell,A}\right) as a minimizer of the optimization program:
\displaystyle\min_{\bm{H}^{\ell},\bm{W}^{\ell}}  \displaystyle\mathcal{L}^{\ell}_{A}\left(\bm{H}^{\ell},\bm{W}^{\ell}\right)  (51)  
\displaystyle\mathrm{s.t.}  \displaystyle\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}^{\ell}\right\^{2}\leq E% _{W},  
\displaystyle\frac{1}{K}\sum_{k=1}^{K_{A}}\frac{1}{n_{A}^{\ell}}\sum_{i=1}^{n_% {A}^{\ell}}\left\\bm{h}_{k,i}\right\^{2}+\frac{1}{K}\sum_{k=K_{A}+1}^{K}% \frac{1}{n_{B}^{\ell}}\sum_{i=1}^{n_{B}^{\ell}}\left\\bm{h}_{k,i}\right\^{2}% \leq E_{H}, 
and \left(\bm{H}^{\ell,B},\bm{W}^{\ell,B}\right) as a minimizer of the optimization program:
\displaystyle\min_{\bm{H}^{\ell},\bm{W}^{\ell}}  \displaystyle\mathcal{L}^{\ell}_{B}\left(\bm{H}^{\ell},\bm{W}^{\ell}\right)  (52)  
\displaystyle\mathrm{s.t.}  \displaystyle\frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}^{\ell}\right\^{2}\leq E% _{W},  
\displaystyle\frac{1}{K}\sum_{k=1}^{K_{A}}\frac{1}{n_{A}^{\ell}}\sum_{i=1}^{n_% {A}^{\ell}}\left\\bm{h}_{k,i}\right\^{2}+\frac{1}{K}\sum_{k=K_{A}+1}^{K}% \frac{1}{n_{B}^{\ell}}\sum_{i=1}^{n_{B}^{\ell}}\left\\bm{h}_{k,i}\right\^{2}% \leq E_{H}. 
Note that Programs (51) and (52) and their minimizers have been studied in Lemma 4. We define:
L_{A}:=\mathcal{L}^{\ell}_{A}\left(\bm{H}^{\ell,A},\bm{W}^{\ell,A}\right)\quad% \text{and}\quad L_{B}:=\mathcal{L}^{\ell}_{B}\left(\bm{H}^{\ell,B},\bm{W}^{% \ell,B}\right). 
Then Lemma 4 implies that L_{A} and L_{B} only depend on K_{A}, K_{B}, E_{H}, and E_{W}, and are independent of \ell. Moreover, since \bm{h}_{k,i}^{\ell,A}=\bm{0}_{p} for all k\in[K_{A}+1:K] and i\in[n_{B}], we have
\mathcal{L}^{\ell}_{B}\left(\bm{H}^{\ell,A},\bm{W}^{\ell,A}\right)=\log(K).  (53) 
Now we prove Theorem 7 by contradiction. Suppose there exists a pair (k,{k^{\prime}}) such that \lim_{\ell\to\infty}\bm{w}^{\ell,\star}_{k}\bm{w}^{\ell,\star}_{k^{\prime}}% \neq\bm{0}_{p}. Then there exists \varepsilon>0 such that for a subsequence \left\{\left(\bm{H}^{a_{\ell},\star},\bm{W}^{a_{\ell},\star}\right)\right\}_{% \ell=1}^{\infty} and an index \ell_{0} when \ell\geq\ell_{0}, we have \left\\bm{w}^{a_{\ell},\star}_{k}\bm{w}^{a_{\ell},\star}_{k^{\prime}}\right% \\geq\varepsilon. Now we figure out a contradiction by estimating the objective function value on \left(\bm{H}^{a_{\ell},\star},\bm{W}^{a_{\ell},\star}\right). In fact, because \left(\bm{H}^{a_{\ell},\star},\bm{W}^{a_{\ell},\star}\right) is a minimizer of \mathcal{L}^{\ell}(\bm{H}^{\ell},\bm{W}^{\ell}), we have
\displaystyle\mathcal{L}^{a_{\ell}}\left(\bm{H}^{a_{\ell},\star},\bm{W}^{a_{% \ell},\star}\right)\leq\mathcal{L}^{a_{\ell}}\left(\bm{H}^{a_{\ell},A},\bm{W}^% {a_{\ell},A}\right)  \displaystyle\overset{\eqref{eq:lblog}}{=}\frac{K_{A}n_{A}^{a_{\ell}}}{K_{A}n_% {A}^{a_{\ell}}+K_{B}n_{B}^{a_{\ell}}}L_{A}+\frac{K_{B}n_{B}^{a_{\ell}}}{K_{A}n% _{A}^{a_{\ell}}+K_{B}n_{B}^{a_{\ell}}}\log(K)  
\displaystyle=L_{A}+\frac{1}{K_{R}R^{a_{\ell}}+1}\left(\log(K)L_{A}\right)% \overset{\ell\to\infty}{\to}L_{A},  (54) 
where we define K_{R}:=K_{A}/K_{B} and use R^{\ell}=n_{A}^{\ell}/n_{B}^{\ell}.
However, when \ell>\ell_{0}, because \left\\bm{w}^{a_{\ell},\star}_{k}\bm{w}^{a_{\ell},\star}_{k^{\prime}}\right% \\geq\varepsilon>0, Lemma 4 implies that
\mathcal{L}^{a_{\ell}}_{A}\left(\bm{H}^{a_{\ell},\star},\bm{W}^{a_{\ell},\star% }\right)\geq L_{A}+\varepsilon_{2}, 
where \varepsilon_{2}>0 only depends on \varepsilon, K_{A}, K_{B}, E_{H}, and E_{W}, and is independent of \ell. We obtain
\displaystyle\mathcal{L}^{a_{\ell}}\left(\bm{H}^{a_{\ell},\star},\bm{W}^{a_{% \ell},\star}\right)  \displaystyle=\mathcal{L}^{a_{\ell}}_{A}\left(\bm{H}^{a_{\ell},\star},\bm{W}^{% a_{\ell},\star}\right)+\mathcal{L}^{a_{\ell}}_{B}\left(\bm{H}^{a_{\ell},\star}% ,\bm{W}^{a_{\ell},\star}\right)  
\displaystyle\overset{a}{\geq}\mathcal{L}^{a_{\ell}}_{A}\left(\bm{H}^{a_{\ell}% ,\star},\bm{W}^{a_{\ell},\star}\right)+\mathcal{L}^{a_{\ell}}_{B}\left(\bm{H}^% {a_{\ell},B},\bm{W}^{a_{\ell},B}\right)  
\displaystyle=\frac{K_{A}n_{A}^{a_{\ell}}}{K_{A}n_{A}^{a_{\ell}}+K_{B}n_{B}^{a% _{\ell}}}(L_{A}+\varepsilon_{2})+\frac{K_{B}n_{B}^{a_{\ell}}}{K_{A}n_{A}^{a_{% \ell}}+K_{B}n_{B}^{a_{\ell}}}L_{B}  
\displaystyle=L_{A}+\varepsilon_{2}+\frac{1}{K_{R}R^{a_{\ell}}+1}(L_{B}L_{A}% \varepsilon_{2})\overset{\ell\to\infty}{\to}L_{A}+\varepsilon_{2},  (55) 
where \overset{a}{\geq} uses \left(\bm{H}^{a_{\ell},B},\bm{W}^{a_{\ell},B}\right) is the minimizer of (52). Thus we meet contradiction by comparing (A.2.2) with (A.2.2) and achieve Theorem 7. ∎
Proof of Lemma 4.
For any constants C_{a}>0, C_{b}>0, and C_{c}>0, define C_{a}^{\prime}:=\frac{C_{a}}{C_{a}+(K_{A}1)C_{b}+K_{B}C_{c}}\in(0,1), C_{b}^{\prime}:=\frac{C_{b}}{C_{a}+(K_{A}1)C_{b}+K_{B}C_{c}}\in(0,1), and C_{c}^{\prime}:=\frac{C_{c}}{C_{a}+(K_{A}1)C_{b}+K_{B}C_{c}}\in(0,1), C_{d}:=C_{a}^{\prime}\log(C_{a}^{\prime})C_{b}^{\prime}(K_{A}1)\log(C_{b}^{% \prime})K_{B}C_{c}^{\prime}\log(C_{c}^{\prime}), C_{e}:=\frac{K_{A}C_{b}}{K_{A}C_{b}+K_{B}C_{c}}\in(0,1), C_{f}:=\frac{K_{B}C_{c}}{K_{A}C_{b}+K_{B}C_{c}}\in(0,1), and C_{g}:=\frac{K_{A}C_{b}+K_{B}C_{c}}{C_{a}+(K_{A}1)C_{b}+K_{B}C_{c}}>0. Using a similar argument as Theorem 3, we show in Lemma 5 (see the end of the proof), for any feasible solution (\bm{H},\bm{W}) of (50), the objective value can be bounded from below by:
\displaystyle\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}\mathcal{% L}(\bm{W}\bm{h}_{k,i},\bm{y}_{k})  (56)  
\displaystyle\overset{a}{\geq}  \displaystyle\frac{C_{g}}{K_{A}}\sqrt{KE_{H}}\sqrt{\sum_{k=1}^{K_{A}}\left\C% _{e}\bm{w}_{A}+C_{f}\bm{w}_{B}\bm{w}_{k}\right\^{2}}+C_{d}  
\displaystyle\overset{b}{\geq}  \displaystyle\frac{C_{g}}{K_{A}}\sqrt{KE_{H}}\sqrt{KE_{W}K_{A}\left(1/K_{R}% C_{f}^{2}\frac{C_{f}^{4}}{C_{e}(2C_{e})}\right)\\bm{w}_{B}\^{2}\sum_{k=K_% {A}+1}^{K}\left\\bm{w}_{k}\bm{w}_{B}\right\^{2}}+C_{d}, 
where \bm{w}_{A}:=\frac{1}{K_{A}}\sum_{k=1}^{K_{A}}\bm{w}_{k}, \bm{w}_{B}:=\frac{1}{K_{B}}\sum_{k=K_{A}+1}^{K}\bm{w}_{k}, and K_{R}:=\frac{K_{A}}{K_{B}}. Moreover, the equality in \overset{a}{\geq} holds only if \bm{h}_{k,i}=\bm{0}_{p} for all k\in[K_{A}+1:K] and i\in[n_{B}].
Though C_{a}, C_{b}, and C_{c} can be any positive numbers, we need to carefully pick them to exactly reach the global minimum of (50). In the following, we separately consider three cases according to the values of K_{A}, K_{B}, and E_{H}E_{W}.

(i)
Consider the case when K_{A}=1. We pick C_{a}:=\exp\left(\sqrt{K_{B}(1+K_{B})E_{H}E_{W}}\right), C_{b}:=1, and C_{c}:=\exp\left(\sqrt{(1+K_{B})E_{H}E_{W}/K_{B}}\right).
Then from \overset{a}{\geq} in (56), we have
\displaystyle\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}\mathcal{% L}(\bm{W}\bm{h}_{k,i},\bm{y}_{k}) \displaystyle\overset{a}{\geq} \displaystyleC_{g}C_{f}\sqrt{KE_{H}}\sqrt{\left\\bm{w}_{1}\bm{w}_{B}\right% \^{2}}+C_{d} \displaystyle= \displaystyleC_{g}C_{f}\sqrt{KE_{H}}\sqrt{\\bm{w}_{1}\^{2}2\bm{w}_{1}^{% \top}\bm{w}_{B}+\\bm{w}_{B}\^{2}}+C_{d} \displaystyle\overset{b}{\geq} \displaystyleC_{g}C_{f}\sqrt{KE_{H}}\sqrt{(1+1/K_{B})(\\bm{w}_{1}\^{2}+K_{B% }\\bm{w}_{B}\^{2})}+C_{d} \displaystyle\overset{c}{\geq} \displaystyleC_{g}C_{f}\sqrt{KE_{H}}\sqrt{(1+1/K_{B})\left(KE_{W}\sum_{k=2}^% {K}\\bm{w}_{k}\bm{w}_{B}\^{2}\right)}+C_{d} \displaystyle\geq \displaystyleC_{g}C_{f}\sqrt{KE_{H}}\sqrt{(1+1/K_{B})KE_{W}}+C_{d}:=L_{1}, (57) where \overset{a}{\geq} uses C_{e}+C_{f}=1, \overset{b}{\geq} follows from Young’s inequality, i.e., 2\bm{w}_{1}^{\top}\bm{w}_{B}\leq(1/K_{B})\\bm{w}_{1}\^{2}+K_{B}\\bm{w}_{B}% \^{2}, and \overset{c}{\geq} follows from \sum_{k=2}^{K}\\bm{w}_{k}\^{2}=K_{B}\\bm{w}_{B}\^{2}+\sum_{k=2}^{K}\\bm{w% }_{k}\bm{w}_{B}\^{2} and the constraint that \sum_{k=1}^{K}\\bm{w}_{k}\^{2}\leq KE_{W}.
On the other hand, when (\bm{H},\bm{W}) satisfies that
\displaystyle\begin{aligned} \displaystyle\bm{w}_{1}&\displaystyle=\sqrt{K_{B}% E_{W}}\bm{u},\quad\bm{w}_{k}=\sqrt{1/K_{B}E_{W}}\bm{u},~{}k\in[2:K],\\ \displaystyle\bm{h}_{1,i}=&\displaystyle\sqrt{(1+K_{B})E_{H}}\bm{u},~{}i\in[n_% {A}],\quad\quad\bm{h}_{k,i}=\bm{0}_{p},~{}k\in[2:K],~{}i\in[n_{B}],\\ \end{aligned} where \bm{u} is any unit vector, (\bm{H},\bm{W}) can achieve the equality in (i). So L_{1} is the global minimum of (50). Moreover, L_{1} is achieved only if the equality in \overset{a}{\geq} in (56) holds. From Lemma 50, we have any minimizer satisfies that \bm{h}_{k,i}=\bm{0}_{p} for all k\in[K_{A}+1:K] and i\in[n_{B}].
Finally, for any feasible solution \left(\bm{H}^{\prime},\bm{W}^{\prime}\right), if there exist k,{k^{\prime}}\in[K_{A}+1:K] such that \left\\bm{w}_{k}\bm{w}_{k^{\prime}}\right\=\varepsilon>0, we have
\sum_{k=K_{A}+1}^{K}\\bm{w}_{k}\bm{w}_{B}\^{2}\geq\\bm{w}_{k}\bm{w}_{B}\% ^{2}+\\bm{w}_{k^{\prime}}\bm{w}_{B}\^{2}\geq\frac{\\bm{w}_{k}\bm{w}_{k^{% \prime}}\^{2}}{2}=\varepsilon^{2}/2. (58) It follows from \overset{c}{\geq} in (i) that
\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}\mathcal{L}(\bm{W}\bm{% h}_{k,i},\bm{y}_{k})\geqC_{g}C_{f}\sqrt{KE_{H}}\sqrt{(1+1/K_{B})\left(KE_{W}% \varepsilon^{2}/2\right)}+C_{d}:=L_{1}+\varepsilon_{1} with \varepsilon_{1}>0 depending on \varepsilon, K_{A}, K_{B}, E_{H}, and E_{W}.

(ii)
Consider the case when K_{A}>1 and \exp\left((1+1/K_{R})\sqrt{E_{H}E_{W}}/(K_{A}1)\right)<\sqrt{1+K_{R}}+1. Let us pick C_{a}:=\exp\left((1+1/K_{R})\sqrt{E_{H}E_{W}}\right), C_{b}:=\exp\left(\frac{1}{K_{A}1}(1+1/K_{R})\sqrt{E_{H}E_{W}}\right), and C_{c}:=1.
Following from \overset{b}{\geq} in (56), we know if 1/K_{R}C_{f}^{2}\frac{C_{f}^{4}}{C_{e}(2C_{f})}>0, then
\displaystyle\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}\mathcal{% L}(\bm{W}\bm{h}_{k,i},\bm{y}_{k})\geqC_{g}(1+1/K_{R})\sqrt{E_{H}E_{W}}+C_{d}:% =L_{2}. (59) In fact, we do have 1/K_{R}C_{f}^{2}\frac{C_{f}^{4}}{C_{e}(2C_{f})}>0 because
\displaystyle\begin{aligned} &\displaystyle 1/K_{R}>C_{f}^{2}\frac{C_{f}^{4}}% {C_{e}(2C_{e})}\quad\quad\quad\left(\text{by~{}}C_{e}+C_{f}=1\right)\\ \displaystyle\iff&\displaystyle C_{e}>\sqrt{\frac{1}{1+K_{R}}}\quad\quad\quad% \left(\text{by~{}}C_{e}=\frac{K_{B}C_{c}}{K_{A}C_{b}+K_{B}C_{c}}\right)\\ \displaystyle\iff&\displaystyle\frac{C_{b}}{C_{c}}>\frac{1}{\sqrt{1+K_{R}}+1}% \\ \displaystyle\iff&\displaystyle\exp\left((1+1/K_{R})\sqrt{E_{H}E_{W}}/(K_{A}1% )\right)<\sqrt{1+K_{R}}+1.\end{aligned} On the other hand, when (\bm{H},\bm{W}) satisfies that
\displaystyle\begin{aligned} \displaystyle\left[\bm{w}_{1},\bm{w}_{2},\ldots,% \bm{w}_{K_{A}}\right]=&\displaystyle\sqrt{\frac{E_{W}}{E_{H}}}~{}\bigg{[}\bm{h% }_{1},\ldots,\bm{h}_{K_{A}}\bigg{]}^{\top}=\sqrt{(1+1/K_{R})E_{W}}~{}(\bm{M}_{% A}^{\star})^{\top},\\ \displaystyle\bm{h}_{k,i}=&\displaystyle\bm{h}_{k},\quad k\in[K_{A}],~{}i\in[n% _{A}]\\ \displaystyle\bm{h}_{k,i}=&\displaystyle\bm{w}_{k}=\bm{0}_{p},\quad k\in[K_{A}% +1:K],~{}i\in[n_{B}],\\ \end{aligned} where \bm{M}_{A}^{\star} is a K_{A}simplex ETF, (\bm{H},\bm{W}) can achieve the equality in (59). So L_{2} is the global minimum of (50). Moreover, L_{2} is achieved only if the equality in \overset{a}{\geq} of (56) holds. From Lemma 5, we have any minimizer satisfies that \bm{h}_{k,i}=\bm{0}_{p} for all k\in[K_{A}+1:K] and i\in[n_{B}].
Finally, for any feasible solution \left(\bm{H}^{\prime},\bm{W}^{\prime}\right), if there exist k,{k^{\prime}}\in[K_{A}+1:K] such that \left\\bm{w}_{k}\bm{w}_{k^{\prime}}\right\=\varepsilon>0, plugging (58) into (56), we have
\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}\mathcal{L}(\bm{W}\bm{% h}_{k,i},\bm{y}_{k})\geq\frac{C_{g}}{K_{A}}\sqrt{KE_{H}}\sqrt{KE_{W}% \varepsilon^{2}/2}+C_{d}:=L_{2}+\varepsilon_{2}, (60) with \varepsilon_{2}>0 depending on \varepsilon, K_{A}, K_{B}, E_{H}, and E_{W}.

(iii)
Consider the case when K_{A}>1 and \exp((1+1/K_{R})\sqrt{E_{H}E_{W}}/(K_{A}1))\geq\sqrt{1+K_{R}}+1. Let C_{f}^{\prime}:=\frac{1}{\sqrt{K_{R}+1}} and C_{e}^{\prime}:=1C_{f}^{\prime}. For x\in[0,1], we define:
\displaystyle\begin{aligned} \displaystyle g_{N}(x):&\displaystyle=\sqrt{\frac% {(1+K_{R})E_{W}}{K_{R}x^{2}+(K_{R}+K_{R}^{2})(1x)^{2}}},\\ \displaystyle g_{a}(x):&\displaystyle=\exp\left(\frac{g_{N}(x)\sqrt{(1+K_{R})E% _{H}/K_{R}}}{\sqrt{x^{2}+\left(1+\frac{C_{e}^{\prime}}{C_{f}^{\prime}}\right)^% {2}(1x)^{2}}}\left[x^{2}+\left(1+\frac{C_{e}^{\prime}}{C_{f}^{\prime}}\right)% (1x)^{2}\right]\right),\\ \displaystyle g_{b}(x):&\displaystyle=\exp\left(\frac{g_{N}(x)\sqrt{(1+K_{R})E% _{H}/K_{R}}}{\sqrt{x^{2}+\left(1+\frac{C_{e}^{\prime}}{C_{f}^{\prime}}\right)^% {2}(1x)^{2}}}\left[\frac{1}{K_{A}1}x^{2}+\left(1+\frac{C_{e}^{\prime}}{C_{f% }^{\prime}}\right)(1x)^{2}\right]\right),\\ \displaystyle g_{c}(x):&\displaystyle=\exp\left(\frac{g_{N}(x)\sqrt{(1+K_{R})E% _{H}/K_{R}}}{\sqrt{x^{2}+\left(1+\frac{C_{e}^{\prime}}{C_{f}^{\prime}}\right)^% {2}(1x)^{2}}}\left[\left(1+\frac{C_{e}^{\prime}}{C_{f}^{\prime}}\right)K_{R}% (1x)^{2}\right]\right).\end{aligned} Let x_{0}\in[0,1] be a root of the equation
g_{b}(x)/g_{c}(x)=\frac{1/C_{f}^{\prime}1}{K_{R}}. We first show that the solution x_{0} exists. First of all, one can directly verify when x\in[0,1], g_{b}(x)/g_{c}(x) is continuous. It suffices to prove that (1) g_{b}(0)/g_{c}(0)\geq\frac{1/C_{f}^{\prime}1}{K_{R}} and (2) g_{b}(1)/g_{c}(1)\leq\frac{1/C_{f}^{\prime}1}{K_{R}}.

(1)
When x=0, we have g_{b}(x)/g_{c}(x)\geq\exp(0)=1. At the same time, \frac{1/C_{f}^{\prime}1}{K_{R}}=\frac{\sqrt{K_{R}+1}1}{K_{R}}=\frac{1}{\sqrt% {K_{R}+1}+1}\leq 1. Thus (i) is achieved.

(2)
When x=1, we have g_{N}(1)=\sqrt{(1+1/K_{R})E_{W}}, so
\displaystyle\begin{aligned} \displaystyle g_{b}(1)/g_{c}(1)=\exp\left((1+1/K% _{R})\sqrt{E_{H}E_{W}}/(K_{A}1)\right)\overset{a}{\leq}\frac{1}{\sqrt{K_{R}+1% }+1}=\frac{1/C_{f}^{\prime}1}{K_{R}}.\end{aligned} where \overset{a}{\leq} is obtained by the condition that
\exp\left((1+1/K_{R})\sqrt{E_{H}E_{W}}/(K_{A}1)\right)\geq\sqrt{1+K_{R}}+1.
Now we pick C_{a}:=g_{a}(x_{0}), C_{b}:=g_{b}(x_{0}), and C_{c}:=g_{c}(x_{0}), because \frac{C_{b}}{C_{c}}=\frac{1/C_{f}^{\prime}1}{K_{R}}, we have C_{e}=C_{e}^{\prime} and C_{f}=C_{f}^{\prime} and 1/K_{R}=C_{f}^{2}+\frac{C_{f}^{4}}{C_{e}(2C_{e})}. Then it follows from \overset{b}{\geq} in (56) that
\displaystyle\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}\mathcal{% L}(\bm{W}\bm{h}_{k,i},\bm{y}_{k})\geqC_{g}(1+1/K_{R})\sqrt{E_{H}E_{W}}+C_{d}=% L_{2}. (61) On the other hand, consider the solution (\bm{H},\bm{W}) that satisfies
\displaystyle\begin{aligned} &\displaystyle\bm{w}_{k}=g_{N}(x_{0})\bm{P}_{A}% \left[\frac{x_{0}}{\sqrt{(K_{A}1)K_{A}}}(K_{A}\bm{y}_{k}\bm{1}_{K_{A}})+% \frac{1x_{0}}{\sqrt{K_{A}}}\bm{1}_{K_{A}}\right],\quad k\in[K_{A}],\\ &\displaystyle\bm{w}_{k}=\frac{C_{e}(2C_{e})}{C_{f}^{2}K_{A}}\bm{P}_{A}\sum_% {k=1}^{K_{A}}\bm{w}_{k},\quad k\in[K_{A}+1:K],\\ &\displaystyle\bm{h}_{k,i}=\frac{\sqrt{(1+1/K_{R})E_{H}}}{\\bm{w}_{i}+\frac{C% _{e}}{C_{f}K_{A}}\sum_{k=1}^{K_{A}}\bm{w}_{k}\}\bm{P}_{A}\left[\bm{w}_{i}+% \frac{C_{e}}{C_{f}K_{A}}\sum_{k=1}^{K_{A}}\bm{w}_{k}\right],\quad k\in[K_{A}],% ~{}i\in[n_{A}],\\ &\displaystyle\bm{h}_{k,i}=\bm{0}_{p},\quad k\in[K_{A}+1:K],~{}i\in[n_{B}],% \end{aligned} where \bm{y}_{k}\in\mathbb{R}^{K} is the vector containing one in the kth entry and zero elsewhere and \bm{P}_{A}\in\mathbb{R}^{p\times K_{A}} is a partial orthogonal matrix such that \bm{P}^{\top}_{A}\bm{P}_{A}=\bm{I}_{K_{A}}. We have \exp\left(\bm{h}_{k,i}^{\top}\bm{w}_{k}\right)=g_{a}(x_{0}) for i\in[n_{A}] and k\in[K_{A}], \exp\left(\bm{h}_{k,i}^{\top}\bm{w}_{k^{\prime}}\right)=g_{b}(x_{0}) for i\in[n_{A}] and k,{k^{\prime}}\in[K_{A}] such that k\neq{k^{\prime}}, and \exp\left(\bm{h}_{k,i}^{\top}\bm{w}_{k^{\prime}}\right)=g_{c}(x_{0}) for i\in[n_{A}], k\in[K_{A}], and {k^{\prime}}\in[K_{B}]. Moreover, (\bm{H},\bm{W}) can achieve the equality in (61). Finally, following a same argument as Case (ii), we have that (1) L_{2} is the global minimum of (50); (2) any minimizer satisfies that \bm{h}_{k,i}=\bm{0}_{p} for all k\in[K_{A}+1:K] and i\in[n_{B}]; (3) for any feasible solution \left(\bm{H}^{\prime},\bm{W}^{\prime}\right), if there exist k,{k^{\prime}}\in[K_{A}+1:K] such that \left\\bm{w}_{k}\bm{w}_{k^{\prime}}\right\=\varepsilon>0, then (60) holds.

(1)
Combining the three cases, we obtain Lemma 4, completing the proof. ∎
Lemma 5.
For any constants C_{a}>0, C_{b}>0, and C_{c}>0, define C_{a}^{\prime}:=\frac{C_{a}}{C_{a}+(K_{A}1)C_{b}+K_{B}C_{c}}\in(0,1), C_{b}^{\prime}:=\frac{C_{b}}{C_{a}+(K_{A}1)C_{b}+K_{B}C_{c}}\in(0,1), and C_{c}^{\prime}:=\frac{C_{c}}{C_{a}+(K_{A}1)C_{b}+K_{B}C_{c}}\in(0,1), C_{d}:=C_{a}^{\prime}\log(C_{a}^{\prime})C_{b}^{\prime}(K_{A}1)\log(C_{b}^{% \prime})K_{B}C_{c}^{\prime}\log(C_{c}^{\prime}), C_{e}:=\frac{K_{A}C_{b}}{K_{A}C_{b}+K_{B}C_{c}}\in(0,1), C_{f}:=\frac{K_{B}C_{c}}{K_{A}C_{b}+K_{B}C_{c}}\in(0,1), and C_{g}:=\frac{K_{A}C_{b}+K_{B}C_{c}}{C_{a}+(K_{A}1)C_{b}+K_{B}C_{c}}>0. For any feasible solution (\bm{H},\bm{W}) of (50), the objective value of (50) can be bounded from below by:
\displaystyle\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}\mathcal{% L}(\bm{W}\bm{h}_{k,i},\bm{y}_{k})  (62)  
\displaystyle\overset{a}{\geq}  \displaystyle\frac{C_{g}}{K_{A}}\sqrt{KE_{H}}\sqrt{\sum_{k=1}^{K_{A}}\left\C% _{e}\bm{w}_{A}+C_{f}\bm{w}_{B}\bm{w}_{k}\right\^{2}}+C_{d}  
\displaystyle\overset{b}{\geq}  \displaystyle\frac{C_{g}}{K_{A}}\sqrt{KE_{H}}\sqrt{KE_{W}\!K_{A}\left(1/K_{R% }C_{f}^{2}\frac{C_{f}^{4}}{C_{e}(2C_{e})}\right)\\bm{w}_{B}\^{2}\!\sum_{% k=K_{A}+1}^{K}\left\\bm{w}_{k}\bm{w}_{B}\right\^{2}}+C_{d}, 
where \bm{w}_{A}:=\frac{1}{K_{A}}\sum_{k=1}^{K_{A}}\bm{w}_{k}, \bm{w}_{B}:=\frac{1}{K_{B}}\sum_{k=K_{A}+1}^{K}\bm{w}_{k}, and K_{R}:=\frac{K_{A}}{K_{B}}. Moreover, the equality in \overset{a}{\geq} hold only if \bm{h}_{k,i}=\bm{0}_{p} for all k\in[K_{A}+1:K].
Proof of Lemma 5.
For k\in[K_{A}] and i\in[n_{k}], we introduce \bm{z}_{k,i}=\bm{W}\bm{h}_{k,i}. Because that C_{a}^{\prime}+(K_{A}1)C_{b}^{\prime}+K_{B}C_{c}^{\prime}=1, C_{a}^{\prime}>0, C_{b}^{\prime}>0, and C_{c}^{\prime}>0, by the concavity of \log(\cdot), we have
\displaystyle\begin{aligned} &\displaystyle\log\left(\frac{\exp(\bm{z}_{k,i}(% i))}{\sum_{{k^{\prime}}=1}^{K}\exp(\bm{z}_{{k^{\prime}},i}(k))}\right)\\ \displaystyle=&\displaystyle\bm{z}_{k,i}(k)+\log\left(C_{a}^{\prime}\left(% \frac{\exp(z_{k,i}(k))}{C_{a}^{\prime}}\right)+\sum_{{k^{\prime}}=1,~{}{k^{% \prime}}\neq k}^{K_{A}}C_{b}^{\prime}\left(\frac{\exp(z_{k,i}({k^{\prime}}))}{% C_{b}^{\prime}}\right)+\sum_{{k^{\prime}}=K_{A}+1}^{K}C_{c}^{\prime}\left(% \frac{\exp(z_{k,i}({k^{\prime}}))}{C_{c}^{\prime}}\right)\right)\\ \displaystyle\geq&\displaystyle\bm{z}_{k,i}(k)+C_{a}^{\prime}\bm{z}_{k,i}(k)+% C_{b}^{\prime}\sum_{{k^{\prime}}=1,~{}{k^{\prime}}\neq k}^{K_{A}}\bm{z}_{k,i}(% {k^{\prime}})+C_{C}^{\prime}\sum_{{k^{\prime}}=K_{A}+1}^{K}\bm{z}_{i,j}(k)+C_{% d}\\ \displaystyle=&\displaystyle C_{g}C_{e}\left(\frac{1}{K_{A}}\sum_{{k^{\prime}}% =1}^{K_{A}}\bm{z}_{k,i}({k^{\prime}})\bm{z}_{k,i}(k)\right)+C_{g}C_{f}\left(% \frac{1}{K_{B}}\sum_{{k^{\prime}}=K_{A}+1}^{K}\bm{z}_{k,i}({k^{\prime}})\bm{z% }_{k,i}(k)\right)+C_{d}.\end{aligned} 
Therefore, integrating (A.2.2) with k\in[K_{A}] and i\in[n_{A}], recalling that \bm{w}_{A}=\frac{1}{K_{A}}\sum_{k=1}^{K_{A}}\bm{w}_{k} and \bm{w}_{B}=\frac{1}{K_{B}}\sum_{k=K_{A}+1}^{K}\bm{w}_{k}, we have
\displaystyle\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}\mathcal{% L}(\bm{W}\bm{h}_{k,i},\bm{y}_{k})  (63)  
\displaystyle\geq  \displaystyle\frac{1}{K_{A}n_{A}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A}}C_{g}% \left[C_{e}(\bm{h}_{k,i}\bm{w}_{A}\bm{h}_{k,i}\bm{w}_{k})+C_{f}(\bm{h}_{k,i}% \bm{w}_{B}\bm{h}_{k,i}\bm{w}_{k})\right]+C_{d}  
\displaystyle\overset{a}{=}  \displaystyle\frac{C_{g}}{K_{A}}\sum_{k=1}^{K_{A}}\bm{h}_{k}^{\top}(C_{e}\bm{w% }_{A}+C_{f}\bm{w}_{B}\bm{w}_{k})+C_{d}, 
where in \overset{a}{=}, we introduce \bm{h}_{k}:=\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\bm{h}_{k,i} for k\in[K], and use C_{e}+C_{f}=1. Then it is sufficient to bound \sum_{k=1}^{K_{A}}\bm{h}_{k}^{\top}(C_{e}\bm{w}_{A}+C_{f}\bm{w}_{B}\bm{w}_{k}). By the Cauchy–Schwarz inequality, we have
\displaystyle\sum_{k=1}^{K_{A}}\bm{h}_{k}^{\top}(C_{e}\bm{w}_{A}+C_{f}\bm{w}_{% B}\bm{w}_{k})\geq  \displaystyle\sqrt{\sum_{k=1}^{K_{A}}\\bm{h}_{k}\^{2}}\sqrt{\sum_{k=1}^{K_{% A}}\left\C_{e}\bm{w}_{A}+C_{f}\bm{w}_{B}\bm{w}_{k}\right\^{2}}  
\displaystyle\overset{a}{\geq}  \displaystyle\sqrt{\sum_{k=1}^{K_{A}}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\\bm{h% }_{k,i}\^{2}}\sqrt{\sum_{k=1}^{K_{A}}\left\C_{e}\bm{w}_{A}+C_{f}\bm{w}_{B}% \bm{w}_{k}\right\^{2}}  
\displaystyle\overset{b}{\geq}  \displaystyle\sqrt{KE_{H}}\sqrt{\sum_{k=1}^{K_{A}}\left\C_{e}\bm{w}_{A}+C_{f% }\bm{w}_{B}\bm{w}_{k}\right\^{2}},  (64) 
where \overset{a}{\geq} follows from Jensens’s inequality \frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\\bm{h}_{k,i}\^{2}\geq\bm{h}_{k} for k\in[K_{A}] and \overset{b}{\geq} uses the constraint that \frac{1}{K}\sum_{k=1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\\bm{h}_{k,i}% \right\^{2}\leq E_{H}. Moreover, we have \sum_{k=1}^{K_{A}}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\\bm{h}_{k,i}\right\% ^{2}=E_{H} only if \bm{h}_{k,i}=\bm{0}_{p} for all k\in[K_{A}+1:K]. Plugging (64) into (63), we obtain \overset{a}{\geq} in (62).
We then bound \sum_{k=1}^{K_{A}}\left\C_{e}\bm{w}_{A}+C_{f}\bm{w}_{B}\bm{w}_{k}\right\^{2}. We have
\displaystyle\frac{1}{K_{A}}\sum_{k=1}^{K_{A}}\left\C_{e}\bm{w}_{A}+C_{f}\bm{% w}_{B}\bm{w}_{k}\right\^{2}  
\displaystyle=  \displaystyle\frac{1}{K_{A}}\sum_{k=1}^{K_{A}}\\bm{w}_{k}\^{2}2\frac{1}{K_{% A}}\sum_{k=1}^{K_{A}}\bm{w}_{k}\cdot(C_{e}\bm{w}_{A}+C_{f}\bm{w}_{B})+\C_{e}% \bm{w}_{A}+C_{f}\bm{w}_{B}\^{2}  
\displaystyle\overset{a}{=}  \displaystyle\frac{1}{K_{A}}\sum_{k=1}^{K_{A}}\\bm{w}_{k}\^{2}2C_{f}^{2}\bm% {w}_{A}^{\top}\bm{w}_{B}C_{e}(2C_{e})\\bm{w}_{A}\^{2}+C_{f}^{2}\\bm{w}_{B% }\^{2}.  (65) 
where \overset{a}{=} uses \sum_{k=1}^{K_{A}}\bm{w}_{k}=K_{A}\bm{w}_{A}. Then using the constraint that \sum_{k=1}^{K}\\bm{w}_{k}\\leq KE_{W} yields that
\displaystyle\frac{1}{K_{A}}\sum_{k=1}^{K_{A}}\\bm{w}_{k}\^{2}2C_{f}^{2}\bm% {w}_{A}^{\top}\bm{w}_{B}C_{e}(2C_{e})\\bm{w}_{A}\^{2}+C_{f}^{2}\\bm{w}_{B% }\^{2}  (66)  
\displaystyle\leq  \displaystyle\frac{K}{K_{A}}E_{W}^{2}\frac{1}{K_{A}}\sum_{k=K_{A}+1}^{K}\!\% \bm{w}_{k}\^{2}C_{e}(2C_{f})\left\\bm{w}_{A}+\frac{C_{f}^{2}}{C_{e}(2C_{e% })}\bm{w}_{B}\right\^{2}\!\!+\!\left(C_{f}^{2}+\frac{C_{f}^{4}}{C_{e}(2C_{e}% )}\right)\\bm{w}_{B}\^{2}  
\displaystyle\overset{a}{=}  \displaystyle\frac{K}{K_{A}}E_{W}^{2}\left(1/K_{R}C_{f}^{2}\frac{C_{f}^{4}}% {C_{e}(2C_{e})}\right)\\bm{w}_{B}\^{2}\frac{1}{K_{A}}\sum_{k=K_{A}+1}^{K}% \!\left\\bm{w}_{k}\bm{w}_{B}\right\^{2}, 
where \overset{a}{\geq} applies \sum_{k=K_{A}+1}^{K}\\bm{w}_{k}\^{2}=K_{B}\\bm{w}_{B}\^{2}+\sum_{k=K_{A}+1% }^{K}\left\\bm{w}_{k}\bm{w}_{B}\right\^{2}. Plugging (A.2.2) and (66) into \overset{a}{\geq} in (62), we obtain \overset{b}{\geq} in (62), completing the proof. ∎
Appendix B Additional Results
Comparison of Oversampling and Weighted Adjusting.
Oversampling and weight adjusting are two commonlyused tricks in deep learning [JK19]. Both of them actually consider the same objective as (15), but applies different optimization algorithms to minimize the objective. It was observed that oversampling is more stable than weight adjusting in optimization. As a by product of this work, we compare the two algorithms below and shows that the variance of updates for oversampling will be potentially much smaller than that of weight adjusting. It was wellknown in stochastic optimization field that the variance of the updates decides the convergence of an optimization algorithm (see e.g, [BCN18, FLL+18, FLZ19]). Thus we offer a reasonable justification for the stability of the oversampling technique. We simply consider sampling the training data without replacement. It slightly differs from the deep learning training methods in practice. Besides, we only consider sampling a single data in each update. The analysis can be directly extended to the minibatch setting.
We first introduce the two methods. The weight adjusting algorithm in each update randomly samples a training data, and updates the parameters \bm{W}_{\textnormal{full}} by the Stochastic Gradient Descent algorithm as
\displaystyle\bm{W}_{\textnormal{full}}^{t+1}=\bm{W}_{\textnormal{full}}^{t}% \eta_{w}\bm{v}_{w}^{t},\quad t=0,1,2,\dots,  (67) 
where \bm{W}_{\textnormal{full}}^{t} denotes the parameters at iteration step t, \eta_{w} is a positive step size, and the stochastic gradient \bm{v}_{w}^{t} satisfies that
\bm{v}_{w}^{t}=\begin{cases}\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(% \bm{x}_{k,i};\bm{W}_{\textnormal{full}}^{t}),\bm{y}_{k}),&k\in[K_{A}],i\in[n_{% A}],\text{~{}with probability~{}}\frac{1}{K_{A}n_{A}+K_{B}n_{B}},\\ w_{r}\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(\bm{x}_{k,i};\bm{W}_{% \textnormal{full}}^{t}),\bm{y}_{k}),&k\in[K_{A}+1:K_{B}],i\in[n_{B}],\text{~{}% with probability~{}}\frac{1}{K_{A}n_{A}+K_{B}n_{B}}.\end{cases} 
We have
\displaystyle\mathbb{E}\left[\bm{v}_{w}^{t}\mid\bm{W}_{\textnormal{full}}^{t}\right]  (68)  
\displaystyle=  \displaystyle\frac{1}{n_{A}K_{A}+n_{B}K_{B}}\left[\sum_{k=1}^{K_{A}}\sum_{i=1}% ^{n_{A}}\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(\bm{x}_{k,i};\bm{W}_{% \textnormal{full}}^{t}),\bm{y}_{k})+w_{r}\!\!\sum_{k=K_{A}+1}^{K}\!\sum_{i=1}^% {n_{B}}\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(\bm{x}_{k,i};\bm{W}_{% \textnormal{full}}^{t}),\bm{y}_{k})\right], 
and
\displaystyle\mathbb{E}\left[\\bm{v}_{w}^{t}\^{2}\mid\bm{W}_{\textnormal{% full}}^{t}\right]=  \displaystyle\frac{1}{n_{A}K_{A}+n_{B}K_{B}}\sum_{k=1}^{K_{A}}\sum_{i=1}^{n_{A% }}\left\\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(\bm{x}_{k,i};\bm{W}_% {\textnormal{full}}^{t}),\bm{y}_{k})\right\^{2}  
\displaystyle+\frac{w_{r}^{2}}{n_{A}K_{A}+n_{B}K_{B}}\sum_{k=K_{A}+1}^{K}\sum_% {i=1}^{n_{B}}\left\\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(\bm{x}_{k% ,i};\bm{W}_{\textnormal{full}}^{t}),\bm{y}_{k})\right\^{2}.  (69) 
For the oversampling method, the algorithm in effect duplicates the data by w_{r} times and runs Stochastic Gradient Descent on the “whole” data. Therefore, the update goes as
\displaystyle\bm{W}_{\textnormal{full}}^{t+1}=\bm{W}_{\textnormal{full}}^{t}% \eta_{s}\bm{v}_{s}^{t},\quad t=0,1,2,\dots,  (70) 
where \bm{v}_{s}^{t} satisfies that
\bm{v}_{s}^{t}=\begin{cases}\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(% \bm{x}_{k,i};\bm{W}_{\textnormal{full}}^{t}),\bm{y}_{k}),&k\in[K_{A}],i\in[n_{% A}],\text{~{}with probability~{}}\frac{1}{K_{A}n_{A}+K_{B}w_{r}n_{B}},\\ \nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(\bm{x}_{k,i};\bm{W}_{% \textnormal{full}}^{t}),\bm{y}_{k}),&k\in[K_{A}+1:K_{B}],i\in[n_{B}],\text{~{}% with probability~{}}\frac{w_{r}}{K_{A}n_{A}+K_{B}w_{r}n_{B}}.\end{cases} 
We obtain
\displaystyle\mathbb{E}\left[\bm{v}_{s}^{t}\mid\bm{W}_{\textnormal{full}}^{t}% \right]=  \displaystyle\frac{1}{n_{A}K_{A}+w_{r}n_{B}K_{B}}\sum_{k=1}^{K_{A}}\sum_{i=1}^% {n_{A}}\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(\bm{x}_{k,i};\bm{W}_{% \textnormal{full}}^{t}),\bm{y}_{k})  
\displaystyle+\frac{w_{r}}{n_{A}K_{A}+w_{r}n_{B}K_{B}}\sum_{k=K_{A}+1}^{K}\sum% _{i=1}^{n_{B}}\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(\bm{x}_{k,i};% \bm{W}_{\textnormal{full}}^{t}),\bm{y}_{k}), 
and
\displaystyle\mathbb{E}\left[\\bm{v}_{s}^{t}\^{2}\mid\bm{W}_{\textnormal{% full}}^{t}\right]=  \displaystyle\frac{1}{n_{A}K_{A}+w_{r}n_{B}K_{B}}\sum_{k=1}^{K_{A}}\sum_{i=1}^% {n_{A}}\left\\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(\bm{x}_{k,i};% \bm{W}_{\textnormal{full}}^{t}),\bm{y}_{k})\right\^{2}  
\displaystyle+\frac{w_{r}}{n_{A}K_{A}+w_{r}n_{B}K_{B}}\sum_{k=K_{A}+1}^{K}\sum% _{i=1}^{n_{B}}\left\\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(\bm{x}_{% k,i};\bm{W}_{\textnormal{full}}^{t}),\bm{y}_{k})\right\^{2}.  (71) 
We suppose the two updates in expectation are in a same scale. That means we assume \eta_{w}=\frac{n_{A}K_{A}+w_{r}n_{B}K_{B}}{n_{A}K_{A}+n_{B}K_{B}}\eta_{s}. Then \eta_{w}\mathbb{E}\left[\bm{v}_{w}^{t}\mid\bm{W}_{\textnormal{full}}^{t}\right% ]=\eta_{s}\mathbb{E}\left[\bm{v}_{s}^{t}\mid\bm{W}_{\textnormal{full}}^{t}\right]. In fact, if K_{A}\asymp 1, K_{B}\asymp 1, n_{A}\gg n_{B}, and 1\ll w_{r}\lesssim\left(n_{A}/n_{B}\right), we have \frac{n_{A}K_{A}+w_{r}n_{B}K_{B}}{n_{A}K_{A}+n_{B}K_{B}}\asymp 1 and so \eta_{w}\asymp\eta_{s}. Now by comparing (B) with (B), we obtain that the second moment of \eta_{w}\bm{v}_{w}^{t} is much smaller than that of \eta_{s}\bm{v}_{s}^{t} since the order of w_{r} for the latter is larger by 1. For example, let us assume that all the norms of the gradients are in a same order, i.e., \left\\nabla_{\bm{W}_{\textnormal{full}}}\mathcal{L}(f(\bm{x}_{k,i};\bm{W}_{% \textnormal{full}}^{t}),\bm{y}_{k})\right\\asymp a for all k and i, where a>0. Then (B) implies that \mathbb{E}\left[\\bm{v}_{s}^{t}\^{2}\mid\bm{W}_{\textnormal{full}}^{t}\right% ]\asymp\eta_{s}^{2}a^{2}. However, (B) reads that \mathbb{E}\left[\\bm{v}_{w}^{t}\^{2}\mid\bm{W}_{\textnormal{full}}^{t}\right% ]\asymp\eta_{s}^{2}\frac{n_{A}K_{A}+w_{r}^{2}n_{B}K_{B}}{n_{A}K_{A}+w_{r}n_{B}% K_{B}}a^{2}. Furthermore, if we set w_{r}\asymp n_{A}/n_{B}, then \mathbb{E}\left[\\bm{v}_{w}^{t}\^{2}\mid\bm{W}_{\textnormal{full}}^{t}\right% ]\asymp\eta_{s}^{2}w_{r}a^{2}. Thus the second moment for \eta_{w}\bm{v}_{w}^{t} is around w_{r} times of that for \eta_{s}\bm{v}_{s}^{t}. And this fact also holds for the variance because \left\\eta_{s}\mathbb{E}\left[\bm{v}_{s}^{t}\mid\bm{W}_{\textnormal{full}}^{t% }\right]\right\\asymp\eta_{s}a and the property that \mathbb{E}\\bm{x}\mathbb{E}[\bm{x}]\^{2}=\mathbb{E}\\bm{x}\^{2}\\mathbb% {E}[\bm{x}]\^{2} for any random variable \bm{x}. Therefore, we can conclude that the variance of updates for oversampling is potentially much smaller than that of weight adjusting.
More Discussions on Convex Relaxation and CrossEntropy Loss.
We show Program (7) can also be relaxed as a nuclear normconstrained convex optimization. The result heavily relies on the progress of matrix decomposition, e.g. [BMP08, HV19]. We will the use the equality (see e.g., [BMP08, Section 2]) that for any matrix \bm{Z} and a>0,
\\bm{Z}\_{*}=\inf_{r\in\mathbb{N}_{+}}\inf_{\bm{U},\bm{V}:\bm{U}\bm{V}^{\top% }=\bm{Z}}\frac{a}{2}\\bm{U}\^{2}+\frac{1}{2a}\\bm{V}\^{2},  (72) 
where r is the number of columns for \bm{U} and \\cdot\_{*} denotes the nuclear norm.
For any feasible solution \left(\bm{H},\bm{W}\right) for the original program (7), we define
\bm{h}_{k}=\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\bm{h}_{k,i},~{}k\in[K],\quad% \tilde{\bm{H}}=[\bm{h}_{1},\bm{h}_{2},\dots,\bm{h}_{K}]\in\mathbb{R}^{p\times K% },~{}~{}\text{and}~{}~{}\bm{Z}=\bm{W}\tilde{\bm{H}}\in\mathbb{R}^{K\times K}.  (73) 
We consider the convex program:
\displaystyle\min_{\bm{Z}\in\mathbb{R}^{K\times K}}  \displaystyle\sum_{k=1}^{K}\frac{n_{k}}{N}\mathcal{L}(\bm{Z}_{k},\bm{y}_{k})  (74)  
\displaystyle\mathrm{s.t.}  \displaystyle\\bm{Z}\_{*}\leq K\sqrt{E_{H}E_{W}}. 
where \bm{Z}_{k} denotes the kth column of \bm{Z} for k\in[K].
Lemma 6.
Assume p\geq K and the loss function \mathcal{L} is convex on the first argument. Let \bm{Z}^{\star} be a minimizer of the convex program (74). Let r be the rank of \bm{Z}^{\star} and consider thin Singular Value Decomposition (SVD) of \bm{Z}^{\star} as \bm{Z}^{\star}=\bm{U}^{\star}\bm{\Sigma}^{\star}\bm{V}^{\star}. Introduce two diagonal matrices \bm{\Sigma}_{1}^{\star} and \bm{\Sigma}_{2}^{\star} with the entries defined as \bm{\Sigma}_{1}^{\star}(i,i)=\sqrt{\frac{E_{W}}{E_{H}}}\sqrt{\bm{\Sigma}^{% \star}(i,i)} and \bm{\Sigma}_{2}^{\star}(i,i)=\sqrt{\frac{E_{H}}{E_{W}}}\bm{\Sigma}^{\star}(i,i% )/\sqrt{\bm{\Sigma}^{\star}(i,i)} for i\in[r], respectively. Let \left(\bm{H}^{\star},\bm{W}^{\star}\right) be
\displaystyle\bm{W}=\bm{U}^{\star}\bm{\Sigma}_{1}^{\star}\bm{P}^{\top},\quad% \left[\bm{h}_{1}^{\star},\bm{h}_{2}^{\star},\dots,\bm{h}_{K}^{\star}\right]=% \bm{P}\bm{\Sigma}_{2}^{\star}\bm{V}^{\star},  (75)  
\displaystyle\bm{h}_{k,i}^{\star}=\bm{h}_{k}^{\star},\quad k\in[K],~{}i\in[n_{% k}], 
where \bm{P}\in\mathbb{R}^{p\times r} is any partial orthogonal matrix such that \bm{P}^{\top}\bm{P}=\bm{I}_{r}. Then (\bm{H}^{\star},\bm{W}^{\star}) is a minimizer of (7).
Proof of Lemma 6.
For any feasible solution \left(\bm{H},\bm{W}\right) for the original program (7), define \bm{h}_{k} for k\in[K], \tilde{\bm{H}}, and \bm{Z} by (73). We show \bm{Z} is a feasible solution for the convex program (74). In fact, by (72) with r=K and a=\sqrt{E_{H}/E_{W}}, we have
\displaystyle\left\\bm{Z}\right\_{*}  \displaystyle\leq\frac{\sqrt{E_{H}/E_{W}}}{2}\left\\bm{W}\right\^{2}+\frac{% \sqrt{E_{W}/E_{H}}}{2}\left\\tilde{\bm{H}}\right\^{2}  
\displaystyle\overset{a}{\leq}\frac{\sqrt{E_{H}/E_{W}}}{2}\sum_{k=1}^{K}\\bm{% w}_{k}\^{2}+\frac{\sqrt{E_{W}/E_{H}}}{2}\sum_{k=1}^{K}\frac{1}{n_{k}}\sum_{i=% 1}^{n_{k}}\left\\bm{h}_{k,i}\right\^{2}  
\displaystyle\leq K\sqrt{E_{H}E_{W}},  (76) 
where \overset{a}{\leq} applies Jensen’s inequality as:
\left\\tilde{\bm{H}}\right\^{2}=\sum_{k=1}^{K}\\bm{h}_{k}\^{2}\leq\sum_{k=% 1}^{K}\frac{1}{n_{k}}\sum_{i=1}^{n_{k}}\left\\bm{h}_{k,i}\right\^{2}. 
Let L_{0} be the global minimum of the convex problem (74). Since \mathcal{L} is convex on the first argument, by the same argument as (A.2.1), we obtain, for any feasible solution \left(\bm{H},\bm{W}\right),
\displaystyle\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n_{k}}\mathcal{L}(\bm{W}\bm{% h}_{k,i},\bm{y}_{k})  \displaystyle=\sum_{k=1}^{K}\frac{n_{k}}{N}\left[\frac{1}{n_{k}}\sum_{k=1}^{n_% {k}}\mathcal{L}(\bm{W}\bm{h}_{k,i},\bm{y}_{k})\right]  
\displaystyle\geq\sum_{k=1}^{K}\frac{n_{k}}{N}\mathcal{L}(\bm{W}\bm{h}_{k},\bm% {y}_{k})=\sum_{k=1}^{K}\frac{n_{k}}{N}\mathcal{L}(\bm{Z}_{k},\bm{y}_{k})\geq L% _{0}.  (77) 
On the other hand, for the solution \left(\bm{H}^{\star},\bm{W}^{\star}\right) defined in (75) with \bm{Z}^{\star}, we can verify that \left(\bm{H}^{\star},\bm{W}^{\star}\right) is a feasible solution for (7) and
\frac{1}{N}\sum_{k=1}^{K}\sum_{i=1}^{n_{k}}\mathcal{L}(\bm{W}^{\star}\bm{h}_{k% ,i}^{\star},\bm{y}_{k})=\sum_{k=1}^{K}\frac{n_{k}}{N}\mathcal{L}(\bm{Z}_{k}^{% \star},\bm{y}_{k})=L_{0}.  (78) 
Combining (B) and (78), we have that L_{0} is the global minimum of (7) and (\bm{H}^{\star},\bm{W}^{\star}) is a minimizer. ∎
Property 1.
For the crossentropy loss, we have the following properties.

(A)
Any minimizer \bm{Z}^{\star} of (74) satisfies that \\bm{Z}\_{*}=\sqrt{E_{H}E_{W}}.

(B)
Any minimizer (\bm{H}^{\star},\bm{W}^{\star}) of (7) satisfies
\frac{1}{K}\sum_{k=1}^{K}\frac{1}{n}\sum_{i=1}^{n}\left\\bm{h}_{k,i}^{\star}% \right\^{2}=E_{H},\quad\text{and}\quad\quad\frac{1}{K}\sum_{k=1}^{K}\left\% \bm{w}_{k}^{\star}\right\^{2}=E_{W}. 
(C)
Any minimizer \bm{X}^{\star} of (13) satisfies that
\frac{1}{K}\sum_{k=1}^{K}\bm{X}^{\star}(k,k)=E_{H},\quad\text{and}\quad\quad% \frac{1}{K}\sum_{k=K+1}^{2K}\bm{X}^{\star}(k,k)=E_{W}.
Proof of Property 1.
We first prove (A). Let \bm{Z}^{\star} be any minimier of (74). Then by the Karush–Kuhn–Tucker conditions, there is a pair (\lambda,\bm{\xi}) with \lambda\geq 0 and \bm{\xi}\in\partial\\bm{Z}^{\star}\_{*} such that
\nabla_{\bm{Z}}\left[\sum_{k=1}^{K}\frac{n_{k}}{N}\mathcal{L}(\bm{Z}_{k}^{% \star},\bm{y}_{k})\right]+\lambda\bm{\xi}=\bm{0}^{K\times K}, 
where \partial\\bm{Z}\_{*} denotes the set of subgradient of \\bm{Z}\_{*}. For the crossentropy loss, one can verify that \nabla_{\bm{Z}}\left[\sum_{k=1}^{K}\frac{n_{k}}{N}\mathcal{L}(\bm{Z}_{k},\bm{y% }_{k})\right]\neq\bm{0}^{K\times K} for all \bm{Z}. So \lambda\neq 0. By the complementary slackness condition, we have \bm{Z} will reach the boundary of the constraint, achieving (A).
For (B), suppose there is a minimizer (\bm{H}^{\star},\bm{W}^{\star}) of (7) such that \frac{1}{K}\sum_{k=1}^{K}\frac{1}{n}\sum_{i=1}^{n}\left\\bm{h}_{k,i}^{\star}% \right\^{2}<E_{H} or \frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}^{\star}\right\^{2}<E_{W}. Letting \bm{Z}^{\star} defined as (73), it follows from (B) that \bm{Z}^{\star} is a minimizer of (74). However, by (B), we have \\bm{Z}^{\star}\_{*}<\sqrt{E_{H}E_{W}}, which is contradictory to (A). We obtain (B).
For (C), suppose there is a minimizer \bm{X}^{\star} of (13) such that \frac{1}{K}\sum_{k=1}^{K}\bm{X}^{\star}(k,k)<E_{H} or
\frac{1}{K}\sum_{k=K+1}^{2K}\bm{X}^{\star}(k,k)<E_{W}. Then letting (\bm{H}^{\star},\bm{W}^{\star}) defined in (14), (\bm{H}^{\star},\bm{W}^{\star}) is a minimizer of (7) from Theorem 1. However, we have \frac{1}{K}\sum_{k=1}^{K}\frac{1}{n}\sum_{i=1}^{n}\left\\bm{h}_{k,i}^{\star}%
\right\^{2}<E_{H} or \frac{1}{K}\sum_{k=1}^{K}\left\\bm{w}_{k}^{\star}\right\^{2}<E_{W}, which is contradictory to (B). We complete the proof.
∎