This commit is contained in:
2021-08-02 22:12:33 +02:00
parent fed29ebbbe
commit 7b3498bd27
57 changed files with 2778 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
import sys
# Check that we're not running on an unsupported Python version.
if sys.version_info < (3, 5):
print("my_project_name requires Python 3.5 or above.")
sys.exit(1)
__version__ = "0.0.1"

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,120 @@
from nio import AsyncClient, MatrixRoom, RoomMessageText
from my_project_name.chat_functions import react_to_event, send_text_to_room
from my_project_name.config import Config
from my_project_name.storage import Storage
from my_project_name.caldav_handler import CaldavHandler
class Command:
def __init__(
self,
client: AsyncClient,
store: Storage,
config: Config,
command: str,
room: MatrixRoom,
event: RoomMessageText,
):
"""A command made by a user.
Args:
client: The client to communicate to matrix with.
store: Bot storage.
config: Bot configuration parameters.
command: The command and arguments.
room: The room the command was sent in.
event: The event describing the command.
"""
self.client = client
self.store = store
self.config = config
self.command = command
self.room = room
self.event = event
self.args = self.command.split()[1:]
async def process(self):
"""Process the command"""
#if self.command.startswith("echo"):
# await self._echo()
if self.command.startswith("react"):
await self._react()
elif self.command.startswith("help"):
await self._show_help()
elif self.command.startswith("today"):
await self._show_today()
elif self.command.startswith("week"):
await self._show_week()
elif self.command.startswith("month"):
await self._show_month()
#else:
# await self._unknown_command()
async def _show_today(self):
handler = CaldavHandler()
response = handler.print_today()
if len(response) == 0:
response = "today is nothing planned yet. riot or read theory"
await send_text_to_room(self.client, self.room.room_id, response)
async def _show_week(self):
handler = CaldavHandler()
response = handler.print_week()
await send_text_to_room(self.client, self.room.room_id, response)
async def _show_month(self):
handler = CaldavHandler()
response = handler.print_month()
await send_text_to_room(self.client, self.room.room_id, response, markdown_convert=True)
async def _echo(self):
"""Echo back the command's arguments"""
response = " ".join(self.args)
await send_text_to_room(self.client, self.room.room_id, response)
async def _react(self):
"""Make the bot react to the command message"""
# React with a start emoji
reaction = ""
await react_to_event(
self.client, self.room.room_id, self.event.event_id, reaction
)
# React with some generic text
reaction = "(A)"
await react_to_event(
self.client, self.room.room_id, self.event.event_id, reaction
)
async def _show_help(self):
"""Show the help text"""
if not self.args:
text = (
"Hello, I am kallauser MC's (A)wesome calendar bot <3! Use `help commands` to view "
"available commands.\n"
"Use `help rules` to view the rules"
)
await send_text_to_room(self.client, self.room.room_id, text)
return
topic = self.args[0]
if topic == "rules":
text = "be nice to each other."
elif topic == "commands":
text = "Available commands: today, week, month"
else:
text = "I dont know what you are talking about.."
await send_text_to_room(self.client, self.room.room_id, text)
async def _unknown_command(self):
await send_text_to_room(
self.client,
self.room.room_id,
f"Unknown command '{self.command}'. Try the 'help' command for more information.",
)

View File

@@ -0,0 +1,105 @@
import sys
import time
import logging
import collections
from os.path import exists
import json
import dateutil.rrule as rrule
import caldav
import pytz
from icalendar import Calendar, Event
import datetime
import caldav
class CaldavHandler:
def get_config(self, path):
with open("./config.json") as f:
return json.load(f)
def __init__(self):
self._config_path = "./config.json"
if not exists(self._config_path):
print("No config file found. Aborting.")
self._config = self.get_config(self._config_path)
self._caldavclient = caldav.DAVClient(self._config["caldav"]["url"],
username=self._config["caldav"]["username"],
password=self._config["caldav"]["password"])
def get_event_map(self, events, time_span):
result = {}
for event in events:
event.load()
e = event.instance.vevent
list_of_occurences = []
if e.getChildValue('rrule') == None:
list_of_occurences.append(e.getChildValue('dtstart'))
else:
#recurring events only return with the date of the first recurring event ever created
#we have to use rrule manually to expand the dates to display them correctly
rule = rrule.rrulestr(e.getChildValue('rrule'), dtstart=e.getChildValue('dtstart'))
list_of_occurences = rule.between(datetime.datetime.now(datetime.timezone.utc),
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=time_span),
inc=True)
for datetime_of_event in list_of_occurences:
datestr = datetime_of_event.strftime("%x")
eventstr = str( "(" + e.dtstart.value.strftime("%H:%M") + " - " + e.dtend.value.strftime("%H:%M") + ") " + e.summary.value)
if datestr in result:
result[datestr].append(eventstr)
else:
result[datestr] = [ eventstr ]
#sort start times
for key in result:
result[key].sort()
od = collections.OrderedDict(sorted(result.items()))
return od
def event_map_to_string(self, event_map):
result = ""
for k, v in event_map.items():
dt_string = k
format = "%x"
dt_object = datetime.datetime.strptime(dt_string, format)
result += "##### " + dt_object.strftime("%A, %d. of %B") + ":\n"
for event in v:
result += "* " + event + "\n"
print(result)
return result
def event_to_string(self, event):
event.load()
e = event.instance.vevent
datestr = e.dtstart.value.strftime("%X")
return str( "(" + e.dtstart.value.strftime("%a, %-d. %b - %H:%M") + " - " + e.dtend.value.strftime("%H:%M") + ") " + e.summary.value)
def get_events(self, start_time, end_time):
cal = self._caldavclient.principal().calendars()
for ca in cal:
events = ca.date_search(start=start_time, end=end_time, expand=True)
return events
def send_events(self, events, time_span):
return self.event_map_to_string(self.get_event_map(events, time_span))
def print_month(self):
events = self.get_events(datetime.date.today(), datetime.date.today() + datetime.timedelta(days=30))
return self.send_events(events, 30)
def print_week(self):
events = self.get_events(datetime.date.today(), datetime.date.today() + datetime.timedelta(days=7))
return self.send_events(events, 7)
def print_today(self):
events = self.get_events(datetime.date.today(), datetime.date.today() + datetime.timedelta(days=1))
return self.send_events(events, 1)

View File

@@ -0,0 +1,198 @@
import logging
from nio import (
AsyncClient,
InviteMemberEvent,
JoinError,
MatrixRoom,
MegolmEvent,
RoomGetEventError,
RoomMessageText,
UnknownEvent,
)
from my_project_name.bot_commands import Command
from my_project_name.chat_functions import make_pill, react_to_event, send_text_to_room
from my_project_name.config import Config
from my_project_name.message_responses import Message
from my_project_name.storage import Storage
logger = logging.getLogger(__name__)
class Callbacks:
def __init__(self, client: AsyncClient, store: Storage, config: Config):
"""
Args:
client: nio client used to interact with matrix.
store: Bot storage.
config: Bot configuration parameters.
"""
self.client = client
self.store = store
self.config = config
self.command_prefix = config.command_prefix
async def message(self, room: MatrixRoom, event: RoomMessageText) -> None:
"""Callback for when a message event is received
Args:
room: The room the event came from.
event: The event defining the message.
"""
# Extract the message text
msg = event.body
# Ignore messages from ourselves
if event.sender == self.client.user:
return
logger.debug(
f"Bot message received for room {room.display_name} | "
f"{room.user_name(event.sender)}: {msg}"
)
# Process as message if in a public room without command prefix
has_command_prefix = msg.startswith(self.command_prefix)
# room.is_group is often a DM, but not always.
# room.is_group does not allow room aliases
# room.member_count > 2 ... we assume a public room
# room.member_count <= 2 ... we assume a DM
if not has_command_prefix and room.member_count > 2:
# General message listener
message = Message(self.client, self.store, self.config, msg, room, event)
await message.process()
return
# Otherwise if this is in a 1-1 with the bot or features a command prefix,
# treat it as a command
if has_command_prefix:
# Remove the command prefix
msg = msg[len(self.command_prefix) :]
command = Command(self.client, self.store, self.config, msg, room, event)
await command.process()
async def invite(self, room: MatrixRoom, event: InviteMemberEvent) -> None:
"""Callback for when an invite is received. Join the room specified in the invite.
Args:
room: The room that we are invited to.
event: The invite event.
"""
logger.debug(f"Got invite to {room.room_id} from {event.sender}.")
# Attempt to join 3 times before giving up
for attempt in range(3):
result = await self.client.join(room.room_id)
if type(result) == JoinError:
logger.error(
f"Error joining room {room.room_id} (attempt %d): %s",
attempt,
result.message,
)
else:
break
else:
logger.error("Unable to join room: %s", room.room_id)
# Successfully joined room
logger.info(f"Joined {room.room_id}")
async def _reaction(
self, room: MatrixRoom, event: UnknownEvent, reacted_to_id: str
) -> None:
"""A reaction was sent to one of our messages. Let's send a reply acknowledging it.
Args:
room: The room the reaction was sent in.
event: The reaction event.
reacted_to_id: The event ID that the reaction points to.
"""
logger.debug(f"Got reaction to {room.room_id} from {event.sender}.")
# Get the original event that was reacted to
event_response = await self.client.room_get_event(room.room_id, reacted_to_id)
if isinstance(event_response, RoomGetEventError):
logger.warning(
"Error getting event that was reacted to (%s)", reacted_to_id
)
return
reacted_to_event = event_response.event
# Only acknowledge reactions to events that we sent
if reacted_to_event.sender != self.config.user_id:
return
# Send a message acknowledging the reaction
reaction_sender_pill = make_pill(event.sender)
reaction_content = (
event.source.get("content", {}).get("m.relates_to", {}).get("key")
)
message = (
f"{reaction_sender_pill} reacted to this event with `{reaction_content}`!"
)
await send_text_to_room(
self.client,
room.room_id,
message,
reply_to_event_id=reacted_to_id,
)
async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None:
"""Callback for when an event fails to decrypt. Inform the user.
Args:
room: The room that the event that we were unable to decrypt is in.
event: The encrypted event that we were unable to decrypt.
"""
logger.error(
f"Failed to decrypt event '{event.event_id}' in room '{room.room_id}'!"
f"\n\n"
f"Tip: try using a different device ID in your config file and restart."
f"\n\n"
f"If all else fails, delete your store directory and let the bot recreate "
f"it (your reminders will NOT be deleted, but the bot may respond to existing "
f"commands a second time)."
)
red_x_and_lock_emoji = "❌ 🔐"
# React to the undecryptable event with some emoji
await react_to_event(
self.client,
room.room_id,
event.event_id,
red_x_and_lock_emoji,
)
async def unknown(self, room: MatrixRoom, event: UnknownEvent) -> None:
"""Callback for when an event with a type that is unknown to matrix-nio is received.
Currently this is used for reaction events, which are not yet part of a released
matrix spec (and are thus unknown to nio).
Args:
room: The room the reaction was sent in.
event: The event itself.
"""
if event.type == "m.reaction":
# Get the ID of the event this was a reaction to
relation_dict = event.source.get("content", {}).get("m.relates_to", {})
reacted_to = relation_dict.get("event_id")
if reacted_to and relation_dict.get("rel_type") == "m.annotation":
await self._reaction(room, event, reacted_to)
return
logger.debug(
f"Got unknown event with type to {event.type} from {event.sender} in {room.room_id}."
)

View File

@@ -0,0 +1,154 @@
import logging
from typing import Optional, Union
from markdown import markdown
from nio import (
AsyncClient,
ErrorResponse,
MatrixRoom,
MegolmEvent,
Response,
RoomSendResponse,
SendRetryError,
)
logger = logging.getLogger(__name__)
async def send_text_to_room(
client: AsyncClient,
room_id: str,
message: str,
notice: bool = True,
markdown_convert: bool = True,
reply_to_event_id: Optional[str] = None,
) -> Union[RoomSendResponse, ErrorResponse]:
"""Send text to a matrix room.
Args:
client: The client to communicate to matrix with.
room_id: The ID of the room to send the message to.
message: The message content.
notice: Whether the message should be sent with an "m.notice" message type
(will not ping users).
markdown_convert: Whether to convert the message content to markdown.
Defaults to true.
reply_to_event_id: Whether this message is a reply to another event. The event
ID this is message is a reply to.
Returns:
A RoomSendResponse if the request was successful, else an ErrorResponse.
"""
# Determine whether to ping room members or not
msgtype = "m.notice" if notice else "m.text"
content = {
"msgtype": msgtype,
"format": "org.matrix.custom.html",
"body": message,
}
if markdown_convert:
content["formatted_body"] = markdown(message)
if reply_to_event_id:
content["m.relates_to"] = {"m.in_reply_to": {"event_id": reply_to_event_id}}
try:
return await client.room_send(
room_id,
"m.room.message",
content,
ignore_unverified_devices=True,
)
except SendRetryError:
logger.exception(f"Unable to send message response to {room_id}")
def make_pill(user_id: str, displayname: str = None) -> str:
"""Convert a user ID (and optionally a display name) to a formatted user 'pill'
Args:
user_id: The MXID of the user.
displayname: An optional displayname. Clients like Element will figure out the
correct display name no matter what, but other clients may not. If not
provided, the MXID will be used instead.
Returns:
The formatted user pill.
"""
if not displayname:
# Use the user ID as the displayname if not provided
displayname = user_id
return f'<a href="https://matrix.to/#/{user_id}">{displayname}</a>'
async def react_to_event(
client: AsyncClient,
room_id: str,
event_id: str,
reaction_text: str,
) -> Union[Response, ErrorResponse]:
"""Reacts to a given event in a room with the given reaction text
Args:
client: The client to communicate to matrix with.
room_id: The ID of the room to send the message to.
event_id: The ID of the event to react to.
reaction_text: The string to react with. Can also be (one or more) emoji characters.
Returns:
A nio.Response or nio.ErrorResponse if an error occurred.
Raises:
SendRetryError: If the reaction was unable to be sent.
"""
content = {
"m.relates_to": {
"rel_type": "m.annotation",
"event_id": event_id,
"key": reaction_text,
}
}
return await client.room_send(
room_id,
"m.reaction",
content,
ignore_unverified_devices=True,
)
async def decryption_failure(self, room: MatrixRoom, event: MegolmEvent) -> None:
"""Callback for when an event fails to decrypt. Inform the user"""
logger.error(
f"Failed to decrypt event '{event.event_id}' in room '{room.room_id}'!"
f"\n\n"
f"Tip: try using a different device ID in your config file and restart."
f"\n\n"
f"If all else fails, delete your store directory and let the bot recreate "
f"it (your reminders will NOT be deleted, but the bot may respond to existing "
f"commands a second time)."
)
user_msg = (
"Unable to decrypt this message. "
"Check whether you've chosen to only encrypt to trusted devices."
)
await send_text_to_room(
self.client,
room.room_id,
user_msg,
reply_to_event_id=event.event_id,
)

136
my_project_name/config.py Normal file
View File

@@ -0,0 +1,136 @@
import logging
import os
import re
import sys
from typing import Any, List, Optional
import yaml
from my_project_name.errors import ConfigError
logger = logging.getLogger()
logging.getLogger("peewee").setLevel(
logging.INFO
) # Prevent debug messages from peewee lib
class Config:
"""Creates a Config object from a YAML-encoded config file from a given filepath"""
def __init__(self, filepath: str):
self.filepath = filepath
if not os.path.isfile(filepath):
raise ConfigError(f"Config file '{filepath}' does not exist")
# Load in the config file at the given filepath
with open(filepath) as file_stream:
self.config_dict = yaml.safe_load(file_stream.read())
# Parse and validate config options
self._parse_config_values()
def _parse_config_values(self):
"""Read and validate each config option"""
# Logging setup
formatter = logging.Formatter(
"%(asctime)s | %(name)s [%(levelname)s] %(message)s"
)
log_level = self._get_cfg(["logging", "level"], default="INFO")
logger.setLevel(log_level)
file_logging_enabled = self._get_cfg(
["logging", "file_logging", "enabled"], default=False
)
file_logging_filepath = self._get_cfg(
["logging", "file_logging", "filepath"], default="bot.log"
)
if file_logging_enabled:
handler = logging.FileHandler(file_logging_filepath)
handler.setFormatter(formatter)
logger.addHandler(handler)
console_logging_enabled = self._get_cfg(
["logging", "console_logging", "enabled"], default=True
)
if console_logging_enabled:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger.addHandler(handler)
# Storage setup
self.store_path = self._get_cfg(["storage", "store_path"], required=True)
# Create the store folder if it doesn't exist
if not os.path.isdir(self.store_path):
if not os.path.exists(self.store_path):
os.mkdir(self.store_path)
else:
raise ConfigError(
f"storage.store_path '{self.store_path}' is not a directory"
)
# Database setup
database_path = self._get_cfg(["storage", "database"], required=True)
# Support both SQLite and Postgres backends
# Determine which one the user intends
sqlite_scheme = "sqlite://"
postgres_scheme = "postgres://"
if database_path.startswith(sqlite_scheme):
self.database = {
"type": "sqlite",
"connection_string": database_path[len(sqlite_scheme) :],
}
elif database_path.startswith(postgres_scheme):
self.database = {"type": "postgres", "connection_string": database_path}
else:
raise ConfigError("Invalid connection string for storage.database")
# Matrix bot account setup
self.user_id = self._get_cfg(["matrix", "user_id"], required=True)
if not re.match("@.*:.*", self.user_id):
raise ConfigError("matrix.user_id must be in the form @name:domain")
self.user_password = self._get_cfg(["matrix", "user_password"], required=False)
self.user_token = self._get_cfg(["matrix", "user_token"], required=False)
if not self.user_token and not self.user_password:
raise ConfigError("Must supply either user token or password")
self.device_id = self._get_cfg(["matrix", "device_id"], required=True)
self.device_name = self._get_cfg(
["matrix", "device_name"], default="nio-template"
)
self.homeserver_url = self._get_cfg(["matrix", "homeserver_url"], required=True)
self.command_prefix = self._get_cfg(["command_prefix"], default="!c") + " "
def _get_cfg(
self,
path: List[str],
default: Optional[Any] = None,
required: Optional[bool] = True,
) -> Any:
"""Get a config option from a path and option name, specifying whether it is
required.
Raises:
ConfigError: If required is True and the object is not found (and there is
no default value provided), a ConfigError will be raised.
"""
# Sift through the the config until we reach our option
config = self.config_dict
for name in path:
config = config.get(name)
# If at any point we don't get our expected option...
if config is None:
# Raise an error if it was required
if required and not default:
raise ConfigError(f"Config option {'.'.join(path)} is required")
# or return the default value
return default
# We found the option. Return it.
return config

12
my_project_name/errors.py Normal file
View File

@@ -0,0 +1,12 @@
# This file holds custom error types that you can define for your application.
class ConfigError(RuntimeError):
"""An error encountered during reading the config file.
Args:
msg: The message displayed to the user on error.
"""
def __init__(self, msg: str):
super(ConfigError, self).__init__("%s" % (msg,))

119
my_project_name/main.py Normal file
View File

@@ -0,0 +1,119 @@
#!/usr/bin/env python3
import asyncio
import logging
import sys
from time import sleep
from aiohttp import ClientConnectionError, ServerDisconnectedError
from nio import (
AsyncClient,
AsyncClientConfig,
InviteMemberEvent,
LocalProtocolError,
LoginError,
MegolmEvent,
RoomMessageText,
UnknownEvent,
)
from my_project_name.callbacks import Callbacks
from my_project_name.config import Config
from my_project_name.storage import Storage
logger = logging.getLogger(__name__)
async def main():
"""The first function that is run when starting the bot"""
# Read user-configured options from a config file.
# A different config file path can be specified as the first command line argument
if len(sys.argv) > 1:
config_path = sys.argv[1]
else:
config_path = "config.yaml"
# Read the parsed config file and create a Config object
config = Config(config_path)
# Configure the database
store = Storage(config.database)
# Configuration options for the AsyncClient
client_config = AsyncClientConfig(
max_limit_exceeded=0,
max_timeouts=0,
store_sync_tokens=True,
encryption_enabled=True,
)
# Initialize the matrix client
client = AsyncClient(
config.homeserver_url,
config.user_id,
device_id=config.device_id,
store_path=config.store_path,
config=client_config,
)
if config.user_token:
client.access_token = config.user_token
client.user_id = config.user_id
# Set up event callbacks
callbacks = Callbacks(client, store, config)
client.add_event_callback(callbacks.message, (RoomMessageText,))
client.add_event_callback(callbacks.invite, (InviteMemberEvent,))
client.add_event_callback(callbacks.decryption_failure, (MegolmEvent,))
client.add_event_callback(callbacks.unknown, (UnknownEvent,))
# Keep trying to reconnect on failure (with some time in-between)
while True:
try:
if config.user_token:
# Use token to log in
client.load_store()
# Sync encryption keys with the server
if client.should_upload_keys:
await client.keys_upload()
else:
# Try to login with the configured username/password
try:
login_response = await client.login(
password=config.user_password,
device_name=config.device_name,
)
# Check if login failed
if type(login_response) == LoginError:
logger.error("Failed to login: %s", login_response.message)
return False
except LocalProtocolError as e:
# There's an edge case here where the user hasn't installed the correct C
# dependencies. In that case, a LocalProtocolError is raised on login.
logger.fatal(
"Failed to login. Have you installed the correct dependencies? "
"https://github.com/poljar/matrix-nio#installation "
"Error: %s",
e,
)
return False
# Login succeeded!
logger.info(f"Logged in as {config.user_id}")
await client.sync_forever(timeout=30000, full_state=True)
except (ClientConnectionError, ServerDisconnectedError):
logger.warning("Unable to connect to homeserver, retrying in 15s...")
# Sleep so we don't bombard the server with login requests
sleep(15)
finally:
# Make sure to close the client connection on disconnect
await client.close()
# Run the main function in an asyncio event loop
asyncio.get_event_loop().run_until_complete(main())

View File

@@ -0,0 +1,52 @@
import logging
from nio import AsyncClient, MatrixRoom, RoomMessageText
from my_project_name.chat_functions import send_text_to_room
from my_project_name.config import Config
from my_project_name.storage import Storage
logger = logging.getLogger(__name__)
class Message:
def __init__(
self,
client: AsyncClient,
store: Storage,
config: Config,
message_content: str,
room: MatrixRoom,
event: RoomMessageText,
):
"""Initialize a new Message
Args:
client: nio client used to interact with matrix.
store: Bot storage.
config: Bot configuration parameters.
message_content: The body of the message.
room: The room the event came from.
event: The event defining the message.
"""
self.client = client
self.store = store
self.config = config
self.message_content = message_content
self.room = room
self.event = event
async def process(self) -> None:
"""Process and possibly respond to the message"""
if self.message_content.lower() == "hello world":
await self._hello_world()
async def _hello_world(self) -> None:
"""Say hello"""
text = "Hello, world!"
await send_text_to_room(self.client, self.room.room_id, text)

126
my_project_name/storage.py Normal file
View File

@@ -0,0 +1,126 @@
import logging
from typing import Any, Dict
# The latest migration version of the database.
#
# Database migrations are applied starting from the number specified in the database's
# `migration_version` table + 1 (or from 0 if this table does not yet exist) up until
# the version specified here.
#
# When a migration is performed, the `migration_version` table should be incremented.
latest_migration_version = 0
logger = logging.getLogger(__name__)
class Storage:
def __init__(self, database_config: Dict[str, str]):
"""Setup the database.
Runs an initial setup or migrations depending on whether a database file has already
been created.
Args:
database_config: a dictionary containing the following keys:
* type: A string, one of "sqlite" or "postgres".
* connection_string: A string, featuring a connection string that
be fed to each respective db library's `connect` method.
"""
self.conn = self._get_database_connection(
database_config["type"], database_config["connection_string"]
)
self.cursor = self.conn.cursor()
self.db_type = database_config["type"]
# Try to check the current migration version
migration_level = 0
try:
self._execute("SELECT version FROM migration_version")
row = self.cursor.fetchone()
migration_level = row[0]
except Exception:
self._initial_setup()
finally:
if migration_level < latest_migration_version:
self._run_migrations(migration_level)
logger.info(f"Database initialization of type '{self.db_type}' complete")
def _get_database_connection(
self, database_type: str, connection_string: str
) -> Any:
"""Creates and returns a connection to the database"""
if database_type == "sqlite":
import sqlite3
# Initialize a connection to the database, with autocommit on
return sqlite3.connect(connection_string, isolation_level=None)
elif database_type == "postgres":
import psycopg2
conn = psycopg2.connect(connection_string)
# Autocommit on
conn.set_isolation_level(0)
return conn
def _initial_setup(self) -> None:
"""Initial setup of the database"""
logger.info("Performing initial database setup...")
# Set up the migration_version table
self._execute(
"""
CREATE TABLE migration_version (
version INTEGER PRIMARY KEY
)
"""
)
# Initially set the migration version to 0
self._execute(
"""
INSERT INTO migration_version (
version
) VALUES (?)
""",
(0,),
)
# Set up any other necessary database tables here
logger.info("Database setup complete")
def _run_migrations(self, current_migration_version: int) -> None:
"""Execute database migrations. Migrates the database to the
`latest_migration_version`.
Args:
current_migration_version: The migration version that the database is
currently at.
"""
logger.debug("Checking for necessary database migrations...")
# if current_migration_version < 1:
# logger.info("Migrating the database from v0 to v1...")
#
# # Add new table, delete old ones, etc.
#
# # Update the stored migration version
# self._execute("UPDATE migration_version SET version = 1")
#
# logger.info("Database migrated to v1")
def _execute(self, *args) -> None:
"""A wrapper around cursor.execute that transforms placeholder ?'s to %s for postgres.
This allows for the support of queries that are compatible with both postgres and sqlite.
Args:
args: Arguments passed to cursor.execute.
"""
if self.db_type == "postgres":
self.cursor.execute(args[0].replace("?", "%s"), *args[1:])
else:
self.cursor.execute(*args)