From c4f98e35059b146392df9d2a56f6fd08de6a39c9 Mon Sep 17 00:00:00 2001 From: AkashiCoin Date: Mon, 16 Sep 2024 18:39:21 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=88=20perf(github=5Futils):=20?= =?UTF-8?q?=E6=94=AF=E6=8C=81github=20url=E4=B8=8B=E8=BD=BD=E9=81=8D?= =?UTF-8?q?=E5=8E=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/settings.json | 3 +- .../auto_update/_data_source.py | 12 +- .../plugin_store/data_source.py | 18 +-- .../web_ui/public/data_source.py | 6 +- zhenxun/utils/github_utils/__init__.py | 32 +++-- zhenxun/utils/github_utils/func.py | 20 +-- zhenxun/utils/github_utils/models.py | 50 ++++--- zhenxun/utils/http_utils.py | 123 ++++++++++-------- 8 files changed, 152 insertions(+), 112 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 0a666ba9c..6ab1cbda8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -23,7 +23,8 @@ "ujson", "unban", "userinfo", - "zhenxun" + "zhenxun", + "jsdelivr" ], "python.analysis.autoImportCompletions": true, "python.testing.pytestArgs": ["tests"], diff --git a/zhenxun/builtin_plugins/auto_update/_data_source.py b/zhenxun/builtin_plugins/auto_update/_data_source.py index b089fc2eb..be5e3b230 100644 --- a/zhenxun/builtin_plugins/auto_update/_data_source.py +++ b/zhenxun/builtin_plugins/auto_update/_data_source.py @@ -10,8 +10,8 @@ from zhenxun.services.log import logger from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.utils.platform import PlatformUtils +from zhenxun.utils.github_utils import GithubUtils from zhenxun.utils.github_utils.models import RepoInfo -from zhenxun.utils.github_utils import parse_github_url from .config import ( TMP_PATH, @@ -170,19 +170,19 @@ async def update(cls, bot: Bot, user_id: str, version_type: str) -> str | None: cur_version = cls.__get_version() url = None new_version = None - repo_info = parse_github_url(DEFAULT_GITHUB_URL) + repo_info = GithubUtils.parse_github_url(DEFAULT_GITHUB_URL) if version_type in {"dev", "main"}: repo_info.branch = version_type new_version = await cls.__get_version_from_repo(repo_info) if new_version: new_version = new_version.split(":")[-1].strip() - url = await repo_info.get_archive_download_url() + url = await repo_info.get_archive_download_urls() elif version_type == "release": data = await cls.__get_latest_data() if not data: return "获取更新版本失败..." new_version = data.get("name", "") - url = await repo_info.get_release_source_download_url_tgz(new_version) + url = await repo_info.get_release_source_download_urls_tgz(new_version) if not url: return "获取版本下载链接失败..." if TMP_PATH.exists(): @@ -200,7 +200,7 @@ async def update(cls, bot: Bot, user_id: str, version_type: str) -> str | None: download_file = ( DOWNLOAD_GZ_FILE if version_type == "release" else DOWNLOAD_ZIP_FILE ) - if await AsyncHttpx.download_file(url, download_file): + if await AsyncHttpx.download_file(url, download_file, stream=True): logger.debug("下载真寻最新版文件完成...", "检查更新") await _file_handle(new_version) return ( @@ -253,7 +253,7 @@ async def __get_version_from_repo(cls, repo_info: RepoInfo) -> str: 返回: str: 版本号 """ - version_url = await repo_info.get_raw_download_url(path="__version__") + version_url = await repo_info.get_raw_download_urls(path="__version__") try: res = await AsyncHttpx.get(version_url) if res.status_code == 200: diff --git a/zhenxun/builtin_plugins/plugin_store/data_source.py b/zhenxun/builtin_plugins/plugin_store/data_source.py index 6c35890c8..407c53e1d 100644 --- a/zhenxun/builtin_plugins/plugin_store/data_source.py +++ b/zhenxun/builtin_plugins/plugin_store/data_source.py @@ -8,8 +8,8 @@ from zhenxun.services.log import logger from zhenxun.utils.http_utils import AsyncHttpx from zhenxun.models.plugin_info import PluginInfo +from zhenxun.utils.github_utils import GithubUtils from zhenxun.utils.github_utils.models import RepoAPI -from zhenxun.utils.github_utils import api_strategy, parse_github_url from zhenxun.builtin_plugins.plugin_store.models import StorePluginInfo from zhenxun.utils.image_utils import RowStyle, BuildImage, ImageTemplate from zhenxun.builtin_plugins.auto_update.config import REQ_TXT_FILE_STRING @@ -78,12 +78,12 @@ async def __get_data(cls) -> dict[str, StorePluginInfo]: 返回: dict: 插件信息数据 """ - default_github_url = await parse_github_url( + default_github_url = await GithubUtils.parse_github_url( DEFAULT_GITHUB_URL - ).get_raw_download_url("plugins.json") - extra_github_url = await parse_github_url( + ).get_raw_download_urls("plugins.json") + extra_github_url = await GithubUtils.parse_github_url( EXTRA_GITHUB_URL - ).get_raw_download_url("plugins.json") + ).get_raw_download_urls("plugins.json") res = await AsyncHttpx.get(default_github_url) res2 = await AsyncHttpx.get(extra_github_url) @@ -210,9 +210,9 @@ async def install_plugin_with_repo( ): files: list[str] repo_api: RepoAPI - repo_info = parse_github_url(github_url) + repo_info = GithubUtils.parse_github_url(github_url) logger.debug(f"成功获取仓库信息: {repo_info}", "插件管理") - for repo_api in api_strategy: + for repo_api in GithubUtils.iter_api_strategies(): try: await repo_api.parse_repo_info(repo_info) break @@ -227,7 +227,7 @@ async def install_plugin_with_repo( module_path=module_path.replace(".", "/") + ("" if is_dir else ".py"), is_dir=is_dir, ) - download_urls = [await repo_info.get_raw_download_url(file) for file in files] + download_urls = [await repo_info.get_raw_download_urls(file) for file in files] base_path = BASE_PATH / "plugins" if is_external else BASE_PATH download_paths: list[Path | str] = [base_path / file for file in files] logger.debug(f"插件下载路径: {download_paths}", "插件管理") @@ -242,7 +242,7 @@ async def install_plugin_with_repo( req_files.extend(repo_api.get_files("requirement.txt", False)) logger.debug(f"获取插件依赖文件列表: {req_files}", "插件管理") req_download_urls = [ - await repo_info.get_raw_download_url(file) for file in req_files + await repo_info.get_raw_download_urls(file) for file in req_files ] req_paths: list[Path | str] = [plugin_path / file for file in req_files] logger.debug(f"插件依赖文件下载路径: {req_paths}", "插件管理") diff --git a/zhenxun/builtin_plugins/web_ui/public/data_source.py b/zhenxun/builtin_plugins/web_ui/public/data_source.py index 8134433c5..8d094d99f 100644 --- a/zhenxun/builtin_plugins/web_ui/public/data_source.py +++ b/zhenxun/builtin_plugins/web_ui/public/data_source.py @@ -6,7 +6,7 @@ from zhenxun.services.log import logger from zhenxun.utils.http_utils import AsyncHttpx -from zhenxun.utils.github_utils import parse_github_url +from zhenxun.utils.github_utils import GithubUtils from ..config import TMP_PATH, PUBLIC_PATH, WEBUI_DIST_GITHUB_URL @@ -15,9 +15,9 @@ async def update_webui_assets(): webui_assets_path = TMP_PATH / "webui_assets.zip" - download_url = await parse_github_url( + download_url = await GithubUtils.parse_github_url( WEBUI_DIST_GITHUB_URL - ).get_archive_download_url() + ).get_archive_download_urls() if await AsyncHttpx.download_file( download_url, webui_assets_path, follow_redirects=True ): diff --git a/zhenxun/utils/github_utils/__init__.py b/zhenxun/utils/github_utils/__init__.py index 56fd10089..89b0a80a5 100644 --- a/zhenxun/utils/github_utils/__init__.py +++ b/zhenxun/utils/github_utils/__init__.py @@ -1,23 +1,27 @@ +from collections.abc import Generator + from .consts import GITHUB_REPO_URL_PATTERN -from .func import get_fastest_raw_format, get_fastest_archive_format +from .func import get_fastest_raw_formats, get_fastest_archive_formats from .models import RepoAPI, RepoInfo, GitHubStrategy, JsdelivrStrategy __all__ = [ - "parse_github_url", - "get_fastest_raw_format", - "get_fastest_archive_format", - "api_strategy", + "get_fastest_raw_formats", + "get_fastest_archive_formats", + "GithubUtils", ] -def parse_github_url(github_url: str) -> "RepoInfo": - if matched := GITHUB_REPO_URL_PATTERN.match(github_url): - return RepoInfo(**{k: v for k, v in matched.groupdict().items() if v}) - raise ValueError("github地址格式错误") - +class GithubUtils: + # 使用 + jsdelivr_api = RepoAPI(JsdelivrStrategy()) # type: ignore + github_api = RepoAPI(GitHubStrategy()) # type: ignore -# 使用 -jsdelivr_api = RepoAPI(JsdelivrStrategy()) # type: ignore -github_api = RepoAPI(GitHubStrategy()) # type: ignore + @classmethod + def iter_api_strategies(cls) -> Generator[RepoAPI]: + yield from [cls.github_api, cls.jsdelivr_api] -api_strategy = [github_api, jsdelivr_api] + @classmethod + def parse_github_url(cls, github_url: str) -> "RepoInfo": + if matched := GITHUB_REPO_URL_PATTERN.match(github_url): + return RepoInfo(**{k: v for k, v in matched.groupdict().items() if v}) + raise ValueError("github地址格式错误") diff --git a/zhenxun/utils/github_utils/func.py b/zhenxun/utils/github_utils/func.py index 145dbd967..95d2a3ef7 100644 --- a/zhenxun/utils/github_utils/func.py +++ b/zhenxun/utils/github_utils/func.py @@ -9,15 +9,15 @@ ) -async def __get_fastest_format(formats: dict[str, str]) -> str: +async def __get_fastest_formats(formats: dict[str, str]) -> list[str]: sorted_urls = await AsyncHttpx.get_fastest_mirror(list(formats.keys())) if not sorted_urls: raise Exception("无法获取任意GitHub资源加速地址,请检查网络") - return formats[sorted_urls[0]] + return [formats[url] for url in sorted_urls] @cached() -async def get_fastest_raw_format() -> str: +async def get_fastest_raw_formats() -> list[str]: """获取最快的raw下载地址格式""" formats: dict[str, str] = { "https://raw.githubusercontent.com/": RAW_CONTENT_FORMAT, @@ -26,11 +26,11 @@ async def get_fastest_raw_format() -> str: "https://gh-proxy.com/": f"https://gh-proxy.com/{RAW_CONTENT_FORMAT}", "https://cdn.jsdelivr.net/": "https://cdn.jsdelivr.net/gh/{owner}/{repo}@{branch}/{path}", } - return await __get_fastest_format(formats) + return await __get_fastest_formats(formats) @cached() -async def get_fastest_archive_format() -> str: +async def get_fastest_archive_formats() -> list[str]: """获取最快的归档下载地址格式""" formats: dict[str, str] = { "https://github.com/": ARCHIVE_URL_FORMAT, @@ -38,11 +38,11 @@ async def get_fastest_archive_format() -> str: "https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{ARCHIVE_URL_FORMAT}", "https://gh-proxy.com/": f"https://gh-proxy.com/{ARCHIVE_URL_FORMAT}", } - return await __get_fastest_format(formats) + return await __get_fastest_formats(formats) @cached() -async def get_fastest_release_format() -> str: +async def get_fastest_release_formats() -> list[str]: """获取最快的发行版资源下载地址格式""" formats: dict[str, str] = { "https://objects.githubusercontent.com/": RELEASE_ASSETS_FORMAT, @@ -50,14 +50,14 @@ async def get_fastest_release_format() -> str: "https://mirror.ghproxy.com/": f"https://mirror.ghproxy.com/{RELEASE_ASSETS_FORMAT}", "https://gh-proxy.com/": f"https://gh-proxy.com/{RELEASE_ASSETS_FORMAT}", } - return await __get_fastest_format(formats) + return await __get_fastest_formats(formats) @cached() -async def get_fastest_release_source_format() -> str: +async def get_fastest_release_source_formats() -> list[str]: """获取最快的发行版源码下载地址格式""" formats: dict[str, str] = { "https://codeload.github.com/": RELEASE_SOURCE_FORMAT, "https://p.102333.xyz/": f"https://p.102333.xyz/{RELEASE_SOURCE_FORMAT}", } - return await __get_fastest_format(formats) + return await __get_fastest_formats(formats) diff --git a/zhenxun/utils/github_utils/models.py b/zhenxun/utils/github_utils/models.py index 98ce02914..170892815 100644 --- a/zhenxun/utils/github_utils/models.py +++ b/zhenxun/utils/github_utils/models.py @@ -7,9 +7,9 @@ from ..http_utils import AsyncHttpx from .consts import CACHED_API_TTL, GIT_API_TREES_FORMAT, JSD_PACKAGE_API_FORMAT from .func import ( - get_fastest_raw_format, - get_fastest_archive_format, - get_fastest_release_source_format, + get_fastest_raw_formats, + get_fastest_archive_formats, + get_fastest_release_source_formats, ) @@ -20,21 +20,41 @@ class RepoInfo(BaseModel): repo: str branch: str = "main" - async def get_raw_download_url(self, path: str): - url_format = await get_fastest_raw_format() - return url_format.format(**self.dict(), path=path) + async def get_raw_download_url(self, path: str) -> str: + return (await self.get_raw_download_urls(path))[0] - async def get_archive_download_url(self): - url_format = await get_fastest_archive_format() - return url_format.format(**self.dict()) + async def get_archive_download_url(self) -> str: + return (await self.get_archive_download_urls())[0] - async def get_release_source_download_url_tgz(self, version: str): - url_format = await get_fastest_release_source_format() - return url_format.format(**self.dict(), version=version, compress="tar.gz") + async def get_release_source_download_url_tgz(self, version: str) -> str: + return (await self.get_release_source_download_urls_tgz(version))[0] - async def get_release_source_download_url_zip(self, version: str): - url_format = await get_fastest_release_source_format() - return url_format.format(**self.dict(), version=version, compress="zip") + async def get_release_source_download_url_zip(self, version: str) -> str: + return (await self.get_release_source_download_urls_zip(version))[0] + + async def get_raw_download_urls(self, path: str) -> list[str]: + url_formats = await get_fastest_raw_formats() + return [ + url_format.format(**self.dict(), path=path) for url_format in url_formats + ] + + async def get_archive_download_urls(self) -> list[str]: + url_formats = await get_fastest_archive_formats() + return [url_format.format(**self.dict()) for url_format in url_formats] + + async def get_release_source_download_urls_tgz(self, version: str) -> list[str]: + url_formats = await get_fastest_release_source_formats() + return [ + url_format.format(**self.dict(), version=version, compress="tar.gz") + for url_format in url_formats + ] + + async def get_release_source_download_urls_zip(self, version: str) -> list[str]: + url_formats = await get_fastest_release_source_formats() + return [ + url_format.format(**self.dict(), version=version, compress="zip") + for url_format in url_formats + ] class APIStrategy(Protocol): diff --git a/zhenxun/utils/http_utils.py b/zhenxun/utils/http_utils.py index 98f2b74ee..f45f84ec5 100644 --- a/zhenxun/utils/http_utils.py +++ b/zhenxun/utils/http_utils.py @@ -33,7 +33,7 @@ class AsyncHttpx: @retry(stop_max_attempt_number=3) async def get( cls, - url: str, + url: str | list[str], *, params: dict[str, Any] | None = None, headers: dict[str, str] | None = None, @@ -56,18 +56,28 @@ async def get( proxy: 指定代理 timeout: 超时时间 """ + if not isinstance(url, list): + url = [url] if not headers: headers = get_user_agent() + last_exception = Exception _proxy = proxy if proxy else cls.proxy if use_proxy else None - async with httpx.AsyncClient(proxies=_proxy, verify=verify) as client: # type: ignore - return await client.get( - url, - params=params, - headers=headers, - cookies=cookies, - timeout=timeout, - **kwargs, - ) + for u in url: + try: + async with httpx.AsyncClient(proxies=_proxy, verify=verify) as client: # type: ignore + return await client.get( + u, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + **kwargs, + ) + except Exception: + last_exception = Exception + if u != url[-1]: + logger.warning(f"获取 {u} 失败, 尝试下一个") + raise last_exception @classmethod async def head( @@ -162,7 +172,7 @@ async def post( @classmethod async def download_file( cls, - url: str, + url: str | list[str], path: str | Path, *, params: dict[str, str] | None = None, @@ -220,50 +230,55 @@ async def download_file( if not headers: headers = get_user_agent() _proxy = proxy if proxy else cls.proxy if use_proxy else None - try: - async with httpx.AsyncClient( - proxies=_proxy, # type: ignore - verify=verify, - ) as client: - async with client.stream( - "GET", - url, - params=params, - headers=headers, - cookies=cookies, - timeout=timeout, - follow_redirects=True, - **kwargs, - ) as response: - response.raise_for_status() - logger.info( - f"开始下载 {path.name}.. Path: {path.absolute()}" - ) - async with aiofiles.open(path, "wb") as wf: - total = int(response.headers["Content-Length"]) - with rich.progress.Progress( # type: ignore - rich.progress.TextColumn(path.name), # type: ignore - "[progress.percentage]{task.percentage:>3.0f}%", # type: ignore - rich.progress.BarColumn(bar_width=None), # type: ignore - rich.progress.DownloadColumn(), # type: ignore - rich.progress.TransferSpeedColumn(), # type: ignore - ) as progress: - download_task = progress.add_task( - "Download", total=total - ) - async for chunk in response.aiter_bytes(): - await wf.write(chunk) - await wf.flush() - progress.update( - download_task, - completed=response.num_bytes_downloaded, - ) + if not isinstance(url, list): + url = [url] + for u in url: + try: + async with httpx.AsyncClient( + proxies=_proxy, # type: ignore + verify=verify, + ) as client: + async with client.stream( + "GET", + u, + params=params, + headers=headers, + cookies=cookies, + timeout=timeout, + follow_redirects=True, + **kwargs, + ) as response: + response.raise_for_status() logger.info( - f"下载 {url} 成功.. Path:{path.absolute()}" + f"开始下载 {path.name}.. " + f"Path: {path.absolute()}" ) - return True - except (TimeoutError, ConnectTimeout): - pass + async with aiofiles.open(path, "wb") as wf: + total = int(response.headers["Content-Length"]) + with rich.progress.Progress( # type: ignore + rich.progress.TextColumn(path.name), # type: ignore + "[progress.percentage]{task.percentage:>3.0f}%", # type: ignore + rich.progress.BarColumn(bar_width=None), # type: ignore + rich.progress.DownloadColumn(), # type: ignore + rich.progress.TransferSpeedColumn(), # type: ignore + ) as progress: + download_task = progress.add_task( + "Download", total=total + ) + async for chunk in response.aiter_bytes(): + await wf.write(chunk) + await wf.flush() + progress.update( + download_task, + completed=response.num_bytes_downloaded, + ) + logger.info( + f"下载 {u} 成功.. " + f"Path:{path.absolute()}" + ) + return True + except (TimeoutError, ConnectTimeout): + pass else: logger.error(f"下载 {url} 下载超时.. Path:{path.absolute()}") except Exception as e: @@ -273,7 +288,7 @@ async def download_file( @classmethod async def gather_download_file( cls, - url_list: list[str], + url_list: list[str] | list[list[str]], path_list: list[str | Path], *, limit_async_number: int | None = None,