diff --git a/setup.py b/setup.py index a4d6479..41d904f 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ 'env = rocker.extensions:Environment', 'expose = rocker.extensions:Expose', 'git = rocker.git_extension:Git', - 'gpus = rocker.extensions:Gpus', + 'gpus = rocker.nvidia_extension:Gpus', 'group_add = rocker.extensions:GroupAdd', 'home = rocker.extensions:HomeDir', 'hostname = rocker.extensions:Hostname', diff --git a/src/rocker/extensions.py b/src/rocker/extensions.py index 3a0474c..bee9363 100644 --- a/src/rocker/extensions.py +++ b/src/rocker/extensions.py @@ -491,32 +491,4 @@ def get_docker_args(self, cliargs): def register_arguments(parser, defaults={}): parser.add_argument('--shm-size', default=defaults.get('shm_size', None), - help="Set the size of the shared memory for the container (e.g., 512m, 1g).") - - -class Gpus(RockerExtension): - @staticmethod - def get_name(): - return 'gpus' - - def __init__(self): - self.name = Gpus.get_name() - - def get_preamble(self, cliargs): - return '' - - def get_docker_args(self, cliargs): - # The gpu ids will be set in the nvidia extension, if the nvidia argument is passed. - if cliargs.get('nvidia', None): - return '' - args = '' - gpus = cliargs.get('gpus', None) - if gpus: - args += f' --gpus {gpus} ' - return args - - @staticmethod - def register_arguments(parser, defaults={}): - parser.add_argument('--gpus', - default=defaults.get('gpus', None), - help="Set the indices of GPUs to use") \ No newline at end of file + help="Set the size of the shared memory for the container (e.g., 512m, 1g).") \ No newline at end of file diff --git a/src/rocker/nvidia_extension.py b/src/rocker/nvidia_extension.py index deedeb9..e4e47e3 100644 --- a/src/rocker/nvidia_extension.py +++ b/src/rocker/nvidia_extension.py @@ -239,3 +239,30 @@ def register_arguments(parser, defaults): action='store_true', default=defaults.get('cuda', None), help="Install cuda and nvidia-cuda-dev into the container") + +class Gpus(RockerExtension): + @staticmethod + def get_name(): + return 'gpus' + + def __init__(self): + self.name = Gpus.get_name() + + def get_preamble(self, cliargs): + return '' + + def get_docker_args(self, cliargs): + # The gpu ids will be set in the nvidia extension, if the nvidia argument is passed. + if cliargs.get('nvidia', None): + return '' + args = '' + gpus = cliargs.get('gpus', None) + if gpus: + args += f' --gpus {gpus} ' + return args + + @staticmethod + def register_arguments(parser, defaults={}): + parser.add_argument('--gpus', + default=defaults.get('gpus', None), + help="Set the indices of GPUs to use") \ No newline at end of file diff --git a/test/test_extension.py b/test/test_extension.py index 26558ab..1ad88d8 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -648,42 +648,4 @@ def test_shm_size_extension(self): mock_cliargs = {'shm_size': '12g'} args = p.get_docker_args(mock_cliargs) - self.assertIn('--shm-size 12g', args) - -class GpusExtensionTest(unittest.TestCase): - - def setUp(self): - # Work around interference between empy Interpreter - # stdout proxy and test runner. empy installs a proxy on stdout - # to be able to capture the information. - # And the test runner creates a new stdout object for each test. - # This breaks empy as it assumes that the proxy has persistent - # between instances of the Interpreter class - # empy will error with the exception - # "em.Error: interpreter stdout proxy lost" - em.Interpreter._wasProxyInstalled = False - - @pytest.mark.docker - def test_gpus_extension(self): - plugins = list_plugins() - gpus_plugin = plugins['gpus'] - self.assertEqual(gpus_plugin.get_name(), 'gpus') - - p = gpus_plugin() - self.assertTrue(plugin_load_parser_correctly(gpus_plugin)) - - # Test when no GPUs are specified - mock_cliargs = {} - self.assertEqual(p.get_snippet(mock_cliargs), '') - self.assertEqual(p.get_preamble(mock_cliargs), '') - args = p.get_docker_args(mock_cliargs) - self.assertNotIn('--gpus', args) - - # Test when GPUs are specified - mock_cliargs = {'gpus': 'all'} - args = p.get_docker_args(mock_cliargs) - self.assertIn('--gpus all', args) - - mock_cliargs = {'gpus': '0,1'} - args = p.get_docker_args(mock_cliargs) - self.assertIn('--gpus 0,1', args) + self.assertIn('--shm-size 12g', args) \ No newline at end of file diff --git a/test/test_nvidia.py b/test/test_nvidia.py index 6fb7e63..9cfd2f5 100644 --- a/test/test_nvidia.py +++ b/test/test_nvidia.py @@ -328,3 +328,41 @@ def test_cuda_env_subs(self): with self.assertRaises(SystemExit) as cm: p.get_environment_subs(mock_cliargs) self.assertEqual(cm.exception.code, 1) + +class GpusExtensionTest(unittest.TestCase): + + def setUp(self): + # Work around interference between empy Interpreter + # stdout proxy and test runner. empy installs a proxy on stdout + # to be able to capture the information. + # And the test runner creates a new stdout object for each test. + # This breaks empy as it assumes that the proxy has persistent + # between instances of the Interpreter class + # empy will error with the exception + # "em.Error: interpreter stdout proxy lost" + em.Interpreter._wasProxyInstalled = False + + @pytest.mark.docker + def test_gpus_extension(self): + plugins = list_plugins() + gpus_plugin = plugins['gpus'] + self.assertEqual(gpus_plugin.get_name(), 'gpus') + + p = gpus_plugin() + self.assertTrue(plugin_load_parser_correctly(gpus_plugin)) + + # Test when no GPUs are specified + mock_cliargs = {} + self.assertEqual(p.get_snippet(mock_cliargs), '') + self.assertEqual(p.get_preamble(mock_cliargs), '') + args = p.get_docker_args(mock_cliargs) + self.assertNotIn('--gpus', args) + + # Test when GPUs are specified + mock_cliargs = {'gpus': 'all'} + args = p.get_docker_args(mock_cliargs) + self.assertIn('--gpus all', args) + + mock_cliargs = {'gpus': '0,1'} + args = p.get_docker_args(mock_cliargs) + self.assertIn('--gpus 0,1', args)