-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Global inducing points implementation #50
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good draft PR. Left some comments and questions. Would be good to have a notebook to see how to use the code. A few tests (as discussed in person) will also be required. I also think the formatting is off in some places; make sure to run make format
before pushing.
class GIGPLayer(tf.keras.layers.Layer): | ||
""" | ||
A sparse variational multioutput GP layer. This layer holds the kernel, | ||
inducing variables and variational distribution, and mean function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inducing variables and variational distribution, and mean function. | |
inducing variables, variational distribution, and mean function. |
Calculates the log probability of a zero-mean multivariate Gaussian with covariance sigma | ||
and evaluation points X, with batching of both the covariance and X. | ||
|
||
TODO: look into whether this can be replaced with a tfp.distributions.Distribution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this should be possible and I would advice doing that.
from gpflux.layers import LayerWithObservations, SampleBasedGaussianLikelihoodLayer | ||
|
||
|
||
class GIDeepGP(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a lot of code duplication from DeepGP
. Can you not inherit from DeepGP
and overwrite the appropriate methods?
""" | ||
mean_function = self.mean_function(inputs) | ||
|
||
Kuu = self.kernel(inputs[..., :self.num_inducing, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give some information in the form of comments how inputs
is structured. What exactly are the inputs and outputs of call
?
""" | ||
Samples function values f based off samples of u. | ||
|
||
:param u: Samples of the inducing points, shape [S, Lout, M, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is Lout
?
Quick implementation of "global inducing" approximate posteriors for basic DGPs (see Ober and Aitchison (2021): https://arxiv.org/abs/2005.08140). TODOs: