{ "cells": [ { "cell_type": "markdown", "id": "2c123e93", "metadata": {}, "source": [ "Copyright 2021-2024 Lawrence Livermore National Security, LLC and other MuyGPyS Project Developers. See the top-level COPYRIGHT file for details.\n", "\n", "SPDX-License-Identifier: MIT\n", "\n", "# Deep Kernels with MuyGPs in PyTorch Tutorial\n", "\n", "In this tutorial, we outline how to construct a simple deep kernel model using the PyTorch implementation of MuyGPs.\n" ] }, { "cell_type": "markdown", "id": "b83dc4cd", "metadata": {}, "source": [ "We use the MNIST classification problem as a benchmark.\n", "We will use the deep kernel MuyGPs model to classify images of handwritten digits between 0 and 9.\n", "In order to reduce the runtime of the training loop, we will use a fully-connected architecture, meaning we will have to vectorize each image prior to training.\n", "We download the training and testing data using the torchvision.datasets API." ] }, { "cell_type": "code", "execution_count": 1, "id": "3c5b93d3", "metadata": { "nbsphinx": "hidden" }, "outputs": [], "source": [ "import sys\n", "for m in sys.modules.keys():\n", " if m.startswith(\"Muy\"):\n", " sys.modules.pop(m)" ] }, { "cell_type": "markdown", "id": "e2a52af1", "metadata": {}, "source": [ "First, we will import necessary dependencies.\n", "We also force MuyGPyS to use the `\"torch\"` backend.\n", "This can also be done by setting the `MUYGPYS_BACKEND` environment variable to `\"torch\"`. " ] }, { "cell_type": "code", "execution_count": 2, "id": "ae9bb692", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "env: MUYGPYS_BACKEND=torch\n", "env: MUYGPYS_FTYPE=32\n" ] } ], "source": [ "%env MUYGPYS_BACKEND=torch\n", "%env MUYGPYS_FTYPE=32" ] }, { "cell_type": "code", "execution_count": 3, "id": "9f72675f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "I0000 00:00:1707872829.179408 1 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.\n" ] } ], "source": [ "import numpy as np\n", "import random\n", "import matplotlib.pyplot as plt\n", "import os\n", "import torch\n", "import torchvision\n", "from MuyGPyS.examples.muygps_torch import predict_model\n", "from MuyGPyS.gp import MuyGPS \n", "from MuyGPyS.gp.deformation import l2, Isotropy\n", "from MuyGPyS.gp.hyperparameter import Parameter\n", "from MuyGPyS.gp.kernels import Matern\n", "from MuyGPyS.gp.noise import HomoscedasticNoise\n", "from MuyGPyS.neighbors import NN_Wrapper\n", "from MuyGPyS.optimize.batch import sample_batch\n", "from MuyGPyS.torch import MuyGPs_layer\n", "from torch import nn\n", "from torch.nn.functional import one_hot\n", "from torch.optim.lr_scheduler import ExponentialLR" ] }, { "cell_type": "markdown", "id": "14b93a3d", "metadata": {}, "source": [ "We set the target directory for torch to download MNIST." ] }, { "cell_type": "code", "execution_count": 4, "id": "888637aa", "metadata": {}, "outputs": [], "source": [ "root = './data'\n", "if not os.path.exists(root):\n", " os.mkdir(root)" ] }, { "cell_type": "markdown", "id": "124c8d0e", "metadata": {}, "source": [ "We use torch's utilities to download MNIST and transform it into an appropriately normalized tensor." ] }, { "cell_type": "code", "execution_count": 5, "id": "56e51b89", "metadata": {}, "outputs": [], "source": [ "trans = torchvision.transforms.Compose(\n", " [\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize((0.5,),(1.0,)),\n", " ]\n", ")\n", "train_set = torchvision.datasets.MNIST(\n", " root=root, train=True, transform=trans, download=True\n", ")\n", "test_set = torchvision.datasets.MNIST(\n", " root=root, train=False, transform=trans, download=True\n", ")" ] }, { "cell_type": "markdown", "id": "2914f92e", "metadata": {}, "source": [ "MNIST is a popular benchmark dataset of hand-written digits, 0-9.\n", "Each digit is a 28x28 pixel image, with 784 total pixel features." ] }, { "cell_type": "code", "execution_count": 6, "id": "c49a0be9", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAx8AAACHCAYAAAB3eeQWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAW5klEQVR4nO3deXxU1d3H8TOZBJJgWBJIIBDWEMIqKAhYJKJItQ+CCoioj5QuFhBUxOWRp9VasdVqVUBwLeDSoi9xo260WKQqmyigIgQEEoKBsIdAFpKZef7o0wzf22ayzplM8nn/0/vtvXPnSE7uzMk9v3tcPp/PZwAAAAAgyCJC3QAAAAAAjQODDwAAAABWMPgAAAAAYAWDDwAAAABWMPgAAAAAYAWDDwAAAABWMPgAAAAAYEVkTV/o9XpNbm6uiYuLMy6Xqy7bhCDz+XymoKDAJCcnm4iIuht/0ifCU7D6gzH0iXDFNQJO9Ak40SfgVNU+UePBR25urklJSanpy1EP5OTkmA4dOtTZ+egT4a2u+4Mx9IlwxzUCTvQJONEn4FRZn6jx4CMuLs4YY8ww8yMTaaJqehqEQJkpNZ+a98t/hnWFPhGegtUfjKFPhCuuEXCiT8CJPgGnqvaJGg8+/nUrLNJEmUgXnSOs+P75P3V9O5M+EaaC1B/OPid9IsxwjYATfQJO9Ak4VbFPUHAOAAAAwAoGHwAAAACsYPABAAAAwAoGHwAAAACsYPABAAAAwAoGHwAAAACsYPABAAAAwAoGHwAAAACsYPABAAAAwAoGHwAAAACsYPABAAAAwAoGHwAAAACsYPABAAAAwIrIUDcACFdll5wv+cD0Eslbh74o+dx1kyUnL2wi2b36yzpsHQAAQP3DnQ8AAAAAVjD4AAAAAGAFgw8AAAAAVjTqmg9XpP7nu9u0rvJrM+/sLNkT65XcqdshybHTXZIPPq7z/b8c+JrkI57Tkge/Plty6h3rq9xW1A1vxgDJ8xc/JTk1SvuT9ghjNg9dIjlzoEfyXZ2H1K6BaHBOjx8s+ZHfPy35wWtvkuzb9E3Q24Tg2f3oUMnbr9drTJTLLXn49Jslx7y9MTgNAxA07oR4ya4WzSXvG5csubi1T3LqA1slewsL67B1wcGdDwAAAABWMPgAAAAAYAWDDwAAAABWhHXNh7tnd8m+plGSczNaSi4aonUU8S00f3Ku1l3UxgeFcZIfeepyyRv6/lny3tIiyQ/nXSY5+ROd4wc7SkcNLN++e9HLsi8tSut2vI4qjz2lpZLzvU0lD9BoSq4YJDlm9dd6/uLiyhvcABWNvUBzgs57j1+8zmZzrDo0UP8+9GDWlSFqCYLh4KwLJX888feSS316jfk3fCwAYSGiT3r59q57Y2TfT/qulTw7YWW1zt0zaark7j/+opqts487HwAAAACsYPABAAAAwAoGHwAAAACsCKuaD8/F50l+fOlCyc45+DaV+nTNhvsW/Fhy5GmdnDv09RmS474vk9z0iNaAxG7aUMsW4j9xN9fnaZ8eni551hP+2pwRMaccrw48dl96XOdzf7RIn+H/2a/nS/7bC89I7vWK9pGu9zTc2oZAcofrv3NstxN6wGJ7bQm6CK1n8XXU68CliTskf+TSPobwcipF68TiI0L3GYaaO/PDgZKzb/D/XKedt0b23d5qZ8Bz9X1hpuTYA/rd4cSFJZI7/Umvj01WbgrcWASFa1Bfyd/N0mv5x8P8a/a0cWvBZ4Tju8R7ha0k7ylJlHxLq0zJLw9/XvKDgyZL9n2u9aP1AXc+AAAAAFjB4AMAAACAFQw+AAAAAFgRVjUfTTNzJX9RnCI5LSqvzt5r9oEhkvecai15abflkvO9Oi8zab4+t7m6eHy7Hftfai/580ELKziy+n6T+LnkD8/R+flTskZJfrHzKsnNex2ts7aEswdGvy75ke2jKjgy/Lm7dZK8I0MLWvpvvFFycj2cy4uKnZowWPIbV89zHOGS9MwJrUFbda3WFjTL3iZZK0gQLIenav3egrv1c2NgU38NqHM+/+SskZIHtNgneevPnH1COc93YfwkyfHVWyICVeRu00byznn63eEvFy6S3DVK150zxrGw11mWnNTvsm+PGybZ61jD7pZ3tebj7P5mjDFFSbqOSHSF7xw63PkAAAAAYAWDDwAAAABWMPgAAAAAYEVY1XyUHTgoecEjEyQ/dPlpye6vzpG8dfqCgOefe6Rf+fZ3I2Nln+fEAcnXD50uOetWPVcXszXgeyE0yi45X/Ky/k9JjjAVP2d/Svalkjet6in565/quVYX6UzLxE26ZsN3x3U+d9RvV2tbdPp3oxXlKqv8oAYi8oXCgPuLdjcPuB/1S/HoCyTf/zut4UmLCvxL/uLzl0tu+23taglRNS7HmmHFI8+V/Ma9j0pOjtT5/D/Nvqx8O/uxHrKv2XtbJK+O7Sh5zVtp+l7dVwRs68ktCZLjAx6Nmvr+xu6St2U4a3OcNR4Ve8VZ43GV1oN6MnUtGNeA3lU+d7jgzgcAAAAAKxh8AAAAALCCwQcAAAAAK8Kq5sMpfsk6yW3+onMfPUePSe7d5yeStw3X+bcrnsso3048EXhurWud1nR0WVfBgQgpb8YAyfMXa11GapT+CngdT8ofs+Pq8m33eK0pavlfuhpLr5dnSE5bmCM5Imez5FafaFtLH9Jndb/RT/vnT0ZoYZF79ZemIfIO6y/5ouhPQ9OQEOjcLPDaLimrPAH3o345cGOx5BExxY4j3JKca0C0nUeNRygcmKHrqWy80zm/X2s8Jnx3peSycaXl27FHNsg+5xpeuTdrHeKG7oHX+figME5y6rP6OdN4KuTsaj8mq1rHLz/VVvLjO/01o0l3ay/wZO4KeK7jfRterR93PgAAAABYweADAAAAgBUMPgAAAABYEdY1H06eI4HnS5eerHgNB2OM6X3Dt+Xbh5/WubjGy1zrcOA6X5+HfeQOXVsjzfH89i9K9PV/P9VL8tFX/c/jTjiuhT0tXlmv2dGW2s69TXLrvOKjt+saEIm6LEiDkT06RnKiO7aCI8NfZGd9xv/4+MDP9I/Ze1wyV6X6JbJDe8nbLloiudSnP7HtpRLNvsd1jYdmRusFEBy7FgyWnHmNrgmmlYDG9PzbVMnpd2ZJruy7yNmmTnunyscaY8zchyZLbpVDwakVP9fP4163zJSc8jf93W62Tdela53tX7ujutftwqSGt+gXdz4AAAAAWMHgAwAAAIAVDD4AAAAAWNGgaj4q0/OenZKn9L1U8pJOH5VvZ0y4RfbFvabz+1E/RMRqPUDZ709KXp/+puS9ZWck3zFntuRWn+yTnNjsUPl2qOfXX9AuW3JWaJoRdJGpBQH3F+9oaachFuQ82UzyD5rq7PI/nuygLzih/Ruh5+7do3x74J+/qdZrJ76pa/d0e4PPGRt2/2GI5MxrFkrO9+p6LBN2XC+5x0z9LuEpqPiaFdFMf8ePju8neew5j+rxRmve0l/X7yKpS6nxCAXPd3slp87aW8GR/1SX662UDgr8mRiOuPMBAAAAwAoGHwAAAACsYPABAAAAwIpGVfPhOZEv+ei0npL3rfCvCfE/c1+Sffdee7Vk32Zd1SHlIcc8TJ+vps1ENRRl6LoeK9MXBTz+Z7fNkhz3ts6xrst5mgiOxE3Op+7XH+7WCZLzxum6DfHX7pe8Ju2PjjNES3p64VWSE/PW1qp9qHvZY/w/8+UJmx17db2o63dfKTnt4d2SQ11X1lC5kxIlv3i1fk54HSt5OGs8mlym9XaVXYEi+vvXi+qzeLvsm5s033G0rh/xgy3XSe7xa309fSQ87bvvwvLtsljH90PnMh6O3dd0D1znM2P/xZJjPvwy0OnqBe58AAAAALCCwQcAAAAAKxrVtCsn71a9nXndA3eVb//p/sdk35YhOg3L6JP6TO9mMyR3f/6A5LI9WTVrJALq9+AWyRGO8fSUbH2ccszbG4PdpBqLcukUjVLHvVK3qz7ePLWvKF5/xs0qOK4i3osGSPa59Z53zkidBnEmuVRyRBP/xIe/XrRA9kU5bp8f9Oi5frVHp28e8+oEjtgInVSRtEEfsUgPCL1jU4ZKfmvq2Y9KjZJ9U3MyJJdO1v7gOayP9kZwuKL1331g08CTl2JubaKv75QieddUfQT2qJE6zWVW4nPl2x0j9dG5zilbHscUbddrrXX/iV0B24rQcDdvLrn4gu6So+7Nk/xVun5WyLH/9tkfuH+uLtIlBvbf3FGyr0y/29ZH3PkAAAAAYAWDDwAAAABWMPgAAAAAYEWjrvlwil/sf5zZjMxbZF/zh/URmcu6rpS87aanJKen/Exyjwd0nOfZtafG7WzMTvy3zrf+ZZLW5niNztX94q+9JHc09fdRpc55ns7HP364Xf9buhudZ9xQlBTrvHmvo9JhyZwnJK+Y0b9a578n4QXJEY7nHBb5zkjO9ejP5anDF5dvj1x1u+xruVn7X7u/6rxfV7ZeRw5v1/ngSW6tL/F9/rVBaLl795C8du5TjiOiTUXW7e8sOSXrmzpqFarDV1wieUOJXmMGN9Xfu3dWvSrZeS2uzKoif93GLkfx3oiYU5I3ndFrRsuXAj9WFXa4mjpq/zL6Sp616GXJI2I+kpzn0T63uqhV+fZ9O8fKvmW9l0pOjtT3doqO0P6659qWkrtm6jXJW1wc8HyhwJ0PAAAAAFYw+AAAAABgBYMPAAAAAFZQ81EB12dbJBeOT5Q8aOJMyRvumSd5xwidV35D51GS84fVsoGNVJlOkTctInS+7LpinSvZ9aVcfX1QWlU1EbH6bO4dj/VxHPGFpBv2XCE5/ba9kgM/CTx8pd64WXLv3+kaOimDvq/V+VcfSpN8+AN9Zn/CNp1P2+TDzx1n8O9PM5sCvpfzZ/T9PRdKHtRU53e/eqp9wPPBvp1z9Pe2smfwn63jw5pZpyU0PHmHJN8/TWsyH3tmkeR++rFiXjmp63zMXTNGctpSnVMfmZdfvp247JjsG5Hyd8mTV2tbKrumIDgiorVO4uhEXQ/qk9/OD/j63sv0O2GH1XqdaPqe/3MkoZ3W/Sxbeb7k2QmBa8OcNUpf/VjbNjTnVslJL22V7C0sDHh+G7jzAQAAAMAKBh8AAAAArGDwAQAAAMAKaj6qyDlnNGm+5uK7tZog1qWTRp/v/K7k0Vffrse/taGWLYQxxhz1nCO5bE9WaBpi/r3GI/NhfU74jrG6XsAHhS0k5y5MlRx3fH0dti58dLk3uM+9b2f2BfX8Z4sdfjjg/l+uHic5zWwMZnPwH3gzdK733IFvV/m1l31zneRzNrGuR33UZKXWVczpckG1Xl/Z72XBWP/53uv4juwr9enffGOyHAUmsMK5jseOx/tpHhu4xmNs5lWS0x7Vtduc3xkjU/y1heeu0M+cuxK+lZzv1bWmBr8xW3K7dD33R31fk7zuV9r2iZNGSz4yX7+LRB/VGhIn98d1v6YYdz4AAAAAWMHgAwAAAIAVDD4AAAAAWEHNRwW8w/pL3j1BnwHdp3+WZGeNh9OCYzqPOPYdnuUdDHd+NkFymmPtjGByzhU/dEeR5O0Dtcbj0q8nSm52uc4ZjTONs8ajMev0DitBhNpDS5+T3Ccq8M/kzgPDy7dbTDou+xrqWjwIrCzG/3dd57owXuOV3GWpzv8P5VpUDZkrUr/uZj55ruQdYxZK3l9WInnMs3dL7rx4t+QyR41H6Uhdu6PPI/71q+5P1O8lS052kvzy/14pOfVN/S7gbp0g+eLLdI2R0xPzJb814HnJHeZrvYvTu6f1/M+ldQ14fE1w5wMAAACAFQw+AAAAAFjB4AMAAACAFY265sM1sI/knbf66zae/8GLsm94tD53uTIlPn1u8vpjXfQA74FqnQ//z6UxwjF+njdsmeSFJi1oTcn+zVDJb9z0uOS0KK0DOm/jZMnJV+uzvQGE3oAmek1xztl3WrfkvPLtxONrg9ImhJe4V8+ao/+H0LUDfjl36VouO8bMk5zrqPGY8PBdkju/rTWZxy7R73S+G+MkL++j52/j9tdZ9H5VazTSnjsiOTYz8LpvniNHJTdf5sx6/PjpWq+SND474PnN7JaO/2Nb4ONrgDsfAAAAAKxg8AEAAADACgYfAAAAAKxo0DUfkV302cm7pyRL/vXEVyWPO0fn3VXHnLyBktfMGyK51YvranxunMXxyH3nM9MzYnTu4+1L9Vnb3Zbo8VEHCyTnZbSRHD9xf/n2zI4fyb4rYvVZ3StOJ0m+6evLJbd+tpkBzuZ26d9/jqdFSW77gc3WNE45y7X2L8q1pVqvb/ex/3ODdT1gjDEF1539+W9vrSlU7OmfLwq4P9pRT3rl1H9Ibn+rruEzuflfKnlHXUuj959vLd9Ovfdz2ecpC+7qLomLtBbNF/ifwhjzfdDa8i/c+QAAAABgBYMPAAAAAFYw+AAAAABgRVjXfER27ig5//x2kif+5kPJU1u+WeP3mn1AazjWLdIaj/ilGyW38lLjEQrRLu3S2y97RvKnF0VL3lXSVvKUFllVfq/bci+S/OHa/pK737beAIF4fFqDxJ+Dgs+bMUDyk/1fkexc1yPfWyx50Ae3S07PZr0eqPyu/CLXN/84lS55cNOvJce7tUZjTustAc83esc1kvet6yC56/J8yanb/LU/viDXeIQDfkMAAAAAWMHgAwAAAIAVDD4AAAAAWFGvaz4i2+l8/GOLdZ2EaV3WSJ4Ul1er95vx/bDy7S+f7i/7Wi//RnJ8ATUdoZD08SHJ9/xiqORH2gb+uQyPPiN5WHRWwOM3l/jH55PW3Cz70qbo89u7G2o8UDuFgwpD3YQGrzi+ieRh0acdR7glrSzU2sK0m/UZ/Y6qHcC0X+P/PY6aof2p1Oc8GjasHaHrvA2+4RLJ+efqd4PIw7rmUtozuvZF5EH9LtK5OEcy14XAuPMBAAAAwAoGHwAAAACsYPABAAAAwIqQ1nyc+aGulXFm1jHJc1Lflzwqxjk3t3ryPEWSh6+YLTn9lzvKt+NPaO0A8/fqB8/O3ZJ3TegsudfMmZK/vXZBtc6f/v50yT0W+efupm3+wnk4UCtuF3//ARoa12dbyreXnkyUfZPitHagsLeuT9YkZ3/Q2tWYeY7q98uk+Ws1V/J6VuaoW3zyAQAAALCCwQcAAAAAKxh8AAAAALAipDUfWVfp2Gdn39er9fqFJ7pJnrdmlGSXxyU5fe5eyd3zNkj2VOvdUR+U7cmSnDpL85hZg6p1vjSjz/DnkeyoSyWr2kj29KeazLbmWw5Knrlfn/f/TIquHwXUxhPPjpc86c55ktv96jvJR0/00xOs/yoo7QJCiTsfAAAAAKxg8AEAAADACgYfAAAAAKwIac1H2rSNkkdPO7925zMbA+6npgNAKLV9Qp8t/6MnzpPc1Wyx2JrGqWxvtuT9Q3T/aFO7zyHgbO1fzpQ88arRkl9LfVdyxn2TJMdf30Ky50R+HbYOCA3ufAAAAACwgsEHAAAAACsYfAAAAACwIqQ1HwAAAA2V58hRyWfGJUju+YdfSN4+8lnJY9J/qidk3Q80ANz5AAAAAGAFgw8AAAAAVjD4AAAAAGAFNR8AAAAWOGtAuk/WPMYMcryCGg80PNz5AAAAAGAFgw8AAAAAVtR42pXP5zPGGFNmSo3x1Vl7YEGZKTXG+H+GdYU+EZ6C1R/OPid9IrxwjYATfQJO9Ak4VbVP1HjwUVBQYIwx5lPzfk1PgRArKCgwLVq0qNPzGUOfCFd13R/+dU5j6BPhimsEnOgTcKJPwKmyPuHy1XDI6vV6TW5uromLizMul6vGDYR9Pp/PFBQUmOTkZBMRUXcz7+gT4SlY/cEY+kS44hoBJ/oEnOgTcKpqn6jx4AMAAAAAqoOCcwAAAABWMPgAAAAAYAWDDwAAAABWMPgAAAAAYAWDDwAAAABWMPgAAAAAYAWDDwAAAABWMPgAAAAAYAWDDwAAAABWMPgAAAAAYAWDDwAAAABWMPgAAAAAYMX/ASkuSHkaemLtAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(1, 6, figsize=(10, 3))\n", "for i, ax in enumerate(axes):\n", " ax.imshow(train_set.data[i, :, :])\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "66125f37", "metadata": {}, "source": [ "In the interest of reducing the runtime of this example, we will use vectorized images as our features in\n", "this dataset.\n", "We will collect 60,000 training samples and 10,000 test samples." ] }, { "cell_type": "code", "execution_count": 7, "id": "dabad0cf", "metadata": {}, "outputs": [], "source": [ "class_count = len(train_set.classes)\n", "train_count, x_pixel_count, y_pixel_count = train_set.data.shape\n", "test_count, _, _ = test_set.data.shape\n", "feature_count = x_pixel_count * y_pixel_count" ] }, { "cell_type": "markdown", "id": "aadcde3c", "metadata": {}, "source": [ "We vectorize the images and one-hot encode the class labels.\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "f71dd370", "metadata": {}, "outputs": [], "source": [ "train_features = torch.zeros((train_count, feature_count))\n", "train_responses = torch.zeros((train_count, class_count))\n", "\n", "for i in range(train_count):\n", " train_features[i,:] = train_set[i][0].flatten()\n", " train_responses[i,:] = one_hot(\n", " torch.tensor(train_set[i][1]).to(torch.int64),\n", " num_classes=class_count,\n", " )\n", "\n", "test_features = torch.zeros((test_count, feature_count))\n", "test_responses = torch.zeros((test_count, class_count))\n", "\n", "for i in range(test_count):\n", " test_features[i,:] = test_set[i][0].flatten()\n", " test_responses[i,:] = one_hot(\n", " torch.tensor(test_set[i][1]).to(torch.int64),\n", " num_classes=class_count,\n", " )" ] }, { "cell_type": "markdown", "id": "319abcfa", "metadata": {}, "source": [ "We set up our nearest neighbor lookup structure using the NN_Wrapper data structure in MuyGPs.\n", "We then define our batch and construct tensor containing the features and targets of the batched elements and their 30 nearest neighbors.\n", "We choose an algorithm that will return the exact nearest neighbors.\n", "We set a random seed for reproducability." ] }, { "cell_type": "code", "execution_count": 9, "id": "9bd9469e", "metadata": {}, "outputs": [], "source": [ "torch.autograd.set_detect_anomaly(True)\n", "np.random.seed(0)\n", "test_count, _ = test_features.shape\n", "train_count, _ = train_features.shape\n", "\n", "nn_count = 30\n", "nbrs_lookup = NN_Wrapper(train_features, nn_count, nn_method=\"exact\")" ] }, { "cell_type": "markdown", "id": "42073b2d", "metadata": {}, "source": [ "We sample a training batch of 500 elements and record their indices and those of their nearest neighbors." ] }, { "cell_type": "code", "execution_count": 10, "id": "fa31e88f", "metadata": {}, "outputs": [], "source": [ "batch_count = 500\n", "batch_indices, batch_nn_indices = sample_batch(\n", " nbrs_lookup, batch_count, train_count\n", ")\n", "\n", "batch_features = train_features[batch_indices,:]\n", "batch_targets = train_responses[batch_indices, :]\n", "batch_nn_targets = train_responses[batch_nn_indices, :]\n", "\n", "if torch.cuda.is_available():\n", " train_features = train_features.cuda()\n", " train_responses = train_responses.cuda()\n", " test_features = test_features.cuda()\n", " test_responses = test_responses.cuda()" ] }, { "cell_type": "markdown", "id": "67ef039a", "metadata": {}, "source": [ "We now define a stochastic variational deep kernel MuyGPs class.\n", "This class composes a dense neural network embedding with a `MuyGPyS.torch.muygps_layer` Gaussian process layer.\n", "Presently, this layer only supports the Matérn kernel with special values of the smoothness parameter set to 0.5, 1.5, 2.5, or $\\infty$.\n", "The smoothness values are limited because `torch` does not implement modified bessel functions of the second kind.\n", "Future versions of the library will also support other kernel types." ] }, { "cell_type": "code", "execution_count": 11, "id": "afa79b85", "metadata": {}, "outputs": [], "source": [ "class SVDKMuyGPs(nn.Module):\n", " def __init__(\n", " self,\n", " muygps_model,\n", " batch_indices,\n", " batch_nn_indices,\n", " batch_targets,\n", " batch_nn_targets,\n", " ):\n", " super().__init__()\n", " self.embedding = nn.Sequential(\n", " nn.Linear(784,400),\n", " nn.ReLU(),\n", " nn.Linear(400,200),\n", " nn.ReLU(),\n", " nn.Linear(200,100),\n", " )\n", " self.batch_indices = batch_indices\n", " self.batch_nn_indices = batch_nn_indices\n", " self.batch_targets = batch_targets\n", " self.batch_nn_targets = batch_nn_targets\n", " self.GP_layer = MuyGPs_layer(\n", " muygps_model,\n", " batch_indices,\n", " batch_nn_indices,\n", " batch_targets,\n", " batch_nn_targets,\n", " )\n", " self.deformation = self.GP_layer.deformation\n", " \n", " def forward(self,x): \n", " predictions = self.embedding(x)\n", " predictions, variances = self.GP_layer(predictions)\n", " return predictions, variances" ] }, { "cell_type": "markdown", "id": "14df56f9", "metadata": {}, "source": [ "## Training a Deep Kernel MuyGPs Model" ] }, { "cell_type": "markdown", "id": "25af10b7", "metadata": {}, "source": [ "We will use a Matérn kernel with a smoothness parameter of 0.5 and a Guassian homoscedastic noise prior variance of `1e-6`." ] }, { "cell_type": "markdown", "id": "62bc1010", "metadata": {}, "source": [ "⚠️ Presently the torch backend only supports fixed special case Matérn smoothness parameters with values 0.5, 1.5, or 2.5. ⚠️" ] }, { "cell_type": "markdown", "id": "ae5e29b8", "metadata": {}, "source": [ "⚠️ An isotropic length scale is the only torch-optimizable parameter. ⚠️" ] }, { "cell_type": "code", "execution_count": 12, "id": "d50b9f56", "metadata": {}, "outputs": [], "source": [ "muygps_model = MuyGPS(\n", " kernel=Matern(\n", " smoothness=Parameter(0.5),\n", " deformation=Isotropy(\n", " l2,\n", " length_scale=Parameter(1.0, (0.1, 2))\n", " ),\n", " ),\n", " noise=HomoscedasticNoise(1e-6),\n", ")" ] }, { "cell_type": "markdown", "id": "9cdc2d09", "metadata": {}, "source": [ "We instantiate a `SVDKMuyGPs` model using this MuyGPS model." ] }, { "cell_type": "code", "execution_count": 13, "id": "67b3a311", "metadata": {}, "outputs": [], "source": [ "model = SVDKMuyGPs(\n", " muygps_model = muygps_model,\n", " batch_indices=batch_indices,\n", " batch_nn_indices=batch_nn_indices,\n", " batch_targets=batch_targets,\n", " batch_nn_targets=batch_nn_targets,\n", ")\n", "if torch.cuda.is_available():\n", " model = model.cuda()" ] }, { "cell_type": "markdown", "id": "c92a3525", "metadata": {}, "source": [ "We use the Adam optimizer over 10 training iterations, with an initial learning rate of `1e-2` and decay of `0.97`." ] }, { "cell_type": "code", "execution_count": 14, "id": "96b50e7d", "metadata": {}, "outputs": [], "source": [ "training_iterations = 10\n", "optimizer = torch.optim.Adam(\n", " [{'params': model.parameters()}], lr=1e-2\n", ")\n", "scheduler = ExponentialLR(optimizer, gamma=0.97)" ] }, { "cell_type": "markdown", "id": "e8a69655", "metadata": {}, "source": [ "We will use cross-entropy loss, as it is commonly performant for classification problems.\n", "Other losses are available." ] }, { "cell_type": "code", "execution_count": 15, "id": "63e1b669", "metadata": {}, "outputs": [], "source": [ "ce_loss = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", "id": "14c025ff", "metadata": {}, "source": [ "We construct a standard PyTorch training loop function." ] }, { "cell_type": "code", "execution_count": 16, "id": "62ce4411", "metadata": {}, "outputs": [], "source": [ "def train(nbrs_lookup):\n", " for i in range(training_iterations):\n", " model.train()\n", " optimizer.zero_grad()\n", " predictions,variances = model(train_features)\n", " loss = ce_loss(predictions,batch_targets)\n", " loss.backward() \n", " optimizer.step()\n", " scheduler.step()\n", " if np.mod(i,1) == 0:\n", " print(f\"Iter {i + 1}/{training_iterations} - Loss: {loss.item()}\")\n", " model.eval()\n", " nbrs_lookup = NN_Wrapper(\n", " model.embedding(train_features).detach().numpy(), \n", " nn_count, nn_method=\"exact\"\n", " )\n", " batch_nn_indices,_ = nbrs_lookup._get_nns(\n", " model.embedding(batch_features).detach().numpy(),\n", " nn_count=nn_count,\n", " )\n", " batch_nn_targets = train_responses[batch_nn_indices, :] \n", " model.batch_nn_indices = batch_nn_indices\n", " model.batch_nn_targets = batch_nn_targets\n", " torch.cuda.empty_cache()\n", " nbrs_lookup = NN_Wrapper(\n", " model.embedding(train_features).detach().numpy(),\n", " nn_count,\n", " nn_method=\"exact\",\n", " )\n", " batch_nn_indices,_ = nbrs_lookup._get_nns(\n", " model.embedding(batch_features).detach().numpy(),\n", " nn_count=nn_count,\n", " )\n", " batch_nn_targets = train_responses[batch_nn_indices, :]\n", " model.batch_nn_indices = batch_nn_indices\n", " model.batch_nn_targets = batch_nn_targets\n", " return nbrs_lookup, model" ] }, { "cell_type": "markdown", "id": "b11548ab", "metadata": {}, "source": [ "Finally, we execute the training function and evaluate the trained model" ] }, { "cell_type": "code", "execution_count": 17, "id": "eb287061", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter 1/10 - Loss: 1.514917016029358\n", "Iter 2/10 - Loss: 1.4779890775680542\n", "Iter 3/10 - Loss: 1.4398393630981445\n", "Iter 4/10 - Loss: 1.423111081123352\n", "Iter 5/10 - Loss: 1.4219379425048828\n", "Iter 6/10 - Loss: 1.4020675420761108\n", "Iter 7/10 - Loss: 1.3868385553359985\n", "Iter 8/10 - Loss: 1.3743661642074585\n", "Iter 9/10 - Loss: 1.3675148487091064\n", "Iter 10/10 - Loss: 1.3568177223205566\n" ] } ], "source": [ "nbrs_lookup, model_trained = train(nbrs_lookup)" ] }, { "cell_type": "markdown", "id": "357019a6", "metadata": {}, "source": [ "Our final model parameters look like the following:" ] }, { "cell_type": "code", "execution_count": 18, "id": "4316f1a4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "embedding.0.weight, torch.Size([400, 784])\n", "embedding.0.bias, torch.Size([400])\n", "embedding.2.weight, torch.Size([200, 400])\n", "embedding.2.bias, torch.Size([200])\n", "embedding.4.weight, torch.Size([100, 200])\n", "embedding.4.bias, torch.Size([100])\n", "GP_layer.length_scale, 1.0085885524749756\n" ] } ], "source": [ "for n, p in model_trained.named_parameters():\n", " print(f\"{n}, {p.shape if p.shape != torch.Size([]) else p.item()}\") " ] }, { "cell_type": "markdown", "id": "a4f54d60", "metadata": {}, "source": [ "We then compute and report the performance of the predicted test responses using this trained model." ] }, { "cell_type": "code", "execution_count": 19, "id": "101aa630", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MNIST Prediction Accuracy Using Hybrid Torch Model:\n", "0.9361\n" ] } ], "source": [ "predictions, variances = predict_model(\n", " model=model_trained,\n", " test_features=test_features,\n", " train_features=train_features,\n", " train_responses=train_responses,\n", " nbrs_lookup=nbrs_lookup,\n", " nn_count=nn_count,\n", ")\n", "print(\"MNIST Prediction Accuracy Using Hybrid Torch Model:\")\n", "print(\n", " (\n", " torch.sum(\n", " torch.argmax(predictions,dim=1) == torch.argmax(test_responses,dim=1)\n", " ) / 10000\n", " ).numpy()\n", ")" ] }, { "cell_type": "markdown", "id": "ae2a5434", "metadata": {}, "source": [ "We note that this is quite mediocre performance on MNIST. In the interest of reducing notebook runtime we have used a simple fully-connected neural network model to construct the Gaussian process kernel. To achieve results closer to the state-of-the-art (near 100% accuracy), we recommend using more complex architectures which integrate convolutional kernels into the model." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "vscode": { "interpreter": { "hash": "434749ae7207e94f9d6928c9f347c2cd1a679cf18b55a36c093c3f406aed8e17" } } }, "nbformat": 4, "nbformat_minor": 5 }