Skip to content

Commit

Permalink
fix for pytorch 1.13.1 (#9)
Browse files Browse the repository at this point in the history
* fix for pytorch 1.13.1

* lint

* nbconvert dependency

* update gh actions
  • Loading branch information
svenkreiss authored Mar 13, 2023
1 parent c1dde8c commit 492316d
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 19 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/deploy-guide.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ jobs:
deploy-guide:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3

- name: Set up Python 3.8
uses: actions/setup-python@v1
uses: actions/setup-python@v4
with:
python-version: 3.8

Expand All @@ -28,7 +28,7 @@ jobs:
unzip train.zip -d data-trajnet
- name: Install ffmpeg
uses: FedericoCarboni/setup-ffmpeg@v1
uses: FedericoCarboni/setup-ffmpeg@v2
id: setup-ffmpeg
- name: ffmpeg codecs
run: ffmpeg -codecs
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/pypi-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ jobs:
name: Build source distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
with:
fetch-depth: 0

- uses: actions/setup-python@v2
- uses: actions/setup-python@v4
name: Install Python
with:
python-version: '3.7'
Expand All @@ -38,7 +38,7 @@ jobs:
- name: Build sdist
run: python setup.py sdist

- uses: actions/upload-artifact@v2
- uses: actions/upload-artifact@v3
with:
path: dist/*.tar.gz

Expand All @@ -50,12 +50,12 @@ jobs:
# alternatively, to publish when a GitHub Release is created, use the following rule:
if: github.event_name == 'release' && github.event.action == 'published'
steps:
- uses: actions/download-artifact@v2
- uses: actions/download-artifact@v3
with:
name: artifact
path: dist

- uses: pypa/gh-action-pypi-publish@master
- uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_TOKEN }}
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ jobs:
strategy:
matrix:
include:
- os: ubuntu-latest
python: 3.6
torch: 1.9.0+cpu
torch-source: https://download.pytorch.org/whl/torch_stable.html
- os: ubuntu-latest
python: 3.7
torch: 1.9.0+cpu
Expand All @@ -26,6 +22,10 @@ jobs:
python: 3.9
torch: 1.9.0+cpu
torch-source: https://download.pytorch.org/whl/torch_stable.html
- os: ubuntu-latest
python: "3.10"
torch: 1.13.1+cpu
torch-source: https://download.pytorch.org/whl/torch_stable.html
- os: macos-11
python: 3.7
torch: 1.9.0
Expand All @@ -50,7 +50,7 @@ jobs:
steps:
- run: ls -n /Applications/ | grep Xcode*
if: matrix.os == 'macos-11'
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Data setup
if: matrix.slow
run: |
Expand All @@ -59,7 +59,7 @@ jobs:
unzip train.zip -d data-trajnet
- name: Set up Python ${{ matrix.python }}
if: ${{ !matrix.conda }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python }}
- name: Set up Conda
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
'dev': [
'jupyter-book',
'matplotlib',
'nbconvert',
'nbstripout',
'nbval',
'pycodestyle',
Expand Down
7 changes: 6 additions & 1 deletion socialforce/potentials/pedped_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@ def norm_r_ab(r_ab):
zero vector gives nan gradients.
"""
out = torch.linalg.norm(r_ab, ord=2, dim=2, keepdim=False)
torch.diagonal(out)[:] = 0.0

# only take the upper and lower triangles and leaving the
# diagonal at zero and do it in a differentiable way
# without inplace ops
out = torch.triu(out, diagonal=1) + torch.tril(out, diagonal=-1)

return out


Expand Down
2 changes: 1 addition & 1 deletion socialforce/potentials/pedped_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(self, *,
lin_out = torch.nn.Linear(hidden_units, 1)

# activation_function = torch.nn.Softplus
activation_function = lambda: torch.nn.Softplus(beta=5)
activation_function = lambda: torch.nn.Softplus(beta=5) # pylint: disable=unnecessary-lambda-assignment
self.mlp = torch.nn.Sequential(
lin_in, activation_function(),
*[layer for lin in lin_hidden for layer in (lin, activation_function())],
Expand Down
2 changes: 1 addition & 1 deletion socialforce/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def keep(state1, state2):
acc = (torch.abs(state1[:, 4:6]) > acc_abs) | (torch.abs(state2[:, 4:6]) > acc_abs)
acc = torch.any(acc, dim=-1)
# keep 10% of samples without acc:
acc[~acc] = (torch.rand(acc[~acc].shape) < 0.1)
acc[~acc] = torch.rand(acc[~acc].shape) < 0.1
acc[:] = torch.any(acc) # symmetrize

return valid_state1 & valid_state2 & small_distance & acc
Expand Down
4 changes: 2 additions & 2 deletions socialforce/trajnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Reader:
"""
def __init__(self, input_file, scene_type=None):
if scene_type is not None and scene_type not in {'rows', 'paths', 'tags'}:
raise Exception('scene_type not supported')
raise ValueError('scene_type not supported')
self.scene_type = scene_type

self.tracks_by_frame = defaultdict(list)
Expand Down Expand Up @@ -102,7 +102,7 @@ def paths_to_xy(paths):
def scene(self, scene_id):
scene = self.scenes_by_id.get(scene_id)
if scene is None:
raise Exception('scene with that id not found')
raise ValueError('scene with that id not found')

frames = range(scene.start, scene.end + 1)
track_rows = [r
Expand Down

0 comments on commit 492316d

Please sign in to comment.