-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove default independent sampler jitter but ensure positive variance #888
base: develop
Are you sure you want to change the base?
Changes from all commits
8ef3e46
e284c87
7fd07e6
ff62b0c
7ef0492
97900aa
51adc46
bfccfd8
0555a84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
from ...space import EncoderFunction | ||
from ...types import TensorType | ||
from ...utils import DEFAULTS, flatten_leading_dims | ||
from ...utils.misc import ensure_positive | ||
from ..interfaces import ( | ||
ProbabilisticModel, | ||
ReparametrizationSampler, | ||
|
@@ -114,7 +115,7 @@ def __init__( | |
"at: [N..., 1, D] # IndependentReparametrizationSampler only supports batch sizes of one", | ||
"return: [N..., S, 1, L]", | ||
) | ||
def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorType: | ||
def sample(self, at: TensorType, *, jitter: float = 0.0) -> TensorType: | ||
""" | ||
Return approximate samples from the `model` specified at :meth:`__init__`. Multiple calls to | ||
:meth:`sample`, for any given :class:`IndependentReparametrizationSampler` and ``at``, will | ||
|
@@ -133,7 +134,7 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy | |
tf.debugging.assert_greater_equal(jitter, 0.0) | ||
|
||
mean, var = self._model.predict(at[..., None, :, :]) # [..., 1, 1, L], [..., 1, 1, L] | ||
var = var + jitter | ||
var = ensure_positive(var + jitter) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (note that we could alternatively ignore the jitter argument here, even if it's explicitly provided, if we think that would be better) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This version might be a bit difficult to read and debug, as we are potentially applying a correction twice (we apply the jitter with the sum, then with ensure_positive we potentially add an offset). But I'm not sure if there exists a better alternative There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One solution to both this comment and the one at the end would be to change the default value to -1, and comment that this magic value doesn't add jitter but ensures that the variance is positive. And then if the user specifies an explicit non-negative jitter we can use that unmodified? (Engineering-wise it would be nicer to make jitter an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would explicitly ignore the jitter here and add to docstrings that it is ignored - perhaps lets also do it properly and change it to be optional There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here I think there should be no reason for the user to want a different jitter, right @vpicheny ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you also please check GPflux and keras samplers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the main case for the jitter here is when the sampling is used with an acquisition function, possibly using sqrt(var) or log(var) or cdf(mean, var), that would fail if it is numerically zero but negative. Otherwise we would probably just want to avoid any offset that would get in the way, e.g. say the output is not rescaled and has very very small values so adding 1e-6 would change everything. We could leave this logic to the acquisition function, or just ensure here that we are "just positive". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
def sample_eps() -> tf.Tensor: | ||
self._initialized.assign(True) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -462,3 +462,9 @@ def _flatten_module( # type: ignore[no-untyped-def] | |
for subvalue in subvalues: | ||
# Predicate is already tested for these values. | ||
yield subvalue | ||
|
||
|
||
def ensure_positive(x: TensorType) -> TensorType: | ||
"""Esure that all the elements in `x` are strictly positive (using a dtype-dependent | ||
capping threshold.""" | ||
return tf.math.maximum(x, 1e-6 if x.dtype == tf.float32 else 1e-16) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. naive question, is 1e-6 the lowest we can have with single precision? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not at all. This was just based on scaling up the suggested value of 1e-16 for float64. Both numbers can go significantly smaller if we want: float32 can go down to aound 1e-38 and float64 to 2e-308. Do you have any intuition for how small we should make these? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we may be fine with smallest number for each precision that makes it positive, though it may depend on the usage downstream - at the moment we are just taking sqrt and doing some multiplication, that will take it to equal 0 but in this use case it should be fine I think? eps contribution would be removed in these cases, but not sure if that's relevant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, I would probably vote for a very small value on both cases. 1e-6 is way too high. And maybe we do not need to differentiate between single and double precision? Both could be e.g. 1e-32 or something |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure what this test is doing... does setting the kernel amplitude to 0 makes the model variance equal to zero? should we check then that the model prediction variance is zero, but the sampler applies the right fix?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's right. I've now added an assert that the model variance is zero.