-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
291 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,15 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB") | ||
__all__ = ( | ||
"AuthDB", | ||
"JobDB", | ||
"JobLoggingDB", | ||
"ProxyDB", | ||
"SandboxMetadataDB", | ||
"TaskQueueDB", | ||
) | ||
|
||
from .auth.db import AuthDB | ||
from .jobs.db import JobDB, JobLoggingDB, TaskQueueDB | ||
from .proxy.db import ProxyDB | ||
from .sandbox_metadata.db import SandboxMetadataDB |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import os | ||
import stat | ||
from datetime import datetime, timezone | ||
from pathlib import Path | ||
from subprocess import DEVNULL, PIPE, STDOUT | ||
from tempfile import TemporaryDirectory | ||
|
||
from DIRAC.Core.Security import Locations | ||
from DIRAC.Core.Security.VOMS import voms_init_cmd | ||
from DIRAC.Core.Security.X509Chain import X509Chain | ||
from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise | ||
from sqlalchemy import select | ||
|
||
from diracx.core.exceptions import ProxyNotFoundError, VOMSInitError | ||
from diracx.db.sql.utils import BaseSQLDB, utcnow | ||
|
||
from .schema import Base as ProxyDBBase | ||
from .schema import CleanProxies | ||
|
||
PROXY_PROVIDER = "Certificate" | ||
|
||
|
||
class ProxyDB(BaseSQLDB): | ||
metadata = ProxyDBBase.metadata | ||
|
||
async def get_proxy( | ||
self, | ||
dn: str, | ||
vo: str, | ||
dirac_group: str, | ||
voms_attr: str | None, | ||
lifetime_seconds: int, | ||
) -> str: | ||
"""Generate a new proxy for the given DN as PEM with the given VOMS extension""" | ||
original_chain = await self.get_stored_proxy( | ||
dn, min_lifetime_seconds=lifetime_seconds | ||
) | ||
|
||
proxy_string = returnValueOrRaise( | ||
original_chain.generateProxyToString( | ||
lifetime_seconds, | ||
diracGroup=dirac_group, | ||
strength=returnValueOrRaise(original_chain.getStrength()), | ||
) | ||
) | ||
proxy_chain = X509Chain() | ||
proxy_chain.loadProxyFromString(proxy_string) | ||
|
||
with TemporaryDirectory() as tmpdir: | ||
in_fn = Path(tmpdir) / "in.pem" | ||
in_fn.touch(stat.S_IRUSR | stat.S_IWUSR) | ||
in_fn.write_text(proxy_string) | ||
out_fn = Path(tmpdir) / "out.pem" | ||
|
||
cmd = voms_init_cmd( | ||
vo, | ||
voms_attr, | ||
proxy_chain, | ||
str(in_fn), | ||
str(out_fn), | ||
Locations.getVomsesLocation(), | ||
) | ||
proc = await asyncio.create_subprocess_exec( | ||
*cmd, | ||
stdin=DEVNULL, | ||
stdout=PIPE, | ||
stderr=STDOUT, | ||
env=os.environ | ||
| { | ||
"X509_CERT_DIR": Locations.getCAsLocationNoConfig(), | ||
"X509_VOMS_DIR": Locations.getVomsdirLocation(), | ||
}, | ||
) | ||
await proc.wait() | ||
if proc.returncode != 0: | ||
assert proc.stdout | ||
message = (await proc.stdout.read()).decode("utf-8", "backslashreplace") | ||
raise VOMSInitError( | ||
f"voms-proxy-init failed with return code {proc.returncode}: {message}" | ||
) | ||
|
||
voms_string = out_fn.read_text() | ||
|
||
return voms_string | ||
|
||
async def get_stored_proxy( | ||
self, dn: str, *, min_lifetime_seconds: int | ||
) -> X509Chain: | ||
"""Get the X509 proxy that is stored in the DB for the given DN | ||
NOTE: This is the original long-lived proxy and should only be used to | ||
generate short-lived proxies!!! | ||
""" | ||
stmt = select(CleanProxies.Pem, CleanProxies.ExpirationTime) | ||
stmt = stmt.where( | ||
CleanProxies.UserDN == dn, | ||
CleanProxies.ExpirationTime > utcnow(), | ||
CleanProxies.ProxyProvider == PROXY_PROVIDER, | ||
) | ||
|
||
for pem_data, expiration_time in (await self.conn.execute(stmt)).all(): | ||
seconds_remaining = ( | ||
expiration_time.replace(tzinfo=timezone.utc) | ||
- datetime.now(timezone.utc) | ||
).total_seconds() | ||
if seconds_remaining <= min_lifetime_seconds: | ||
continue | ||
|
||
pem_data = pem_data.decode("ascii") | ||
if not pem_data: | ||
continue | ||
chain = X509Chain() | ||
returnValueOrRaise(chain.loadProxyFromString(pem_data)) | ||
return chain | ||
raise ProxyNotFoundError( | ||
f"No proxy found for {dn} with over {min_lifetime_seconds} seconds of life" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from __future__ import annotations | ||
|
||
from sqlalchemy import ( | ||
BLOB, | ||
DateTime, | ||
String, | ||
) | ||
from sqlalchemy.orm import declarative_base | ||
|
||
from diracx.db.sql.utils import Column, NullColumn | ||
|
||
Base = declarative_base() | ||
|
||
|
||
class CleanProxies(Base): | ||
__tablename__ = "ProxyDB_CleanProxies" | ||
UserName = Column(String(64)) | ||
Pem = NullColumn(BLOB) | ||
ProxyProvider = Column(String(64), default="Certificate") | ||
ExpirationTime = NullColumn(DateTime) | ||
UserDN = Column(String(255), primary_key=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from __future__ import annotations | ||
|
||
from datetime import datetime, timedelta, timezone | ||
from functools import wraps | ||
from pathlib import Path | ||
from typing import AsyncGenerator | ||
|
||
import pytest | ||
from DIRAC.Core.Security.VOMS import voms_init_cmd | ||
from DIRAC.Core.Security.X509Chain import X509Chain | ||
from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise | ||
from sqlalchemy import insert | ||
|
||
from diracx.core.exceptions import DiracError | ||
from diracx.db.sql.proxy.db import ProxyDB | ||
from diracx.db.sql.proxy.schema import CleanProxies | ||
|
||
TEST_NAME = "testuser" | ||
TEST_DN = "/O=Dirac Computing/O=CERN/CN=MrUser" | ||
TEST_DATA_DIR = Path(__file__).parent / "data" | ||
TEST_PEM_PATH = TEST_DATA_DIR / "proxy.pem" | ||
|
||
|
||
@pytest.fixture | ||
async def empty_proxy_db(tmp_path) -> AsyncGenerator[ProxyDB, None]: | ||
proxy_db = ProxyDB("sqlite+aiosqlite:///:memory:") | ||
async with proxy_db.engine_context(): | ||
async with proxy_db.engine.begin() as conn: | ||
await conn.run_sync(proxy_db.metadata.create_all) | ||
yield proxy_db | ||
|
||
|
||
@pytest.fixture | ||
async def proxy_db(empty_proxy_db) -> AsyncGenerator[ProxyDB, None]: | ||
async with empty_proxy_db.engine.begin() as conn: | ||
await conn.execute( | ||
insert(CleanProxies).values( | ||
UserName=TEST_NAME, | ||
UserDN=TEST_DN, | ||
ProxyProvider="Certificate", | ||
Pem=TEST_PEM_PATH.read_bytes(), | ||
ExpirationTime=datetime(2033, 11, 25, 21, 25, 23, tzinfo=timezone.utc), | ||
) | ||
) | ||
yield empty_proxy_db | ||
|
||
|
||
async def test_get_stored_proxy(proxy_db: ProxyDB): | ||
async with proxy_db as proxy_db: | ||
proxy = await proxy_db.get_stored_proxy(TEST_DN, min_lifetime_seconds=3600) | ||
assert proxy | ||
|
||
|
||
async def test_no_proxy_for_dn_1(empty_proxy_db: ProxyDB): | ||
async with empty_proxy_db as proxy_db: | ||
with pytest.raises(DiracError, match="No proxy found"): | ||
await proxy_db.get_stored_proxy(TEST_DN, min_lifetime_seconds=3600) | ||
|
||
|
||
async def test_no_proxy_for_dn_2(empty_proxy_db: ProxyDB): | ||
async with empty_proxy_db as proxy_db: | ||
with pytest.raises(DiracError, match="No proxy found"): | ||
await proxy_db.get_stored_proxy( | ||
"/O=OtherOrg/O=CERN/CN=MrUser", min_lifetime_seconds=3600 | ||
) | ||
|
||
|
||
async def test_proxy_not_long_enough(proxy_db: ProxyDB): | ||
async with proxy_db as proxy_db: | ||
with pytest.raises(DiracError, match="No proxy found"): | ||
# The test proxy we use is valid for 10 years | ||
# If this code still exists in 2028 we might start having problems with 2K38 | ||
await proxy_db.get_stored_proxy( | ||
TEST_DN, min_lifetime_seconds=10 * 365 * 24 * 3600 | ||
) | ||
|
||
|
||
@wraps(voms_init_cmd) | ||
def voms_init_cmd_fake(*args, **kwargs): | ||
cmd = voms_init_cmd(*args, **kwargs) | ||
|
||
new_cmd = ["voms-proxy-fake"] | ||
i = 1 | ||
while i < len(cmd): | ||
# Some options are not supported by voms-proxy-fake | ||
if cmd[i] in {"-valid", "-vomses", "-timeout"}: | ||
i += 2 | ||
continue | ||
new_cmd.append(cmd[i]) | ||
i += 1 | ||
new_cmd.extend( | ||
[ | ||
"-hostcert", | ||
f"{TEST_DATA_DIR}/certs/host/hostcert.pem", | ||
"-hostkey", | ||
f"{TEST_DATA_DIR}/certs/host/hostkey.pem", | ||
"-fqan", | ||
"/fakevo/Role=NULL/Capability=NULL", | ||
] | ||
) | ||
return new_cmd | ||
|
||
|
||
async def test_get_proxy(proxy_db: ProxyDB, monkeypatch): | ||
monkeypatch.setenv("X509_CERT_DIR", str(TEST_DATA_DIR / "certs")) | ||
monkeypatch.setattr("diracx.db.sql.proxy.db.voms_init_cmd", voms_init_cmd_fake) | ||
|
||
async with proxy_db as proxy_db: | ||
proxy_pem = await proxy_db.get_proxy( | ||
TEST_DN, "fakevo", "fakevo_user", "/fakevo", 3600 | ||
) | ||
|
||
proxy_chain = X509Chain() | ||
returnValueOrRaise(proxy_chain.loadProxyFromString(proxy_pem)) | ||
|
||
# Check validity | ||
not_after = returnValueOrRaise(proxy_chain.getNotAfterDate()).replace( | ||
tzinfo=timezone.utc | ||
) | ||
# The proxy should currently be valid | ||
assert datetime.now(timezone.utc) < not_after | ||
# The proxy should be invalid in less than 3601 seconds | ||
time_left = not_after - datetime.now(timezone.utc) | ||
assert time_left < timedelta(hours=1, seconds=1) | ||
|
||
# Check VOMS data | ||
voms_data = returnValueOrRaise(proxy_chain.getVOMSData()) | ||
assert voms_data["vo"] == "fakevo" | ||
assert voms_data["fqan"] == ["/fakevo/Role=NULL/Capability=NULL"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters