Skip to content

Commit

Permalink
Fix v6e TPU Scripts and RayJob CRs (#2447)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanaoleary authored Oct 16, 2024
1 parent 175a1f7 commit 047699f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 26 deletions.
10 changes: 2 additions & 8 deletions ray-operator/config/samples/ray-job.tpu-v6e-multihost.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
apiVersion: ray.io/v1
kind: RayJob
metadata:
name: ray-multi-host-job
name: v6e-256-job
spec:
entrypoint: python ray-operator/config/samples/ray-tpu-scripts/tpu_list_devices.py
entrypoint: python ray-operator/config/samples/tpu/tpu_list_devices.py
runtimeEnvYAML: |
working_dir: "https://github.com/ray-project/kuberay/archive/master.zip"
pip:
- jax[tpu]==0.4.33
- -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
env_vars:
TPU_HOSTS: "64"
rayClusterSpec:
rayVersion: '2.37.0'
headGroupSpec:
Expand Down Expand Up @@ -50,10 +48,6 @@ spec:
containers:
- name: ray-worker
image: rayproject/ray:2.37.0-py310
lifecycle:
preStop:
exec:
command: [ "/bin/sh","-c","ray stop" ]
resources:
limits:
cpu: "100"
Expand Down
10 changes: 2 additions & 8 deletions ray-operator/config/samples/ray-job.tpu-v6e-singlehost.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
apiVersion: ray.io/v1
kind: RayJob
metadata:
name: ray-single-host-job
name: v6e-4-job
spec:
entrypoint: python ray-operator/config/samples/ray-tpu-scripts/tpu_list_devices.py
entrypoint: python ray-operator/config/samples/tpu/tpu_list_devices.py
runtimeEnvYAML: |
working_dir: "https://github.com/ray-project/kuberay/archive/master.zip"
pip:
- jax[tpu]==0.4.33
- -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
env_vars:
TPU_HOSTS: "1"
rayClusterSpec:
rayVersion: '2.37.0'
headGroupSpec:
Expand Down Expand Up @@ -50,10 +48,6 @@ spec:
containers:
- name: ray-worker
image: rayproject/ray:2.37.0-py310
lifecycle:
preStop:
exec:
command: [ "/bin/sh","-c","ray stop" ]
resources:
limits:
cpu: "100"
Expand Down
13 changes: 3 additions & 10 deletions ray-operator/config/samples/tpu/tpu_list_devices.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import ray
import jax

ray.init(
runtime_env={
"pip": [
"jax[tpu]==0.4.33",
"-f https://storage.googleapis.com/jax-releases/libtpu_releases.html",
]
}
)
ray.init()

@ray.remote(resources={"TPU": 4})
def tpu_cores():
import jax
return "TPU cores:" + str(jax.device_count())

num_workers = int(os.environ['TPU_HOSTS']) # Set in env of RayJob or RayCluster.
num_workers = int(ray.available_resources()["TPU"]) // 4
result = [tpu_cores.remote() for _ in range(num_workers)]
print(ray.get(result))

0 comments on commit 047699f

Please sign in to comment.