@@ -0,0 +1 @@ | |||
Add type hints to storage classes. |
@@ -27,8 +27,6 @@ exclude = (?x) | |||
|synapse/storage/databases/main/__init__.py | |||
|synapse/storage/databases/main/account_data.py | |||
|synapse/storage/databases/main/cache.py | |||
|synapse/storage/databases/main/censor_events.py | |||
|synapse/storage/databases/main/deviceinbox.py | |||
|synapse/storage/databases/main/devices.py | |||
|synapse/storage/databases/main/directory.py | |||
|synapse/storage/databases/main/e2e_room_keys.py | |||
@@ -38,19 +36,15 @@ exclude = (?x) | |||
|synapse/storage/databases/main/events_bg_updates.py | |||
|synapse/storage/databases/main/events_forward_extremities.py | |||
|synapse/storage/databases/main/events_worker.py | |||
|synapse/storage/databases/main/filtering.py | |||
|synapse/storage/databases/main/group_server.py | |||
|synapse/storage/databases/main/lock.py | |||
|synapse/storage/databases/main/media_repository.py | |||
|synapse/storage/databases/main/metrics.py | |||
|synapse/storage/databases/main/monthly_active_users.py | |||
|synapse/storage/databases/main/openid.py | |||
|synapse/storage/databases/main/presence.py | |||
|synapse/storage/databases/main/profile.py | |||
|synapse/storage/databases/main/purge_events.py | |||
|synapse/storage/databases/main/push_rule.py | |||
|synapse/storage/databases/main/receipts.py | |||
|synapse/storage/databases/main/rejections.py | |||
|synapse/storage/databases/main/room.py | |||
|synapse/storage/databases/main/room_batch.py | |||
|synapse/storage/databases/main/roommember.py | |||
@@ -59,7 +53,6 @@ exclude = (?x) | |||
|synapse/storage/databases/main/state.py | |||
|synapse/storage/databases/main/state_deltas.py | |||
|synapse/storage/databases/main/stats.py | |||
|synapse/storage/databases/main/tags.py | |||
|synapse/storage/databases/main/transactions.py | |||
|synapse/storage/databases/main/user_directory.py | |||
|synapse/storage/databases/main/user_erasure_store.py | |||
@@ -13,12 +13,12 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING | |||
from typing import TYPE_CHECKING, Optional | |||
from synapse.events.utils import prune_event_dict | |||
from synapse.metrics.background_process_metrics import wrap_as_background_process | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.database import DatabasePool, LoggingTransaction | |||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | |||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | |||
from synapse.util import json_encoder | |||
@@ -41,7 +41,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) | |||
@wrap_as_background_process("_censor_redactions") | |||
async def _censor_redactions(self): | |||
async def _censor_redactions(self) -> None: | |||
"""Censors all redactions older than the configured period that haven't | |||
been censored yet. | |||
@@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
and original_event.internal_metadata.is_redacted() | |||
): | |||
# Redaction was allowed | |||
pruned_json = json_encoder.encode( | |||
pruned_json: Optional[str] = json_encoder.encode( | |||
prune_event_dict( | |||
original_event.room_version, original_event.get_dict() | |||
) | |||
@@ -116,7 +116,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
updates.append((redaction_id, event_id, pruned_json)) | |||
def _update_censor_txn(txn): | |||
def _update_censor_txn(txn: LoggingTransaction) -> None: | |||
for redaction_id, event_id, pruned_json in updates: | |||
if pruned_json: | |||
self._censor_event_txn(txn, event_id, pruned_json) | |||
@@ -130,14 +130,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn) | |||
def _censor_event_txn(self, txn, event_id, pruned_json): | |||
def _censor_event_txn( | |||
self, txn: LoggingTransaction, event_id: str, pruned_json: str | |||
) -> None: | |||
"""Censor an event by replacing its JSON in the event_json table with the | |||
provided pruned JSON. | |||
Args: | |||
txn (LoggingTransaction): The database transaction. | |||
event_id (str): The ID of the event to censor. | |||
pruned_json (str): The pruned JSON | |||
txn: The database transaction. | |||
event_id: The ID of the event to censor. | |||
pruned_json: The pruned JSON | |||
""" | |||
self.db_pool.simple_update_one_txn( | |||
txn, | |||
@@ -157,7 +159,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
# Try to retrieve the event's content from the database or the event cache. | |||
event = await self.get_event(event_id) | |||
def delete_expired_event_txn(txn): | |||
def delete_expired_event_txn(txn: LoggingTransaction) -> None: | |||
# Delete the expiry timestamp associated with this event from the database. | |||
self._delete_event_expiry_txn(txn, event_id) | |||
@@ -194,14 +196,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase | |||
"delete_expired_event", delete_expired_event_txn | |||
) | |||
def _delete_event_expiry_txn(self, txn, event_id): | |||
def _delete_event_expiry_txn(self, txn: LoggingTransaction, event_id: str) -> None: | |||
"""Delete the expiry timestamp associated with an event ID without deleting the | |||
actual event. | |||
Args: | |||
txn (LoggingTransaction): The transaction to use to perform the deletion. | |||
event_id (str): The event ID to delete the associated expiry timestamp of. | |||
txn: The transaction to use to perform the deletion. | |||
event_id: The event ID to delete the associated expiry timestamp of. | |||
""" | |||
return self.db_pool.simple_delete_txn( | |||
self.db_pool.simple_delete_txn( | |||
txn=txn, table="event_expiry", keyvalues={"event_id": event_id} | |||
) |
@@ -1,4 +1,5 @@ | |||
# Copyright 2016 OpenMarket Ltd | |||
# Copyright 2021 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
@@ -19,9 +20,17 @@ from synapse.logging import issue9533_logger | |||
from synapse.logging.opentracing import log_kv, set_tag, trace | |||
from synapse.replication.tcp.streams import ToDeviceStream | |||
from synapse.storage._base import SQLBaseStore, db_to_json | |||
from synapse.storage.database import DatabasePool, LoggingTransaction | |||
from synapse.storage.database import ( | |||
DatabasePool, | |||
LoggingDatabaseConnection, | |||
LoggingTransaction, | |||
) | |||
from synapse.storage.engines import PostgresEngine | |||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator | |||
from synapse.storage.util.id_generators import ( | |||
AbstractStreamIdGenerator, | |||
MultiWriterIdGenerator, | |||
StreamIdGenerator, | |||
) | |||
from synapse.types import JsonDict | |||
from synapse.util import json_encoder | |||
from synapse.util.caches.expiringcache import ExpiringCache | |||
@@ -34,14 +43,21 @@ logger = logging.getLogger(__name__) | |||
class DeviceInboxWorkerStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
def __init__( | |||
self, | |||
database: DatabasePool, | |||
db_conn: LoggingDatabaseConnection, | |||
hs: "HomeServer", | |||
): | |||
super().__init__(database, db_conn, hs) | |||
self._instance_name = hs.get_instance_name() | |||
# Map of (user_id, device_id) to the last stream_id that has been | |||
# deleted up to. This is so that we can no op deletions. | |||
self._last_device_delete_cache = ExpiringCache( | |||
self._last_device_delete_cache: ExpiringCache[ | |||
Tuple[str, Optional[str]], int | |||
] = ExpiringCache( | |||
cache_name="last_device_delete_cache", | |||
clock=self._clock, | |||
max_len=10000, | |||
@@ -53,14 +69,16 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
self._instance_name in hs.config.worker.writers.to_device | |||
) | |||
self._device_inbox_id_gen = MultiWriterIdGenerator( | |||
db_conn=db_conn, | |||
db=database, | |||
stream_name="to_device", | |||
instance_name=self._instance_name, | |||
tables=[("device_inbox", "instance_name", "stream_id")], | |||
sequence_name="device_inbox_sequence", | |||
writers=hs.config.worker.writers.to_device, | |||
self._device_inbox_id_gen: AbstractStreamIdGenerator = ( | |||
MultiWriterIdGenerator( | |||
db_conn=db_conn, | |||
db=database, | |||
stream_name="to_device", | |||
instance_name=self._instance_name, | |||
tables=[("device_inbox", "instance_name", "stream_id")], | |||
sequence_name="device_inbox_sequence", | |||
writers=hs.config.worker.writers.to_device, | |||
) | |||
) | |||
else: | |||
self._can_write_to_device = True | |||
@@ -101,6 +119,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
def process_replication_rows(self, stream_name, instance_name, token, rows): | |||
if stream_name == ToDeviceStream.NAME: | |||
# If replication is happening than postgres must be being used. | |||
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator) | |||
self._device_inbox_id_gen.advance(instance_name, token) | |||
for row in rows: | |||
if row.entity.startswith("@"): | |||
@@ -220,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
log_kv({"message": f"deleted {count} messages for device", "count": count}) | |||
# Update the cache, ensuring that we only ever increase the value | |||
last_deleted_stream_id = self._last_device_delete_cache.get( | |||
updated_last_deleted_stream_id = self._last_device_delete_cache.get( | |||
(user_id, device_id), 0 | |||
) | |||
self._last_device_delete_cache[(user_id, device_id)] = max( | |||
last_deleted_stream_id, up_to_stream_id | |||
updated_last_deleted_stream_id, up_to_stream_id | |||
) | |||
return count | |||
@@ -432,7 +452,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
) | |||
async with self._device_inbox_id_gen.get_next() as stream_id: | |||
now_ms = self.clock.time_msec() | |||
now_ms = self._clock.time_msec() | |||
await self.db_pool.runInteraction( | |||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id | |||
) | |||
@@ -483,7 +503,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
) | |||
async with self._device_inbox_id_gen.get_next() as stream_id: | |||
now_ms = self.clock.time_msec() | |||
now_ms = self._clock.time_msec() | |||
await self.db_pool.runInteraction( | |||
"add_messages_from_remote_to_device_inbox", | |||
add_messages_txn, | |||
@@ -1,4 +1,5 @@ | |||
# Copyright 2015, 2016 OpenMarket Ltd | |||
# Copyright 2021 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
@@ -18,6 +19,7 @@ from canonicaljson import encode_canonical_json | |||
from synapse.api.errors import Codes, SynapseError | |||
from synapse.storage._base import SQLBaseStore, db_to_json | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.types import JsonDict | |||
from synapse.util.caches.descriptors import cached | |||
@@ -49,7 +51,7 @@ class FilteringStore(SQLBaseStore): | |||
# Need an atomic transaction to SELECT the maximal ID so far then | |||
# INSERT a new one | |||
def _do_txn(txn): | |||
def _do_txn(txn: LoggingTransaction) -> int: | |||
sql = ( | |||
"SELECT filter_id FROM user_filters " | |||
"WHERE user_id = ? AND filter_json = ?" | |||
@@ -61,7 +63,7 @@ class FilteringStore(SQLBaseStore): | |||
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" | |||
txn.execute(sql, (user_localpart,)) | |||
max_id = txn.fetchone()[0] | |||
max_id = txn.fetchone()[0] # type: ignore[index] | |||
if max_id is None: | |||
filter_id = 0 | |||
else: | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from types import TracebackType | |||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type | |||
from typing import TYPE_CHECKING, Optional, Tuple, Type | |||
from weakref import WeakValueDictionary | |||
from twisted.internet.interfaces import IReactorCore | |||
@@ -62,7 +62,9 @@ class LockStore(SQLBaseStore): | |||
# A map from `(lock_name, lock_key)` to the token of any locks that we | |||
# think we currently hold. | |||
self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary() | |||
self._live_tokens: WeakValueDictionary[ | |||
Tuple[str, str], Lock | |||
] = WeakValueDictionary() | |||
# When we shut down we want to remove the locks. Technically this can | |||
# lead to a race, as we may drop the lock while we are still processing. | |||
@@ -1,6 +1,21 @@ | |||
# Copyright 2019-2021 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Optional | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import LoggingTransaction | |||
class OpenIdStore(SQLBaseStore): | |||
@@ -20,7 +35,7 @@ class OpenIdStore(SQLBaseStore): | |||
async def get_user_id_for_open_id_token( | |||
self, token: str, ts_now_ms: int | |||
) -> Optional[str]: | |||
def get_user_id_for_token_txn(txn): | |||
def get_user_id_for_token_txn(txn: LoggingTransaction) -> Optional[str]: | |||
sql = ( | |||
"SELECT user_id FROM open_id_tokens" | |||
" WHERE token = ? AND ? <= ts_valid_until_ms" | |||
@@ -1,5 +1,6 @@ | |||
# Copyright 2014-2016 OpenMarket Ltd | |||
# Copyright 2018 New Vector Ltd | |||
# Copyright 2021 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
@@ -14,9 +15,10 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Dict, List, Tuple | |||
from typing import Dict, List, Tuple, cast | |||
from synapse.storage._base import db_to_json | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore | |||
from synapse.types import JsonDict | |||
from synapse.util import json_encoder | |||
@@ -50,7 +52,7 @@ class TagsWorkerStore(AccountDataWorkerStore): | |||
async def get_all_updated_tags( | |||
self, instance_name: str, last_id: int, current_id: int, limit: int | |||
) -> Tuple[List[Tuple[int, tuple]], int, bool]: | |||
) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]: | |||
"""Get updates for tags replication stream. | |||
Args: | |||
@@ -75,7 +77,9 @@ class TagsWorkerStore(AccountDataWorkerStore): | |||
if last_id == current_id: | |||
return [], current_id, False | |||
def get_all_updated_tags_txn(txn): | |||
def get_all_updated_tags_txn( | |||
txn: LoggingTransaction, | |||
) -> List[Tuple[int, str, str]]: | |||
sql = ( | |||
"SELECT stream_id, user_id, room_id" | |||
" FROM room_tags_revisions as r" | |||
@@ -83,13 +87,16 @@ class TagsWorkerStore(AccountDataWorkerStore): | |||
" ORDER BY stream_id ASC LIMIT ?" | |||
) | |||
txn.execute(sql, (last_id, current_id, limit)) | |||
return txn.fetchall() | |||
# mypy doesn't understand what the query is selecting. | |||
return cast(List[Tuple[int, str, str]], txn.fetchall()) | |||
tag_ids = await self.db_pool.runInteraction( | |||
"get_all_updated_tags", get_all_updated_tags_txn | |||
) | |||
def get_tag_content(txn, tag_ids): | |||
def get_tag_content( | |||
txn: LoggingTransaction, tag_ids | |||
) -> List[Tuple[int, Tuple[str, str, str]]]: | |||
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" | |||
results = [] | |||
for stream_id, user_id, room_id in tag_ids: | |||
@@ -127,15 +134,15 @@ class TagsWorkerStore(AccountDataWorkerStore): | |||
given version | |||
Args: | |||
user_id(str): The user to get the tags for. | |||
stream_id(int): The earliest update to get for the user. | |||
user_id: The user to get the tags for. | |||
stream_id: The earliest update to get for the user. | |||
Returns: | |||
A mapping from room_id strings to lists of tag strings for all the | |||
rooms that changed since the stream_id token. | |||
""" | |||
def get_updated_tags_txn(txn): | |||
def get_updated_tags_txn(txn: LoggingTransaction) -> List[str]: | |||
sql = ( | |||
"SELECT room_id from room_tags_revisions" | |||
" WHERE user_id = ? AND stream_id > ?" | |||
@@ -200,7 +207,7 @@ class TagsWorkerStore(AccountDataWorkerStore): | |||
content_json = json_encoder.encode(content) | |||
def add_tag_txn(txn, next_id): | |||
def add_tag_txn(txn: LoggingTransaction, next_id: int) -> None: | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="room_tags", | |||
@@ -224,7 +231,7 @@ class TagsWorkerStore(AccountDataWorkerStore): | |||
""" | |||
assert self._can_write_to_account_data | |||
def remove_tag_txn(txn, next_id): | |||
def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: | |||
sql = ( | |||
"DELETE FROM room_tags " | |||
" WHERE user_id = ? AND room_id = ? AND tag = ?" | |||
@@ -1,4 +1,5 @@ | |||
# Copyright 2014-2016 OpenMarket Ltd | |||
# Copyright 2021 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
@@ -11,6 +12,7 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import abc | |||
import heapq | |||
import logging | |||
import threading | |||
@@ -87,7 +89,25 @@ def _load_current_id( | |||
return (max if step > 0 else min)(current_id, step) | |||
class StreamIdGenerator: | |||
class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): | |||
@abc.abstractmethod | |||
def get_next(self) -> AsyncContextManager[int]: | |||
raise NotImplementedError() | |||
@abc.abstractmethod | |||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: | |||
raise NotImplementedError() | |||
@abc.abstractmethod | |||
def get_current_token(self) -> int: | |||
raise NotImplementedError() | |||
@abc.abstractmethod | |||
def get_current_token_for_writer(self, instance_name: str) -> int: | |||
raise NotImplementedError() | |||
class StreamIdGenerator(AbstractStreamIdGenerator): | |||
"""Used to generate new stream ids when persisting events while keeping | |||
track of which transactions have been completed. | |||
@@ -209,7 +229,7 @@ class StreamIdGenerator: | |||
return self.get_current_token() | |||
class MultiWriterIdGenerator: | |||
class MultiWriterIdGenerator(AbstractStreamIdGenerator): | |||
"""An ID generator that tracks a stream that can have multiple writers. | |||
Uses a Postgres sequence to coordinate ID assignment, but positions of other | |||