Laplace Approximation for Deep Neural Networks

Table of Contents

Deep Neural Networks have been applied to a wide variety of tasks with tremendous success. However, the majority of these models simply learn to make a single maximum-likelihood or maximum-a-posteriori estimation, meaning that we do not know how certain a model is about any particular prediction. Bayesian Deep Learning aims to fill this void, by marginalizing over all network weights and thereby not just yielding a point estimate, but rather a full distribution over all possible parameter settings. However, with Bayesian Deep Learning one often runs into computational issues and has to settle with approximations that can nevertheless be really powerful. A nice introductory blog post about Bayesian Deep Learning can be found here{:target=“_blank”}. A somewhat forgotten technique that has gained a lot of traction in the past few years is the LaPlace Approximation. It provides a useful alternative and most interestingly, can be applied after training a model with standard techniques like stochastic gradient descent. Therefore, this post will take a closer look at the LaPlace Approximation and demonstrate it with a PyTorch implementation. There has been an entire libray{:target=“_blank”} created that implements the LaPlace Approximation for PyTorch. However, to better understand the concept I attempt to recreate a small example from scratch.

1. Introduction

Neural Networks can be defined as a function \(f\) parameterized by \(\theta\) that maps inputs \(X\) to an output \(Y\). Given a dataset \[D=\{(x_i,y_i): x_i \in X, y_i \in Y\}_{i=1}^n\] of iid observations, a network can then be trained by maximizing the likelihood \(p(D|\theta)=\prod_i^n p(y_i|f_{\theta}(x_i)\). Furthermore, if we also define a prior \(p(\theta)\) over the network parameters, we can estimate the posterior via Bayes' rule:

\[p(\theta|D)=\frac{p(D|\theta)p(\theta)}{p(D)}\]

  • \(p(\theta)\) - the prior distribution - represents knowledge about parameter \(\theta\) before the data D is seen
  • \(p(D \vert \theta)\) - the likelihood function - indicates the probability of obtaining the data D with that particular parameter setting \(\theta\), and
  • \(p(\theta \vert D)\) - the posterior distribution - represents the state of knowledge about parameters \(\theta\) after observing all the new information from D

Here, the common problem of Bayesian approaches is already starring us in the eye, namely the normalization constant \(p(D)=\int p(D|\theta) p(\theta) d\theta\). Considering a standard deep neural network of today which can easily have millions of parameters, this integral is simply untractable even with modern computing power. We therefore have to resort to approximating the posterior \(p(\theta|D)\) instead of exactly computing it like we would like to. A simple method is to simply finding the most likeliy point under the posterior, which is the maximum a posteriori (MAP) estimate defined as follows:

\[\theta_{MAP}= \underset{\theta}{\operatorname{argmax}} \prod_{i=1}^n p(y_i|f_{\theta}(x_i))p(\theta)\]

For numerical stability reasons, it is more common to work with the logarithm because it turns a product into a sum and as a convex function preserve the maximum and minima of the function we apply it to. Hence,

\[\theta_{MAP}= \underset{\theta}{\operatorname{argmax}} \sum_{i=1}^n \log p(y_i|f_{\theta}(x_i))+\log p(\theta)\]

Additionally, we often cast problems as minimizing an objective function with respect to the loss, so it is more common to minimize the negative of the above function which is equivalent and called the negative log likelihood:

\[\theta_{MAP}=\underset{\theta}{\operatorname{argmin}} \sum_{i=1}^n \log p(y_i|f_\theta(x_i))+\log p(\theta)\]

We can obtain this estimate through optimization with standard tools like stochastic gradient descent. While this gives us a model that allows us to make predictions, maybe even "very good" predictions, we still lack uncertainty around our parameter setting \(\theta\). There are a variety of methods that aim to estimate this uncertainty, but a simple and still effective one is the LaPlace Approximation.

2. LaPlace Approximation

First made popular by David MacKay's paper{:target=“_blank”} in 1992, when he applied the LaPlace approximation to Bayesian neural networks, the method has seen a combeback in recent years with researchers showing their applicability to today's deep neural networks. The following paragraphs will explain the method and illustrate its application to neural networks.

2.1. Idea and Derivation

The general goal is to find an approximation to the intractable posterior distribution of the network weights \(\theta\) given the data D. The idea of the LaPlace approximation is to take the MAP estimate and subsequently construct a Gaussian distribution around it, where the covariance matrix captures the curvature around this mode. This derivation follows the steps in the Pattern Recognition Book by Bishop Chapter 4.4. As a first step, we construct a second-order Taylor expansion around our mode \(\theta_{MAP}\):

\[\log p(\theta|D) \approx \log p(\theta_{MAP}|D)- \nabla_{\theta}f(\theta-\theta_{MAP}) - \frac{1}{2}(\theta-\theta_{MAP}) \nabla_{\theta}^2 f (\theta-\theta_{MAP})\]

First, recall that we are interested in the optimizing the negative log likelihood or therefore negative log posterior, which explains the minus signs. Furthermore, since we assume that \(\theta_{MAP}\) is actually a mode of the posterior, its first order derivative is zero and we can hence drop it:

\[\log p(\theta|D) \approx \log p(\theta_{MAP}|D) - \frac{1}{2}(\theta-\theta_{MAP}) \nabla_{\theta}^2 f (\theta-\theta_{MAP})\]

To remove the logarithm and obtain the posterior, we take the exponential of both sides:

\[p(\theta|D) \approx p(\theta_{MAP}|D)\exp(- \frac{1}{2}(\theta-\theta_{MAP}) \nabla_{\theta}^2 f (\theta-\theta_{MAP}))\]

The exponentitated term on the right already resembles a Gaussian distribution albeit unnormalized, meaning we only need a normalization factor Z to obtain our desired Gaussian distribution. The approximate distribution that we want to be Gaussian is proportional to the posterior that we want to approximate, meaning that they only differ up to a constant. The appropriate normalization constant Z that we need to turn our expression into a properly defined probability distribution can then be found by recalling the standard definition of a multivariate Gaussian:

\[\mathcal{N}(x|\mu, \Sigma)=\frac{1}{(2\pi)^{D/2}}\frac{1}{|\Sigma|^{1/2}}\exp{-\frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)}\]

where \(\mu\) is a D-dimensional mean vector and \(\Sigma\) is a DxD covariance matrix. Therefore, we can rewrite our posterior as:

\[p(\theta|D) \approx \frac{1}{(2\pi)^{D/2}}\frac{1}{|H|^{-1/2}}\exp{-\frac{1}{2}(x-\mu)^T \Sigma^{-1}(x-\mu)}\]

And finally, we have:

\[p(\theta|D)\approx \mathcal{N}(\theta_{MAP}, H^{-1})\]

Throughout this derivation we have assumed a single mode distribution at \(\theta_{MAP}\). However, in practice and especially with deep neural networks the actual distributions will have numerous modes that would each have a different LaPlace approximation around them. Additionally, it iss a local method that may fail to capture global properties important to the true distribution. Nevertheless, it is a reasonable approach that has been shown to work effectively (cite papers here).

2.2. The Hessian

One detail that has not been mentioned yet is that the computation of second order derivatives of the loss w.r.t. the weights, the Hessian, is for deep neural networks also too computatianlly expansive as it is a PxP matrix, where P is often in the order of many millions. Nonetheless, this Hessian can also be approximated by what is called the Generalized Gauss-Newton matrix (GGN) (a paper about it can be found here{:target=“_blank”} ). This takes the following form:

\[\widetilde{H}=\sum_{n=1}^NJ_n^TH_nJ_n+\lambda I\]

where \(J_n\in \mathcal{R}^{OxW}\) is the Jacobian of the model outputs with respect to the parameters \(\theta\) and \(H_n\in\mathcal{R}^{OxO}\) is the Hessian of the negative log-likelihood with respect to the model outputs. Here \(O\) denotes the model output size and \(W\) the number of parameters.

\[

\begin{align*} J_n=\frac{\partial f_{\theta}(x_n)}{\partial \theta} && H_n=\frac{-\partial^2\log(p(y|f_{\theta}(x_n))}{\partial^2 f_{\theta}(x_n)} \end{align*}

\]

This means, that we do not have to take the full second-order Hessian w.r.t all weights but can approximate it with a method that only requires more computationally cheaper derivatives.

2.3. Predictions

After all, we are interested in making predictions with our model that include uncertainty estimates. This is achieved through the posterior predictive distribution:

\[p(y^{\star}|x^{\star}, D)=\int p(y^{\star}|x^{\star},\theta)p(\theta|D)d\theta\]

, where \(x^{\star}\) is a new unseen prediction point and \(y^{\star}\) the corresponding prediction.However, this integral, which as you might have expected, is also intractable because we need to integrate over all parameters. Yet, this time, we can draw samples from the approximate posterior we found earlier and in that manner approximate the posterior predictive distribution:

\[\frac{1}{S}\sum_{s=1}^S p(y^{\star}|x^{\star}, \theta_s)\]

As you can observe the parameters \(\theta\) are subscripted by s, meaning that for each prediction input \(x^{\star}\) we make a total of S predictions, each time with a differently sampled parameter vector \(\theta\). Immer et al.{:target=“_blank”} have shown that in practice much better results can be achieved by linearizing the model. This means that we approximate the neural network function with a local linear approximation around the parameter setting that we sampled which amounts to a first order Taylor expansion. As a consequence the neural network function becomes linear in the parameters \(\theta\) but not the inputs X. This can be written down as follows:

\[f_{\theta}(x)=f_{\theta_{MAP}}(x)+ J_{\theta_{MAP}}(\theta-\theta_{MAP})\]

3. PyTorch Implementation

3.1. MAP Training

We will begin with the part that should be familiar to most: standard neural network training to obtain a MAP estimate. As a dataset we will use a small regression toy dataset with some Gaussian noise constructed as follows:

{::options parse_block_html="true" /}

<details>

<summary>

Sin Toy Dataset

</summary>

def sinusoid_data(n=150, sigma_noise=0.3):
    torch.manual_seed(42)
    # create simple sinusoid data set
    x_train = (torch.rand(n) * 10).unsqueeze(-1)
    y_train = torch.sin(x_train) + torch.randn_like(x_train) * sigma_noise
    x_test = torch.linspace(-10, 13, 500).unsqueeze(-1)
    y_test = torch.sin(x_test) + torch.randn_like(x_test) * sigma_noise

    return {"x_train": x_train, "y_train": y_train, "x_test": x_test, "y_test": y_test}

</details>

{::options parse_block_html="false" /}

Next, we define a loss criterion to optimize which will just be a normal MSE loss

{::options parse_block_html="true" /}

<details>

<summary>

Criterion

</summary>

def criterion(pred, target, reduce=True):
        """MSE Loss"""
        if reduce:
            return ((target - pred) ** 2).mean()
        else:
            return (target - pred) ** 2

</details>

{::options parse_block_html="false" /}

As a model we can construct a small two layer network with a non-linear activation function in between:

{::options parse_block_html="true" /}

<details>

<summary>

Model

</summary>

def construct_model(num_features, n_hidden_units=10):
    model = nn.Sequential(
        nn.Linear(num_features, n_hidden_units),
        nn.Tanh(),
        nn.Linear(n_hidden_units, 1),
    )
    return model

</details>

{::options parse_block_html="false" /}

The training loop will look as follows:

{::options parse_block_html="true" /}

<details>

<summary>

Train Loop

</summary>

def train(X, y, model, optimizer, n_epochs):
    ##################
    ## MAP Training ##
    ##################
    for i in range(n_epochs):
        pred = model(X)
        nll = criterion(pred, y)
        print(nll)
        optimizer.zero_grad()
        nll.backward()
        optimizer.step()

</details>

{::options parse_block_html="false" /}

Then we can train our model, by creating a model, and an optimizer of our choice. Here, we use the popular Adam optimizer with a learning rate of 0.01:

{::options parse_block_html="true" /}

<details>

<summary>

MAP Training

</summary>

x_train, y_train, x_test, _ = sinusoid_data()
model = construct_model(num_features=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

### map training ###
train(x_train, y_train, model, optimizer, n_epochs=n_epochs)

</details>

{::options parse_block_html="false" /}

We can look at a resulting plot that shows how the model has fitted the training data:

Here, we can already clearly observe that the model has not learned the general trend of a sine curve but rather just overfitted to the training data it has seen. This discrepency should be reflected in a measure of uncertainty where ideally, the model is more certain around data it has observed and uncertain in areas where it has not seen observations. For exactly this, we can use the LaPlace Approximation.

3.2. LaPlace Implementation

As mentioned in the beginning, the LA libray implements everything we would need at this point. The following will attempt to recreate the steps they take to have a simple example from scratch for our regression task. At the end we will have a LaPlace class that is fully available in a notebook at the end. With the concluded MAP training, we already have one of the two quantities we need in order to construct our posterior distribution with a Gaussian, namely the parameter settings \(\theta_{MAP}\) of our trained model above. The focus is therefore now to obtain the covariance matrix via the Gauss Newton Matrix. First, we need a couple of helper functions in order to work with the parameters of the PyTorch Sequential model and compute the Jacobian and Hessian properly.

To obtain all our parameters in one stacked vector and vice versa to set the weights of the model again with a certain parameter configuration, we can use the following:

{::options parse_block_html="true" /}

<details>

<summary>

Retrieving and setting parameter vector

</summary>

def params_to_vector(self):
    """
    Returns a vector of all model parameters as a stacked vector
    model
    """
    if not self.last_layer:
        param_vector = torch.cat([param.view(-1) for param in self.model.parameters()])
    else:
        last_layer = list(self.model.children())[-1]
        param_vector = torch.cat([param.view(-1) for param in last_layer.parameters()])

    self.num_params = param_vector.shape[0]

    return param_vector

def vector_to_params(self, param_vector):
    """
    returns the individual parameters from a vector
    Args:
        param_vector - given parameter vector to put into model

    """
    weight_idx = 0

    if not self.last_layer:
        param_iterator = self.model
    else:
        param_iterator = list(self.model.children())[-1] # last layer

    for param in param_iterator.parameters():
        param_len = param.numel()

        # updata parameter with param_vector slice
        param.data = param_vector[weight_idx: weight_idx+param_len].view_as(param).data

        weight_idx += param_len

</details>

{::options parse_block_html="false" /}

The computation of the Hessian is a bit more involved in PyTorch, but I found the following code that will compute the Hessian of the loss w.r.t to the model output.

{::options parse_block_html="true" /}

<details>

<summary>

Hessian computation

</summary>

def jacobian(self, y, x, create_graph=False):
    """https://gist.github.com/apaszke/226abdf867c4e9d6698bd198f3b45fb7"""
    jac = []
    flat_y = y.reshape(-1)
    grad_y = torch.zeros_like(flat_y)
    for i in range(len(flat_y)):
        grad_y[i] = 1.0
        (grad_x,) = torch.autograd.grad(
            flat_y, x, grad_y, retain_graph=True, create_graph=create_graph
        )
        jac.append(grad_x.reshape(x.shape))
        grad_y[i] = 0.0
    return torch.stack(jac).reshape(y.shape + x.shape)

def hessian(self, y, x):
    """https://gist.github.com/apaszke/226abdf867c4e9d6698bd198f3b45fb7"""
    return self.jacobian(self.jacobian(y, x, create_graph=True), x)

</details>

{::options parse_block_html="false" /}

Unfortunately, we can not use the above code to directly compute the Jacobian because we collect the parameters from the model and stack them into one vector which looses the computational graph required for PyTorch to compute derivatives. Instead, we can use the following code, to compute the Jacobian of the output w.r.t. to the parameter vector \(\theta\).

{::options parse_block_html="true" /}

<details>

<summary>

Jacobian

</summary>

def gradient(self, model):
        grad = torch.cat([p.grad.data.flatten() for p in model.parameters()])
        return grad.detach()

    def jacobian_params(self, model, data, k=True):
        model.zero_grad()
        output = model(data)
        Jacs = list()
        for i in range(output.shape[0]):
            rg = (i != (output.shape[0] - 1))
            output[i].backward(retain_graph=rg)
            jacs = self.gradient(model)
            model.zero_grad()
            Jacs.append(jacs)
        Jacs = torch.stack(Jacs)
        return Jacs.detach().squeeze(), output.detach()

</details>

{::options parse_block_html="false" /}

With these helper functions at hand, we can iterate through the training set again and compute the Gauss Newton Matrix that approximates the Hessian or precision matrix. To get the covariance matrix we simply need to invert the precision matrix.

\[\widetilde{H}=\sum_{n=1}^NJ_n^TH_nJ_n+\lambda I\]

{::options parse_block_html="true" /}

<details>

<summary>

Estimating Covariance Matrix

</summary>

def compute_mean_and_cov(self, train_loader, criterion):
    """
    Compute mean and covariance for laplace approximation with general gauss newton matrix
    """
    precision = torch.eye(self.num_params) * self.prior_precision

    self.loss = 0
    self.n_data = len(train_loader.dataset)

    for X, y in train_loader:
        m_out = self.model(X)
        batch_loss = criterion(m_out, y)

        # jac is of shape N x num_params
        jac, _ = self.jacobian_params(self.model, X)

        # hess is diagonal matrix of shape of NxN, where N is X.shape[0] or batch_size
        hess = self.hessian(batch_loss, m_out).squeeze()
        hess = torch.eye(X.shape[0]) # found much better results with this, do not know why
        precision += jac.T @ hess @ jac

        self.loss += batch_loss.item()

    self.n_data = len(train_loader.dataset)
    self.map_mean = self.params_to_vector()
    self.H = precision
    self.cov = torch.linalg.inv(precision)

</details>

{::options parse_block_html="false" /}

Lastly, we implement our prediction method via linear sampling, by linearizing the network and then take MC samples with which we approximate the posterior predictive distribution.

{::options parse_block_html="true" /}

<details>

<summary>

Linear sampling

</summary>

def linear_sampling(self, X, num_samples=100):
    theta_map = self.params_to_vector()

    if not self.last_layer:
        jac, model_map = self.jacobian_params(self.model, X)
    else:

        jac, model_map = self.last_layer_jacobian(self.model, X)

    offset = model_map - jac @ theta_map.unsqueeze(-1)

    # reparameterization trick
    covs = self.cov @ torch.randn(len(theta_map), num_samples)

    theta_samples = theta_map + covs.T # num_samples x num_params
    preds = list()

    for i in range(num_samples):
        pred = offset + jac @ theta_samples[i].unsqueeze(-1)
        preds.append(pred.detach())

    preds = torch.stack(preds)

    return preds, model_map

</details>

{::options parse_block_html="false" /}

We can put all the pieces into a class so that we can apply it more easily to pretrained models, which can be seen in the jupyter notebook{:target=“_blank”}.

3.3. Optimizing prior precision and sigma noise

In practice, it is also common to tune the prior precision and sigma noise. Once we have obtained a MAP estimate, the marginal likelihood becomes a fixed quantity given the MAP parameter setting. This implies that the usually assumed to be constants like prior precision and sigma noise variable can be tuned via modern optimization methods. The LaPlace library has this implemented which I am using for the figures that follow below.

3.4. Some Results

A first observation that became obvious was the impact of scaling the input data appropriately. This is generally something that should be done for Neural Networks to obtain more stable gradients during backpropagation. Here is an example with unscaled input data:

In contrast, it changes quiet a bit with scaled input data, where we can see that the uncertainy ranges become more concise and do not pan out as much:

Lastly, when debugging code and not actually computing the Hessian of the loss w.r.t. to the model output like the GGN method states, but simply using the identity matrix as the Hessian, the uncertainty estimates of the implementation became much smoother:

In some sense, this implementation looks better than the library one because the uncertainty should not increase indefinitely when moving further away from the training data. I plan to investigate the stated observations in a follow up post, as I want to keep this text as an introduction to the method of Laplace Approximation itself.

4. Conclusion

In this blog post, we took a closer look at the LaPlace approximation and how it can be applied to neural networks. The LaPlace class we constructed from scratch works okay for smaller models and illustrates the steps that occur to equip neural networks with uncertainty after training them. However, for modern neural networks that contain several million parameters, further approximation steps need to be taken. More specifically, there need to be further approximations of the Hessian through for example Kronecker-factored Approximate Curvature (K-FAC) or diagonalization. These are implemented in the LA library{:target=“_blank”} and can easily be applied to your trained network at scale for regression and classification tasks.

Date: March 2, 2023

Author: Nils Lehmann

Created: 2024-03-13 Wed 15:58