Unsupervised Learning with
Stein’s Unbiased Risk Estimator
Learning from unlabeled and noisy data is one of the grand challenges of machine learning. As such, it has seen a flurry of research with new ideas proposed continuously. In this work, we revisit a classical idea: Stein’s Unbiased Risk Estimator (SURE). We show that, in the context of image recovery, SURE and its generalizations can be used to train convolutional neural networks (CNNs) for a range of image denoising and recovery problems without any ground truth data.
Specifically, our goal is to reconstruct an image from a noisy linear transformation (measurement) of the image. We consider two scenarios: one where no additional data is available and one where we have measurements of other images that are drawn from the same noisy distribution as , but have no access to the clean images. Such is the case, for instance, in the context of medical imaging, microscopy, and astronomy, where noise-less ground truth data is rarely available.
We show that in this situation, SURE can be used to estimate the mean-squared-error loss associated with an estimate of . Using this estimate of the loss, we train networks to perform denoising and compressed sensing recovery. In addition, we also use the SURE framework to partially explain and improve upon an intriguing results presented by Ulyanov et al. in DeepImagePrior (): that a network initialized with random weights and fit to a single noisy image can effectively denoise that image.
Unsupervised Learning with
Stein’s Unbiased Risk Estimator
Christopher A. Metzler Rice University Houston, TX 77005 firstname.lastname@example.org Ali Mousavi Rice University Houston, TX 77005 email@example.com Reinhard Heckel Rice University Houston, TX 77005 firstname.lastname@example.org Richard G. Baraniuk Rice University Houston, TX 77005 email@example.com
noticebox[b]Preprint. Work in progress.\end@float
In this work we consider reconstructing an unknown image from measurements of the form , where are the measurements, is the linear measurement operator, and denotes noise. This problem arises in numerous applications, including denoising, inpainting, superresolution, deblurring, and compressive sensing. The goal of an image recovery algorithm is to use prior information about the image’s distribution and knowledge of the measurement operator to reconstruct .
The key determinant of an image recovery algorithm’s accuracy is the accuracy of its prior. Accordingly, over the past decades a large variety of increasingly complex image priors have been considered, ranging from simple sparse models SureShrink (), to non-local self-similarity priors BM3D (), all the way to neural network based priors, in particular CNN based priors DnCNN (). Among these methods, CNN priors often offer the best performance. It is widely believed that key to the success of CNN based priors is the ability to process and learn from vast amounts of training data, although recent work suggests the structure of a CNN itself encodes strong prior information DeepImagePrior (). Given a neural network based prior, an image can be reconstructed by enforcing the prior via empirical risk minimization hand_global_2017 (), or by training a neural network to directly reconstruct the image.
CNN image recovery methods are typically trained by taking a representative set of images , drawn from the same distribution as , and capturing a set of measurements , either physically or in simulation. The network then learns the mapping from observations back to images by minimizing a loss function; typically the mean-squared-error (MSE) between and . This presents a challenge in applications where we do not have access to example images. Moreover, even if we have example images, we might have a large set of measurements as well, and would like to use that set to refine our reconstruction algorithm.
In the nomenclature of machine learning, the measurements would be considered features and the images the labels. Thus when the training problem simplifies to learning from noisy labels. When the training problem is to learn from noisy linear transformations of the labels.
Learning from noisy labels has been extensively studied in the context of classification; see for instance natarajan2013learning (); xiao2015learning (); liu2016classification (); sukhbaatar2014training (); sukhbaatar2014learning (). However, the problem of learning from noisy data has been studied far less in the context of image recovery. In this work we show that the SURE framework can be used to i) denoise an image with a neural network (NN) without any training data, ii) train NNs to denoise images from noisy training data, and iii) train a NN, using only noisy measurements, to solve the compressive sensing problem.
A very recent preprint overlaps with the some of the aspirations of the second of our contributions. Specifically, MCSUREtraining () also demonstrates that SURE can be used to train CNN based denoisers without ground truth data. However, in this work, we go significantly beyond the setup in MCSUREtraining () and show that SURE can be applied much more broadly.
2 SURE and its Generalizations
The goal of this work is to reconstruct an image from a noisy linear observations , and knowledge of the linear measurement operator . In addition to , we assume that we are given training measurements but not the images that produced them (we also consider the case where no training measurements are given). Without access to we cannot fit a model that minimizes the MSE, but we can minimize a loss based on Stein’s Unbiased Risk Estimator (SURE). In this section, we introduce SURE and its generalizations.
SURE is a model selection technique that was first proposed by its namesake in SURE (). SURE provides an unbiased estimate of the MSE of an estimator of the mean of a Gaussian distributed random vector, with unknown mean. Let denote a vector we would like to estimate from noisy observations where . Also, assume is a weakly differentiable function parameterized by which receives noisy observations as input and provides an estimate of as output. Then, according to SURE (), we can write the expectation of the MSE with respect to the random variable as
where stands for divergence and is defined as
Note that two terms within the SURE loss (1) depend on the parameter . The first term, minimizes the difference between the estimate and the observations (bias). The second term, penalizes the denoiser for varying as the input is changed (variance). Thus SURE is a natural way to control the bias variance trade-off of a recovery algorithm.
The central challenge in using SURE in practice is computing the divergence. With advanced denoisers the divergence is hard or even impossible to express analytically.
MC-SURE is a Monte Carlo method to estimate the divergence, and thus the SURE loss, that was proposed in MCSURE (). In particular, the authors show that for bounded functions we have
where is an i.i.d. Gaussian distributed random vector with unit variance elements.
Following the law of large numbers, this expectation can be approximated with Monte Carlo sampling. Thanks to the high dimensionality of images, a single sample well approximates the expectation. The limit can be approximated well by using a small value for ; we use throughout. This approximation leaves us with
GSURE was proposed in GSURE () to estimate the MSE associated with estimates of from a linear measurement , where , and has known covariance and follows any distribution from the exponential family. For the special case of i.i.d. Gaussian distributed noise the estimate simplifies to
where denotes orthonormal projection onto the range space of and is the pseudoinverse of . Note that while this expression involves the unknown , it can be minimized with respect to without knowledge of .
Minimizing SURE requires propagating gradients with respect to the Monte Carlo estimate of the divergence (4). This is challenging to do by hand, but made easy by TensorFlow’s and PyTorch’s auto-differentiation capabilities, which used throughout much of our experiments below.
3 Denoising Without Pre-Training
CNNs offer state-of-the-art performance for many image recovery problems including super-resolution SRCNN (), inpainting yang2017high (), and denoising DnCNN (). Typically, these networks are trained on a large dataset of images. However, it was recently shown that if just fit to a single corrupted image, without any pre-training, CNNs can still perform effective image recovery. Specifically, the recent work Deep Image Prior DeepImagePrior () demonstrated that a randomly initialized expansive CNN, trained so that for a fixed input the output matches the corrupted image well, performs exceptionally well at the aforementioned inverse problems.
Similar to more traditional image recovery methods like BM3D BM3D (), the deep image prior method only exploits structure within a given image and does not use any external training data. In this paper, we consider a slight variation of the original deep image prior that empirically performs better on the examples considered. This variation—as well as the original deep prior work—requires the training to be stopped early, and the performance depends on hitting the right stopping time. We then show that, in the context of denoising, training using the SURE loss allows us to train without early stopping and improves the performance over the original deep image prior trained with least squares loss.
3.1 Deep Image Prior
The deep image prior is a CNN, denoted by and parameterized by the weight and bias vector that maps a latent representation to an image . The paper DeepImagePrior () proposes to minimize
over the parameters , starting from a random initialization, using a method such as gradient descent. Here, is a loss function, chosen as the squared -norm (i.e., , and is the latent representation, potentially chosen at random.
In this paper, we consider a variant of the original Deep Image Prior work, where is set equal to . Following DeepImagePrior (), our goal is to train the network such that it fits the image but not the noise. If the network is trained using too few iterations, then the network does not represent the image well. If the network is trained until convergence, however, then the network’s parameters are overfit and describe the noise along with the actual image.
Thus, there is a sweet spot, where the network minimizes the error between the true image and the reconstructed image. This is illustrated in Figure 1(a). To obtain Figure 1(a), we repeated the experiments in DeepImagePrior () with a very large U-net architecture Unet () where we replaced with the the noisy observation . We note that this network offered performance superior the the network originally used in DeepImagePrior (), which is initialized with a random and not the noisy image itself. However, both choosing and choosing randomly performs very similar, and qualitatively both choices behave the same in that both require early stopping to achieve optimal performance.
In more detail, U-net Unet () consists of a cascade convolution and relu layers that are sequentially downsampled and then upsampled. The convolutional layers at each “level” of upsampling/downsampling are connected together by skip connections. The U-net we used had features at each convolution layer. This particular recovery problem was to denoise a Mandrill image which was contaminated with additive white Gaussian noise (AWGN) with standard deviation 25. It can be seen that after a few iterations, the network begins to overfit to the noisy image and normalized MSE (NMSE) starts to increase. Unfortunately, by itself the training loss gives little indication about when to stop training; it smoothly decays towards .
3.2 SURE Deep Image Prior
Returning to Figure 1(a), we observe that the MSE is minimized at a point where the training loss is reasonably small, but at the same time the network divergence (scaled by ) is not too large; thus the sum of the two terms is small. In fact, if one plots the SURE estimate of the loss (not shown here), then it lies directly on top of the observed NMSE.
Inspired by this observation, we propose to train with the SURE loss instead of the fidelity or loss. The results are shown in Figure 1(b). It can be seen that not only does the network avoid overfitting, the final NMSE is superior to that achieved by optimal stopping with the fidelity loss.
We tested the performance of this large U-net trained only on the image to be denoised against state-of-the-art trained and untrained algorithms DnCNN DnCNN () and CBM3D BM3D (). Results are presented in Figure 2. When trained with the SURE loss the U-net trained only on the image to be denoised offers performance competitive with these two methods. However, the figure also demonstrates that the deep prior method is slow and that training across additional images, as exemplified by DnCNN, is beneficial.
The experiments in this section were conducted using Torch implementations of U-net and MatConvnet implementations of DnCNN. The BM3D and CBM3D experiments were used Matlab implementations, based on compiled C. Public implementations of the algorithms can be found in the following locations: DnCNN MatConvnet: https://github.com/cszn/DnCNN; U-net Torch: https://github.com/cszn/DnCNN; BM3D Matlab: http://www.cs.tut.fi/~foi/GCF-BM3D/index.html#ref_software. The U-net was trained using the Adam optimizer ADAMopt () with a learning rate of 0.001.
Throughout the paper we report recovery accuracies in terms of PSNR, defined as
where all images lie in the range . All experiments were performed on a desktop with an Intel 6800K CPU and an Nvidia Pascal Titan X GPU.
4 Denoising with Noisy Training Data
In the previous section we showed that SURE can be used to fit a CNN denoiser without training data. This is useful in regimes where no training data whatsoever is available. However, as the previous section demonstrated, these untrained neural networks are computationally very expensive to apply.
In this section we focus on the scenario where we can capture additional noisy training data with which to train a neural network for denoising. We show that if provided a set of noisy images , training the network with the SURE loss improves upon training with the MSE loss. Specifically, we optimize the network’s parameters by minimizing
As before we will use the Monte-Carlo estimate of the divergence (3).
4.1 DnCNN and Experimental Setup
In this section, we consider the DnCNN image denoiser, trained on grayscale images pulled from Berkeley’s BSD-500 dataset BSDDataset (), and trained using the MSE and SURE loss. Example images from this dataset are presented in Figure 4. The training images were cropped, rescaled, flipped, and rotated to form a set of 204 800 overlapping patches. We tested the methods on 6 standard test images presented in Figure 4.
The DnCNN image denoiser DnCNN () consists of 16 sequential convolutional layers with features at each layer. Sandwiched between these layers are ReLU and batch-normalization operations. DnCNN has a single skip connection to its output and is trained using residual learning residuallearning ().
The experiments in this section were performed using a Tensorflow implementations of DnCNN available at https://github.com/crisb-DUT/DnCNN-tensorflow. We trained all networks for 50 epochs using the Adam optimizer ADAMopt () with a training rate of 0.001 which was reduced to 0.0001 after 30 epochs. We used mini-batches of 128 patches.
The results of training DnCNN using the MSE and SURE loss are presented in Figure 5 and Table 1. The results show that training DnCNN with SURE on noisy data results in reconstructions almost equivalent to training with the true MSE on clean images. Moreover, both DnCNN trained with SURE and with the true MSE outperform BM3D. As expected, because calculating the SURE loss requires two calls to the denoiser, it takes roughly twice as long to train as does training with the MSE.
|DnCNN SURE||8.1 hrs||0.01 sec||29.1||28.9||29.2||32.4||26.1||26.6|
|DnCNN MSE||4.3 hrs||0.01 sec||29.5||29.2||29.6||33.1||26.4||26.8|
5 Compressive Sensing Recovery with Noisy Measurements
In this section we study the problem of image recovery from linear undersampled measurements. The main novelty of this section is that we do not have training labels (i.e., ground truth images). Instead and unlike conventional learning approaches in compressive sensing (CS), we train the recovery algorithm using only noisy undersampled linear measurements.
The proposed method is closely related to Parameterless Optimal AMP ParameterlessAMP (), which used the SURE loss to tune AMP AMP () to reconstruct signal with unknown sparsity in a known dictionary. The proposed method is somewhat related to blind CS BlindCS () wherein signals that are sparse with respect to an unknown dictionary are reconstructed from compressive measurements. At a high level, the proposed method can be considered a practical form of universal CS jalali2014minimum (); BaronUniversalCS1 (); BaronUniversalCS2 (); jalali2017universal () wherein a signal’s distribution is estimated from noisy linear measurements and used to reconstruct the signal.
We used the Learned Denoising-based Approximate Message Passing (LDAMP) network proposed in LDAMP () as our recovery method. LDAMP offers state-of-the-art CS recovery when dealing with measurement matrices with i.i.d. Gaussian distributed elements. LDAMP is an unrolled version of the DAMP DAMP () algorithm that decouples signal recovery into a series of denoising problems solved at every layer. In other words and as shown in (5.1), LDAMP receives a noisy version of (i.e., ) at every layer and tries to reconstruct by eliminating the effective noise vector .
The success of LDAMP stems from its Onsager correction term (i.e., the last term on the first line of (5.1)) that removes the bias from intermediate solutions. As a result, at each layer the effective noise term follows a Gaussian distribution whose variance is accurately predicted by MalekiThesis ().
In this work, we use a layer LDAMP network, where each layer itself contains a layer DnCNN denoiser . Below we use to denote LDAMP.
5.2 Training LDAMP and Experimental Setup
We compare three methods of training LDAMP. All three methods utilize layer-by-layer training, which in the context of LDAMP is minimum MSE optimal LDAMP ().
The first method, LDAMP MSE, simply minimizes the MSE with respect to the training data. The second method, LDAMP SURE, uses SURE to train LDAMP using only noisy measurements. This method takes advantage of the fact that at each layer LDAMP is solving denoising problems with known variance
The third and final method, LDAMP GSURE, uses generalized SURE to train LDAMP using only noisy measurements.
The SURE and GSURE methods both rely upon Monte-Carlo estimation of the divergence (3).
The experiments in this section were performed using a TensorFlow implementation of LDAMP available online at https://github.com/ricedsp/D-AMP_Toolbox. We used dense Gaussian measurement matrices for our low resolution training and coded diffraction patterns, which offer multiplications, for our high resolution testing. For both training and testing we sampled with and . See LDAMP () for more details about the measurement process. Other experimental settings such as batch sizes, learning rates, etc. are the same as those in Section 4.1.
The networks resulting from training LDAMP with the MSE, SURE, and GSURE losses are compared to BM3D-AMP DAMP () in Figure 6 and Table 2. The results demonstrate LDAMP networks trained with SURE and MSE both perform roughly on par with BM3D-AMP and run significantly faster. The results also demonstrate that LDAMP trained with GSURE offers significantly reduced performance. This can be understood by returning to the original GSURE cost (5). Minimizes GSURE minimizes the distance between and not between and , where denotes orthogonal projection onto the range space of . In the context of compressive sensing, the range space is small and so these two distances are not necessarily close to one another.
|34.7 hrs||0.4 sec||30.6||31.9||31.7||36.6||26.4||28.0|
|43.2 hrs||0.4 sec||26.3||25.7||25.9||28.6||24.5||24.4|
|20.7 hrs||0.4 sec||32.1||34.6||35.0||38.6||27.3||28.6|
We have made three distinct contributions. First we showed that SURE can be used to denoise an image using a CNN without any training data. Second, we demonstrated that SURE can be used to train a CNN denoiser using only noisy training data. Third, we showed that SURE can be used to train a neural network, using only noisy measurements, to solve the compressive sensing problem.
In the context of imaging, our work suggests a new hands-off approach to reconstruct images. Using SURE, one could toss a sensor into a novel imaging environment and have the sensor itself figure out and then apply the appropriate prior to reconstruct images.
In the context of machine learning, our work suggests that divergence may be an overlooked proxy for variance in an estimator. Thus, while SURE is applicable in only fairly specific circumstances, penalizing divergence could be applied more broadly as a tool to help attack overfitting.
- (1) D. Ulyanov, A. Vedaldi, and V. S. Lempitsky, “Deep image prior,” CoRR, vol. abs/1711.10925, 2017.
- (2) D. L. Donoho and I. M. Johnstone, “Adapting to unknown smoothness via wavelet shrinkage,” Journal of the american statistical association, vol. 90, no. 432, pp. 1200–1224, 1995.
- (3) K. Dabov, A. Foi, V. Katkovnik, and K. Egiazarian, “Image denoising by sparse 3-d transform-domain collaborative filtering,” IEEE Transactions on image processing, vol. 16, no. 8, pp. 2080–2095, 2007.
- (4) K. Zhang, W. Zuo, Y. Chen, D. Meng, and L. Zhang, “Beyond a gaussian denoiser: Residual learning of deep cnn for image denoising,” IEEE Transactions on Image Processing, vol. 26, no. 7, pp. 3142–3155, 2017.
- (5) P. Hand and V. Voroninski, “Global guarantees for enforcing deep generative priors by empirical risk,” in Conference on Learning Theory, 2018, arXiv:1705.07576.
- (6) N. Natarajan, I. S. Dhillon, P. K. Ravikumar, and A. Tewari, “Learning with noisy labels,” in Proc. Adv. in Neural Processing Systems (NIPS), 2013, pp. 1196–1204.
- (7) T. Xiao, T. Xia, Y. Yang, C. Huang, and X. Wang, “Learning from massive noisy labeled data for image classification,” in Proc. IEEE Int. Conf. Comp. Vision, and Pattern Recognition, 2015, pp. 2691–2699.
- (8) T. Liu and D. Tao, “Classification with noisy labels by importance reweighting,” IEEE Trans. Pattern Anal. Machine Intell., vol. 38, no. 3, pp. 447–461, 2016.
- (9) S. Sukhbaatar, J. Bruna, M. Paluri, L. Bourdev, and R. Fergus, “Training convolutional networks with noisy labels,” arXiv preprint arXiv:1406.2080, 2014.
- (10) S. Sukhbaatar and R. Fergus, “Learning from noisy labels with deep neural networks,” arXiv preprint arXiv:1406.2080, vol. 2, no. 3, p. 4, 2014.
- (11) S. Soltanayev and S. Y. Chun, “Training deep learning based denoisers without ground truth data,” arXiv preprint arXiv:1803.01314, 2018.
- (12) C. M. Stein, “Estimation of the mean of a multivariate normal distribution,” The annals of Statistics, pp. 1135–1151, 1981.
- (13) S. Ramani, T. Blu, and M. Unser, “Monte-carlo sure: A black-box optimization of regularization parameters for general denoising algorithms,” IEEE Transactions on Image Processing, vol. 17, no. 9, pp. 1540–1554, 2008.
- (14) Y. C. Eldar, “Generalized sure for exponential families: Applications to regularization,” IEEE Transactions on Signal Processing, vol. 57, no. 2, pp. 471–481, 2009.
- (15) C. Dong, C. Loy, K. He, and X. Tang, “Learning a deep convolutional network for image super-resolution,” in European Conference on Computer Vision. Springer, 2014, pp. 184–199.
- (16) C. Yang, X. Lu, Z. Lin, E. Shechtman, O. Wang, and H. Li, “High-resolution image inpainting using multi-scale neural patch synthesis,” in The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), vol. 1, no. 2, 2017, p. 3.
- (17) O. Ronneberger, P. Fischer, and T. Brox, “U-net: Convolutional networks for biomedical image segmentation,” in International Conference on Medical image computing and computer-assisted intervention. Springer, 2015, pp. 234–241.
- (18) D. Kingma and J. Ba, “Adam: A method for stochastic optimization,” arXiv preprint arXiv:1412.6980, 2014.
- (19) D. Martin, C. Fowlkes, D. Tal, and J. Malik, “A database of human segmented natural images and its application to evaluating segmentation algorithms and measuring ecological statistics,” Proc. Int. Conf. Computer Vision, vol. 2, pp. 416–423, July 2001.
- (20) K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” Proc. IEEE Int. Conf. Comp. Vision, and Pattern Recognition, pp. 770–778, 2016.
- (21) A. Mousavi, A. Maleki, and R. G. Baraniuk, “Parameterless optimal approximate message passing,” arXiv preprint arXiv:1311.0035, 2013.
- (22) D. L. Donoho, A. Maleki, and A. Montanari, “Message-passing algorithms for compressed sensing,” Proceedings of the National Academy of Sciences, vol. 106, no. 45, pp. 18 914–18 919, 2009.
- (23) S. Gleichman and Y. C. Eldar, “Blind compressed sensing,” IEEE Transactions on Information Theory, vol. 57, no. 10, pp. 6958–6975, 2011.
- (24) S. Jalali, A. Maleki, and R. G. Baraniuk, “Minimum complexity pursuit for universal compressed sensing,” IEEE Transactions on Information Theory, vol. 60, no. 4, pp. 2253–2268, 2014.
- (25) D. Baron and M. F. Duarte, “Universal map estimation in compressed sensing,” in Communication, Control, and Computing (Allerton), 2011 49th Annual Allerton Conference on. IEEE, 2011, pp. 768–775.
- (26) J. Zhu, D. Baron, and M. F. Duarte, “Recovery from linear measurements with complexity-matching universal signal estimation.” IEEE Trans. Signal Processing, vol. 63, no. 6, pp. 1512–1527, 2015.
- (27) S. Jalali and H. V. Poor, “Universal compressed sensing for almost lossless recovery,” IEEE Transactions on Information Theory, vol. 63, no. 5, pp. 2933–2953, 2017.
- (28) C. Metzler, A. Mousavi, and R. Baraniuk, “Learned d-amp: Principled neural network based compressive image recovery,” in Advances in Neural Information Processing Systems, 2017, pp. 1770–1781.
- (29) C. A. Metzler, A. Maleki, and R. G. Baraniuk, “From denoising to compressed sensing,” IEEE Transactions on Information Theory, vol. 62, no. 9, pp. 5117–5144, 2016.
- (30) A. Maleki, “Approximate message passing algorithm for compressed sensing,” Stanford University PhD Thesis, Nov. 2010.