Copyright 2021-2023 Lawrence Livermore National Security, LLC and other MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.

SPDX-License-Identifier: MIT

Deep Kernels with MuyGPs in PyTorch Tutorial

In this tutorial, we outline how to construct a simple deep kernel model using the PyTorch implementation of MuyGPs.

We use the MNIST classification problem as a benchmark. We will use the deep kernel MuyGPs model to classify images of handwritten digits between 0 and 9. In order to reduce the runtime of the training loop, we will use a fully-connected architecture, meaning we will have to vectorize each image prior to training. We download the training and testing data using the torchvision.datasets API.

First, we will import necessary dependencies. We also force MuyGPyS to use the "torch" backend. This can also be done by setting the MUYGPYS_BACKEND environment variable to "torch".

[2]:
%env MUYGPYS_BACKEND=torch
%env MUYGPYS_FTYPE=32
env: MUYGPYS_BACKEND=torch
env: MUYGPYS_FTYPE=32
[3]:
from MuyGPyS.gp.distortion import l2

import numpy as np
import torch
import torchvision
import os
from torch.nn.functional import one_hot
root = './data'
if not os.path.exists(root):
    os.mkdir(root)

We use torch’s utilities to download MNIST and transform it into an appropriately normalized tensor.

[4]:
trans = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,),(1.0,)),
    ]
)
train_set = torchvision.datasets.MNIST(
    root=root, train=True, transform=trans, download=True
)
test_set = torchvision.datasets.MNIST(
    root=root, train=False, transform=trans, download=True
)

MNIST is a popular benchmark dataset of hand-written digits, 0-9. Each digit is a 28x28 pixel image, with 784 total pixel features. In the interest of reducing runtime, we will use vectorized images as our features in this dataset.

[5]:
num_classes = 10
num_train_samples = 60000
num_test_samples = 10000
num_pixels = 784

We will collect 60,000 training samples and 10,000 test samples. We vectorize the images and one-hot encode the class labels.

[6]:
train_features = torch.zeros((num_train_samples,num_pixels))
train_responses = torch.zeros((num_train_samples,num_classes))

for i in range(num_train_samples):
    train_features[i,:] = train_set[i][0].flatten()
    train_responses[i,:] = one_hot(
        torch.tensor(train_set[i][1]).to(torch.int64),
        num_classes=num_classes,
    )

test_features = torch.zeros((num_test_samples,num_pixels))
test_responses = torch.zeros((num_test_samples,num_classes))

for i in range(num_test_samples):
    test_features[i,:] = test_set[i][0].flatten()
    test_responses[i,:] = one_hot(
        torch.tensor(test_set[i][1]).to(torch.int64),
        num_classes=num_classes,
    )

We set up our nearest neighbor lookup structure using the NN_Wrapper data structure in MuyGPs. We then define our batch and construct tensor containing the features and targets of the batched elements and their 30 nearest neighbors. We choose an algorithm that will return the exact nearest neighbors. We set a random seed for reproducability.

[7]:
from torch import nn
import random
from torch.optim.lr_scheduler import ExponentialLR
torch.autograd.set_detect_anomaly(True)
np.random.seed(0)
test_count, _ = test_features.shape
train_count, _ = train_features.shape


from MuyGPyS.neighbors import NN_Wrapper
nn_count = 30
nbrs_lookup = NN_Wrapper(train_features, nn_count, nn_method="exact")

We sample a training batch of 500 elements and record their indices and those of their nearest neighbors.

[8]:

#We will make use of batching in our hyperparameter training from MuyGPyS.optimize.batch import sample_batch batch_count = 500 batch_indices, batch_nn_indices = sample_batch( nbrs_lookup, batch_count, train_count ) batch_features = train_features[batch_indices,:] batch_targets = train_responses[batch_indices, :] batch_nn_targets = train_responses[batch_nn_indices, :] if torch.cuda.is_available(): train_features = train_features.cuda() train_responses = train_responses.cuda() test_features = test_features.cuda() test_responses = test_responses.cuda()

We now define a stochastic variational deep kernel MuyGPs class. This class composes a dense neural network embedding with a MuyGPyS.torch.muygps_layer Gaussian process layer. Presently, this layer only supports the Matérn kernel with special values of the nu or smoothness parameter set to 0.5, 1.5, 2.5, or \(\infty\). The smoothness values are limited because torch does not implement modified bessel functions of the second kind. Future versions of the library will also support other kernel types.

[9]:
from MuyGPyS.torch import MuyGPs_layer
print('Building Stochastic Variational Deep Kernel MuyGPs model')

class SVDKMuyGPs(nn.Module):
    def __init__(
        self,
        muygps_model,
        batch_indices,
        batch_nn_indices,
        batch_targets,
        batch_nn_targets,
    ):
        super().__init__()
        self.embedding = nn.Sequential(
            nn.Linear(784,400),
            nn.ReLU(),
            nn.Linear(400,200),
            nn.ReLU(),
            nn.Linear(200,100),
        )
        self.batch_indices = batch_indices
        self.batch_nn_indices = batch_nn_indices
        self.batch_targets = batch_targets
        self.batch_nn_targets = batch_nn_targets
        self.GP_layer = MuyGPs_layer(
            muygps_model,
            batch_indices,
            batch_nn_indices,
            batch_targets,
            batch_nn_targets,
        )

    def forward(self,x):
        predictions = self.embedding(x)
        predictions, variances = self.GP_layer(predictions)
        return predictions, variances
Building Stochastic Variational Deep Kernel MuyGPs model

Training a Deep Kernel MuyGPs Model

We instantiate a SVDKMuyGPs model with initial guess hyperparameters. We fix a Matérn kernel smoothness parameter of 0.5 and a Guassian homoscedastic noise prior variance of 1e-6.

[10]:
from MuyGPyS.gp import MuyGPS
from MuyGPyS.gp.noise import HomoscedasticNoise
from MuyGPyS.gp.hyperparameter import ScalarHyperparameter
from MuyGPyS.gp.kernels import Matern
from MuyGPyS.gp.distortion import IsotropicDistortion

model_nu = 0.5
model_length_scale = 1.0
measurement_eps = 1e-6

muygps_model = MuyGPS(
    kernel=Matern(
        nu=ScalarHyperparameter(model_nu),
        metric=IsotropicDistortion(l2,
            length_scale=ScalarHyperparameter(model_length_scale)
        ),
    ),
    eps=HomoscedasticNoise(measurement_eps),
)


model = SVDKMuyGPs(
    muygps_model = muygps_model,
    batch_indices=batch_indices,
    batch_nn_indices=batch_nn_indices,
    batch_targets=batch_targets,
    batch_nn_targets=batch_nn_targets,
)
if torch.cuda.is_available():
    model = model.cuda()

We use the Adam optimizer over 10 training iterations, with an initial learning rate of 1e-2 and decay of 0.97.

[11]:
training_iterations = 10
optimizer = torch.optim.Adam(
    [{'params': model.parameters()}], lr=1e-2
)
scheduler = ExponentialLR(optimizer, gamma=0.97)

We will use cross-entropy loss, as it is commonly performant for classification problems. Other losses are available.

[12]:
ce_loss = nn.CrossEntropyLoss()
# mse_loss = nn.MSELoss()
# l1_loss = nn.L1Loss()
# bce_loss = nn.BCELoss()

We construct a standard PyTorch training loop function.

[13]:
def train(nbrs_lookup):
    for i in range(training_iterations):
        model.train()
        optimizer.zero_grad()
        predictions,variances = model(train_features)
        loss = ce_loss(predictions,batch_targets)
        loss.backward()
        optimizer.step()
        scheduler.step()
        if np.mod(i,1) == 0:
            print(f"Iter {i + 1}/{training_iterations} - Loss: {loss.item()}")
            model.eval()
            nbrs_lookup = NN_Wrapper(
                model.embedding(train_features).detach().numpy(),
                nn_count, nn_method="exact"
            )
            batch_nn_indices,_ = nbrs_lookup._get_nns(
                model.embedding(batch_features).detach().numpy(),
                nn_count=nn_count,
            )
            batch_nn_targets = train_responses[batch_nn_indices, :]
            model.batch_nn_indices = batch_nn_indices
            model.batch_nn_targets = batch_nn_targets
        torch.cuda.empty_cache()
    nbrs_lookup = NN_Wrapper(
        model.embedding(train_features).detach().numpy(),
        nn_count,
        nn_method="exact",
    )
    batch_nn_indices,_ = nbrs_lookup._get_nns(
        model.embedding(batch_features).detach().numpy(),
        nn_count=nn_count,
    )
    batch_nn_targets = train_responses[batch_nn_indices, :]
    model.batch_nn_indices = batch_nn_indices
    model.batch_nn_targets = batch_nn_targets
    return nbrs_lookup, model

Finally, we execute the training function and evaluate the trained model

[14]:
nbrs_lookup, model_trained = train(nbrs_lookup)
model_trained.eval()
Iter 1/10 - Loss: 1.5154695510864258
Iter 2/10 - Loss: 1.4770054817199707
Iter 3/10 - Loss: 1.441899299621582
Iter 4/10 - Loss: 1.4263132810592651
Iter 5/10 - Loss: 1.4199936389923096
Iter 6/10 - Loss: 1.40185546875
Iter 7/10 - Loss: 1.3923559188842773
Iter 8/10 - Loss: 1.380581259727478
Iter 9/10 - Loss: 1.372625470161438
Iter 10/10 - Loss: 1.3624086380004883
[14]:
SVDKMuyGPs(
  (embedding): Sequential(
    (0): Linear(in_features=784, out_features=400, bias=True)
    (1): ReLU()
    (2): Linear(in_features=400, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=100, bias=True)
  )
  (GP_layer): MuyGPs_layer()
)

We then compute and report the performance of the predicted test responses using this trained model.

[15]:
from MuyGPyS.examples.muygps_torch import predict_model
predictions, variances = predict_model(
    model=model_trained,
    test_features=test_features,
    train_features=train_features,
    train_responses=train_responses,
    nbrs_lookup=nbrs_lookup,
    nn_count=nn_count,
)
print("MNIST Prediction Accuracy Using Low-Level Torch Implementation:")
print(
    (
        torch.sum(
            torch.argmax(predictions,dim=1) == torch.argmax(test_responses,dim=1)
        ) / 10000
    ).numpy()
)
MNIST Prediction Accuracy Using Low-Level Torch Implementation:
0.9398

Training a Deep Kernel MuyGPs Model Using Our Example API Function

Similar to our one-line regression tutorial API, we support a one-line Deep MuyGPs regression API. This snippet performs the same work as above with a singular function execution.

[16]:
#Import high-level API function train_deep_kernel_muygps
from MuyGPyS.examples.muygps_torch import train_deep_kernel_muygps

model_nu = 0.5
model_length_scale = 1.0
measurement_eps = 1e-6

muygps_model = MuyGPS(
    kernel=Matern(
        nu=ScalarHyperparameter(model_nu),
        metric=IsotropicDistortion(l2,
            length_scale=ScalarHyperparameter(model_length_scale)
        ),
    ),
    eps=HomoscedasticNoise(measurement_eps),
)

#Use leave-one-out-likelihood loss function to train model
model = SVDKMuyGPs(
    muygps_model=muygps_model,
    batch_indices=batch_indices,
    batch_nn_indices=batch_nn_indices,
    batch_targets=batch_targets,
    batch_nn_targets=batch_nn_targets)

nbrs_lookup, model_trained = train_deep_kernel_muygps(
    model=model,
    train_features=train_features,
    train_responses=train_responses,
    batch_indices=batch_indices,
    nbrs_lookup=nbrs_lookup,
    training_iterations=10,
    optimizer_method=torch.optim.Adam,
    learning_rate=1e-2,
    scheduler_decay=0.97,
    loss_function="ce",
    update_frequency=1,
    verbose=True,
)

model_trained.eval()
Iter 1/10 - Loss: 1.5164387226
Iter 2/10 - Loss: 1.4815564156
Iter 3/10 - Loss: 1.4386521578
Iter 4/10 - Loss: 1.4218910933
Iter 5/10 - Loss: 1.4173457623
Iter 6/10 - Loss: 1.4027299881
Iter 7/10 - Loss: 1.3904489279
Iter 8/10 - Loss: 1.3764212132
Iter 9/10 - Loss: 1.3704509735
Iter 10/10 - Loss: 1.3603465557
[16]:
SVDKMuyGPs(
  (embedding): Sequential(
    (0): Linear(in_features=784, out_features=400, bias=True)
    (1): ReLU()
    (2): Linear(in_features=400, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=100, bias=True)
  )
  (GP_layer): MuyGPs_layer()
)

We similarly report our prediction performance on the test responses using this trained model.

[17]:
from MuyGPyS.examples.muygps_torch import predict_model
predictions,variances = predict_model(
    model=model_trained,
    test_features=test_features,
    train_features=train_features,
    train_responses=train_responses,
    nbrs_lookup=nbrs_lookup,
    nn_count=nn_count,
)

print("MNIST Prediction Accuracy Using High-Level Training API:")
print(
    (
        torch.sum(
            torch.argmax(predictions,dim=1) == torch.argmax(test_responses,dim=1)
        ) / 10000
    ).numpy()
)
MNIST Prediction Accuracy Using High-Level Training API:
0.9348

We note that this is quite mediocre performance on MNIST. In the interest of reducing notebook runtime we have used a simple fully-connected neural network model to construct the Gaussian process kernel. To achieve results closer to the state-of-the-art (near 100% accuracy), we recommend using more complex architectures which integrate convolutional kernels into the model.