@@ -0,0 +1 @@ | |||
Add some type hints to datastore. |
@@ -15,12 +15,17 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, List, Optional, Tuple | |||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast | |||
from synapse.config.homeserver import HomeServerConfig | |||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | |||
from synapse.storage.database import ( | |||
DatabasePool, | |||
LoggingDatabaseConnection, | |||
LoggingTransaction, | |||
) | |||
from synapse.storage.databases.main.stats import UserSortOrder | |||
from synapse.storage.engines import PostgresEngine | |||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine | |||
from synapse.storage.types import Cursor | |||
from synapse.storage.util.id_generators import ( | |||
IdGenerator, | |||
MultiWriterIdGenerator, | |||
@@ -266,7 +271,9 @@ class DataStore( | |||
A tuple of a list of mappings from user to information and a count of total users. | |||
""" | |||
def get_users_paginate_txn(txn): | |||
def get_users_paginate_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[List[JsonDict], int]: | |||
filters = [] | |||
args = [self.hs.config.server.server_name] | |||
@@ -301,7 +308,7 @@ class DataStore( | |||
""" | |||
sql = "SELECT COUNT(*) as total_users " + sql_base | |||
txn.execute(sql, args) | |||
count = txn.fetchone()[0] | |||
count = cast(Tuple[int], txn.fetchone())[0] | |||
sql = f""" | |||
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, | |||
@@ -338,7 +345,9 @@ class DataStore( | |||
) | |||
def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig): | |||
def check_database_before_upgrade( | |||
cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig | |||
) -> None: | |||
"""Called before upgrading an existing database to check that it is broadly sane | |||
compared with the configuration. | |||
""" | |||
@@ -14,7 +14,7 @@ | |||
# limitations under the License. | |||
import logging | |||
import re | |||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple | |||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast | |||
from synapse.appservice import ( | |||
ApplicationService, | |||
@@ -83,7 +83,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): | |||
txn.execute( | |||
"SELECT COALESCE(max(txn_id), 0) FROM application_services_txns" | |||
) | |||
return txn.fetchone()[0] # type: ignore | |||
return cast(Tuple[int], txn.fetchone())[0] | |||
self._as_txn_seq_gen = build_sequence_generator( | |||
db_conn, | |||
@@ -14,7 +14,17 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Collection, | |||
Dict, | |||
Iterable, | |||
List, | |||
Optional, | |||
Set, | |||
Tuple, | |||
cast, | |||
) | |||
from synapse.logging import issue9533_logger | |||
from synapse.logging.opentracing import log_kv, set_tag, trace | |||
@@ -118,7 +128,13 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
prefilled_cache=device_outbox_prefill, | |||
) | |||
def process_replication_rows(self, stream_name, instance_name, token, rows): | |||
def process_replication_rows( | |||
self, | |||
stream_name: str, | |||
instance_name: str, | |||
token: int, | |||
rows: Iterable[ToDeviceStream.ToDeviceStreamRow], | |||
) -> None: | |||
if stream_name == ToDeviceStream.NAME: | |||
# If replication is happening than postgres must be being used. | |||
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator) | |||
@@ -134,7 +150,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
) | |||
return super().process_replication_rows(stream_name, instance_name, token, rows) | |||
def get_to_device_stream_token(self): | |||
def get_to_device_stream_token(self) -> int: | |||
return self._device_inbox_id_gen.get_current_token() | |||
async def get_messages_for_user_devices( | |||
@@ -301,7 +317,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
if not user_ids_to_query: | |||
return {}, to_stream_id | |||
def get_device_messages_txn(txn: LoggingTransaction): | |||
def get_device_messages_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: | |||
# Build a query to select messages from any of the given devices that | |||
# are between the given stream id bounds. | |||
@@ -428,7 +446,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
log_kv({"message": "No changes in cache since last check"}) | |||
return 0 | |||
def delete_messages_for_device_txn(txn): | |||
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: | |||
sql = ( | |||
"DELETE FROM device_inbox" | |||
" WHERE user_id = ? AND device_id = ?" | |||
@@ -455,15 +473,14 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
@trace | |||
async def get_new_device_msgs_for_remote( | |||
self, destination, last_stream_id, current_stream_id, limit | |||
) -> Tuple[List[dict], int]: | |||
self, destination: str, last_stream_id: int, current_stream_id: int, limit: int | |||
) -> Tuple[List[JsonDict], int]: | |||
""" | |||
Args: | |||
destination(str): The name of the remote server. | |||
last_stream_id(int|long): The last position of the device message stream | |||
destination: The name of the remote server. | |||
last_stream_id: The last position of the device message stream | |||
that the server sent up to. | |||
current_stream_id(int|long): The current position of the device | |||
message stream. | |||
current_stream_id: The current position of the device message stream. | |||
Returns: | |||
A list of messages for the device and where in the stream the messages got to. | |||
""" | |||
@@ -485,7 +502,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
return [], last_stream_id | |||
@trace | |||
def get_new_messages_for_remote_destination_txn(txn): | |||
def get_new_messages_for_remote_destination_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[List[JsonDict], int]: | |||
sql = ( | |||
"SELECT stream_id, messages_json FROM device_federation_outbox" | |||
" WHERE destination = ?" | |||
@@ -527,7 +546,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
up_to_stream_id: Where to delete messages up to. | |||
""" | |||
def delete_messages_for_remote_destination_txn(txn): | |||
def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None: | |||
sql = ( | |||
"DELETE FROM device_federation_outbox" | |||
" WHERE destination = ?" | |||
@@ -566,7 +585,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
if last_id == current_id: | |||
return [], current_id, False | |||
def get_all_new_device_messages_txn(txn): | |||
def get_all_new_device_messages_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[List[Tuple[int, tuple]], int, bool]: | |||
# We limit like this as we might have multiple rows per stream_id, and | |||
# we want to make sure we always get all entries for any stream_id | |||
# we return. | |||
@@ -607,8 +628,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
@trace | |||
async def add_messages_to_device_inbox( | |||
self, | |||
local_messages_by_user_then_device: dict, | |||
remote_messages_by_destination: dict, | |||
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], | |||
remote_messages_by_destination: Dict[str, JsonDict], | |||
) -> int: | |||
"""Used to send messages from this server. | |||
@@ -624,7 +645,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
assert self._can_write_to_device | |||
def add_messages_txn(txn, now_ms, stream_id): | |||
def add_messages_txn( | |||
txn: LoggingTransaction, now_ms: int, stream_id: int | |||
) -> None: | |||
# Add the local messages directly to the local inbox. | |||
self._add_messages_to_local_device_inbox_txn( | |||
txn, stream_id, local_messages_by_user_then_device | |||
@@ -677,11 +700,16 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
return self._device_inbox_id_gen.get_current_token() | |||
async def add_messages_from_remote_to_device_inbox( | |||
self, origin: str, message_id: str, local_messages_by_user_then_device: dict | |||
self, | |||
origin: str, | |||
message_id: str, | |||
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], | |||
) -> int: | |||
assert self._can_write_to_device | |||
def add_messages_txn(txn, now_ms, stream_id): | |||
def add_messages_txn( | |||
txn: LoggingTransaction, now_ms: int, stream_id: int | |||
) -> None: | |||
# Check if we've already inserted a matching message_id for that | |||
# origin. This can happen if the origin doesn't receive our | |||
# acknowledgement from the first time we received the message. | |||
@@ -727,8 +755,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
return stream_id | |||
def _add_messages_to_local_device_inbox_txn( | |||
self, txn, stream_id, messages_by_user_then_device | |||
): | |||
self, | |||
txn: LoggingTransaction, | |||
stream_id: int, | |||
messages_by_user_then_device: Dict[str, Dict[str, JsonDict]], | |||
) -> None: | |||
assert self._can_write_to_device | |||
local_by_user_then_device = {} | |||
@@ -840,8 +871,10 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): | |||
self._remove_dead_devices_from_device_inbox, | |||
) | |||
async def _background_drop_index_device_inbox(self, progress, batch_size): | |||
def reindex_txn(conn): | |||
async def _background_drop_index_device_inbox( | |||
self, progress: JsonDict, batch_size: int | |||
) -> int: | |||
def reindex_txn(conn: LoggingDatabaseConnection) -> None: | |||
txn = conn.cursor() | |||
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") | |||
txn.close() | |||
@@ -25,6 +25,7 @@ from typing import ( | |||
Optional, | |||
Set, | |||
Tuple, | |||
cast, | |||
) | |||
from synapse.api.errors import Codes, StoreError | |||
@@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore): | |||
Number of devices of this users. | |||
""" | |||
def count_devices_by_users_txn(txn, user_ids): | |||
def count_devices_by_users_txn( | |||
txn: LoggingTransaction, user_ids: List[str] | |||
) -> int: | |||
sql = """ | |||
SELECT count(*) | |||
FROM devices | |||
@@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
) | |||
txn.execute(sql + clause, args) | |||
return txn.fetchone()[0] | |||
return cast(Tuple[int], txn.fetchone())[0] | |||
if not user_ids: | |||
return 0 | |||
@@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
""" | |||
txn.execute(sql, (destination, from_stream_id, now_stream_id, limit)) | |||
return list(txn) | |||
return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall()) | |||
async def _get_device_update_edus_by_remote( | |||
self, | |||
@@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
async def _get_last_device_update_for_remote_user( | |||
self, destination: str, user_id: str, from_stream_id: int | |||
) -> int: | |||
def f(txn): | |||
def f(txn: LoggingTransaction) -> int: | |||
prev_sent_id_sql = """ | |||
SELECT coalesce(max(stream_id), 0) as stream_id | |||
FROM device_lists_outbound_last_success | |||
@@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
if not user_ids_to_check: | |||
return set() | |||
def _get_users_whose_devices_changed_txn(txn): | |||
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: | |||
changes = set() | |||
stream_id_where_clause = "stream_id > ?" | |||
@@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore): | |||
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: | |||
"""Mark that we no longer track device lists for remote user.""" | |||
def _mark_remote_user_device_list_as_unsubscribed_txn(txn): | |||
def _mark_remote_user_device_list_as_unsubscribed_txn( | |||
txn: LoggingTransaction, | |||
) -> None: | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="device_lists_remote_extremeties", | |||
@@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
) | |||
def _store_dehydrated_device_txn( | |||
self, txn, user_id: str, device_id: str, device_data: str | |||
self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str | |||
) -> Optional[str]: | |||
old_device_id = self.db_pool.simple_select_one_onecol_txn( | |||
txn, | |||
@@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
""" | |||
yesterday = self._clock.time_msec() - prune_age | |||
def _prune_txn(txn): | |||
def _prune_txn(txn: LoggingTransaction) -> None: | |||
# look for (user, destination) pairs which have an update older than | |||
# the cutoff. | |||
# | |||
@@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
"drop_device_lists_outbound_last_success_non_unique_idx", | |||
) | |||
async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): | |||
def f(conn): | |||
async def _drop_device_list_streams_non_unique_indexes( | |||
self, progress: JsonDict, batch_size: int | |||
) -> int: | |||
def f(conn: LoggingDatabaseConnection) -> None: | |||
txn = conn.cursor() | |||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") | |||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") | |||
@@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
) | |||
return 1 | |||
async def _remove_duplicate_outbound_pokes(self, progress, batch_size): | |||
async def _remove_duplicate_outbound_pokes( | |||
self, progress: JsonDict, batch_size: int | |||
) -> int: | |||
# for some reason, we have accumulated duplicate entries in | |||
# device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less | |||
# efficient. | |||
@@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, | |||
) | |||
def _txn(txn): | |||
def _txn(txn: LoggingTransaction) -> int: | |||
clause, args = make_tuple_comparison_clause( | |||
[(x, last_row[x]) for x in KEY_COLS] | |||
) | |||
@@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
context = get_active_span_text_map() | |||
def add_device_changes_txn(txn, stream_ids): | |||
def add_device_changes_txn( | |||
txn: LoggingTransaction, stream_ids: List[int] | |||
) -> None: | |||
self._add_device_change_to_stream_txn( | |||
txn, | |||
user_id, | |||
@@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
txn: LoggingTransaction, | |||
user_id: str, | |||
device_ids: Collection[str], | |||
stream_ids: List[str], | |||
): | |||
stream_ids: List[int], | |||
) -> None: | |||
txn.call_after( | |||
self._device_list_stream_cache.entity_has_changed, | |||
user_id, | |||
@@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
user_id: str, | |||
device_ids: Iterable[str], | |||
room_ids: Collection[str], | |||
stream_ids: List[str], | |||
stream_ids: List[int], | |||
context: Dict[str, str], | |||
) -> None: | |||
"""Record the user in the room has updated their device.""" | |||
@@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
LIMIT ? | |||
""" | |||
def get_uncoverted_outbound_room_pokes_txn(txn): | |||
def get_uncoverted_outbound_room_pokes_txn( | |||
txn: LoggingTransaction, | |||
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: | |||
txn.execute(sql, (limit,)) | |||
return [ | |||
@@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
Marks the associated row in `device_lists_changes_in_room` as handled. | |||
""" | |||
def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]): | |||
def add_device_list_outbound_pokes_txn( | |||
txn: LoggingTransaction, stream_ids: List[int] | |||
) -> None: | |||
if hosts: | |||
self._add_device_outbound_poke_to_stream_txn( | |||
txn, | |||
@@ -522,7 +522,9 @@ class GroupServerWorkerStore(SQLBaseStore): | |||
desc="get_joined_groups", | |||
) | |||
async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]: | |||
async def get_all_groups_for_user( | |||
self, user_id: str, now_token: int | |||
) -> List[JsonDict]: | |||
def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: | |||
sql = """ | |||
SELECT group_id, type, membership, u.content | |||
@@ -15,11 +15,12 @@ | |||
import itertools | |||
import logging | |||
from typing import Dict, Iterable, List, Optional, Tuple | |||
from typing import Any, Dict, Iterable, List, Optional, Tuple | |||
from signedjson.key import decode_verify_key_bytes | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.keys import FetchKeyResult | |||
from synapse.storage.types import Cursor | |||
from synapse.util.caches.descriptors import cached, cachedList | |||
@@ -35,7 +36,9 @@ class KeyStore(SQLBaseStore): | |||
"""Persistence for signature verification keys""" | |||
@cached() | |||
def _get_server_verify_key(self, server_name_and_key_id): | |||
def _get_server_verify_key( | |||
self, server_name_and_key_id: Tuple[str, str] | |||
) -> FetchKeyResult: | |||
raise NotImplementedError() | |||
@cachedList( | |||
@@ -179,19 +182,21 @@ class KeyStore(SQLBaseStore): | |||
async def get_server_keys_json( | |||
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] | |||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: | |||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: | |||
"""Retrieve the key json for a list of server_keys and key ids. | |||
If no keys are found for a given server, key_id and source then | |||
that server, key_id, and source triplet entry will be an empty list. | |||
The JSON is returned as a byte array so that it can be efficiently | |||
used in an HTTP response. | |||
Args: | |||
server_keys (list): List of (server_name, key_id, source) triplets. | |||
server_keys: List of (server_name, key_id, source) triplets. | |||
Returns: | |||
A mapping from (server_name, key_id, source) triplets to a list of dicts | |||
""" | |||
def _get_server_keys_json_txn(txn): | |||
def _get_server_keys_json_txn( | |||
txn: LoggingTransaction, | |||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: | |||
results = {} | |||
for server_name, key_id, from_server in server_keys: | |||
keyvalues = {"server_name": server_name} | |||
@@ -388,7 +388,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) | |||
async def store_url_cache( | |||
self, url, response_code, etag, expires_ts, og, media_id, download_ts | |||
self, | |||
url: str, | |||
response_code: int, | |||
etag: Optional[str], | |||
expires_ts: int, | |||
og: Optional[str], | |||
media_id: str, | |||
download_ts: int, | |||
) -> None: | |||
await self.db_pool.simple_insert( | |||
"local_media_repository_url_cache", | |||
@@ -441,7 +448,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
) | |||
async def get_cached_remote_media( | |||
self, origin, media_id: str | |||
self, origin: str, media_id: str | |||
) -> Optional[Dict[str, Any]]: | |||
return await self.db_pool.simple_select_one( | |||
"remote_media_cache", | |||
@@ -608,7 +615,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
) | |||
async def delete_remote_media(self, media_origin: str, media_id: str) -> None: | |||
def delete_remote_media_txn(txn): | |||
def delete_remote_media_txn(txn: LoggingTransaction) -> None: | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
"remote_media_cache", | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast | |||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast | |||
from synapse.api.presence import PresenceState, UserPresenceState | |||
from synapse.replication.tcp.streams import PresenceStream | |||
@@ -103,7 +103,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
prefilled_cache=presence_cache_prefill, | |||
) | |||
async def update_presence(self, presence_states) -> Tuple[int, int]: | |||
async def update_presence( | |||
self, presence_states: List[UserPresenceState] | |||
) -> Tuple[int, int]: | |||
assert self._can_persist_presence | |||
stream_ordering_manager = self._presence_id_gen.get_next_mult( | |||
@@ -121,7 +123,10 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
return stream_orderings[-1], self._presence_id_gen.get_current_token() | |||
def _update_presence_txn( | |||
self, txn: LoggingTransaction, stream_orderings, presence_states | |||
self, | |||
txn: LoggingTransaction, | |||
stream_orderings: List[int], | |||
presence_states: List[UserPresenceState], | |||
) -> None: | |||
for stream_id, state in zip(stream_orderings, presence_states): | |||
txn.call_after( | |||
@@ -405,7 +410,13 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
self._presence_on_startup = [] | |||
return active_on_startup | |||
def process_replication_rows(self, stream_name, instance_name, token, rows) -> None: | |||
def process_replication_rows( | |||
self, | |||
stream_name: str, | |||
instance_name: str, | |||
token: int, | |||
rows: Iterable[Any], | |||
) -> None: | |||
if stream_name == PresenceStream.NAME: | |||
self._presence_id_gen.advance(instance_name, token) | |||
for row in rows: | |||
@@ -14,11 +14,25 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Dict, | |||
Iterable, | |||
Iterator, | |||
List, | |||
Optional, | |||
Tuple, | |||
cast, | |||
) | |||
from synapse.push import PusherConfig, ThrottleParams | |||
from synapse.storage._base import SQLBaseStore, db_to_json | |||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | |||
from synapse.storage.database import ( | |||
DatabasePool, | |||
LoggingDatabaseConnection, | |||
LoggingTransaction, | |||
) | |||
from synapse.storage.util.id_generators import StreamIdGenerator | |||
from synapse.types import JsonDict | |||
from synapse.util import json_encoder | |||
@@ -117,7 +131,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
return self._decode_pushers_rows(ret) | |||
async def get_all_pushers(self) -> Iterator[PusherConfig]: | |||
def get_pushers(txn): | |||
def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]: | |||
txn.execute("SELECT * FROM pushers") | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
@@ -152,7 +166,9 @@ class PusherWorkerStore(SQLBaseStore): | |||
if last_id == current_id: | |||
return [], current_id, False | |||
def get_all_updated_pushers_rows_txn(txn): | |||
def get_all_updated_pushers_rows_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[List[Tuple[int, tuple]], int, bool]: | |||
sql = """ | |||
SELECT id, user_name, app_id, pushkey | |||
FROM pushers | |||
@@ -160,10 +176,13 @@ class PusherWorkerStore(SQLBaseStore): | |||
ORDER BY id ASC LIMIT ? | |||
""" | |||
txn.execute(sql, (last_id, current_id, limit)) | |||
updates = [ | |||
(stream_id, (user_name, app_id, pushkey, False)) | |||
for stream_id, user_name, app_id, pushkey in txn | |||
] | |||
updates = cast( | |||
List[Tuple[int, tuple]], | |||
[ | |||
(stream_id, (user_name, app_id, pushkey, False)) | |||
for stream_id, user_name, app_id, pushkey in txn | |||
], | |||
) | |||
sql = """ | |||
SELECT stream_id, user_id, app_id, pushkey | |||
@@ -192,12 +211,12 @@ class PusherWorkerStore(SQLBaseStore): | |||
) | |||
@cached(num_args=1, max_entries=15000) | |||
async def get_if_user_has_pusher(self, user_id: str): | |||
async def get_if_user_has_pusher(self, user_id: str) -> None: | |||
# This only exists for the cachedList decorator | |||
raise NotImplementedError() | |||
async def update_pusher_last_stream_ordering( | |||
self, app_id, pushkey, user_id, last_stream_ordering | |||
self, app_id: str, pushkey: str, user_id: str, last_stream_ordering: int | |||
) -> None: | |||
await self.db_pool.simple_update_one( | |||
"pushers", | |||
@@ -291,7 +310,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
last_user = progress.get("last_user", "") | |||
def _delete_pushers(txn) -> int: | |||
def _delete_pushers(txn: LoggingTransaction) -> int: | |||
sql = """ | |||
SELECT name FROM users | |||
@@ -339,7 +358,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
last_pusher = progress.get("last_pusher", 0) | |||
def _delete_pushers(txn) -> int: | |||
def _delete_pushers(txn: LoggingTransaction) -> int: | |||
sql = """ | |||
SELECT p.id, access_token FROM pushers AS p | |||
@@ -396,7 +415,7 @@ class PusherWorkerStore(SQLBaseStore): | |||
last_pusher = progress.get("last_pusher", 0) | |||
def _delete_pushers(txn) -> int: | |||
def _delete_pushers(txn: LoggingTransaction) -> int: | |||
sql = """ | |||
SELECT p.id, p.user_name, p.app_id, p.pushkey | |||
@@ -502,7 +521,7 @@ class PusherStore(PusherWorkerStore): | |||
async def delete_pusher_by_app_id_pushkey_user_id( | |||
self, app_id: str, pushkey: str, user_id: str | |||
) -> None: | |||
def delete_pusher_txn(txn, stream_id): | |||
def delete_pusher_txn(txn: LoggingTransaction, stream_id: int) -> None: | |||
self._invalidate_cache_and_stream( # type: ignore[attr-defined] | |||
txn, self.get_if_user_has_pusher, (user_id,) | |||
) | |||
@@ -547,7 +566,7 @@ class PusherStore(PusherWorkerStore): | |||
# account. | |||
pushers = list(await self.get_pushers_by_user_id(user_id)) | |||
def delete_pushers_txn(txn, stream_ids): | |||
def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None: | |||
self._invalidate_cache_and_stream( # type: ignore[attr-defined] | |||
txn, self.get_if_user_has_pusher, (user_id,) | |||
) | |||
@@ -370,10 +370,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
def _update_state_for_partial_state_event_txn( | |||
self, | |||
txn, | |||
txn: LoggingTransaction, | |||
event: EventBase, | |||
context: EventContext, | |||
): | |||
) -> None: | |||
# we shouldn't have any outliers here | |||
assert not event.internal_metadata.is_outlier() | |||
@@ -131,7 +131,7 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
session_id: str, | |||
stage_type: str, | |||
result: Union[str, bool, JsonDict], | |||
): | |||
) -> None: | |||
""" | |||
Mark a session stage as completed. | |||
@@ -200,7 +200,9 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
desc="set_ui_auth_client_dict", | |||
) | |||
async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): | |||
async def set_ui_auth_session_data( | |||
self, session_id: str, key: str, value: Any | |||
) -> None: | |||
""" | |||
Store a key-value pair into the sessions data associated with this | |||
request. This data is stored server-side and cannot be modified by | |||
@@ -223,7 +225,7 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
def _set_ui_auth_session_data_txn( | |||
self, txn: LoggingTransaction, session_id: str, key: str, value: Any | |||
): | |||
) -> None: | |||
# Get the current value. | |||
result = cast( | |||
Dict[str, Any], | |||
@@ -275,7 +277,7 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
session_id: str, | |||
user_agent: str, | |||
ip: str, | |||
): | |||
) -> None: | |||
"""Add the given user agent / IP to the tracking table""" | |||
await self.db_pool.simple_upsert( | |||
table="ui_auth_sessions_ips", | |||
@@ -318,7 +320,7 @@ class UIAuthWorkerStore(SQLBaseStore): | |||
def _delete_old_ui_auth_sessions_txn( | |||
self, txn: LoggingTransaction, expiration_time: int | |||
): | |||
) -> None: | |||
# Get the expired sessions. | |||
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" | |||
txn.execute(sql, [expiration_time]) | |||