Matrix Sketching for Secure Collaborative Machine Learning
Collaborative machine learning (ML), also known as federated ML, allows participants to jointly train a model without data sharing. To update the model parameters, the central parameter server broadcasts model parameters to the participants, and the participants send ascending directions such as gradients to the server. While data do not leave a participant’s device, the communicated gradients and parameters will leak a participant’s privacy. Prior work proposed attacks that infer participant’s privacy from gradients and parameters, and they showed simple defenses like dropout and differential privacy do not help much.
To defend privacy leakage, we propose a method called Double Blind Collaborative Learning (DBCL) which is based on random matrix sketching. The high-level idea is to apply a random transformation to the parameters, data, and gradients in every iteration so that the existing attacks will fail or become less effective. While it improves the security of collaborative ML, DBCL does not increase the computation and communication cost much and does not hurt prediction accuracy at all. DBCL can be potentially applied to decentralized collaborative ML to defend privacy leakage.
collaborative machine learning, federated learning, randomized linear algebra, matrix sketching, data privacy.
Collaborative machine learning (ML) allows for learning a model from participants’ data but without data sharing. Collaborative ML is motivated by real-world applications.
Suppose Amazon, Facebook, or Google wants to train a model using users’ data which can be accessed by their APPs. A straightforward solution is to collect users’ data and then train a model on the server. However, what if users are unwilling to upload their data to the server?
Suppose some hospitals want to use machine learning to automatically diagnose a disease. Every hospital can train a model using its own data; however, such a model is inferior to a model learned on all the hospitals’ data. Oftentimes, laws and policies may forbid giving patients’ medical data to another party.
Distributed stochastic gradient descent (SGD), as illustrated in Figure 1(a), is perhaps the simplest approach to collaborative learning. Specifically, the central parameter server broadcasts model parameters to the participants, every participant uses a batch of local data to evaluate a stochastic gradient, and the server aggregates the stochastic gradients and updates the model parameters. McMahan et al. (2017) proposed the federated averaging (FedAvg) algorithm for a higher communication-efficiency. FedAvg let the participants locally run multiple SGDs, instead of only one, so that the aggregated updates will be a better descending direction.
Collaboratively training a model seems to preserve participants’ privacy, which is unfortunately demonstrated not true by recent studies (Hitaj et al., 2017; Melis et al., 2019). Even if a participant’s data do not leave his device, important properties of his data can be inferred by an adversary (the party who seeks to infer others’ data). Surprisingly, Hitaj et al. (2017); Melis et al. (2019) showed that to infer other participants’ data, the adversary needs only to control one participant device and access the model parameters in every iteration; the adversary does not need to take control of the server.
The reason why the attacks can be conducted is that the model parameters and gradients carry important information regarding the data (Ateniese et al., 2015; Fredrikson et al., 2015). Hitaj et al. (2017) used the joint model as a discriminator to train a generator which roughly depicts other participants’ data. Melis et al. (2019) used the difference between two consecutive parameter tensors to infer the gradient and then conduct property inference using the gradient.
Hitaj et al. (2017); Melis et al. (2019) studied several simple defenses, but none of the defense works well. While differential privacy, i.e., adding noises to the data, parameters, or gradients, can stop the leaking of privacy, the noise evitably hurts the accuracy and may even stop the learning from making progress (Hitaj et al., 2017); if the noise is not sufficiently strong, participants’ information will be leaked. Dropout training randomly masks a fraction of the parameters, making the participants have access to only part of the parameters in each iteration. However, knowing part of the parameters is sufficient for conducting the attacks.
1.1 Our Contributions
To protect participants’ data from leaking during collaborative learning, we propose a method called Double-Blind Collaborative Learning (DBCL). DBCL applies random sketching to every or some layers of a neural network, and the random sketching varies from iteration to iteration. Throughout the training, the participants never see the model parameters, and the server does not see any real (stochastic) gradient or descending direction. This is why our method is called double-blind.
DBCL has the following nice properties. First, DBCL defeats the property inference attack developed by Melis et al. (2019). Second, DBCL does not hinder prediction accuracy at all. Third, DBCL does not increase the per-iteration time complexity and communication complexity, although it reasonably increases the iterations to converge. Last but not least, to apply DBCL to fully-connected layers and convolutional layers, no additional tuning is needed. In sum, DBCL can increase security at little cost.
Furthermore, we realize that the property inference attack developed by Melis et al. (2019) can be applied to decentralized ML Bianchi et al. (2013); Lan et al. (2017); Lian et al. (2017); Ram et al. (2010a, b); Sirb and Ye (2016); Srivastava and Nedic (2011). A typical decentralized algorithm works in this way: a node collects its neighbors’ model parameters, takes the average, and then performs a local SGD Lian et al. (2017). See the illustration in Figure 1(b). DBCL can be straightforwardly applied to protect decentralized ML.
Fully-connected (FC) layer.
Let be the input shape, be the output shape, and be the batch size. Let be the input, be the parameter matrix, and be the output. After the FC layer, there is typically an activation function applied elementwisely.
Let be the loss evaluated on a batch of training samples. We derive backpropagation for the FC layer defined in the above. (We follow the convention of PyTorch.) Let be the gradient received from the upper layer. We need to compute
which can be established by the chain rule. We use to update the parameter matrix by e.g., , and pass to the lower layer.
Uniform sampling matrix.
We call a uniform sampling matrix if its columns are sampled from the columns of the the set
uniformly at random. Here, is the -th standard basis of . We call a uniform sampling matrix because contains randomly sampled columns of . For any and , , and is bounded.
We call a count sketch matrix if it is constructed in the following way. Every row of has exactly one nonzero entry whose position is randomly sampled from and value is sampled from . Here is an example of ():
Given , the sketch can be computed in time. CountSketch is much faster than the standard matrix multiplication which has time complexity. For any and , , and is bounded. In practice, is never explicitly constructed. CountSketch has a sparse , but it is not as sparse as the uniform sampling.
Matrix sketching means constructing a small matrix which captures much of the useful information in a big matrix. In the above, is a sketch of . Uniform sampling and CountSketch are the fastest matrix sketching methods. More “accurate” but more expensive sketching methods include Gaussian projection (Johnson and Lindenstrauss, 1984), subsampled randomized Hadamard transform (Drineas et al., 2011; Lu et al., 2013; Tropp, 2011), leverage score sampling Drineas et al. (2006, 2008, 2012), etc. The readers can refer to the surveys (Drineas and Mahoney, 2016; Halko et al., 2011; Mahoney, 2011; Woodruff, 2014).
To understand why we use random project rather than sampling, we use the example of Drineas et al. (2008). Let be people’s feature, for example,
Uniform sampling returns a scaled subset of , e.g., . CountSketch may produce
whose entries does not have any concrete meaning. Given and the sketching matrix , one cannot recover . Let be random sketching matrices,
3 Threat Model
3.1 Attacking Centralized Collaborative ML
Here we describe the PIA proposed by Melis et al. (2019). Centralized collaborative ML, aka federated ML, has a central parameter server that coordinates the participants (aka clients, workers, etc.) In every round, the server broadcasts the current model parameters to the participants, the participants use the model and their local data to compute ascending directions (e.g., gradients or stochastic gradients), and the server updates the model parameters by
Since a participant (say the -th) knows , , and his own direction , he can use the above equation to calculate the sum of other participants’ directions: . In the case of two-party collaborative ML, that is, , one participant knows the updating direction of the other participant in every iteration.
Since is computed using a batch of the -th participant’s data, it is not surprising that can disclose the participant’s properties. Melis et al. (2019) proposed to take as input feature and train a classifier (e.g., logistic regression) to predict the participant’s properties. (The classifier is irrelevant to the model parameters.) The participant’s data cannot be recovered, however, the classifier can tell “the participant is likely female”, “the participant’s income is very likely low”, etc. The PIA is very effective in the 2-party experiments; it becomes less effective as the number of participants, , increases.
3.2 Attacking Decentralized Collaborative ML
We realize that the PIA proposed by Melis et al. (2019) can be extended to decentralized ML. In decentralized ML, there is no central server; every participant (i.e., a worker node) is connected with several others which are called “neighbors”. There are different decentralized optimization algorithms (Bianchi et al., 2013; Yuan et al., 2016; Sirb and Ye, 2016; Colin et al., 2016; Lian et al., 2017; Lan et al., 2017; Tang et al., 2018). For example, every worker node aggregates its neighbors’ model parameters, take a weighted average of theirs and his own parameters as the intermediate parameters, and then locally perform an SGD update.
This kind of decentralized ML is vulnerable to the PIA. One participant (the adversary) knows his neighbor’s (victim) parameters in two consecutive rounds, denote and . The difference is mainly the gradient evaluated on the victim’s data.111Besides the victim’s gradient, contains the victim’s neighbors’ gradients, but their weights are lower than the victim’s gradient. The adversary can use a classifier in the same way as (Melis et al., 2019) to infer the victim’s properties.
4 Double-Blind Collaborative Learning (DBCL)
4.1 High-Level Ideas
A reason why collaborative ML is unsafe is that the true model parameters are known to every participant and that a participant’s ascending direction, such as gradient, is known to the others. To address the shortcoming with collaborative ML, we propose Double-Blind Collaborative Learning (DBCL). As its name suggests, using DBCL, the server cannot see the participants’ ascending directions, and the participants can never see the true model parameters throughout. Instead, the server broadcasts sketched model parameters, and the participants evaluate gradients using sketched parameters and sketched local data. By varying the sketching matrix in every iteration, it is impossible to infer the difference between the parameters in two consecutive rounds, and thus the property inference attack (PIA) will fail.
4.2 Algorithm Description
We describe DBCL for collaborative learning with a central parameter server. In the following, we discuss backpropagation for sketched fully-connected (FC) layer; backpropagation for convolutional layer is similar because convolution can be expressed as matrix multiplication.
The central server generates a new seed and then a random sketch: . It then broadcasts and to all the devices through message passing.
Local forward pass.
The -th device randomly selects a batch of samples from its local data and then locally performs a forward pass. Suppose the input of the sketched FC layer is . The device uses the seed to draw a sketch and outputs . Let be the loss evaluated on the batch of samples after the forward pass.
Local backward pass.
Let the local gradient propagated to this sketched FC layer be . The device calculates
The gradient is propagated to the lower layer to continue the backpropagation.
The server collects for all (or some) through communication and computes . Let be the loss evaluated on training samples. It can be shown that
The server then updates the parameters by, e.g., .
4.3 Pros and Cons
DBCL does not hinder test accuracy at all; on contrast, like dropout, DBCL can be thought of a regularization that improves generalization. DBCL does not need extra tuning; setting sketch to half of the input shape typically works very well. DBCL has a slightly better per-iteration time complexity and communication complexity. Compared with standard collaborative ML, DBCL needs reasonably more iterations to converge, which is similar to the behavior of dropout training.
Per-iteration time complexity.
The computation performed by a worker node is dominated by the matrix-matrix multiplications in the forward and backward passes. Backpropagation for a layer with input shape of and output shape of has a time complexity of where is the batch size of mini-batch SGD. If is a CountSketch matrix, then the backpropagation for this layer has time complexity. Since , sketching actually makes backpropagation faster. Specifically, the per-time complexity is reduced to of the original.
Per-iteration communication complexity.
The standard collaborative ML needs to communicate and . Using sketching, DBCL will communicate and . The per-iteration communication complexity (number of words) is reduced to of the original.
In every iteration, the server and worker nodes need to agree upon the sketching matrix which is seemingly a big overhead. However, it is unnecessary to communication . Instead, they just need to agree upon the random seed and then use the same pseudo-random generator to produce the same .
Currently, the convergence rate of DBCL is unclear (even for convex optimization.) Our empirical studies show that sketching slows down the convergence, and smaller sketching size will lead to slower convergence. This is not surprising, because it is well known that dropout makes convergence reasonably slower for the sake of better generalization.
Our empirical study shows that DBCL does not harm test accuracy at all; on the contrary, it alleviates overfitting. If is a uniform sampling matrix, then DBCL is exactly dropout training. Since random projections such as CountSketch has similar properties as uniform sampling, it is not surprising that DBCL does not hinder test accuracy.
4.4 Theoretical Perspectives
Defend property inference attack (PIA).
The reason why DBCL makes PIA fail can be explained in the following way. Let and be the parameter matrices of two consecutive iterations. What a participant sees are
The participant also knows and . However, with the four matrices, the participant does not have enough information to recover , and thus he cannot conduct the PIA. Blocki et al. (2012) showed random projection preserves differential privacy, but their theory is not applicable to our problem.
Let . It is easy to show that is a unbiased estimate of : where the expectiation is taken w.r.t. the randomness in the two sketching matrices. If the “mass” of is substantially larger than , then the signal (i.e., ) will be overwhelmed by random noise. We formally state the problem in the following conjecture; it has been empirically verified but not formally proved. If the conjecture is proved, it will imply that for any linear classifier, using will not be better than a random guess.
Let be any fix matrix, e.g., the parameters in the classifier of PIA. Assume first, the “mass” of is substantially bigger than , and second, the sketching matrices are not too sparse. Then is no better than a random guess in the sense that
Here, can be arbitrarily small, and it depends on and the sparsity of the sketching.
In the following, we discuss the necessary conditions for the conjecture to hold. The sketching matrix must vary with iteration; otherwise, the conjecture cannot hold. If , then would be
It is very well known from the literature of randomized linear algebra that can be very close to , and thus the bound in the conjecture would not hold.
The matrices much have much larger “mass” than the matrix, which is the case in deep learning. What if has a small mass? For example, if , then will be , and it can be used to approximate .
The sketching matrix cannot be too sparse; otherwise, the conjecture cannot hold. Suppose are random sampling matrices, which has only nonzero entries, and . Then the sketches
reveal of and , respectively. The two random sampling matrices have overlap. So it will reveal of and is therefore better than a random guess. Thus, using uniform sampling does not satisfy the conjecture, and this is why dropout does not defend PIA.
Convergence of DBFL.
Our empirical study shows sketching does not hinder training and test accuracy, although sketching slows down the convergence. It is highly interesting to prove the convergence rate under some assumptions, e.g., convex and smoothness. If we apply DBFL to a generalized linear model, then the training will be solving the following problem:
where are the training samples and measures the difference between the prediction and the target. If is uniform sampling matrix, then the problem will be empirical risk minimization with dropout. The contemporaneous work (Khaled and Richtárik, 2019) analyzed gradient descent with dropout and showed the algorithm converges to a neighborhood of
4.5 Empirical Study of the Convergence
We build a four-layer convolutional neural network which has two convolutional layers and two fully-connected (FC) layers. We apply CountSketch (with to all the layers except for the output FC layer. We conduct experiments on the MNIST dataset which has training samples of images. We report the results in Figure 1. After 17 epochs, the validation error (using the server’s parameters) drops below and soon drops below . This accuracy is comparable to the best results for MNIST, which indicates that sketching does not affect accuracy. The validation error then remains below throughout, which means overfitting does not happen.
5 Related Work
This work is most relevant to (Melis et al., 2019) in which the property inference attack (PIA) is proposed; our work is proposed to defend this PIA. An earlier work (Hitaj et al., 2017) is the first to realize that collaborative ML can leak users’ privacy; it is yet unclear whether our proposed DBCL can defend their attack, and we will find it out. The very recent work (Chen et al., 2018; Triastcyn and Faltings, 2019) proposed to defend the attacks, but they substantially hurt accuracy and efficiency to some extent. Besides, DBCL can be integrated into their frameworks.
Our methodology is based on matrix sketching (Johnson and Lindenstrauss, 1984; Drineas et al., 2008; Halko et al., 2011; Mahoney, 2011; Woodruff, 2014; Drineas and Mahoney, 2016). In particular, we find that after a random projection, “signals” leaks participants’ data will be outweighed by “noise”. After finishing this work, we realize that an earlier work (Blocki et al., 2012) found that random projection preserves differential privacy. However, the theory of (Blocki et al., 2012) does not solve our question (Conjecture 4.4).
We develop our algorithmic framework based on the connection between sketching (Woodruff, 2014) and dropout training (Srivastava et al., 2014); in particular, if is uniform sampling, then DBCL is essentially dropout. Our approach is different from (Hanzely et al., 2018) which directly applies matrix sketching to the true gradients; DBCL, as well as dropout, applies matrix sketching to the data and model parameters. Our approach is somehow similar to the contemporaneous work (Khaled and Richtárik, 2019) which is developed for computational benefits.
Collaborative machine learning (ML) has attracted plenty of attention in recent years for it enables multiple parties to jointly learn a model without data sharing. Unfortunately, prior work showed that collaborative ML in the standard way can easily leak participants privacy. We propose Double-Blind Collaborative Machine (DBCL) for improving the security of collaborative ML; in particular, we show DBCL can defeat the property inference attack. While it improves security, it does not hurt the accuracy and much increasing the computational cost. This work leaves two open questions regarding the security and convergence.
The author thanks Michael Mahoney, Richard Peng, Peter Richtárik, and David Woodruff for their helpful suggestions.
- Ateniese et al. (2015) Giuseppe Ateniese, Luigi V. Mancini, Angelo Spognardi, Antonio Villani, Domenico Vitali, and Giovanni Felici. Hacking smart machines with smarter ones: How to extract meaningful data from machine learning classifiers. International Journal of Security and Networks, 10(3):137–150, September 2015. ISSN 1747-8405.
- Bianchi et al. (2013) Pascal Bianchi, Gersende Fort, and Walid Hachem. Performance of a distributed stochastic approximation algorithm. IEEE Transactions on Information Theory, 59(11):7405–7418, 2013.
- Blocki et al. (2012) Jeremiah Blocki, Avrim Blum, Anupam Datta, and Or Sheffet. The Johnson-Lindenstrauss transform itself preserves differential privacy. In Annual Symposium on Foundations of Computer Science (FOCS), 2012.
- Chen et al. (2018) Qingrong Chen, Chong Xiang, Minhui Xue, Bo Li, Nikita Borisov, Dali Kaarfar, and Haojin Zhu. Differentially private data generative models. arXiv preprint arXiv:1812.02274, 2018.
- Colin et al. (2016) Igor Colin, Aurélien Bellet, Joseph Salmon, and Stéphan Clémenccon. Gossip dual averaging for decentralized optimization of pairwise functions. arXiv preprint arXiv:1606.02421, 2016.
- Drineas and Mahoney (2016) Petros Drineas and Michael W Mahoney. RandNLA: randomized numerical linear algebra. Communications of the ACM, 59(6):80–90, 2016.
- Drineas et al. (2006) Petros Drineas, Michael W. Mahoney, and S. Muthukrishnan. Sampling algorithms for regression and applications. In Annual ACM-SIAM Symposium on Discrete Algorithm (SODA), 2006.
- Drineas et al. (2008) Petros Drineas, Michael W. Mahoney, and S. Muthukrishnan. Relative-error CUR matrix decompositions. SIAM Journal on Matrix Analysis and Applications, 30(2):844–881, September 2008.
- Drineas et al. (2011) Petros Drineas, Michael W. Mahoney, S. Muthukrishnan, and Tamás Sarlós. Faster least squares approximation. Numerische Mathematik, 117(2):219–249, 2011.
- Drineas et al. (2012) Petros Drineas, Malik Magdon-Ismail, Michael W. Mahoney, and David P. Woodruff. Fast approximation of matrix coherence and statistical leverage. Journal of Machine Learning Research, 13:3441–3472, 2012.
- Dwork (2011) Cynthia Dwork. Differential privacy. Encyclopedia of Cryptography and Security, pages 338–340, 2011.
- Dwork and Naor (2010) Cynthia Dwork and Moni Naor. On the difficulties of disclosure prevention in statistical databases or the case for differential privacy. Journal of Privacy and Confidentiality, 2(1), 2010.
- Fredrikson et al. (2015) Matt Fredrikson, Somesh Jha, and Thomas Ristenpart. Model inversion attacks that exploit confidence information and basic countermeasures. In Proceedings of the 22nd ACM SIGSAC Conference on Computer and Communications Security, 2015.
- Halko et al. (2011) Nathan Halko, Per-Gunnar Martinsson, and Joel A. Tropp. Finding structure with randomness: Probabilistic algorithms for constructing approximate matrix decompositions. SIAM Review, 53(2):217–288, 2011.
- Hanzely et al. (2018) Filip Hanzely, Konstantin Mishchenko, and Peter Richtárik. SEGA: Variance reduction via gradient sketching. In Advances in Neural Information Processing Systems (NeurIPS), 2018.
- Hitaj et al. (2017) Briland Hitaj, Giuseppe Ateniese, and Fernando Perez-Cruz. Deep models under the GAN: information leakage from collaborative deep learning. In Proceedings of the 2017 ACM SIGSAC Conference on Computer and Communications Security, 2017.
- Johnson and Lindenstrauss (1984) William B. Johnson and Joram Lindenstrauss. Extensions of Lipschitz mappings into a Hilbert space. Contemporary mathematics, 26(189-206), 1984.
- Khaled and Richtárik (2019) Ahmed Khaled and Peter Richtárik. Gradient descent with compressed iterates. arXiv, 2019.
- Lan et al. (2017) Guanghui Lan, Soomin Lee, and Yi Zhou. Communication-efficient algorithms for decentralized and stochastic optimization. Mathematical Programming, pages 1–48, 2017.
- Lian et al. (2017) Xiangru Lian, Ce Zhang, Huan Zhang, Cho-Jui Hsieh, Wei Zhang, and Ji Liu. Can decentralized algorithms outperform centralized algorithms? a case study for decentralized parallel stochastic gradient descent. In Advances in Neural Information Processing Systems (NIPS), 2017.
- Lu et al. (2013) Yichao Lu, Paramveer Dhillon, Dean P Foster, and Lyle Ungar. Faster ridge regression via the subsampled randomized Hadamard transform. In Advances in Neural Information Processing Systems (NIPS), 2013.
- Mahoney (2011) Michael W. Mahoney. Randomized algorithms for matrices and data. Foundations and Trends in Machine Learning, 3(2):123–224, 2011.
- McMahan et al. (2017) Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. Communication-efficient learning of deep networks from decentralized data. In Artificial Intelligence and Statistics (AISTATS), 2017.
- Melis et al. (2019) Luca Melis, Congzheng Song, Emiliano De Cristofaro, and Vitaly Shmatikov. Exploiting unintended feature leakage in collaborative learning. IEEE, 2019.
- Ram et al. (2010a) S Sundhar Ram, Angelia Nedić, and Venu V Veeravalli. Asynchronous gossip algorithm for stochastic optimization: Constant stepsize analysis. In Recent Advances in Optimization and its Applications in Engineering, pages 51–60. Springer, 2010a.
- Ram et al. (2010b) S Sundhar Ram, Angelia Nedić, and Venugopal V Veeravalli. Distributed stochastic subgradient projection algorithms for convex optimization. Journal of optimization theory and applications, 147(3):516–545, 2010b.
- Sirb and Ye (2016) Benjamin Sirb and Xiaojing Ye. Consensus optimization with delayed and stochastic gradients on decentralized networks. In 2016 IEEE International Conference on Big Data (Big Data), pages 76–85. IEEE, 2016.
- Srivastava and Nedic (2011) Kunal Srivastava and Angelia Nedic. Distributed asynchronous constrained stochastic optimization. IEEE Journal of Selected Topics in Signal Processing, 5(4):772–790, 2011.
- Srivastava et al. (2014) Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research, 15(1):1929–1958, 2014.
- Tang et al. (2018) Hanlin Tang, Xiangru Lian, Ming Yan, Ce Zhang, and Ji Liu. D2: Decentralized training over decentralized data. arXiv preprint arXiv:1803.07068, 2018.
- Triastcyn and Faltings (2019) Aleksei Triastcyn and Boi Faltings. Federated generative privacy. EPFL Tech. Report, 2019.
- Tropp (2011) Joel A Tropp. Improved analysis of the subsampled randomized hadamard transform. Advances in Adaptive Data Analysis, 3(01n02):115–126, 2011.
- Woodruff (2014) David P Woodruff. Sketching as a tool for numerical linear algebra. Foundations and Trends® in Theoretical Computer Science, 10(1–2):1–157, 2014.
- Yuan et al. (2016) Kun Yuan, Qing Ling, and Wotao Yin. On the convergence of decentralized gradient descent. SIAM Journal on Optimization, 26(3):1835–1854, 2016.
Appendix A Algorithm Derivation
a.1 Fully-Connected (FC) Layers
For simplicity, we study the batch size case for an FC layer. Let be the input, be the parameter matrix, () be a sketching matrix, and
be the output (during training). For test, sketching is not applied, equivalently, .
a.2 Extension to Convolutional Layers
Let be a tensor and be a kernel. The convolution outputs a matrix. The convolution can be equivalently written as matrix-vector multiplication in the following way.
We segment to patches and then reshape every patch to a -dimensional vector. Let be the -th patch (vector). Tensor has such patches. Let
be the concatenation of the patches.
Let be the vectorization of the kernel . The matrix-vector product, , is indeed the vectorization of the convolution .
In practice, we typically use multiple kernels for the convolution; let be the concatenation of different (vectorized) kernels. In this way, the convolution of with different kernels, which outputs a tensor, is the reshape of .
We show in the above that tensor convolution can be equivalently expressed as matrix-matrix multiplication. Therefore, we can apply matrix sketching to convolutional layer in the same way as the FC layer. Specifically, let be a random sketching matrix. Then is an approximation to , and the backpropagation is accordingly derived using matrix differentiation.