Skip to content

Commit

Permalink
Ensure host has GPU when trying to execute locally with GPU image
Browse files Browse the repository at this point in the history
  • Loading branch information
thvasilo committed Dec 19, 2024
1 parent 4d266e5 commit 2edf3a6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
3 changes: 2 additions & 1 deletion sagemaker/pipeline/create_sm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def _create_pipeline_parameters(self, args: PipelineArgs):
self.partition_algorithm_param = self._create_string_parameter(
"PartitionAlgorithm", args.partition_config.partition_algorithm
)
# TODO: Probably should not be a parameter
self.graph_name_param = self._create_string_parameter(
"GraphName", args.task_config.graph_name
)
Expand Down Expand Up @@ -314,7 +315,7 @@ def _create_gconstruct_step(self, args: PipelineArgs) -> ProcessingStep:
gc_proc_output = ProcessingOutput(
source=gc_local_output_path,
destination=gconstruct_s3_output,
output_name=f"{self.graph_name_param}-gconstruct",
output_name=f"{self.args.task_config.graph_name}-gconstruct",
)

gconstruct_arguments = [
Expand Down
10 changes: 10 additions & 0 deletions sagemaker/pipeline/execute_sm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import argparse
import os
import subprocess
import sys
import warnings

Expand Down Expand Up @@ -140,6 +141,15 @@ def main():
deploy_time_hash = pipeline_deploy_args.get_hash_hex()

if args.local_execution:
# Ensure GPU is available if trying to execute with GPU locally
if not pipeline_deploy_args.instance_config.train_on_cpu:
try:
subprocess.check_output('nvidia-smi')
except Exception:
raise RuntimeError(
'Need host with NVidia GPU to run training on GPU! '
"Try re-deploying the pipeline with --train-on-cpu set."
)
# Use local pipeline and session
local_session = LocalPipelineSession()
pipeline_generator = GraphStormPipelineGenerator(
Expand Down

0 comments on commit 2edf3a6

Please sign in to comment.