diff --git a/matrix_alertbot/alertmanager.py b/matrix_alertbot/alertmanager.py index eb55a15..b2caa5b 100644 --- a/matrix_alertbot/alertmanager.py +++ b/matrix_alertbot/alertmanager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime from typing import Any, Dict, List @@ -20,6 +22,9 @@ class AlertmanagerClient(AsyncContextManager): self.cache = cache self.session = aiohttp.ClientSession() + def __aenter__(self) -> AlertmanagerClient: + return self + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: await super().__aexit__(*args, **kwargs) await self.close() diff --git a/matrix_alertbot/main.py b/matrix_alertbot/main.py index b75514c..a35c2a3 100644 --- a/matrix_alertbot/main.py +++ b/matrix_alertbot/main.py @@ -5,6 +5,7 @@ import sys from asyncio import TimeoutError from time import sleep +import aiotools from aiohttp import ClientConnectionError, ServerDisconnectedError from diskcache import Cache from nio import ( @@ -52,7 +53,7 @@ def create_matrix_client(config: Config) -> AsyncClient: async def start_matrix_client(cache: Cache, config: Config) -> bool: - async with create_matrix_client(config) as client: + async with aiotools.closing_async(create_matrix_client(config)) as client: # Configure Alertmanager client async with AlertmanagerClient(config.alertmanager_url, cache) as alertmanager: # Set up event callbacks @@ -114,9 +115,9 @@ async def start_matrix_client(cache: Cache, config: Config) -> bool: async def start_webhook_server(cache: Cache, config: Config) -> None: - async with create_matrix_client(config) as client: - webhook_server = Webhook(client, cache, config) - await webhook_server.start() + async with aiotools.closing_async(create_matrix_client(config)) as client: + async with Webhook(client, cache, config) as webhook_server: + await webhook_server.start() def main() -> None: diff --git a/matrix_alertbot/webhook.py b/matrix_alertbot/webhook.py index b2b353c..f2254e4 100644 --- a/matrix_alertbot/webhook.py +++ b/matrix_alertbot/webhook.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import logging +from typing import Any from aiohttp import web, web_request +from aiotools import AsyncContextManager from diskcache import Cache from nio import AsyncClient, SendRetryError @@ -47,7 +51,7 @@ async def create_alert(request: web_request.Request) -> web.Response: return web.Response(status=200) -class Webhook: +class Webhook(AsyncContextManager): def __init__(self, client: AsyncClient, cache: Cache, config: Config) -> None: self.app = web.Application(logger=logger) self.app["client"] = client @@ -61,6 +65,13 @@ class Webhook: self.port = config.port self.socket = config.socket + def __aenter__(self) -> Webhook: + return self + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + await super().__aexit__(*args, **kwargs) + await self.close() + async def start(self) -> None: await self.runner.setup()