Skip to content

Commit

Permalink
add nvidia compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
SamratThapa120 committed Dec 8, 2024
1 parent 06dbabb commit bd7fcb5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/rocker/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,9 @@ def get_preamble(self, cliargs):
return ''

def get_docker_args(self, cliargs):
# The gpu ids will be set in the nvidia extension, if the nvidia argument is passed.
if cliargs.get('nvidia', None):
return ''
args = ''
gpus = cliargs.get('gpus', None)
if gpus:
Expand Down
7 changes: 5 additions & 2 deletions src/rocker/nvidia_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,15 @@ def get_snippet(self, cliargs):

def get_docker_args(self, cliargs):
force_flag = cliargs.get('nvidia', None)
gpus_ids_flag = cliargs.get('gpus', None)
if gpus_ids_flag is None:
gpus_ids_flag = 'all'
if force_flag == 'runtime':
return " --runtime=nvidia"
if force_flag == 'gpus':
return " --gpus all"
return f" --gpus {gpus_ids_flag}"
if get_docker_version() >= Version("19.03"):
return " --gpus all"
return f" --gpus {gpus_ids_flag}"
return " --runtime=nvidia"

@staticmethod
Expand Down

0 comments on commit bd7fcb5

Please sign in to comment.