Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support proxies #211

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions diracx-cli/src/diracx/cli/internal/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typer import Option

from diracx.core.config import Config
from diracx.core.config.schema import Field, SupportInfo
from diracx.core.config.schema import Field, SupportInfo, VomsServerConfig

from ..utils import AsyncTyper

Expand All @@ -34,6 +34,7 @@ class VOConfig(BaseModel):
IdP: IdPConfig
UserSubjects: dict[str, str]
Support: SupportInfo = Field(default_factory=SupportInfo)
VOMSServers: dict[str, VomsServerConfig] = {}


class ConversionConfig(BaseModel):
Expand Down Expand Up @@ -133,10 +134,12 @@ def _apply_fixes(raw):
raw["Registry"][vo]["DefaultProxyLifeTime"] = original_registry[
"DefaultProxyLifeTime"
]
# Copy over the necessary parts of the VO section
for key in {"VOMSName"}:
if key in original_registry.get("VO", {}).get(vo, {}):
raw["Registry"][vo][key] = original_registry["VO"][vo][key]
# Copy over information about the VOMS server
raw["Registry"][vo]["VOMS"] = {"Servers": vo_meta.VOMSServers}
if "VOMSName" in original_registry.get("VO", {}).get(vo, {}):
raw["Registry"][vo]["VOMS"]["Name"] = original_registry["VO"][vo][
"VOMSName"
]
# Find the groups that belong to this VO
vo_users = set()
for name, info in original_registry["Groups"].items():
Expand Down
6 changes: 5 additions & 1 deletion diracx-cli/tests/legacy/cs_sync/integration_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ Registry:
- /C=ch/O=DIRAC/OU=DIRAC CI/CN=ciuser
Email: [email protected]
PreferedUsername: adminusername
VOMSName: myVOMS
VOMS:
Name: myVOMS
Servers: {}
vo:
DefaultGroup: dirac_user
Groups:
Expand Down Expand Up @@ -131,6 +133,8 @@ Registry:
- /C=ch/O=DIRAC/OU=DIRAC CI/CN=trialUser
Email: [email protected]
PreferedUsername: trialUser
VOMS:
Servers: {}
Resources:
FTSEndpoints:
FTS3:
Expand Down
52 changes: 52 additions & 0 deletions diracx-client/src/diracx/client/aio/operations/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
build_auth_do_device_flow_request,
build_auth_finish_device_flow_request,
build_auth_finished_request,
build_auth_get_proxy_request,
build_auth_get_refresh_tokens_request,
build_auth_initiate_device_flow_request,
build_auth_revoke_refresh_token_request,
Expand Down Expand Up @@ -763,6 +764,57 @@ async def userinfo(self, **kwargs: Any) -> _models.UserInfoResponse:

return deserialized

@distributed_trace_async
async def get_proxy(self, **kwargs: Any) -> Any:
"""Get Proxy.

Get Proxy.

:return: any
:rtype: any
:raises ~azure.core.exceptions.HttpResponseError:
"""
error_map = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
409: ResourceExistsError,
304: ResourceNotModifiedError,
}
error_map.update(kwargs.pop("error_map", {}) or {})

_headers = kwargs.pop("headers", {}) or {}
_params = kwargs.pop("params", {}) or {}

cls: ClsType[Any] = kwargs.pop("cls", None)

request = build_auth_get_proxy_request(
headers=_headers,
params=_params,
)
request.url = self._client.format_url(request.url)

_stream = False
pipeline_response: PipelineResponse = (
await self._client._pipeline.run( # pylint: disable=protected-access
request, stream=_stream, **kwargs
)
)

response = pipeline_response.http_response

if response.status_code not in [200]:
map_error(
status_code=response.status_code, response=response, error_map=error_map
)
raise HttpResponseError(response=response)

deserialized = self._deserialize("object", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {})

return deserialized


class ConfigOperations:
"""
Expand Down
65 changes: 65 additions & 0 deletions diracx-client/src/diracx/client/operations/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,20 @@ def build_auth_userinfo_request(**kwargs: Any) -> HttpRequest:
return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs)


def build_auth_get_proxy_request(**kwargs: Any) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})

accept = _headers.pop("Accept", "application/json")

# Construct URL
_url = "/api/auth/proxy"

# Construct headers
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")

return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs)


def build_config_serve_config_request(
vo: str,
*,
Expand Down Expand Up @@ -1480,6 +1494,57 @@ def userinfo(self, **kwargs: Any) -> _models.UserInfoResponse:

return deserialized

@distributed_trace
def get_proxy(self, **kwargs: Any) -> Any:
"""Get Proxy.

Get Proxy.

:return: any
:rtype: any
:raises ~azure.core.exceptions.HttpResponseError:
"""
error_map = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
409: ResourceExistsError,
304: ResourceNotModifiedError,
}
error_map.update(kwargs.pop("error_map", {}) or {})

_headers = kwargs.pop("headers", {}) or {}
_params = kwargs.pop("params", {}) or {}

cls: ClsType[Any] = kwargs.pop("cls", None)

request = build_auth_get_proxy_request(
headers=_headers,
params=_params,
)
request.url = self._client.format_url(request.url)

_stream = False
pipeline_response: PipelineResponse = (
self._client._pipeline.run( # pylint: disable=protected-access
request, stream=_stream, **kwargs
)
)

response = pipeline_response.http_response

if response.status_code not in [200]:
map_error(
status_code=response.status_code, response=response, error_map=error_map
)
raise HttpResponseError(response=response)

deserialized = self._deserialize("object", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {})

return deserialized


class ConfigOperations:
"""
Expand Down
13 changes: 13 additions & 0 deletions diracx-core/src/diracx/core/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,26 @@ class SupportInfo(BaseModel):
Message: str = "Please contact system administrator"


class VomsServerConfig(BaseModel):
# Taken from one of the lines in $X509_VOMSES/$VO_NAME
Info: str
# Taken from $X509_VOMSDIR/$VO_NAME/$HOSTNAME.lsc
Chain: list[str]


class VomsConfig(BaseModel):
Name: str | None
Servers: dict[str, VomsServerConfig] = {}


class RegistryConfig(BaseModel):
IdP: IdpConfig
Support: SupportInfo = Field(default_factory=SupportInfo)
DefaultGroup: str
DefaultStorageQuota: float = 0
DefaultProxyLifeTime: int = 12 * 60 * 60
VOMSName: str | None = None
VOMS: VomsConfig = Field(default_factory=VomsConfig)

Users: dict[str, UserConfig]
Groups: dict[str, GroupConfig]
Expand Down
8 changes: 8 additions & 0 deletions diracx-core/src/diracx/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,11 @@ class JobNotFound(Exception):
def __init__(self, job_id: int):
self.job_id: int = job_id
super().__init__(f"Job {job_id} not found")


class VOMSInitError(DiracError):
"""Adding VOMS attributes to a proxy failed"""


class ProxyNotFoundError(DiracError):
"""There are no valid proxies for the given user"""
1 change: 1 addition & 0 deletions diracx-db/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ testing = [
AuthDB = "diracx.db.sql:AuthDB"
JobDB = "diracx.db.sql:JobDB"
JobLoggingDB = "diracx.db.sql:JobLoggingDB"
ProxyDB = "diracx.db.sql:ProxyDB"
SandboxMetadataDB = "diracx.db.sql:SandboxMetadataDB"
TaskQueueDB = "diracx.db.sql:TaskQueueDB"

Expand Down
10 changes: 9 additions & 1 deletion diracx-db/src/diracx/db/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from __future__ import annotations

__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB")
__all__ = (
"AuthDB",
"JobDB",
"JobLoggingDB",
"ProxyDB",
"SandboxMetadataDB",
"TaskQueueDB",
)

from .auth.db import AuthDB
from .jobs.db import JobDB, JobLoggingDB, TaskQueueDB
from .proxy.db import ProxyDB
from .sandbox_metadata.db import SandboxMetadataDB
Empty file.
122 changes: 122 additions & 0 deletions diracx-db/src/diracx/db/sql/proxy/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

import asyncio
import os
import stat
from datetime import datetime, timezone
from pathlib import Path
from subprocess import DEVNULL, PIPE, STDOUT
from tempfile import TemporaryDirectory

from DIRAC.Core.Security import Locations
from DIRAC.Core.Security.VOMS import voms_init_cmd
from DIRAC.Core.Security.X509Chain import X509Chain
from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise
from sqlalchemy import select

from diracx.core.exceptions import ProxyNotFoundError, VOMSInitError
from diracx.db.sql.utils import BaseSQLDB, utcnow

from .schema import Base as ProxyDBBase
from .schema import CleanProxies

PROXY_PROVIDER = "Certificate"


class ProxyDB(BaseSQLDB):
metadata = ProxyDBBase.metadata

async def get_proxy(
self,
dn: str,
vo: str,
dirac_group: str,
voms_attr: str | None,
lifetime_seconds: int,
vomses: Path,
vomsdir: Path,
) -> str:
"""Generate a new proxy for the given DN as PEM with the given VOMS extension"""
original_chain = await self.get_stored_proxy(
dn, min_lifetime_seconds=lifetime_seconds
)

proxy_string = returnValueOrRaise(
original_chain.generateProxyToString(
lifetime_seconds,
diracGroup=dirac_group,
strength=returnValueOrRaise(original_chain.getStrength()),
)
)
proxy_chain = X509Chain()
proxy_chain.loadProxyFromString(proxy_string)

with TemporaryDirectory() as tmpdir:
in_fn = Path(tmpdir) / "in.pem"
in_fn.touch(stat.S_IRUSR | stat.S_IWUSR)
in_fn.write_text(proxy_string)
out_fn = Path(tmpdir) / "out.pem"

cmd = voms_init_cmd(
vo,
voms_attr,
proxy_chain,
str(in_fn),
str(out_fn),
str(vomses),
)
proc = await asyncio.create_subprocess_exec(
*cmd,
stdin=DEVNULL,
stdout=PIPE,
stderr=STDOUT,
env=os.environ
| {
"X509_CERT_DIR": Locations.getCAsLocationNoConfig(),
"X509_VOMS_DIR": str(vomsdir),
},
)
await proc.wait()
if proc.returncode != 0:
assert proc.stdout
message = (await proc.stdout.read()).decode("utf-8", "backslashreplace")
raise VOMSInitError(
f"voms-proxy-init failed with return code {proc.returncode}: {message}"
)

voms_string = out_fn.read_text()

return voms_string

async def get_stored_proxy(
self, dn: str, *, min_lifetime_seconds: int
) -> X509Chain:
"""Get the X509 proxy that is stored in the DB for the given DN

NOTE: This is the original long-lived proxy and should only be used to
generate short-lived proxies!!!
"""
stmt = select(CleanProxies.Pem, CleanProxies.ExpirationTime)
stmt = stmt.where(
CleanProxies.UserDN == dn,
CleanProxies.ExpirationTime > utcnow(),
CleanProxies.ProxyProvider == PROXY_PROVIDER,
)

for pem_data, expiration_time in (await self.conn.execute(stmt)).all():
seconds_remaining = (
expiration_time.replace(tzinfo=timezone.utc)
- datetime.now(timezone.utc)
).total_seconds()
if seconds_remaining <= min_lifetime_seconds:
continue

pem_data = pem_data.decode("ascii")
if not pem_data:
continue
chain = X509Chain()
returnValueOrRaise(chain.loadProxyFromString(pem_data))
return chain
raise ProxyNotFoundError(
f"No proxy found for {dn} with over {min_lifetime_seconds} seconds of life"
)
Loading
Loading