@@ -0,0 +1 @@ | |||
Improve type hints. |
@@ -17,7 +17,7 @@ import logging | |||
import os | |||
import sys | |||
import tempfile | |||
from typing import List, Mapping, Optional | |||
from typing import List, Mapping, Optional, Sequence | |||
from twisted.internet import defer, task | |||
@@ -57,7 +57,7 @@ from synapse.storage.databases.main.state import StateGroupWorkerStore | |||
from synapse.storage.databases.main.stream import StreamWorkerStore | |||
from synapse.storage.databases.main.tags import TagsWorkerStore | |||
from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore | |||
from synapse.types import JsonDict, StateMap | |||
from synapse.types import JsonMapping, StateMap | |||
from synapse.util import SYNAPSE_VERSION | |||
from synapse.util.logcontext import LoggingContext | |||
@@ -198,7 +198,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): | |||
for event in state.values(): | |||
json.dump(event, fp=f) | |||
def write_profile(self, profile: JsonDict) -> None: | |||
def write_profile(self, profile: JsonMapping) -> None: | |||
user_directory = os.path.join(self.base_directory, "user_data") | |||
os.makedirs(user_directory, exist_ok=True) | |||
profile_file = os.path.join(user_directory, "profile") | |||
@@ -206,7 +206,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): | |||
with open(profile_file, "a") as f: | |||
json.dump(profile, fp=f) | |||
def write_devices(self, devices: List[JsonDict]) -> None: | |||
def write_devices(self, devices: Sequence[JsonMapping]) -> None: | |||
user_directory = os.path.join(self.base_directory, "user_data") | |||
os.makedirs(user_directory, exist_ok=True) | |||
device_file = os.path.join(user_directory, "devices") | |||
@@ -215,7 +215,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): | |||
with open(device_file, "a") as f: | |||
json.dump(device, fp=f) | |||
def write_connections(self, connections: List[JsonDict]) -> None: | |||
def write_connections(self, connections: Sequence[JsonMapping]) -> None: | |||
user_directory = os.path.join(self.base_directory, "user_data") | |||
os.makedirs(user_directory, exist_ok=True) | |||
connection_file = os.path.join(user_directory, "connections") | |||
@@ -225,7 +225,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): | |||
json.dump(connection, fp=f) | |||
def write_account_data( | |||
self, file_name: str, account_data: Mapping[str, JsonDict] | |||
self, file_name: str, account_data: Mapping[str, JsonMapping] | |||
) -> None: | |||
account_data_directory = os.path.join( | |||
self.base_directory, "user_data", "account_data" | |||
@@ -237,7 +237,7 @@ class FileExfiltrationWriter(ExfiltrationWriter): | |||
with open(account_data_file, "a") as f: | |||
json.dump(account_data, fp=f) | |||
def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: | |||
def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None: | |||
file_directory = os.path.join(self.base_directory, "media_ids") | |||
os.makedirs(file_directory, exist_ok=True) | |||
media_id_file = os.path.join(file_directory, media_id) | |||
@@ -14,11 +14,11 @@ | |||
import abc | |||
import logging | |||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set | |||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set | |||
from synapse.api.constants import Direction, Membership | |||
from synapse.events import EventBase | |||
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo | |||
from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo | |||
from synapse.visibility import filter_events_for_client | |||
if TYPE_CHECKING: | |||
@@ -35,7 +35,7 @@ class AdminHandler: | |||
self._state_storage_controller = self._storage_controllers.state | |||
self._msc3866_enabled = hs.config.experimental.msc3866.enabled | |||
async def get_whois(self, user: UserID) -> JsonDict: | |||
async def get_whois(self, user: UserID) -> JsonMapping: | |||
connections = [] | |||
sessions = await self._store.get_user_ip_and_agents(user) | |||
@@ -55,7 +55,7 @@ class AdminHandler: | |||
return ret | |||
async def get_user(self, user: UserID) -> Optional[JsonDict]: | |||
async def get_user(self, user: UserID) -> Optional[JsonMapping]: | |||
"""Function to get user details""" | |||
user_info: Optional[UserInfo] = await self._store.get_user_by_id( | |||
user.to_string() | |||
@@ -344,7 +344,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): | |||
raise NotImplementedError() | |||
@abc.abstractmethod | |||
def write_profile(self, profile: JsonDict) -> None: | |||
def write_profile(self, profile: JsonMapping) -> None: | |||
"""Write the profile of a user. | |||
Args: | |||
@@ -353,7 +353,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): | |||
raise NotImplementedError() | |||
@abc.abstractmethod | |||
def write_devices(self, devices: List[JsonDict]) -> None: | |||
def write_devices(self, devices: Sequence[JsonMapping]) -> None: | |||
"""Write the devices of a user. | |||
Args: | |||
@@ -362,7 +362,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): | |||
raise NotImplementedError() | |||
@abc.abstractmethod | |||
def write_connections(self, connections: List[JsonDict]) -> None: | |||
def write_connections(self, connections: Sequence[JsonMapping]) -> None: | |||
"""Write the connections of a user. | |||
Args: | |||
@@ -372,7 +372,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): | |||
@abc.abstractmethod | |||
def write_account_data( | |||
self, file_name: str, account_data: Mapping[str, JsonDict] | |||
self, file_name: str, account_data: Mapping[str, JsonMapping] | |||
) -> None: | |||
"""Write the account data of a user. | |||
@@ -383,7 +383,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta): | |||
raise NotImplementedError() | |||
@abc.abstractmethod | |||
def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: | |||
def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None: | |||
"""Write the media's metadata of a user. | |||
Exports only the metadata, as this can be fetched from the database via | |||
read only. In order to access the files, a connection to the correct | |||
@@ -57,6 +57,7 @@ from synapse.storage.roommember import MemberSummary | |||
from synapse.types import ( | |||
DeviceListUpdates, | |||
JsonDict, | |||
JsonMapping, | |||
MutableStateMap, | |||
Requester, | |||
RoomStreamToken, | |||
@@ -1793,19 +1794,23 @@ class SyncHandler: | |||
) | |||
if push_rules_changed: | |||
global_account_data = dict(global_account_data) | |||
global_account_data[ | |||
AccountDataTypes.PUSH_RULES | |||
] = await self._push_rules_handler.push_rules_for_user(sync_config.user) | |||
global_account_data = { | |||
AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user( | |||
sync_config.user | |||
), | |||
**global_account_data, | |||
} | |||
else: | |||
all_global_account_data = await self.store.get_global_account_data_for_user( | |||
user_id | |||
) | |||
global_account_data = dict(all_global_account_data) | |||
global_account_data[ | |||
AccountDataTypes.PUSH_RULES | |||
] = await self._push_rules_handler.push_rules_for_user(sync_config.user) | |||
global_account_data = { | |||
AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user( | |||
sync_config.user | |||
), | |||
**all_global_account_data, | |||
} | |||
account_data_for_user = ( | |||
await sync_config.filter_collection.filter_global_account_data( | |||
@@ -1909,7 +1914,7 @@ class SyncHandler: | |||
blocks_all_rooms | |||
or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data() | |||
): | |||
account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {} | |||
account_data_by_room: Mapping[str, Mapping[str, JsonMapping]] = {} | |||
elif since_token and not sync_result_builder.full_state: | |||
account_data_by_room = ( | |||
await self.store.get_updated_room_account_data_for_user( | |||
@@ -2349,8 +2354,8 @@ class SyncHandler: | |||
sync_result_builder: "SyncResultBuilder", | |||
room_builder: "RoomSyncResultBuilder", | |||
ephemeral: List[JsonDict], | |||
tags: Optional[Mapping[str, Mapping[str, Any]]], | |||
account_data: Mapping[str, JsonDict], | |||
tags: Optional[Mapping[str, JsonMapping]], | |||
account_data: Mapping[str, JsonMapping], | |||
always_include: bool = False, | |||
) -> None: | |||
"""Populates the `joined` and `archived` section of `sync_result_builder` | |||
@@ -39,7 +39,7 @@ from synapse.rest.admin._base import ( | |||
from synapse.rest.client._base import client_patterns | |||
from synapse.storage.databases.main.registration import ExternalIDReuseException | |||
from synapse.storage.databases.main.stats import UserSortOrder | |||
from synapse.types import JsonDict, UserID | |||
from synapse.types import JsonDict, JsonMapping, UserID | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
@@ -211,7 +211,7 @@ class UserRestServletV2(RestServlet): | |||
async def on_GET( | |||
self, request: SynapseRequest, user_id: str | |||
) -> Tuple[int, JsonDict]: | |||
) -> Tuple[int, JsonMapping]: | |||
await assert_requester_is_admin(self.auth, request) | |||
target_user = UserID.from_string(user_id) | |||
@@ -226,7 +226,7 @@ class UserRestServletV2(RestServlet): | |||
async def on_PUT( | |||
self, request: SynapseRequest, user_id: str | |||
) -> Tuple[int, JsonDict]: | |||
) -> Tuple[int, JsonMapping]: | |||
requester = await self.auth.get_user_by_req(request) | |||
await assert_user_is_admin(self.auth, requester) | |||
@@ -658,7 +658,7 @@ class WhoisRestServlet(RestServlet): | |||
async def on_GET( | |||
self, request: SynapseRequest, user_id: str | |||
) -> Tuple[int, JsonDict]: | |||
) -> Tuple[int, JsonMapping]: | |||
target_user = UserID.from_string(user_id) | |||
requester = await self.auth.get_user_by_req(request) | |||
@@ -20,7 +20,7 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError | |||
from synapse.http.server import HttpServer | |||
from synapse.http.servlet import RestServlet, parse_json_object_from_request | |||
from synapse.http.site import SynapseRequest | |||
from synapse.types import JsonDict, RoomID | |||
from synapse.types import JsonDict, JsonMapping, RoomID | |||
from ._base import client_patterns | |||
@@ -95,7 +95,7 @@ class AccountDataServlet(RestServlet): | |||
async def on_GET( | |||
self, request: SynapseRequest, user_id: str, account_data_type: str | |||
) -> Tuple[int, JsonDict]: | |||
) -> Tuple[int, JsonMapping]: | |||
requester = await self.auth.get_user_by_req(request) | |||
if user_id != requester.user.to_string(): | |||
raise AuthError(403, "Cannot get account data for other users.") | |||
@@ -106,7 +106,7 @@ class AccountDataServlet(RestServlet): | |||
and account_data_type == AccountDataTypes.PUSH_RULES | |||
): | |||
account_data: Optional[ | |||
JsonDict | |||
JsonMapping | |||
] = await self._push_rules_handler.push_rules_for_user(requester.user) | |||
else: | |||
account_data = await self.store.get_global_account_data_by_type_for_user( | |||
@@ -236,7 +236,7 @@ class RoomAccountDataServlet(RestServlet): | |||
user_id: str, | |||
room_id: str, | |||
account_data_type: str, | |||
) -> Tuple[int, JsonDict]: | |||
) -> Tuple[int, JsonMapping]: | |||
requester = await self.auth.get_user_by_req(request) | |||
if user_id != requester.user.to_string(): | |||
raise AuthError(403, "Cannot get account data for other users.") | |||
@@ -253,7 +253,7 @@ class RoomAccountDataServlet(RestServlet): | |||
self._hs.config.experimental.msc4010_push_rules_account_data | |||
and account_data_type == AccountDataTypes.PUSH_RULES | |||
): | |||
account_data: Optional[JsonDict] = {} | |||
account_data: Optional[JsonMapping] = {} | |||
else: | |||
account_data = await self.store.get_account_data_for_room_and_type( | |||
user_id, room_id, account_data_type | |||
@@ -43,7 +43,7 @@ from synapse.storage.util.id_generators import ( | |||
MultiWriterIdGenerator, | |||
StreamIdGenerator, | |||
) | |||
from synapse.types import JsonDict | |||
from synapse.types import JsonDict, JsonMapping | |||
from synapse.util import json_encoder | |||
from synapse.util.caches.descriptors import cached | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
@@ -119,7 +119,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
@cached() | |||
async def get_global_account_data_for_user( | |||
self, user_id: str | |||
) -> Mapping[str, JsonDict]: | |||
) -> Mapping[str, JsonMapping]: | |||
""" | |||
Get all the global client account_data for a user. | |||
@@ -164,7 +164,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
@cached() | |||
async def get_room_account_data_for_user( | |||
self, user_id: str | |||
) -> Mapping[str, Mapping[str, JsonDict]]: | |||
) -> Mapping[str, Mapping[str, JsonMapping]]: | |||
""" | |||
Get all of the per-room client account_data for a user. | |||
@@ -213,7 +213,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
@cached(num_args=2, max_entries=5000, tree=True) | |||
async def get_global_account_data_by_type_for_user( | |||
self, user_id: str, data_type: str | |||
) -> Optional[JsonDict]: | |||
) -> Optional[JsonMapping]: | |||
""" | |||
Returns: | |||
The account data. | |||
@@ -265,7 +265,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
@cached(num_args=2, tree=True) | |||
async def get_account_data_for_room( | |||
self, user_id: str, room_id: str | |||
) -> Mapping[str, JsonDict]: | |||
) -> Mapping[str, JsonMapping]: | |||
"""Get all the client account_data for a user for a room. | |||
Args: | |||
@@ -296,7 +296,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
@cached(num_args=3, max_entries=5000, tree=True) | |||
async def get_account_data_for_room_and_type( | |||
self, user_id: str, room_id: str, account_data_type: str | |||
) -> Optional[JsonDict]: | |||
) -> Optional[JsonMapping]: | |||
"""Get the client account_data of given type for a user for a room. | |||
Args: | |||
@@ -394,7 +394,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
async def get_updated_global_account_data_for_user( | |||
self, user_id: str, stream_id: int | |||
) -> Dict[str, JsonDict]: | |||
) -> Mapping[str, JsonMapping]: | |||
"""Get all the global account_data that's changed for a user. | |||
Args: | |||
@@ -12,11 +12,10 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING, Dict | |||
from typing import TYPE_CHECKING, Dict, FrozenSet | |||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | |||
from synapse.storage.databases.main import CacheInvalidationWorkerStore | |||
from synapse.types import StrCollection | |||
from synapse.util.caches.descriptors import cached | |||
if TYPE_CHECKING: | |||
@@ -34,7 +33,7 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore): | |||
super().__init__(database, db_conn, hs) | |||
@cached() | |||
async def list_enabled_features(self, user_id: str) -> StrCollection: | |||
async def list_enabled_features(self, user_id: str) -> FrozenSet[str]: | |||
""" | |||
Checks to see what features are enabled for a given user | |||
Args: | |||
@@ -49,7 +48,7 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore): | |||
["feature"], | |||
) | |||
return [feature["feature"] for feature in enabled] | |||
return frozenset(feature["feature"] for feature in enabled) | |||
async def set_features_for_user( | |||
self, | |||
@@ -23,7 +23,7 @@ from synapse.storage._base import db_to_json | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore | |||
from synapse.storage.util.id_generators import AbstractStreamIdGenerator | |||
from synapse.types import JsonDict | |||
from synapse.types import JsonDict, JsonMapping | |||
from synapse.util import json_encoder | |||
from synapse.util.caches.descriptors import cached | |||
@@ -34,7 +34,7 @@ class TagsWorkerStore(AccountDataWorkerStore): | |||
@cached() | |||
async def get_tags_for_user( | |||
self, user_id: str | |||
) -> Mapping[str, Mapping[str, JsonDict]]: | |||
) -> Mapping[str, Mapping[str, JsonMapping]]: | |||
"""Get all the tags for a user. | |||
@@ -109,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore): | |||
async def get_updated_tags( | |||
self, user_id: str, stream_id: int | |||
) -> Mapping[str, Mapping[str, JsonDict]]: | |||
) -> Mapping[str, Mapping[str, JsonMapping]]: | |||
"""Get all the tags for the rooms where the tags have changed since the | |||
given version | |||