muygps_torch

Resources and high-level API for a deep kernel learning with MuyGPs.

train_deep_kernel_muygps() is a high-level API for training deep kernel MuyGPs models for regression.

predict_model() is a high-level API for generating predictions at test locations given a trained model.

MuyGPyS.examples.muygps_torch.predict_model(model, test_features, train_features, train_responses, nbrs_lookup, nn_count)[source]

Generate predictions using a PyTorch model containing a MuyGPyS.torch.muygps_layer.MuyGPs_layer layer or a MuyGPyS.torch.muygps_layer.MultivariateMuyGPs_layer layer in its structure. Note that the custom PyTorch layers for MuyGPs objects only support the Matern kernel. Support for more kernels will be added in future releases.

Example

>>> #model must be defined as a PyTorch model inheriting from
... #torch.nn.Module. Must have two components: model.embedding
... #(e.g., a neural net) and another component model.GP_layer.
>>> from MuyGPyS.testing.test_utils import _make_gaussian_data
>>> from MuyGPyS.neighbors import NN_Wrapper
>>> train, test = _make_gaussian_data(10000, 1000, 100, 10)
>>> nn_count = 10
>>> nbrs_lookup = NN_Wrapper(train['input'], nn_count, nn_method="hnsw")
>>> predictions, variances = predict_model(
... model,
... torch.from_numpy(test['input']),
... torch.from_numpy(train['input']),
... torch.from_numpy(train['output']),
... nbrs_lookup,
... nn_count)
Parameters:
  • model – A custom PyTorch.nn.Module object containing an embedding component and one MuyGPs_layer or MultivariateMuyGPS_layer layer.

  • test_features (Tensor) – A torch.Tensor of shape (test_count, feature_count) containing the test features to be regressed.

  • train_features (Tensor) – A torch.Tensor of shape (train_count, feature_count) containing the training features.

  • train_responses (Tensor) – A torch.Tensor of shape (train_count, response_count) containing the training responses corresponding to each feature.

  • nbrs_lookup (NN_Wrapper) – A NN_Wrapper nearest neighbor lookup data structure.

Returns:

  • predictions – A torch.Tensor of shape (test_count, response_count) whose rows are the predicted response for each of the given test feature.

  • variances – A torch.Tensor of shape (batch_count,) consisting of the diagonal elements of the posterior variance, or a matrix of shape (batch_count, response_count) for a multidimensional response.

MuyGPyS.examples.muygps_torch.predict_multiple_model(model, test_features, train_features, train_responses, nbrs_lookup, nn_count)[source]

Generate predictions using a PyTorch model containing a MuyGPyS.torch.muygps_layer.MultivariateMuyGPs_layer in its structure. Meant for the case in which there is more than one GP model used to model multiple outputs. Note that the custom PyTorch MultivariateMuyGPs_layer objects only support the Matern kernel. Support for more kernels will be added in future releases.

Parameters:
  • model – A custom PyTorch.nn.Module object containing an embedding component and one MuyGPyS.torch.muygps_layer.MultivariateMuyGPs_layer layer.

  • test_features (Tensor) – A torch.Tensor of shape (test_count, feature_count) containing the test features to be regressed.

  • train_features (Tensor) – A torch.Tensor of shape (train_count, feature_count) containing the training features.

  • train_responses (Tensor) – A torch.Tensor of shape (train_count, response_count) containing the training responses corresponding to each feature.

  • nbrs_lookup (NN_Wrapper) – A NN_Wrapper nearest neighbor lookup data structure.

Returns:

  • predictions – A torch.Tensor of shape (test_count, response_count) whose rows are the predicted response for each of the given test feature.

  • variances – A torch.Tensor of shape (batch_count,) consisting of the diagonal elements of the posterior variance, or a matrix of shape (batch_count, response_count) for a multidimensional response.

MuyGPyS.examples.muygps_torch.predict_single_model(model, test_features, train_features, train_responses, nbrs_lookup, nn_count)[source]

Generate predictions using a PyTorch model containing at least one MuyGPyS.torch.muygps_layer.MuyGPs_layer in its structure. Note that the custom PyTorch MuyGPs_layer objects only support the Matern kernel. Support for more kernels will be added in future releases.

Parameters:
  • model – A custom PyTorch.nn.Module object containing an embedding component and one MuyGPyS.torch.muygps_layer.MuyGPs_layer layer.

  • test_features (Tensor) – A torch.Tensor of shape (test_count, feature_count) containing the test features to be regressed.

  • train_features (Tensor) – A torch.Tensor of shape (train_count, feature_count) containing the training features.

  • train_responses (Tensor) – A torch.Tensor of shape (train_count, response_count) containing the training responses corresponding to each feature.

  • nbrs_lookup (NN_Wrapper) – A NN_Wrapper nearest neighbor lookup data structure.

Returns:

  • predictions – A torch.Tensor of shape (test_count, response_count) whose rows are the predicted response for each of the given test feature.

  • variances – A torch.Tensor of shape (batch_count,response_count) shape consisting of the diagonal elements of the posterior variance.

MuyGPyS.examples.muygps_torch.train_deep_kernel_muygps(model, train_features, train_responses, batch_indices, nbrs_lookup, training_iterations=10, optimizer_method=<class 'torch.optim.adam.Adam'>, learning_rate=0.001, scheduler_decay=0.95, loss_function='lool', update_frequency=1, verbose=False, nn_kwargs={})[source]

Train a PyTorch model containing an embedding component and a MuyGPyS.torch.muygps_layer.MuyGPs_layer layer or a MuyGPyS.torch.muygps_layer. MultivariateMuyGPs_layer layer in its structure. Note that the custom PyTorch layers for MuyGPs models only support the Matern kernel. Support for more kernels will be added in future releases.

Example

>>> #model must be defined as a PyTorch model inheriting from
... #torch.nn.Module. Must have two components: model.embedding
... #(e.g., a neural net) and another component model.GP_layer.
>>> from MuyGPyS.testing.test_utils import _make_gaussian_data
>>> from MuyGPyS.neighbors import NN_Wrapper
>>> from MuyGPyS.examples.muygps_torch import train_deep_kernel_muygps
>>> from MuyGPyS._src.optimize.loss import _lool_fn as lool_fn
>>> train, test = _make_gaussian_data(10000, 1000, 100, 10)
>>> nn_count = 10
>>> nbrs_lookup = NN_Wrapper(train['input'], nn_count, nn_method="hnsw")
>>> batch_count = 100
>>> train_count = 10000
>>> batch_indices, batch_nn_indices = sample_batch(nbrs_lookup,
... batch_count, train_count)
>>> nbrs_struct, model_trained = train_deep_kernel_muygps(
... model=model,
... train_features=torch.from_numpy(train['input']),
... train_responses=torch.from_numpy(train['output']),
... batch_indices=torch.from_numpy(batch_indices),
... nbrs_lookup=nbrs_lookup,
... training_iterations=10,
... optimizer_method=torch.optim.Adam,
... learning_rate=1e-3,
... scheduler_decay=0.95,
... loss_function=lool_fn,
... update_frequency=1)
Parameters:
  • model – A custom PyTorch.nn.Module object containing at least one embedding layer and one MuyGPs_layer or MultivariateMuyGPS_layer layer.

  • train_features (Tensor) – A torch.Tensor of shape (train_count, feature_count) containing the training features.

  • train_responses (Tensor) – A torch.Tensor of shape (train_count, response_count) containing the training responses corresponding to each feature.

  • batch_indices (Tensor) – A torch.Tensor of shape (batch_count,) containing the indices of the training batch.

  • nbrs_lookup (NN_Wrapper) – A NN_Wrapper nearest neighbor lookup data structure.

  • training_iterations – The number of training iterations to be used in training.

  • method (optimizer) – An optimization method from the torch.optim class.

  • learning_rate – The learning rate to be applied during training.

  • schedule_decay – The exponential decay rate to be applied to the learning rate.

  • function (loss) – The loss function to be used in training. Defaults to “lool” for leave-one-out likelihood. Other options are “mse” for mean-squared error, “ce” for cross entropy loss, “bce” for binary cross entropy loss, and “l1” for L1 loss.

  • update_frequency – Tells the training procedure how frequently the nearest neighbor structure should be updated. An update frequency of n indicates that every n epochs the nearest neighbor structure should be updated.

  • verbose – Indicates whether or not to include print statements during training.

  • nn_kwargs (Dict) – Parameters for the nearest neighbors wrapper. See MuyGPyS.neighbors.NN_Wrapper for the supported methods and their parameters.

Returns:

  • nbrs_lookup – A NN_Wrapper object containing the nearest neighbors of the embedded training data.

  • model – A trained deep kernel MuyGPs model.

MuyGPyS.examples.muygps_torch.update_nearest_neighbors(model, train_features, train_responses, batch_indices, nn_count, nn_kwargs={})[source]

Update the nearest neighbors after deformation via a PyTorch model containing an embedding component and a MuyGPyS.torch.muygps_layer.MuyGPs_layer layer or a MuyGPyS.torch.muygps_layer. MultivariateMuyGPs_layer layer in its structure.

Example

>>> #model must be defined as a PyTorch model inheriting from
... #torch.nn.Module. Must have two components: model.embedding
... #(e.g., a neural net) and another component model.GP_layer.
>>> from MuyGPyS.testing.test_utils import _make_gaussian_data
>>> from MuyGPyS.neighbors import NN_Wrapper
>>> from MuyGPyS.examples.muygps_torch import update_nearest_neighbors
>>> train, test = _make_gaussian_data(10000, 1000, 100, 10)
>>> nn_count = 10
>>> batch_count = 100
>>> train_count = 10000
>>> batch_indices, batch_nn_indices = sample_batch(nbrs_lookup, batch_count, train_count)
>>> nbrs_struct, model_trained = update_nearest_neighbors(
... model=model,
... train_features=torch.from_numpy(train['input']),
... train_responses=torch.from_numpy(train['output']),
... batch_indices=torch.from_numpy(batch_indices),
... nn_count=nn_count,)
Parameters:
  • model – A custom PyTorch.nn.Module object containing at least one embedding layer and one MuyGPs_layer or MultivariateMuyGPS_layer layer.

  • train_features (Tensor) – A torch.Tensor of shape (train_count, feature_count) containing the training features.

  • train_responses (Tensor) – A torch.Tensor of shape (train_count, response_count) containing the training responses corresponding to each feature.

  • batch_indices (Tensor) – A torch.Tensor of shape (batch_count,) containing the indices of the training batch.

  • nn_count (int) – A torch.int64 giving the number of nearest neighbors.

  • nn_kwargs (Dict) – Parameters for the nearest neighbors wrapper. See MuyGPyS.neighbors.NN_Wrapper for the supported methods and their parameters.

Returns:

  • nbrs_lookup – A NN_Wrapper object containing the updated nearest neighbors of the embedded training data.

  • model – A deep kernel MuyGPs model with updated nearest neighbors.