Add support for postgres as a db backend

This commit is contained in:
Andrew Morgan 2020-08-16 15:51:59 +01:00
parent dc3b83694b
commit ec25fead72
4 changed files with 118 additions and 33 deletions

View file

@ -52,9 +52,6 @@ class Config(object):
logger.addHandler(handler) logger.addHandler(handler)
# Storage setup # Storage setup
self.database_filepath = self._get_cfg(
["storage", "database_filepath"], required=True
)
self.store_filepath = self._get_cfg( self.store_filepath = self._get_cfg(
["storage", "store_filepath"], required=True ["storage", "store_filepath"], required=True
) )
@ -68,6 +65,23 @@ class Config(object):
f"storage.store_filepath '{self.store_filepath}' is not a directory" f"storage.store_filepath '{self.store_filepath}' is not a directory"
) )
# Database setup
database_path = self._get_cfg(["storage", "database"], required=True)
# Support both SQLite and Postgres backends
# Determine which one the user intends
sqlite_scheme = "sqlite://"
postgres_scheme = "postgres://"
if database_path.startswith(sqlite_scheme):
self.database = {
"type": "sqlite",
"connection_string": database_path[len(sqlite_scheme) :],
}
elif database_path.startswith(postgres_scheme):
self.database = {"type": "postgres", "connection_string": database_path}
else:
raise ConfigError("Invalid connection string for storage.database")
# Matrix bot account setup # Matrix bot account setup
self.user_id = self._get_cfg(["matrix", "user_id"], required=True) self.user_id = self._get_cfg(["matrix", "user_id"], required=True)
if not re.match("@.*:.*", self.user_id): if not re.match("@.*:.*", self.user_id):

View file

@ -32,7 +32,7 @@ async def main():
config = Config(config_filepath) config = Config(config_filepath)
# Configure the database # Configure the database
store = Storage(config.database_filepath) store = Storage(config)
# Configuration options for the AsyncClient # Configuration options for the AsyncClient
client_config = AsyncClientConfig( client_config = AsyncClientConfig(

View file

@ -1,50 +1,117 @@
import logging import logging
import os.path
import sqlite3
latest_db_version = 0 # The latest migration version of the database.
#
# Database migrations are applied starting from the number specified in the database's
# `migration_version` table + 1 (or from 0 if this table does not yet exist) up until
# the version specified here.
#
# When a migration is performed, the `migration_version` table should be incremented.
latest_migration_version = 0
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Storage(object): class Storage(object):
def __init__(self, db_path): def __init__(self, database_config):
"""Setup the database """Setup the database
Runs an initial setup or migrations depending on whether a database file has already Runs an initial setup or migrations depending on whether a database file has already
been created been created
Args: Args:
db_path (str): The name of the database file database_config: a dictionary containing the following keys:
* type: A string, one of "sqlite" or "postgres"
* connection_string: A string, featuring a connection string that
be fed to each respective db library's `connect` method
""" """
self.db_path = db_path self.conn = self._get_database_connection(
database_config["type"], database_config["connection_string"]
)
self.cursor = self.conn.cursor()
self.db_type = database_config["type"]
# Check if a database has already been connected # Try to check the current migration version
if os.path.isfile(self.db_path): migration_level = 0
self._run_migrations() try:
else: self._execute("SELECT version FROM migration_version")
row = self.cursor.fetchone()
migration_level = row[0]
except Exception:
self._initial_setup() self._initial_setup()
finally:
if migration_level < latest_migration_version:
self._run_migrations(migration_level)
logger.info(f"Database initialization of type '{self.db_type}' complete")
def _get_database_connection(self, database_type: str, connection_string: str):
if database_type == "sqlite":
import sqlite3
# Initialize a connection to the database, with autocommit on
return sqlite3.connect(connection_string, isolation_level=None)
elif database_type == "postgres":
import psycopg2
conn = psycopg2.connect(connection_string)
# Autocommit on
conn.set_isolation_level(0)
return conn
def _initial_setup(self): def _initial_setup(self):
"""Initial setup of the database""" """Initial setup of the database"""
logger.info("Performing initial database setup...") logger.info("Performing initial database setup...")
# Initialize a connection to the database # Set up the migration_version table
self.conn = sqlite3.connect(self.db_path) self._execute(
self.cursor = self.conn.cursor() """
CREATE TABLE migration_version (
# Sync token table version INTEGER PRIMARY KEY
self.cursor.execute(
"CREATE TABLE sync_token ("
"dedupe_id INTEGER PRIMARY KEY, "
"token TEXT NOT NULL"
")"
) )
"""
)
# Initially set the migration version to 0
self._execute(
"""
INSERT INTO migration_version (
version
) VALUES (?)
""",
(0,),
)
# Set up any other necessary database tables here
logger.info("Database setup complete") logger.info("Database setup complete")
def _run_migrations(self): def _run_migrations(self, current_migration_version: int):
"""Execute database migrations""" """Execute database migrations. Migrates the database to the
# Initialize a connection to the database `latest_migration_version`
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor() Args:
current_migration_version: The migration version that the database is
currently at
"""
logger.debug("Checking for necessary database migrations...")
# if current_migration_version < 1:
# logger.info("Migrating the database from v0 to v1...")
#
# # Add new table, delete old ones, etc.
#
# # Update the stored migration version
# self._execute("UPDATE migration_version SET version = 1")
#
# logger.info("Database migrated to v1")
def _execute(self, *args):
"""A wrapper around cursor.execute that transforms placeholder ?'s to %s for postgres
"""
if self.db_type == "postgres":
self.cursor.execute(args[0].replace("?", "%s"), *args[1:])
else:
self.cursor.execute(*args)

View file

@ -17,14 +17,18 @@ matrix:
# If this device ID already exists, messages will be dropped silently in encrypted rooms # If this device ID already exists, messages will be dropped silently in encrypted rooms
device_id: ABCDEFGHIJ device_id: ABCDEFGHIJ
# What to name the logged in device # What to name the logged in device
device_name: nio-template device_name: my-project-name
storage: storage:
# The path to the database # The database connection string
database_filepath: "bot.db" # For SQLite3, this would look like:
# database: "sqlite://bot.db"
# For Postgres, this would look like:
# database: "postgres://username:password@localhost/dbname?sslmode=disable"
database: "sqlite://bot.db"
# The path to a directory for internal bot storage # The path to a directory for internal bot storage
# containing encryption keys, sync tokens, etc. # containing encryption keys, sync tokens, etc.
store_filepath: "./store" store_path: "./store"
# Logging setup # Logging setup
logging: logging: