Skip to content

Commit

Permalink
upgrade pyright and fix new errors
Browse files Browse the repository at this point in the history
  • Loading branch information
William Blum authored and blumu committed Jun 22, 2021
1 parent 68b0192 commit 56fdb67
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 8 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
on:
pull_request:
branches:
- master
push:
branches:
- main
Expand Down Expand Up @@ -38,7 +41,7 @@ jobs:
- name: Pull pip dependencies
run: ./install-pythonpackages.sh
- name: Install pyright
run: npm install -g pyright
run: npm install -g pyright@1.1.151
- name: Pull typing stubs from cache
uses: actions/cache@v2
with:
Expand Down
8 changes: 7 additions & 1 deletion createstubs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ set -e

. ./getpythonpath.sh

pushd "$(dirname "$0")"

echo "$(tput setaf 2)Creating type stubs$(tput sgr0)"
createstub() {
local name=$1
Expand Down Expand Up @@ -42,12 +44,16 @@ fi

if [ ! -d "boolean" ]; then
pyright --createstub boolean
sed -i '/class BooleanAlgebra(object):/a\ TRUE = ...\n FALSE = ...' typings/boolean/boolean.pyi
sed -i '/class BooleanAlgebra:/a\ TRUE = ...\n FALSE = ...' typings/boolean/boolean.pyi
else
echo stub 'boolean' already created
fi

echo 'Typing stub generation completed'

# Stubs that needed manual patching and that
# were instead checked-in in git
# pyright --createstub boolean
# pyright --createstub gym

popd
2 changes: 1 addition & 1 deletion cyberbattle/_env/graph_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class CyberBattleGraph(gym.Wrapper):
def __init__(self, env, maximum_total_credentials=22, maximum_node_count=22):
super().__init__(env)
self._bounds = self.env._bounds
self.__graph = None
self.__graph = nx.DiGraph()
self.observation_space = DiGraph(self._bounds.maximum_node_count)

def reset(self):
Expand Down
3 changes: 2 additions & 1 deletion cyberbattle/agents/baseline/agent_dql.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ def optimize_model(self, norm_clipping=False):
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
else:
for param in self.policy_net.parameters():
param.grad.data.clamp_(-1, 1)
if param.grad is not None:
param.grad.data.clamp_(-1, 1)
self.optimizer.step()

def get_actor_state_vector(self, global_state: ndarray, actor_features: ndarray) -> ndarray:
Expand Down
4 changes: 2 additions & 2 deletions cyberbattle/agents/baseline/agent_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class StateAugmentation:
"""Default agent state augmentation, consisting of the gym environment
observation itself and nothing more."""

def __init__(self, observation: Optional[cyberbattle_env.Observation] = None):
def __init__(self, observation: cyberbattle_env.Observation):
self.observation = observation

def on_step(self, action: cyberbattle_env.Action, reward: float, done: bool, observation: cyberbattle_env.Observation):
Expand Down Expand Up @@ -424,7 +424,7 @@ class ActionTrackingStateAugmentation(StateAugmentation):
- failed_action_count: count of action taken and failed at the current node
"""

def __init__(self, p: EnvironmentBounds, observation: Optional[cyberbattle_env.Observation] = None):
def __init__(self, p: EnvironmentBounds, observation: cyberbattle_env.Observation):
self.aa = AbstractAction(p)
self.success_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
self.failed_action_count = np.zeros(shape=(p.maximum_node_count, self.aa.n_actions), dtype=np.int32)
Expand Down
2 changes: 1 addition & 1 deletion cyberbattle/agents/baseline/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def epsilon_greedy_search(
all_episodes_availability = []

wrapped_env = AgentWrapper(cyberbattle_gym_env,
ActionTrackingStateAugmentation(environment_properties))
ActionTrackingStateAugmentation(environment_properties, cyberbattle_gym_env.reset()))
steps_done = 0
plot_title = f"{title} (epochs={episode_count}, ϵ={initial_epsilon}, ϵ_min={epsilon_minimum}," \
+ (f"ϵ_multdecay={epsilon_multdecay}," if epsilon_multdecay else '') \
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
gym~=0.17.3
numpy==1.19.4
boolean.py~=3.7
boolean.py==3.8
networkx==2.4
pyyaml~=5.4.1
setuptools~=49.2.1
Expand Down

0 comments on commit 56fdb67

Please sign in to comment.