Source code for iamai.adapter.utils

"""适配器实用工具。

这里定义了一些在编写适配器时常用的基类,适配器开发者可以直接继承自这里的类或者用作参考。
"""

import asyncio
from abc import ABCMeta, abstractmethod
from typing import Literal, Optional, Union

import aiohttp
from aiohttp import web

from iamai.adapter import Adapter
from iamai.log import logger
from iamai.typing import ConfigT, EventT

__all__ = [
    "PollingAdapter",
    "HttpClientAdapter",
    "WebSocketClientAdapter",
    "HttpServerAdapter",
    "WebSocketServerAdapter",
    "WebSocketAdapter",
]


[docs] class PollingAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta): """轮询式适配器示例。""" delay: float = 0.1 create_task: bool = False _on_tick_task: Optional["asyncio.Task[None]"] = None
[docs] async def run(self) -> None: """运行适配器。""" while not self.bot.should_exit.is_set(): await asyncio.sleep(self.delay) if self.create_task: self._on_tick_task = asyncio.create_task(self.on_tick()) else: await self.on_tick()
[docs] @abstractmethod async def on_tick(self) -> None: """当轮询发生。"""
[docs] class HttpClientAdapter(PollingAdapter[EventT, ConfigT], metaclass=ABCMeta): """HTTP 客户端适配器示例。""" session: aiohttp.ClientSession
[docs] async def startup(self) -> None: """初始化适配器。""" self.session = aiohttp.ClientSession()
[docs] @abstractmethod async def on_tick(self) -> None: """当轮询发生。"""
[docs] async def shutdown(self) -> None: """关闭并清理连接。""" await self.session.close()
[docs] class WebSocketClientAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta): """WebSocket 客户端适配器示例。""" url: str
[docs] async def run(self) -> None: """运行适配器。""" async with ( aiohttp.ClientSession() as session, session.ws_connect(self.url) as ws, ): msg: aiohttp.WSMessage async for msg in ws: if self.bot.should_exit.is_set(): break if msg.type == aiohttp.WSMsgType.ERROR: break await self.handle_response(msg)
[docs] @abstractmethod async def handle_response(self, msg: aiohttp.WSMessage) -> None: """处理响应。"""
[docs] class HttpServerAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta): """HTTP 服务端适配器示例。""" app: web.Application runner: web.AppRunner site: web.TCPSite host: str port: int get_url: str post_url: str
[docs] async def startup(self) -> None: """初始化适配器。""" self.app = web.Application() self.app.add_routes( [ web.get(self.get_url, self.handle_response), web.post(self.post_url, self.handle_response), ] )
[docs] async def run(self) -> None: """运行适配器。""" self.runner = web.AppRunner(self.app) await self.runner.setup() self.site = web.TCPSite(self.runner, self.host, self.port) await self.site.start()
[docs] async def shutdown(self) -> None: """关闭并清理连接。""" await self.runner.cleanup()
[docs] @abstractmethod async def handle_response(self, request: web.Request) -> web.StreamResponse: """处理响应。"""
[docs] class WebSocketServerAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta): """WebSocket 服务端适配器示例。""" app: web.Application runner: web.AppRunner site: web.TCPSite websocket: web.WebSocketResponse host: str port: int url: str
[docs] async def startup(self) -> None: """初始化适配器。""" self.app = web.Application() self.app.add_routes([web.get(self.url, self.handle_response)])
[docs] async def run(self) -> None: """运行适配器。""" self.runner = web.AppRunner(self.app) await self.runner.setup() self.site = web.TCPSite(self.runner, self.host, self.port) await self.site.start()
[docs] async def shutdown(self) -> None: """关闭并清理连接。""" await self.websocket.close() await self.site.stop() await self.runner.cleanup()
[docs] async def handle_response(self, request: web.Request) -> web.WebSocketResponse: """处理 WebSocket。""" ws = web.WebSocketResponse() await ws.prepare(request) self.websocket = ws msg: aiohttp.WSMessage async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: await self.handle_ws_response(msg) elif msg.type == aiohttp.WSMsgType.ERROR: break return ws
[docs] @abstractmethod async def handle_ws_response(self, msg: aiohttp.WSMessage) -> None: """处理 WebSocket 响应。"""
[docs] class WebSocketAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta): """WebSocket 适配器示例。 同时支持 WebSocket 客户端和服务端。 """ websocket: Union[web.WebSocketResponse, aiohttp.ClientWebSocketResponse, None] = ( None ) # ws session: Optional[aiohttp.ClientSession] # reverse-ws app: Optional[web.Application] runner: Optional[web.AppRunner] site: Optional[web.TCPSite] # config adapter_type: Literal["ws", "reverse-ws"] host: str port: int url: str reconnect_interval: int = 3
[docs] async def startup(self) -> None: """初始化适配器。""" if self.adapter_type == "ws": self.session = aiohttp.ClientSession() elif self.adapter_type == "reverse-ws": self.app = web.Application() self.app.add_routes([web.get(self.url, self.handle_reverse_ws_response)]) else: logger.error( 'Config "adapter_type" must be "ws" or "reverse-ws", not ' + self.adapter_type )
[docs] async def run(self) -> None: """运行适配器。""" if self.adapter_type == "ws": while True: try: await self.websocket_connect() except aiohttp.ClientError as e: self.bot.error_or_exception("WebSocket connection error:", e) if self.bot.should_exit.is_set(): break await asyncio.sleep(self.reconnect_interval) elif self.adapter_type == "reverse-ws": assert self.app is not None self.runner = web.AppRunner(self.app) await self.runner.setup() self.site = web.TCPSite(self.runner, self.host, self.port) await self.site.start()
[docs] async def shutdown(self) -> None: """关闭并清理连接。""" if self.websocket is not None: await self.websocket.close() if self.adapter_type == "ws": if self.session is not None: await self.session.close() elif self.adapter_type == "reverse-ws": if self.site is not None: await self.site.stop() if self.runner is not None: await self.runner.cleanup()
[docs] async def handle_reverse_ws_response( self, request: web.Request ) -> web.WebSocketResponse: """处理 aiohttp WebSocket 服务器的接收。""" self.websocket = web.WebSocketResponse() await self.websocket.prepare(request) await self.reverse_ws_connection_hook() await self.handle_websocket() return self.websocket
[docs] async def reverse_ws_connection_hook(self) -> None: """反向 WebSocket 连接建立时的钩子函数。""" logger.info("WebSocket connected!")
[docs] async def websocket_connect(self) -> None: """创建正向 WebSocket 连接。""" assert self.session is not None logger.info("Tying to connect to WebSocket server...") async with self.session.ws_connect( f"ws://{self.host}:{self.port}{self.url}" ) as self.websocket: await self.handle_websocket()
[docs] async def handle_websocket(self) -> None: """处理 WebSocket。""" if self.websocket is None or self.websocket.closed: return async for msg in self.websocket: await self.handle_websocket_msg(msg) if not self.bot.should_exit.is_set(): logger.warning("WebSocket connection closed!")
[docs] @abstractmethod async def handle_websocket_msg(self, msg: aiohttp.WSMessage) -> None: """处理 WebSocket 消息。""" raise NotImplementedError