MultiBranch Tensor Network Structure for TensorTrain Discriminant Analysis
Abstract
Higherorder data with high dimensionality arise in a diverse set of application areas such as computer vision, video analytics and medical imaging. Tensors provide a natural tool for representing these types of data. Although there has been a lot of work in the area of tensor decomposition and lowrank tensor approximation, extensions to supervised learning, feature extraction and classification are still limited. Moreover, most of the existing supervised tensor learning approaches are based on the orthogonal Tucker model. However, this model has some limitations for large tensors including high memory and computational costs. In this paper, we introduce a supervised learning approach for tensor classification based on the tensortrain model. In particular, we introduce a multibranch tensor network structure for efficient implementation of tensortrain discriminant analysis (TTDA). The proposed approach takes advantage of the flexibility of the tensor train structure to implement various computationally efficient versions of TTDA. This approach is then evaluated on image and video classification tasks with respect to computation time, storage cost and classification accuracy and is compared to both vector and tensor based discriminant analysis methods.
Index terms— TensorTrain, Tensor Networks, Multidimensional Discriminant Analysis, Supervised TensorTrain.
I Introduction
Tensors, which are higher order generalizations of matrices and vectors, provide a natural way to represent multidimensional data objects whose entries are indexed by several continuous or discrete variables. Employing tensors and their decompositions to process data objects has become increasingly popular [1, 2, 3]. For instance, a color image is a thirdorder tensor defined by two indices for spatial variables and one index for color mode. A video comprised of color images is a fourthorder tensor with an additional index for the time variable. Majority of the current work on tensor decomposition is focused on unsupervised learning of lowrank representations of the tensor object, e.g. PARAFAC [4] and Tucker decomposition [5, 3, 6, 7].
1 many realworld applications such as computer vision, data instances are more naturally represented as secondorder or higherorder tensors, where the order of a tensor corresponds to the number of modes. Conventional supervised learning approaches including LDA applied to vectorized tensor samples are inadequate when dealing with massive multidimensional data as they cannot capture the crosscouplings across the different modes and suffer from increasing storage and computational costs [2, 8, 1]. Therefore, there is a growing need for new methods that account for the intrinsic structure of data while learning discriminant subspaces.
In recent years, supervised and unsupervised tensor subspace learning approaches based on the Tucker model have been proposed [9, 10, 11, 12, 13]. Some of these approaches such as Multilinear Principal Component Analysis (MPCA) [13] is successful at dimensionality reduction but is not necessarily suitable for discriminative feature extraction. Others such as Multilinear Discriminant Analysis (MDA) [9] are not practical with increasing number of modes, due to exponential increase in the storage cost of Tucker model [14].
TensorTrain (TT) model, on the other hand, provides better compression than Tucker models, especially for higher order tensors, as it expresses a given highdimensional tensor as the product of lowrank, 3mode tensors [2]. TT model has been employed in various applications such as PCA [15], manifold learning [16] and deep learning [17]. In this paper, we introduce a discriminant subspace learning approach using the TT model, namely the TensorTrain Discriminant Analysis (TTDA). The proposed approach is based on the linear discriminant analysis (LDA) and approximates the LDA subspace by putting a constraint such that it has a TT structure. Although this constraint provides an efficient structure for storing the learned subspaces, it is computationally prohibitive. For this reason, we propose several computationally efficient implementations of TTDA utilizing the flexibility of the tensor network structure. Previous work on TT and tensor networks in general, brings up the question of whether reshaping highdimensional vector and matrixtype data into tensors and then processing them using TT decomposition provides any significant benefits [18]. Several papers employed this idea to reshape matrices into tensors, known as ket augmentation and quantized TT (QTT), for better compression and higher computational efficiency [19, 16, 18, 20]. In QTT, a vector or a matrix is tensorized to a higher order tensor to improve both the computational complexity and storage cost of TT. Inspired by this idea, we propose to tensorize the projected training samples in a learning framework. Once the projections are tensorized, the corresponding TT subspaces become smaller in dimension. Using this structural approximation to TTDA, first we propose approximating 2DLDA by TT and then generalize by increasing the number of modes (or branches) of the projected training samples.
This paper differs from current work in two ways. First, we generalize discriminant analysis using TT model, and propose several approaches to implement this generalization using the efficiency of multibranch Tensor Network structure in discriminant analysis. Second, we explore computational and storage complexity in terms of the number of modes of the projection tensors. Therefore, depending on the input tensor dimensions, we provide a method to find the most efficient way to implement the TT model.
The rest of the paper is organized as follows. In Section II, we provide background on tensor operations, TT and Tucker decomposition, LDA and MDA. In Section III, we propose several approaches to implement tensortrain discriminant analysis. We also provide storage cost and computational complexities for the proposed approaches and propose an optimal structure for decomposing a given tensor object in terms of storage complexity. In Section IV, we compare the proposed methods with stateoftheart tensor based discriminant analysis methods for classification applications.
Ii Background
Let be the collection of tensor training data samples. Define as the sample tensors where is the class index and is the sample index, from a given with classes where each class has samples.
Iia Notation
Definition 1. (Vectorization, Matricization and Reshaping) is a vectorization operator such that . is a tensortomatrix reshaping operator defined as and the inverse operator is denoted as .
Definition 2. (Left and right unfolding) The left unfolding operator creates a matrix from a tensor by taking all modes except the last mode as row indices and the last mode as column indices, i.e. which is equivalent to . Right unfolding transforms a tensor to a matrix by taking all the first mode fibers as column vectors, i.e. which is equivalent to . The inverse of these operators are denoted as and , respectively.
Definition 3. (Tensor trace) Tensor trace is applied on matrix slices of a tensor and contracts the matrices to a scalar. Let with , then trace operation on modes and is defined as:
(1) 
where is a mode tensor.
Definition 4. (Tensor Merging Product) Tensor merging product connects two tensors along some given sets of modes. For two tensors and where and for some and , tensor merging product is given by [14]:
(2) 
is a mode tensor that is calculated as:
(3) 
A graphical representation of tensors and and tensor merging product above is given in Fig. 1.
A special case of the tensor merging product can be considered for the case where for all . In this case, the tensor merging product across the first modes is defined as:
(4) 
where . This can equivalently be written as:
(5) 
where .
Definition 5. (Tensor Train Decomposition (TT Decomposition)) Using tensortrain decomposition, each element of can be represented as:
(6) 
where are the three mode lowrank tensor factors, are the TTranks of the corresponding modes and is the projected sample vector. Using tensor merging product form, (6) can be rewritten as
(7) 
A graphical representation of (7) can be seen in Fig. 2. If is vectorized, another equivalent expression for (6) in terms of matrix projection is obtained as:
(8) 
For the sake of simplicity, we define where . When s are left orthogonal, is also left orthogonal [21], i.e. leads to where is defined as the identity matrix with size .
Definition 6. (Tucker Decomposition (TD)) If the number of modes of the projected samples is equal to the number of modes of the input tensors , the TTmodel becomes equivalent to Tucker decomposition. In this case, is also known as the core tensor. This is given in (9) and also shown in Fig. 3:
(9) 
where and .
IiB Linear Discriminant Analysis (LDA)
LDA for vectorized tensor data finds an orthogonal projection that maximizes the discriminability of projections:
(10) 
where , is the regularization parameter that controls the tradeoff between and which are withinclass and betweenclass scatter matrices, respectively, given by:
(11)  
(12) 
where is the sample mean for each class and is the total mean of all sample tensors. Since is an orthogonal projection matrix, (10) is equivalent to minimizing the withinclass scatter and maximizing the between class scatter of projections, which can be solved by taking the eigenvectors corresponding to the lowest eigenvalues of as , i.e. .
IiC Multilinear Discriminant Analysis (MDA)
MDA extends TD to supervised learning by finding a subspace for each mode that maximizes the discriminability along that mode [10, 11, 9]. When the number of modes is 1, MDA is equivalent to LDA. In the case of MDA, withinclass scatter along each mode is defined as:
(13) 
Betweenclass scatter is also found in a similar manner. Using these definitions of the scatter matrices, each is found by optimizing [11]:
(14) 
Different implementations of the multiway discriminant analysis have been introduced including Discriminant Analysis with Tensor Representation (DATER), Direct Generalized Tensor Discriminant Analysis (DGTDA) and Constrained MDA (CMDA). DATER minimizes the ratio [10]. Direct Generalized Tensor Discriminant Analysis (DGTDA), on the other hand, finds the solution for each mode independent of the other modes, i.e. computes scatter matrices without projecting inputs on , where for each mode and updates using these [9]. Constrained MDA (CMDA) finds the solution in an iterative fashion [9], where each subspace is found by fixing all other subspaces. As these approaches use Tucker Decomposition, their storage complexity becomes inefficient with increasing number of modes and samples. In this paper, we propose using novel TTbased models to overcome this problem while still having similar or better computational complexity.
Iii Methods for TensorTrain Discriminant Analysis
When the data are higher order tensors, LDA needs to first vectorize them and then find an optimal projection as shown in (10). This creates several problems as the intrinsic structure of the data is destroyed and dimensionality, computation time and storage cost increase exponentially. Thus, in this section we propose to solve the above problem by constraining to be a TT subspace to reduce the computational and storage complexity and to obtain a solution that will preserve the inherent structure. With this approach, the obtained will still result in discriminative features and will have additional structure imposed by the TT.
Iiia TT Discriminant Analysis (TTDA):
The goal of TTDA is to learn left orthogonal tensor factors using TTmodel such that the discriminability of projections is maximized. Using TT decomposition proposed in [22] can be initialized. To optimize s for discriminability, we need to solve (10) for each , which can be rewritten using the definition of as:
(15) 
Using the definitions presented in (4) and (5), we can express (15) in terms of tensor merging product:
(16) 
where . Let and . By rearranging the terms in (16), we can first compute all merging products and trace operations that do not involve and then write the optimization problem only in terms of . Thus, we first define:
(17) 
where (refer to Fig. 4 for a graphical representation of (17)). Then, (15) is equivalent to:
(18) 
Let , then (18) can be rewritten as:
(19) 
This is a convex function with unitary constraints and can be solved by the algorithm proposed in [23]. The procedure described above to find the subspaces is computationally expensive and takes a long time to solve due to the complexity of finding each [16].
When , (19) does not apply as is not defined and the trace operation is defined on the third mode of . To update , the following can be used:
(20) 
where . See Fig. 5 for a graphical representation of . The pseudocode for TTDA is given in Algortihm 1.
IiiB Twoway Tensor Train Discriminant Analysis (2WTTDA):
As LDA tries to find a subspace which maximizes discriminability for vectortype data, similarly 2DLDA tries to find two subspaces such that these subspaces maximize discriminability for matrixtype data [24]. If one considers the matricized version of along mode , i.e. , where , the equivalent orthogonal projection can be written as:
(21) 
where and .
In TTDA, since the projections are considered to be vectors, the subspace is analogous to the solution of LDA with the constraint that the subspace admits a TT model. If we consider the projections and the input samples as matrices, now we can impose a TT structure to the two subspaces (left and right subspaces) analogous to 2DLDA, instead of LDA. In other words, one can find two sets of TT representations corresponding to and in (21). Using this analogy, (21) can be rewritten as:
(22) 
which is equivalent to the following tensortrain representation:
(23) 
This tensortrain formulation is graphically represented in Fig. 6 where the decomposition, i.e. , has two branches, thus we denote it as Twoway Tensor Train Decomposition (2WTT).
To maximize discriminability using 2WTT, an optimization scheme that alternates between the two sets of TTsubspaces can be utilized. This will reduce the dimensionality of the input by projecting it to a lower rank matrix. Computational complexity will then be reduced as the cost of computing scatter matrices and the number of matrix multiplications to find in (17) will decrease. We propose the procedure given in Algorithm 2 to implement this approach and refer to this approach of implementing TTDA as Twoway Tensor Train Discriminant Analysis (2WTTDA).
To determine the value of in (22), we use a center of mass approach and find the that minimizes . In this manner, the problem can be separated into two parts which have similar computational complexities.
IiiC Threeway Tensor Train Discriminant Analysis (3WTTDA):
Elaborating on the idea of 2WTTDA, one can increase the number of modes of the projected samples which will increase the number of tensor factor sets, or equivalently the number of subspaces to be approximated using TT structure. For example, one might choose the number of modes of the projections as three, i.e. , where . This model, named as Threeway Tensor Train Decomposition(3WTT) is given in (24) and represented graphically in Fig. 7.
(24) 
To maximize discriminability using 3WTT, one can utilize an iterative approach as in 2WTTDA, where input tensors are projected on all tensor factor sets except the set to be optimized, then apply TTDA to the projections. This procedure can be repeated until a convergence criterion is met or a number of iterations is reached. The values of and are calculated such that the product of dimensions corresponding to each set is as close to as possible. It is important to note that 3WTT will only be meaningful for tensors of order three or higher. When tensors are of order three, 3WTT is equivalent to Tucker Model.
IiiD Computational and Storage Complexity
In this section, we analyze the computational and storage complexities of the aforementioned algorithms.
Storage Complexity: Let and . Assuming is a multiple of both 2 and 3, total storage complexities are:

for TT Decomposition, where ;

for TwoWay TT Decomposition, where ;

for ThreeWay TT Decomposition, where ;

for Tucker Decomposition, where .
Subspace  Projection  

TT  
2WTT 