@@ -0,0 +1 @@ | |||
Add missing type hints to storage classes. |
@@ -31,14 +31,11 @@ exclude = (?x) | |||
|synapse/storage/databases/main/group_server.py | |||
|synapse/storage/databases/main/metrics.py | |||
|synapse/storage/databases/main/monthly_active_users.py | |||
|synapse/storage/databases/main/presence.py | |||
|synapse/storage/databases/main/purge_events.py | |||
|synapse/storage/databases/main/push_rule.py | |||
|synapse/storage/databases/main/receipts.py | |||
|synapse/storage/databases/main/roommember.py | |||
|synapse/storage/databases/main/search.py | |||
|synapse/storage/databases/main/state.py | |||
|synapse/storage/databases/main/user_directory.py | |||
|synapse/storage/schema/ | |||
|tests/api/test_auth.py | |||
@@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC): | |||
Returns: | |||
dict: `user_id` -> `UserPresenceState` | |||
""" | |||
states = { | |||
user_id: self.user_to_current_state.get(user_id, None) | |||
for user_id in user_ids | |||
} | |||
states = {} | |||
missing = [] | |||
for user_id in user_ids: | |||
state = self.user_to_current_state.get(user_id, None) | |||
if state: | |||
states[user_id] = state | |||
else: | |||
missing.append(user_id) | |||
missing = [user_id for user_id, state in states.items() if not state] | |||
if missing: | |||
# There are things not in our in memory cache. Lets pull them out of | |||
# the database. | |||
res = await self.store.get_presence_for_users(missing) | |||
states.update(res) | |||
missing = [user_id for user_id, state in states.items() if not state] | |||
if missing: | |||
new = { | |||
user_id: UserPresenceState.default(user_id) for user_id in missing | |||
} | |||
states.update(new) | |||
self.user_to_current_state.update(new) | |||
for user_id in missing: | |||
# if user has no state in database, create the state | |||
if not res.get(user_id, None): | |||
new_state = UserPresenceState.default(user_id) | |||
states[user_id] = new_state | |||
self.user_to_current_state[user_id] = new_state | |||
return states | |||
@@ -12,15 +12,23 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple | |||
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast | |||
from synapse.api.presence import PresenceState, UserPresenceState | |||
from synapse.replication.tcp.streams import PresenceStream | |||
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause | |||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | |||
from synapse.storage.database import ( | |||
DatabasePool, | |||
LoggingDatabaseConnection, | |||
LoggingTransaction, | |||
) | |||
from synapse.storage.engines import PostgresEngine | |||
from synapse.storage.types import Connection | |||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator | |||
from synapse.storage.util.id_generators import ( | |||
AbstractStreamIdGenerator, | |||
MultiWriterIdGenerator, | |||
StreamIdGenerator, | |||
) | |||
from synapse.util.caches.descriptors import cached, cachedList | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
from synapse.util.iterutils import batch_iter | |||
@@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): | |||
database: DatabasePool, | |||
db_conn: LoggingDatabaseConnection, | |||
hs: "HomeServer", | |||
): | |||
) -> None: | |||
super().__init__(database, db_conn, hs) | |||
# Used by `PresenceStore._get_active_presence()` | |||
@@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
database: DatabasePool, | |||
db_conn: LoggingDatabaseConnection, | |||
hs: "HomeServer", | |||
): | |||
) -> None: | |||
super().__init__(database, db_conn, hs) | |||
self._instance_name = hs.get_instance_name() | |||
self._presence_id_gen: AbstractStreamIdGenerator | |||
self._can_persist_presence = ( | |||
hs.get_instance_name() in hs.config.worker.writers.presence | |||
self._instance_name in hs.config.worker.writers.presence | |||
) | |||
if isinstance(database.engine, PostgresEngine): | |||
@@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
return stream_orderings[-1], self._presence_id_gen.get_current_token() | |||
def _update_presence_txn(self, txn, stream_orderings, presence_states): | |||
def _update_presence_txn( | |||
self, txn: LoggingTransaction, stream_orderings, presence_states | |||
) -> None: | |||
for stream_id, state in zip(stream_orderings, presence_states): | |||
txn.call_after( | |||
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id | |||
@@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
if last_id == current_id: | |||
return [], current_id, False | |||
def get_all_presence_updates_txn(txn): | |||
def get_all_presence_updates_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[List[Tuple[int, list]], int, bool]: | |||
sql = """ | |||
SELECT stream_id, user_id, state, last_active_ts, | |||
last_federation_update_ts, last_user_sync_ts, | |||
status_msg, | |||
currently_active | |||
status_msg, currently_active | |||
FROM presence_stream | |||
WHERE ? < stream_id AND stream_id <= ? | |||
ORDER BY stream_id ASC | |||
LIMIT ? | |||
""" | |||
txn.execute(sql, (last_id, current_id, limit)) | |||
updates = [(row[0], row[1:]) for row in txn] | |||
updates = cast( | |||
List[Tuple[int, list]], | |||
[(row[0], row[1:]) for row in txn], | |||
) | |||
upper_bound = current_id | |||
limited = False | |||
@@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
) | |||
@cached() | |||
def _get_presence_for_user(self, user_id): | |||
def _get_presence_for_user(self, user_id: str) -> None: | |||
raise NotImplementedError() | |||
@cachedList( | |||
@@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
list_name="user_ids", | |||
num_args=1, | |||
) | |||
async def get_presence_for_users(self, user_ids): | |||
async def get_presence_for_users( | |||
self, user_ids: Iterable[str] | |||
) -> Dict[str, UserPresenceState]: | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="presence_stream", | |||
column="user_id", | |||
@@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
True if the user should have full presence sent to them, False otherwise. | |||
""" | |||
def _should_user_receive_full_presence_with_token_txn(txn): | |||
def _should_user_receive_full_presence_with_token_txn( | |||
txn: LoggingTransaction, | |||
) -> bool: | |||
sql = """ | |||
SELECT 1 FROM users_to_send_full_presence_to | |||
WHERE user_id = ? | |||
@@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
_should_user_receive_full_presence_with_token_txn, | |||
) | |||
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): | |||
async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None: | |||
"""Adds to the list of users who should receive a full snapshot of presence | |||
upon their next sync. | |||
@@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
return users_to_state | |||
def get_current_presence_token(self): | |||
def get_current_presence_token(self) -> int: | |||
return self._presence_id_gen.get_current_token() | |||
def _get_active_presence(self, db_conn: Connection): | |||
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]: | |||
"""Fetch non-offline presence from the database so that we can register | |||
the appropriate time outs. | |||
""" | |||
@@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore): | |||
return [UserPresenceState(**row) for row in rows] | |||
def take_presence_startup_info(self): | |||
def take_presence_startup_info(self) -> List[UserPresenceState]: | |||
active_on_startup = self._presence_on_startup | |||
self._presence_on_startup = None | |||
self._presence_on_startup = [] | |||
return active_on_startup | |||
def process_replication_rows(self, stream_name, instance_name, token, rows): | |||
def process_replication_rows(self, stream_name, instance_name, token, rows) -> None: | |||
if stream_name == PresenceStream.NAME: | |||
self._presence_id_gen.advance(instance_name, token) | |||
for row in rows: | |||
@@ -13,9 +13,10 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Any, List, Set, Tuple | |||
from typing import Any, List, Set, Tuple, cast | |||
from synapse.api.errors import SynapseError | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.databases.main import CacheInvalidationWorkerStore | |||
from synapse.storage.databases.main.state import StateGroupWorkerStore | |||
from synapse.types import RoomStreamToken | |||
@@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): | |||
) | |||
def _purge_history_txn( | |||
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool | |||
self, | |||
txn: LoggingTransaction, | |||
room_id: str, | |||
token: RoomStreamToken, | |||
delete_local_events: bool, | |||
) -> Set[int]: | |||
# Tables that should be pruned: | |||
# event_auth | |||
@@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): | |||
""", | |||
(room_id,), | |||
) | |||
(min_depth,) = txn.fetchone() | |||
(min_depth,) = cast(Tuple[int], txn.fetchone()) | |||
logger.info("[purge] updating room_depth to %d", min_depth) | |||
@@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): | |||
"purge_room", self._purge_room_txn, room_id | |||
) | |||
def _purge_room_txn(self, txn, room_id: str) -> List[int]: | |||
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: | |||
# First we fetch all the state groups that should be deleted, before | |||
# we delete that information. | |||
txn.execute( | |||
@@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
database: DatabasePool, | |||
db_conn: LoggingDatabaseConnection, | |||
hs: "HomeServer", | |||
): | |||
) -> None: | |||
super().__init__(database, db_conn, hs) | |||
self.server_name = hs.hostname | |||
@@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
processed_event_count = 0 | |||
for room_id, event_count in rooms_to_work_on: | |||
is_in_room = await self.is_host_joined(room_id, self.server_name) | |||
is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined] | |||
if is_in_room: | |||
users_with_profile = await self.get_users_in_room_with_profiles(room_id) | |||
users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined] | |||
# Throw away users excluded from the directory. | |||
users_with_profile = { | |||
user_id: profile | |||
@@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
for user_id in users_to_work_on: | |||
if await self.should_include_local_user_in_dir(user_id): | |||
profile = await self.get_profileinfo(get_localpart_from_id(user_id)) | |||
profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined] | |||
await self.update_profile_in_user_dir( | |||
user_id, profile.display_name, profile.avatar_url | |||
) | |||
@@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
# technically it could be DM-able. In the future, this could potentially | |||
# be configurable per-appservice whether the appservice sender can be | |||
# contacted. | |||
if self.get_app_service_by_user_id(user) is not None: | |||
if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined] | |||
return False | |||
# We're opting to exclude appservice users (anyone matching the user | |||
@@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
# they could be DM-able. In the future, this could potentially | |||
# be configurable per-appservice whether the appservice users can be | |||
# contacted. | |||
if self.get_if_app_services_interested_in_user(user): | |||
if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined] | |||
# TODO we might want to make this configurable for each app service | |||
return False | |||
# Support users are for diagnostics and should not appear in the user directory. | |||
if await self.is_support_user(user): | |||
if await self.is_support_user(user): # type: ignore[attr-defined] | |||
return False | |||
# Deactivated users aren't contactable, so should not appear in the user directory. | |||
try: | |||
if await self.get_user_deactivated_status(user): | |||
if await self.get_user_deactivated_status(user): # type: ignore[attr-defined] | |||
return False | |||
except StoreError: | |||
# No such user in the users table. No need to do this when calling | |||
@@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
(EventTypes.RoomHistoryVisibility, ""), | |||
) | |||
current_state_ids = await self.get_filtered_current_state_ids( | |||
current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined] | |||
room_id, StateFilter.from_types(types_to_filter) | |||
) | |||
join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) | |||
if join_rules_id: | |||
join_rule_ev = await self.get_event(join_rules_id, allow_none=True) | |||
join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined] | |||
if join_rule_ev: | |||
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: | |||
return True | |||
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) | |||
if hist_vis_id: | |||
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) | |||
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined] | |||
if hist_vis_ev: | |||
if ( | |||
hist_vis_ev.content.get("history_visibility") | |||
@@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name | |||
if TYPE_CHECKING: | |||
from synapse.appservice.api import ApplicationService | |||
from synapse.storage.databases.main import DataStore | |||
from synapse.storage.databases.main import DataStore, PurgeEventsStore | |||
# Define a state map type from type/state_key to T (usually an event ID or | |||
# event) | |||
@@ -485,7 +485,7 @@ class RoomStreamToken: | |||
) | |||
@classmethod | |||
async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": | |||
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": | |||
try: | |||
if string[0] == "s": | |||
return cls(topological=None, stream=int(string[1:])) | |||
@@ -502,7 +502,7 @@ class RoomStreamToken: | |||
instance_id = int(key) | |||
pos = int(value) | |||
instance_name = await store.get_name_from_instance_id(instance_id) | |||
instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] | |||
instance_map[instance_name] = pos | |||
return cls( | |||