Copyright 2021-2023 Lawrence Livermore National Security, LLC and other MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.

SPDX-License-Identifier: MIT

Deep Kernels with MuyGPs in PyTorch Tutorial

In this tutorial, we outline how to construct a simple deep kernel model using the PyTorch implementation of MuyGPs.

We use the MNIST classification problem as a benchmark. We will use the deep kernel MuyGPs model to classify images of handwritten digits between 0 and 9. In order to reduce the runtime of the training loop, we will use a fully-connected architecture, meaning we will have to vectorize each image prior to training. We download the training and testing data using the torchvision.datasets API.

First, we will import necessary dependencies. We also force MuyGPyS to use the "torch" backend. This can also be done by setting the MUYGPYS_BACKEND environment variable to "torch".

[1]:
from MuyGPyS import config
config.update("muygpys_backend","torch")

import numpy as np
import torch
import torchvision
import os
from torch.nn.functional import one_hot
root = './data'
if not os.path.exists(root):
    os.mkdir(root)

We use torch’s utilities to download MNIST and transform it into an appropriately normalized tensor.

[2]:
trans = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,),(1.0,)),
    ]
)
train_set = torchvision.datasets.MNIST(
    root=root, train=True, transform=trans, download=True
)
test_set = torchvision.datasets.MNIST(
    root=root, train=False, transform=trans, download=True
)
17.9%
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
100.0%
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
100.0%
6.0%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
100.0%
100.0%
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

MNIST is a popular benchmark dataset of hand-written digits, 0-9. Each digit is a 28x28 pixel image, with 784 total pixel features. In the interest of reducing runtime, we will use vectorized images as our features in this dataset.

[3]:
num_classes = 10
num_train_samples = 60000
num_test_samples = 10000
num_pixels = 784

We will collect 60,000 training samples and 10,000 test samples. We vectorize the images and one-hot encode the class labels.

[4]:
train_features = torch.zeros((num_train_samples,num_pixels))
train_responses = torch.zeros((num_train_samples,num_classes))

for i in range(num_train_samples):
    train_features[i,:] = train_set[i][0].flatten()
    train_responses[i,:] = one_hot(
        torch.tensor(train_set[i][1]).to(torch.int64),
        num_classes=num_classes,
    )

test_features = torch.zeros((num_test_samples,num_pixels))
test_responses = torch.zeros((num_test_samples,num_classes))

for i in range(num_test_samples):
    test_features[i,:] = test_set[i][0].flatten()
    test_responses[i,:] = one_hot(
        torch.tensor(test_set[i][1]).to(torch.int64),
        num_classes=num_classes,
    )

We set up our nearest neighbor lookup structure using the NN_Wrapper data structure in MuyGPs. We then define our batch and construct tensor containing the features and targets of the batched elements and their 30 nearest neighbors. We choose an algorithm that will return the exact nearest neighbors. We set a random seed for reproducability.

[5]:
from torch import nn
import random
from torch.optim.lr_scheduler import ExponentialLR
torch.autograd.set_detect_anomaly(True)
np.random.seed(0)
test_count, _ = test_features.shape
train_count, _ = train_features.shape


from MuyGPyS.neighbors import NN_Wrapper
nn_count = 30
nbrs_lookup = NN_Wrapper(train_features, nn_count, nn_method="exact")

We sample a training batch of 500 elements and record their indices and those of their nearest neighbors.

[6]:

#We will make use of batching in our hyperparameter training
from MuyGPyS.optimize.batch import sample_batch
batch_count = 500
batch_indices, batch_nn_indices = sample_batch(
    nbrs_lookup, batch_count, train_count
)

batch_indices = batch_indices.astype(np.int64)
batch_nn_indices = batch_nn_indices.astype(np.int64)
batch_indices = torch.from_numpy(batch_indices)
batch_nn_indices = torch.from_numpy(batch_nn_indices)

batch_features = train_features[batch_indices,:]
batch_targets = train_responses[batch_indices, :]
batch_nn_targets = train_responses[batch_nn_indices, :]

if torch.cuda.is_available():
    train_features = train_features.cuda()
    train_responses = train_responses.cuda()
    test_features = test_features.cuda()
    test_responses = test_responses.cuda()

We now define a stochastic variational deep kernel MuyGPs class. This class composes a dense neural network embedding with a MuyGPyS.torch.muygps_layer Gaussian process layer. Presently, this layer only supports the Matérn kernel with special values of the nu or smoothness parameter set to 0.5, 1.5, 2.5, or \(\infty\). The smoothness values are limited because torch does not implement modified bessel functions of the second kind. Future versions of the library will also support other kernel types.

[7]:
from MuyGPyS.torch.muygps_layer import MuyGPs_layer
print('Building Stochastic Variational Deep Kernel MuyGPs model')

class SVDKMuyGPs(nn.Module):
    def __init__(
        self,
        num_models,
        kernel_eps,
        nu,
        length_scale,
        batch_indices,
        batch_nn_indices,
        batch_targets,
        batch_nn_targets,
    ):
        super().__init__()
        self.embedding = nn.Sequential(
            nn.Linear(784,400),
            nn.ReLU(),
            nn.Linear(400,200),
            nn.ReLU(),
            nn.Linear(200,100),
        )
        self.eps = kernel_eps
        self.nu = nu
        self.length_scale = length_scale
        self.batch_indices = batch_indices
        self.num_models = num_models
        self.batch_nn_indices = batch_nn_indices
        self.batch_targets = batch_targets
        self.batch_nn_targets = batch_nn_targets
        self.GP_layer = MuyGPs_layer(
            kernel_eps,
            nu,
            length_scale,
            batch_indices,
            batch_nn_indices,
            batch_targets,
            batch_nn_targets,
        )

    def forward(self,x):
        predictions = self.embedding(x)
        predictions, variances, sigma_sq = self.GP_layer(predictions)
        return predictions, variances, sigma_sq
Building Stochastic Variational Deep Kernel MuyGPs model

Training a Deep Kernel MuyGPs Model

We instantiate a SVDKMuyGPs model with initial guess hyperparameters. We fix a Matérn kernel smoothness parameter of 0.5 and a Guassian homoscedastic noise prior variance of 1e-6.

[8]:
model = SVDKMuyGPs(
    num_models=num_classes,
    kernel_eps=1e-6,
    nu=0.5,
    length_scale=1.0,
    batch_indices=batch_indices,
    batch_nn_indices=batch_nn_indices,
    batch_targets=batch_targets,
    batch_nn_targets=batch_nn_targets,
)
if torch.cuda.is_available():
    model = model.cuda()

We use the Adam optimizer over 10 training iterations, with an initial learning rate of 1e-2 and decay of 0.97.

[9]:
training_iterations = 10
optimizer = torch.optim.Adam(
    [{'params': model.parameters()}], lr=1e-2
)
scheduler = ExponentialLR(optimizer, gamma=0.97)

We will use cross-entropy loss, as it is commonly performant for classification problems. Other losses are available.

[10]:
ce_loss = nn.CrossEntropyLoss()
# mse_loss = nn.MSELoss()
# l1_loss = nn.L1Loss()
# bce_loss = nn.BCELoss()

We construct a standard PyTorch training loop function.

[11]:
def train(nbrs_lookup):
    for i in range(training_iterations):
        model.train()
        optimizer.zero_grad()
        predictions,variances,sigma_sq = model(train_features)
        loss = ce_loss(predictions,batch_targets)
        loss.backward()
        optimizer.step()
        scheduler.step()
        if np.mod(i,1) == 0:
            print(f"Iter {i + 1}/{training_iterations} - Loss: {loss.item()}")
            model.eval()
            nbrs_lookup = NN_Wrapper(
                model.embedding(train_features).detach().numpy(),
                nn_count, nn_method="exact"
            )
            batch_nn_indices,_ = nbrs_lookup._get_nns(
                model.embedding(batch_features).detach().numpy(),
                nn_count=nn_count,
            )
            batch_nn_indices = torch.from_numpy(
                batch_nn_indices.astype(np.int64)
            )
            batch_nn_targets = train_responses[batch_nn_indices, :]
            model.batch_nn_indices = batch_nn_indices
            model.batch_nn_targets = batch_nn_targets
        torch.cuda.empty_cache()
    nbrs_lookup = NN_Wrapper(
        model.embedding(train_features).detach().numpy(),
        nn_count,
        nn_method="exact",
    )
    batch_nn_indices,_ = nbrs_lookup._get_nns(
        model.embedding(batch_features).detach().numpy(),
        nn_count=nn_count,
    )
    batch_nn_indices = torch.from_numpy(
        batch_nn_indices.astype(np.int64)
    )
    batch_nn_targets = train_responses[batch_nn_indices, :]
    model.batch_nn_indices = batch_nn_indices
    model.batch_nn_targets = batch_nn_targets
    return nbrs_lookup, model

Finally, we execute the training function and evaluate the trained model

[12]:
nbrs_lookup, model_trained = train(nbrs_lookup)
model_trained.eval()
Iter 1/10 - Loss: 1.5117967128753662
Iter 2/10 - Loss: 1.47725510597229
Iter 3/10 - Loss: 1.4376972913742065
Iter 4/10 - Loss: 1.426263689994812
Iter 5/10 - Loss: 1.415531039237976
Iter 6/10 - Loss: 1.40116286277771
Iter 7/10 - Loss: 1.3888448476791382
Iter 8/10 - Loss: 1.376481056213379
Iter 9/10 - Loss: 1.3672336339950562
Iter 10/10 - Loss: 1.3580222129821777
[12]:
SVDKMuyGPs(
  (embedding): Sequential(
    (0): Linear(in_features=784, out_features=400, bias=True)
    (1): ReLU()
    (2): Linear(in_features=400, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=100, bias=True)
  )
  (GP_layer): MuyGPs_layer()
)

We then compute and report the performance of the predicted test responses using this trained model.

[13]:
from MuyGPyS.examples.muygps_torch import predict_model
predictions, variances,sigma_sq = predict_model(
    model=model_trained,
    test_features=test_features,
    train_features=train_features,
    train_responses=train_responses,
    nbrs_lookup=nbrs_lookup,
    nn_count=nn_count,
)
print("MNIST Prediction Accuracy Using Low-Level Torch Implementation:")
print(
    (
        torch.sum(
            torch.argmax(predictions,dim=1) == torch.argmax(test_responses,dim=1)
        ) / 10000
    ).numpy()
)
MNIST Prediction Accuracy Using Low-Level Torch Implementation:
0.9385

Training a Deep Kernel MuyGPs Model Using Our Example API Function

Similar to our one-line regression tutorial API, we support a one-line Deep MuyGPs regression API. This snippet performs the same work as above with a singular function execution.

[14]:
#Import high-level API function train_deep_kernel_muygps
from MuyGPyS.examples.muygps_torch import train_deep_kernel_muygps

#Use leave-one-out-likelihood loss function to train model
model = SVDKMuyGPs(
    num_models=num_classes,
    kernel_eps=1e-6,
    nu=0.5,
    length_scale=1.0,
    batch_indices=batch_indices,
    batch_nn_indices=batch_nn_indices,
    batch_targets=batch_targets,
    batch_nn_targets=batch_nn_targets)

nbrs_lookup, model_trained = train_deep_kernel_muygps(
    model=model,
    train_features=train_features,
    train_responses=train_responses,
    batch_indices=batch_indices,
    nbrs_lookup=nbrs_lookup,
    training_iterations=10,
    optimizer_method=torch.optim.Adam,
    learning_rate=1e-2,
    scheduler_decay=0.97,
    loss_function="ce",
    update_frequency=1,
    verbose=True,
)

model_trained.eval()
Iter 1/10 - Loss: 1.5146902800
Iter 2/10 - Loss: 1.4795417786
Iter 3/10 - Loss: 1.4437427521
Iter 4/10 - Loss: 1.4335327148
Iter 5/10 - Loss: 1.4196856022
Iter 6/10 - Loss: 1.4036072493
Iter 7/10 - Loss: 1.3896402121
Iter 8/10 - Loss: 1.3768211603
Iter 9/10 - Loss: 1.3684673309
Iter 10/10 - Loss: 1.3605209589
[14]:
SVDKMuyGPs(
  (embedding): Sequential(
    (0): Linear(in_features=784, out_features=400, bias=True)
    (1): ReLU()
    (2): Linear(in_features=400, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=100, bias=True)
  )
  (GP_layer): MuyGPs_layer()
)

We similarly report our prediction performance on the test responses using this trained model.

[15]:
from MuyGPyS.examples.muygps_torch import predict_model
predictions,variances,sigma_sq = predict_model(
    model=model_trained,
    test_features=test_features,
    train_features=train_features,
    train_responses=train_responses,
    nbrs_lookup=nbrs_lookup,
    nn_count=nn_count,
)

print("MNIST Prediction Accuracy Using High-Level Training API:")
print(
    (
        torch.sum(
            torch.argmax(predictions,dim=1) == torch.argmax(test_responses,dim=1)
        ) / 10000
    ).numpy()
)
MNIST Prediction Accuracy Using High-Level Training API:
0.934

We note that this is quite mediocre performance on MNIST. In the interest of reducing notebook runtime we have used a simple fully-connected neural network model to construct the Gaussian process kernel. To achieve results closer to the state-of-the-art (near 100% accuracy), we recommend using more complex architectures which integrate convolutional kernels into the model.