Skip to content

Commit

Permalink
Merge pull request #154 from finncatling/lap-risk-153
Browse files Browse the repository at this point in the history
Simpler example risk distributions
  • Loading branch information
finncatling authored May 13, 2021
2 parents 774419d + af6e9e7 commit 8c98296
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
4 changes: 2 additions & 2 deletions 09_plot_novel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@
patient_indices=(9942, 3094),
kde_bandwidths=(0.008, 0.04),
output_dir=FIGURES_OUTPUT_DIR,
output_filename="09_novel_model_2_example_risk_distributions",
output_filename="09_novel_model_2_example_risk_distributions_simpler",
)
plot_saver(
plot_example_risk_distributions,
y_pred_samples=y_pred_samples,
patient_indices=(9942, 6530, 3094),
kde_bandwidths=(0.008, 0.012, 0.04),
output_dir=FIGURES_OUTPUT_DIR,
output_filename="09_novel_model_3_example_risk_distributions",
output_filename="09_novel_model_3_example_risk_distributions_simpler",
)


Expand Down
22 changes: 17 additions & 5 deletions utils/plot/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from arviz.stats.density_utils import kde
import matplotlib.ticker as mtick
from arviz.stats.density_utils import _kde_linear

from utils.evaluate import stratify_y_pred

Expand Down Expand Up @@ -84,9 +85,17 @@ def plot_example_risk_distributions(
"""Plot predicted risk distributions (and corresponding point estimates)
for example patients. y_pred_samples is (n_sampled_risks, n_patients).
kde_bandwidths should be same length as patient_indices."""
fig, ax = plt.subplots()
fig, ax = plt.subplots(figsize=(5, 3))

for i, j in enumerate(patient_indices):
grid, pdf = kde(y_pred_samples[:, j], bw=kde_bandwidths[i])
grid, pdf = _kde_linear(
y_pred_samples[:, j],
bw=kde_bandwidths[i],
bound_correction=False,
extend=True,
extend_fct=2.0,
adaptive=True
)
ax.fill_between(grid, pdf, alpha=0.4, label=f'Patient {i + 1}')
if i:
ax.axvline(np.median(y_pred_samples[:, j]), c='black', ls=':')
Expand All @@ -95,15 +104,18 @@ def plot_example_risk_distributions(
np.median(y_pred_samples[:, j]),
c='black',
ls=':',
label='Point prediction'
label='Point estimate'
)

ax.set_ylim(bottom=0)
ax.set(
xlim=(0, 1),
xlabel='Predicted risk of death',
ylabel='Probability density'
ylabel='Probability density',
yticks=([]) # turn off major y ticks
)
ax.set_yticks([], minor=True) # turn off minor y ticks
ax.xaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0, decimals=0))
ax.legend()

return fig, ax

0 comments on commit 8c98296

Please sign in to comment.