diff --git a/my_project_name/config.py b/my_project_name/config.py index 292895f..1def0d2 100644 --- a/my_project_name/config.py +++ b/my_project_name/config.py @@ -52,9 +52,6 @@ class Config(object): logger.addHandler(handler) # Storage setup - self.database_filepath = self._get_cfg( - ["storage", "database_filepath"], required=True - ) self.store_filepath = self._get_cfg( ["storage", "store_filepath"], required=True ) @@ -68,6 +65,23 @@ class Config(object): f"storage.store_filepath '{self.store_filepath}' is not a directory" ) + # Database setup + database_path = self._get_cfg(["storage", "database"], required=True) + + # Support both SQLite and Postgres backends + # Determine which one the user intends + sqlite_scheme = "sqlite://" + postgres_scheme = "postgres://" + if database_path.startswith(sqlite_scheme): + self.database = { + "type": "sqlite", + "connection_string": database_path[len(sqlite_scheme) :], + } + elif database_path.startswith(postgres_scheme): + self.database = {"type": "postgres", "connection_string": database_path} + else: + raise ConfigError("Invalid connection string for storage.database") + # Matrix bot account setup self.user_id = self._get_cfg(["matrix", "user_id"], required=True) if not re.match("@.*:.*", self.user_id): diff --git a/my_project_name/main.py b/my_project_name/main.py index 1081f7a..59141c9 100644 --- a/my_project_name/main.py +++ b/my_project_name/main.py @@ -32,7 +32,7 @@ async def main(): config = Config(config_filepath) # Configure the database - store = Storage(config.database_filepath) + store = Storage(config) # Configuration options for the AsyncClient client_config = AsyncClientConfig( diff --git a/my_project_name/storage.py b/my_project_name/storage.py index eeb5e38..1e84ea0 100644 --- a/my_project_name/storage.py +++ b/my_project_name/storage.py @@ -1,50 +1,117 @@ import logging -import os.path -import sqlite3 -latest_db_version = 0 +# The latest migration version of the database. +# +# Database migrations are applied starting from the number specified in the database's +# `migration_version` table + 1 (or from 0 if this table does not yet exist) up until +# the version specified here. +# +# When a migration is performed, the `migration_version` table should be incremented. +latest_migration_version = 0 logger = logging.getLogger(__name__) class Storage(object): - def __init__(self, db_path): + def __init__(self, database_config): """Setup the database Runs an initial setup or migrations depending on whether a database file has already been created Args: - db_path (str): The name of the database file + database_config: a dictionary containing the following keys: + * type: A string, one of "sqlite" or "postgres" + * connection_string: A string, featuring a connection string that + be fed to each respective db library's `connect` method """ - self.db_path = db_path + self.conn = self._get_database_connection( + database_config["type"], database_config["connection_string"] + ) + self.cursor = self.conn.cursor() + self.db_type = database_config["type"] - # Check if a database has already been connected - if os.path.isfile(self.db_path): - self._run_migrations() - else: + # Try to check the current migration version + migration_level = 0 + try: + self._execute("SELECT version FROM migration_version") + row = self.cursor.fetchone() + migration_level = row[0] + except Exception: self._initial_setup() + finally: + if migration_level < latest_migration_version: + self._run_migrations(migration_level) + + logger.info(f"Database initialization of type '{self.db_type}' complete") + + def _get_database_connection(self, database_type: str, connection_string: str): + if database_type == "sqlite": + import sqlite3 + + # Initialize a connection to the database, with autocommit on + return sqlite3.connect(connection_string, isolation_level=None) + elif database_type == "postgres": + import psycopg2 + + conn = psycopg2.connect(connection_string) + + # Autocommit on + conn.set_isolation_level(0) + + return conn def _initial_setup(self): """Initial setup of the database""" logger.info("Performing initial database setup...") - # Initialize a connection to the database - self.conn = sqlite3.connect(self.db_path) - self.cursor = self.conn.cursor() - - # Sync token table - self.cursor.execute( - "CREATE TABLE sync_token (" - "dedupe_id INTEGER PRIMARY KEY, " - "token TEXT NOT NULL" - ")" + # Set up the migration_version table + self._execute( + """ + CREATE TABLE migration_version ( + version INTEGER PRIMARY KEY + ) + """ ) + # Initially set the migration version to 0 + self._execute( + """ + INSERT INTO migration_version ( + version + ) VALUES (?) + """, + (0,), + ) + + # Set up any other necessary database tables here + logger.info("Database setup complete") - def _run_migrations(self): - """Execute database migrations""" - # Initialize a connection to the database - self.conn = sqlite3.connect(self.db_path) - self.cursor = self.conn.cursor() + def _run_migrations(self, current_migration_version: int): + """Execute database migrations. Migrates the database to the + `latest_migration_version` + + Args: + current_migration_version: The migration version that the database is + currently at + """ + logger.debug("Checking for necessary database migrations...") + + # if current_migration_version < 1: + # logger.info("Migrating the database from v0 to v1...") + # + # # Add new table, delete old ones, etc. + # + # # Update the stored migration version + # self._execute("UPDATE migration_version SET version = 1") + # + # logger.info("Database migrated to v1") + + def _execute(self, *args): + """A wrapper around cursor.execute that transforms placeholder ?'s to %s for postgres + """ + if self.db_type == "postgres": + self.cursor.execute(args[0].replace("?", "%s"), *args[1:]) + else: + self.cursor.execute(*args) diff --git a/sample.config.yaml b/sample.config.yaml index 1606b7f..969b3a5 100644 --- a/sample.config.yaml +++ b/sample.config.yaml @@ -17,14 +17,18 @@ matrix: # If this device ID already exists, messages will be dropped silently in encrypted rooms device_id: ABCDEFGHIJ # What to name the logged in device - device_name: nio-template + device_name: my-project-name storage: - # The path to the database - database_filepath: "bot.db" + # The database connection string + # For SQLite3, this would look like: + # database: "sqlite://bot.db" + # For Postgres, this would look like: + # database: "postgres://username:password@localhost/dbname?sslmode=disable" + database: "sqlite://bot.db" # The path to a directory for internal bot storage # containing encryption keys, sync tokens, etc. - store_filepath: "./store" + store_path: "./store" # Logging setup logging: