refactor: Remove key attribute from PushTarget and rename groups table name
This commit is contained in:
parent
063608a3ad
commit
d85c18a62a
@ -59,6 +59,10 @@ class StarBot:
|
||||
logger.error(ex.msg)
|
||||
return
|
||||
|
||||
if not self.__datasource.bots:
|
||||
logger.error("数据源配置为空, 请先在数据源中配置完毕后再重新运行")
|
||||
return
|
||||
|
||||
# 连接 Redis
|
||||
try:
|
||||
await redis.init()
|
||||
@ -120,10 +124,6 @@ class StarBot:
|
||||
logger.success("用户自定义命令模块载入完毕")
|
||||
|
||||
# 启动消息推送模块
|
||||
if not self.__datasource.bots:
|
||||
logger.error("不存在需要启动的 Bot 账号, 请先在数据源中配置完毕后再重新运行")
|
||||
return
|
||||
|
||||
Ariadne.options["default_account"] = self.__datasource.bots[0].qq
|
||||
|
||||
logger.info("开始运行 Ariadne 消息推送模块")
|
||||
|
@ -24,9 +24,6 @@ class DataSource(metaclass=abc.ABCMeta):
|
||||
self.__up_list: List[Up] = []
|
||||
self.__up_map: Dict[int, Up] = {}
|
||||
self.__uid_list: List[int] = []
|
||||
self.__target_list: List[PushTarget] = []
|
||||
self.__target_key_map: Dict[str, PushTarget] = {}
|
||||
self.__target_bot_map: Dict[str, Bot] = {}
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self):
|
||||
@ -47,13 +44,6 @@ class DataSource(metaclass=abc.ABCMeta):
|
||||
self.__uid_list = list(self.__up_map.keys())
|
||||
if len(set(self.__uid_list)) < len(self.__uid_list):
|
||||
raise DataSourceException("配置中不可含有重复的 UID")
|
||||
self.__target_list = [x for target in map(lambda up: up.targets, self.__up_list) for x in target]
|
||||
self.__target_key_map = dict(zip(map(lambda target: target.key, self.__target_list), self.__target_list))
|
||||
|
||||
for bot in self.bots:
|
||||
for up in bot.ups:
|
||||
for target in up.targets:
|
||||
self.__target_bot_map[target.key] = bot
|
||||
|
||||
def get_up_list(self) -> List[Up]:
|
||||
"""
|
||||
@ -91,12 +81,12 @@ class DataSource(metaclass=abc.ABCMeta):
|
||||
raise DataSourceException(f"不存在的 UID: {uid}")
|
||||
return up
|
||||
|
||||
def get_bot(self, qq: int) -> Bot:
|
||||
def get_bot(self, qq: Optional[int] = None) -> Bot:
|
||||
"""
|
||||
根据 QQ 获取 Bot 实例
|
||||
|
||||
Args:
|
||||
qq: 需要获取 Bot 的 QQ
|
||||
qq: 需要获取 Bot 的 QQ,单 Bot 推送时可不传入
|
||||
|
||||
Returns:
|
||||
Bot 实例
|
||||
@ -104,6 +94,11 @@ class DataSource(metaclass=abc.ABCMeta):
|
||||
Raises:
|
||||
DataSourceException: QQ 不存在
|
||||
"""
|
||||
if qq is None:
|
||||
if len(self.bots) != 1:
|
||||
raise DataSourceException(f"多 Bot 推送时需明确指定要获取的 Bot QQ")
|
||||
return self.bots[0]
|
||||
|
||||
bot = next((b for b in self.bots if b.qq == qq), None)
|
||||
if bot is None:
|
||||
raise DataSourceException(f"不存在的 QQ: {qq}")
|
||||
@ -130,42 +125,6 @@ class DataSource(metaclass=abc.ABCMeta):
|
||||
|
||||
return ups
|
||||
|
||||
def get_target_by_key(self, key: str) -> PushTarget:
|
||||
"""
|
||||
根据推送 key 获取 PushTarget 实例,用于 HTTP API 推送
|
||||
|
||||
Args:
|
||||
key: 需要获取 PushTarget 的推送 key
|
||||
|
||||
Returns:
|
||||
PushTarget 实例
|
||||
|
||||
Raises:
|
||||
DataSourceException: key 不存在
|
||||
"""
|
||||
target = self.__target_key_map.get(key)
|
||||
if target is None:
|
||||
raise DataSourceException(f"不存在的推送 key: {key}")
|
||||
return target
|
||||
|
||||
def get_bot_by_key(self, key: str) -> Bot:
|
||||
"""
|
||||
根据推送 key 获取其所在的 Bot 实例,用于 HTTP API 推送
|
||||
|
||||
Args:
|
||||
key: 需要获取所在 Bot 的推送 key
|
||||
|
||||
Returns:
|
||||
Bot 实例
|
||||
|
||||
Raises:
|
||||
DataSourceException: key 不存在
|
||||
"""
|
||||
bot = self.__target_bot_map.get(key)
|
||||
if bot is None:
|
||||
raise DataSourceException(f"不存在的推送 key: {key}")
|
||||
return bot
|
||||
|
||||
async def wait_for_connects(self):
|
||||
"""
|
||||
等待所有 Up 实例连接直播间完毕
|
||||
@ -301,37 +260,37 @@ class MySQLDataSource(DataSource):
|
||||
推送目标列表
|
||||
"""
|
||||
live_on = await self.__query(
|
||||
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` "
|
||||
"FROM `groups` AS `g` LEFT JOIN `live_on` AS `l` "
|
||||
"ON g.`uid` = l.`uid` AND g.`index` = l.`index` "
|
||||
f"WHERE g.`uid` = {uid} "
|
||||
"ORDER BY g.`index`"
|
||||
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` "
|
||||
"FROM `targets` AS `t` LEFT JOIN `live_on` AS `l` "
|
||||
"ON t.`uid` = l.`uid` AND t.`id` = l.`id` "
|
||||
f"WHERE t.`uid` = {uid} "
|
||||
"ORDER BY t.`id`"
|
||||
)
|
||||
live_off = await self.__query(
|
||||
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` "
|
||||
"FROM `groups` AS `g` LEFT JOIN `live_off` AS `l` "
|
||||
"ON g.`uid` = l.`uid` AND g.`index` = l.`index` "
|
||||
f"WHERE g.`uid` = {uid} "
|
||||
"ORDER BY g.`index`"
|
||||
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` "
|
||||
"FROM `targets` AS `t` LEFT JOIN `live_off` AS `l` "
|
||||
"ON t.`uid` = l.`uid` AND t.`id` = l.`id` "
|
||||
f"WHERE t.`uid` = {uid} "
|
||||
"ORDER BY t.`id`"
|
||||
)
|
||||
live_report = await self.__query(
|
||||
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, "
|
||||
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, "
|
||||
"`enabled`, `logo`, `logo_base64`, `time`, `fans_change`, `fans_medal_change`, `guard_change`, "
|
||||
"`danmu`, `box`, `gift`, `sc`, `guard`, "
|
||||
"`danmu_ranking`, `box_ranking`, `box_profit_ranking`, `gift_ranking`, `sc_ranking`, "
|
||||
"`guard_list`, `box_profit_diagram`, `danmu_diagram`, `box_diagram`, `gift_diagram`, "
|
||||
"`sc_diagram`, `guard_diagram`, `danmu_cloud` "
|
||||
"FROM `groups` AS `g` LEFT JOIN `live_report` AS `l` "
|
||||
"ON g.`uid` = l.`uid` AND g.`index` = l.`index` "
|
||||
f"WHERE g.`uid` = {uid} "
|
||||
"ORDER BY g.`index`"
|
||||
"FROM `targets` AS `t` LEFT JOIN `live_report` AS `l` "
|
||||
"ON t.`uid` = l.`uid` AND t.`id` = l.`id` "
|
||||
f"WHERE t.`uid` = {uid} "
|
||||
"ORDER BY t.`id`"
|
||||
)
|
||||
dynamic_update = await self.__query(
|
||||
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` "
|
||||
"FROM `groups` AS `g` LEFT JOIN `dynamic_update` AS `d` "
|
||||
"ON g.`uid` = d.`uid` AND g.`index` = d.`index` "
|
||||
f"WHERE g.`uid` = {uid} "
|
||||
"ORDER BY g.`index`"
|
||||
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` "
|
||||
"FROM `targets` AS `t` LEFT JOIN `dynamic_update` AS `d` "
|
||||
"ON t.`uid` = d.`uid` AND t.`id` = d.`id` "
|
||||
f"WHERE t.`uid` = {uid} "
|
||||
"ORDER BY t.`id`"
|
||||
)
|
||||
|
||||
targets = []
|
||||
@ -430,15 +389,20 @@ class MySQLDataSource(DataSource):
|
||||
Args:
|
||||
uid: 需要追加读取配置的 UID
|
||||
"""
|
||||
if uid in self.get_uid_list():
|
||||
raise DataSourceException(f"载入 UID: {uid} 的推送配置失败, 不可重复载入")
|
||||
|
||||
user = await self.__query(f"SELECT * FROM `bot` WHERE uid = {uid}")
|
||||
if len(user) == 0:
|
||||
logger.error(f"载入 UID: {uid} 的推送配置失败, UID 不存在")
|
||||
raise DataSourceException(f"载入 UID: {uid} 的推送配置失败, UID 不存在")
|
||||
|
||||
bot = user[0].get("bot")
|
||||
qq = user[0].get("bot")
|
||||
targets = await self.__load_targets(uid)
|
||||
up = Up(uid=uid, targets=targets)
|
||||
self.get_bot(bot).ups.append(up)
|
||||
bot = self.get_bot(qq)
|
||||
bot.ups.append(up)
|
||||
up.inject_bot(bot)
|
||||
super().format_data()
|
||||
logger.success(f"已成功载入 UID: {uid} 的推送配置")
|
||||
|
||||
|
@ -264,13 +264,8 @@ class PushTarget(BaseModel):
|
||||
dynamic_update: Optional[DynamicUpdate] = DynamicUpdate()
|
||||
"""动态推送配置。默认:DynamicUpdate()"""
|
||||
|
||||
key: Optional[str] = None
|
||||
"""推送 Key,可选功能,可使用此 Key 通过 HTTP API 向对应的好友或群推送消息。默认:str(id)-str(type)"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
if not self.key:
|
||||
self.key = "-".join([str(self.id), str(self.type.value)])
|
||||
self.__raise_for_not_invalid_placeholders()
|
||||
|
||||
def __raise_for_not_invalid_placeholders(self):
|
||||
@ -287,7 +282,7 @@ class PushTarget(BaseModel):
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.key)
|
||||
return hash(self.id) ^ hash(self.type.value)
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
|
@ -3,10 +3,11 @@ from typing import Optional
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from aiohttp.web_routedef import RouteTableDef
|
||||
from graia.ariadne.exception import UnknownTarget
|
||||
from loguru import logger
|
||||
|
||||
from .datasource import DataSource
|
||||
from .model import Message
|
||||
from .model import Message, PushType
|
||||
from ..exception import DataSourceException
|
||||
from ..utils import config
|
||||
|
||||
@ -14,21 +15,55 @@ routes = web.RouteTableDef()
|
||||
datasource: Optional[DataSource] = None
|
||||
|
||||
|
||||
@routes.get("/send/{key}/{message}")
|
||||
async def send(request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
key = request.match_info['key']
|
||||
message = request.match_info['message']
|
||||
@routes.get("/send/{type}/{key}/{message}")
|
||||
async def send(request: aiohttp.web.Request, qq: int = None) -> aiohttp.web.Response:
|
||||
if len(datasource.bots) == 1:
|
||||
bot = datasource.get_bot()
|
||||
else:
|
||||
if qq is None:
|
||||
qq = config.get("HTTP_API_DEAFULT_BOT")
|
||||
if qq is None:
|
||||
logger.warning("HTTP API 推送失败, 多 Bot 推送时使用 HTTP API 需填写 HTTP_API_DEAFULT_BOT 配置项")
|
||||
return web.Response(text="fail")
|
||||
|
||||
try:
|
||||
bot = datasource.get_bot(qq)
|
||||
except DataSourceException:
|
||||
logger.warning("HTTP API 推送失败, 填写的 HTTP_API_DEAFULT_BOT 配置项不正确")
|
||||
return web.Response(text="fail")
|
||||
|
||||
if not str(request.match_info['key']).isdigit():
|
||||
logger.warning("HTTP API 推送失败, 传入的 QQ 或群号格式不正确")
|
||||
return web.Response(text="fail")
|
||||
|
||||
type_map = {
|
||||
"friend": PushType.Friend,
|
||||
"group": PushType.Group
|
||||
}
|
||||
_type = type_map.get(str(request.match_info['type']), None)
|
||||
if _type is None:
|
||||
logger.warning("HTTP API 推送失败, 传入的推送类型格式不正确")
|
||||
return web.Response(text="fail")
|
||||
|
||||
key = int(request.match_info['key'])
|
||||
message = Message(id=key, content=str(request.match_info['message']), type=_type)
|
||||
|
||||
try:
|
||||
target = datasource.get_target_by_key(key)
|
||||
bot = datasource.get_bot_by_key(key)
|
||||
msg = Message(id=target.id, content=message, type=target.type)
|
||||
await bot.send_message(msg)
|
||||
return web.Response(text="success")
|
||||
except DataSourceException:
|
||||
logger.warning(f"HTTP API 推送失败, 不存在的推送 key: {key}")
|
||||
await bot.send_message(message)
|
||||
except UnknownTarget:
|
||||
pass
|
||||
|
||||
return web.Response(text="success")
|
||||
|
||||
|
||||
@routes.get("/send/{bot}/{type}/{key}/{message}")
|
||||
async def send_by_bot(request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||
if not str(request.match_info['bot']).isdigit():
|
||||
logger.warning("HTTP API 推送失败, 传入的 Bot QQ 格式不正确")
|
||||
return web.Response(text="fail")
|
||||
|
||||
return await send(request, int(request.match_info['bot']))
|
||||
|
||||
|
||||
def get_routes() -> RouteTableDef:
|
||||
"""
|
||||
@ -43,6 +78,7 @@ def get_routes() -> RouteTableDef:
|
||||
async def http_init(source: DataSource):
|
||||
global datasource
|
||||
datasource = source
|
||||
port = config.get("HTTP_API_PORT")
|
||||
|
||||
logger.info("开始启动 HTTP API 推送服务")
|
||||
|
||||
@ -50,10 +86,10 @@ async def http_init(source: DataSource):
|
||||
app.add_routes(routes)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, 'localhost', config.get("HTTP_API_PORT"))
|
||||
site = web.TCPSite(runner, 'localhost', port)
|
||||
try:
|
||||
await site.start()
|
||||
except OSError:
|
||||
logger.error(f"设定的 HTTP API 端口 {config.get('HTTP_API_PORT')} 已被占用, HTTP API 推送服务启动失败")
|
||||
logger.error(f"设定的 HTTP API 端口 {port} 已被占用, HTTP API 推送服务启动失败")
|
||||
return
|
||||
logger.success("成功启动 HTTP API 推送服务")
|
||||
logger.success(f"成功启动 HTTP API 推送服务: http://localhost:{port}")
|
||||
|
@ -81,6 +81,8 @@ SIMPLE_CONFIG = {
|
||||
"USE_HTTP_API": False,
|
||||
# HTTP API 端口
|
||||
"HTTP_API_PORT": 8088,
|
||||
# 默认 HTTP API 推送 Bot QQ,多 Bot 推送时必填
|
||||
"HTTP_API_DEAFULT_BOT": None,
|
||||
|
||||
# 命令触发前缀
|
||||
"COMMAND_PREFIX": "",
|
||||
@ -175,9 +177,11 @@ FULL_CONFIG = {
|
||||
"PROXY": "",
|
||||
|
||||
# 是否使用 HTTP API 推送
|
||||
"USE_HTTP_API": True,
|
||||
"USE_HTTP_API": False,
|
||||
# HTTP API 端口
|
||||
"HTTP_API_PORT": 8088,
|
||||
# 默认 HTTP API 推送 Bot QQ,多 Bot 推送时必填
|
||||
"HTTP_API_DEAFULT_BOT": None,
|
||||
|
||||
# 命令触发前缀
|
||||
"COMMAND_PREFIX": "",
|
||||
|
@ -121,6 +121,10 @@ async def hincrbyfloat(key: str, hkey: Union[str, int], value: float = 1.0) -> f
|
||||
return await __redis.hincrbyfloat(key, hkey, value)
|
||||
|
||||
|
||||
async def hdel(key: str, hkey: Union[str, int]):
|
||||
await __redis.hdel(key, hkey)
|
||||
|
||||
|
||||
# Set
|
||||
|
||||
async def scard(key: str) -> int:
|
||||
|
Loading…
x
Reference in New Issue
Block a user