Skip to content

Commit

Permalink
Fix distconv test
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedorowicz1 committed Jun 18, 2024
1 parent 8dc219b commit 356d44b
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions ci_test/unit_tests/test_unit_layer_upsample_distconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import os.path
import sys
import pytest
import numpy as np
import lbann.contrib.args

Expand Down Expand Up @@ -112,6 +113,12 @@ def construct_model(lbann):
"""

num_height_groups = tools.gpus_per_node(lbann)
if num_height_groups == 0:
e = 'this test requires GPUs.'
print('Skip - ' + e)
pytest.skip(e)

# Input data
# Note: Sum with a weights layer so that gradient checking will
# verify that error signals are correct.
Expand Down Expand Up @@ -173,15 +180,15 @@ def construct_model(lbann):
dilation=dilations,
has_bias=False,
parallel_strategy=create_parallel_strategy(
4),
num_height_groups),
name=f'conv1_{uname}')

y = lbann.Upsample(x,
num_dims=len(u["scale_factors"]),
has_vectors=True,
scale_factors=u["scale_factors"],
upsample_mode=u['upsample_mode'],
parallel_strategy=create_parallel_strategy(4),
parallel_strategy=create_parallel_strategy(num_height_groups),
name=f'upsample_{uname}')


Expand Down Expand Up @@ -215,7 +222,7 @@ def construct_model(lbann):
dilation=dilations,
has_bias=False,
parallel_strategy=create_parallel_strategy(
4),
num_height_groups),
name=f'conv2_{uname}')

y = lbann.Identity(y, name=f'out_{uname}')
Expand Down

0 comments on commit 356d44b

Please sign in to comment.