diff --git a/pyproject.toml b/pyproject.toml index c1391eb6..ceb15576 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,10 +131,25 @@ win32-setctime = "1.1.0" wordcloud = "1.8.2.2" [tool.nonebot] -plugins = ["nonebot_plugin_guild_patch", "nonebot_bison"] +plugins = ["nonebot_plugin_guild_patch"] plugin_dirs = ["src/plugins"] adapters = [{name = "OneBot V11", module_name = "nonebot.adapters.onebot.v11", project_link = "nonebot-adapter-onebot", desc = "OneBot V11 协议"}] +# https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html +[tool.black] +line-length = 500 + +# https://beta.ruff.rs/docs/settings/ +[tool.ruff] +line-length = 500 +# https://beta.ruff.rs/docs/rules/ +select = ["E", "W", "F"] +ignore = ["F401"] +# Exclude a variety of commonly ignored directories. +respect-gitignore = true +ignore-init-module-imports = true + + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/src/plugins/nonebot_bison/__init__.py b/src/plugins/nonebot_bison/__init__.py new file mode 100644 index 00000000..90da7a1a --- /dev/null +++ b/src/plugins/nonebot_bison/__init__.py @@ -0,0 +1,47 @@ +from nonebot.plugin import PluginMetadata, require + +require("nonebot_plugin_apscheduler") +require("nonebot_plugin_datastore") +require("nonebot_plugin_saa") + +import nonebot_plugin_saa + +from .plugin_config import PlugConfig, plugin_config +from . import post, send, theme, types, utils, config, platform, bootstrap, scheduler, admin_page, sub_manager + +__help__version__ = "0.8.2" +nonebot_plugin_saa.enable_auto_select_bot() + +__help__plugin__name__ = "nonebot_bison" +__usage__ = ( + "本bot可以提供b站、微博等社交媒体的消息订阅,详情请查看本bot文档," + f"或者{'at本bot' if plugin_config.bison_to_me else '' }发送“添加订阅”订阅第一个帐号," + "发送“查询订阅”或“删除订阅”管理订阅" +) + +__supported_adapters__ = nonebot_plugin_saa.__plugin_meta__.supported_adapters + +__plugin_meta__ = PluginMetadata( + name="Bison", + description="通用订阅推送插件", + usage=__usage__, + type="application", + homepage="https://github.com/felinae98/nonebot-bison", + config=PlugConfig, + supported_adapters=__supported_adapters__, + extra={"version": __help__version__, "docs": "https://nonebot-bison.netlify.app/"}, +) + +__all__ = [ + "admin_page", + "bootstrap", + "config", + "sub_manager", + "post", + "scheduler", + "send", + "platform", + "types", + "utils", + "theme", +] diff --git a/src/plugins/nonebot_bison/admin_page/__init__.py b/src/plugins/nonebot_bison/admin_page/__init__.py new file mode 100644 index 00000000..4c2d4e68 --- /dev/null +++ b/src/plugins/nonebot_bison/admin_page/__init__.py @@ -0,0 +1,94 @@ +import os +from pathlib import Path +from typing import TYPE_CHECKING + +from nonebot.log import logger +from nonebot.rule import to_me +from nonebot.typing import T_State +from nonebot import get_driver, on_command +from nonebot.adapters.onebot.v11 import Bot +from nonebot.adapters.onebot.v11.event import PrivateMessageEvent + +from .api import router as api_router +from ..plugin_config import plugin_config +from .token_manager import token_manager as tm + +if TYPE_CHECKING: + from nonebot.drivers.fastapi import Driver + + +STATIC_PATH = (Path(__file__).parent / "dist").resolve() + + +def init_fastapi(driver: "Driver"): + import socketio + from fastapi.applications import FastAPI + from fastapi.staticfiles import StaticFiles + + sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") + socket_app = socketio.ASGIApp(sio, socketio_path="socket") + + class SinglePageApplication(StaticFiles): + def __init__(self, directory: os.PathLike, index="index.html"): + self.index = index + super().__init__(directory=directory, packages=None, html=True, check_dir=True) + + def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]: + full_path, stat_res = super().lookup_path(path) + if stat_res is None: + return super().lookup_path(self.index) + return (full_path, stat_res) + + def register_router_fastapi(driver: "Driver", socketio): + static_path = STATIC_PATH + nonebot_app = FastAPI( + title="nonebot-bison", + description="nonebot-bison webui and api", + ) + nonebot_app.include_router(api_router) + nonebot_app.mount("/", SinglePageApplication(directory=static_path), name="bison-frontend") + + app = driver.server_app + app.mount("/bison", nonebot_app, "nonebot-bison") + + register_router_fastapi(driver, socket_app) + host = str(driver.config.host) + port = driver.config.port + if host in ["0.0.0.0", "127.0.0.1"]: + host = "localhost" + logger.opt(colors=True).info(f"Nonebot Bison frontend will be running at: http://{host}:{port}/bison") + logger.opt(colors=True).info("该页面不能被直接访问,请私聊bot 后台管理 以获取可访问地址") + + +def register_get_token_handler(): + get_token = on_command("后台管理", rule=to_me(), priority=5, aliases={"管理后台"}) + + @get_token.handle() + async def send_token(bot: "Bot", event: PrivateMessageEvent, state: T_State): + token = tm.get_user_token((event.get_user_id(), event.sender.nickname)) + await get_token.finish(f"请访问: {plugin_config.outer_url / 'auth' / token}") + + get_token.__help__name__ = "获取后台管理地址" # type: ignore + get_token.__help__info__ = "获取管理bot后台的地址,该地址会在一段时间过后过期,请不要泄漏该地址" # type: ignore + + +def get_fastapi_driver() -> "Driver | None": + try: + from nonebot.drivers.fastapi import Driver + + if (driver := get_driver()) and isinstance(driver, Driver): + return driver + return None + + except ImportError: + return None + + +if (STATIC_PATH / "index.html").exists(): + if driver := get_fastapi_driver(): + init_fastapi(driver) + register_get_token_handler() + else: + logger.warning("your driver is not fastapi, webui feature will be disabled") +else: + logger.warning("Frontend file not found, please compile it or use docker or pypi version") diff --git a/src/plugins/nonebot_bison/admin_page/api.py b/src/plugins/nonebot_bison/admin_page/api.py new file mode 100644 index 00000000..afe834e0 --- /dev/null +++ b/src/plugins/nonebot_bison/admin_page/api.py @@ -0,0 +1,199 @@ +import nonebot +from fastapi import status +from fastapi.routing import APIRouter +from fastapi.param_functions import Depends +from fastapi.exceptions import HTTPException +from nonebot_plugin_saa import TargetQQGroup +from nonebot_plugin_saa.auto_select_bot import get_bot +from fastapi.security.oauth2 import OAuth2PasswordBearer + +from ..types import WeightConfig +from ..apis import check_sub_target +from .jwt import load_jwt, pack_jwt +from ..types import Target as T_Target +from ..utils.get_bot import get_groups +from ..platform import platform_manager +from .token_manager import token_manager +from ..config.db_config import SubscribeDupException +from ..config import NoSuchUserException, NoSuchTargetException, NoSuchSubscribeException, config +from .types import ( + TokenResp, + GlobalConf, + StatusResp, + SubscribeResp, + PlatformConfig, + AddSubscribeReq, + SubscribeConfig, + SubscribeGroupDetail, +) + +router = APIRouter(prefix="/api", tags=["api"]) + +oath_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +async def get_jwt_obj(token: str = Depends(oath_scheme)): + obj = load_jwt(token) + if not obj: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + return obj + + +async def check_group_permission(groupNumber: int, token_obj: dict = Depends(get_jwt_obj)): + groups = token_obj["groups"] + for group in groups: + if int(groupNumber) == group["id"]: + return + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + +async def check_is_superuser(token_obj: dict = Depends(get_jwt_obj)): + if token_obj.get("type") != "admin": + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + +@router.get("/global_conf") +async def get_global_conf() -> GlobalConf: + res = {} + for platform_name, platform in platform_manager.items(): + res[platform_name] = PlatformConfig( + platformName=platform_name, + categories=platform.categories, + enabledTag=platform.enable_tag, + name=platform.name, + hasTarget=getattr(platform, "has_target"), + ) + return GlobalConf(platformConf=res) + + +async def get_admin_groups(qq: int): + res = [] + for group in await get_groups(): + group_id = group["group_id"] + bot = get_bot(TargetQQGroup(group_id=group_id)) + if not bot: + continue + users = await bot.get_group_member_list(group_id=group_id) + for user in users: + if user["user_id"] == qq and user["role"] in ("owner", "admin"): + res.append({"id": group_id, "name": group["group_name"]}) + return res + + +@router.get("/auth") +async def auth(token: str) -> TokenResp: + if qq_tuple := token_manager.get_user(token): + qq, nickname = qq_tuple + if str(qq) in nonebot.get_driver().config.superusers: + jwt_obj = { + "id": qq, + "type": "admin", + "groups": [ + { + "id": info["group_id"], + "name": info["group_name"], + } + for info in await get_groups() + ], + } + ret_obj = TokenResp( + type="admin", + name=nickname, + id=qq, + token=pack_jwt(jwt_obj), + ) + return ret_obj + if admin_groups := await get_admin_groups(int(qq)): + jwt_obj = {"id": str(qq), "type": "user", "groups": admin_groups} + ret_obj = TokenResp( + type="user", + name=nickname, + id=qq, + token=pack_jwt(jwt_obj), + ) + return ret_obj + else: + raise HTTPException(400, "permission denied") + else: + raise HTTPException(400, "code error") + + +@router.get("/subs") +async def get_subs_info(jwt_obj: dict = Depends(get_jwt_obj)) -> SubscribeResp: + groups = jwt_obj["groups"] + res: SubscribeResp = {} + for group in groups: + group_id = group["id"] + raw_subs = await config.list_subscribe(TargetQQGroup(group_id=group_id)) + subs = [ + SubscribeConfig( + platformName=sub.target.platform_name, + targetName=sub.target.target_name, + cats=sub.categories, + tags=sub.tags, + target=sub.target.target, + ) + for sub in raw_subs + ] + res[group_id] = SubscribeGroupDetail(name=group["name"], subscribes=subs) + return res + + +@router.get("/target_name", dependencies=[Depends(get_jwt_obj)]) +async def get_target_name(platformName: str, target: str): + return {"targetName": await check_sub_target(platformName, T_Target(target))} + + +@router.post("/subs", dependencies=[Depends(check_group_permission)]) +async def add_group_sub(groupNumber: int, req: AddSubscribeReq) -> StatusResp: + try: + await config.add_subscribe( + TargetQQGroup(group_id=groupNumber), + T_Target(req.target), + req.targetName, + req.platformName, + req.cats, + req.tags, + ) + return StatusResp(ok=True, msg="") + except SubscribeDupException: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "subscribe duplicated") + + +@router.delete("/subs", dependencies=[Depends(check_group_permission)]) +async def del_group_sub(groupNumber: int, platformName: str, target: str): + try: + await config.del_subscribe(TargetQQGroup(group_id=groupNumber), target, platformName) + except (NoSuchUserException, NoSuchSubscribeException): + raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe") + return StatusResp(ok=True, msg="") + + +@router.patch("/subs", dependencies=[Depends(check_group_permission)]) +async def update_group_sub(groupNumber: int, req: AddSubscribeReq): + try: + await config.update_subscribe( + TargetQQGroup(group_id=groupNumber), + req.target, + req.targetName, + req.platformName, + req.cats, + req.tags, + ) + except (NoSuchUserException, NoSuchSubscribeException): + raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such user or subscribe") + return StatusResp(ok=True, msg="") + + +@router.get("/weight", dependencies=[Depends(check_is_superuser)]) +async def get_weight_config(): + return await config.get_all_weight_config() + + +@router.put("/weight", dependencies=[Depends(check_is_superuser)]) +async def update_weigth_config(platformName: str, target: str, weight_config: WeightConfig): + try: + await config.update_time_weight_config(T_Target(target), platformName, weight_config) + except NoSuchTargetException: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "no such subscribe") + return StatusResp(ok=True, msg="") diff --git a/src/plugins/nonebot_bison/admin_page/dist/.gitkeep b/src/plugins/nonebot_bison/admin_page/dist/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/src/plugins/nonebot_bison/admin_page/jwt.py b/src/plugins/nonebot_bison/admin_page/jwt.py new file mode 100644 index 00000000..866c184d --- /dev/null +++ b/src/plugins/nonebot_bison/admin_page/jwt.py @@ -0,0 +1,22 @@ +import random +import string +import datetime + +import jwt + +_key = "".join(random.SystemRandom().choice(string.ascii_letters) for _ in range(16)) + + +def pack_jwt(obj: dict) -> str: + return jwt.encode( + {"exp": datetime.datetime.utcnow() + datetime.timedelta(hours=1), **obj}, + _key, + algorithm="HS256", + ) + + +def load_jwt(token: str) -> dict | None: + try: + return jwt.decode(token, _key, algorithms=["HS256"]) + except Exception: + return None diff --git a/src/plugins/nonebot_bison/admin_page/token_manager.py b/src/plugins/nonebot_bison/admin_page/token_manager.py new file mode 100644 index 00000000..bb62d0ad --- /dev/null +++ b/src/plugins/nonebot_bison/admin_page/token_manager.py @@ -0,0 +1,25 @@ +import random +import string + +from expiringdict import ExpiringDict + + +class TokenManager: + def __init__(self): + self.token_manager = ExpiringDict(max_len=100, max_age_seconds=60 * 10) + + def get_user(self, token: str) -> tuple | None: + res = self.token_manager.get(token) + assert res is None or isinstance(res, tuple) + return res + + def save_user(self, token: str, qq: tuple) -> None: + self.token_manager[token] = qq + + def get_user_token(self, qq: tuple) -> str: + token = "".join(random.choices(string.ascii_letters + string.digits, k=16)) + self.save_user(token, qq) + return token + + +token_manager = TokenManager() diff --git a/src/plugins/nonebot_bison/admin_page/types.py b/src/plugins/nonebot_bison/admin_page/types.py new file mode 100644 index 00000000..7a18b67f --- /dev/null +++ b/src/plugins/nonebot_bison/admin_page/types.py @@ -0,0 +1,52 @@ +from pydantic import BaseModel + + +class PlatformConfig(BaseModel): + name: str + categories: dict[int, str] + enabledTag: bool + platformName: str + hasTarget: bool + + +AllPlatformConf = dict[str, PlatformConfig] + + +class GlobalConf(BaseModel): + platformConf: AllPlatformConf + + +class TokenResp(BaseModel): + token: str + type: str + id: int + name: str + + +class SubscribeConfig(BaseModel): + platformName: str + target: str + targetName: str + cats: list[int] + tags: list[str] + + +class SubscribeGroupDetail(BaseModel): + name: str + subscribes: list[SubscribeConfig] + + +SubscribeResp = dict[int, SubscribeGroupDetail] + + +class AddSubscribeReq(BaseModel): + platformName: str + target: str + targetName: str + cats: list[int] + tags: list[str] + + +class StatusResp(BaseModel): + ok: bool + msg: str diff --git a/src/plugins/nonebot_bison/apis.py b/src/plugins/nonebot_bison/apis.py new file mode 100644 index 00000000..6d5130ea --- /dev/null +++ b/src/plugins/nonebot_bison/apis.py @@ -0,0 +1,12 @@ +from .types import Target +from .scheduler import scheduler_dict +from .platform import platform_manager + + +async def check_sub_target(platform_name: str, target: Target): + platform = platform_manager[platform_name] + scheduler_conf_class = platform.scheduler + scheduler = scheduler_dict[scheduler_conf_class] + client = await scheduler.scheduler_config_obj.get_query_name_client() + + return await platform_manager[platform_name].get_target_name(client, target) diff --git a/src/plugins/nonebot_bison/bootstrap.py b/src/plugins/nonebot_bison/bootstrap.py new file mode 100644 index 00000000..92d06a84 --- /dev/null +++ b/src/plugins/nonebot_bison/bootstrap.py @@ -0,0 +1,49 @@ +from nonebot.log import logger +from sqlalchemy import text, inspect +from nonebot_plugin_datastore.db import get_engine, pre_db_init, post_db_init + +from .config.db_migration import data_migrate +from .scheduler.manager import init_scheduler +from .config.config_legacy import start_up as legacy_db_startup + + +@pre_db_init +async def pre(): + def _has_table(conn, table_name): + insp = inspect(conn) + return insp.has_table(table_name) + + async with get_engine().begin() as conn: + if not await conn.run_sync(_has_table, "alembic_version"): + logger.debug("未发现默认版本数据库,开始初始化") + return + + logger.debug("发现默认版本数据库,开始检查版本") + t = await conn.scalar(text("select version_num from alembic_version")) + if t not in [ + "4a46ba54a3f3", # alter_type + "5f3370328e44", # add_time_weight_table + "0571870f5222", # init_db + "a333d6224193", # add_last_scheduled_time + "c97c445e2bdb", # add_constraint + ]: + logger.warning(f"当前数据库版本:{t},不是插件的版本,已跳过。") + return + + logger.debug(f"当前数据库版本:{t},是插件的版本,开始迁移。") + # 删除可能存在的版本数据库 + if await conn.run_sync(_has_table, "nonebot_bison_alembic_version"): + await conn.execute(text("drop table nonebot_bison_alembic_version")) + + await conn.execute(text("alter table alembic_version rename to nonebot_bison_alembic_version")) + + +@post_db_init +async def post(): + # legacy db + legacy_db_startup() + # migrate data + await data_migrate() + # init scheduler + await init_scheduler() + logger.info("nonebot-bison bootstrap done") diff --git a/src/plugins/nonebot_bison/compat.py b/src/plugins/nonebot_bison/compat.py new file mode 100644 index 00000000..d4a65a5b --- /dev/null +++ b/src/plugins/nonebot_bison/compat.py @@ -0,0 +1,28 @@ +from typing import Literal, overload + +from pydantic import BaseModel +from nonebot.compat import PYDANTIC_V2 + +__all__ = ("model_validator", "model_rebuild") + + +if PYDANTIC_V2: + from pydantic import model_validator as model_validator + + def model_rebuild(model: type[BaseModel]): + return model.model_rebuild() + +else: + from pydantic import root_validator + + @overload + def model_validator(*, mode: Literal["before"]): ... + + @overload + def model_validator(*, mode: Literal["after"]): ... + + def model_validator(*, mode: Literal["before", "after"]): + return root_validator(pre=mode == "before", allow_reuse=True) + + def model_rebuild(model: type[BaseModel]): + return model.update_forward_refs() diff --git a/src/plugins/nonebot_bison/config/__init__.py b/src/plugins/nonebot_bison/config/__init__.py new file mode 100644 index 00000000..a04d41f0 --- /dev/null +++ b/src/plugins/nonebot_bison/config/__init__.py @@ -0,0 +1,4 @@ +from .db_config import config as config +from .utils import NoSuchUserException as NoSuchUserException +from .utils import NoSuchTargetException as NoSuchTargetException +from .utils import NoSuchSubscribeException as NoSuchSubscribeException diff --git a/src/plugins/nonebot_bison/config/config_legacy.py b/src/plugins/nonebot_bison/config/config_legacy.py new file mode 100644 index 00000000..24e7e4dd --- /dev/null +++ b/src/plugins/nonebot_bison/config/config_legacy.py @@ -0,0 +1,252 @@ +import os +import json +from os import path +from pathlib import Path +from datetime import datetime +from collections import defaultdict +from typing import Literal, TypedDict + +from nonebot.log import logger +from tinydb import Query, TinyDB + +from ..utils import Singleton +from ..types import User, Target +from ..platform import platform_manager +from ..plugin_config import plugin_config +from .utils import NoSuchUserException, NoSuchSubscribeException + +supported_target_type = platform_manager.keys() + + +def get_config_path() -> tuple[str, str]: + if plugin_config.bison_config_path: + data_dir = plugin_config.bison_config_path + else: + working_dir = os.getcwd() + data_dir = path.join(working_dir, "data") + old_path = path.join(data_dir, "hk_reporter.json") + new_path = path.join(data_dir, "bison.json") + deprecated_maker_path = path.join(data_dir, "bison.json.deprecated") + if os.path.exists(old_path) and not os.path.exists(new_path): + os.rename(old_path, new_path) + return new_path, deprecated_maker_path + + +def drop(): + config = Config() + if plugin_config.bison_config_path: + data_dir = plugin_config.bison_config_path + else: + working_dir = os.getcwd() + data_dir = path.join(working_dir, "data") + old_path = path.join(data_dir, "bison.json") + deprecated_marker_path = path.join(data_dir, "bison.json.deprecated") + if os.path.exists(old_path): + config.db.close() + config.available = False + with open(deprecated_marker_path, "w") as file: + content = { + "migration_time": datetime.now().isoformat(), + } + file.write(json.dumps(content)) + return True + return False + + +class SubscribeContent(TypedDict): + target: str + target_type: str + target_name: str + cats: list[int] + tags: list[str] + + +class ConfigContent(TypedDict): + user: int + user_type: Literal["group", "private"] + subs: list[SubscribeContent] + + +class Config(metaclass=Singleton): + "Dropping it!" + + migrate_version = 2 + + def __init__(self): + self._do_init() + + def _do_init(self): + path, deprecated_marker_path = get_config_path() + if Path(deprecated_marker_path).exists(): + self.available = False + elif Path(path).exists(): + self.available = True + self.db = TinyDB(path, encoding="utf-8") + self.kv_config = self.db.table("kv") + self.user_target = self.db.table("user_target") + self.target_user_cache: dict[str, defaultdict[Target, list[User]]] = {} + self.target_user_cat_cache = {} + self.target_user_tag_cache = {} + self.target_list = {} + self.next_index: defaultdict[str, int] = defaultdict(lambda: 0) + else: + self.available = False + + def add_subscribe(self, user, user_type, target, target_name, target_type, cats, tags): + user_query = Query() + query = (user_query.user == user) & (user_query.user_type == user_type) + if user_data := self.user_target.get(query): + # update + assert not isinstance(user_data, list) + subs: list = user_data.get("subs", []) + subs.append( + { + "target": target, + "target_type": target_type, + "target_name": target_name, + "cats": cats, + "tags": tags, + } + ) + self.user_target.update({"subs": subs}, query) + else: + # insert + self.user_target.insert( + { + "user": user, + "user_type": user_type, + "subs": [ + { + "target": target, + "target_type": target_type, + "target_name": target_name, + "cats": cats, + "tags": tags, + } + ], + } + ) + self.update_send_cache() + + def list_subscribe(self, user, user_type) -> list[SubscribeContent]: + query = Query() + if user_sub := self.user_target.get((query.user == user) & (query.user_type == user_type)): + assert not isinstance(user_sub, list) + return user_sub["subs"] + return [] + + def get_all_subscribe(self): + return self.user_target + + def del_subscribe(self, user, user_type, target, target_type): + user_query = Query() + query = (user_query.user == user) & (user_query.user_type == user_type) + if not (query_res := self.user_target.get(query)): + raise NoSuchUserException() + assert not isinstance(query_res, list) + subs = query_res.get("subs", []) + for idx, sub in enumerate(subs): + if sub.get("target") == target and sub.get("target_type") == target_type: + subs.pop(idx) + self.user_target.update({"subs": subs}, query) + self.update_send_cache() + return + raise NoSuchSubscribeException() + + def update_subscribe(self, user, user_type, target, target_name, target_type, cats, tags): + user_query = Query() + query = (user_query.user == user) & (user_query.user_type == user_type) + if user_data := self.user_target.get(query): + # update + assert not isinstance(user_data, list) + subs: list = user_data.get("subs", []) + find_flag = False + for item in subs: + if item["target"] == target and item["target_type"] == target_type: + item["target_name"], item["cats"], item["tags"] = ( + target_name, + cats, + tags, + ) + find_flag = True + break + if not find_flag: + raise NoSuchSubscribeException() + self.user_target.update({"subs": subs}, query) + else: + raise NoSuchUserException() + self.update_send_cache() + + def update_send_cache(self): + res = {target_type: defaultdict(list) for target_type in supported_target_type} + cat_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} + tag_res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} + # res = {target_type: defaultdict(lambda: defaultdict(list)) for target_type in supported_target_type} + to_del = [] + for user in self.user_target.all(): + for sub in user.get("subs", []): + if sub.get("target_type") not in supported_target_type: + to_del.append( + { + "user": user["user"], + "user_type": user["user_type"], + "target": sub["target"], + "target_type": sub["target_type"], + } + ) + continue + res[sub["target_type"]][sub["target"]].append(User(user["user"], user["user_type"])) + cat_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[ + "cats" + ] + tag_res[sub["target_type"]][sub["target"]]["{}-{}".format(user["user_type"], user["user"])] = sub[ + "tags" + ] + self.target_user_cache = res + self.target_user_cat_cache = cat_res + self.target_user_tag_cache = tag_res + for target_type in self.target_user_cache: + self.target_list[target_type] = list(self.target_user_cache[target_type].keys()) + + logger.info(f"Deleting {to_del}") + for d in to_del: + self.del_subscribe(**d) + + def get_sub_category(self, target_type, target, user_type, user): + return self.target_user_cat_cache[target_type][target][f"{user_type}-{user}"] + + def get_sub_tags(self, target_type, target, user_type, user): + return self.target_user_tag_cache[target_type][target][f"{user_type}-{user}"] + + def get_next_target(self, target_type): + # FIXME 插入或删除target后对队列的影响(但是并不是大问题 + if not self.target_list[target_type]: + return None + self.next_index[target_type] %= len(self.target_list[target_type]) + res = self.target_list[target_type][self.next_index[target_type]] + self.next_index[target_type] += 1 + return res + + +def start_up(): + config = Config() + if not config.available: + return + if not (search_res := config.kv_config.search(Query().name == "version")): + config.kv_config.insert({"name": "version", "value": config.migrate_version}) + elif search_res[0].get("value") < config.migrate_version: # type: ignore + query = Query() + version_query = query.name == "version" + cur_version = search_res[0].get("value") + if cur_version == 1: + cur_version = 2 + for user_conf in config.user_target.all(): + conf_id = user_conf.doc_id + subs = user_conf["subs"] + for sub in subs: + sub["cats"] = [] + sub["tags"] = [] + config.user_target.update({"subs": subs}, doc_ids=[conf_id]) + config.kv_config.update({"value": config.migrate_version}, version_query) + # do migration + config.update_send_cache() diff --git a/src/plugins/nonebot_bison/config/db_config.py b/src/plugins/nonebot_bison/config/db_config.py new file mode 100644 index 00000000..157b1ef6 --- /dev/null +++ b/src/plugins/nonebot_bison/config/db_config.py @@ -0,0 +1,263 @@ +import asyncio +from collections import defaultdict +from datetime import time, datetime +from collections.abc import Callable, Sequence, Awaitable + +from nonebot.compat import model_dump +from sqlalchemy.orm import selectinload +from sqlalchemy.exc import IntegrityError +from sqlalchemy import func, delete, select +from nonebot_plugin_saa import PlatformTarget +from nonebot_plugin_datastore import create_session + +from ..types import Tag +from ..types import Target as T_Target +from .utils import NoSuchTargetException +from .db_model import User, Target, Subscribe, ScheduleTimeWeight +from ..types import Category, UserSubInfo, WeightConfig, TimeWeightConfig, PlatformWeightConfigResp + + +def _get_time(): + dt = datetime.now() + cur_time = time(hour=dt.hour, minute=dt.minute, second=dt.second) + return cur_time + + +class SubscribeDupException(Exception): ... + + +class DBConfig: + def __init__(self): + self.add_target_hook: list[Callable[[str, T_Target], Awaitable]] = [] + self.delete_target_hook: list[Callable[[str, T_Target], Awaitable]] = [] + + def register_add_target_hook(self, fun: Callable[[str, T_Target], Awaitable]): + self.add_target_hook.append(fun) + + def register_delete_target_hook(self, fun: Callable[[str, T_Target], Awaitable]): + self.delete_target_hook.append(fun) + + async def add_subscribe( + self, + user: PlatformTarget, + target: T_Target, + target_name: str, + platform_name: str, + cats: list[Category], + tags: list[Tag], + ): + async with create_session() as session: + db_user_stmt = select(User).where(User.user_target == model_dump(user)) + db_user: User | None = await session.scalar(db_user_stmt) + if not db_user: + db_user = User(user_target=model_dump(user)) + session.add(db_user) + db_target_stmt = select(Target).where(Target.platform_name == platform_name).where(Target.target == target) + db_target: Target | None = await session.scalar(db_target_stmt) + if not db_target: + db_target = Target(target=target, platform_name=platform_name, target_name=target_name) + await asyncio.gather(*[hook(platform_name, target) for hook in self.add_target_hook]) + else: + db_target.target_name = target_name + subscribe = Subscribe( + categories=cats, + tags=tags, + user=db_user, + target=db_target, + ) + session.add(subscribe) + try: + await session.commit() + except IntegrityError as e: + if len(e.args) > 0 and "UNIQUE constraint failed" in e.args[0]: + raise SubscribeDupException() + raise e + + async def list_subscribe(self, user: PlatformTarget) -> Sequence[Subscribe]: + async with create_session() as session: + query_stmt = ( + select(Subscribe) + .where(User.user_target == model_dump(user)) + .join(User) + .options(selectinload(Subscribe.target)) + ) + subs = (await session.scalars(query_stmt)).all() + return subs + + async def list_subs_with_all_info(self) -> Sequence[Subscribe]: + """获取数据库中带有user、target信息的subscribe数据""" + async with create_session() as session: + query_stmt = ( + select(Subscribe).join(User).options(selectinload(Subscribe.target), selectinload(Subscribe.user)) + ) + subs = (await session.scalars(query_stmt)).all() + + return subs + + async def del_subscribe(self, user: PlatformTarget, target: str, platform_name: str): + async with create_session() as session: + user_obj = await session.scalar(select(User).where(User.user_target == model_dump(user))) + target_obj = await session.scalar( + select(Target).where(Target.platform_name == platform_name, Target.target == target) + ) + await session.execute(delete(Subscribe).where(Subscribe.user == user_obj, Subscribe.target == target_obj)) + target_count = await session.scalar( + select(func.count()).select_from(Subscribe).where(Subscribe.target == target_obj) + ) + if target_count == 0: + # delete empty target + await asyncio.gather(*[hook(platform_name, T_Target(target)) for hook in self.delete_target_hook]) + await session.commit() + + async def update_subscribe( + self, + user: PlatformTarget, + target: str, + target_name: str, + platform_name: str, + cats: list, + tags: list, + ): + async with create_session() as sess: + subscribe_obj: Subscribe = await sess.scalar( + select(Subscribe) + .where( + User.user_target == model_dump(user), + Target.target == target, + Target.platform_name == platform_name, + ) + .join(User) + .join(Target) + .options(selectinload(Subscribe.target)) # type:ignore + ) + subscribe_obj.tags = tags # type:ignore + subscribe_obj.categories = cats # type:ignore + subscribe_obj.target.target_name = target_name + await sess.commit() + + async def get_platform_target(self, platform_name: str) -> Sequence[Target]: + async with create_session() as sess: + subq = select(Subscribe.target_id).distinct().subquery() + query = select(Target).join(subq).where(Target.platform_name == platform_name) + return (await sess.scalars(query)).all() + + async def get_time_weight_config(self, target: T_Target, platform_name: str) -> WeightConfig: + async with create_session() as sess: + time_weight_conf = ( + await sess.scalars( + select(ScheduleTimeWeight) + .where(Target.platform_name == platform_name, Target.target == target) + .join(Target) + ) + ).all() + targetObj = await sess.scalar( + select(Target).where(Target.platform_name == platform_name, Target.target == target) + ) + assert targetObj + return WeightConfig( + default=targetObj.default_schedule_weight, + time_config=[ + TimeWeightConfig( + start_time=time_conf.start_time, + end_time=time_conf.end_time, + weight=time_conf.weight, + ) + for time_conf in time_weight_conf + ], + ) + + async def update_time_weight_config(self, target: T_Target, platform_name: str, conf: WeightConfig): + async with create_session() as sess: + targetObj = await sess.scalar( + select(Target).where(Target.platform_name == platform_name, Target.target == target) + ) + if not targetObj: + raise NoSuchTargetException() + target_id = targetObj.id + targetObj.default_schedule_weight = conf.default + delete_statement = delete(ScheduleTimeWeight).where(ScheduleTimeWeight.target_id == target_id) + await sess.execute(delete_statement) + for time_conf in conf.time_config: + new_conf = ScheduleTimeWeight( + start_time=time_conf.start_time, + end_time=time_conf.end_time, + weight=time_conf.weight, + target=targetObj, + ) + sess.add(new_conf) + + await sess.commit() + + async def get_current_weight_val(self, platform_list: list[str]) -> dict[str, int]: + res = {} + cur_time = _get_time() + async with create_session() as sess: + targets = ( + await sess.scalars( + select(Target) + .where(Target.platform_name.in_(platform_list)) + .options(selectinload(Target.time_weight)) + ) + ).all() + for target in targets: + key = f"{target.platform_name}-{target.target}" + weight = target.default_schedule_weight + for time_conf in target.time_weight: + if time_conf.start_time <= cur_time and time_conf.end_time > cur_time: + weight = time_conf.weight + break + res[key] = weight + return res + + async def get_platform_target_subscribers(self, platform_name: str, target: T_Target) -> list[UserSubInfo]: + async with create_session() as sess: + query = ( + select(Subscribe) + .join(Target) + .where(Target.platform_name == platform_name, Target.target == target) + .options(selectinload(Subscribe.user)) + ) + subsribes = (await sess.scalars(query)).all() + return [ + UserSubInfo( + PlatformTarget.deserialize(subscribe.user.user_target), + subscribe.categories, + subscribe.tags, + ) + for subscribe in subsribes + ] + + async def get_all_weight_config( + self, + ) -> dict[str, dict[str, PlatformWeightConfigResp]]: + res: dict[str, dict[str, PlatformWeightConfigResp]] = defaultdict(dict) + async with create_session() as sess: + query = select(Target) + targets = (await sess.scalars(query)).all() + query = select(ScheduleTimeWeight).options(selectinload(ScheduleTimeWeight.target)) + time_weights = (await sess.scalars(query)).all() + + for target in targets: + platform_name = target.platform_name + if platform_name not in res.keys(): + res[platform_name][target.target] = PlatformWeightConfigResp( + target=T_Target(target.target), + target_name=target.target_name, + platform_name=platform_name, + weight=WeightConfig(default=target.default_schedule_weight, time_config=[]), + ) + + for time_weight_config in time_weights: + platform_name = time_weight_config.target.platform_name + target = time_weight_config.target.target + res[platform_name][target].weight.time_config.append( + TimeWeightConfig( + start_time=time_weight_config.start_time, + end_time=time_weight_config.end_time, + weight=time_weight_config.weight, + ) + ) + return res + + +config = DBConfig() diff --git a/src/plugins/nonebot_bison/config/db_migration.py b/src/plugins/nonebot_bison/config/db_migration.py new file mode 100644 index 00000000..75080ad6 --- /dev/null +++ b/src/plugins/nonebot_bison/config/db_migration.py @@ -0,0 +1,71 @@ +from nonebot.log import logger +from nonebot.compat import model_dump +from nonebot_plugin_datastore.db import get_engine +from sqlalchemy.ext.asyncio.session import AsyncSession +from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate + +from .db_model import User, Target, Subscribe +from .config_legacy import Config, ConfigContent, drop + + +async def data_migrate(): + config = Config() + if config.available: + logger.warning("You are still using legacy db, migrating to sqlite") + all_subs: list[ConfigContent] = [ConfigContent(**item) for item in config.get_all_subscribe().all()] + async with AsyncSession(get_engine()) as sess: + user_to_create = [] + subscribe_to_create = [] + platform_target_map: dict[str, tuple[Target, str, int]] = {} + for user in all_subs: + if user["user_type"] == "group": + user_target = TargetQQGroup(group_id=user["user"]) + else: + user_target = TargetQQPrivate(user_id=user["user"]) + db_user = User(user_target=model_dump(user_target)) + user_to_create.append(db_user) + user_sub_set = set() + for sub in user["subs"]: + target = sub["target"] + platform_name = sub["target_type"] + target_name = sub["target_name"] + key = f"{target}-{platform_name}" + if key in user_sub_set: + # a user subscribe a target twice + logger.error( + f"用户 {user['user_type']}-{user['user']} 订阅了 {platform_name}-{target_name} 两次," + "随机采用了一个订阅", + ) + continue + user_sub_set.add(key) + if key in platform_target_map.keys(): + target_obj, ext_user_type, ext_user = platform_target_map[key] + if target_obj.target_name != target_name: + # GG + logger.error( + f"你的旧版本数据库中存在数据不一致问题,请完成迁移后执行重新添加{platform_name}平台的{target}" + f"它的名字可能为{target_obj.target_name}或{target_name}" + ) + + else: + target_obj = Target( + platform_name=platform_name, + target_name=target_name, + target=target, + ) + platform_target_map[key] = ( + target_obj, + user["user_type"], + user["user"], + ) + subscribe_obj = Subscribe( + user=db_user, + target=target_obj, + categories=sub["cats"], + tags=sub["tags"], + ) + subscribe_to_create.append(subscribe_obj) + sess.add_all(user_to_create + [x[0] for x in platform_target_map.values()] + subscribe_to_create) + await sess.commit() + drop() + logger.info("migrate success") diff --git a/src/plugins/nonebot_bison/config/db_model.py b/src/plugins/nonebot_bison/config/db_model.py new file mode 100644 index 00000000..849094d1 --- /dev/null +++ b/src/plugins/nonebot_bison/config/db_model.py @@ -0,0 +1,68 @@ +import datetime +from pathlib import Path + +from nonebot_plugin_saa import PlatformTarget +from sqlalchemy.dialects.postgresql import JSONB +from nonebot.compat import PYDANTIC_V2, ConfigDict +from nonebot_plugin_datastore import get_plugin_data +from sqlalchemy.orm import Mapped, relationship, mapped_column +from sqlalchemy import JSON, String, ForeignKey, UniqueConstraint + +from ..types import Tag, Category + +Model = get_plugin_data().Model +get_plugin_data().set_migration_dir(Path(__file__).parent / "migrations") + + +class User(Model): + id: Mapped[int] = mapped_column(primary_key=True) + user_target: Mapped[dict] = mapped_column(JSON().with_variant(JSONB, "postgresql")) + + subscribes: Mapped[list["Subscribe"]] = relationship(back_populates="user") + + @property + def saa_target(self) -> PlatformTarget: + return PlatformTarget.deserialize(self.user_target) + + +class Target(Model): + __table_args__ = (UniqueConstraint("target", "platform_name", name="unique-target-constraint"),) + + id: Mapped[int] = mapped_column(primary_key=True) + platform_name: Mapped[str] = mapped_column(String(20)) + target: Mapped[str] = mapped_column(String(1024)) + target_name: Mapped[str] = mapped_column(String(1024)) + default_schedule_weight: Mapped[int] = mapped_column(default=10) + + subscribes: Mapped[list["Subscribe"]] = relationship(back_populates="target") + time_weight: Mapped[list["ScheduleTimeWeight"]] = relationship(back_populates="target") + + +class ScheduleTimeWeight(Model): + id: Mapped[int] = mapped_column(primary_key=True) + target_id: Mapped[int] = mapped_column(ForeignKey("nonebot_bison_target.id")) + start_time: Mapped[datetime.time] + end_time: Mapped[datetime.time] + weight: Mapped[int] + + target: Mapped[Target] = relationship(back_populates="time_weight") + + if PYDANTIC_V2: + model_config = ConfigDict(arbitrary_types_allowed=True) + else: + + class Config: + arbitrary_types_allowed = True + + +class Subscribe(Model): + __table_args__ = (UniqueConstraint("target_id", "user_id", name="unique-subscribe-constraint"),) + + id: Mapped[int] = mapped_column(primary_key=True) + target_id: Mapped[int] = mapped_column(ForeignKey("nonebot_bison_target.id")) + user_id: Mapped[int] = mapped_column(ForeignKey("nonebot_bison_user.id")) + categories: Mapped[list[Category]] = mapped_column(JSON) + tags: Mapped[list[Tag]] = mapped_column(JSON) + + target: Mapped[Target] = relationship(back_populates="subscribes") + user: Mapped[User] = relationship(back_populates="subscribes") diff --git a/src/plugins/nonebot_bison/config/migrations/0571870f5222_init_db.py b/src/plugins/nonebot_bison/config/migrations/0571870f5222_init_db.py new file mode 100644 index 00000000..391433f1 --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/0571870f5222_init_db.py @@ -0,0 +1,61 @@ +"""init db + +Revision ID: 0571870f5222 +Revises: +Create Date: 2022-03-21 19:18:13.762626 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "0571870f5222" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "target", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("platform_name", sa.String(length=20), nullable=False), + sa.Column("target", sa.String(length=1024), nullable=False), + sa.Column("target_name", sa.String(length=1024), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "user", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("type", sa.String(length=20), nullable=False), + sa.Column("uid", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "subscribe", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("target_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("categories", sa.String(length=1024), nullable=True), + sa.Column("tags", sa.String(length=1024), nullable=True), + sa.ForeignKeyConstraint( + ["target_id"], + ["target.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("subscribe") + op.drop_table("user") + op.drop_table("target") + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrations/4a46ba54a3f3_alter_type.py b/src/plugins/nonebot_bison/config/migrations/4a46ba54a3f3_alter_type.py new file mode 100644 index 00000000..4dbeefed --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/4a46ba54a3f3_alter_type.py @@ -0,0 +1,56 @@ +"""alter type + +Revision ID: 4a46ba54a3f3 +Revises: c97c445e2bdb +Create Date: 2022-03-27 21:50:10.911649 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "4a46ba54a3f3" +down_revision = "c97c445e2bdb" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("subscribe", schema=None) as batch_op: + batch_op.alter_column( + "categories", + existing_type=sa.VARCHAR(length=1024), + type_=sa.JSON(), + existing_nullable=True, + postgresql_using="categories::json", + ) + batch_op.alter_column( + "tags", + existing_type=sa.VARCHAR(length=1024), + type_=sa.JSON(), + existing_nullable=True, + postgresql_using="tags::json", + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("subscribe", schema=None) as batch_op: + batch_op.alter_column( + "tags", + existing_type=sa.JSON(), + type_=sa.VARCHAR(length=1024), + existing_nullable=True, + ) + batch_op.alter_column( + "categories", + existing_type=sa.JSON(), + type_=sa.VARCHAR(length=1024), + existing_nullable=True, + ) + + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrations/5da28f6facb3_rename_tables.py b/src/plugins/nonebot_bison/config/migrations/5da28f6facb3_rename_tables.py new file mode 100644 index 00000000..dd86893b --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/5da28f6facb3_rename_tables.py @@ -0,0 +1,33 @@ +"""rename tables + +Revision ID: 5da28f6facb3 +Revises: 5f3370328e44 +Create Date: 2023-01-15 19:04:54.987491 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "5da28f6facb3" +down_revision = "5f3370328e44" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table("target", "nonebot_bison_target") + op.rename_table("user", "nonebot_bison_user") + op.rename_table("schedule_time_weight", "nonebot_bison_scheduletimeweight") + op.rename_table("subscribe", "nonebot_bison_subscribe") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table("nonebot_bison_subscribe", "subscribe") + op.rename_table("nonebot_bison_scheduletimeweight", "schedule_time_weight") + op.rename_table("nonebot_bison_user", "user") + op.rename_table("nonebot_bison_target", "target") + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrations/5f3370328e44_add_time_weight_table.py b/src/plugins/nonebot_bison/config/migrations/5f3370328e44_add_time_weight_table.py new file mode 100644 index 00000000..696dfa71 --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/5f3370328e44_add_time_weight_table.py @@ -0,0 +1,48 @@ +"""add time-weight table + +Revision ID: 5f3370328e44 +Revises: a333d6224193 +Create Date: 2022-05-31 22:05:13.235981 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "5f3370328e44" +down_revision = "a333d6224193" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "schedule_time_weight", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("target_id", sa.Integer(), nullable=True), + sa.Column("start_time", sa.Time(), nullable=True), + sa.Column("end_time", sa.Time(), nullable=True), + sa.Column("weight", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["target_id"], + ["target.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + with op.batch_alter_table("target", schema=None) as batch_op: + batch_op.add_column(sa.Column("default_schedule_weight", sa.Integer(), nullable=True)) + batch_op.drop_column("last_schedule_time") + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("target", schema=None) as batch_op: + batch_op.add_column(sa.Column("last_schedule_time", sa.DATETIME(), nullable=True)) + batch_op.drop_column("default_schedule_weight") + + op.drop_table("schedule_time_weight") + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrations/632b8086bc2b_add_user_target.py b/src/plugins/nonebot_bison/config/migrations/632b8086bc2b_add_user_target.py new file mode 100644 index 00000000..a6f5e3a5 --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/632b8086bc2b_add_user_target.py @@ -0,0 +1,41 @@ +"""add user_target + +Revision ID: 632b8086bc2b +Revises: aceef470d69c +Create Date: 2023-03-20 00:39:30.199915 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import JSONB + +# revision identifiers, used by Alembic. +revision = "632b8086bc2b" +down_revision = "aceef470d69c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: + batch_op.drop_constraint("unique-user-constraint", type_="unique") + batch_op.add_column( + sa.Column( + "user_target", + sa.JSON().with_variant(JSONB, "postgresql"), + nullable=True, + ) + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: + batch_op.drop_column("user_target") + batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"]) + + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrations/67c38b3f39c2_make_user_target_not_nullable.py b/src/plugins/nonebot_bison/config/migrations/67c38b3f39c2_make_user_target_not_nullable.py new file mode 100644 index 00000000..1f3e07a7 --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/67c38b3f39c2_make_user_target_not_nullable.py @@ -0,0 +1,45 @@ +"""make user_target not nullable + +Revision ID: 67c38b3f39c2 +Revises: a5466912fad0 +Create Date: 2023-03-20 11:08:42.883556 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import JSONB + +# revision identifiers, used by Alembic. +revision = "67c38b3f39c2" +down_revision = "a5466912fad0" +branch_labels = None +depends_on = None + + +def jsonb_if_postgresql_else_json(): + return sa.JSON().with_variant(JSONB, "postgresql") + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: + batch_op.alter_column( + "user_target", + existing_type=jsonb_if_postgresql_else_json(), + nullable=False, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: + batch_op.alter_column( + "user_target", + existing_type=jsonb_if_postgresql_else_json(), + nullable=True, + ) + + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrations/8d3863e9d74b_remove_uid_and_type.py b/src/plugins/nonebot_bison/config/migrations/8d3863e9d74b_remove_uid_and_type.py new file mode 100644 index 00000000..649e7f66 --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/8d3863e9d74b_remove_uid_and_type.py @@ -0,0 +1,34 @@ +"""remove uid and type + +Revision ID: 8d3863e9d74b +Revises: 67c38b3f39c2 +Create Date: 2023-03-20 15:38:20.220599 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8d3863e9d74b" +down_revision = "67c38b3f39c2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: + batch_op.drop_column("uid") + batch_op.drop_column("type") + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_user", schema=None) as batch_op: + batch_op.add_column(sa.Column("type", sa.VARCHAR(length=20), nullable=False)) + batch_op.add_column(sa.Column("uid", sa.INTEGER(), nullable=False)) + + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrations/a333d6224193_add_last_scheduled_time.py b/src/plugins/nonebot_bison/config/migrations/a333d6224193_add_last_scheduled_time.py new file mode 100644 index 00000000..ad0892bc --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/a333d6224193_add_last_scheduled_time.py @@ -0,0 +1,32 @@ +"""add last scheduled time + +Revision ID: a333d6224193 +Revises: 4a46ba54a3f3 +Create Date: 2022-03-29 21:01:38.213153 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a333d6224193" +down_revision = "4a46ba54a3f3" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("target", schema=None) as batch_op: + batch_op.add_column(sa.Column("last_schedule_time", sa.DateTime(timezone=True), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("target", schema=None) as batch_op: + batch_op.drop_column("last_schedule_time") + + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrations/a5466912fad0_map_user.py b/src/plugins/nonebot_bison/config/migrations/a5466912fad0_map_user.py new file mode 100644 index 00000000..c89098f8 --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/a5466912fad0_map_user.py @@ -0,0 +1,52 @@ +"""map user + +Revision ID: a5466912fad0 +Revises: 632b8086bc2b +Create Date: 2023-03-20 01:14:42.623789 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.orm import Session +from sqlalchemy.ext.automap import automap_base + +# revision identifiers, used by Alembic. +revision = "a5466912fad0" +down_revision = "632b8086bc2b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + Base = automap_base() + Base.prepare(op.get_bind()) + User = Base.classes.nonebot_bison_user + with Session(op.get_bind()) as sess: + users = sess.scalars(sa.select(User)).all() + for user in users: + if user.type == "group": + user.user_target = {"platform_type": "QQ Group", "group_id": user.uid} + elif user.type == "private": + user.user_target = {"platform_type": "QQ Private", "user_id": user.uid} + else: + sess.delete(user) + sess.add_all(users) + sess.commit() + + +def downgrade() -> None: + Base = automap_base() + Base.prepare(op.get_bind()) + User = Base.classes.nonebot_bison_user + with Session(op.get_bind()) as sess: + users = sess.scalars(sa.select(User)).all() + for user in users: + if user.user_target["platform_type"] == "QQ Group": + user.uid = user.user_target["group_id"] + user.type = "group" + else: + user.uid = user.user_target["user_id"] + user.type = "private" + sess.add_all(users) + sess.commit() diff --git a/src/plugins/nonebot_bison/config/migrations/aceef470d69c_alter_fields_not_null.py b/src/plugins/nonebot_bison/config/migrations/aceef470d69c_alter_fields_not_null.py new file mode 100644 index 00000000..c51a400e --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/aceef470d69c_alter_fields_not_null.py @@ -0,0 +1,52 @@ +"""alter fields not null + +Revision ID: aceef470d69c +Revises: bd92923c218f +Create Date: 2023-03-09 19:10:42.168133 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "aceef470d69c" +down_revision = "bd92923c218f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_scheduletimeweight", schema=None) as batch_op: + batch_op.alter_column("target_id", existing_type=sa.INTEGER(), nullable=False) + batch_op.alter_column("start_time", existing_type=sa.TIME(), nullable=False) + batch_op.alter_column("end_time", existing_type=sa.TIME(), nullable=False) + batch_op.alter_column("weight", existing_type=sa.INTEGER(), nullable=False) + + with op.batch_alter_table("nonebot_bison_subscribe", schema=None) as batch_op: + batch_op.alter_column("target_id", existing_type=sa.INTEGER(), nullable=False) + batch_op.alter_column("user_id", existing_type=sa.INTEGER(), nullable=False) + + with op.batch_alter_table("nonebot_bison_target", schema=None) as batch_op: + batch_op.alter_column("default_schedule_weight", existing_type=sa.INTEGER(), nullable=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_target", schema=None) as batch_op: + batch_op.alter_column("default_schedule_weight", existing_type=sa.INTEGER(), nullable=True) + + with op.batch_alter_table("nonebot_bison_subscribe", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.INTEGER(), nullable=True) + batch_op.alter_column("target_id", existing_type=sa.INTEGER(), nullable=True) + + with op.batch_alter_table("nonebot_bison_scheduletimeweight", schema=None) as batch_op: + batch_op.alter_column("weight", existing_type=sa.INTEGER(), nullable=True) + batch_op.alter_column("end_time", existing_type=sa.TIME(), nullable=True) + batch_op.alter_column("start_time", existing_type=sa.TIME(), nullable=True) + batch_op.alter_column("target_id", existing_type=sa.INTEGER(), nullable=True) + + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrations/bd92923c218f_alter_json_not_null.py b/src/plugins/nonebot_bison/config/migrations/bd92923c218f_alter_json_not_null.py new file mode 100644 index 00000000..aa3f2ff9 --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/bd92923c218f_alter_json_not_null.py @@ -0,0 +1,53 @@ +"""alter_json_not_null + +Revision ID: bd92923c218f +Revises: 5da28f6facb3 +Create Date: 2023-03-02 14:04:16.492133 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy import select +from sqlalchemy.orm import Session +from sqlalchemy.ext.automap import automap_base + +# revision identifiers, used by Alembic. +revision = "bd92923c218f" +down_revision = "5da28f6facb3" +branch_labels = None +depends_on = None + + +def set_default_value(): + Base = automap_base() + Base.prepare(autoload_with=op.get_bind()) + Subscribe = Base.classes.nonebot_bison_subscribe + with Session(op.get_bind()) as session: + select_statement = select(Subscribe) + results = session.scalars(select_statement) + for subscribe in results: + if subscribe.categories is None: + subscribe.categories = [] + if subscribe.tags is None: + subscribe.tags = [] + session.commit() + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + set_default_value() + with op.batch_alter_table("nonebot_bison_subscribe", schema=None) as batch_op: + batch_op.alter_column("categories", existing_type=sa.JSON(), nullable=False) + batch_op.alter_column("tags", existing_type=sa.JSON(), nullable=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("nonebot_bison_subscribe", schema=None) as batch_op: + batch_op.alter_column("tags", existing_type=sa.JSON(), nullable=True) + batch_op.alter_column("categories", existing_type=sa.JSON(), nullable=True) + + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrations/c97c445e2bdb_add_constraint.py b/src/plugins/nonebot_bison/config/migrations/c97c445e2bdb_add_constraint.py new file mode 100644 index 00000000..0388316e --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/c97c445e2bdb_add_constraint.py @@ -0,0 +1,43 @@ +"""add constraint + +Revision ID: c97c445e2bdb +Revises: 0571870f5222 +Create Date: 2022-03-26 19:46:50.910721 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c97c445e2bdb" +down_revision = "0571870f5222" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("subscribe", schema=None) as batch_op: + batch_op.create_unique_constraint("unique-subscribe-constraint", ["target_id", "user_id"]) + + with op.batch_alter_table("target", schema=None) as batch_op: + batch_op.create_unique_constraint("unique-target-constraint", ["target", "platform_name"]) + + with op.batch_alter_table("user", schema=None) as batch_op: + batch_op.create_unique_constraint("unique-user-constraint", ["type", "uid"]) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("user", schema=None) as batch_op: + batch_op.drop_constraint("unique-user-constraint", type_="unique") + + with op.batch_alter_table("target", schema=None) as batch_op: + batch_op.drop_constraint("unique-target-constraint", type_="unique") + + with op.batch_alter_table("subscribe", schema=None) as batch_op: + batch_op.drop_constraint("unique-subscribe-constraint", type_="unique") + + # ### end Alembic commands ### diff --git a/src/plugins/nonebot_bison/config/migrations/f9baef347cc8_remove_old_target.py b/src/plugins/nonebot_bison/config/migrations/f9baef347cc8_remove_old_target.py new file mode 100644 index 00000000..fbed082d --- /dev/null +++ b/src/plugins/nonebot_bison/config/migrations/f9baef347cc8_remove_old_target.py @@ -0,0 +1,34 @@ +"""remove_old_target + +Revision ID: f9baef347cc8 +Revises: 8d3863e9d74b +Create Date: 2023-08-25 00:20:51.511329 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.orm import Session +from sqlalchemy.ext.automap import automap_base + +# revision identifiers, used by Alembic. +revision = "f9baef347cc8" +down_revision = "8d3863e9d74b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + Base = automap_base() + Base.prepare(op.get_bind()) + User = Base.classes.nonebot_bison_user + with Session(op.get_bind()) as sess: + users = sess.scalars(sa.select(User)).all() + for user in users: + if user.user_target["platform_type"] == "Unknow Onebot 12 Platform": + sess.delete(user) + sess.commit() + + +def downgrade() -> None: + pass diff --git a/src/plugins/nonebot_bison/config/subs_io/__init__.py b/src/plugins/nonebot_bison/config/subs_io/__init__.py new file mode 100644 index 00000000..55ab0170 --- /dev/null +++ b/src/plugins/nonebot_bison/config/subs_io/__init__.py @@ -0,0 +1,3 @@ +from .subs_io import subscribes_export, subscribes_import + +__all__ = ["subscribes_export", "subscribes_import"] diff --git a/src/plugins/nonebot_bison/config/subs_io/nbesf_model/__init__.py b/src/plugins/nonebot_bison/config/subs_io/nbesf_model/__init__.py new file mode 100644 index 00000000..8dae14f7 --- /dev/null +++ b/src/plugins/nonebot_bison/config/subs_io/nbesf_model/__init__.py @@ -0,0 +1,6 @@ +"""nbesf is Nonebot Bison Enchangable Subscribes File!""" + +from . import v1, v2 +from .base import NBESFBase + +__all__ = ["v1", "v2", "NBESFBase"] diff --git a/src/plugins/nonebot_bison/config/subs_io/nbesf_model/base.py b/src/plugins/nonebot_bison/config/subs_io/nbesf_model/base.py new file mode 100644 index 00000000..426c8199 --- /dev/null +++ b/src/plugins/nonebot_bison/config/subs_io/nbesf_model/base.py @@ -0,0 +1,35 @@ +from abc import ABC + +from pydantic import BaseModel +from nonebot.compat import PYDANTIC_V2, ConfigDict +from nonebot_plugin_saa.registries import AllSupportedPlatformTarget as UserInfo + +from ....types import Tag, Category + + +class NBESFBase(BaseModel, ABC): + version: int # 表示nbesf格式版本,有效版本从1开始 + groups: list = [] + + if PYDANTIC_V2: + model_config = ConfigDict(from_attributes=True) + else: + + class Config: + orm_mode = True + + +class SubReceipt(BaseModel): + """ + 快递包中每件货物的收据 + + 导入订阅时的Model + """ + + user: UserInfo + target: str + target_name: str + platform_name: str + cats: list[Category] + tags: list[Tag] + # default_schedule_weight: int diff --git a/src/plugins/nonebot_bison/config/subs_io/nbesf_model/v1.py b/src/plugins/nonebot_bison/config/subs_io/nbesf_model/v1.py new file mode 100644 index 00000000..324edf36 --- /dev/null +++ b/src/plugins/nonebot_bison/config/subs_io/nbesf_model/v1.py @@ -0,0 +1,130 @@ +"""nbesf is Nonebot Bison Enchangable Subscribes File! ver.1""" + +from typing import Any +from functools import partial + +from nonebot.log import logger +from pydantic import BaseModel +from nonebot_plugin_saa import TargetQQGroup, TargetQQPrivate +from nonebot.compat import PYDANTIC_V2, ConfigDict, model_dump, type_validate_json, type_validate_python + +from ..utils import NBESFParseErr +from ....types import Tag, Category +from .base import NBESFBase, SubReceipt +from ...db_config import SubscribeDupException, config + +# ===== nbesf 定义格式 ====== # +NBESF_VERSION = 1 + + +class UserHead(BaseModel): + """Bison快递包收货信息""" + + type: str + uid: int + + if PYDANTIC_V2: + model_config = ConfigDict(from_attributes=True) + else: + + class Config: + orm_mode = True + + +class Target(BaseModel): + """Bsion快递包发货信息""" + + target_name: str + target: str + platform_name: str + default_schedule_weight: int + + if PYDANTIC_V2: + model_config = ConfigDict(from_attributes=True) + else: + + class Config: + orm_mode = True + + +class SubPayload(BaseModel): + """Bison快递包里的单件货物""" + + categories: list[Category] + tags: list[Tag] + target: Target + + if PYDANTIC_V2: + model_config = ConfigDict(from_attributes=True) + else: + + class Config: + orm_mode = True + + +class SubPack(BaseModel): + """Bison给指定用户派送的快递包""" + + user: UserHead + subs: list[SubPayload] + + +class SubGroup( + NBESFBase, +): + """ + Bison的全部订单(按用户分组) + + 结构参见`nbesf_model`下的对应版本 + """ + + version: int = NBESF_VERSION + groups: list[SubPack] + + +# ======================= # + + +async def subs_receipt_gen(nbesf_data: SubGroup): + for item in nbesf_data.groups: + match item.user.type: + case "group": + user = TargetQQGroup(group_id=item.user.uid) + case "private": + user = TargetQQPrivate(user_id=item.user.uid) + case _: + raise NotImplementedError(f"nbesf v1 不支持的用户类型:{item.user.type}") + + sub_receipt = partial(SubReceipt, user=user) + + for sub in item.subs: + receipt = sub_receipt( + target=sub.target.target, + target_name=sub.target.target_name, + platform_name=sub.target.platform_name, + cats=sub.categories, + tags=sub.tags, + ) + try: + await config.add_subscribe(receipt.user, **model_dump(receipt, exclude={"user"})) + except SubscribeDupException: + logger.warning(f"!添加订阅条目 {repr(receipt)} 失败: 相同的订阅已存在") + except Exception as e: + logger.error(f"!添加订阅条目 {repr(receipt)} 失败: {repr(e)}") + else: + logger.success(f"添加订阅条目 {repr(receipt)} 成功!") + + +def nbesf_parser(raw_data: Any) -> SubGroup: + try: + if isinstance(raw_data, str): + nbesf_data = type_validate_json(SubGroup, raw_data) + else: + nbesf_data = type_validate_python(SubGroup, raw_data) + + except Exception as e: + logger.error("数据解析失败,该数据格式可能不满足NBESF格式标准!") + raise NBESFParseErr("数据解析失败") from e + else: + logger.success("NBESF文件解析成功.") + return nbesf_data diff --git a/src/plugins/nonebot_bison/config/subs_io/nbesf_model/v2.py b/src/plugins/nonebot_bison/config/subs_io/nbesf_model/v2.py new file mode 100644 index 00000000..7b2a1884 --- /dev/null +++ b/src/plugins/nonebot_bison/config/subs_io/nbesf_model/v2.py @@ -0,0 +1,106 @@ +"""nbesf is Nonebot Bison Enchangable Subscribes File! ver.2""" + +from typing import Any +from functools import partial + +from nonebot.log import logger +from pydantic import BaseModel +from nonebot_plugin_saa.registries import AllSupportedPlatformTarget +from nonebot.compat import PYDANTIC_V2, ConfigDict, model_dump, type_validate_json, type_validate_python + +from ..utils import NBESFParseErr +from ....types import Tag, Category +from .base import NBESFBase, SubReceipt +from ...db_config import SubscribeDupException, config + +# ===== nbesf 定义格式 ====== # +NBESF_VERSION = 2 + + +class Target(BaseModel): + """Bsion快递包发货信息""" + + target_name: str + target: str + platform_name: str + default_schedule_weight: int + + if PYDANTIC_V2: + model_config = ConfigDict(from_attributes=True) + else: + + class Config: + orm_mode = True + + +class SubPayload(BaseModel): + """Bison快递包里的单件货物""" + + categories: list[Category] + tags: list[Tag] + target: Target + + if PYDANTIC_V2: + model_config = ConfigDict(from_attributes=True) + else: + + class Config: + orm_mode = True + + +class SubPack(BaseModel): + """Bison给指定用户派送的快递包""" + + # user_target: Bison快递包收货信息 + user_target: AllSupportedPlatformTarget + subs: list[SubPayload] + + +class SubGroup(NBESFBase): + """ + Bison的全部订单(按用户分组) + + 结构参见`nbesf_model`下的对应版本 + """ + + version: int = NBESF_VERSION + groups: list[SubPack] + + +# ======================= # + + +async def subs_receipt_gen(nbesf_data: SubGroup): + for item in nbesf_data.groups: + sub_receipt = partial(SubReceipt, user=item.user_target) + + for sub in item.subs: + receipt = sub_receipt( + target=sub.target.target, + target_name=sub.target.target_name, + platform_name=sub.target.platform_name, + cats=sub.categories, + tags=sub.tags, + ) + try: + await config.add_subscribe(receipt.user, **model_dump(receipt, exclude={"user"})) + except SubscribeDupException: + logger.warning(f"!添加订阅条目 {repr(receipt)} 失败: 相同的订阅已存在") + except Exception as e: + logger.error(f"!添加订阅条目 {repr(receipt)} 失败: {repr(e)}") + else: + logger.success(f"添加订阅条目 {repr(receipt)} 成功!") + + +def nbesf_parser(raw_data: Any) -> SubGroup: + try: + if isinstance(raw_data, str): + nbesf_data = type_validate_json(SubGroup, raw_data) + else: + nbesf_data = type_validate_python(SubGroup, raw_data) + + except Exception as e: + logger.error("数据解析失败,该数据格式可能不满足NBESF格式标准!") + raise NBESFParseErr("数据解析失败") from e + else: + return nbesf_data diff --git a/src/plugins/nonebot_bison/config/subs_io/subs_io.py b/src/plugins/nonebot_bison/config/subs_io/subs_io.py new file mode 100644 index 00000000..ec826957 --- /dev/null +++ b/src/plugins/nonebot_bison/config/subs_io/subs_io.py @@ -0,0 +1,77 @@ +from typing import cast +from collections import defaultdict +from collections.abc import Callable + +from sqlalchemy import select +from nonebot.log import logger +from sqlalchemy.sql.selectable import Select +from nonebot_plugin_saa import PlatformTarget +from nonebot.compat import type_validate_python +from nonebot_plugin_datastore.db import create_session +from sqlalchemy.orm.strategy_options import selectinload + +from .utils import NBESFVerMatchErr +from ..db_model import User, Subscribe +from .nbesf_model import NBESFBase, v1, v2 + + +async def subscribes_export(selector: Callable[[Select], Select]) -> v2.SubGroup: + """ + 将Bison订阅导出为 Nonebot Bison Exchangable Subscribes File 标准格式的 SubGroup 类型数据 + + selector: + 对 sqlalchemy Select 对象的操作函数,用于限定查询范围 + e.g. lambda stmt: stmt.where(User.uid=2233, User.type="group") + """ + async with create_session() as sess: + sub_stmt = select(Subscribe).join(User) + sub_stmt = selector(sub_stmt).options(selectinload(Subscribe.target)) + sub_stmt = cast(Select[tuple[Subscribe]], sub_stmt) + sub_data = await sess.scalars(sub_stmt) + + user_stmt = select(User).join(Subscribe) + user_stmt = selector(user_stmt).distinct() + user_stmt = cast(Select[tuple[User]], user_stmt) + user_data = await sess.scalars(user_stmt) + + groups: list[v2.SubPack] = [] + user_id_sub_dict: dict[int, list[v2.SubPayload]] = defaultdict(list) + + for sub in sub_data: + sub_paylaod = type_validate_python(v2.SubPayload, sub) + user_id_sub_dict[sub.user_id].append(sub_paylaod) + + for user in user_data: + assert isinstance(user, User) + sub_pack = v2.SubPack( + user_target=PlatformTarget.deserialize(user.user_target), + subs=user_id_sub_dict[user.id], + ) + groups.append(sub_pack) + + sub_group = v2.SubGroup(groups=groups) + + return sub_group + + +async def subscribes_import( + nbesf_data: NBESFBase, +): + """ + 从 Nonebot Bison Exchangable Subscribes File 标准格式的数据中导入订阅 + + nbesf_data: + 符合nbesf_model标准的 SubGroup 类型数据 + """ + + logger.info("开始添加订阅流程") + match nbesf_data.version: + case 1: + assert isinstance(nbesf_data, v1.SubGroup) + await v1.subs_receipt_gen(nbesf_data) + case 2: + assert isinstance(nbesf_data, v2.SubGroup) + await v2.subs_receipt_gen(nbesf_data) + case _: + raise NBESFVerMatchErr(f"不支持的NBESF版本:{nbesf_data.version}") + logger.info("订阅流程结束,请检查所有订阅记录是否全部添加成功") diff --git a/src/plugins/nonebot_bison/config/subs_io/utils.py b/src/plugins/nonebot_bison/config/subs_io/utils.py new file mode 100644 index 00000000..181769a2 --- /dev/null +++ b/src/plugins/nonebot_bison/config/subs_io/utils.py @@ -0,0 +1,4 @@ +class NBESFVerMatchErr(Exception): ... + + +class NBESFParseErr(Exception): ... diff --git a/src/plugins/nonebot_bison/config/utils.py b/src/plugins/nonebot_bison/config/utils.py new file mode 100644 index 00000000..8c064974 --- /dev/null +++ b/src/plugins/nonebot_bison/config/utils.py @@ -0,0 +1,10 @@ +class NoSuchUserException(Exception): + pass + + +class NoSuchSubscribeException(Exception): + pass + + +class NoSuchTargetException(Exception): + pass diff --git a/src/plugins/nonebot_bison/platform/__init__.py b/src/plugins/nonebot_bison/platform/__init__.py new file mode 100644 index 00000000..c99ce122 --- /dev/null +++ b/src/plugins/nonebot_bison/platform/__init__.py @@ -0,0 +1,24 @@ +from pathlib import Path +from pkgutil import iter_modules +from collections import defaultdict +from importlib import import_module + +from .platform import Platform, make_no_target_group + +_package_dir = str(Path(__file__).resolve().parent) +for _, module_name, _ in iter_modules([_package_dir]): + import_module(f"{__name__}.{module_name}") + + +_platform_list: defaultdict[str, list[type[Platform]]] = defaultdict(list) +for _platform in Platform.registry: + if not _platform.enabled: + continue + _platform_list[_platform.platform_name].append(_platform) + +platform_manager: dict[str, type[Platform]] = {} +for name, platform_list in _platform_list.items(): + if len(platform_list) == 1: + platform_manager[name] = platform_list[0] + else: + platform_manager[name] = make_no_target_group(platform_list) diff --git a/src/plugins/nonebot_bison/platform/arknights.py b/src/plugins/nonebot_bison/platform/arknights.py new file mode 100644 index 00000000..e931d69e --- /dev/null +++ b/src/plugins/nonebot_bison/platform/arknights.py @@ -0,0 +1,251 @@ +from typing import Any +from functools import partial + +from yarl import URL +from httpx import AsyncClient +from bs4 import BeautifulSoup as bs +from pydantic import Field, BaseModel +from nonebot.compat import type_validate_python + +from ..post import Post +from ..types import Target, RawPost, Category +from .platform import NewMessage, StatusChange +from ..utils.scheduler_config import SchedulerConfig + + +class ArkResponseBase(BaseModel): + code: int + msg: str + + +class BulletinListItem(BaseModel): + cid: str + title: str + category: int + display_time: str = Field(alias="displayTime") + updated_at: int = Field(alias="updatedAt") + sticky: bool + + +class BulletinList(BaseModel): + list: list[BulletinListItem] + + +class BulletinData(BaseModel): + cid: str + display_type: int = Field(alias="displayType") + title: str + category: int + header: str + content: str + jump_link: str = Field(alias="jumpLink") + banner_image_url: str = Field(alias="bannerImageUrl") + display_time: str = Field(alias="displayTime") + updated_at: int = Field(alias="updatedAt") + + +class ArkBulletinListResponse(ArkResponseBase): + data: BulletinList + + +class ArkBulletinResponse(ArkResponseBase): + data: BulletinData + + +class ArknightsSchedConf(SchedulerConfig): + name = "arknights" + schedule_type = "interval" + schedule_setting = {"seconds": 30} + + +class Arknights(NewMessage): + categories = {1: "游戏公告"} + platform_name = "arknights" + name = "明日方舟游戏信息" + enable_tag = False + enabled = True + is_common = False + scheduler = ArknightsSchedConf + has_target = False + default_theme = "arknights" + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + return "明日方舟游戏信息" + + async def get_sub_list(self, _) -> list[BulletinListItem]: + raw_data = await self.client.get("https://ak-webview.hypergryph.com/api/game/bulletinList?target=IOS") + return type_validate_python(ArkBulletinListResponse, raw_data.json()).data.list + + def get_id(self, post: BulletinListItem) -> Any: + return post.cid + + def get_date(self, post: BulletinListItem) -> Any: + # 为什么不使用post.updated_at? + # update_at的时间是上传鹰角服务器的时间,而不是公告发布的时间 + # 也就是说鹰角可能会在中午就把晚上的公告上传到服务器,但晚上公告才会显示,但是update_at就是中午的时间不会改变 + # 如果指定了get_date,那么get_date会被优先使用, 并在获取到的值超过2小时时忽略这条post,导致其不会被发送 + return None + + def get_category(self, _) -> Category: + return Category(1) + + async def parse(self, raw_post: BulletinListItem) -> Post: + raw_data = await self.client.get( + f"https://ak-webview.hypergryph.com/api/game/bulletin/{self.get_id(post=raw_post)}" + ) + data = type_validate_python(ArkBulletinResponse, raw_data.json()).data + + def title_escape(text: str) -> str: + return text.replace("\\n", " - ") + + # gen title, content + if data.header: + # header是title的更详细版本 + # header会和content一起出现 + title = data.header + else: + # 只有一张图片 + title = title_escape(data.title) + + return Post( + self, + content=data.content, + title=title, + nickname="明日方舟游戏内公告", + images=[data.banner_image_url] if data.banner_image_url else None, + url=(url.human_repr() if (url := URL(data.jump_link)).scheme.startswith("http") else None), + timestamp=data.updated_at, + compress=True, + ) + + +class AkVersion(StatusChange): + categories = {2: "更新信息"} + platform_name = "arknights" + name = "明日方舟游戏信息" + enable_tag = False + enabled = True + is_common = False + scheduler = ArknightsSchedConf + has_target = False + default_theme = "brief" + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + return "明日方舟游戏信息" + + async def get_status(self, _): + res_ver = await self.client.get("https://ak-conf.hypergryph.com/config/prod/official/IOS/version") + res_preanounce = await self.client.get( + "https://ak-conf.hypergryph.com/config/prod/announce_meta/IOS/preannouncement.meta.json" + ) + res = res_ver.json() + res.update(res_preanounce.json()) + return res + + def compare_status(self, _, old_status, new_status): + res = [] + ArkUpdatePost = partial(Post, self, "", nickname="明日方舟更新信息") + if old_status.get("preAnnounceType") == 2 and new_status.get("preAnnounceType") == 0: + res.append(ArkUpdatePost(title="登录界面维护公告上线(大概是开始维护了)")) + elif old_status.get("preAnnounceType") == 0 and new_status.get("preAnnounceType") == 2: + res.append(ArkUpdatePost(title="登录界面维护公告下线(大概是开服了,冲!)")) + if old_status.get("clientVersion") != new_status.get("clientVersion"): + res.append(ArkUpdatePost(title="游戏本体更新(大更新)")) + if old_status.get("resVersion") != new_status.get("resVersion"): + res.append(ArkUpdatePost(title="游戏资源更新(小更新)")) + return res + + def get_category(self, _): + return Category(2) + + async def parse(self, raw_post): + return raw_post + + +class MonsterSiren(NewMessage): + categories = {3: "塞壬唱片新闻"} + platform_name = "arknights" + name = "明日方舟游戏信息" + enable_tag = False + enabled = True + is_common = False + scheduler = ArknightsSchedConf + has_target = False + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + return "明日方舟游戏信息" + + async def get_sub_list(self, _) -> list[RawPost]: + raw_data = await self.client.get("https://monster-siren.hypergryph.com/api/news") + return raw_data.json()["data"]["list"] + + def get_id(self, post: RawPost) -> Any: + return post["cid"] + + def get_date(self, _) -> None: + return None + + def get_category(self, _) -> Category: + return Category(3) + + async def parse(self, raw_post: RawPost) -> Post: + url = f'https://monster-siren.hypergryph.com/info/{raw_post["cid"]}' + res = await self.client.get(f'https://monster-siren.hypergryph.com/api/news/{raw_post["cid"]}') + raw_data = res.json() + content = raw_data["data"]["content"] + content = content.replace("
", "\n") + soup = bs(content, "html.parser") + imgs = [x["src"] for x in soup("img")] + text = f'{raw_post["title"]}\n{soup.text.strip()}' + return Post( + self, + text, + images=imgs, + url=url, + nickname="塞壬唱片新闻", + compress=True, + ) + + +class TerraHistoricusComic(NewMessage): + categories = {4: "泰拉记事社漫画"} + platform_name = "arknights" + name = "明日方舟游戏信息" + enable_tag = False + enabled = True + is_common = False + scheduler = ArknightsSchedConf + has_target = False + default_theme = "brief" + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + return "明日方舟游戏信息" + + async def get_sub_list(self, _) -> list[RawPost]: + raw_data = await self.client.get("https://terra-historicus.hypergryph.com/api/recentUpdate") + return raw_data.json()["data"] + + def get_id(self, post: RawPost) -> Any: + return f'{post["comicCid"]}/{post["episodeCid"]}' + + def get_date(self, _) -> None: + return None + + def get_category(self, _) -> Category: + return Category(4) + + async def parse(self, raw_post: RawPost) -> Post: + url = f'https://terra-historicus.hypergryph.com/comic/{raw_post["comicCid"]}/episode/{raw_post["episodeCid"]}' + return Post( + self, + raw_post["subtitle"], + title=f'{raw_post["title"]} - {raw_post["episodeShortTitle"]}', + images=[raw_post["coverUrl"]], + url=url, + nickname="泰拉记事社漫画", + compress=True, + ) diff --git a/src/plugins/nonebot_bison/platform/bilibili.py b/src/plugins/nonebot_bison/platform/bilibili.py new file mode 100644 index 00000000..e8936593 --- /dev/null +++ b/src/plugins/nonebot_bison/platform/bilibili.py @@ -0,0 +1,567 @@ +import re +import json +from abc import ABC +from copy import deepcopy +from enum import Enum, unique +from typing_extensions import Self +from datetime import datetime, timedelta +from typing import Any, TypeVar, TypeAlias, NamedTuple + +from httpx import AsyncClient +from nonebot.log import logger +from pydantic import Field, BaseModel +from nonebot.compat import PYDANTIC_V2, ConfigDict, type_validate_json, type_validate_python + +from nonebot_bison.compat import model_rebuild + +from ..post import Post +from ..types import Tag, Target, RawPost, ApiError, Category +from ..utils import SchedulerConfig, http_client, text_similarity +from .platform import NewMessage, StatusChange, CategoryNotSupport, CategoryNotRecognize + +TBaseModel = TypeVar("TBaseModel", bound=type[BaseModel]) + + +# 不能当成装饰器用 +# 当装饰器用时,global namespace 中还没有被装饰的类,会报错 +def model_rebuild_recurse(cls: TBaseModel) -> TBaseModel: + """Recursively rebuild all BaseModel subclasses in the class.""" + if not PYDANTIC_V2: + from inspect import isclass, getmembers + + for _, sub_cls in getmembers(cls, lambda x: isclass(x) and issubclass(x, BaseModel)): + model_rebuild_recurse(sub_cls) + model_rebuild(cls) + return cls + + +class Base(BaseModel): + if PYDANTIC_V2: + model_config = ConfigDict(from_attributes=True) + else: + + class Config: + orm_mode = True + + +class APIBase(Base): + """Bilibili API返回的基础数据""" + + code: int + message: str + + +class UserAPI(APIBase): + class Card(Base): + name: str + + class Data(Base): + card: "UserAPI.Card" + + data: Data | None = None + + +class PostAPI(APIBase): + class Info(Base): + uname: str + + class UserProfile(Base): + info: "PostAPI.Info" + + class Origin(Base): + uid: int + dynamic_id: int + dynamic_id_str: str + timestamp: int + type: int + rid: int + bvid: str | None = None + + class Desc(Base): + dynamic_id: int + dynamic_id_str: str + timestamp: int + type: int + user_profile: "PostAPI.UserProfile" + rid: int + bvid: str | None = None + + origin: "PostAPI.Origin | None" = None + + class Card(Base): + desc: "PostAPI.Desc" + card: str + + class Data(Base): + cards: "list[PostAPI.Card] | None" + + data: Data | None = None + + +DynRawPost: TypeAlias = PostAPI.Card + +model_rebuild_recurse(UserAPI) +model_rebuild_recurse(PostAPI) + + +class BilibiliClient: + _client: AsyncClient + _refresh_time: datetime + cookie_expire_time = timedelta(hours=5) + + def __init__(self) -> None: + self._client = http_client() + self._refresh_time = datetime(year=2000, month=1, day=1) # an expired time + + async def _init_session(self): + res = await self._client.get("https://www.bilibili.com/") + if res.status_code != 200: + logger.warning("unable to refresh temp cookie") + else: + self._refresh_time = datetime.now() + + async def _refresh_client(self): + if datetime.now() - self._refresh_time > self.cookie_expire_time: + await self._init_session() + + async def get_client(self) -> AsyncClient: + await self._refresh_client() + return self._client + + +bilibili_client = BilibiliClient() + + +class BaseSchedConf(ABC, SchedulerConfig): + schedule_type = "interval" + bilibili_client: BilibiliClient + + def __init__(self): + super().__init__() + self.bilibili_client = bilibili_client + + async def get_client(self, _: Target) -> AsyncClient: + return await self.bilibili_client.get_client() + + async def get_query_name_client(self) -> AsyncClient: + return await self.bilibili_client.get_client() + + +class BilibiliSchedConf(BaseSchedConf): + name = "bilibili.com" + schedule_setting = {"seconds": 10} + + +class BililiveSchedConf(BaseSchedConf): + name = "live.bilibili.com" + schedule_setting = {"seconds": 3} + + +class Bilibili(NewMessage): + categories = { + 1: "一般动态", + 2: "专栏文章", + 3: "视频", + 4: "纯文字", + 5: "转发", + # 5: "短视频" + } + platform_name = "bilibili" + enable_tag = True + enabled = True + is_common = True + scheduler = BilibiliSchedConf + name = "B站" + has_target = True + parse_target_promot = "请输入用户主页的链接" + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target}) + res.raise_for_status() + res_data = type_validate_json(UserAPI, res.content) + if res_data.code != 0: + return None + return res_data.data.card.name if res_data.data else None + + @classmethod + async def parse_target(cls, target_text: str) -> Target: + if re.match(r"\d+", target_text): + return Target(target_text) + elif m := re.match(r"(?:https?://)?space\.bilibili\.com/(\d+)", target_text): + return Target(m.group(1)) + else: + raise cls.ParseTargetException() + + async def get_sub_list(self, target: Target) -> list[DynRawPost]: + params = {"host_uid": target, "offset": 0, "need_top": 0} + res = await self.client.get( + "https://api.vc.bilibili.com/dynamic_svr/v1/dynamic_svr/space_history", + params=params, + timeout=4.0, + ) + res.raise_for_status() + res_obj = type_validate_json(PostAPI, res.content) + + if res_obj.code == 0: + if (data := res_obj.data) and (card := data.cards): + return card + return [] + raise ApiError(res.request.url) + + def get_id(self, post: DynRawPost) -> int: + return post.desc.dynamic_id + + def get_date(self, post: DynRawPost) -> int: + return post.desc.timestamp + + def _do_get_category(self, post_type: int) -> Category: + match post_type: + case 2: + return Category(1) + case 64: + return Category(2) + case 8: + return Category(3) + case 4: + return Category(4) + case 1: + # 转发 + return Category(5) + case unknown_type: + raise CategoryNotRecognize(unknown_type) + + def get_category(self, post: DynRawPost) -> Category: + post_type = post.desc.type + return self._do_get_category(post_type) + + def get_tags(self, raw_post: DynRawPost) -> list[Tag]: + card_content = json.loads(raw_post.card) + text: str = card_content["item"]["content"] + result: list[str] = re.findall(r"#(.*?)#", text) + return result + + def _text_process(self, dynamic: str, desc: str, title: str) -> str: + similarity = 1.0 if len(dynamic) == 0 or len(desc) == 0 else text_similarity(dynamic, desc) + if len(dynamic) == 0 and len(desc) == 0: + text = title + elif similarity > 0.8: + text = title + "\n\n" + desc if len(dynamic) < len(desc) else dynamic + "\n=================\n" + title + else: + text = dynamic + "\n=================\n" + title + "\n\n" + desc + return text + + def _raw_post_parse(self, raw_post: DynRawPost, in_repost: bool = False): + class ParsedPost(NamedTuple): + text: str + pics: list[str] + url: str | None + repost_owner: str | None = None + repost: "ParsedPost | None" = None + + card_content: dict[str, Any] = json.loads(raw_post.card) + repost_owner: str | None = ou["info"]["uname"] if (ou := card_content.get("origin_user")) else None + + def extract_url_id(url_template: str, name: str) -> str | None: + if in_repost: + if origin := raw_post.desc.origin: + return url_template.format(getattr(origin, name)) + return None + return url_template.format(getattr(raw_post.desc, name)) + + match self._do_get_category(raw_post.desc.type): + case 1: + # 一般动态 + url = extract_url_id("https://t.bilibili.com/{}", "dynamic_id_str") + text: str = card_content["item"]["description"] + pic: list[str] = [img["img_src"] for img in card_content["item"]["pictures"]] + return ParsedPost(text, pic, url, repost_owner) + case 2: + # 专栏文章 + url = extract_url_id("https://www.bilibili.com/read/cv{}", "rid") + text = "{} {}".format(card_content["title"], card_content["summary"]) + pic = card_content["image_urls"] + return ParsedPost(text, pic, url, repost_owner) + case 3: + # 视频 + url = extract_url_id("https://www.bilibili.com/video/{}", "bvid") + dynamic = card_content.get("dynamic", "") + title = card_content["title"] + desc = card_content.get("desc", "") + text = self._text_process(dynamic, desc, title) + pic = [card_content["pic"]] + return ParsedPost(text, pic, url, repost_owner) + case 4: + # 纯文字 + url = extract_url_id("https://t.bilibili.com/{}", "dynamic_id_str") + text = card_content["item"]["content"] + pic = [] + return ParsedPost(text, pic, url, repost_owner) + case 5: + # 转发 + url = extract_url_id("https://t.bilibili.com/{}", "dynamic_id_str") + text = card_content["item"]["content"] + orig_type: int = card_content["item"]["orig_type"] + orig_card: str = card_content["origin"] + orig_post = DynRawPost(desc=raw_post.desc, card=orig_card) + orig_post.desc.type = orig_type + + orig_parsed_post = self._raw_post_parse(orig_post, in_repost=True) + return ParsedPost(text, [], url, repost_owner, orig_parsed_post) + case unsupported_type: + raise CategoryNotSupport(unsupported_type) + + async def parse(self, raw_post: DynRawPost) -> Post: + parsed_raw_post = self._raw_post_parse(raw_post) + + post = Post( + self, + parsed_raw_post.text, + url=parsed_raw_post.url, + images=list(parsed_raw_post.pics), + nickname=raw_post.desc.user_profile.info.uname, + ) + if rp := parsed_raw_post.repost: + post.repost = Post( + self, + rp.text, + url=rp.url, + images=list(rp.pics), + nickname=rp.repost_owner, + ) + return post + + +class Bilibililive(StatusChange): + categories = {1: "开播提醒", 2: "标题更新提醒", 3: "下播提醒"} + platform_name = "bilibili-live" + enable_tag = False + enabled = True + is_common = True + scheduler = BililiveSchedConf + name = "Bilibili直播" + has_target = True + use_batch = True + default_theme = "brief" + + @unique + class LiveStatus(Enum): + # 直播状态 + # 0: 未开播 + # 1: 正在直播 + # 2: 轮播中 + OFF = 0 + ON = 1 + CYCLE = 2 + + @unique + class LiveAction(Enum): + # 当前直播行为,由新旧直播状态对比决定 + # on: 正在直播 + # off: 未开播 + # turn_on: 状态变更为正在直播 + # turn_off: 状态变更为未开播 + # title_update: 标题更新 + TURN_ON = "turn_on" + TURN_OFF = "turn_off" + ON = "on" + OFF = "off" + TITLE_UPDATE = "title_update" + + class Info(BaseModel): + title: str + room_id: int # 直播间号 + uid: int # 主播uid + live_time: int # 开播时间 + live_status: "Bilibililive.LiveStatus" + area_name: str = Field(alias="area_v2_name") # 新版分区名 + uname: str # 主播名 + face: str # 头像url + cover: str = Field(alias="cover_from_user") # 封面url + keyframe: str # 关键帧url,可能会有延迟 + category: Category = Field(default=Category(0)) + + def get_live_action(self, old_info: Self) -> "Bilibililive.LiveAction": + status = Bilibililive.LiveStatus + action = Bilibililive.LiveAction + if old_info.live_status in [status.OFF, status.CYCLE] and self.live_status == status.ON: + return action.TURN_ON + elif old_info.live_status == status.ON and self.live_status in [ + status.OFF, + status.CYCLE, + ]: + return action.TURN_OFF + elif old_info.live_status == status.ON and self.live_status == status.ON: + if old_info.title != self.title: + # 开播时通常会改标题,避免短时间推送两次 + return action.TITLE_UPDATE + else: + return action.ON + else: + return action.OFF + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + res = await client.get("https://api.bilibili.com/x/web-interface/card", params={"mid": target}) + res_data = json.loads(res.text) + if res_data["code"]: + return None + return res_data["data"]["card"]["name"] + + def _gen_empty_info(self, uid: int) -> Info: + """返回一个空的Info,用于该用户没有直播间的情况""" + return Bilibililive.Info( + title="", + room_id=0, + uid=uid, + live_time=0, + live_status=Bilibililive.LiveStatus.OFF, + area_v2_name="", + uname="", + face="", + cover_from_user="", + keyframe="", + ) + + async def batch_get_status(self, targets: list[Target]) -> list[Info]: + # https://github.com/SocialSisterYi/bilibili-API-collect/blob/master/docs/live/info.md#批量查询直播间状态 + res = await self.client.get( + "https://api.live.bilibili.com/room/v1/Room/get_status_info_by_uids", + params={"uids[]": targets}, + timeout=4.0, + ) + res_dict = res.json() + + if res_dict["code"] != 0: + raise self.FetchError() + + data = res_dict.get("data", {}) + infos = [] + for target in targets: + if target in data.keys(): + infos.append(type_validate_python(self.Info, data[target])) + else: + infos.append(self._gen_empty_info(int(target))) + return infos + + def compare_status(self, _: Target, old_status: Info, new_status: Info) -> list[RawPost]: + action = Bilibililive.LiveAction + match new_status.get_live_action(old_status): + case action.TURN_ON: + return self._gen_current_status(new_status, 1) + case action.TITLE_UPDATE: + return self._gen_current_status(new_status, 2) + case action.TURN_OFF: + return self._gen_current_status(new_status, 3) + case _: + return [] + + def _gen_current_status(self, new_status: Info, category: Category): + current_status = deepcopy(new_status) + current_status.category = Category(category) + return [current_status] + + def get_category(self, status: Info) -> Category: + assert status.category != Category(0) + return status.category + + async def parse(self, raw_post: Info) -> Post: + url = f"https://live.bilibili.com/{raw_post.room_id}" + pic = [raw_post.cover] if raw_post.category == Category(1) else [raw_post.keyframe] + title = f"[{self.categories[raw_post.category].rstrip('提醒')}] {raw_post.title}" + target_name = f"{raw_post.uname} {raw_post.area_name}" + return Post( + self, + "", + title=title, + url=url, + images=list(pic), + nickname=target_name, + compress=True, + ) + + +class BilibiliBangumi(StatusChange): + categories = {} + platform_name = "bilibili-bangumi" + enable_tag = False + enabled = True + is_common = True + scheduler = BilibiliSchedConf + name = "Bilibili剧集" + has_target = True + parse_target_promot = "请输入剧集主页" + default_theme = "brief" + + _url = "https://api.bilibili.com/pgc/review/user" + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + res = await client.get(cls._url, params={"media_id": target}) + res_data = res.json() + if res_data["code"]: + return None + return res_data["result"]["media"]["title"] + + @classmethod + async def parse_target(cls, target_string: str) -> Target: + if re.match(r"\d+", target_string): + return Target(target_string) + elif m := re.match(r"md(\d+)", target_string): + return Target(m[1]) + elif m := re.match(r"(?:https?://)?www\.bilibili\.com/bangumi/media/md(\d+)", target_string): + return Target(m[1]) + raise cls.ParseTargetException() + + async def get_status(self, target: Target): + res = await self.client.get( + self._url, + params={"media_id": target}, + timeout=4.0, + ) + res_dict = res.json() + if res_dict["code"] == 0: + return { + "index": res_dict["result"]["media"]["new_ep"]["index"], + "index_show": res_dict["result"]["media"]["new_ep"]["index_show"], + "season_id": res_dict["result"]["media"]["season_id"], + } + else: + raise self.FetchError + + def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]: + if new_status["index"] != old_status["index"]: + return [new_status] + else: + return [] + + async def parse(self, raw_post: RawPost) -> Post: + detail_res = await self.client.get( + f'https://api.bilibili.com/pgc/view/web/season?season_id={raw_post["season_id"]}' + ) + detail_dict = detail_res.json() + lastest_episode = None + for episode in detail_dict["result"]["episodes"][::-1]: + if episode["badge"] in ("", "会员"): + lastest_episode = episode + break + if not lastest_episode: + lastest_episode = detail_dict["result"]["episodes"] + + url = lastest_episode["link"] + pic: list[str] = [lastest_episode["cover"]] + target_name = detail_dict["result"]["season_title"] + content = raw_post["index_show"] + title = lastest_episode["share_copy"] + return Post( + self, + content, + title=title, + url=url, + images=list(pic), + nickname=target_name, + compress=True, + ) + + +model_rebuild(Bilibililive.Info) diff --git a/src/plugins/nonebot_bison/platform/ff14.py b/src/plugins/nonebot_bison/platform/ff14.py new file mode 100644 index 00000000..e050aaef --- /dev/null +++ b/src/plugins/nonebot_bison/platform/ff14.py @@ -0,0 +1,46 @@ +from typing import Any + +from httpx import AsyncClient + +from ..post import Post +from ..utils import scheduler +from .platform import NewMessage +from ..types import Target, RawPost + + +class FF14(NewMessage): + categories = {} + platform_name = "ff14" + name = "最终幻想XIV官方公告" + enable_tag = False + enabled = True + is_common = False + scheduler_class = "ff14" + scheduler = scheduler("interval", {"seconds": 60}) + has_target = False + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + return "最终幻想XIV官方公告" + + async def get_sub_list(self, _) -> list[RawPost]: + raw_data = await self.client.get( + "https://cqnews.web.sdo.com/api/news/newsList?gameCode=ff&CategoryCode=5309,5310,5311,5312,5313&pageIndex=0&pageSize=5" + ) + return raw_data.json()["Data"] + + def get_id(self, post: RawPost) -> Any: + """用发布时间当作 ID + + 因为有时候官方会直接编辑以前的文章内容 + """ + return post["PublishDate"] + + def get_date(self, _: RawPost) -> None: + return None + + async def parse(self, raw_post: RawPost) -> Post: + title = raw_post["Title"] + text = raw_post["Summary"] + url = raw_post["Author"] + return Post(self, text, title=title, url=url, nickname="最终幻想XIV官方公告") diff --git a/src/plugins/nonebot_bison/platform/ncm.py b/src/plugins/nonebot_bison/platform/ncm.py new file mode 100644 index 00000000..031dc93b --- /dev/null +++ b/src/plugins/nonebot_bison/platform/ncm.py @@ -0,0 +1,129 @@ +import re +from typing import Any + +from httpx import AsyncClient + +from ..post import Post +from .platform import NewMessage +from ..utils import SchedulerConfig +from ..types import Target, RawPost, ApiError + + +class NcmSchedConf(SchedulerConfig): + name = "music.163.com" + schedule_type = "interval" + schedule_setting = {"minutes": 1} + + +class NcmArtist(NewMessage): + categories = {} + platform_name = "ncm-artist" + enable_tag = False + enabled = True + is_common = True + scheduler = NcmSchedConf + name = "网易云-歌手" + has_target = True + parse_target_promot = "请输入歌手主页(包含数字ID)的链接" + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + res = await client.get( + f"https://music.163.com/api/artist/albums/{target}", + headers={"Referer": "https://music.163.com/"}, + ) + res_data = res.json() + if res_data["code"] != 200: + raise ApiError(res.request.url) + return res_data["artist"]["name"] + + @classmethod + async def parse_target(cls, target_text: str) -> Target: + if re.match(r"^\d+$", target_text): + return Target(target_text) + elif match := re.match(r"(?:https?://)?music\.163\.com/#/artist\?id=(\d+)", target_text): + return Target(match.group(1)) + else: + raise cls.ParseTargetException() + + async def get_sub_list(self, target: Target) -> list[RawPost]: + res = await self.client.get( + f"https://music.163.com/api/artist/albums/{target}", + headers={"Referer": "https://music.163.com/"}, + ) + res_data = res.json() + if res_data["code"] != 200: + return [] + else: + return res_data["hotAlbums"] + + def get_id(self, post: RawPost) -> Any: + return post["id"] + + def get_date(self, post: RawPost) -> int: + return post["publishTime"] // 1000 + + async def parse(self, raw_post: RawPost) -> Post: + text = "新专辑发布:{}".format(raw_post["name"]) + target_name = raw_post["artist"]["name"] + pics = [raw_post["picUrl"]] + url = "https://music.163.com/#/album?id={}".format(raw_post["id"]) + return Post(self, text, url=url, images=pics, nickname=target_name) + + +class NcmRadio(NewMessage): + categories = {} + platform_name = "ncm-radio" + enable_tag = False + enabled = True + is_common = False + scheduler = NcmSchedConf + name = "网易云-电台" + has_target = True + parse_target_promot = "请输入主播电台主页(包含数字ID)的链接" + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + res = await client.post( + "http://music.163.com/api/dj/program/byradio", + headers={"Referer": "https://music.163.com/"}, + data={"radioId": target, "limit": 1000, "offset": 0}, + ) + res_data = res.json() + if res_data["code"] != 200 or res_data["programs"] == 0: + return + return res_data["programs"][0]["radio"]["name"] + + @classmethod + async def parse_target(cls, target_text: str) -> Target: + if re.match(r"^\d+$", target_text): + return Target(target_text) + elif match := re.match(r"(?:https?://)?music\.163\.com/#/djradio\?id=(\d+)", target_text): + return Target(match.group(1)) + else: + raise cls.ParseTargetException() + + async def get_sub_list(self, target: Target) -> list[RawPost]: + res = await self.client.post( + "http://music.163.com/api/dj/program/byradio", + headers={"Referer": "https://music.163.com/"}, + data={"radioId": target, "limit": 1000, "offset": 0}, + ) + res_data = res.json() + if res_data["code"] != 200: + return [] + else: + return res_data["programs"] + + def get_id(self, post: RawPost) -> Any: + return post["id"] + + def get_date(self, post: RawPost) -> int: + return post["createTime"] // 1000 + + async def parse(self, raw_post: RawPost) -> Post: + text = "网易云电台更新:{}".format(raw_post["name"]) + target_name = raw_post["radio"]["name"] + pics = [raw_post["coverUrl"]] + url = "https://music.163.com/#/program/{}".format(raw_post["id"]) + return Post(self, text, url=url, images=pics, nickname=target_name) diff --git a/src/plugins/nonebot_bison/platform/platform.py b/src/plugins/nonebot_bison/platform/platform.py new file mode 100644 index 00000000..0c902c65 --- /dev/null +++ b/src/plugins/nonebot_bison/platform/platform.py @@ -0,0 +1,501 @@ +import ssl +import json +import time +import typing +from dataclasses import dataclass +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, TypeVar, ParamSpec +from collections.abc import Callable, Awaitable, Collection + +import httpx +from httpx import AsyncClient +from nonebot.log import logger +from nonebot_plugin_saa import PlatformTarget + +from ..post import Post +from ..plugin_config import plugin_config +from ..utils import ProcessContext, SchedulerConfig +from ..types import Tag, Target, RawPost, SubUnit, Category + + +class CategoryNotSupport(Exception): + """raise in get_category, when you know the category of the post + but don't want to support it or don't support its parsing yet + """ + + +class CategoryNotRecognize(Exception): + """raise in get_category, when you don't know the category of post""" + + +class RegistryMeta(type): + def __new__(cls, name, bases, namespace, **kwargs): + return super().__new__(cls, name, bases, namespace) + + def __init__(cls, name, bases, namespace, **kwargs): + if kwargs.get("base"): + # this is the base class + cls.registry = [] + elif not kwargs.get("abstract"): + # this is the subclass + cls.registry.append(cls) + + super().__init__(name, bases, namespace, **kwargs) + + +P = ParamSpec("P") +R = TypeVar("R") + + +async def catch_network_error(func: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs) -> R | None: + try: + return await func(*args, **kwargs) + except httpx.RequestError as err: + if plugin_config.bison_show_network_warning: + logger.warning(f"network connection error: {type(err)}, url: {err.request.url}") + return None + except ssl.SSLError as err: + if plugin_config.bison_show_network_warning: + logger.warning(f"ssl error: {err}") + return None + except json.JSONDecodeError as err: + logger.warning(f"json error, parsing: {err.doc}") + raise err + + +class PlatformMeta(RegistryMeta): + categories: dict[Category, str] + store: dict[Target, Any] + + def __init__(cls, name, bases, namespace, **kwargs): + cls.reverse_category = {} + cls.store = {} + if hasattr(cls, "categories") and cls.categories: + for key, val in cls.categories.items(): + cls.reverse_category[val] = key + super().__init__(name, bases, namespace, **kwargs) + + +class PlatformABCMeta(PlatformMeta, ABC): ... + + +class Platform(metaclass=PlatformABCMeta, base=True): + scheduler: type[SchedulerConfig] + ctx: ProcessContext + is_common: bool + enabled: bool + name: str + has_target: bool + categories: dict[Category, str] + enable_tag: bool + platform_name: str + parse_target_promot: str | None = None + registry: list[type["Platform"]] + client: AsyncClient + reverse_category: dict[str, Category] + use_batch: bool = False + # TODO: 限定可使用的theme名称 + default_theme: str = "basic" + + @classmethod + @abstractmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: ... + + @abstractmethod + async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: ... + + async def do_fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: + return await catch_network_error(self.fetch_new_post, sub_unit) or [] + + @abstractmethod + async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: ... + + async def do_batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: + return await catch_network_error(self.batch_fetch_new_post, sub_units) or [] + + @abstractmethod + async def parse(self, raw_post: RawPost) -> Post: ... + + async def do_parse(self, raw_post: RawPost) -> Post: + "actually function called" + return await self.parse(raw_post) + + def __init__(self, context: ProcessContext, client: AsyncClient): + super().__init__() + self.client = client + self.ctx = context + + class ParseTargetException(Exception): + pass + + @classmethod + async def parse_target(cls, target_string: str) -> Target: + return Target(target_string) + + @abstractmethod + def get_tags(self, raw_post: RawPost) -> Collection[Tag] | None: + "Return Tag list of given RawPost" + + @classmethod + def get_stored_data(cls, target: Target) -> Any: + return cls.store.get(target) + + @classmethod + def set_stored_data(cls, target: Target, data: Any): + cls.store[target] = data + + def tag_separator(self, stored_tags: list[Tag]) -> tuple[list[Tag], list[Tag]]: + """返回分离好的正反tag元组""" + subscribed_tags = [] + banned_tags = [] + for tag in stored_tags: + if tag.startswith("~"): + banned_tags.append(tag.lstrip("~")) + else: + subscribed_tags.append(tag) + return subscribed_tags, banned_tags + + def is_banned_post( + self, + post_tags: Collection[Tag], + subscribed_tags: list[Tag], + banned_tags: list[Tag], + ) -> bool: + """只要存在任意屏蔽tag则返回真,此行为优先级最高。 + 存在任意被订阅tag则返回假,此行为优先级次之。 + 若被订阅tag为空,则返回假。 + """ + # 存在任意需要屏蔽的tag则为真 + if banned_tags: + for tag in post_tags or []: + if tag in banned_tags: + return True + # 检测屏蔽tag后,再检测订阅tag + # 存在任意需要订阅的tag则为假 + if subscribed_tags: + ban_it = True + for tag in post_tags or []: + if tag in subscribed_tags: + ban_it = False + return ban_it + else: + return False + + async def filter_user_custom( + self, raw_post_list: list[RawPost], cats: list[Category], tags: list[Tag] + ) -> list[RawPost]: + res: list[RawPost] = [] + for raw_post in raw_post_list: + if self.categories: + cat = self.get_category(raw_post) + if cats and cat not in cats: + continue + if self.enable_tag and tags: + raw_post_tags = self.get_tags(raw_post) + if isinstance(raw_post_tags, Collection) and self.is_banned_post( + raw_post_tags, *self.tag_separator(tags) + ): + continue + res.append(raw_post) + return res + + async def dispatch_user_post( + self, new_posts: list[RawPost], sub_unit: SubUnit + ) -> list[tuple[PlatformTarget, list[Post]]]: + res: list[tuple[PlatformTarget, list[Post]]] = [] + for user, cats, required_tags in sub_unit.user_sub_infos: + user_raw_post = await self.filter_user_custom(new_posts, cats, required_tags) + user_post: list[Post] = [] + for raw_post in user_raw_post: + user_post.append(await self.do_parse(raw_post)) + res.append((user, user_post)) + return res + + @abstractmethod + def get_category(self, post: RawPost) -> Category | None: + "Return category of given Rawpost" + raise NotImplementedError() + + +class MessageProcess(Platform, abstract=True): + "General message process fetch, parse, filter progress" + + def __init__(self, ctx: ProcessContext, client: AsyncClient): + super().__init__(ctx, client) + self.parse_cache: dict[Any, Post] = {} + + @abstractmethod + def get_id(self, post: RawPost) -> Any: + "Get post id of given RawPost" + + async def do_parse(self, raw_post: RawPost) -> Post: + post_id = self.get_id(raw_post) + if post_id not in self.parse_cache: + retry_times = 3 + while retry_times: + try: + self.parse_cache[post_id] = await self.parse(raw_post) + break + except Exception as err: + retry_times -= 1 + if not retry_times: + raise err + return self.parse_cache[post_id] + + @abstractmethod + async def get_sub_list(self, target: Target) -> list[RawPost]: + "Get post list of the given target" + raise NotImplementedError() + + @abstractmethod + async def batch_get_sub_list(self, targets: list[Target]) -> list[list[RawPost]]: + "Get post list of the given targets" + raise NotImplementedError() + + @abstractmethod + def get_date(self, post: RawPost) -> int | None: + "Get post timestamp and return, return None if can't get the time" + + async def filter_common(self, raw_post_list: list[RawPost]) -> list[RawPost]: + res = [] + for raw_post in raw_post_list: + # post_id = self.get_id(raw_post) + # if post_id in exists_posts_set: + # continue + if ( + (post_time := self.get_date(raw_post)) + and time.time() - post_time > 2 * 60 * 60 + and plugin_config.bison_init_filter + ): + continue + try: + self.get_category(raw_post) + except CategoryNotSupport as e: + logger.info("未支持解析的推文类别:" + repr(e) + ",忽略") + continue + except CategoryNotRecognize as e: + logger.warning("未知推文类别:" + repr(e)) + msgs = self.ctx.gen_req_records() + for m in msgs: + logger.warning(m) + continue + except NotImplementedError: + pass + res.append(raw_post) + return res + + +class NewMessage(MessageProcess, abstract=True): + "Fetch a list of messages, filter the new messages, dispatch it to different users" + + @dataclass + class MessageStorage: + inited: bool + exists_posts: set[Any] + + async def filter_common_with_diff(self, target: Target, raw_post_list: list[RawPost]) -> list[RawPost]: + filtered_post = await self.filter_common(raw_post_list) + store = self.get_stored_data(target) or self.MessageStorage(False, set()) + res = [] + if not store.inited and plugin_config.bison_init_filter: + # target not init + for raw_post in filtered_post: + post_id = self.get_id(raw_post) + store.exists_posts.add(post_id) + logger.info(f"init {self.platform_name}-{target} with {store.exists_posts}") + store.inited = True + else: + for raw_post in filtered_post: + post_id = self.get_id(raw_post) + if post_id in store.exists_posts: + continue + res.append(raw_post) + store.exists_posts.add(post_id) + self.set_stored_data(target, store) + return res + + async def _handle_new_post( + self, + post_list: list[RawPost], + sub_unit: SubUnit, + ) -> list[tuple[PlatformTarget, list[Post]]]: + new_posts = await self.filter_common_with_diff(sub_unit.sub_target, post_list) + if not new_posts: + return [] + else: + for post in new_posts: + logger.info( + "fetch new post from {} {}: {}".format( + self.platform_name, + sub_unit.sub_target if self.has_target else "-", + self.get_id(post), + ) + ) + res = await self.dispatch_user_post(new_posts, sub_unit) + self.parse_cache = {} + return res + + async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: + post_list = await self.get_sub_list(sub_unit.sub_target) + return await self._handle_new_post(post_list, sub_unit) + + async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: + if not self.has_target: + raise RuntimeError("Target without target should not use batch api") # pragma: no cover + posts_set = await self.batch_get_sub_list([x[0] for x in sub_units]) + res = [] + for sub_unit, posts in zip(sub_units, posts_set): + res.extend(await self._handle_new_post(posts, sub_unit)) + return res + + +class StatusChange(Platform, abstract=True): + "Watch a status, and fire a post when status changes" + + class FetchError(RuntimeError): + pass + + @abstractmethod + async def get_status(self, target: Target) -> Any: ... + + @abstractmethod + async def batch_get_status(self, targets: list[Target]) -> list[Any]: ... + + @abstractmethod + def compare_status(self, target: Target, old_status, new_status) -> list[RawPost]: ... + + @abstractmethod + async def parse(self, raw_post: RawPost) -> Post: ... + + async def _handle_status_change( + self, new_status: Any, sub_unit: SubUnit + ) -> list[tuple[PlatformTarget, list[Post]]]: + res = [] + if old_status := self.get_stored_data(sub_unit.sub_target): + diff = self.compare_status(sub_unit.sub_target, old_status, new_status) + if diff: + logger.info( + "status changes {} {}: {} -> {}".format( + self.platform_name, + sub_unit.sub_target if self.has_target else "-", + old_status, + new_status, + ) + ) + res = await self.dispatch_user_post(diff, sub_unit) + self.set_stored_data(sub_unit.sub_target, new_status) + return res + + async def fetch_new_post(self, sub_unit: SubUnit) -> list[tuple[PlatformTarget, list[Post]]]: + try: + new_status = await self.get_status(sub_unit.sub_target) + except self.FetchError as err: + logger.warning(f"fetching {self.name}-{sub_unit.sub_target} error: {err}") + raise + return await self._handle_status_change(new_status, sub_unit) + + async def batch_fetch_new_post(self, sub_units: list[SubUnit]) -> list[tuple[PlatformTarget, list[Post]]]: + if not self.has_target: + raise RuntimeError("Target without target should not use batch api") # pragma: no cover + new_statuses = await self.batch_get_status([x[0] for x in sub_units]) + res = [] + for sub_unit, new_status in zip(sub_units, new_statuses): + res.extend(await self._handle_status_change(new_status, sub_unit)) + return res + + +class SimplePost(NewMessage, abstract=True): + "Fetch a list of messages, dispatch it to different users" + + async def _handle_new_post( + self, + new_posts: list[RawPost], + sub_unit: SubUnit, + ) -> list[tuple[PlatformTarget, list[Post]]]: + if not new_posts: + return [] + else: + for post in new_posts: + logger.info( + "fetch new post from {} {}: {}".format( + self.platform_name, + sub_unit.sub_target if self.has_target else "-", + self.get_id(post), + ) + ) + res = await self.dispatch_user_post(new_posts, sub_unit) + self.parse_cache = {} + return res + + +def make_no_target_group(platform_list: list[type[Platform]]) -> type[Platform]: + if typing.TYPE_CHECKING: + + class NoTargetGroup(Platform, abstract=True): + platform_list: list[type[Platform]] + platform_obj_list: list[Platform] + + DUMMY_STR = "_DUMMY" + + platform_name = platform_list[0].platform_name + name = DUMMY_STR + categories_keys = set() + categories = {} + scheduler = platform_list[0].scheduler + + for platform in platform_list: + if platform.has_target: + raise RuntimeError(f"Platform {platform.name} should have no target") + if name == DUMMY_STR: + name = platform.name + elif name != platform.name: + raise RuntimeError(f"Platform name for {platform_name} not fit") + platform_category_key_set = set(platform.categories.keys()) + if platform_category_key_set & categories_keys: + raise RuntimeError(f"Platform categories for {platform_name} duplicate") + categories_keys |= platform_category_key_set + categories.update(platform.categories) + if platform.scheduler != scheduler: + raise RuntimeError(f"Platform scheduler for {platform_name} not fit") + + def __init__(self: "NoTargetGroup", ctx: ProcessContext, client: AsyncClient): + Platform.__init__(self, ctx, client) + self.platform_obj_list = [] + for platform_class in self.platform_list: + self.platform_obj_list.append(platform_class(ctx, client)) + + def __str__(self: "NoTargetGroup") -> str: + return "[" + " ".join(x.name for x in self.platform_list) + "]" + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target): + return await platform_list[0].get_target_name(client, target) + + async def fetch_new_post(self: "NoTargetGroup", sub_unit: SubUnit): + res = defaultdict(list) + for platform in self.platform_obj_list: + platform_res = await platform.fetch_new_post(sub_unit) + for user, posts in platform_res: + res[user].extend(posts) + return [[key, val] for key, val in res.items()] + + return type( + "NoTargetGroup", + (Platform,), + { + "platform_list": platform_list, + "platform_name": platform_list[0].platform_name, + "name": name, + "categories": categories, + "scheduler": scheduler, + "is_common": platform_list[0].is_common, + "enabled": True, + "has_target": False, + "enable_tag": False, + "__init__": __init__, + "get_target_name": get_target_name, + "fetch_new_post": fetch_new_post, + }, + abstract=True, + ) diff --git a/src/plugins/nonebot_bison/platform/rss.py b/src/plugins/nonebot_bison/platform/rss.py new file mode 100644 index 00000000..a7af5929 --- /dev/null +++ b/src/plugins/nonebot_bison/platform/rss.py @@ -0,0 +1,81 @@ +import time +import calendar +from typing import Any + +import feedparser +from httpx import AsyncClient +from bs4 import BeautifulSoup as bs + +from ..post import Post +from .platform import NewMessage +from ..types import Target, RawPost +from ..utils import SchedulerConfig, text_similarity + + +class RssSchedConf(SchedulerConfig): + name = "rss" + schedule_type = "interval" + schedule_setting = {"seconds": 30} + + +class Rss(NewMessage): + categories = {} + enable_tag = False + platform_name = "rss" + name = "Rss" + enabled = True + is_common = True + scheduler = RssSchedConf + has_target = True + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + res = await client.get(target, timeout=10.0) + feed = feedparser.parse(res.text) + return feed["feed"]["title"] + + def get_date(self, post: RawPost) -> int: + if hasattr(post, "published_parsed"): + return calendar.timegm(post.published_parsed) + elif hasattr(post, "updated_parsed"): + return calendar.timegm(post.updated_parsed) + else: + return calendar.timegm(time.gmtime()) + + def get_id(self, post: RawPost) -> Any: + return post.id + + async def get_sub_list(self, target: Target) -> list[RawPost]: + res = await self.client.get(target, timeout=10.0) + feed = feedparser.parse(res) + entries = feed.entries + for entry in entries: + entry["_target_name"] = feed.feed.title + return feed.entries + + def _text_process(self, title: str, desc: str) -> tuple[str | None, str]: + """检查标题和描述是否相似,如果相似则标题为None, 否则返回标题和描述""" + similarity = 1.0 if len(title) == 0 or len(desc) == 0 else text_similarity(title, desc) + if similarity > 0.8: + return None, title if len(title) > len(desc) else desc + + return title, desc + + async def parse(self, raw_post: RawPost) -> Post: + title = raw_post.get("title", "") + soup = bs(raw_post.description, "html.parser") + desc = soup.text.strip() + title, desc = self._text_process(title, desc) + pics = [x.attrs["src"] for x in soup("img")] + if raw_post.get("media_content"): + for media in raw_post["media_content"]: + if media.get("medium") == "image" and media.get("url"): + pics.append(media.get("url")) + return Post( + self, + desc, + title=title, + url=raw_post.link, + images=pics, + nickname=raw_post["_target_name"], + ) diff --git a/src/plugins/nonebot_bison/platform/weibo.py b/src/plugins/nonebot_bison/platform/weibo.py new file mode 100644 index 00000000..54c2a52a --- /dev/null +++ b/src/plugins/nonebot_bison/platform/weibo.py @@ -0,0 +1,191 @@ +import re +import json +from typing import Any +from datetime import datetime +from urllib.parse import unquote + +from yarl import URL +from lxml import etree +from httpx import AsyncClient +from nonebot.log import logger +from bs4 import BeautifulSoup as bs + +from ..post import Post +from .platform import NewMessage +from ..utils import SchedulerConfig, http_client +from ..types import Tag, Target, RawPost, ApiError, Category + +_HEADER = { + "accept": ( + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng," + "*/*;q=0.8,application/signed-exchange;v=b3;q=0.9" + ), + "accept-language": "zh-CN,zh;q=0.9", + "authority": "m.weibo.cn", + "cache-control": "max-age=0", + "sec-fetch-dest": "empty", + "sec-fetch-mode": "same-origin", + "sec-fetch-site": "same-origin", + "upgrade-insecure-requests": "1", + "user-agent": ( + "Mozilla/5.0 (Linux; Android 6.0; Nexus 5 Build/MRA58N) " + "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/89.0.4389.72 " + "Mobile Safari/537.36" + ), +} + + +class WeiboSchedConf(SchedulerConfig): + name = "weibo.com" + schedule_type = "interval" + schedule_setting = {"seconds": 3} + + +class Weibo(NewMessage): + categories = { + 1: "转发", + 2: "视频", + 3: "图文", + 4: "文字", + } + enable_tag = True + platform_name = "weibo" + name = "新浪微博" + enabled = True + is_common = True + scheduler = WeiboSchedConf + has_target = True + parse_target_promot = "请输入用户主页(包含数字UID)的链接" + + @classmethod + async def get_target_name(cls, client: AsyncClient, target: Target) -> str | None: + param = {"containerid": "100505" + target} + res = await client.get("https://m.weibo.cn/api/container/getIndex", params=param) + res_dict = json.loads(res.text) + if res_dict.get("ok") == 1: + return res_dict["data"]["userInfo"]["screen_name"] + else: + return None + + @classmethod + async def parse_target(cls, target_text: str) -> Target: + if re.match(r"\d+", target_text): + return Target(target_text) + elif match := re.match(r"(?:https?://)?weibo\.com/u/(\d+)", target_text): + # 都2202年了应该不会有http了吧,不过还是防一手 + return Target(match.group(1)) + else: + raise cls.ParseTargetException() + + async def get_sub_list(self, target: Target) -> list[RawPost]: + params = {"containerid": "107603" + target} + res = await self.client.get("https://m.weibo.cn/api/container/getIndex?", params=params, timeout=4.0) + res_data = json.loads(res.text) + if not res_data["ok"] and res_data["msg"] != "这里还没有内容": + raise ApiError(res.request.url) + + def custom_filter(d: RawPost) -> bool: + return d["card_type"] == 9 + + return list(filter(custom_filter, res_data["data"]["cards"])) + + def get_id(self, post: RawPost) -> Any: + return post["mblog"]["id"] + + def filter_platform_custom(self, raw_post: RawPost) -> bool: + return raw_post["card_type"] == 9 + + def get_date(self, raw_post: RawPost) -> float: + created_time = datetime.strptime(raw_post["mblog"]["created_at"], "%a %b %d %H:%M:%S %z %Y") + return created_time.timestamp() + + def get_tags(self, raw_post: RawPost) -> list[Tag] | None: + "Return Tag list of given RawPost" + text = raw_post["mblog"]["text"] + soup = bs(text, "html.parser") + res = [ + x[1:-1] + for x in filter( + lambda s: s[0] == "#" and s[-1] == "#", + (x.text for x in soup.find_all("span", class_="surl-text")), + ) + ] + super_topic_img = soup.find("img", src=re.compile(r"timeline_card_small_super_default")) + if super_topic_img: + try: + res.append(super_topic_img.parent.parent.find("span", class_="surl-text").text + "超话") # type: ignore + except Exception: + logger.info(f"super_topic extract error: {text}") + return res + + def get_category(self, raw_post: RawPost) -> Category: + if raw_post["mblog"].get("retweeted_status"): + return Category(1) + elif raw_post["mblog"].get("page_info") and raw_post["mblog"]["page_info"].get("type") == "video": + return Category(2) + elif raw_post["mblog"].get("pics"): + return Category(3) + else: + return Category(4) + + def _get_text(self, raw_text: str) -> str: + text = raw_text.replace("