104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
import prometheus_client
|
|
from aiohttp import ClientError, web, web_request
|
|
from aiohttp_prometheus_exporter.handler import metrics
|
|
from aiohttp_prometheus_exporter.middleware import prometheus_middleware_factory
|
|
from diskcache import Cache
|
|
from nio import AsyncClient, LocalProtocolError
|
|
|
|
from matrix_alertbot.alert import Alert
|
|
from matrix_alertbot.chat_functions import send_text_to_room
|
|
from matrix_alertbot.config import Config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
routes = web.RouteTableDef()
|
|
|
|
|
|
@routes.post("/alert")
|
|
async def create_alert(request: web_request.Request) -> web.Response:
|
|
data = await request.json()
|
|
logger.info(f"Received alert: {data}")
|
|
client: AsyncClient = request.app["client"]
|
|
config: Config = request.app["config"]
|
|
cache: Cache = request.app["cache"]
|
|
|
|
if "alerts" not in data:
|
|
return web.Response(status=400, body="Data must contain 'alerts' key.")
|
|
|
|
if not isinstance(data["alerts"], list):
|
|
return web.Response(status=400, body="Alerts must be a list.")
|
|
|
|
if len(data["alerts"]) == 0:
|
|
return web.Response(status=400, body="Alerts cannot be empty.")
|
|
|
|
plaintext = ""
|
|
html = ""
|
|
for i, alert in enumerate(data["alerts"]):
|
|
try:
|
|
alert = Alert.from_dict(alert)
|
|
except KeyError:
|
|
return web.Response(status=400, body=f"Invalid alert: {alert}.")
|
|
|
|
if i != 0:
|
|
plaintext += "\n"
|
|
html += "<br/>\n"
|
|
plaintext += alert.plaintext()
|
|
html += alert.html()
|
|
|
|
try:
|
|
event = await send_text_to_room(
|
|
client, config.room_id, plaintext, html, notice=False
|
|
)
|
|
except (LocalProtocolError, ClientError) as e:
|
|
logger.error(e)
|
|
return web.Response(
|
|
status=500, body="An error occured when sending alerts to Matrix room."
|
|
)
|
|
|
|
fingerprints = tuple(alert["fingerprint"] for alert in data["alerts"])
|
|
cache.set(
|
|
event.event_id, fingerprints, expire=config.cache_expire_time, tag="event"
|
|
)
|
|
return web.Response(status=200)
|
|
|
|
|
|
class Webhook:
|
|
def __init__(self, client: AsyncClient, cache: Cache, config: Config) -> None:
|
|
self.app = web.Application(logger=logger)
|
|
self.app["client"] = client
|
|
self.app["config"] = config
|
|
self.app["cache"] = cache
|
|
self.app.add_routes(routes)
|
|
|
|
prometheus_registry = prometheus_client.CollectorRegistry(auto_describe=True)
|
|
self.app.middlewares.append(
|
|
prometheus_middleware_factory(registry=prometheus_registry)
|
|
)
|
|
self.app.router.add_get("/metrics", metrics())
|
|
|
|
self.runner = web.AppRunner(self.app)
|
|
|
|
self.config = config
|
|
self.address = config.address
|
|
self.port = config.port
|
|
self.socket = config.socket
|
|
|
|
async def start(self) -> None:
|
|
await self.runner.setup()
|
|
|
|
site: web.BaseSite
|
|
if self.address and self.port:
|
|
site = web.TCPSite(self.runner, self.address, self.port)
|
|
logger.info(f"Listenning on {self.address}:{self.port}")
|
|
elif self.socket:
|
|
site = web.UnixSite(self.runner, self.socket)
|
|
logger.info(f"Listenning on unix://{self.socket}")
|
|
|
|
await site.start()
|
|
|
|
async def close(self) -> None:
|
|
await self.runner.cleanup()
|