Skip to content

Commit

Permalink
move gpus extension
Browse files Browse the repository at this point in the history
  • Loading branch information
SamratThapa120 committed Dec 12, 2024
1 parent bd7fcb5 commit 5c2eba8
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 69 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
'env = rocker.extensions:Environment',
'expose = rocker.extensions:Expose',
'git = rocker.git_extension:Git',
'gpus = rocker.extensions:Gpus',
'gpus = rocker.nvidia_extension:Gpus',
'group_add = rocker.extensions:GroupAdd',
'home = rocker.extensions:HomeDir',
'hostname = rocker.extensions:Hostname',
Expand Down
30 changes: 1 addition & 29 deletions src/rocker/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,32 +491,4 @@ def get_docker_args(self, cliargs):
def register_arguments(parser, defaults={}):
parser.add_argument('--shm-size',
default=defaults.get('shm_size', None),
help="Set the size of the shared memory for the container (e.g., 512m, 1g).")


class Gpus(RockerExtension):
@staticmethod
def get_name():
return 'gpus'

def __init__(self):
self.name = Gpus.get_name()

def get_preamble(self, cliargs):
return ''

def get_docker_args(self, cliargs):
# The gpu ids will be set in the nvidia extension, if the nvidia argument is passed.
if cliargs.get('nvidia', None):
return ''
args = ''
gpus = cliargs.get('gpus', None)
if gpus:
args += f' --gpus {gpus} '
return args

@staticmethod
def register_arguments(parser, defaults={}):
parser.add_argument('--gpus',
default=defaults.get('gpus', None),
help="Set the indices of GPUs to use")
help="Set the size of the shared memory for the container (e.g., 512m, 1g).")
27 changes: 27 additions & 0 deletions src/rocker/nvidia_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,30 @@ def register_arguments(parser, defaults):
action='store_true',
default=defaults.get('cuda', None),
help="Install cuda and nvidia-cuda-dev into the container")

class Gpus(RockerExtension):
@staticmethod
def get_name():
return 'gpus'

def __init__(self):
self.name = Gpus.get_name()

def get_preamble(self, cliargs):
return ''

def get_docker_args(self, cliargs):
# The gpu ids will be set in the nvidia extension, if the nvidia argument is passed.
if cliargs.get('nvidia', None):
return ''
args = ''
gpus = cliargs.get('gpus', None)
if gpus:
args += f' --gpus {gpus} '
return args

@staticmethod
def register_arguments(parser, defaults={}):
parser.add_argument('--gpus',
default=defaults.get('gpus', None),
help="Set the indices of GPUs to use")
40 changes: 1 addition & 39 deletions test/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,42 +648,4 @@ def test_shm_size_extension(self):

mock_cliargs = {'shm_size': '12g'}
args = p.get_docker_args(mock_cliargs)
self.assertIn('--shm-size 12g', args)

class GpusExtensionTest(unittest.TestCase):

def setUp(self):
# Work around interference between empy Interpreter
# stdout proxy and test runner. empy installs a proxy on stdout
# to be able to capture the information.
# And the test runner creates a new stdout object for each test.
# This breaks empy as it assumes that the proxy has persistent
# between instances of the Interpreter class
# empy will error with the exception
# "em.Error: interpreter stdout proxy lost"
em.Interpreter._wasProxyInstalled = False

@pytest.mark.docker
def test_gpus_extension(self):
plugins = list_plugins()
gpus_plugin = plugins['gpus']
self.assertEqual(gpus_plugin.get_name(), 'gpus')

p = gpus_plugin()
self.assertTrue(plugin_load_parser_correctly(gpus_plugin))

# Test when no GPUs are specified
mock_cliargs = {}
self.assertEqual(p.get_snippet(mock_cliargs), '')
self.assertEqual(p.get_preamble(mock_cliargs), '')
args = p.get_docker_args(mock_cliargs)
self.assertNotIn('--gpus', args)

# Test when GPUs are specified
mock_cliargs = {'gpus': 'all'}
args = p.get_docker_args(mock_cliargs)
self.assertIn('--gpus all', args)

mock_cliargs = {'gpus': '0,1'}
args = p.get_docker_args(mock_cliargs)
self.assertIn('--gpus 0,1', args)
self.assertIn('--shm-size 12g', args)
38 changes: 38 additions & 0 deletions test/test_nvidia.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,41 @@ def test_cuda_env_subs(self):
with self.assertRaises(SystemExit) as cm:
p.get_environment_subs(mock_cliargs)
self.assertEqual(cm.exception.code, 1)

class GpusExtensionTest(unittest.TestCase):

def setUp(self):
# Work around interference between empy Interpreter
# stdout proxy and test runner. empy installs a proxy on stdout
# to be able to capture the information.
# And the test runner creates a new stdout object for each test.
# This breaks empy as it assumes that the proxy has persistent
# between instances of the Interpreter class
# empy will error with the exception
# "em.Error: interpreter stdout proxy lost"
em.Interpreter._wasProxyInstalled = False

@pytest.mark.docker
def test_gpus_extension(self):
plugins = list_plugins()
gpus_plugin = plugins['gpus']
self.assertEqual(gpus_plugin.get_name(), 'gpus')

p = gpus_plugin()
self.assertTrue(plugin_load_parser_correctly(gpus_plugin))

# Test when no GPUs are specified
mock_cliargs = {}
self.assertEqual(p.get_snippet(mock_cliargs), '')
self.assertEqual(p.get_preamble(mock_cliargs), '')
args = p.get_docker_args(mock_cliargs)
self.assertNotIn('--gpus', args)

# Test when GPUs are specified
mock_cliargs = {'gpus': 'all'}
args = p.get_docker_args(mock_cliargs)
self.assertIn('--gpus all', args)

mock_cliargs = {'gpus': '0,1'}
args = p.get_docker_args(mock_cliargs)
self.assertIn('--gpus 0,1', args)

0 comments on commit 5c2eba8

Please sign in to comment.