From c8846da599ab572515b42b17f1248f2dbedf9b6d Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Sun, 23 Feb 2020 23:17:25 +0000 Subject: [PATCH] Helper method for loading config file options --- config.py | 71 ++++++++++++++++++++++++++++++++---------------- requirements.txt | 2 +- 2 files changed, 48 insertions(+), 25 deletions(-) diff --git a/config.py b/config.py index bc63f83..5aa6cf7 100644 --- a/config.py +++ b/config.py @@ -3,6 +3,7 @@ import re import os import yaml import sys +from typing import List, Any from errors import ConfigError logger = logging.getLogger() @@ -19,56 +20,78 @@ class Config(object): # Load in the config file at the given filepath with open(filepath) as file_stream: - config = yaml.full_load(file_stream.read()) + self.config = yaml.safe_load(file_stream.read()) # Logging setup formatter = logging.Formatter('%(asctime)s | %(name)s [%(levelname)s] %(message)s') - log_dict = config.get("logging", {}) - log_level = log_dict.get("level", "INFO") + log_level = self._get_cfg(["logging", "level"], default="INFO") logger.setLevel(log_level) - file_logging = log_dict.get("file_logging", {}) - file_logging_enabled = file_logging.get("enabled", False) - file_logging_filepath = file_logging.get("filepath", "bot.log") + file_logging_enabled = self._get_cfg(["logging", "file_logging", "enabled"], default=False) + file_logging_filepath = self._get_cfg(["logging", "file_logging", "filepath"], default="bot.log") if file_logging_enabled: handler = logging.FileHandler(file_logging_filepath) handler.setFormatter(formatter) logger.addHandler(handler) - console_logging = log_dict.get("console_logging", {}) - console_logging_enabled = console_logging.get("enabled", True) + console_logging_enabled = self._get_cfg(["logging", "console_logging", "enabled"], default=True) if console_logging_enabled: handler = logging.StreamHandler(sys.stdout) handler.setFormatter(formatter) logger.addHandler(handler) # Database setup - database_dict = config.get("database", {}) - self.database_filepath = database_dict.get("filepath") + self.database_filepath = self._get_cfg(["database", "filepath"], required=True) # Matrix bot account setup - matrix = config.get("matrix", {}) - - self.user_id = matrix.get("user_id") - if not self.user_id: - raise ConfigError("matrix.user_id is a required field") - elif not re.match("@.*:.*", self.user_id): + self.user_id = self._get_cfg(["matrix", "user_id"], required=True) + if not re.match("@.*:.*", self.user_id): raise ConfigError("matrix.user_id must be in the form @name:domain") - self.access_token = matrix.get("access_token") - if not self.access_token: - raise ConfigError("matrix.access_token is a required field") + self.access_token = self._get_cfg(["matrix", "access_token"], required=True) - self.device_id = matrix.get("device_id") + self.device_id = self._get_cfg(["matrix", "device_id"]) if not self.device_id: logger.warning( "Config option matrix.device_id is not provided, which means " "that end-to-end encryption won't work correctly" ) - self.homeserver_url = matrix.get("homeserver_url") - if not self.homeserver_url: - raise ConfigError("matrix.homeserver_url is a required field") + self.homeserver_url = self._get_cfg(["matrix", "homeserver_url"], required=True) - self.command_prefix = config.get("command_prefix", "!c") + " " + self.command_prefix = self._get_cfg(["command_prefix"], default="!c") + " " + + def _get_cfg( + self, + path: List[str], + default: Any = None, + required: bool = True, + ) -> Any: + """Get a config option from a path and option name, specifying whether it is + required. + + Raises: + ConfigError: If required is specified and the object is not found + (and there is no default value provided), this error will be raised + """ + path_str = '.'.join(path) + + # Sift through the the config until we reach our option + config = self.config + for name in path: + print("Name is", name) + config = config.get(name) + print("Config is", config) + + # If at any point we don't get our expected option... + if config is None: + # Raise an error if it was required + if required or not default: + raise ConfigError(f"Config option {path_str} is required") + + # or return the default value + return default + + # We found the option. Return it + return config diff --git a/requirements.txt b/requirements.txt index fa9ef1a..26f4dcf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -matrix-nio>=0.6 +matrix-nio>=0.8.0 Markdown>=3.1.1 PyYAML>=5.1.2