Quellcode durchsuchen

Annotate synapse.storage.util (#10892)

Also mark `synapse.streams` as having has no untyped defs

Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>
tags/v1.45.0rc1
David Robertson vor 2 Jahren
committed by GitHub
Ursprung
Commit
51a5da74cc
Es konnte kein GPG-Schlüssel zu dieser Signatur gefunden werden GPG-Schlüssel-ID: 4AEE18F83AFDEB23
8 geänderte Dateien mit 124 neuen und 65 gelöschten Zeilen
  1. +1
    -0
      changelog.d/10892.misc
  2. +6
    -0
      mypy.ini
  3. +2
    -2
      synapse/replication/slave/storage/_slaved_id_tracker.py
  4. +7
    -3
      synapse/replication/slave/storage/pushers.py
  5. +7
    -3
      synapse/storage/databases/main/pusher.py
  6. +7
    -2
      synapse/storage/databases/main/registration.py
  7. +91
    -52
      synapse/storage/util/id_generators.py
  8. +3
    -3
      synapse/storage/util/sequence.py

+ 1
- 0
changelog.d/10892.misc Datei anzeigen

@@ -0,0 +1 @@
Add further type hints to `synapse.storage.util`.

+ 6
- 0
mypy.ini Datei anzeigen

@@ -105,6 +105,12 @@ disallow_untyped_defs = True
[mypy-synapse.state.*]
disallow_untyped_defs = True

[mypy-synapse.storage.util.*]
disallow_untyped_defs = True

[mypy-synapse.streams.*]
disallow_untyped_defs = True

[mypy-synapse.util.batching_queue]
disallow_untyped_defs = True



+ 2
- 2
synapse/replication/slave/storage/_slaved_id_tracker.py Datei anzeigen

@@ -13,14 +13,14 @@
# limitations under the License.
from typing import List, Optional, Tuple

from synapse.storage.types import Connection
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.util.id_generators import _load_current_id


class SlavedIdTracker:
def __init__(
self,
db_conn: Connection,
db_conn: LoggingDatabaseConnection,
table: str,
column: str,
extra_tables: Optional[List[Tuple[str, str]]] = None,


+ 7
- 3
synapse/replication/slave/storage/pushers.py Datei anzeigen

@@ -15,9 +15,8 @@
from typing import TYPE_CHECKING

from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.types import Connection

from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -27,7 +26,12 @@ if TYPE_CHECKING:


class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker( # type: ignore
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]


+ 7
- 3
synapse/storage/databases/main/pusher.py Datei anzeigen

@@ -18,8 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional,

from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -32,7 +31,12 @@ logger = logging.getLogger(__name__)


class PusherWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]


+ 7
- 2
synapse/storage/databases/main/registration.py Datei anzeigen

@@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID, UserInfo
@@ -1775,7 +1775,12 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):


class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self._ignore_unknown_session_error = (


+ 91
- 52
synapse/storage/util/id_generators.py Datei anzeigen

@@ -16,42 +16,62 @@ import logging
import threading
from collections import OrderedDict
from contextlib import contextmanager
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from types import TracebackType
from typing import (
AsyncContextManager,
ContextManager,
Dict,
Generator,
Generic,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)

import attr
from sortedcontainers import SortedSet

from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator

logger = logging.getLogger(__name__)


T = TypeVar("T")


class IdGenerator:
def __init__(self, db_conn, table, column):
def __init__(
self,
db_conn: LoggingDatabaseConnection,
table: str,
column: str,
):
self._lock = threading.Lock()
self._next_id = _load_current_id(db_conn, table, column)

def get_next(self):
def get_next(self) -> int:
with self._lock:
self._next_id += 1
return self._next_id


def _load_current_id(db_conn, table, column, step=1):
"""

Args:
db_conn (object):
table (str):
column (str):
step (int):

Returns:
int
"""
def _load_current_id(
db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
) -> int:
# debug logging for https://github.com/matrix-org/synapse/issues/7968
logger.info("initialising stream generator for %s(%s)", table, column)
cur = db_conn.cursor(txn_name="_load_current_id")
@@ -59,7 +79,9 @@ def _load_current_id(db_conn, table, column, step=1):
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
(val,) = cur.fetchone()
result = cur.fetchone()
assert result is not None
(val,) = result
cur.close()
current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
@@ -93,16 +115,16 @@ class StreamIdGenerator:

def __init__(
self,
db_conn,
table,
column,
db_conn: LoggingDatabaseConnection,
table: str,
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
step=1,
):
step: int = 1,
) -> None:
assert step != 0
self._lock = threading.Lock()
self._step = step
self._current = _load_current_id(db_conn, table, column, step)
self._step: int = step
self._current: int = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
@@ -115,7 +137,7 @@ class StreamIdGenerator:
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()

def get_next(self):
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
@@ -128,7 +150,7 @@ class StreamIdGenerator:
self._unfinished_ids[next_id] = next_id

@contextmanager
def manager():
def manager() -> Generator[int, None, None]:
try:
yield next_id
finally:
@@ -137,7 +159,7 @@ class StreamIdGenerator:

return _AsyncCtxManagerWrapper(manager())

def get_next_mult(self, n):
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
"""
Usage:
async with stream_id_gen.get_next(n) as stream_ids:
@@ -155,7 +177,7 @@ class StreamIdGenerator:
self._unfinished_ids[next_id] = next_id

@contextmanager
def manager():
def manager() -> Generator[Sequence[int], None, None]:
try:
yield next_ids
finally:
@@ -215,7 +237,7 @@ class MultiWriterIdGenerator:

def __init__(
self,
db_conn,
db_conn: LoggingDatabaseConnection,
db: DatabasePool,
stream_name: str,
instance_name: str,
@@ -223,7 +245,7 @@ class MultiWriterIdGenerator:
sequence_name: str,
writers: List[str],
positive: bool = True,
):
) -> None:
self._db = db
self._stream_name = stream_name
self._instance_name = instance_name
@@ -285,9 +307,9 @@ class MultiWriterIdGenerator:

def _load_current_ids(
self,
db_conn,
db_conn: LoggingDatabaseConnection,
tables: List[Tuple[str, str, str]],
):
) -> None:
cur = db_conn.cursor(txn_name="_load_current_ids")

# Load the current positions of all writers for the stream.
@@ -335,7 +357,9 @@ class MultiWriterIdGenerator:
"agg": "MAX" if self._positive else "-MIN",
}
cur.execute(sql)
(stream_id,) = cur.fetchone()
result = cur.fetchone()
assert result is not None
(stream_id,) = result

max_stream_id = max(max_stream_id, stream_id)

@@ -354,7 +378,7 @@ class MultiWriterIdGenerator:

self._persisted_upto_position = min_stream_id

rows = []
rows: List[Tuple[str, int]] = []
for table, instance_column, id_column in tables:
sql = """
SELECT %(instance)s, %(id)s FROM %(table)s
@@ -367,7 +391,8 @@ class MultiWriterIdGenerator:
}
cur.execute(sql, (min_stream_id * self._return_factor,))

rows.extend(cur)
# Cast safety: this corresponds to the types returned by the query above.
rows.extend(cast(Iterable[Tuple[str, int]], cur))

# Sort so that we handle rows in order for each instance.
rows.sort()
@@ -385,13 +410,13 @@ class MultiWriterIdGenerator:

cur.close()

def _load_next_id_txn(self, txn) -> int:
def _load_next_id_txn(self, txn: Cursor) -> int:
return self._sequence_gen.get_next_id_txn(txn)

def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)

def get_next(self):
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
@@ -403,9 +428,12 @@ class MultiWriterIdGenerator:
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")

return _MultiWriterCtxManager(self)
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
# controls the return type. If `None` or omitted, the context manager yields
# a single integer stream_id; otherwise it yields a list of stream_ids.
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))

def get_next_mult(self, n: int):
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
"""
Usage:
async with stream_id_gen.get_next_mult(5) as stream_ids:
@@ -417,9 +445,10 @@ class MultiWriterIdGenerator:
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")

return _MultiWriterCtxManager(self, n)
# Cast safety: see get_next.
return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))

def get_next_txn(self, txn: LoggingTransaction):
def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Usage:

@@ -457,7 +486,7 @@ class MultiWriterIdGenerator:

return self._return_factor * next_id

def _mark_id_as_finished(self, next_id: int):
def _mark_id_as_finished(self, next_id: int) -> None:
"""The ID has finished being processed so we should advance the
current position if possible.
"""
@@ -534,7 +563,7 @@ class MultiWriterIdGenerator:
for name, i in self._current_positions.items()
}

def advance(self, instance_name: str, new_id: int):
def advance(self, instance_name: str, new_id: int) -> None:
"""Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
@@ -560,7 +589,7 @@ class MultiWriterIdGenerator:
with self._lock:
return self._return_factor * self._persisted_upto_position

def _add_persisted_position(self, new_id: int):
def _add_persisted_position(self, new_id: int) -> None:
"""Record that we have persisted a position.

This is used to keep the `_current_positions` up to date.
@@ -606,7 +635,7 @@ class MultiWriterIdGenerator:
# do.
break

def _update_stream_positions_table_txn(self, txn: Cursor):
def _update_stream_positions_table_txn(self, txn: Cursor) -> None:
"""Update the `stream_positions` table with newly persisted position."""

if not self._writers:
@@ -628,20 +657,25 @@ class MultiWriterIdGenerator:
txn.execute(sql, (self._stream_name, self._instance_name, pos))


@attr.s(slots=True)
class _AsyncCtxManagerWrapper:
@attr.s(frozen=True, auto_attribs=True)
class _AsyncCtxManagerWrapper(Generic[T]):
"""Helper class to convert a plain context manager to an async one.

This is mainly useful if you have a plain context manager but the interface
requires an async one.
"""

inner = attr.ib()
inner: ContextManager[T]

async def __aenter__(self):
async def __aenter__(self) -> T:
return self.inner.__enter__()

async def __aexit__(self, exc_type, exc, tb):
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
return self.inner.__exit__(exc_type, exc, tb)


@@ -671,7 +705,12 @@ class _MultiWriterCtxManager:
else:
return [i * self.id_gen._return_factor for i in self.stream_ids]

async def __aexit__(self, exc_type, exc, tb):
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)



+ 3
- 3
synapse/storage/util/sequence.py Datei anzeigen

@@ -81,7 +81,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
):
) -> None:
"""Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table.

@@ -122,7 +122,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
):
) -> None:
"""See SequenceGenerator.check_consistency for docstring."""

txn = db_conn.cursor(txn_name="sequence.check_consistency")
@@ -244,7 +244,7 @@ class LocalSequenceGenerator(SequenceGenerator):
id_column: str,
stream_name: Optional[str] = None,
positive: bool = True,
):
) -> None:
# There is nothing to do for in memory sequences
pass



Laden…
Abbrechen
Speichern