feat: HTTP API support

This commit is contained in:
LWR 2022-11-03 19:15:00 +08:00
parent f2c271c3b8
commit 6ce2e4c262
5 changed files with 157 additions and 30 deletions

View File

@ -1,3 +1,4 @@
import asyncio
import sys
from creart import create
@ -6,9 +7,10 @@ from graia.broadcast import Broadcast
from loguru import logger
from .datasource import DataSource
from .server import http_init
from ..exception.DataSourceException import DataSourceException
from ..exception.RedisException import RedisException
from ..utils import redis
from ..utils import redis, config
class StarBot:
@ -59,9 +61,24 @@ class StarBot:
logger.error(ex.msg)
return
# 启动 HTTP API 服务
if config.get("USE_HTTP_API"):
asyncio.get_event_loop().create_task(http_init(self.__datasource))
# 启动 Bot
logger.info("开始启动 Ariadne 消息推送模块")
Ariadne.options["default_account"] = 1499887988
if not self.__datasource.bots:
logger.error("不存在需要启动的 Bot 账号, 请先在数据源中配置完毕后再重新运行")
return
Ariadne.options["default_account"] = self.__datasource.bots[0].qq
logger.info("开始运行 Ariadne 消息推送模块")
logger.disable("graia.ariadne.service")
logger.disable("launart")
for bot in self.__datasource.bots:
bot.start_sender()
try:
Ariadne.launch_blocking()
except RuntimeError as ex:

View File

@ -20,10 +20,13 @@ class DataSource(metaclass=abc.ABCMeta):
"""
def __init__(self):
self.bots = {}
self.__up_list = []
self.__uid_list = []
self.__up_map = {}
self.bots: List[Bot] = []
self.__up_list: List[Up] = []
self.__up_map: Dict[int, Up] = {}
self.__uid_list: List[int] = []
self.__target_list: List[PushTarget] = []
self.__target_key_map: Dict[str, PushTarget] = {}
self.__target_bot_map: Dict[str, Bot] = {}
@abc.abstractmethod
async def load(self):
@ -39,11 +42,18 @@ class DataSource(metaclass=abc.ABCMeta):
Raises:
DataSourceException: 配置中包含重复 uid
"""
self.__up_list = [x for up in map(lambda bot: bot.ups, self.bots.values()) for x in up]
self.__uid_list = list(map(lambda up: up.uid, self.__up_list))
self.__up_list = [x for up in map(lambda bot: bot.ups, self.bots) for x in up]
self.__up_map = dict(zip(map(lambda up: up.uid, self.__up_list), self.__up_list))
self.__uid_list = list(self.__up_map.keys())
if len(set(self.__uid_list)) < len(self.__uid_list):
raise DataSourceException("配置中不可含有重复的 UID")
self.__up_map = dict(zip(map(lambda up: up.uid, self.__up_list), self.__up_list))
self.__target_list = [x for target in map(lambda up: up.targets, self.__up_list) for x in target]
self.__target_key_map = dict(zip(map(lambda target: target.key, self.__target_list), self.__target_list))
for bot in self.bots:
for up in bot.ups:
for target in up.targets:
self.__target_bot_map[target.key] = bot
def get_up_list(self) -> List[Up]:
"""
@ -81,6 +91,42 @@ class DataSource(metaclass=abc.ABCMeta):
raise DataSourceException(f"不存在的 UID: {uid}")
return up
def get_target_by_key(self, key: str) -> PushTarget:
"""
根据推送 key 获取 PushTarget 实例用于 HTTP API 推送
Args:
key: 需要获取 PushTarget 的推送 key
Returns:
PushTarget 实例
Raises:
DataSourceException: key 不存在
"""
target = self.__target_key_map.get(key)
if target is None:
raise DataSourceException(f"不存在的推送 key: {key}")
return target
def get_bot_by_key(self, key: str) -> Bot:
"""
根据推送 key 获取其所在的 Bot 实例用于 HTTP API 推送
Args:
key: 需要获取所在 Bot 的推送 key
Returns:
Bot 实例
Raises:
DataSourceException: key 不存在
"""
bot = self.__target_bot_map.get(key)
if bot is None:
raise DataSourceException(f"不存在的推送 key: {key}")
return bot
class DictDataSource(DataSource):
"""
@ -105,22 +151,22 @@ class DictDataSource(DataSource):
Raises:
DataSourceException: 配置字典格式错误或缺少必要参数
"""
if not self.bots:
logger.info("已选用 Dict 作为 Bot 数据源")
logger.info("开始从 Dict 中初始化 Bot 配置")
else:
logger.info("开始从 Dict 中更新 Bot 配置")
if self.bots:
return
logger.info("已选用 Dict 作为 Bot 数据源")
logger.info("开始从 Dict 中初始化 Bot 配置")
for bot in self.__config:
if "qq" not in bot:
raise DataSourceException("提供的配置字典中未提供 Bot 的 QQ 号参数")
try:
self.bots.update({bot["qq"]: Bot(**bot)})
self.bots.append(Bot(**bot))
except ValidationError as ex:
raise DataSourceException(f"提供的配置字典中缺少必须的 {ex.errors()[0].get('loc')[-1]} 参数")
super().format_data()
logger.success(f"成功从 Dict 中导入 {len(self.get_up_list())} 个 UP 主")
logger.success(f"成功从 Dict 中导入 {len(self.get_up_list())} 个 UP 主")
class MySQLDataSource(DataSource):
@ -196,13 +242,14 @@ class MySQLDataSource(DataSource):
async def load(self):
"""
MySQL 读取配置
MySQL 初始化配置
"""
if not self.bots:
logger.info("已选用 MySQL 作为 Bot 数据源")
logger.info("开始从 MySQL 中初始化 Bot 配置")
else:
logger.info("开始从 MySQL 中更新 Bot 配置")
if self.bots:
return
logger.info("已选用 MySQL 作为 Bot 数据源")
logger.info("开始从 MySQL 中初始化 Bot 配置")
if not self.__pool:
await self.__connect()
@ -283,7 +330,7 @@ class MySQLDataSource(DataSource):
ups.append(Up(uid=uid, targets=targets))
self.bots.update({bot: Bot(qq=bot, ups=ups)})
self.bots.append(Bot(qq=bot, ups=ups))
super().format_data()
logger.success(f"成功从 MySQL 中导入了 {len(self.get_up_list())} 个 UP 主")

View File

@ -4,7 +4,6 @@ from typing import List, Optional, Any
from graia.ariadne import Ariadne
from graia.ariadne.connection.config import config as AriadneConfig, HttpClientConfig, WebsocketClientConfig
from graia.ariadne.event.lifecycle import ApplicationLaunched
from graia.ariadne.message.chain import MessageChain
from graia.ariadne.message.element import At, AtAll
from graia.ariadne.model import LogConfig, MemberPerm
@ -55,11 +54,9 @@ class Bot(BaseModel, AsyncEvent):
for up in self.ups:
up.inject_bot(self)
# Ariadne 启动成功后启动消息发送模块
@self.__bot.broadcast.receiver(ApplicationLaunched)
async def start_sender():
logger.success(f"Bot [{self.qq}] 已启动")
self.__loop.create_task(self.__sender())
def start_sender(self):
self.__loop.create_task(self.__sender())
logger.success(f"Bot [{self.qq}] 已启动")
def send_message(self, msg: Message):
self.__queue.append(msg)
@ -68,6 +65,8 @@ class Bot(BaseModel, AsyncEvent):
"""
消息发送模块
"""
interval = config.get("MESSAGE_SEND_INTERVAL")
while True:
if self.__queue:
msg = self.__queue[0]
@ -80,6 +79,7 @@ class Bot(BaseModel, AsyncEvent):
logger.info(f"{self.qq} -> 群[{msg.id}] : {message}")
await self.__bot.send_group_message(msg.id, message)
self.__queue.pop(0)
await asyncio.sleep(interval)
else:
await asyncio.sleep(0.1)

55
starbot/core/server.py Normal file
View File

@ -0,0 +1,55 @@
from typing import Optional
import aiohttp
from aiohttp import web
from aiohttp.web_routedef import RouteTableDef
from loguru import logger
from .datasource import DataSource
from .model import Message
from ..exception import DataSourceException
from ..utils import config
routes = web.RouteTableDef()
datasource: Optional[DataSource] = None
@routes.get("/send/{key}/{message}")
async def send(request: aiohttp.web.Request) -> aiohttp.web.Response:
key = request.match_info['key']
message = request.match_info['message']
try:
target = datasource.get_target_by_key(key)
bot = datasource.get_bot_by_key(key)
msg = Message(id=target.id, content=message, type=target.type)
bot.send_message(msg)
return web.Response(text="success")
except DataSourceException:
logger.warning(f"HTTP API 推送失败, 不存在的推送 key: {key}")
return web.Response(text="fail")
def get_routes() -> RouteTableDef:
"""
获取路由可用于外部扩展功能
Returns:
路由实例
"""
return routes
async def http_init(source: DataSource):
global datasource
datasource = source
logger.info("开始启动 HTTP API 推送服务")
app = web.Application()
app.add_routes(routes)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, 'localhost', config.get("HTTP_API_PORT"))
await site.start()
logger.success("成功启动 HTTP API 推送服务")

View File

@ -52,6 +52,9 @@ SIMPLE_CONFIG = {
# HTTP API 端口
"HTTP_API_PORT": 8088,
# 消息发送间隔,消息发送过快容易被风控,单位:秒
"MESSAGE_SEND_INTERVAL": 0.5,
# 命令触发前缀
"COMMAND_PREFIX": "",
@ -117,6 +120,9 @@ FULL_CONFIG = {
# HTTP API 端口
"HTTP_API_PORT": 8088,
# 消息发送间隔,消息发送过快容易被风控,单位:秒
"MESSAGE_SEND_INTERVAL": 0.5,
# 命令触发前缀
"COMMAND_PREFIX": "",
@ -158,6 +164,7 @@ def use_simple_config():
未设置 Bot 主人 QQ
不使用 HTTP 代理
不开启 HTTP API 推送
消息发送间隔 0.5
无命令触发前缀
不开启风控消息补发
"""
@ -178,6 +185,7 @@ def use_full_config():
未设置 Bot 主人 QQ
不使用 HTTP 代理
开启 HTTP API 推送 (port: 8088)
消息发送间隔 0.5
无命令触发前缀
开启风控消息补发 仅补发推送消息
"""