@@ -22,10 +22,26 @@ import sys
import time
import traceback
from types import TracebackType
from typing import Dict, Iterable, Optional, Set, Tuple, Type, cast
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
List,
NoReturn,
Optional,
Set,
Tuple,
Type,
TypeVar,
cast,
)
import yaml
from matrix_common.versionstring import get_distribution_version_string
from typing_extensions import TypedDict
from twisted.internet import defer, reactor as reactor_
@@ -36,7 +52,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
@@ -173,6 +189,8 @@ end_error_exec_info: Optional[
Tuple[Type[BaseException], BaseException, TracebackType]
] = None
R = TypeVar("R")
class Store(
ClientIpBackgroundUpdateStore,
@@ -195,17 +213,19 @@ class Store(
PresenceBackgroundUpdateStore,
GroupServerWorkerStore,
):
def execute(self, f, *args, **kwargs):
def execute(self, f: Callable[..., R] , *args: Any , **kwargs: Any ) -> Awaitable[R] :
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql, *args):
def r(txn):
def execute_sql(self, sql: str , *args: object ) -> Awaitable[List[Tuple]] :
def r(txn: LoggingTransaction ) -> List[Tuple] :
txn.execute(sql, args)
return txn.fetchall()
return self.db_pool.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows):
def insert_many_txn(
self, txn: LoggingTransaction, table: str, headers: List[str], rows: List[Tuple]
) -> None:
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in headers),
@@ -218,14 +238,15 @@ class Store(
logger.exception("Failed to insert: %s", table)
raise
def set_room_is_public(self, room_id, is_public):
# Note: the parent method is an `async def`.
def set_room_is_public(self, room_id: str, is_public: bool) -> NoReturn:
raise Exception(
"Attempt to set room_is_public during port_db: database not empty?"
)
class MockHomeserver:
def __init__(self, config):
def __init__(self, config: HomeServerConfig ):
self.clock = Clock(reactor)
self.config = config
self.hostname = config.server.server_name
@@ -233,24 +254,30 @@ class MockHomeserver:
"matrix-synapse"
)
def get_clock(self):
def get_clock(self) -> Clock :
return self.clock
def get_reactor(self):
def get_reactor(self) -> ISynapseReactor :
return reactor
def get_instance_name(self):
def get_instance_name(self) -> str :
return "master"
class Porter:
def __init__(self, sqlite_config, progress, batch_size, hs_config):
def __init__(
self,
sqlite_config: Dict[str, Any],
progress: "Progress",
batch_size: int,
hs_config: HomeServerConfig,
):
self.sqlite_config = sqlite_config
self.progress = progress
self.batch_size = batch_size
self.hs_config = hs_config
async def setup_table(self, table):
async def setup_table(self, table: str ) -> Tuple[str, int, int, int, int] :
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
row = await self.postgres_store.db_pool.simple_select_one(
@@ -292,7 +319,7 @@ class Porter:
)
else:
def delete_all(txn):
def delete_all(txn: LoggingTransaction ) -> None :
txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
)
@@ -317,7 +344,7 @@ class Porter:
async def get_table_constraints(self) -> Dict[str, Set[str]]:
"""Returns a map of tables that have foreign key constraints to tables they depend on."""
def _get_constraints(txn):
def _get_constraints(txn: LoggingTransaction ) -> Dict[str, Set[str]] :
# We can pull the information about foreign key constraints out from
# the postgres schema tables.
sql = """
@@ -343,8 +370,13 @@ class Porter:
)
async def handle_table(
self, table, postgres_size, table_size, forward_chunk, backward_chunk
):
self,
table: str,
postgres_size: int,
table_size: int,
forward_chunk: int,
backward_chunk: int,
) -> None:
logger.info(
"Table %s: %i/%i (rows %i-%i) already ported",
table,
@@ -391,7 +423,9 @@ class Porter:
while True:
def r(txn):
def r(
txn: LoggingTransaction,
) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]:
forward_rows = []
backward_rows = []
if do_forward[0]:
@@ -418,6 +452,7 @@ class Porter:
)
if frows or brows:
assert headers is not None
if frows:
forward_chunk = max(row[0] for row in frows) + 1
if brows:
@@ -426,7 +461,8 @@ class Porter:
rows = frows + brows
rows = self._convert_rows(table, headers, rows)
def insert(txn):
def insert(txn: LoggingTransaction) -> None:
assert headers is not None
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
self.postgres_store.db_pool.simple_update_one_txn(
@@ -448,8 +484,12 @@ class Porter:
return
async def handle_search_table(
self, postgres_size, table_size, forward_chunk, backward_chunk
):
self,
postgres_size: int,
table_size: int,
forward_chunk: int,
backward_chunk: int,
) -> None:
select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es"
@@ -460,7 +500,7 @@ class Porter:
while True:
def r(txn):
def r(txn: LoggingTransaction ) -> Tuple[List[str], List[Tuple]] :
txn.execute(select, (forward_chunk, self.batch_size))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
@@ -474,7 +514,7 @@ class Porter:
# We have to treat event_search differently since it has a
# different structure in the two different databases.
def insert(txn):
def insert(txn: LoggingTransaction ) -> None :
sql = (
"INSERT INTO event_search (event_id, room_id, key,"
" sender, vector, origin_server_ts, stream_ordering)"
@@ -528,7 +568,7 @@ class Porter:
self,
db_config: DatabaseConnectionConfig,
allow_outdated_version: bool = False,
):
) -> Store :
"""Builds and returns a database store using the provided configuration.
Args:
@@ -556,7 +596,7 @@ class Porter:
return store
async def run_background_updates_on_postgres(self):
async def run_background_updates_on_postgres(self) -> None :
# Manually apply all background updates on the PostgreSQL database.
postgres_ready = (
await self.postgres_store.db_pool.updates.has_completed_background_updates()
@@ -568,12 +608,12 @@ class Porter:
self.progress.set_state("Running background updates on PostgreSQL")
while not postgres_ready:
await self.postgres_store.db_pool.updates.do_next_background_update(100 )
await self.postgres_store.db_pool.updates.do_next_background_update(True )
postgres_ready = await (
self.postgres_store.db_pool.updates.has_completed_background_updates()
)
async def run(self):
async def run(self) -> None :
"""Ports the SQLite database to a PostgreSQL database.
When a fatal error is met, its message is assigned to the global "end_error"
@@ -609,7 +649,7 @@ class Porter:
self.progress.set_state("Creating port tables")
def create_port_table(txn):
def create_port_table(txn: LoggingTransaction ) -> None :
txn.execute(
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE,"
@@ -622,7 +662,7 @@ class Porter:
# We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not
# ported across.
def alter_table(txn):
def alter_table(txn: LoggingTransaction ) -> None :
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" RENAME rowid TO forward_rowid"
@@ -742,7 +782,9 @@ class Porter:
finally:
reactor.stop()
def _convert_rows(self, table, headers, rows):
def _convert_rows(
self, table: str, headers: List[str], rows: List[Tuple]
) -> List[Tuple]:
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
@@ -750,7 +792,7 @@ class Porter:
class BadValueException(Exception):
pass
def conv(j, col):
def conv(j: int , col: object ) -> object :
if j in bool_cols:
return bool(col)
if isinstance(col, bytes):
@@ -776,7 +818,7 @@ class Porter:
return outrows
async def _setup_sent_transactions(self):
async def _setup_sent_transactions(self) -> Tuple[int, int, int] :
# Only save things from the last day
yesterday = int(time.time() * 1000) - 86400000
@@ -788,10 +830,10 @@ class Porter:
")"
)
def r(txn):
def r(txn: LoggingTransaction ) -> Tuple[List[str], List[Tuple]] :
txn.execute(select)
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
headers: List[str] = [column[0] for column in txn.description]
ts_ind = headers.index("ts")
@@ -805,7 +847,7 @@ class Porter:
if inserted_rows:
max_inserted_rowid = max(r[0] for r in rows)
def insert(txn):
def insert(txn: LoggingTransaction ) -> None :
self.postgres_store.insert_many_txn(
txn, "sent_transactions", headers[1:], rows
)
@@ -814,7 +856,7 @@ class Porter:
else:
max_inserted_rowid = 0
def get_start_id(txn):
def get_start_id(txn: LoggingTransaction ) -> int :
txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1",
@@ -839,12 +881,13 @@ class Porter:
},
)
def get_sent_table_size(txn):
def get_sent_table_size(txn: LoggingTransaction ) -> int :
txn.execute(
"SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
)
(size,) = txn.fetchone()
return int(size)
result = txn.fetchone()
assert result is not None
return int(result[0])
remaining_count = await self.sqlite_store.execute(get_sent_table_size)
@@ -852,25 +895,35 @@ class Porter:
return next_chunk, inserted_rows, total_count
async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
frows = await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
async def _get_remaining_count_to_port(
self, table: str, forward_chunk: int, backward_chunk: int
) -> int:
frows = cast(
List[Tuple[int]],
await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
),
)
brows = await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
brows = cast(
List[Tuple[int]],
await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
),
)
return frows[0][0] + brows[0][0]
async def _get_already_ported_count(self, table):
async def _get_already_ported_count(self, table: str ) -> int :
rows = await self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,)
)
return rows[0][0]
async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
async def _get_total_count_to_port(
self, table: str, forward_chunk: int, backward_chunk: int
) -> Tuple[int, int]:
remaining, done = await make_deferred_yieldable(
defer.gatherResults(
[
@@ -891,14 +944,17 @@ class Porter:
return done, remaining + done
async def _setup_state_group_id_seq(self) -> None:
curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
curr_id: Optional[
int
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
)
if not curr_id:
return
def r(txn):
def r(txn: LoggingTransaction) -> None:
assert curr_id is not None
next_id = curr_id + 1
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
@@ -909,7 +965,7 @@ class Porter:
"setup_user_id_seq", find_max_generated_user_id_localpart
)
def r(txn):
def r(txn: LoggingTransaction ) -> None :
next_id = curr_id + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
@@ -931,7 +987,7 @@ class Porter:
allow_none=True,
)
def _setup_events_stream_seqs_set_pos(txn):
def _setup_events_stream_seqs_set_pos(txn: LoggingTransaction ) -> None :
if curr_forward_id:
txn.execute(
"ALTER SEQUENCE events_stream_seq RESTART WITH %s",
@@ -955,17 +1011,20 @@ class Porter:
"""Set a sequence to the correct value."""
current_stream_ids = []
for stream_id_table in stream_id_tables:
max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
table=stream_id_table,
keyvalues={},
retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True,
max_stream_id = cast(
int,
await self.sqlite_store.db_pool.simple_select_one_onecol(
table=stream_id_table,
keyvalues={},
retcol="COALESCE(MAX(stream_id), 1)",
allow_none=True,
),
)
current_stream_ids.append(max_stream_id)
next_id = max(current_stream_ids) + 1
def r(txn):
def r(txn: LoggingTransaction ) -> None :
sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name,)
txn.execute(sql + " %s", (next_id,))
@@ -974,14 +1033,18 @@ class Porter:
)
async def _setup_auth_chain_sequence(self) -> None:
curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
curr_chain_id: Optional[
int
] = await self.sqlite_store.db_pool.simple_select_one_onecol(
table="event_auth_chains",
keyvalues={},
retcol="MAX(chain_id)",
allow_none=True,
)
def r(txn):
def r(txn: LoggingTransaction) -> None:
# Presumably there is at least one row in event_auth_chains.
assert curr_chain_id is not None
txn.execute(
"ALTER SEQUENCE event_auth_chain_id RESTART WITH %s",
(curr_chain_id + 1,),
@@ -999,15 +1062,22 @@ class Porter:
##############################################
class Progress(object):
class TableProgress(TypedDict):
start: int
num_done: int
total: int
perc: int
class Progress:
"""Used to report progress of the port"""
def __init__(self):
self.tables = {}
def __init__(self) -> None :
self.tables: Dict[str, TableProgress] = {}
self.start_time = int(time.time())
def add_table(self, table, cur, size):
def add_table(self, table: str , cur: int , size: int ) -> None :
self.tables[table] = {
"start": cur,
"num_done": cur,
@@ -1015,19 +1085,22 @@ class Progress(object):
"perc": int(cur * 100 / size),
}
def update(self, table, num_done):
def update(self, table: str , num_done: int ) -> None :
data = self.tables[table]
data["num_done"] = num_done
data["perc"] = int(num_done * 100 / data["total"])
def done(self):
def done(self) -> None:
pass
def set_state(self, state: str) -> None:
pass
class CursesProgress(Progress):
"""Reports progress to a curses window"""
def __init__(self, stdscr):
def __init__(self, stdscr: "curses.window" ):
self.stdscr = stdscr
curses.use_default_colors()
@@ -1045,7 +1118,7 @@ class CursesProgress(Progress):
super(CursesProgress, self).__init__()
def update(self, table, num_done):
def update(self, table: str , num_done: int ) -> None :
super(CursesProgress, self).update(table, num_done)
self.total_processed = 0
@@ -1056,7 +1129,7 @@ class CursesProgress(Progress):
self.render()
def render(self, force=False):
def render(self, force: bool = False) -> None :
now = time.time()
if not force and now - self.last_update < 0.2:
@@ -1128,12 +1201,12 @@ class CursesProgress(Progress):
self.stdscr.refresh()
self.last_update = time.time()
def done(self):
def done(self) -> None :
self.finished = True
self.render(True)
self.stdscr.getch()
def set_state(self, state):
def set_state(self, state: str ) -> None :
self.stdscr.clear()
self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
self.stdscr.refresh()
@@ -1142,7 +1215,7 @@ class CursesProgress(Progress):
class TerminalProgress(Progress):
"""Just prints progress to the terminal"""
def update(self, table, num_done):
def update(self, table: str , num_done: int ) -> None :
super(TerminalProgress, self).update(table, num_done)
data = self.tables[table]
@@ -1151,7 +1224,7 @@ class TerminalProgress(Progress):
"%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
)
def set_state(self, state):
def set_state(self, state: str ) -> None :
print(state + "...")
@@ -1159,7 +1232,7 @@ class TerminalProgress(Progress):
##############################################
def main():
def main() -> None :
parser = argparse.ArgumentParser(
description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database."
@@ -1225,7 +1298,7 @@ def main():
config = HomeServerConfig()
config.parse_config_dict(hs_config, "", "")
def start(stdscr=None):
def start(stdscr: Optional["curses.window"] = None) -> None :
progress: Progress
if stdscr:
progress = CursesProgress(stdscr)
@@ -1240,7 +1313,7 @@ def main():
)
@defer.inlineCallbacks
def run():
def run() -> Generator["defer.Deferred[Any]", Any, None] :
with LoggingContext("synapse_port_db_run"):
yield defer.ensureDeferred(porter.run())