Homework 11: Detecting Distribution Shift on MNIST using Bayesian Neural Networks#
Overview#
In this exercise we will compare Bayesian NNs with deterministic NNs on a distribution shift detection task. To do this, we’ll monitor the predictive entropy as the distribution gradually shifts. A model with better uncertainty quantification should become less certain—that is, have a more entropic predictive distribution—as the input distribution shifts. Mathematically, our quantity of interest is:
where \(p(y | x^{*}, D)\) is the predictive distribution:
The goal is to obtain something similar to Figure 4 from the paper Multiplicative Normalizing Flows for Variational Bayesian Neural Networks, comparing MC dropout, ensembles, and a Bayesian NN.
We will be using the well-known MNIST dataset, a set of 70,000 hand-written digit images, and we will generate a gradual distribution shift on the dataset by rotating the images. As such, the final plot will depict the change in the entropy of the predictive distribution (y-axis) as degree of rotation increases (x-axis). The paper above shows the result for one image. We, on the other hand, will average over multiple images to make a better comparison between models.
We’ll use rotation to simulate a smooth shift. Here’s how you can rotate a given image:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import trange, tqdm
torch.manual_seed(42)
np.random.seed(42)
from PIL import Image
from torchvision import datasets
from torch.nn.functional import softmax
from torchvision.transforms.functional import rotate
def imshow(image):
plt.imshow(image, cmap='gray', vmin=0, vmax=255)
plt.show()
def show_rotation_on_mnist_example_image():
mnist_train = datasets.MNIST('./tmp_data', train=True, download=True)
image = Image.fromarray(mnist_train.data[0].numpy())
imshow(image)
rotated_image = rotate(image, angle=90)
imshow(rotated_image)
show_rotation_on_mnist_example_image()
Let’s setup the training and testing data:
def get_mnist_data(train=True):
mnist_data = datasets.MNIST('../data', train=train, download=True)
x = mnist_data.data.reshape(-1, 28 * 28).float()
y = mnist_data.targets
return x, y
x_train, y_train = get_mnist_data(train=True)
x_test, y_test = get_mnist_data(train=False)
Now that we have the data, let’s start training neural networks.
Define non-Bayesian (Deterministic) Neural Network#
We will reuse our MLP network architecture with different hyperparameters:
First let’s create our point estimate neural network, in other words a standard fully connected MLP. We will define the number of hidden layers dynamically so we can reuse the same class for different depths. We will also add a dropout flag, this will allow us to easily use the same architecture for our BNN.
class MLP(nn.Module):
def __init__(self, input_dim=1, output_dim=1, hidden_dim=10, n_hidden_layers=1, use_dropout=False):
super().__init__()
self.use_dropout = use_dropout
if use_dropout:
self.dropout = nn.Dropout(p=0.5)
self.activation = nn.Tanh()
# dynamically define architecture
self.layer_sizes = [input_dim] + n_hidden_layers * [hidden_dim] + [output_dim]
layer_list = [nn.Linear(self.layer_sizes[idx - 1], self.layer_sizes[idx]) for idx in
range(1, len(self.layer_sizes))]
self.layers = nn.ModuleList(layer_list)
def forward(self, input):
hidden = self.activation(self.layers[0](input))
for layer in self.layers[1:-1]:
hidden_temp = self.activation(layer(hidden))
if self.use_dropout:
hidden_temp = self.dropout(hidden_temp)
hidden = hidden_temp + hidden # residual connection
output_mean = self.layers[-1](hidden).squeeze()
return output_mean
Problem 1: Deterministic Neural Network#
net = MLP(input_dim=784, output_dim=10, hidden_dim=30, n_hidden_layers=3)
Training#
def train_on_mnist(net):
x_train, y_train = get_mnist_data(train=True)
optimizer = torch.optim.Adam(params=net.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
batch_size = 250
progress_bar = trange(20)
for _ in progress_bar:
for batch_idx in range(int(x_train.shape[0] / batch_size)):
batch_low, batch_high = batch_idx * batch_size, (batch_idx + 1) * batch_size
optimizer.zero_grad()
loss = criterion(target=y_train[batch_low:batch_high], input=net(x_train[batch_low:batch_high]))
progress_bar.set_postfix(loss=f'{loss / batch_size:.3f}')
loss.backward()
optimizer.step()
return net
net = train_on_mnist(net)
Testing#
def accuracy(targets, predictions):
return (targets == predictions).sum() / targets.shape[0]
def evaluate_accuracy_on_mnist(net):
test_data = get_mnist_data(train=False)
x_test, y_test = test_data
net.eval()
y_preds = net(x_test).argmax(1)
acc = accuracy(y_test, y_preds)
print("Test accuracy is %.2f%%" % (acc.item() * 100))
evaluate_accuracy_on_mnist(net)
Rotating the Images#
Now let’s compute predictive entropy on some rotated images…
First we will generate the rotated images with an increasing rotation angle from the test images. We use a subset of the MNIST test set for evaluation:
def get_mnist_test_subset(n_test_images):
mnist_test = datasets.MNIST('../data', train=False, download=True)
x = mnist_test.data[:n_test_images].float()
y = mnist_test.targets[:n_test_images]
return x, y
n_test_images = 100
x_test_subset, y_test_subset = get_mnist_test_subset(n_test_images=n_test_images)
rotation_angles = [3 * i for i in range(0, 31)] # use angles from 0 to 90 degrees
rotated_images = [rotate(x_test_subset, angle).reshape(-1, 28 * 28) for angle in rotation_angles]
Evaluating on the Rotated Images#
y_preds_deterministic = [softmax(net(images), dim=-1) for images in rotated_images]
The information entropy \(H\) of a probability distribution \(p\) over a discrete random variable \(X\) with possible outcomes \(x_1, \ldots, x_N\), occuring with probabilities \(p(x_i) := p_i\) is given by:
The entropy quantifies the uncertainty of a probability distribution in the sense, that the more uncertain the outcome a hypothetical experiment with drawing from the distribution is the higher the entropy. Highest is for an equal distribution of probability mass over all possible outcomes. In our case the deterministic NN estimates a probability distribution over the ten digits as classes on MNIST for each image. For the rotated images we can thus calculate the entropy over the rotation angle.
Question: How do you expect the entropy to behave with increasing rotation angle of the images? Answer in the cell below:
Implement a function for calculating the entropy according to the formula above
def entropy(p):
# YOUR CODE HERE
raise NotImplementedError()
Now we can calculate the accuracies and entropies for all rotated images and plot both:
def calculate_accuracies_and_entropies(y_preds):
accuracies = [accuracy(y_test_subset, p.argmax(axis=1)) for p in y_preds]
entropies = [np.mean(entropy(p.detach().numpy())) for p in y_preds]
return accuracies, entropies
def plot_accuracy_and_entropy(add_to_plot):
fig, ax = plt.subplots(figsize=(10, 5))
plt.xlim([0, 90])
plt.xlabel("Rotation Angle", fontsize=20)
add_to_plot(ax)
plt.legend()
plt.show()
def add_deterministic(ax):
accuracies, entropies = calculate_accuracies_and_entropies(y_preds_deterministic)
ax.plot(rotation_angles, accuracies, 'b--', linewidth=3, label="Accuracy, Deterministic")
ax.plot(rotation_angles, entropies, 'b-', linewidth=3, label="Entropy, Deterministic")
plot_accuracy_and_entropy(add_deterministic)
Question: What is your interpretation of the plot above: Is the predictive entropy changing? If so, how would you explain this? Answer in the cell below:
Problem 2: Monte Carlo Dropout Network#
Let’s create our Dropout Network. We keep the network depth and hidden layer size the same as for the MLP for a fair model comparison
net_dropout = MLP(input_dim=784, output_dim=10, hidden_dim=30, n_hidden_layers=3, use_dropout=True)
Training#
net_dropout = train_on_mnist(net_dropout)
Testing#
evaluate_accuracy_on_mnist(net_dropout)
Evaluating on the Rotated Images#
Sample 100 different dropout masks and average the predictions over them (store the predictions in a list called y_pred_dropout
)
n_dropout_samples = 100
net_dropout.train() # we set the model to train to 'activate' the dropout layer
# YOUR CODE HERE
raise NotImplementedError()
# y_preds_dropout = NotImplemented
Question: What is the best way to average over the predictions? Should you first average the network output and then apply the softmax, or the other way around? Answer in the cell below:
def add_deterministic_and_dropout(ax):
accuracies, entropies = calculate_accuracies_and_entropies(y_preds_deterministic)
ax.plot(rotation_angles, accuracies, 'b--', linewidth=3, label="Accuracy, Deterministic")
ax.plot(rotation_angles, entropies, 'b-', linewidth=3, label="Entropy, Deterministic")
accuracies, entropies = calculate_accuracies_and_entropies(y_preds_dropout)
ax.plot(rotation_angles, accuracies, 'r--', linewidth=3, label="Accuracy, MC Dropout")
ax.plot(rotation_angles, entropies, 'r-', linewidth=3, label="Entropy, MC Dropout")
plot_accuracy_and_entropy(add_deterministic_and_dropout)
Question: How does MLP compare with MC-Dropout Network? (Are there any benefits of the Bayesian approach?). Answer in the cell below:
Problem 3: Deep Ensemble#
Deep ensembles were first introduced by Lakshminarayanan et al. (2017). As the name implies multiple point estimate NN are trained, an ensemble, and the final prediction is computed as an average across the models. From a Bayesian perspective the different point estimates correspond to modes of a Bayesian posterior. This can be interpreted as approximating the posterior with a distribution parametrized as multiple Dirac deltas:
where \(\alpha_{\theta_{i}}\) are positive constants such that their sum is equal to one.
Now let’s investigate Deep Ensemble performance. We will use the exact same network hyperparameters as for the MLP:
Define and train an ensemble of five MLPs with the same hyperparameters as above. First generate the ensembles (store in a list called ensemble
)
ensemble_size = 5
# YOUR CODE HERE
raise NotImplementedError()
# ensemble = NotImplemented
Training#
for net in ensemble:
train_on_mnist(net)
Testing#
Evaluate the accuracy of the ensemble prediction. Hints: How do you aggregate best over the multiple different predictions given by the members of the ensemble? What is the difference to the regression setting above?v
# YOUR CODE HERE
raise NotImplementedError()
# y_preds = NotImplemented
Evaluating on the Rotated Images#
Again, average the predictions, but this time over the members of the ensemble (store the predictions in a list called y_preds_ensemble
)
# YOUR CODE HERE
raise NotImplementedError()
# y_preds_ensemble = NotImplemented
def add_deep_ensemble(ax):
accuracies, entropies = calculate_accuracies_and_entropies(y_preds_deterministic)
ax.plot(rotation_angles, accuracies, 'b--', linewidth=3, label="Accuracy, Deterministic")
ax.plot(rotation_angles, entropies, 'b-', linewidth=3, label="Entropy, Deterministic")
accuracies, entropies = calculate_accuracies_and_entropies(y_preds_dropout)
ax.plot(rotation_angles, accuracies, 'r--', linewidth=3, label="Accuracy, MC Dropout")
ax.plot(rotation_angles, entropies, 'r-', linewidth=3, label="Entropy, MC Dropout")
accuracies, entropies = calculate_accuracies_and_entropies(y_preds_ensemble)
ax.plot(rotation_angles, accuracies, 'g--', linewidth=3, label="Accuracy, Deep Ensemble")
ax.plot(rotation_angles, entropies, 'g-', linewidth=3, label="Entropy, Deep Ensemble")
plot_accuracy_and_entropy(add_deep_ensemble)
Question: Are there any differences in the performance? Explain why you see or don’t see any differences. Answer in the cell below:
[EXTRA CREDIT] Problem 4: Bayesian Neural Network#
First install pyro package:
%pip install pyro-ppl
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
from pyro.infer import Predictive
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.distributions import Normal, Categorical
from torch.nn.functional import softmax
from tqdm.auto import trange, tqdm
Implement a Bayesian Neural Network for classifying MNIST digits. For background on deep Bayesian networks, see the lecture notebook on Uncertainty Quantification.
As a backbone use the MLP architecture introduced in the beginning of the notebook. However, because we will implement a custom guide(), define every layer explicitly.
class My_MLP(nn.Module):
'''
Implement a MLP with 3 hidden layers, Tanh activation, no dropout or residual connections
'''
def __init__(self, in_dim=784, out_dim=10, hid_dim=200):
super().__init__()
assert in_dim > 0
assert out_dim > 0
assert hid_dim > 0
# Define the activation
# YOUR CODE HERE
raise NotImplementedError()
# self.act =
# Define the 3 hidden layers:
# YOUR CODE HERE
raise NotImplementedError()
# self.fc1 =
# self.fc2 =
# self.fc3 =
# self.out =
def forward(self, x):
# YOUR CODE HERE
raise NotImplementedError()
return pred
Initialize the network. You will have to access it’s layers in your model and guide functions
net = My_MLP()
# confirm your layer names
for name, _ in net.named_parameters():
print(name)
Define the model:
Probablistic models in Pyro are specified as model() functions. This function defines how the output data is generated. Within the model() function, first, the pyro module random_module() converts the paramaters of our NN into random variables that have prior probability distributions. Second, in pyro sample we define that the output of the network is categorical, while the pyro plate allows us to vectorize this function for computational efficiency.
Hint: remember we are doing a classification instead of regression!
You can ‘cheat’ a little: to speed up the training and limit a bit more the number of paramters we need to optimize, implement a BNN where only the last layer is Bayesian!
def model(x_data, y_data):
# YOUR CODE HERE
raise NotImplementedError()
implement the guide(), variational distribution:
the guide allows us to initialise a well behaved distribution which later we can optmize to approximate the true posterior
softplus = torch.nn.Softplus()
def my_guide(x_data, y_data):
# YOUR CODE HERE
raise NotImplementedError()
Initialize the stochastic variational inference (SVI)
adam = pyro.optim.Adam({"lr": 1e-3})
# YOUR CODE HERE
# svi = raise NotImplementedError()
Training#
pyro.clear_param_store()
batch_size = 250
bar = trange(30)
for epoch in bar:
for batch_idx in range(int(x_train.shape[0] / batch_size)):
batch_low, batch_high = batch_idx * batch_size, (batch_idx+1) * batch_size
loss = svi.step(x_train[batch_low:batch_high], y_train[batch_low:batch_high])
bar.set_postfix(loss=f'{loss / batch_size:.3f}')
Testing#
Use the learned guide() function to do predictions. Why? Because the model() function knows the priors for the weights and biases, not the learned posterior. The guide() contains the approximate posterior distributions of the parameter values, which we want to use to make the predictions.
num_samples = 10
# YOUR CODE HERE
raise NotImplementedError()
# y_preds = NotImplemented
Evaluating on Rotated Images#
Store the predictions in a list called y_preds_bnn
num_samples = 50
# YOUR CODE HERE
raise NotImplementedError()
# y_preds_bnn = NotImplemented
Show entropies for all four models
# add the computed values for BNN
def add_bnn(ax):
accuracies, entropies = calculate_accuracies_and_entropies(y_preds_deterministic)
ax.plot(rotation_angles, accuracies, 'b--', linewidth=3, label="Accuracy, Deterministic")
ax.plot(rotation_angles, entropies, 'b-', linewidth=3, label="Entropy, Deterministic")
accuracies, entropies = calculate_accuracies_and_entropies(y_preds_dropout)
ax.plot(rotation_angles, accuracies, 'r--', linewidth=3, label="Accuracy, MC Dropout")
ax.plot(rotation_angles, entropies, 'r-', linewidth=3, label="Entropy, MC Dropout")
accuracies, entropies = calculate_accuracies_and_entropies(y_preds_ensemble)
ax.plot(rotation_angles, accuracies, 'g--', linewidth=3, label="Accuracy, Deep Ensemble")
ax.plot(rotation_angles, entropies, 'g-', linewidth=3, label="Entropy, Deep Ensemble")
accuracies, entropies = calculate_accuracies_and_entropies(y_preds_bnn)
ax.plot(rotation_angles, accuracies, 'y--', linewidth=3, label="Accuracy, BNN")
ax.plot(rotation_angles, entropies, 'y-', linewidth=3, label="Entropy, BNN")
plot_accuracy_and_entropy(add_bnn)
Question: Which method is the best at detecting the distribution shift? How can you interpret this? Answer in the cell below:
Acknowledgments#
Initial version: Mark Neubauer
© Copyright 2024