|
|
@@ -16,42 +16,62 @@ import logging |
|
|
|
import threading |
|
|
|
from collections import OrderedDict |
|
|
|
from contextlib import contextmanager |
|
|
|
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union |
|
|
|
from types import TracebackType |
|
|
|
from typing import ( |
|
|
|
AsyncContextManager, |
|
|
|
ContextManager, |
|
|
|
Dict, |
|
|
|
Generator, |
|
|
|
Generic, |
|
|
|
Iterable, |
|
|
|
List, |
|
|
|
Optional, |
|
|
|
Sequence, |
|
|
|
Set, |
|
|
|
Tuple, |
|
|
|
Type, |
|
|
|
TypeVar, |
|
|
|
Union, |
|
|
|
cast, |
|
|
|
) |
|
|
|
|
|
|
|
import attr |
|
|
|
from sortedcontainers import SortedSet |
|
|
|
|
|
|
|
from synapse.metrics.background_process_metrics import run_as_background_process |
|
|
|
from synapse.storage.database import DatabasePool, LoggingTransaction |
|
|
|
from synapse.storage.database import ( |
|
|
|
DatabasePool, |
|
|
|
LoggingDatabaseConnection, |
|
|
|
LoggingTransaction, |
|
|
|
) |
|
|
|
from synapse.storage.types import Cursor |
|
|
|
from synapse.storage.util.sequence import PostgresSequenceGenerator |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
|
|
|
|
|
|
class IdGenerator: |
|
|
|
def __init__(self, db_conn, table, column): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
db_conn: LoggingDatabaseConnection, |
|
|
|
table: str, |
|
|
|
column: str, |
|
|
|
): |
|
|
|
self._lock = threading.Lock() |
|
|
|
self._next_id = _load_current_id(db_conn, table, column) |
|
|
|
|
|
|
|
def get_next(self): |
|
|
|
def get_next(self) -> int: |
|
|
|
with self._lock: |
|
|
|
self._next_id += 1 |
|
|
|
return self._next_id |
|
|
|
|
|
|
|
|
|
|
|
def _load_current_id(db_conn, table, column, step=1): |
|
|
|
""" |
|
|
|
|
|
|
|
Args: |
|
|
|
db_conn (object): |
|
|
|
table (str): |
|
|
|
column (str): |
|
|
|
step (int): |
|
|
|
|
|
|
|
Returns: |
|
|
|
int |
|
|
|
""" |
|
|
|
def _load_current_id( |
|
|
|
db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1 |
|
|
|
) -> int: |
|
|
|
# debug logging for https://github.com/matrix-org/synapse/issues/7968 |
|
|
|
logger.info("initialising stream generator for %s(%s)", table, column) |
|
|
|
cur = db_conn.cursor(txn_name="_load_current_id") |
|
|
@@ -59,7 +79,9 @@ def _load_current_id(db_conn, table, column, step=1): |
|
|
|
cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) |
|
|
|
else: |
|
|
|
cur.execute("SELECT MIN(%s) FROM %s" % (column, table)) |
|
|
|
(val,) = cur.fetchone() |
|
|
|
result = cur.fetchone() |
|
|
|
assert result is not None |
|
|
|
(val,) = result |
|
|
|
cur.close() |
|
|
|
current_id = int(val) if val else step |
|
|
|
return (max if step > 0 else min)(current_id, step) |
|
|
@@ -93,16 +115,16 @@ class StreamIdGenerator: |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
db_conn, |
|
|
|
table, |
|
|
|
column, |
|
|
|
db_conn: LoggingDatabaseConnection, |
|
|
|
table: str, |
|
|
|
column: str, |
|
|
|
extra_tables: Iterable[Tuple[str, str]] = (), |
|
|
|
step=1, |
|
|
|
): |
|
|
|
step: int = 1, |
|
|
|
) -> None: |
|
|
|
assert step != 0 |
|
|
|
self._lock = threading.Lock() |
|
|
|
self._step = step |
|
|
|
self._current = _load_current_id(db_conn, table, column, step) |
|
|
|
self._step: int = step |
|
|
|
self._current: int = _load_current_id(db_conn, table, column, step) |
|
|
|
for table, column in extra_tables: |
|
|
|
self._current = (max if step > 0 else min)( |
|
|
|
self._current, _load_current_id(db_conn, table, column, step) |
|
|
@@ -115,7 +137,7 @@ class StreamIdGenerator: |
|
|
|
# The key and values are the same, but we never look at the values. |
|
|
|
self._unfinished_ids: OrderedDict[int, int] = OrderedDict() |
|
|
|
|
|
|
|
def get_next(self): |
|
|
|
def get_next(self) -> AsyncContextManager[int]: |
|
|
|
""" |
|
|
|
Usage: |
|
|
|
async with stream_id_gen.get_next() as stream_id: |
|
|
@@ -128,7 +150,7 @@ class StreamIdGenerator: |
|
|
|
self._unfinished_ids[next_id] = next_id |
|
|
|
|
|
|
|
@contextmanager |
|
|
|
def manager(): |
|
|
|
def manager() -> Generator[int, None, None]: |
|
|
|
try: |
|
|
|
yield next_id |
|
|
|
finally: |
|
|
@@ -137,7 +159,7 @@ class StreamIdGenerator: |
|
|
|
|
|
|
|
return _AsyncCtxManagerWrapper(manager()) |
|
|
|
|
|
|
|
def get_next_mult(self, n): |
|
|
|
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: |
|
|
|
""" |
|
|
|
Usage: |
|
|
|
async with stream_id_gen.get_next(n) as stream_ids: |
|
|
@@ -155,7 +177,7 @@ class StreamIdGenerator: |
|
|
|
self._unfinished_ids[next_id] = next_id |
|
|
|
|
|
|
|
@contextmanager |
|
|
|
def manager(): |
|
|
|
def manager() -> Generator[Sequence[int], None, None]: |
|
|
|
try: |
|
|
|
yield next_ids |
|
|
|
finally: |
|
|
@@ -215,7 +237,7 @@ class MultiWriterIdGenerator: |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
db_conn, |
|
|
|
db_conn: LoggingDatabaseConnection, |
|
|
|
db: DatabasePool, |
|
|
|
stream_name: str, |
|
|
|
instance_name: str, |
|
|
@@ -223,7 +245,7 @@ class MultiWriterIdGenerator: |
|
|
|
sequence_name: str, |
|
|
|
writers: List[str], |
|
|
|
positive: bool = True, |
|
|
|
): |
|
|
|
) -> None: |
|
|
|
self._db = db |
|
|
|
self._stream_name = stream_name |
|
|
|
self._instance_name = instance_name |
|
|
@@ -285,9 +307,9 @@ class MultiWriterIdGenerator: |
|
|
|
|
|
|
|
def _load_current_ids( |
|
|
|
self, |
|
|
|
db_conn, |
|
|
|
db_conn: LoggingDatabaseConnection, |
|
|
|
tables: List[Tuple[str, str, str]], |
|
|
|
): |
|
|
|
) -> None: |
|
|
|
cur = db_conn.cursor(txn_name="_load_current_ids") |
|
|
|
|
|
|
|
# Load the current positions of all writers for the stream. |
|
|
@@ -335,7 +357,9 @@ class MultiWriterIdGenerator: |
|
|
|
"agg": "MAX" if self._positive else "-MIN", |
|
|
|
} |
|
|
|
cur.execute(sql) |
|
|
|
(stream_id,) = cur.fetchone() |
|
|
|
result = cur.fetchone() |
|
|
|
assert result is not None |
|
|
|
(stream_id,) = result |
|
|
|
|
|
|
|
max_stream_id = max(max_stream_id, stream_id) |
|
|
|
|
|
|
@@ -354,7 +378,7 @@ class MultiWriterIdGenerator: |
|
|
|
|
|
|
|
self._persisted_upto_position = min_stream_id |
|
|
|
|
|
|
|
rows = [] |
|
|
|
rows: List[Tuple[str, int]] = [] |
|
|
|
for table, instance_column, id_column in tables: |
|
|
|
sql = """ |
|
|
|
SELECT %(instance)s, %(id)s FROM %(table)s |
|
|
@@ -367,7 +391,8 @@ class MultiWriterIdGenerator: |
|
|
|
} |
|
|
|
cur.execute(sql, (min_stream_id * self._return_factor,)) |
|
|
|
|
|
|
|
rows.extend(cur) |
|
|
|
# Cast safety: this corresponds to the types returned by the query above. |
|
|
|
rows.extend(cast(Iterable[Tuple[str, int]], cur)) |
|
|
|
|
|
|
|
# Sort so that we handle rows in order for each instance. |
|
|
|
rows.sort() |
|
|
@@ -385,13 +410,13 @@ class MultiWriterIdGenerator: |
|
|
|
|
|
|
|
cur.close() |
|
|
|
|
|
|
|
def _load_next_id_txn(self, txn) -> int: |
|
|
|
def _load_next_id_txn(self, txn: Cursor) -> int: |
|
|
|
return self._sequence_gen.get_next_id_txn(txn) |
|
|
|
|
|
|
|
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: |
|
|
|
def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]: |
|
|
|
return self._sequence_gen.get_next_mult_txn(txn, n) |
|
|
|
|
|
|
|
def get_next(self): |
|
|
|
def get_next(self) -> AsyncContextManager[int]: |
|
|
|
""" |
|
|
|
Usage: |
|
|
|
async with stream_id_gen.get_next() as stream_id: |
|
|
@@ -403,9 +428,12 @@ class MultiWriterIdGenerator: |
|
|
|
if self._writers and self._instance_name not in self._writers: |
|
|
|
raise Exception("Tried to allocate stream ID on non-writer") |
|
|
|
|
|
|
|
return _MultiWriterCtxManager(self) |
|
|
|
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids, |
|
|
|
# controls the return type. If `None` or omitted, the context manager yields |
|
|
|
# a single integer stream_id; otherwise it yields a list of stream_ids. |
|
|
|
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) |
|
|
|
|
|
|
|
def get_next_mult(self, n: int): |
|
|
|
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: |
|
|
|
""" |
|
|
|
Usage: |
|
|
|
async with stream_id_gen.get_next_mult(5) as stream_ids: |
|
|
@@ -417,9 +445,10 @@ class MultiWriterIdGenerator: |
|
|
|
if self._writers and self._instance_name not in self._writers: |
|
|
|
raise Exception("Tried to allocate stream ID on non-writer") |
|
|
|
|
|
|
|
return _MultiWriterCtxManager(self, n) |
|
|
|
# Cast safety: see get_next. |
|
|
|
return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n)) |
|
|
|
|
|
|
|
def get_next_txn(self, txn: LoggingTransaction): |
|
|
|
def get_next_txn(self, txn: LoggingTransaction) -> int: |
|
|
|
""" |
|
|
|
Usage: |
|
|
|
|
|
|
@@ -457,7 +486,7 @@ class MultiWriterIdGenerator: |
|
|
|
|
|
|
|
return self._return_factor * next_id |
|
|
|
|
|
|
|
def _mark_id_as_finished(self, next_id: int): |
|
|
|
def _mark_id_as_finished(self, next_id: int) -> None: |
|
|
|
"""The ID has finished being processed so we should advance the |
|
|
|
current position if possible. |
|
|
|
""" |
|
|
@@ -534,7 +563,7 @@ class MultiWriterIdGenerator: |
|
|
|
for name, i in self._current_positions.items() |
|
|
|
} |
|
|
|
|
|
|
|
def advance(self, instance_name: str, new_id: int): |
|
|
|
def advance(self, instance_name: str, new_id: int) -> None: |
|
|
|
"""Advance the position of the named writer to the given ID, if greater |
|
|
|
than existing entry. |
|
|
|
""" |
|
|
@@ -560,7 +589,7 @@ class MultiWriterIdGenerator: |
|
|
|
with self._lock: |
|
|
|
return self._return_factor * self._persisted_upto_position |
|
|
|
|
|
|
|
def _add_persisted_position(self, new_id: int): |
|
|
|
def _add_persisted_position(self, new_id: int) -> None: |
|
|
|
"""Record that we have persisted a position. |
|
|
|
|
|
|
|
This is used to keep the `_current_positions` up to date. |
|
|
@@ -606,7 +635,7 @@ class MultiWriterIdGenerator: |
|
|
|
# do. |
|
|
|
break |
|
|
|
|
|
|
|
def _update_stream_positions_table_txn(self, txn: Cursor): |
|
|
|
def _update_stream_positions_table_txn(self, txn: Cursor) -> None: |
|
|
|
"""Update the `stream_positions` table with newly persisted position.""" |
|
|
|
|
|
|
|
if not self._writers: |
|
|
@@ -628,20 +657,25 @@ class MultiWriterIdGenerator: |
|
|
|
txn.execute(sql, (self._stream_name, self._instance_name, pos)) |
|
|
|
|
|
|
|
|
|
|
|
@attr.s(slots=True) |
|
|
|
class _AsyncCtxManagerWrapper: |
|
|
|
@attr.s(frozen=True, auto_attribs=True) |
|
|
|
class _AsyncCtxManagerWrapper(Generic[T]): |
|
|
|
"""Helper class to convert a plain context manager to an async one. |
|
|
|
|
|
|
|
This is mainly useful if you have a plain context manager but the interface |
|
|
|
requires an async one. |
|
|
|
""" |
|
|
|
|
|
|
|
inner = attr.ib() |
|
|
|
inner: ContextManager[T] |
|
|
|
|
|
|
|
async def __aenter__(self): |
|
|
|
async def __aenter__(self) -> T: |
|
|
|
return self.inner.__enter__() |
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc, tb): |
|
|
|
async def __aexit__( |
|
|
|
self, |
|
|
|
exc_type: Optional[Type[BaseException]], |
|
|
|
exc: Optional[BaseException], |
|
|
|
tb: Optional[TracebackType], |
|
|
|
) -> Optional[bool]: |
|
|
|
return self.inner.__exit__(exc_type, exc, tb) |
|
|
|
|
|
|
|
|
|
|
@@ -671,7 +705,12 @@ class _MultiWriterCtxManager: |
|
|
|
else: |
|
|
|
return [i * self.id_gen._return_factor for i in self.stream_ids] |
|
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc, tb): |
|
|
|
async def __aexit__( |
|
|
|
self, |
|
|
|
exc_type: Optional[Type[BaseException]], |
|
|
|
exc: Optional[BaseException], |
|
|
|
tb: Optional[TracebackType], |
|
|
|
) -> bool: |
|
|
|
for i in self.stream_ids: |
|
|
|
self.id_gen._mark_id_as_finished(i) |
|
|
|
|
|
|
|