Skip to content

Commit

Permalink
save notebook cumreward plots to disk (#145)
Browse files Browse the repository at this point in the history
* notebook cumreward plot to disk

---------

Co-authored-by: William Blum <[email protected]>
  • Loading branch information
blumu and William Blum authored Aug 7, 2024
1 parent 09622b8 commit 4eabac5
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 163 deletions.
4 changes: 3 additions & 1 deletion cyberbattle/agents/baseline/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,14 @@ def plot_all_episodes(r):
plt.show()


def plot_averaged_cummulative_rewards(title, all_runs, show=True):
def plot_averaged_cummulative_rewards(title, all_runs, show=True, save_at=None):
"""Plot averaged cumulative rewards"""
new_plot(title)
for r in all_runs:
plot_episodes_rewards_averaged(r)
plt.legend(loc="lower right")
if save_at:
plt.savefig(save_at)
if show:
plt.show()

Expand Down
59 changes: 9 additions & 50 deletions notebooks/notebook_benchmark-chain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import logging\n",
"import gymnasium as gym\n",
"import cyberbattle.agents.baseline.learner as learner\n",
Expand All @@ -81,6 +82,7 @@
"import cyberbattle.agents.baseline.agent_tabularqlearning as tqa\n",
"import cyberbattle.agents.baseline.agent_dql as dqla\n",
"from cyberbattle.agents.baseline.agent_wrapper import Verbosity\n",
"from cyberbattle._env.cyberbattle_env import CyberBattleEnv\n",
"\n",
"logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format=\"%(levelname)s: %(message)s\")\n",
"%matplotlib inline"
Expand Down Expand Up @@ -111,59 +113,14 @@
"outputs": [],
"source": [
"# Papermill notebook parameters\n",
"\n",
"#############\n",
"# gymid = 'CyberBattleTiny-v0'\n",
"#############\n",
"gymid = \"CyberBattleToyCtf-v0\"\n",
"env_size = None\n",
"iteration_count = 1500\n",
"training_episode_count = 20\n",
"eval_episode_count = 10\n",
"maximum_node_count = 12\n",
"maximum_total_credentials = 10\n",
"#############\n",
"# gymid = \"CyberBattleChain-v0\"\n",
"# env_size = 10\n",
"# iteration_count = 9000\n",
"# training_episode_count = 50\n",
"# eval_episode_count = 5\n",
"# maximum_node_count = 22\n",
"# maximum_total_credentials = 22"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "encouraging-shoot",
"metadata": {
"execution": {
"iopub.execute_input": "2024-08-04T03:01:55.636085Z",
"iopub.status.busy": "2024-08-04T03:01:55.635325Z",
"iopub.status.idle": "2024-08-04T03:01:55.641049Z",
"shell.execute_reply": "2024-08-04T03:01:55.640123Z"
},
"papermill": {
"duration": 0.011052,
"end_time": "2024-08-04T03:01:55.642618",
"exception": false,
"start_time": "2024-08-04T03:01:55.631566",
"status": "completed"
},
"tags": [
"injected-parameters"
]
},
"outputs": [],
"source": [
"# Parameters\n",
"gymid = \"CyberBattleChain-v0\"\n",
"iteration_count = 9000\n",
"training_episode_count = 50\n",
"eval_episode_count = 5\n",
"maximum_node_count = 22\n",
"maximum_total_credentials = 22\n",
"env_size = 10"
"env_size = 10\n",
"plots_dir = \"plots\"\n"
]
},
{
Expand All @@ -188,7 +145,7 @@
},
"outputs": [],
"source": [
"from cyberbattle._env.cyberbattle_env import CyberBattleEnv\n",
"os.makedirs(plots_dir, exist_ok=True)\n",
"\n",
"# Load the Gym environment\n",
"if env_size:\n",
Expand Down Expand Up @@ -144988,6 +144945,7 @@
" f\"State: {[f.name() for f in themodel.state_space.feature_selection]} \"\n",
" f\"({len(themodel.state_space.feature_selection)}\\n\"\n",
" f\"Action: abstract_action ({themodel.action_space.flat_size()})\",\n",
" save_at=os.path.join(plots_dir, \"benchmark-chain-cumrewards.png\"),\n",
")"
]
},
Expand Down Expand Up @@ -145037,7 +144995,8 @@
"source": [
"contenders = [credlookup_run, tabularq_run, dql_run, dql_exploit_run]\n",
"p.plot_episodes_length(contenders)\n",
"p.plot_averaged_cummulative_rewards(title=f\"Agent Benchmark top contenders\\n\" f\"max_nodes:{ep.maximum_node_count}\\n\", all_runs=contenders)"
"p.plot_averaged_cummulative_rewards(title=f\"Agent Benchmark top contenders\\n\" f\"max_nodes:{ep.maximum_node_count}\\n\", all_runs=contenders,\n",
" save_at=os.path.join(plots_dir, \"benchmark-chain-cumreward_contenders.png\"))"
]
},
{
Expand Down Expand Up @@ -145154,4 +145113,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
10 changes: 4 additions & 6 deletions notebooks/notebook_benchmark-tiny.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
"import cyberbattle.agents.baseline.agent_dql as dqla\n",
"from cyberbattle.agents.baseline.agent_wrapper import Verbosity\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"\n",
"logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format=\"%(levelname)s: %(message)s\")\n",
"%matplotlib inline"
Expand Down Expand Up @@ -470,6 +469,7 @@
" f\"State: {[f.name() for f in themodel.state_space.feature_selection]} \"\n",
" f\"({len(themodel.state_space.feature_selection)}\\n\"\n",
" f\"Action: abstract_action ({themodel.action_space.flat_size()})\",\n",
" save_at=os.path.join(plots_dir, \"benchmark-tiny-cumrewards.png\"),\n",
")"
]
},
Expand Down Expand Up @@ -498,10 +498,8 @@
"source": [
"contenders = [credlookup_run, tabularq_run, dql_run, dql_exploit_run]\n",
"p.plot_episodes_length(contenders)\n",
"p.plot_averaged_cummulative_rewards(title=f\"Agent Benchmark top contenders\\n\" f\"max_nodes:{ep.maximum_node_count}\\n\", all_runs=contenders, show=False)\n",
"\n",
"plt.savefig(os.path.join(plots_dir, \"benchmark-tiny-finalplot.png\"))\n",
"plt.show()"
"p.plot_averaged_cummulative_rewards(title=f\"Agent Benchmark top contenders\\n\" f\"max_nodes:{ep.maximum_node_count}\\n\", all_runs=contenders,\n",
" save_at=os.path.join(plots_dir, \"benchmark-tiny-cumreward_contenders.png\"))"
]
},
{
Expand Down Expand Up @@ -576,4 +574,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
26 changes: 11 additions & 15 deletions notebooks/notebook_benchmark-toyctf.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import logging\n",
"import gymnasium as gym\n",
"import cyberbattle.agents.baseline.learner as learner\n",
Expand Down Expand Up @@ -125,25 +126,14 @@
"outputs": [],
"source": [
"# Papermill notebook parameters\n",
"\n",
"#############\n",
"# gymid = 'CyberBattleTiny-v0'\n",
"#############\n",
"gymid = \"CyberBattleToyCtf-v0\"\n",
"env_size = None\n",
"iteration_count = 1500\n",
"training_episode_count = 20\n",
"eval_episode_count = 10\n",
"maximum_node_count = 12\n",
"maximum_total_credentials = 10\n",
"#############\n",
"# gymid = \"CyberBattleChain-v0\"\n",
"# env_size = 10\n",
"# iteration_count = 9000\n",
"# training_episode_count = 50\n",
"# eval_episode_count = 5\n",
"# maximum_node_count = 22\n",
"# maximum_total_credentials = 22"
"plots_dir = \"output/plots\"\n"
]
},
{
Expand Down Expand Up @@ -176,7 +166,8 @@
"training_episode_count = 20\n",
"eval_episode_count = 10\n",
"maximum_node_count = 12\n",
"maximum_total_credentials = 10"
"maximum_total_credentials = 10\n",
"plots_dir = \"output/plots\""
]
},
{
Expand All @@ -201,6 +192,8 @@
},
"outputs": [],
"source": [
"os.makedirs(plots_dir, exist_ok=True)\n",
"\n",
"# Load the Gym environment\n",
"if env_size:\n",
" _gym_env = gym.make(gymid, size=env_size)\n",
Expand Down Expand Up @@ -192540,6 +192533,8 @@
" f\"State: {[f.name() for f in themodel.state_space.feature_selection]} \"\n",
" f\"({len(themodel.state_space.feature_selection)}\\n\"\n",
" f\"Action: abstract_action ({themodel.action_space.flat_size()})\",\n",
" save_at=os.path.join(plots_dir, \"benchmark-toyctf-cumrewards.png\"),\n",
"\n",
")"
]
},
Expand Down Expand Up @@ -192589,7 +192584,8 @@
"source": [
"contenders = [credlookup_run, tabularq_run, dql_run, dql_exploit_run]\n",
"p.plot_episodes_length(contenders)\n",
"p.plot_averaged_cummulative_rewards(title=f\"Agent Benchmark top contenders\\n\" f\"max_nodes:{ep.maximum_node_count}\\n\", all_runs=contenders)"
"p.plot_averaged_cummulative_rewards(title=f\"Agent Benchmark top contenders\\n\" f\"max_nodes:{ep.maximum_node_count}\\n\", all_runs=contenders,\n",
" save_at=os.path.join(plots_dir, \"benchmark-toyctf-cumrewards_contenders.png\"))"
]
},
{
Expand Down Expand Up @@ -192705,4 +192701,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
15 changes: 13 additions & 2 deletions notebooks/notebook_dql_transfer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,18 @@
"source": [
"iteration_count = 9000\n",
"training_episode_count = 50\n",
"eval_episode_count = 10"
"eval_episode_count = 10\n",
"plots_dir = \"output/images\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65e34f4d",
"metadata": {},
"outputs": [],
"source": [
"os.makedirs(plots_dir, exist_ok=True)"
]
},
{
Expand Down Expand Up @@ -43536,7 +43547,7 @@
" iteration_count=iteration_count,\n",
" epsilon=0.0, # 0.35,\n",
" render=False,\n",
" render_last_episode_rewards_to=\"images/chain10\",\n",
" render_last_episode_rewards_to=os.path.join(plots_dir, \"dql_transfer-chain10\"),\n",
" title=\"Exploiting DQL\",\n",
" verbosity=Verbosity.Quiet,\n",
")"
Expand Down
17 changes: 11 additions & 6 deletions notebooks/notebook_randlookups.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@
},
"outputs": [],
"source": [
"from cyberbattle._env.cyberbattle_env import AttackerGoal\n",
"from cyberbattle.agents.baseline.agent_randomcredlookup import CredentialCacheExploiter\n",
"import cyberbattle.agents.baseline.learner as learner\n",
"import os\n",
"import gymnasium as gym\n",
"import logging\n",
"import sys\n",
"from cyberbattle._env.cyberbattle_env import AttackerGoal\n",
"from cyberbattle.agents.baseline.agent_randomcredlookup import CredentialCacheExploiter\n",
"import cyberbattle.agents.baseline.learner as learner\n",
"import cyberbattle.agents.baseline.plotting as p\n",
"import cyberbattle.agents.baseline.agent_wrapper as w\n",
"from cyberbattle.agents.baseline.agent_wrapper import Verbosity"
Expand Down Expand Up @@ -194,7 +195,8 @@
"source": [
"iteration_count = 9000\n",
"training_episode_count = 50\n",
"eval_episode_count = 5"
"eval_episode_count = 5\n",
"plots_dir = 'plots'"
]
},
{
Expand Down Expand Up @@ -59089,6 +59091,8 @@
}
],
"source": [
"os.makedirs(plots_dir, exist_ok=True)\n",
"\n",
"credexplot = learner.epsilon_greedy_search(\n",
" cyberbattlechain_10,\n",
" learner=CredentialCacheExploiter(),\n",
Expand Down Expand Up @@ -63805,7 +63809,8 @@
"p.plot_all_episodes(credexplot)\n",
"\n",
"all_runs = [credexplot, randomlearning_results]\n",
"p.plot_averaged_cummulative_rewards(title=f\"Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\\n\", all_runs=all_runs)"
"p.plot_averaged_cummulative_rewards(title=f\"Benchmark -- max_nodes={ep.maximum_node_count}, episodes={eval_episode_count},\\n\", all_runs=all_runs,\n",
" save_at=os.path.join(plots_dir, \"randlookups-cumreward.png\"))"
]
},
{
Expand Down Expand Up @@ -63862,4 +63867,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
23 changes: 18 additions & 5 deletions notebooks/notebook_tabularq.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import logging\n",
"from typing import cast\n",
"import gymnasium as gym\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt # type:ignore\n",
"import matplotlib.pyplot as plt\n",
"from cyberbattle.agents.baseline.learner import TrainedLearner\n",
"import cyberbattle.agents.baseline.plotting as p\n",
"import cyberbattle.agents.baseline.agent_wrapper as w\n",
Expand Down Expand Up @@ -172,7 +173,8 @@
"eval_episode_count = 5\n",
"gamma_sweep = [\n",
" 0.015, # about right\n",
"]"
"]\n",
"plots_dir = 'output/plots'"
]
},
{
Expand All @@ -181,6 +183,16 @@
"id": "0cdf621d",
"metadata": {},
"outputs": [],
"source": [
"os.makedirs(plots_dir, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "004c0ad8",
"metadata": {},
"outputs": [],
"source": [
"def qlearning_run(gamma, gym_env):\n",
" \"\"\"Execute one run of the q-learning algorithm for the\n",
Expand Down Expand Up @@ -38410,6 +38422,7 @@
" f\"Q1={[f.name() for f in Q_source_10.state_space.feature_selection]} \"\n",
" f\"-> {[f.name() for f in Q_source_10.action_space.feature_selection]})\\n\"\n",
" f\"Q2={[f.name() for f in Q_attack_10.state_space.feature_selection]} -> 'action'\",\n",
" save_at=os.path.join(plots_dir, \"benchmark-tabularq-cumrewards.png\")\n",
")"
]
},
Expand Down Expand Up @@ -72401,9 +72414,9 @@
"cell_metadata_filter": "title,-all"
},
"kernelspec": {
"display_name": "Python [conda env:cybersim]",
"display_name": "cybersim",
"language": "python",
"name": "conda-env-cybersim-py"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down Expand Up @@ -72432,4 +72445,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
Loading

0 comments on commit 4eabac5

Please sign in to comment.