Skip to content

Commit

Permalink
short-circut checking validity of token
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianMorawiec committed May 21, 2024
1 parent 398ee23 commit bd2fb5d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Dict, List, Optional, Protocol, Sequence, Union

from orquestra.sdk._client._base._jwt import check_jwt_without_signature_verification
from orquestra.sdk._shared import (
Project,
ProjectRef,
Expand All @@ -16,7 +17,12 @@
serde,
)
from orquestra.sdk._shared.abc import RuntimeInterface
from orquestra.sdk._shared.exceptions import IgnoredFieldWarning
from orquestra.sdk._shared.exceptions import (
ExpiredTokenError,
IgnoredFieldWarning,
InvalidTokenError,
UnauthorizedError,
)
from orquestra.sdk._shared.kubernetes.quantity import parse_quantity
from orquestra.sdk._shared.logs import (
LogAccumulator,
Expand Down Expand Up @@ -64,6 +70,14 @@ def __call__(
...


def _check_token_validity(token: str):
try:
check_jwt_without_signature_verification(token)
raise UnauthorizedError()
except (ExpiredTokenError, InvalidTokenError):
return False


def _get_max_resources(workflow_def: WorkflowDef) -> _models.Resources:
max_gpu = None
max_memory = None
Expand Down Expand Up @@ -121,6 +135,7 @@ def __init__(
self,
config: RuntimeConfiguration,
client: _client.DriverClient,
token: str,
verbose: bool = False,
):
"""Initialiser for the CERuntime interface.
Expand All @@ -132,6 +147,7 @@ def __init__(
client: The DriverClient through which the runtime should communicate.
verbose: if `True`, CERuntime may print debug information about
its inner working to stderr.
token: bearers token used to authenticate with the cluster
Raises:
RuntimeConfigError: when the config is invalid.
Expand All @@ -140,6 +156,7 @@ def __init__(
self._verbose = verbose

self._client = client
self._token = token

def create_workflow_run(
self, workflow_def: WorkflowDef, project: Optional[ProjectRef], dry_run: bool
Expand All @@ -162,6 +179,7 @@ def create_workflow_run(
Returns:
the workflow run ID.
"""
_check_token_validity(self._token)
max_invocation_resources = _get_max_resources(workflow_def)

if workflow_def.resources is not None:
Expand Down Expand Up @@ -257,6 +275,8 @@ def get_workflow_run_status(self, workflow_run_id: WorkflowRunId) -> WorkflowRun
Returns:
The status of the workflow run
"""
_check_token_validity(self._token)

try:
return self._client.get_workflow_run(workflow_run_id)
except (_exceptions.InvalidWorkflowRunID, _exceptions.WorkflowRunNotFound) as e:
Expand Down Expand Up @@ -293,6 +313,8 @@ def get_workflow_run_outputs_non_blocking(
Returns:
the outputs associated with the workflow run
"""
_check_token_validity(self._token)

try:
result_ids = self._client.get_workflow_run_results(workflow_run_id)
except (
Expand Down Expand Up @@ -377,6 +399,8 @@ def get_available_outputs(
a mapping between task invocation ID and the available artifacts from the
matching task run.
"""
_check_token_validity(self._token)

try:
artifact_map = self._client.get_workflow_run_artifacts(workflow_run_id)
except (_exceptions.InvalidWorkflowRunID, _exceptions.WorkflowRunNotFound) as e:
Expand Down Expand Up @@ -433,6 +457,8 @@ def stop_workflow_run(
Raises:
WorkflowRunCanNotBeTerminated: if workflow run is cannot be terminated.
"""
_check_token_validity(self._token)

try:
self._client.terminate_workflow_run(workflow_run_id, force)
except _exceptions.WorkflowRunNotFound:
Expand Down Expand Up @@ -539,6 +565,8 @@ def list_workflow_run_summaries(
Raises:
UnauthorizedError: if the remote cluster rejects the token
"""
_check_token_validity(self._token)

func = self._client.list_workflow_run_summaries
try:
return self._list_wf_runs(
Expand Down Expand Up @@ -573,6 +601,8 @@ def list_workflow_runs(
Returns:
A list of the workflow runs
"""
_check_token_validity(self._token)

func = self._client.list_workflow_runs
try:
return self._list_wf_runs(
Expand All @@ -595,6 +625,8 @@ def get_workflow_logs(self, wf_run_id: WorkflowRunId) -> WorkflowLogs:
WorkflowRunNotFound: if the workflow run cannot be found
UnauthorizedError: if the remote cluster rejects the token
"""
_check_token_validity(self._token)

try:
messages = self._client.get_workflow_run_logs(wf_run_id)
sys_messages = self._client.get_system_logs(wf_run_id)
Expand Down Expand Up @@ -673,6 +705,8 @@ def get_task_logs(self, wf_run_id: WorkflowRunId, task_inv_id: TaskInvocationId)
InvalidWorkflowRunLogsError: if the logs could not be decoded.
UnauthorizedError: if the remote cluster rejects the token.
"""
_check_token_validity(self._token)

try:
messages = self._client.get_task_run_logs(wf_run_id, task_inv_id)
except (_exceptions.InvalidWorkflowRunID, _exceptions.TaskRunLogsNotFound) as e:
Expand All @@ -697,6 +731,8 @@ def get_task_logs(self, wf_run_id: WorkflowRunId, task_inv_id: TaskInvocationId)
return LogOutput(out=task_logs.out, err=task_logs.err)

def list_workspaces(self):
_check_token_validity(self._token)

try:
workspaces = self._client.list_workspaces()
return [
Expand All @@ -709,6 +745,8 @@ def list_workspaces(self):
) from e

def list_projects(self, workspace_id: str):
_check_token_validity(self._token)

try:
projects = self._client.list_projects(workspace_id)
return [
Expand All @@ -730,4 +768,6 @@ def list_projects(self, workspace_id: str):
) from e

def get_workflow_project(self, wf_run_id: WorkflowRunId) -> ProjectRef:
_check_token_validity(self._token)

return self._client.get_workflow_project(wf_run_id)
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ def _build_ce_runtime(config: RuntimeConfiguration, verbose: bool):
return orquestra.sdk._client._base._driver._ce_runtime.CERuntime(
config=config,
client=client,
token=token,
verbose=verbose,
)

0 comments on commit bd2fb5d

Please sign in to comment.