Skip to content

Commit

Permalink
use sh.git instead of git module to fix #319
Browse files Browse the repository at this point in the history
  • Loading branch information
chaen committed Dec 5, 2024
1 parent 52460ef commit f1f0831
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
1 change: 1 addition & 0 deletions diracx-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"pydantic >=2.10",
"pydantic-settings",
"pyyaml",
"sh",
]
dynamic = ["version"]

Expand Down
46 changes: 30 additions & 16 deletions diracx-core/src/diracx/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tempfile import TemporaryDirectory
from typing import Annotated

import git
import sh
import yaml
from cachetools import Cache, LRUCache, TTLCache, cachedmethod
from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints
Expand Down Expand Up @@ -117,10 +117,9 @@ class BaseGitConfigSource(ConfigSource):
The caching is based on 2 caches:
* TTL to find the latest commit hashes
* LRU to keep in memory the last few versions.
"""

repo: git.Repo
repo_location: Path

# Needed because of the ConfigSource.__init_subclass__
scheme = "basegit"
Expand All @@ -134,22 +133,31 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None:
@cachedmethod(lambda self: self._latest_revision_cache)
def latest_revision(self) -> tuple[str, datetime]:
try:
rev = self.repo.rev_parse(DEFAULT_GIT_BRANCH)
except git.exc.ODBError as e: # type: ignore
rev = sh.git(
"rev-parse", DEFAULT_GIT_BRANCH, _cwd=self.repo_location, _tty_out=False
).strip()
commit_info = sh.git.show(
"-s", "--format=%ct", rev, _cwd=self.repo_location, _tty_out=False
).strip()
modified = datetime.fromtimestamp(int(commit_info), tz=timezone.utc)
except sh.ErrorReturnCode as e:
raise BadConfigurationVersion(f"Error parsing latest revision: {e}") from e
modified = rev.committed_datetime.astimezone(timezone.utc)
logger.debug(
"Latest revision for %s is %s with mtime %s", self, rev.hexsha, modified
)
return rev.hexsha, modified
logger.debug("Latest revision for %s is %s with mtime %s", self, rev, modified)
return rev, modified

@cachedmethod(lambda self: self._read_raw_cache)
def read_raw(self, hexsha: str, modified: datetime) -> Config:
""":param: hexsha commit hash"""
logger.debug("Reading %s for %s with mtime %s", self, hexsha, modified)
rev = self.repo.rev_parse(hexsha)
blob = rev.tree / DEFAULT_CONFIG_FILE
raw_obj = yaml.safe_load(blob.data_stream.read().decode())
try:
blob = sh.git.show(
f"{hexsha}:{DEFAULT_CONFIG_FILE}",
_cwd=self.repo_location,
_tty_out=False,
)
raw_obj = yaml.safe_load(blob)
except sh.ErrorReturnCode as e:
raise BadConfigurationVersion(f"Error reading configuration: {e}") from e

config_class: Config = select_from_extension(group="diracx", name="config")[
0
Expand Down Expand Up @@ -177,7 +185,13 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None:
raise ValueError("Empty path for LocalGitConfigSource")

self.repo_location = Path(backend_url.path)
self.repo = git.Repo(self.repo_location)
# Check if it's a valid git repository
try:
sh.git("rev-parse", "--git-dir", _cwd=self.repo_location, _tty_out=False)
except sh.ErrorReturnCode as e:
raise ValueError(
f"{self.repo_location} is not a valid git repository"
) from e

def __hash__(self):
return hash(self.repo_location)
Expand All @@ -197,7 +211,7 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None:
self.remote_url = str(backend_url).replace("git+", "")
self._temp_dir = TemporaryDirectory()
self.repo_location = Path(self._temp_dir.name)
self.repo = git.Repo.clone_from(self.remote_url, self.repo_location)
sh.git.clone(self.remote_url, self.repo_location)
self._pull_cache: Cache = TTLCache(
MAX_PULL_CACHED_VERSIONS, DEFAULT_PULL_CACHE_TTL
)
Expand All @@ -212,7 +226,7 @@ def __hash__(self):
@cachedmethod(lambda self: self._pull_cache)
def _pull(self):
"""Git pull from remote repo."""
self.repo.remotes.origin.pull()
sh.git.pull(_cwd=self.repo_location)

def latest_revision(self) -> tuple[str, datetime]:
self._pull()
Expand Down

0 comments on commit f1f0831

Please sign in to comment.