Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random Walk kernel #30

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Empty file added GP/__init__.py
Empty file.
Empty file added GP/kernel_modules/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions GP/kernel_modules/kernel_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Author: Henry Moss & Ryan-Rhys Griffiths
"""
Utility methods for graph-based kernels
"""

import tensorflow as tf
import numpy as np


def normalize(k_matrix):
k_matrix_diagonal = tf.linalg.diag_part(k_matrix)
squared_normalization_factor = tf.multiply(tf.expand_dims(k_matrix_diagonal, 1),
tf.expand_dims(k_matrix_diagonal, 0))

return tf.divide(k_matrix, tf.sqrt(squared_normalization_factor))


def pad_tensor(tensor, target_dim):
return tf.pad(tensor, [[0, target_dim - tensor.shape[0]], [0, target_dim - tensor.shape[0]]], 'CONSTANT')


def pad_tensors(tensor_list):
max_dim = max(tensor_list, key=lambda x: x.shape[0]).shape[0]
return [pad_tensor(tensor, max_dim) for tensor in tensor_list]


def unpad_tensor(tensor):
mask = tf.reduce_sum(tensor, 0) != 0
rows_unpadded = tf.boolean_mask(tensor, mask, axis=0)
fully_unpadded = tf.boolean_mask(rows_unpadded, mask, axis=1)
return fully_unpadded


def preprocess_adjacency_matrix_inputs(adj_mat_list):
padded_adj_mats = pad_tensors(adj_mat_list)
flattened_padded_adj_mats = tf.reshape(padded_adj_mats, (len(padded_adj_mats), padded_adj_mats[0].shape[0]**2))
return flattened_padded_adj_mats


def extract_adj_mats_from_vector_inputs(preprocessed_data):
adj_mat_dim = int(np.sqrt(preprocessed_data.shape[1]))
rehydrated_adj_mats = tf.reshape(preprocessed_data, (len(preprocessed_data), adj_mat_dim, adj_mat_dim))
return rehydrated_adj_mats
124 changes: 124 additions & 0 deletions GP/kernel_modules/random_walk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Author: Henry Moss & Ryan-Rhys Griffiths
"""
Molecule kernels for Gaussian Process Regression implemented in GPflow.
"""

from math import factorial

import gpflow
import tensorflow as tf

from .kernel_utils import normalize, unpad_tensor, extract_adj_mats_from_vector_inputs


class RandomWalk(gpflow.kernels.Kernel):
def __init__(self, normalize=True, weight=0.1, series_type='geometric', p=None, uniform_probabilities=False):
super().__init__()
self.normalize = normalize
self.weight = weight
if series_type == 'geometric':
self.geometric = True
elif series_type == 'exponential':
self.geometric = False
self.p = p
self.uniform_probabilities = uniform_probabilities

def K(self, X, X2=None):
"""
Compute the random walk graph kernel (Gartner et al. 2003),
specifically using the spectral decomposition approach
given by https://www.jmlr.org/papers/volume11/vishwanathan10a/vishwanathan10a.pdf

:param X: N x D array.
:param X2: M x D array. If None, compute the N x N kernel matrix for X.
:return: The kernel matrix of dimension N x M
"""

X = extract_adj_mats_from_vector_inputs(X)

if X2 is None:
X2 = X
X_is_X2 = True
else:
X2 = extract_adj_mats_from_vector_inputs(X2)
X_is_X2 = False
self.normalize = False

flattened_k_matrix = tf.TensorArray(tf.float64, size=len(X)*len(X2))
matrix_idx = 0

for idx_1 in range(len(X)):

adj_mat_1 = tf.cast(unpad_tensor(X[idx_1]), tf.float64)
eigenval_1, eigenvec_1, flanking_factor_1 = self._eigendecompose_and_calculate_flanking_factor(adj_mat_1)

for idx_2 in range(len(X2)):

if X_is_X2 and idx_1 == idx_2:
eigenval_2, eigenval_2, flanking_factor_2 = eigenval_1, eigenval_1, flanking_factor_1
else:
adj_mat_2 = tf.cast(unpad_tensor(X2[idx_2]), tf.float64)
eigenval_2, eigenvec_2, flanking_factor_2 = self._eigendecompose_and_calculate_flanking_factor(
adj_mat_2)

flanking_factor = tf.linalg.LinearOperatorKronecker(
[tf.linalg.LinearOperatorFullMatrix(flanking_factor_1),
tf.linalg.LinearOperatorFullMatrix(flanking_factor_2)
]).to_dense()

diagonal = self.weight * tf.linalg.LinearOperatorKronecker(
[tf.linalg.LinearOperatorFullMatrix(tf.expand_dims(eigenval_1, axis=0)),
tf.linalg.LinearOperatorFullMatrix(tf.expand_dims(eigenval_2, axis=0))
]).to_dense()

if self.p is not None:
power_series = tf.ones_like(diagonal)
temp_diagonal = tf.ones_like(diagonal)

for k in range(self.p):
temp_diagonal = tf.multiply(temp_diagonal, diagonal)
if not self.geometric:
temp_diagonal = tf.divide(temp_diagonal, factorial(k+1))
power_series = tf.add(power_series, temp_diagonal)

power_series = tf.linalg.diag(power_series)
else:
if self.geometric:
power_series = tf.linalg.diag(1 / (1 - diagonal))
else:
power_series = tf.linalg.diag(tf.exp(diagonal))

matrix_entry = tf.linalg.matmul(
flanking_factor,
tf.linalg.matmul(
power_series,
tf.transpose(flanking_factor, perm=[1, 0])
)
)

flattened_k_matrix = flattened_k_matrix.write(matrix_idx, matrix_entry)
matrix_idx += 1

k_matrix = tf.reshape(flattened_k_matrix.stack(), (len(X), len(X2)))

if self.normalize:
return normalize(k_matrix)

return k_matrix

def K_diag(self, X):
"""
Compute the diagonal of the N x N kernel matrix of X
:param X: N x D array
:return: N x 1 array
"""
return tf.linalg.tensor_diag_part(self.K(X))

def _eigendecompose_and_calculate_flanking_factor(self, adjacency_matrix):
eigenval, eigenvec = tf.linalg.eigh(adjacency_matrix)
start_stop_probs = tf.ones((1, tf.shape(eigenvec)[0]), tf.float64)
if self.uniform_probabilities:
start_stop_probs = tf.divide(start_stop_probs, tf.shape(eigenvec)[0])
flanking_factor = tf.linalg.matmul(start_stop_probs, eigenvec)

return eigenval, eigenvec, flanking_factor
152 changes: 152 additions & 0 deletions examples/Gaussian_Process_RandomWalk_Kernel.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import pandas as pd\n",
"sys.path.append('..') # to import from GP.kernels and property_predition.data_utils"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import gpflow\n",
"from gpflow.mean_functions import Constant\n",
"from gpflow.utilities import print_summary\n",
"import numpy as np\n",
"from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error\n",
"from sklearn.model_selection import train_test_split\n",
"import tensorflow as tf\n",
"from rdkit.Chem import MolFromSmiles\n",
"from rdkit.Chem.rdmolops import GetAdjacencyMatrix\n",
"\n",
"from property_prediction.data_utils import transform_data\n",
"from GP.kernel_modules.random_walk import RandomWalk\n",
"from GP.kernel_modules.kernel_utils import pad_tensors\n",
"from GP.kernel_modules.kernel_utils import preprocess_adjacency_matrix_inputs"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv(\"../datasets/ESOL.csv\")[:50]\n",
"smiles = df[\"smiles\"].to_numpy()\n",
"y = df['measured log solubility in mols per litre'].to_numpy()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"X = [tf.convert_to_tensor(GetAdjacencyMatrix(MolFromSmiles(smiles))) for smiles in smiles]\n",
"X = preprocess_adjacency_matrix_inputs(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define the Gaussian Process Regression training objective"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def objective_closure():\n",
" return -m.log_marginal_likelihood()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X.numpy(), y, test_size=0.2, random_state=0)\n",
"y_train = y_train.reshape(-1, 1)\n",
"y_test = y_test.reshape(-1, 1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" We standardise the outputs but leave the inputs unchanged"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"_, y_train, _, y_test, y_scaler = transform_data(X_train, y_train, X_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"k = RandomWalk()\n",
"m = gpflow.models.GPR(data=(X_train, y_train), mean_function=Constant(np.mean(y_train)), kernel=k, noise_variance=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Optimise the kernel variance and noise level by the marginal likelihood"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"opt = gpflow.optimizers.Scipy()\n",
"opt.minimize(objective_closure, m.trainable_variables, options=dict(maxiter=100))\n",
"print_summary(m) # Model summary"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Empty file added tests/__init__.py
Empty file.
Empty file added tests/kernels/__init__.py
Empty file.
61 changes: 61 additions & 0 deletions tests/kernels/test_random_walk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Author: Henry Moss & Ryan-Rhys Griffiths
"""
Verifies the FlowMO implementation of the Random Walk graph kernel
against GraKel
"""

import os

import grakel
import numpy.testing as npt
import pandas as pd
import pytest
import tensorflow as tf
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.rdmolops import GetAdjacencyMatrix

from GP.kernel_modules.random_walk import RandomWalk
from GP.kernel_modules.kernel_utils import preprocess_adjacency_matrix_inputs

@pytest.fixture
def load_data():
benchmark_path = os.path.abspath(
os.path.join(
os.getcwd(), '..', '..', 'datasets', 'ESOL.csv'
)
)
df = pd.read_csv(benchmark_path)
smiles = df["smiles"].to_list()

adj_mats = [GetAdjacencyMatrix(MolFromSmiles(smiles)) for smiles in smiles[:50]]
tensor_adj_mats = [tf.convert_to_tensor(adj_mat) for adj_mat in adj_mats]
preprocessed_tensor_adj_mats = preprocess_adjacency_matrix_inputs(tensor_adj_mats)
grakel_graphs = [grakel.Graph(adj_mat) for adj_mat in adj_mats]

return preprocessed_tensor_adj_mats, grakel_graphs


@pytest.mark.parametrize(
'weight, series_type, p',
[
(0.1, 'geometric', None),
(0.1, 'exponential', None),
#(0.3, 'geometric', None), #Requires `method_type="baseline" in GraKel kernel constructor
(0.3, 'exponential', None),
(0.3, 'geometric', 3), #Doesn't pass due to suspected GraKel bug, see https://github.com/ysig/GraKeL/issues/71
(0.8, 'exponential', 3), #Same issue as above test
]
)
def test_random_walk_unlabelled(weight, series_type, p, load_data):
preprocessed_tensor_adj_mats, grakel_graphs = load_data

random_walk_grakel = grakel.kernels.RandomWalk(normalize=True, lamda=weight, kernel_type=series_type, p=p)
grakel_results = random_walk_grakel.fit_transform(grakel_graphs)

random_walk_FlowMo = RandomWalk(normalize=True, weight=weight, series_type=series_type, p=p)
FlowMo_results = random_walk_FlowMo.K(preprocessed_tensor_adj_mats, preprocessed_tensor_adj_mats)

npt.assert_almost_equal(
grakel_results, FlowMo_results.numpy(),
decimal=2
)