From d85c18a62acedecc0db6da673c85344e2b607038 Mon Sep 17 00:00:00 2001 From: LWR Date: Sat, 18 Feb 2023 00:29:02 +0800 Subject: [PATCH] refactor: Remove key attribute from PushTarget and rename groups table name --- starbot/core/bot.py | 8 +-- starbot/core/datasource.py | 104 ++++++++++++------------------------- starbot/core/model.py | 7 +-- starbot/core/server.py | 66 +++++++++++++++++------ starbot/utils/config.py | 6 ++- starbot/utils/redis.py | 4 ++ 6 files changed, 99 insertions(+), 96 deletions(-) diff --git a/starbot/core/bot.py b/starbot/core/bot.py index 14691f6..465569b 100644 --- a/starbot/core/bot.py +++ b/starbot/core/bot.py @@ -59,6 +59,10 @@ class StarBot: logger.error(ex.msg) return + if not self.__datasource.bots: + logger.error("数据源配置为空, 请先在数据源中配置完毕后再重新运行") + return + # 连接 Redis try: await redis.init() @@ -120,10 +124,6 @@ class StarBot: logger.success("用户自定义命令模块载入完毕") # 启动消息推送模块 - if not self.__datasource.bots: - logger.error("不存在需要启动的 Bot 账号, 请先在数据源中配置完毕后再重新运行") - return - Ariadne.options["default_account"] = self.__datasource.bots[0].qq logger.info("开始运行 Ariadne 消息推送模块") diff --git a/starbot/core/datasource.py b/starbot/core/datasource.py index 0ebd5d6..d2e23cc 100644 --- a/starbot/core/datasource.py +++ b/starbot/core/datasource.py @@ -24,9 +24,6 @@ class DataSource(metaclass=abc.ABCMeta): self.__up_list: List[Up] = [] self.__up_map: Dict[int, Up] = {} self.__uid_list: List[int] = [] - self.__target_list: List[PushTarget] = [] - self.__target_key_map: Dict[str, PushTarget] = {} - self.__target_bot_map: Dict[str, Bot] = {} @abc.abstractmethod async def load(self): @@ -47,13 +44,6 @@ class DataSource(metaclass=abc.ABCMeta): self.__uid_list = list(self.__up_map.keys()) if len(set(self.__uid_list)) < len(self.__uid_list): raise DataSourceException("配置中不可含有重复的 UID") - self.__target_list = [x for target in map(lambda up: up.targets, self.__up_list) for x in target] - self.__target_key_map = dict(zip(map(lambda target: target.key, self.__target_list), self.__target_list)) - - for bot in self.bots: - for up in bot.ups: - for target in up.targets: - self.__target_bot_map[target.key] = bot def get_up_list(self) -> List[Up]: """ @@ -91,12 +81,12 @@ class DataSource(metaclass=abc.ABCMeta): raise DataSourceException(f"不存在的 UID: {uid}") return up - def get_bot(self, qq: int) -> Bot: + def get_bot(self, qq: Optional[int] = None) -> Bot: """ 根据 QQ 获取 Bot 实例 Args: - qq: 需要获取 Bot 的 QQ + qq: 需要获取 Bot 的 QQ,单 Bot 推送时可不传入 Returns: Bot 实例 @@ -104,6 +94,11 @@ class DataSource(metaclass=abc.ABCMeta): Raises: DataSourceException: QQ 不存在 """ + if qq is None: + if len(self.bots) != 1: + raise DataSourceException(f"多 Bot 推送时需明确指定要获取的 Bot QQ") + return self.bots[0] + bot = next((b for b in self.bots if b.qq == qq), None) if bot is None: raise DataSourceException(f"不存在的 QQ: {qq}") @@ -130,42 +125,6 @@ class DataSource(metaclass=abc.ABCMeta): return ups - def get_target_by_key(self, key: str) -> PushTarget: - """ - 根据推送 key 获取 PushTarget 实例,用于 HTTP API 推送 - - Args: - key: 需要获取 PushTarget 的推送 key - - Returns: - PushTarget 实例 - - Raises: - DataSourceException: key 不存在 - """ - target = self.__target_key_map.get(key) - if target is None: - raise DataSourceException(f"不存在的推送 key: {key}") - return target - - def get_bot_by_key(self, key: str) -> Bot: - """ - 根据推送 key 获取其所在的 Bot 实例,用于 HTTP API 推送 - - Args: - key: 需要获取所在 Bot 的推送 key - - Returns: - Bot 实例 - - Raises: - DataSourceException: key 不存在 - """ - bot = self.__target_bot_map.get(key) - if bot is None: - raise DataSourceException(f"不存在的推送 key: {key}") - return bot - async def wait_for_connects(self): """ 等待所有 Up 实例连接直播间完毕 @@ -301,37 +260,37 @@ class MySQLDataSource(DataSource): 推送目标列表 """ live_on = await self.__query( - "SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` " - "FROM `groups` AS `g` LEFT JOIN `live_on` AS `l` " - "ON g.`uid` = l.`uid` AND g.`index` = l.`index` " - f"WHERE g.`uid` = {uid} " - "ORDER BY g.`index`" + "SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` " + "FROM `targets` AS `t` LEFT JOIN `live_on` AS `l` " + "ON t.`uid` = l.`uid` AND t.`id` = l.`id` " + f"WHERE t.`uid` = {uid} " + "ORDER BY t.`id`" ) live_off = await self.__query( - "SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` " - "FROM `groups` AS `g` LEFT JOIN `live_off` AS `l` " - "ON g.`uid` = l.`uid` AND g.`index` = l.`index` " - f"WHERE g.`uid` = {uid} " - "ORDER BY g.`index`" + "SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` " + "FROM `targets` AS `t` LEFT JOIN `live_off` AS `l` " + "ON t.`uid` = l.`uid` AND t.`id` = l.`id` " + f"WHERE t.`uid` = {uid} " + "ORDER BY t.`id`" ) live_report = await self.__query( - "SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, " + "SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, " "`enabled`, `logo`, `logo_base64`, `time`, `fans_change`, `fans_medal_change`, `guard_change`, " "`danmu`, `box`, `gift`, `sc`, `guard`, " "`danmu_ranking`, `box_ranking`, `box_profit_ranking`, `gift_ranking`, `sc_ranking`, " "`guard_list`, `box_profit_diagram`, `danmu_diagram`, `box_diagram`, `gift_diagram`, " "`sc_diagram`, `guard_diagram`, `danmu_cloud` " - "FROM `groups` AS `g` LEFT JOIN `live_report` AS `l` " - "ON g.`uid` = l.`uid` AND g.`index` = l.`index` " - f"WHERE g.`uid` = {uid} " - "ORDER BY g.`index`" + "FROM `targets` AS `t` LEFT JOIN `live_report` AS `l` " + "ON t.`uid` = l.`uid` AND t.`id` = l.`id` " + f"WHERE t.`uid` = {uid} " + "ORDER BY t.`id`" ) dynamic_update = await self.__query( - "SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` " - "FROM `groups` AS `g` LEFT JOIN `dynamic_update` AS `d` " - "ON g.`uid` = d.`uid` AND g.`index` = d.`index` " - f"WHERE g.`uid` = {uid} " - "ORDER BY g.`index`" + "SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` " + "FROM `targets` AS `t` LEFT JOIN `dynamic_update` AS `d` " + "ON t.`uid` = d.`uid` AND t.`id` = d.`id` " + f"WHERE t.`uid` = {uid} " + "ORDER BY t.`id`" ) targets = [] @@ -430,15 +389,20 @@ class MySQLDataSource(DataSource): Args: uid: 需要追加读取配置的 UID """ + if uid in self.get_uid_list(): + raise DataSourceException(f"载入 UID: {uid} 的推送配置失败, 不可重复载入") + user = await self.__query(f"SELECT * FROM `bot` WHERE uid = {uid}") if len(user) == 0: logger.error(f"载入 UID: {uid} 的推送配置失败, UID 不存在") raise DataSourceException(f"载入 UID: {uid} 的推送配置失败, UID 不存在") - bot = user[0].get("bot") + qq = user[0].get("bot") targets = await self.__load_targets(uid) up = Up(uid=uid, targets=targets) - self.get_bot(bot).ups.append(up) + bot = self.get_bot(qq) + bot.ups.append(up) + up.inject_bot(bot) super().format_data() logger.success(f"已成功载入 UID: {uid} 的推送配置") diff --git a/starbot/core/model.py b/starbot/core/model.py index 88fc276..253d48f 100644 --- a/starbot/core/model.py +++ b/starbot/core/model.py @@ -264,13 +264,8 @@ class PushTarget(BaseModel): dynamic_update: Optional[DynamicUpdate] = DynamicUpdate() """动态推送配置。默认:DynamicUpdate()""" - key: Optional[str] = None - """推送 Key,可选功能,可使用此 Key 通过 HTTP API 向对应的好友或群推送消息。默认:str(id)-str(type)""" - def __init__(self, **data: Any): super().__init__(**data) - if not self.key: - self.key = "-".join([str(self.id), str(self.type.value)]) self.__raise_for_not_invalid_placeholders() def __raise_for_not_invalid_placeholders(self): @@ -287,7 +282,7 @@ class PushTarget(BaseModel): return False def __hash__(self): - return hash(self.key) + return hash(self.id) ^ hash(self.type.value) class Message(BaseModel): diff --git a/starbot/core/server.py b/starbot/core/server.py index 1e0afd6..a7c96af 100644 --- a/starbot/core/server.py +++ b/starbot/core/server.py @@ -3,10 +3,11 @@ from typing import Optional import aiohttp from aiohttp import web from aiohttp.web_routedef import RouteTableDef +from graia.ariadne.exception import UnknownTarget from loguru import logger from .datasource import DataSource -from .model import Message +from .model import Message, PushType from ..exception import DataSourceException from ..utils import config @@ -14,21 +15,55 @@ routes = web.RouteTableDef() datasource: Optional[DataSource] = None -@routes.get("/send/{key}/{message}") -async def send(request: aiohttp.web.Request) -> aiohttp.web.Response: - key = request.match_info['key'] - message = request.match_info['message'] +@routes.get("/send/{type}/{key}/{message}") +async def send(request: aiohttp.web.Request, qq: int = None) -> aiohttp.web.Response: + if len(datasource.bots) == 1: + bot = datasource.get_bot() + else: + if qq is None: + qq = config.get("HTTP_API_DEAFULT_BOT") + if qq is None: + logger.warning("HTTP API 推送失败, 多 Bot 推送时使用 HTTP API 需填写 HTTP_API_DEAFULT_BOT 配置项") + return web.Response(text="fail") + + try: + bot = datasource.get_bot(qq) + except DataSourceException: + logger.warning("HTTP API 推送失败, 填写的 HTTP_API_DEAFULT_BOT 配置项不正确") + return web.Response(text="fail") + + if not str(request.match_info['key']).isdigit(): + logger.warning("HTTP API 推送失败, 传入的 QQ 或群号格式不正确") + return web.Response(text="fail") + + type_map = { + "friend": PushType.Friend, + "group": PushType.Group + } + _type = type_map.get(str(request.match_info['type']), None) + if _type is None: + logger.warning("HTTP API 推送失败, 传入的推送类型格式不正确") + return web.Response(text="fail") + + key = int(request.match_info['key']) + message = Message(id=key, content=str(request.match_info['message']), type=_type) try: - target = datasource.get_target_by_key(key) - bot = datasource.get_bot_by_key(key) - msg = Message(id=target.id, content=message, type=target.type) - await bot.send_message(msg) - return web.Response(text="success") - except DataSourceException: - logger.warning(f"HTTP API 推送失败, 不存在的推送 key: {key}") + await bot.send_message(message) + except UnknownTarget: + pass + + return web.Response(text="success") + + +@routes.get("/send/{bot}/{type}/{key}/{message}") +async def send_by_bot(request: aiohttp.web.Request) -> aiohttp.web.Response: + if not str(request.match_info['bot']).isdigit(): + logger.warning("HTTP API 推送失败, 传入的 Bot QQ 格式不正确") return web.Response(text="fail") + return await send(request, int(request.match_info['bot'])) + def get_routes() -> RouteTableDef: """ @@ -43,6 +78,7 @@ def get_routes() -> RouteTableDef: async def http_init(source: DataSource): global datasource datasource = source + port = config.get("HTTP_API_PORT") logger.info("开始启动 HTTP API 推送服务") @@ -50,10 +86,10 @@ async def http_init(source: DataSource): app.add_routes(routes) runner = web.AppRunner(app) await runner.setup() - site = web.TCPSite(runner, 'localhost', config.get("HTTP_API_PORT")) + site = web.TCPSite(runner, 'localhost', port) try: await site.start() except OSError: - logger.error(f"设定的 HTTP API 端口 {config.get('HTTP_API_PORT')} 已被占用, HTTP API 推送服务启动失败") + logger.error(f"设定的 HTTP API 端口 {port} 已被占用, HTTP API 推送服务启动失败") return - logger.success("成功启动 HTTP API 推送服务") + logger.success(f"成功启动 HTTP API 推送服务: http://localhost:{port}") diff --git a/starbot/utils/config.py b/starbot/utils/config.py index ffe759b..7e3efe1 100644 --- a/starbot/utils/config.py +++ b/starbot/utils/config.py @@ -81,6 +81,8 @@ SIMPLE_CONFIG = { "USE_HTTP_API": False, # HTTP API 端口 "HTTP_API_PORT": 8088, + # 默认 HTTP API 推送 Bot QQ,多 Bot 推送时必填 + "HTTP_API_DEAFULT_BOT": None, # 命令触发前缀 "COMMAND_PREFIX": "", @@ -175,9 +177,11 @@ FULL_CONFIG = { "PROXY": "", # 是否使用 HTTP API 推送 - "USE_HTTP_API": True, + "USE_HTTP_API": False, # HTTP API 端口 "HTTP_API_PORT": 8088, + # 默认 HTTP API 推送 Bot QQ,多 Bot 推送时必填 + "HTTP_API_DEAFULT_BOT": None, # 命令触发前缀 "COMMAND_PREFIX": "", diff --git a/starbot/utils/redis.py b/starbot/utils/redis.py index c83aeef..8d7c8e3 100644 --- a/starbot/utils/redis.py +++ b/starbot/utils/redis.py @@ -121,6 +121,10 @@ async def hincrbyfloat(key: str, hkey: Union[str, int], value: float = 1.0) -> f return await __redis.hincrbyfloat(key, hkey, value) +async def hdel(key: str, hkey: Union[str, int]): + await __redis.hdel(key, hkey) + + # Set async def scard(key: str) -> int: