refactor: Remove key attribute from PushTarget and rename groups table name

This commit is contained in:
LWR 2023-02-18 00:29:02 +08:00
parent 063608a3ad
commit d85c18a62a
6 changed files with 99 additions and 96 deletions

View File

@ -59,6 +59,10 @@ class StarBot:
logger.error(ex.msg)
return
if not self.__datasource.bots:
logger.error("数据源配置为空, 请先在数据源中配置完毕后再重新运行")
return
# 连接 Redis
try:
await redis.init()
@ -120,10 +124,6 @@ class StarBot:
logger.success("用户自定义命令模块载入完毕")
# 启动消息推送模块
if not self.__datasource.bots:
logger.error("不存在需要启动的 Bot 账号, 请先在数据源中配置完毕后再重新运行")
return
Ariadne.options["default_account"] = self.__datasource.bots[0].qq
logger.info("开始运行 Ariadne 消息推送模块")

View File

@ -24,9 +24,6 @@ class DataSource(metaclass=abc.ABCMeta):
self.__up_list: List[Up] = []
self.__up_map: Dict[int, Up] = {}
self.__uid_list: List[int] = []
self.__target_list: List[PushTarget] = []
self.__target_key_map: Dict[str, PushTarget] = {}
self.__target_bot_map: Dict[str, Bot] = {}
@abc.abstractmethod
async def load(self):
@ -47,13 +44,6 @@ class DataSource(metaclass=abc.ABCMeta):
self.__uid_list = list(self.__up_map.keys())
if len(set(self.__uid_list)) < len(self.__uid_list):
raise DataSourceException("配置中不可含有重复的 UID")
self.__target_list = [x for target in map(lambda up: up.targets, self.__up_list) for x in target]
self.__target_key_map = dict(zip(map(lambda target: target.key, self.__target_list), self.__target_list))
for bot in self.bots:
for up in bot.ups:
for target in up.targets:
self.__target_bot_map[target.key] = bot
def get_up_list(self) -> List[Up]:
"""
@ -91,12 +81,12 @@ class DataSource(metaclass=abc.ABCMeta):
raise DataSourceException(f"不存在的 UID: {uid}")
return up
def get_bot(self, qq: int) -> Bot:
def get_bot(self, qq: Optional[int] = None) -> Bot:
"""
根据 QQ 获取 Bot 实例
Args:
qq: 需要获取 Bot QQ
qq: 需要获取 Bot QQ Bot 推送时可不传入
Returns:
Bot 实例
@ -104,6 +94,11 @@ class DataSource(metaclass=abc.ABCMeta):
Raises:
DataSourceException: QQ 不存在
"""
if qq is None:
if len(self.bots) != 1:
raise DataSourceException(f"多 Bot 推送时需明确指定要获取的 Bot QQ")
return self.bots[0]
bot = next((b for b in self.bots if b.qq == qq), None)
if bot is None:
raise DataSourceException(f"不存在的 QQ: {qq}")
@ -130,42 +125,6 @@ class DataSource(metaclass=abc.ABCMeta):
return ups
def get_target_by_key(self, key: str) -> PushTarget:
"""
根据推送 key 获取 PushTarget 实例用于 HTTP API 推送
Args:
key: 需要获取 PushTarget 的推送 key
Returns:
PushTarget 实例
Raises:
DataSourceException: key 不存在
"""
target = self.__target_key_map.get(key)
if target is None:
raise DataSourceException(f"不存在的推送 key: {key}")
return target
def get_bot_by_key(self, key: str) -> Bot:
"""
根据推送 key 获取其所在的 Bot 实例用于 HTTP API 推送
Args:
key: 需要获取所在 Bot 的推送 key
Returns:
Bot 实例
Raises:
DataSourceException: key 不存在
"""
bot = self.__target_bot_map.get(key)
if bot is None:
raise DataSourceException(f"不存在的推送 key: {key}")
return bot
async def wait_for_connects(self):
"""
等待所有 Up 实例连接直播间完毕
@ -301,37 +260,37 @@ class MySQLDataSource(DataSource):
推送目标列表
"""
live_on = await self.__query(
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` "
"FROM `groups` AS `g` LEFT JOIN `live_on` AS `l` "
"ON g.`uid` = l.`uid` AND g.`index` = l.`index` "
f"WHERE g.`uid` = {uid} "
"ORDER BY g.`index`"
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` "
"FROM `targets` AS `t` LEFT JOIN `live_on` AS `l` "
"ON t.`uid` = l.`uid` AND t.`id` = l.`id` "
f"WHERE t.`uid` = {uid} "
"ORDER BY t.`id`"
)
live_off = await self.__query(
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` "
"FROM `groups` AS `g` LEFT JOIN `live_off` AS `l` "
"ON g.`uid` = l.`uid` AND g.`index` = l.`index` "
f"WHERE g.`uid` = {uid} "
"ORDER BY g.`index`"
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` "
"FROM `targets` AS `t` LEFT JOIN `live_off` AS `l` "
"ON t.`uid` = l.`uid` AND t.`id` = l.`id` "
f"WHERE t.`uid` = {uid} "
"ORDER BY t.`id`"
)
live_report = await self.__query(
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, "
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, "
"`enabled`, `logo`, `logo_base64`, `time`, `fans_change`, `fans_medal_change`, `guard_change`, "
"`danmu`, `box`, `gift`, `sc`, `guard`, "
"`danmu_ranking`, `box_ranking`, `box_profit_ranking`, `gift_ranking`, `sc_ranking`, "
"`guard_list`, `box_profit_diagram`, `danmu_diagram`, `box_diagram`, `gift_diagram`, "
"`sc_diagram`, `guard_diagram`, `danmu_cloud` "
"FROM `groups` AS `g` LEFT JOIN `live_report` AS `l` "
"ON g.`uid` = l.`uid` AND g.`index` = l.`index` "
f"WHERE g.`uid` = {uid} "
"ORDER BY g.`index`"
"FROM `targets` AS `t` LEFT JOIN `live_report` AS `l` "
"ON t.`uid` = l.`uid` AND t.`id` = l.`id` "
f"WHERE t.`uid` = {uid} "
"ORDER BY t.`id`"
)
dynamic_update = await self.__query(
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` "
"FROM `groups` AS `g` LEFT JOIN `dynamic_update` AS `d` "
"ON g.`uid` = d.`uid` AND g.`index` = d.`index` "
f"WHERE g.`uid` = {uid} "
"ORDER BY g.`index`"
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` "
"FROM `targets` AS `t` LEFT JOIN `dynamic_update` AS `d` "
"ON t.`uid` = d.`uid` AND t.`id` = d.`id` "
f"WHERE t.`uid` = {uid} "
"ORDER BY t.`id`"
)
targets = []
@ -430,15 +389,20 @@ class MySQLDataSource(DataSource):
Args:
uid: 需要追加读取配置的 UID
"""
if uid in self.get_uid_list():
raise DataSourceException(f"载入 UID: {uid} 的推送配置失败, 不可重复载入")
user = await self.__query(f"SELECT * FROM `bot` WHERE uid = {uid}")
if len(user) == 0:
logger.error(f"载入 UID: {uid} 的推送配置失败, UID 不存在")
raise DataSourceException(f"载入 UID: {uid} 的推送配置失败, UID 不存在")
bot = user[0].get("bot")
qq = user[0].get("bot")
targets = await self.__load_targets(uid)
up = Up(uid=uid, targets=targets)
self.get_bot(bot).ups.append(up)
bot = self.get_bot(qq)
bot.ups.append(up)
up.inject_bot(bot)
super().format_data()
logger.success(f"已成功载入 UID: {uid} 的推送配置")

View File

@ -264,13 +264,8 @@ class PushTarget(BaseModel):
dynamic_update: Optional[DynamicUpdate] = DynamicUpdate()
"""动态推送配置。默认DynamicUpdate()"""
key: Optional[str] = None
"""推送 Key可选功能可使用此 Key 通过 HTTP API 向对应的好友或群推送消息。默认str(id)-str(type)"""
def __init__(self, **data: Any):
super().__init__(**data)
if not self.key:
self.key = "-".join([str(self.id), str(self.type.value)])
self.__raise_for_not_invalid_placeholders()
def __raise_for_not_invalid_placeholders(self):
@ -287,7 +282,7 @@ class PushTarget(BaseModel):
return False
def __hash__(self):
return hash(self.key)
return hash(self.id) ^ hash(self.type.value)
class Message(BaseModel):

View File

@ -3,10 +3,11 @@ from typing import Optional
import aiohttp
from aiohttp import web
from aiohttp.web_routedef import RouteTableDef
from graia.ariadne.exception import UnknownTarget
from loguru import logger
from .datasource import DataSource
from .model import Message
from .model import Message, PushType
from ..exception import DataSourceException
from ..utils import config
@ -14,21 +15,55 @@ routes = web.RouteTableDef()
datasource: Optional[DataSource] = None
@routes.get("/send/{key}/{message}")
async def send(request: aiohttp.web.Request) -> aiohttp.web.Response:
key = request.match_info['key']
message = request.match_info['message']
@routes.get("/send/{type}/{key}/{message}")
async def send(request: aiohttp.web.Request, qq: int = None) -> aiohttp.web.Response:
if len(datasource.bots) == 1:
bot = datasource.get_bot()
else:
if qq is None:
qq = config.get("HTTP_API_DEAFULT_BOT")
if qq is None:
logger.warning("HTTP API 推送失败, 多 Bot 推送时使用 HTTP API 需填写 HTTP_API_DEAFULT_BOT 配置项")
return web.Response(text="fail")
try:
bot = datasource.get_bot(qq)
except DataSourceException:
logger.warning("HTTP API 推送失败, 填写的 HTTP_API_DEAFULT_BOT 配置项不正确")
return web.Response(text="fail")
if not str(request.match_info['key']).isdigit():
logger.warning("HTTP API 推送失败, 传入的 QQ 或群号格式不正确")
return web.Response(text="fail")
type_map = {
"friend": PushType.Friend,
"group": PushType.Group
}
_type = type_map.get(str(request.match_info['type']), None)
if _type is None:
logger.warning("HTTP API 推送失败, 传入的推送类型格式不正确")
return web.Response(text="fail")
key = int(request.match_info['key'])
message = Message(id=key, content=str(request.match_info['message']), type=_type)
try:
target = datasource.get_target_by_key(key)
bot = datasource.get_bot_by_key(key)
msg = Message(id=target.id, content=message, type=target.type)
await bot.send_message(msg)
return web.Response(text="success")
except DataSourceException:
logger.warning(f"HTTP API 推送失败, 不存在的推送 key: {key}")
await bot.send_message(message)
except UnknownTarget:
pass
return web.Response(text="success")
@routes.get("/send/{bot}/{type}/{key}/{message}")
async def send_by_bot(request: aiohttp.web.Request) -> aiohttp.web.Response:
if not str(request.match_info['bot']).isdigit():
logger.warning("HTTP API 推送失败, 传入的 Bot QQ 格式不正确")
return web.Response(text="fail")
return await send(request, int(request.match_info['bot']))
def get_routes() -> RouteTableDef:
"""
@ -43,6 +78,7 @@ def get_routes() -> RouteTableDef:
async def http_init(source: DataSource):
global datasource
datasource = source
port = config.get("HTTP_API_PORT")
logger.info("开始启动 HTTP API 推送服务")
@ -50,10 +86,10 @@ async def http_init(source: DataSource):
app.add_routes(routes)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, 'localhost', config.get("HTTP_API_PORT"))
site = web.TCPSite(runner, 'localhost', port)
try:
await site.start()
except OSError:
logger.error(f"设定的 HTTP API 端口 {config.get('HTTP_API_PORT')} 已被占用, HTTP API 推送服务启动失败")
logger.error(f"设定的 HTTP API 端口 {port} 已被占用, HTTP API 推送服务启动失败")
return
logger.success("成功启动 HTTP API 推送服务")
logger.success(f"成功启动 HTTP API 推送服务: http://localhost:{port}")

View File

@ -81,6 +81,8 @@ SIMPLE_CONFIG = {
"USE_HTTP_API": False,
# HTTP API 端口
"HTTP_API_PORT": 8088,
# 默认 HTTP API 推送 Bot QQ多 Bot 推送时必填
"HTTP_API_DEAFULT_BOT": None,
# 命令触发前缀
"COMMAND_PREFIX": "",
@ -175,9 +177,11 @@ FULL_CONFIG = {
"PROXY": "",
# 是否使用 HTTP API 推送
"USE_HTTP_API": True,
"USE_HTTP_API": False,
# HTTP API 端口
"HTTP_API_PORT": 8088,
# 默认 HTTP API 推送 Bot QQ多 Bot 推送时必填
"HTTP_API_DEAFULT_BOT": None,
# 命令触发前缀
"COMMAND_PREFIX": "",

View File

@ -121,6 +121,10 @@ async def hincrbyfloat(key: str, hkey: Union[str, int], value: float = 1.0) -> f
return await __redis.hincrbyfloat(key, hkey, value)
async def hdel(key: str, hkey: Union[str, int]):
await __redis.hdel(key, hkey)
# Set
async def scard(key: str) -> int: