@@ -0,0 +1 @@ | |||
Convert internal type variable syntax to reflect wider ecosystem use. |
@@ -63,9 +63,9 @@ class Auth: | |||
self.store = hs.get_datastore() | |||
self.state = hs.get_state_handler() | |||
self.token_cache = LruCache( | |||
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache( | |||
10000, "token_cache" | |||
) # type: LruCache[str, Tuple[str, bool]] | |||
) | |||
self._auth_blocking = AuthBlocking(self.hs) | |||
@@ -118,7 +118,7 @@ class RedirectException(CodeMessageException): | |||
super().__init__(code=http_code, msg=msg) | |||
self.location = location | |||
self.cookies = [] # type: List[bytes] | |||
self.cookies: List[bytes] = [] | |||
class SynapseError(CodeMessageException): | |||
@@ -160,7 +160,7 @@ class ProxiedRequestError(SynapseError): | |||
): | |||
super().__init__(code, msg, errcode) | |||
if additional_fields is None: | |||
self._additional_fields = {} # type: Dict | |||
self._additional_fields: Dict = {} | |||
else: | |||
self._additional_fields = dict(additional_fields) | |||
@@ -289,7 +289,7 @@ class Filter: | |||
room_id = None | |||
ev_type = "m.presence" | |||
contains_url = False | |||
labels = [] # type: List[str] | |||
labels: List[str] = [] | |||
else: | |||
sender = event.get("sender", None) | |||
if not sender: | |||
@@ -46,9 +46,7 @@ class Ratelimiter: | |||
# * How many times an action has occurred since a point in time | |||
# * The point in time | |||
# * The rate_hz of this particular entry. This can vary per request | |||
self.actions = ( | |||
OrderedDict() | |||
) # type: OrderedDict[Hashable, Tuple[float, int, float]] | |||
self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict() | |||
async def can_do_action( | |||
self, | |||
@@ -195,7 +195,7 @@ class RoomVersions: | |||
) | |||
KNOWN_ROOM_VERSIONS = { | |||
KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { | |||
v.identifier: v | |||
for v in ( | |||
RoomVersions.V1, | |||
@@ -209,4 +209,4 @@ KNOWN_ROOM_VERSIONS = { | |||
RoomVersions.V7, | |||
) | |||
# Note that we do not include MSC2043 here unless it is enabled in the config. | |||
} # type: Dict[str, RoomVersion] | |||
} |
@@ -270,7 +270,7 @@ class GenericWorkerServer(HomeServer): | |||
site_tag = port | |||
# We always include a health resource. | |||
resources = {"/health": HealthResource()} # type: Dict[str, IResource] | |||
resources: Dict[str, IResource] = {"/health": HealthResource()} | |||
for res in listener_config.http_options.resources: | |||
for name in res.names: | |||
@@ -88,9 +88,9 @@ class ApplicationServiceApi(SimpleHttpClient): | |||
super().__init__(hs) | |||
self.clock = hs.get_clock() | |||
self.protocol_meta_cache = ResponseCache( | |||
self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache( | |||
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS | |||
) # type: ResponseCache[Tuple[str, str]] | |||
) | |||
async def query_user(self, service, user_id): | |||
if service.url is None: | |||
@@ -57,8 +57,8 @@ def load_appservices(hostname, config_files): | |||
return [] | |||
# Dicts of value -> filename | |||
seen_as_tokens = {} # type: Dict[str, str] | |||
seen_ids = {} # type: Dict[str, str] | |||
seen_as_tokens: Dict[str, str] = {} | |||
seen_ids: Dict[str, str] = {} | |||
appservices = [] | |||
@@ -25,7 +25,7 @@ from ._base import Config, ConfigError | |||
_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" | |||
# Map from canonicalised cache name to cache. | |||
_CACHES = {} # type: Dict[str, Callable[[float], None]] | |||
_CACHES: Dict[str, Callable[[float], None]] = {} | |||
# a lock on the contents of _CACHES | |||
_CACHES_LOCK = threading.Lock() | |||
@@ -157,7 +157,7 @@ class CacheConfig(Config): | |||
self.event_cache_size = self.parse_size( | |||
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) | |||
) | |||
self.cache_factors = {} # type: Dict[str, float] | |||
self.cache_factors: Dict[str, float] = {} | |||
cache_config = config.get("caches") or {} | |||
self.global_factor = cache_config.get( | |||
@@ -134,9 +134,9 @@ class EmailConfig(Config): | |||
# trusted_third_party_id_servers does not contain a scheme whereas | |||
# account_threepid_delegate_email is expected to. Presume https | |||
self.account_threepid_delegate_email = ( | |||
self.account_threepid_delegate_email: Optional[str] = ( | |||
"https://" + first_trusted_identity_server | |||
) # type: Optional[str] | |||
) | |||
self.using_identity_server_from_trusted_list = True | |||
else: | |||
raise ConfigError( | |||
@@ -25,10 +25,10 @@ class ExperimentalConfig(Config): | |||
experimental = config.get("experimental_features") or {} | |||
# MSC2858 (multiple SSO identity providers) | |||
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool | |||
self.msc2858_enabled: bool = experimental.get("msc2858_enabled", False) | |||
# MSC3026 (busy presence state) | |||
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool | |||
self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) | |||
# MSC2716 (backfill existing history) | |||
self.msc2716_enabled = experimental.get("msc2716_enabled", False) # type: bool | |||
self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) |
@@ -22,7 +22,7 @@ class FederationConfig(Config): | |||
def read_config(self, config, **kwargs): | |||
# FIXME: federation_domain_whitelist needs sytests | |||
self.federation_domain_whitelist = None # type: Optional[dict] | |||
self.federation_domain_whitelist: Optional[dict] = None | |||
federation_domain_whitelist = config.get("federation_domain_whitelist", None) | |||
if federation_domain_whitelist is not None: | |||
@@ -460,7 +460,7 @@ def _parse_oidc_config_dict( | |||
) from e | |||
client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key") | |||
client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey] | |||
client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] = None | |||
if client_secret_jwt_key_config is not None: | |||
keyfile = client_secret_jwt_key_config.get("key_file") | |||
if keyfile: | |||
@@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config): | |||
section = "authproviders" | |||
def read_config(self, config, **kwargs): | |||
self.password_providers = [] # type: List[Any] | |||
self.password_providers: List[Any] = [] | |||
providers = [] | |||
# We want to be backwards compatible with the old `ldap_config` | |||
@@ -62,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): | |||
Dictionary mapping from media type string to list of | |||
ThumbnailRequirement tuples. | |||
""" | |||
requirements = {} # type: Dict[str, List] | |||
requirements: Dict[str, List] = {} | |||
for size in thumbnail_sizes: | |||
width = size["width"] | |||
height = size["height"] | |||
@@ -141,7 +141,7 @@ class ContentRepositoryConfig(Config): | |||
# | |||
# We don't create the storage providers here as not all workers need | |||
# them to be started. | |||
self.media_storage_providers = [] # type: List[tuple] | |||
self.media_storage_providers: List[tuple] = [] | |||
for i, provider_config in enumerate(storage_providers): | |||
# We special case the module "file_system" so as not to need to | |||
@@ -505,7 +505,7 @@ class ServerConfig(Config): | |||
" greater than 'allowed_lifetime_max'" | |||
) | |||
self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]] | |||
self.retention_purge_jobs: List[Dict[str, Optional[int]]] = [] | |||
for purge_job_config in retention_config.get("purge_jobs", []): | |||
interval_config = purge_job_config.get("interval") | |||
@@ -688,23 +688,21 @@ class ServerConfig(Config): | |||
# not included in the sample configuration file on purpose as it's a temporary | |||
# hack, so that some users can trial the new defaults without impacting every | |||
# user on the homeserver. | |||
users_new_default_push_rules = ( | |||
users_new_default_push_rules: list = ( | |||
config.get("users_new_default_push_rules") or [] | |||
) # type: list | |||
) | |||
if not isinstance(users_new_default_push_rules, list): | |||
raise ConfigError("'users_new_default_push_rules' must be a list") | |||
# Turn the list into a set to improve lookup speed. | |||
self.users_new_default_push_rules = set( | |||
users_new_default_push_rules | |||
) # type: set | |||
self.users_new_default_push_rules: set = set(users_new_default_push_rules) | |||
# Whitelist of domain names that given next_link parameters must have | |||
next_link_domain_whitelist = config.get( | |||
next_link_domain_whitelist: Optional[List[str]] = config.get( | |||
"next_link_domain_whitelist" | |||
) # type: Optional[List[str]] | |||
) | |||
self.next_link_domain_whitelist = None # type: Optional[Set[str]] | |||
self.next_link_domain_whitelist: Optional[Set[str]] = None | |||
if next_link_domain_whitelist is not None: | |||
if not isinstance(next_link_domain_whitelist, list): | |||
raise ConfigError("'next_link_domain_whitelist' must be a list") | |||
@@ -34,7 +34,7 @@ class SpamCheckerConfig(Config): | |||
section = "spamchecker" | |||
def read_config(self, config, **kwargs): | |||
self.spam_checkers = [] # type: List[Tuple[Any, Dict]] | |||
self.spam_checkers: List[Tuple[Any, Dict]] = [] | |||
spam_checkers = config.get("spam_checker") or [] | |||
if isinstance(spam_checkers, dict): | |||
@@ -39,7 +39,7 @@ class SSOConfig(Config): | |||
section = "sso" | |||
def read_config(self, config, **kwargs): | |||
sso_config = config.get("sso") or {} # type: Dict[str, Any] | |||
sso_config: Dict[str, Any] = config.get("sso") or {} | |||
# The sso-specific template_dir | |||
self.sso_template_dir = sso_config.get("template_dir") | |||
@@ -80,7 +80,7 @@ class TlsConfig(Config): | |||
fed_whitelist_entries = [] | |||
# Support globs (*) in whitelist values | |||
self.federation_certificate_verification_whitelist = [] # type: List[Pattern] | |||
self.federation_certificate_verification_whitelist: List[Pattern] = [] | |||
for entry in fed_whitelist_entries: | |||
try: | |||
entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii")) | |||
@@ -132,8 +132,8 @@ class TlsConfig(Config): | |||
"use_insecure_ssl_client_just_for_testing_do_not_use" | |||
) | |||
self.tls_certificate = None # type: Optional[crypto.X509] | |||
self.tls_private_key = None # type: Optional[crypto.PKey] | |||
self.tls_certificate: Optional[crypto.X509] = None | |||
self.tls_private_key: Optional[crypto.PKey] = None | |||
def is_disk_cert_valid(self, allow_self_signed=True): | |||
""" | |||
@@ -170,11 +170,13 @@ class Keyring: | |||
) | |||
self._key_fetchers = key_fetchers | |||
self._server_queue = BatchingQueue( | |||
self._server_queue: BatchingQueue[ | |||
_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]] | |||
] = BatchingQueue( | |||
"keyring_server", | |||
clock=hs.get_clock(), | |||
process_batch_callback=self._inner_fetch_key_requests, | |||
) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]] | |||
) | |||
async def verify_json_for_server( | |||
self, | |||
@@ -330,7 +332,7 @@ class Keyring: | |||
# First we need to deduplicate requests for the same key. We do this by | |||
# taking the *maximum* requested `minimum_valid_until_ts` for each pair | |||
# of server name/key ID. | |||
server_to_key_to_ts = {} # type: Dict[str, Dict[str, int]] | |||
server_to_key_to_ts: Dict[str, Dict[str, int]] = {} | |||
for request in requests: | |||
by_server = server_to_key_to_ts.setdefault(request.server_name, {}) | |||
for key_id in request.key_ids: | |||
@@ -355,7 +357,7 @@ class Keyring: | |||
# We now convert the returned list of results into a map from server | |||
# name to key ID to FetchKeyResult, to return. | |||
to_return = {} # type: Dict[str, Dict[str, FetchKeyResult]] | |||
to_return: Dict[str, Dict[str, FetchKeyResult]] = {} | |||
for (request, results) in zip(deduped_requests, results_per_request): | |||
to_return_by_server = to_return.setdefault(request.server_name, {}) | |||
for key_id, key_result in results.items(): | |||
@@ -455,7 +457,7 @@ class StoreKeyFetcher(KeyFetcher): | |||
) | |||
res = await self.store.get_server_verify_keys(key_ids_to_fetch) | |||
keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] | |||
keys: Dict[str, Dict[str, FetchKeyResult]] = {} | |||
for (server_name, key_id), key in res.items(): | |||
keys.setdefault(server_name, {})[key_id] = key | |||
return keys | |||
@@ -603,7 +605,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): | |||
).addErrback(unwrapFirstError) | |||
) | |||
union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] | |||
union_of_keys: Dict[str, Dict[str, FetchKeyResult]] = {} | |||
for result in results: | |||
for server_name, keys in result.items(): | |||
union_of_keys.setdefault(server_name, {}).update(keys) | |||
@@ -656,8 +658,8 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): | |||
except HttpResponseException as e: | |||
raise KeyLookupError("Remote server returned an error: %s" % (e,)) | |||
keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] | |||
added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]] | |||
keys: Dict[str, Dict[str, FetchKeyResult]] = {} | |||
added_keys: List[Tuple[str, str, FetchKeyResult]] = [] | |||
time_now_ms = self.clock.time_msec() | |||
@@ -805,7 +807,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): | |||
Raises: | |||
KeyLookupError if there was a problem making the lookup | |||
""" | |||
keys = {} # type: Dict[str, FetchKeyResult] | |||
keys: Dict[str, FetchKeyResult] = {} | |||
for requested_key_id in key_ids: | |||
# we may have found this key as a side-effect of asking for another. | |||
@@ -531,7 +531,7 @@ def _check_power_levels( | |||
user_level = get_user_power_level(event.user_id, auth_events) | |||
# Check other levels: | |||
levels_to_check = [ | |||
levels_to_check: List[Tuple[str, Optional[str]]] = [ | |||
("users_default", None), | |||
("events_default", None), | |||
("state_default", None), | |||
@@ -539,7 +539,7 @@ def _check_power_levels( | |||
("redact", None), | |||
("kick", None), | |||
("invite", None), | |||
] # type: List[Tuple[str, Optional[str]]] | |||
] | |||
old_list = current_state.content.get("users", {}) | |||
for user in set(list(old_list) + list(user_list)): | |||
@@ -569,12 +569,12 @@ def _check_power_levels( | |||
new_loc = new_loc.get(dir, {}) | |||
if level_to_check in old_loc: | |||
old_level = int(old_loc[level_to_check]) # type: Optional[int] | |||
old_level: Optional[int] = int(old_loc[level_to_check]) | |||
else: | |||
old_level = None | |||
if level_to_check in new_loc: | |||
new_level = int(new_loc[level_to_check]) # type: Optional[int] | |||
new_level: Optional[int] = int(new_loc[level_to_check]) | |||
else: | |||
new_level = None | |||
@@ -105,28 +105,28 @@ class _EventInternalMetadata: | |||
self._dict = dict(internal_metadata_dict) | |||
# the stream ordering of this event. None, until it has been persisted. | |||
self.stream_ordering = None # type: Optional[int] | |||
self.stream_ordering: Optional[int] = None | |||
# whether this event is an outlier (ie, whether we have the state at that point | |||
# in the DAG) | |||
self.outlier = False | |||
out_of_band_membership = DictProperty("out_of_band_membership") # type: bool | |||
send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str | |||
recheck_redaction = DictProperty("recheck_redaction") # type: bool | |||
soft_failed = DictProperty("soft_failed") # type: bool | |||
proactively_send = DictProperty("proactively_send") # type: bool | |||
redacted = DictProperty("redacted") # type: bool | |||
txn_id = DictProperty("txn_id") # type: str | |||
token_id = DictProperty("token_id") # type: int | |||
historical = DictProperty("historical") # type: bool | |||
out_of_band_membership: bool = DictProperty("out_of_band_membership") | |||
send_on_behalf_of: str = DictProperty("send_on_behalf_of") | |||
recheck_redaction: bool = DictProperty("recheck_redaction") | |||
soft_failed: bool = DictProperty("soft_failed") | |||
proactively_send: bool = DictProperty("proactively_send") | |||
redacted: bool = DictProperty("redacted") | |||
txn_id: str = DictProperty("txn_id") | |||
token_id: int = DictProperty("token_id") | |||
historical: bool = DictProperty("historical") | |||
# XXX: These are set by StreamWorkerStore._set_before_and_after. | |||
# I'm pretty sure that these are never persisted to the database, so shouldn't | |||
# be here | |||
before = DictProperty("before") # type: RoomStreamToken | |||
after = DictProperty("after") # type: RoomStreamToken | |||
order = DictProperty("order") # type: Tuple[int, int] | |||
before: RoomStreamToken = DictProperty("before") | |||
after: RoomStreamToken = DictProperty("after") | |||
order: Tuple[int, int] = DictProperty("order") | |||
def get_dict(self) -> JsonDict: | |||
return dict(self._dict) | |||
@@ -132,12 +132,12 @@ class EventBuilder: | |||
format_version = self.room_version.event_format | |||
if format_version == EventFormatVersions.V1: | |||
# The types of auth/prev events changes between event versions. | |||
auth_events = await self._store.add_event_hashes( | |||
auth_event_ids | |||
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] | |||
prev_events = await self._store.add_event_hashes( | |||
prev_event_ids | |||
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] | |||
auth_events: Union[ | |||
List[str], List[Tuple[str, Dict[str, str]]] | |||
] = await self._store.add_event_hashes(auth_event_ids) | |||
prev_events: Union[ | |||
List[str], List[Tuple[str, Dict[str, str]]] | |||
] = await self._store.add_event_hashes(prev_event_ids) | |||
else: | |||
auth_events = auth_event_ids | |||
prev_events = prev_event_ids | |||
@@ -156,7 +156,7 @@ class EventBuilder: | |||
# the db) | |||
depth = min(depth, MAX_DEPTH) | |||
event_dict = { | |||
event_dict: Dict[str, Any] = { | |||
"auth_events": auth_events, | |||
"prev_events": prev_events, | |||
"type": self.type, | |||
@@ -166,7 +166,7 @@ class EventBuilder: | |||
"unsigned": self.unsigned, | |||
"depth": depth, | |||
"prev_state": [], | |||
} # type: Dict[str, Any] | |||
} | |||
if self.is_state(): | |||
event_dict["state_key"] = self._state_key | |||
@@ -76,7 +76,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): | |||
"""Wrapper that loads spam checkers configured using the old configuration, and | |||
registers the spam checker hooks they implement. | |||
""" | |||
spam_checkers = [] # type: List[Any] | |||
spam_checkers: List[Any] = [] | |||
api = hs.get_module_api() | |||
for module, config in hs.config.spam_checkers: | |||
# Older spam checkers don't accept the `api` argument, so we | |||
@@ -239,7 +239,7 @@ class SpamChecker: | |||
will be used as the error message returned to the user. | |||
""" | |||
for callback in self._check_event_for_spam_callbacks: | |||
res = await callback(event) # type: Union[bool, str] | |||
res: Union[bool, str] = await callback(event) | |||
if res: | |||
return res | |||
@@ -86,7 +86,7 @@ class FederationClient(FederationBase): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]] | |||
self.pdu_destination_tried: Dict[str, Dict[str, int]] = {} | |||
self._clock.looping_call(self._clear_tried_cache, 60 * 1000) | |||
self.state = hs.get_state_handler() | |||
self.transport_layer = hs.get_federation_transport_client() | |||
@@ -94,13 +94,13 @@ class FederationClient(FederationBase): | |||
self.hostname = hs.hostname | |||
self.signing_key = hs.signing_key | |||
self._get_pdu_cache = ExpiringCache( | |||
self._get_pdu_cache: ExpiringCache[str, EventBase] = ExpiringCache( | |||
cache_name="get_pdu_cache", | |||
clock=self._clock, | |||
max_len=1000, | |||
expiry_ms=120 * 1000, | |||
reset_expiry_on_get=False, | |||
) # type: ExpiringCache[str, EventBase] | |||
) | |||
def _clear_tried_cache(self): | |||
"""Clear pdu_destination_tried cache""" | |||
@@ -293,10 +293,10 @@ class FederationClient(FederationBase): | |||
transaction_data, | |||
) | |||
pdu_list = [ | |||
pdu_list: List[EventBase] = [ | |||
event_from_pdu_json(p, room_version, outlier=outlier) | |||
for p in transaction_data["pdus"] | |||
] # type: List[EventBase] | |||
] | |||
if pdu_list and pdu_list[0]: | |||
pdu = pdu_list[0] | |||
@@ -122,12 +122,12 @@ class FederationServer(FederationBase): | |||
# origins that we are currently processing a transaction from. | |||
# a dict from origin to txn id. | |||
self._active_transactions = {} # type: Dict[str, str] | |||
self._active_transactions: Dict[str, str] = {} | |||
# We cache results for transaction with the same ID | |||
self._transaction_resp_cache = ResponseCache( | |||
self._transaction_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache( | |||
hs.get_clock(), "fed_txn_handler", timeout_ms=30000 | |||
) # type: ResponseCache[Tuple[str, str]] | |||
) | |||
self.transaction_actions = TransactionActions(self.store) | |||
@@ -135,12 +135,12 @@ class FederationServer(FederationBase): | |||
# We cache responses to state queries, as they take a while and often | |||
# come in waves. | |||
self._state_resp_cache = ResponseCache( | |||
hs.get_clock(), "state_resp", timeout_ms=30000 | |||
) # type: ResponseCache[Tuple[str, Optional[str]]] | |||
self._state_ids_resp_cache = ResponseCache( | |||
self._state_resp_cache: ResponseCache[ | |||
Tuple[str, Optional[str]] | |||
] = ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000) | |||
self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache( | |||
hs.get_clock(), "state_ids_resp", timeout_ms=30000 | |||
) # type: ResponseCache[Tuple[str, str]] | |||
) | |||
self._federation_metrics_domains = ( | |||
hs.config.federation.federation_metrics_domains | |||
@@ -337,7 +337,7 @@ class FederationServer(FederationBase): | |||
origin_host, _ = parse_server_name(origin) | |||
pdus_by_room = {} # type: Dict[str, List[EventBase]] | |||
pdus_by_room: Dict[str, List[EventBase]] = {} | |||
newest_pdu_ts = 0 | |||
@@ -516,9 +516,9 @@ class FederationServer(FederationBase): | |||
self, room_id: str, event_id: Optional[str] | |||
) -> Dict[str, list]: | |||
if event_id: | |||
pdus = await self.handler.get_state_for_pdu( | |||
pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu( | |||
room_id, event_id | |||
) # type: Iterable[EventBase] | |||
) | |||
else: | |||
pdus = (await self.state.get_current_state(room_id)).values() | |||
@@ -791,7 +791,7 @@ class FederationServer(FederationBase): | |||
log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) | |||
results = await self.store.claim_e2e_one_time_keys(query) | |||
json_result = {} # type: Dict[str, Dict[str, dict]] | |||
json_result: Dict[str, Dict[str, dict]] = {} | |||
for user_id, device_keys in results.items(): | |||
for device_id, keys in device_keys.items(): | |||
for key_id, json_str in keys.items(): | |||
@@ -1119,17 +1119,13 @@ class FederationHandlerRegistry: | |||
self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) | |||
self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) | |||
self.edu_handlers = ( | |||
{} | |||
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] | |||
self.query_handlers = ( | |||
{} | |||
) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]] | |||
self.edu_handlers: Dict[str, Callable[[str, dict], Awaitable[None]]] = {} | |||
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} | |||
# Map from type to instance names that we should route EDU handling to. | |||
# We randomly choose one instance from the list to route to for each new | |||
# EDU received. | |||
self._edu_type_to_instance = {} # type: Dict[str, List[str]] | |||
self._edu_type_to_instance: Dict[str, List[str]] = {} | |||
def register_edu_handler( | |||
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] | |||
@@ -71,34 +71,32 @@ class FederationRemoteSendQueue(AbstractFederationSender): | |||
# We may have multiple federation sender instances, so we need to track | |||
# their positions separately. | |||
self._sender_instances = hs.config.worker.federation_shard_config.instances | |||
self._sender_positions = {} # type: Dict[str, int] | |||
self._sender_positions: Dict[str, int] = {} | |||
# Pending presence map user_id -> UserPresenceState | |||
self.presence_map = {} # type: Dict[str, UserPresenceState] | |||
self.presence_map: Dict[str, UserPresenceState] = {} | |||
# Stores the destinations we need to explicitly send presence to about a | |||
# given user. | |||
# Stream position -> (user_id, destinations) | |||
self.presence_destinations = ( | |||
SortedDict() | |||
) # type: SortedDict[int, Tuple[str, Iterable[str]]] | |||
self.presence_destinations: SortedDict[ | |||
int, Tuple[str, Iterable[str]] | |||
] = SortedDict() | |||
# (destination, key) -> EDU | |||
self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu] | |||
self.keyed_edu: Dict[Tuple[str, tuple], Edu] = {} | |||
# stream position -> (destination, key) | |||
self.keyed_edu_changed = ( | |||
SortedDict() | |||
) # type: SortedDict[int, Tuple[str, tuple]] | |||
self.keyed_edu_changed: SortedDict[int, Tuple[str, tuple]] = SortedDict() | |||
self.edus = SortedDict() # type: SortedDict[int, Edu] | |||
self.edus: SortedDict[int, Edu] = SortedDict() | |||
# stream ID for the next entry into keyed_edu_changed/edus. | |||
self.pos = 1 | |||
# map from stream ID to the time that stream entry was generated, so that we | |||
# can clear out entries after a while | |||
self.pos_time = SortedDict() # type: SortedDict[int, int] | |||
self.pos_time: SortedDict[int, int] = SortedDict() | |||
# EVERYTHING IS SAD. In particular, python only makes new scopes when | |||
# we make a new function, so we need to make a new function so the inner | |||
@@ -291,7 +289,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): | |||
# list of tuple(int, BaseFederationRow), where the first is the position | |||
# of the federation stream. | |||
rows = [] # type: List[Tuple[int, BaseFederationRow]] | |||
rows: List[Tuple[int, BaseFederationRow]] = [] | |||
# Fetch presence to send to destinations | |||
i = self.presence_destinations.bisect_right(from_token) | |||
@@ -445,11 +443,11 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu | |||
buff.edus.setdefault(self.edu.destination, []).append(self.edu) | |||
_rowtypes = ( | |||
_rowtypes: Tuple[Type[BaseFederationRow], ...] = ( | |||
PresenceDestinationsRow, | |||
KeyedEduRow, | |||
EduRow, | |||
) # type: Tuple[Type[BaseFederationRow], ...] | |||
) | |||
TypeToRow = {Row.TypeId: Row for Row in _rowtypes} | |||
@@ -148,14 +148,14 @@ class FederationSender(AbstractFederationSender): | |||
self.clock = hs.get_clock() | |||
self.is_mine_id = hs.is_mine_id | |||
self._presence_router = None # type: Optional[PresenceRouter] | |||
self._presence_router: Optional["PresenceRouter"] = None | |||
self._transaction_manager = TransactionManager(hs) | |||
self._instance_name = hs.get_instance_name() | |||
self._federation_shard_config = hs.config.worker.federation_shard_config | |||
# map from destination to PerDestinationQueue | |||
self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] | |||
self._per_destination_queues: Dict[str, PerDestinationQueue] = {} | |||
LaterGauge( | |||
"synapse_federation_transaction_queue_pending_destinations", | |||
@@ -192,9 +192,7 @@ class FederationSender(AbstractFederationSender): | |||
# awaiting a call to flush_read_receipts_for_room. The presence of an entry | |||
# here for a given room means that we are rate-limiting RR flushes to that room, | |||
# and that there is a pending call to _flush_rrs_for_room in the system. | |||
self._queues_awaiting_rr_flush_by_room = ( | |||
{} | |||
) # type: Dict[str, Set[PerDestinationQueue]] | |||
self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {} | |||
self._rr_txn_interval_per_room_ms = ( | |||
1000.0 / hs.config.federation_rr_transactions_per_room_per_second | |||
@@ -265,7 +263,7 @@ class FederationSender(AbstractFederationSender): | |||
if not event.internal_metadata.should_proactively_send(): | |||
return | |||
destinations = None # type: Optional[Set[str]] | |||
destinations: Optional[Set[str]] = None | |||
if not event.prev_event_ids(): | |||
# If there are no prev event IDs then the state is empty | |||
# and so no remote servers in the room | |||
@@ -331,7 +329,7 @@ class FederationSender(AbstractFederationSender): | |||
for event in events: | |||
await handle_event(event) | |||
events_by_room = {} # type: Dict[str, List[EventBase]] | |||
events_by_room: Dict[str, List[EventBase]] = {} | |||
for event in events: | |||
events_by_room.setdefault(event.room_id, []).append(event) | |||
@@ -628,7 +626,7 @@ class FederationSender(AbstractFederationSender): | |||
In order to reduce load spikes, adds a delay between each destination. | |||
""" | |||
last_processed = None # type: Optional[str] | |||
last_processed: Optional[str] = None | |||
while True: | |||
destinations_to_wake = ( | |||
@@ -105,34 +105,34 @@ class PerDestinationQueue: | |||
# catch-up at startup. | |||
# New events will only be sent once this is finished, at which point | |||
# _catching_up is flipped to False. | |||
self._catching_up = True # type: bool | |||
self._catching_up: bool = True | |||
# The stream_ordering of the most recent PDU that was discarded due to | |||
# being in catch-up mode. | |||
self._catchup_last_skipped = 0 # type: int | |||
self._catchup_last_skipped: int = 0 | |||
# Cache of the last successfully-transmitted stream ordering for this | |||
# destination (we are the only updater so this is safe) | |||
self._last_successful_stream_ordering = None # type: Optional[int] | |||
self._last_successful_stream_ordering: Optional[int] = None | |||
# a queue of pending PDUs | |||
self._pending_pdus = [] # type: List[EventBase] | |||
self._pending_pdus: List[EventBase] = [] | |||
# XXX this is never actually used: see | |||
# https://github.com/matrix-org/synapse/issues/7549 | |||
self._pending_edus = [] # type: List[Edu] | |||
self._pending_edus: List[Edu] = [] | |||
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered | |||
# based on their key (e.g. typing events by room_id) | |||
# Map of (edu_type, key) -> Edu | |||
self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu] | |||
self._pending_edus_keyed: Dict[Tuple[str, Hashable], Edu] = {} | |||
# Map of user_id -> UserPresenceState of pending presence to be sent to this | |||
# destination | |||
self._pending_presence = {} # type: Dict[str, UserPresenceState] | |||
self._pending_presence: Dict[str, UserPresenceState] = {} | |||
# room_id -> receipt_type -> user_id -> receipt_dict | |||
self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]] | |||
self._pending_rrs: Dict[str, Dict[str, Dict[str, dict]]] = {} | |||
self._rrs_pending_flush = False | |||
# stream_id of last successfully sent to-device message. | |||
@@ -243,7 +243,7 @@ class PerDestinationQueue: | |||
) | |||
async def _transaction_transmission_loop(self) -> None: | |||
pending_pdus = [] # type: List[EventBase] | |||
pending_pdus: List[EventBase] = [] | |||
try: | |||
self.transmission_loop_running = True | |||
@@ -395,9 +395,9 @@ class TransportLayerClient: | |||
# this uses MSC2197 (Search Filtering over Federation) | |||
path = _create_v1_path("/publicRooms") | |||
data = { | |||
data: Dict[str, Any] = { | |||
"include_all_networks": "true" if include_all_networks else "false" | |||
} # type: Dict[str, Any] | |||
} | |||
if third_party_instance_id: | |||
data["third_party_instance_id"] = third_party_instance_id | |||
if limit: | |||
@@ -423,9 +423,9 @@ class TransportLayerClient: | |||
else: | |||
path = _create_v1_path("/publicRooms") | |||
args = { | |||
args: Dict[str, Any] = { | |||
"include_all_networks": "true" if include_all_networks else "false" | |||
} # type: Dict[str, Any] | |||
} | |||
if third_party_instance_id: | |||
args["third_party_instance_id"] = (third_party_instance_id,) | |||
if limit: | |||
@@ -1013,7 +1013,7 @@ class PublicRoomList(BaseFederationServlet): | |||
if not self.allow_access: | |||
raise FederationDeniedError(origin) | |||
limit = int(content.get("limit", 100)) # type: Optional[int] | |||
limit: Optional[int] = int(content.get("limit", 100)) | |||
since_token = content.get("since", None) | |||
search_filter = content.get("filter", None) | |||
@@ -1991,7 +1991,7 @@ class RoomComplexityServlet(BaseFederationServlet): | |||
return 200, complexity | |||
FEDERATION_SERVLET_CLASSES = ( | |||
FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( | |||
FederationSendServlet, | |||
FederationEventServlet, | |||
FederationStateV1Servlet, | |||
@@ -2019,15 +2019,13 @@ FEDERATION_SERVLET_CLASSES = ( | |||
FederationSpaceSummaryServlet, | |||
FederationV1SendKnockServlet, | |||
FederationMakeKnockServlet, | |||
) # type: Tuple[Type[BaseFederationServlet], ...] | |||
) | |||
OPENID_SERVLET_CLASSES = ( | |||
OpenIdUserInfo, | |||
) # type: Tuple[Type[BaseFederationServlet], ...] | |||
OPENID_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (OpenIdUserInfo,) | |||
ROOM_LIST_CLASSES = (PublicRoomList,) # type: Tuple[Type[PublicRoomList], ...] | |||
ROOM_LIST_CLASSES: Tuple[Type[PublicRoomList], ...] = (PublicRoomList,) | |||
GROUP_SERVER_SERVLET_CLASSES = ( | |||
GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( | |||
FederationGroupsProfileServlet, | |||
FederationGroupsSummaryServlet, | |||
FederationGroupsRoomsServlet, | |||
@@ -2046,19 +2044,19 @@ GROUP_SERVER_SERVLET_CLASSES = ( | |||
FederationGroupsAddRoomsServlet, | |||
FederationGroupsAddRoomsConfigServlet, | |||
FederationGroupsSettingJoinPolicyServlet, | |||
) # type: Tuple[Type[BaseFederationServlet], ...] | |||
) | |||
GROUP_LOCAL_SERVLET_CLASSES = ( | |||
GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( | |||
FederationGroupsLocalInviteServlet, | |||
FederationGroupsRemoveLocalUserServlet, | |||
FederationGroupsBulkPublicisedServlet, | |||
) # type: Tuple[Type[BaseFederationServlet], ...] | |||
) | |||
GROUP_ATTESTATION_SERVLET_CLASSES = ( | |||
GROUP_ATTESTATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( | |||
FederationGroupsRenewAttestaionServlet, | |||
) # type: Tuple[Type[BaseFederationServlet], ...] | |||
) | |||
DEFAULT_SERVLET_GROUPS = ( | |||
@@ -707,9 +707,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
See accept_invite, join_group. | |||
""" | |||
if not self.hs.is_mine_id(user_id): | |||
local_attestation = self.attestations.create_attestation( | |||
group_id, user_id | |||
) # type: Optional[JsonDict] | |||
local_attestation: Optional[ | |||
JsonDict | |||
] = self.attestations.create_attestation(group_id, user_id) | |||
remote_attestation = content["attestation"] | |||
@@ -868,9 +868,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
remote_attestation, user_id=requester_user_id, group_id=group_id | |||
) | |||
local_attestation = self.attestations.create_attestation( | |||
group_id, requester_user_id | |||
) # type: Optional[JsonDict] | |||
local_attestation: Optional[ | |||
JsonDict | |||
] = self.attestations.create_attestation(group_id, requester_user_id) | |||
else: | |||
local_attestation = None | |||
remote_attestation = None | |||
@@ -69,7 +69,7 @@ def _get_requested_host(request: IRequest) -> bytes: | |||
return hostname | |||
# no Host header, use the address/port that the request arrived on | |||
host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address] | |||
host: Union[address.IPv4Address, address.IPv6Address] = request.getHost() | |||
hostname = host.host.encode("ascii") | |||
@@ -160,7 +160,7 @@ class _IPBlacklistingResolver: | |||
def resolveHostName( | |||
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0 | |||
) -> IResolutionReceiver: | |||
addresses = [] # type: List[IAddress] | |||
addresses: List[IAddress] = [] | |||
def _callback() -> None: | |||
has_bad_ip = False | |||
@@ -333,9 +333,9 @@ class SimpleHttpClient: | |||
if self._ip_blacklist: | |||
# If we have an IP blacklist, we need to use a DNS resolver which | |||
# filters out blacklisted IP addresses, to prevent DNS rebinding. | |||
self.reactor = BlacklistingReactorWrapper( | |||
self.reactor: ISynapseReactor = BlacklistingReactorWrapper( | |||
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist | |||
) # type: ISynapseReactor | |||
) | |||
else: | |||
self.reactor = hs.get_reactor() | |||
@@ -349,14 +349,14 @@ class SimpleHttpClient: | |||
pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5)) | |||
pool.cachedConnectionTimeout = 2 * 60 | |||
self.agent = ProxyAgent( | |||
self.agent: IAgent = ProxyAgent( | |||
self.reactor, | |||
hs.get_reactor(), | |||
connectTimeout=15, | |||
contextFactory=self.hs.get_http_client_context_factory(), | |||
pool=pool, | |||
use_proxy=use_proxy, | |||
) # type: IAgent | |||
) | |||
if self._ip_blacklist: | |||
# If we have an IP blacklist, we then install the blacklisting Agent | |||
@@ -411,7 +411,7 @@ class SimpleHttpClient: | |||
cooperator=self._cooperator, | |||
) | |||
request_deferred = treq.request( | |||
request_deferred: defer.Deferred = treq.request( | |||
method, | |||
uri, | |||
agent=self.agent, | |||
@@ -421,7 +421,7 @@ class SimpleHttpClient: | |||
# response bodies. | |||
unbuffered=True, | |||
**self._extra_treq_args, | |||
) # type: defer.Deferred | |||
) | |||
# we use our own timeout mechanism rather than treq's as a workaround | |||
# for https://twistedmatrix.com/trac/ticket/9534. | |||
@@ -772,7 +772,7 @@ class BodyExceededMaxSize(Exception): | |||
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): | |||
"""A protocol which immediately errors upon receiving data.""" | |||
transport = None # type: Optional[ITCPTransport] | |||
transport: Optional[ITCPTransport] = None | |||
def __init__(self, deferred: defer.Deferred): | |||
self.deferred = deferred | |||
@@ -798,7 +798,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): | |||
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): | |||
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" | |||
transport = None # type: Optional[ITCPTransport] | |||
transport: Optional[ITCPTransport] = None | |||
def __init__( | |||
self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int] | |||
@@ -106,7 +106,7 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC): | |||
the parsed data. | |||
""" | |||
CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore | |||
CONTENT_TYPE: str = abc.abstractproperty() # type: ignore | |||
"""The expected content type of the response, e.g. `application/json`. If | |||
the content type doesn't match we fail the request. | |||
""" | |||
@@ -327,11 +327,11 @@ class MatrixFederationHttpClient: | |||
# We need to use a DNS resolver which filters out blacklisted IP | |||
# addresses, to prevent DNS rebinding. | |||
self.reactor = BlacklistingReactorWrapper( | |||
self.reactor: ISynapseReactor = BlacklistingReactorWrapper( | |||
hs.get_reactor(), | |||
hs.config.federation_ip_range_whitelist, | |||
hs.config.federation_ip_range_blacklist, | |||
) # type: ISynapseReactor | |||
) | |||
user_agent = hs.version_string | |||
if hs.config.user_agent_suffix: | |||
@@ -504,7 +504,7 @@ class MatrixFederationHttpClient: | |||
) | |||
# Inject the span into the headers | |||
headers_dict = {} # type: Dict[bytes, List[bytes]] | |||
headers_dict: Dict[bytes, List[bytes]] = {} | |||
opentracing.inject_header_dict(headers_dict, request.destination) | |||
headers_dict[b"User-Agent"] = [self.version_string_bytes] | |||
@@ -533,9 +533,9 @@ class MatrixFederationHttpClient: | |||
destination_bytes, method_bytes, url_to_sign_bytes, json | |||
) | |||
data = encode_canonical_json(json) | |||
producer = QuieterFileBodyProducer( | |||
producer: Optional[IBodyProducer] = QuieterFileBodyProducer( | |||
BytesIO(data), cooperator=self._cooperator | |||
) # type: Optional[IBodyProducer] | |||
) | |||
else: | |||
producer = None | |||
auth_headers = self.build_auth_headers( | |||
@@ -81,7 +81,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: | |||
if f.check(SynapseError): | |||
# mypy doesn't understand that f.check asserts the type. | |||
exc = f.value # type: SynapseError # type: ignore | |||
exc: SynapseError = f.value # type: ignore | |||
error_code = exc.code | |||
error_dict = exc.error_dict() | |||
@@ -132,7 +132,7 @@ def return_html_error( | |||
""" | |||
if f.check(CodeMessageException): | |||
# mypy doesn't understand that f.check asserts the type. | |||
cme = f.value # type: CodeMessageException # type: ignore | |||
cme: CodeMessageException = f.value # type: ignore | |||
code = cme.code | |||
msg = cme.msg | |||
@@ -404,7 +404,7 @@ class JsonResource(DirectServeJsonResource): | |||
key word arguments to pass to the callback | |||
""" | |||
# At this point the path must be bytes. | |||
request_path_bytes = request.path # type: bytes # type: ignore | |||
request_path_bytes: bytes = request.path # type: ignore | |||
request_path = request_path_bytes.decode("ascii") | |||
# Treat HEAD requests as GET requests. | |||
request_method = request.method | |||
@@ -557,7 +557,7 @@ class _ByteProducer: | |||
request: Request, | |||
iterator: Iterator[bytes], | |||
): | |||
self._request = request # type: Optional[Request] | |||
self._request: Optional[Request] = request | |||
self._iterator = iterator | |||
self._paused = False | |||
@@ -205,7 +205,7 @@ def parse_string( | |||
parameter is present, must be one of a list of allowed values and | |||
is not one of those allowed values. | |||
""" | |||
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore | |||
args: Dict[bytes, List[bytes]] = request.args # type: ignore | |||
return parse_string_from_args( | |||
args, | |||
name, | |||
@@ -64,16 +64,16 @@ class SynapseRequest(Request): | |||
def __init__(self, channel, *args, max_request_body_size=1024, **kw): | |||
Request.__init__(self, channel, *args, **kw) | |||
self._max_request_body_size = max_request_body_size | |||
self.site = channel.site # type: SynapseSite | |||
self.site: SynapseSite = channel.site | |||
self._channel = channel # this is used by the tests | |||
self.start_time = 0.0 | |||
# The requester, if authenticated. For federation requests this is the | |||
# server name, for client requests this is the Requester object. | |||
self._requester = None # type: Optional[Union[Requester, str]] | |||
self._requester: Optional[Union[Requester, str]] = None | |||
# we can't yet create the logcontext, as we don't know the method. | |||
self.logcontext = None # type: Optional[LoggingContext] | |||
self.logcontext: Optional[LoggingContext] = None | |||
global _next_request_seq | |||
self.request_seq = _next_request_seq | |||
@@ -152,7 +152,7 @@ class SynapseRequest(Request): | |||
Returns: | |||
The redacted URI as a string. | |||
""" | |||
uri = self.uri # type: Union[bytes, str] | |||
uri: Union[bytes, str] = self.uri | |||
if isinstance(uri, bytes): | |||
uri = uri.decode("ascii", errors="replace") | |||
return redact_uri(uri) | |||
@@ -167,7 +167,7 @@ class SynapseRequest(Request): | |||
Returns: | |||
The request method as a string. | |||
""" | |||
method = self.method # type: Union[bytes, str] | |||
method: Union[bytes, str] = self.method | |||
if isinstance(method, bytes): | |||
return self.method.decode("ascii") | |||
return method | |||
@@ -434,8 +434,8 @@ class XForwardedForRequest(SynapseRequest): | |||
""" | |||
# the client IP and ssl flag, as extracted from the headers. | |||
_forwarded_for = None # type: Optional[_XForwardedForAddress] | |||
_forwarded_https = False # type: bool | |||
_forwarded_for: "Optional[_XForwardedForAddress]" = None | |||
_forwarded_https: bool = False | |||
def requestReceived(self, command, path, version): | |||
# this method is called by the Channel once the full request has been | |||
@@ -110,9 +110,9 @@ class RemoteHandler(logging.Handler): | |||
self.port = port | |||
self.maximum_buffer = maximum_buffer | |||
self._buffer = deque() # type: Deque[logging.LogRecord] | |||
self._connection_waiter = None # type: Optional[Deferred] | |||
self._producer = None # type: Optional[LogProducer] | |||
self._buffer: Deque[logging.LogRecord] = deque() | |||
self._connection_waiter: Optional[Deferred] = None | |||
self._producer: Optional[LogProducer] = None | |||
# Connect without DNS lookups if it's a direct IP. | |||
if _reactor is None: | |||
@@ -123,9 +123,9 @@ class RemoteHandler(logging.Handler): | |||
try: | |||
ip = ip_address(self.host) | |||
if isinstance(ip, IPv4Address): | |||
endpoint = TCP4ClientEndpoint( | |||
endpoint: IStreamClientEndpoint = TCP4ClientEndpoint( | |||
_reactor, self.host, self.port | |||
) # type: IStreamClientEndpoint | |||
) | |||
elif isinstance(ip, IPv6Address): | |||
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port) | |||
else: | |||
@@ -165,7 +165,7 @@ class RemoteHandler(logging.Handler): | |||
def writer(result: Protocol) -> None: | |||
# Force recognising transport as a Connection and not the more | |||
# generic ITransport. | |||
transport = result.transport # type: Connection # type: ignore | |||
transport: Connection = result.transport # type: ignore | |||
# We have a connection. If we already have a producer, and its | |||
# transport is the same, just trigger a resumeProducing. | |||
@@ -188,7 +188,7 @@ class RemoteHandler(logging.Handler): | |||
self._producer.resumeProducing() | |||
self._connection_waiter = None | |||
deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred | |||
deferred: Deferred = self._service.whenConnected(failAfterFailures=1) | |||
deferred.addCallbacks(writer, fail) | |||
self._connection_waiter = deferred | |||
@@ -63,7 +63,7 @@ def parse_drain_configs( | |||
DrainType.CONSOLE_JSON, | |||
DrainType.FILE_JSON, | |||
): | |||
formatter = "json" # type: Optional[str] | |||
formatter: Optional[str] = "json" | |||
elif logging_type in ( | |||
DrainType.CONSOLE_JSON_TERSE, | |||
DrainType.NETWORK_JSON_TERSE, | |||
@@ -113,13 +113,13 @@ class ContextResourceUsage: | |||
self.reset() | |||
else: | |||
# FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now | |||
self.ru_utime = copy_from.ru_utime # type: float | |||
self.ru_stime = copy_from.ru_stime # type: float | |||
self.db_txn_count = copy_from.db_txn_count # type: int | |||
self.ru_utime: float = copy_from.ru_utime | |||
self.ru_stime: float = copy_from.ru_stime | |||
self.db_txn_count: int = copy_from.db_txn_count | |||
self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float | |||
self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float | |||
self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int | |||
self.db_txn_duration_sec: float = copy_from.db_txn_duration_sec | |||
self.db_sched_duration_sec: float = copy_from.db_sched_duration_sec | |||
self.evt_db_fetch_count: int = copy_from.evt_db_fetch_count | |||
def copy(self) -> "ContextResourceUsage": | |||
return ContextResourceUsage(copy_from=self) | |||
@@ -289,12 +289,12 @@ class LoggingContext: | |||
# The thread resource usage when the logcontext became active. None | |||
# if the context is not currently active. | |||
self.usage_start = None # type: Optional[resource._RUsage] | |||
self.usage_start: Optional[resource._RUsage] = None | |||
self.main_thread = get_thread_id() | |||
self.request = None | |||
self.tag = "" | |||
self.scope = None # type: Optional[_LogContextScope] | |||
self.scope: Optional["_LogContextScope"] = None | |||
# keep track of whether we have hit the __exit__ block for this context | |||
# (suggesting that the the thing that created the context thinks it should | |||
@@ -251,7 +251,7 @@ try: | |||
except Exception: | |||
logger.exception("Failed to report span") | |||
RustReporter = _WrappedRustReporter # type: Optional[Type[_WrappedRustReporter]] | |||
RustReporter: Optional[Type[_WrappedRustReporter]] = _WrappedRustReporter | |||
except ImportError: | |||
RustReporter = None | |||
@@ -286,7 +286,7 @@ class SynapseBaggage: | |||
# Block everything by default | |||
# A regex which matches the server_names to expose traces for. | |||
# None means 'block everything'. | |||
_homeserver_whitelist = None # type: Optional[Pattern[str]] | |||
_homeserver_whitelist: Optional[Pattern[str]] = None | |||
# Util methods | |||
@@ -662,7 +662,7 @@ def inject_header_dict( | |||
span = opentracing.tracer.active_span | |||
carrier = {} # type: Dict[str, str] | |||
carrier: Dict[str, str] = {} | |||
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier) | |||
for key, value in carrier.items(): | |||
@@ -704,7 +704,7 @@ def get_active_span_text_map(destination=None): | |||
if destination and not whitelisted_homeserver(destination): | |||
return {} | |||
carrier = {} # type: Dict[str, str] | |||
carrier: Dict[str, str] = {} | |||
opentracing.tracer.inject( | |||
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier | |||
) | |||
@@ -718,7 +718,7 @@ def active_span_context_as_string(): | |||
Returns: | |||
The active span context encoded as a string. | |||
""" | |||
carrier = {} # type: Dict[str, str] | |||
carrier: Dict[str, str] = {} | |||
if opentracing: | |||
opentracing.tracer.inject( | |||
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier | |||
@@ -46,7 +46,7 @@ logger = logging.getLogger(__name__) | |||
METRICS_PREFIX = "/_synapse/metrics" | |||
running_on_pypy = platform.python_implementation() == "PyPy" | |||
all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge]] | |||
all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {} | |||
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") | |||
@@ -130,7 +130,7 @@ class InFlightGauge: | |||
) | |||
# Counts number of in flight blocks for a given set of label values | |||
self._registrations = {} # type: Dict | |||
self._registrations: Dict = {} | |||
# Protects access to _registrations | |||
self._lock = threading.Lock() | |||
@@ -248,7 +248,7 @@ class GaugeBucketCollector: | |||
# We initially set this to None. We won't report metrics until | |||
# this has been initialised after a successful data update | |||
self._metric = None # type: Optional[GaugeHistogramMetricFamily] | |||
self._metric: Optional[GaugeHistogramMetricFamily] = None | |||
registry.register(self) | |||
@@ -125,7 +125,7 @@ def generate_latest(registry, emit_help=False): | |||
) | |||
output.append("# TYPE {0} {1}\n".format(mname, mtype)) | |||
om_samples = {} # type: Dict[str, List[str]] | |||
om_samples: Dict[str, List[str]] = {} | |||
for s in metric.samples: | |||
for suffix in ["_created", "_gsum", "_gcount"]: | |||
if s.name == metric.name + suffix: | |||
@@ -93,7 +93,7 @@ _background_process_db_sched_duration = Counter( | |||
# map from description to a counter, so that we can name our logcontexts | |||
# incrementally. (It actually duplicates _background_process_start_count, but | |||
# it's much simpler to do so than to try to combine them.) | |||
_background_process_counts = {} # type: Dict[str, int] | |||
_background_process_counts: Dict[str, int] = {} | |||
# Set of all running background processes that became active active since the | |||
# last time metrics were scraped (i.e. background processes that performed some | |||
@@ -103,7 +103,7 @@ _background_process_counts = {} # type: Dict[str, int] | |||
# background processes stacking up behind a lock or linearizer, where we then | |||
# only need to iterate over and update metrics for the process that have | |||
# actually been active and can ignore the idle ones. | |||
_background_processes_active_since_last_scrape = set() # type: Set[_BackgroundProcess] | |||
_background_processes_active_since_last_scrape: "Set[_BackgroundProcess]" = set() | |||
# A lock that covers the above set and dict | |||
_bg_metrics_lock = threading.Lock() | |||
@@ -54,7 +54,7 @@ class ModuleApi: | |||
self._state = hs.get_state_handler() | |||
# We expose these as properties below in order to attach a helpful docstring. | |||
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient | |||
self._http_client: SimpleHttpClient = hs.get_simple_http_client() | |||
self._public_room_list_manager = PublicRoomListManager(hs) | |||
self._spam_checker = hs.get_spam_checker() | |||
@@ -203,21 +203,21 @@ class Notifier: | |||
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 | |||
def __init__(self, hs: "synapse.server.HomeServer"): | |||
self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream] | |||
self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]] | |||
self.user_to_user_stream: Dict[str, _NotifierUserStream] = {} | |||
self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} | |||
self.hs = hs | |||
self.storage = hs.get_storage() | |||
self.event_sources = hs.get_event_sources() | |||
self.store = hs.get_datastore() | |||
self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry] | |||
self.pending_new_room_events: List[_PendingRoomEventEntry] = [] | |||
# Called when there are new things to stream over replication | |||
self.replication_callbacks = [] # type: List[Callable[[], None]] | |||
self.replication_callbacks: List[Callable[[], None]] = [] | |||
# Called when remote servers have come back online after having been | |||
# down. | |||
self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]] | |||
self.remote_server_up_callbacks: List[Callable[[str], None]] = [] | |||
self.clock = hs.get_clock() | |||
self.appservice_handler = hs.get_application_service_handler() | |||
@@ -237,7 +237,7 @@ class Notifier: | |||
# when rendering the metrics page, which is likely once per minute at | |||
# most when scraping it. | |||
def count_listeners(): | |||
all_user_streams = set() # type: Set[_NotifierUserStream] | |||
all_user_streams: Set[_NotifierUserStream] = set() | |||
for streams in list(self.room_to_user_streams.values()): | |||
all_user_streams |= streams | |||
@@ -329,8 +329,8 @@ class Notifier: | |||
pending = self.pending_new_room_events | |||
self.pending_new_room_events = [] | |||
users = set() # type: Set[UserID] | |||
rooms = set() # type: Set[str] | |||
users: Set[UserID] = set() | |||
rooms: Set[str] = set() | |||
for entry in pending: | |||
if entry.event_pos.persisted_after(max_room_stream_token): | |||
@@ -580,7 +580,7 @@ class Notifier: | |||
if after_token == before_token: | |||
return EventStreamResult([], (from_token, from_token)) | |||
events = [] # type: List[EventBase] | |||
events: List[EventBase] = [] | |||
end_token = from_token | |||
for name, source in self.event_sources.sources.items(): | |||
@@ -194,7 +194,7 @@ class BulkPushRuleEvaluator: | |||
count_as_unread = _should_count_as_unread(event, context) | |||
rules_by_user = await self._get_rules_for_event(event, context) | |||
actions_by_user = {} # type: Dict[str, List[Union[dict, str]]] | |||
actions_by_user: Dict[str, List[Union[dict, str]]] = {} | |||
room_members = await self.store.get_joined_users_from_context(event, context) | |||
@@ -207,7 +207,7 @@ class BulkPushRuleEvaluator: | |||
event, len(room_members), sender_power_level, power_levels | |||
) | |||
condition_cache = {} # type: Dict[str, bool] | |||
condition_cache: Dict[str, bool] = {} | |||
# If the event is not a state event check if any users ignore the sender. | |||
if not event.is_state(): | |||
@@ -26,10 +26,10 @@ def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, l | |||
# We're going to be mutating this a lot, so do a deep copy | |||
ruleslist = copy.deepcopy(ruleslist) | |||
rules = { | |||
rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = { | |||
"global": {}, | |||
"device": {}, | |||
} # type: Dict[str, Dict[str, List[Dict[str, Any]]]] | |||
} | |||
rules["global"] = _add_empty_priority_class_arrays(rules["global"]) | |||
@@ -66,8 +66,8 @@ class EmailPusher(Pusher): | |||
self.store = self.hs.get_datastore() | |||
self.email = pusher_config.pushkey | |||
self.timed_call = None # type: Optional[IDelayedCall] | |||
self.throttle_params = {} # type: Dict[str, ThrottleParams] | |||
self.timed_call: Optional[IDelayedCall] = None | |||
self.throttle_params: Dict[str, ThrottleParams] = {} | |||
self._inited = False | |||
self._is_processing = False | |||
@@ -168,7 +168,7 @@ class EmailPusher(Pusher): | |||
) | |||
) | |||
soonest_due_at = None # type: Optional[int] | |||
soonest_due_at: Optional[int] = None | |||
if not unprocessed: | |||
await self.save_last_stream_ordering_and_success(self.max_stream_ordering) | |||
@@ -71,7 +71,7 @@ class HttpPusher(Pusher): | |||
self.data = pusher_config.data | |||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC | |||
self.failing_since = pusher_config.failing_since | |||
self.timed_call = None # type: Optional[IDelayedCall] | |||
self.timed_call: Optional[IDelayedCall] = None | |||
self._is_processing = False | |||
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room | |||
self._pusherpool = hs.get_pusherpool() | |||
@@ -110,7 +110,7 @@ class Mailer: | |||
self.state_handler = self.hs.get_state_handler() | |||
self.storage = hs.get_storage() | |||
self.app_name = app_name | |||
self.email_subjects = hs.config.email_subjects # type: EmailSubjectConfig | |||
self.email_subjects: EmailSubjectConfig = hs.config.email_subjects | |||
logger.info("Created Mailer for app_name %s" % app_name) | |||
@@ -230,7 +230,7 @@ class Mailer: | |||
[pa["event_id"] for pa in push_actions] | |||
) | |||
notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]] | |||
notifs_by_room: Dict[str, List[Dict[str, Any]]] = {} | |||
for pa in push_actions: | |||
notifs_by_room.setdefault(pa["room_id"], []).append(pa) | |||
@@ -356,13 +356,13 @@ class Mailer: | |||
room_name = await calculate_room_name(self.store, room_state_ids, user_id) | |||
room_vars = { | |||
room_vars: Dict[str, Any] = { | |||
"title": room_name, | |||
"hash": string_ordinal_total(room_id), # See sender avatar hash | |||
"notifs": [], | |||
"invite": is_invite, | |||
"link": self._make_room_link(room_id), | |||
} # type: Dict[str, Any] | |||
} | |||
if not is_invite: | |||
for n in notifs: | |||
@@ -460,9 +460,9 @@ class Mailer: | |||
type_state_key = ("m.room.member", event.sender) | |||
sender_state_event_id = room_state_ids.get(type_state_key) | |||
if sender_state_event_id: | |||
sender_state_event = await self.store.get_event( | |||
sender_state_event: Optional[EventBase] = await self.store.get_event( | |||
sender_state_event_id | |||
) # type: Optional[EventBase] | |||
) | |||
else: | |||
# Attempt to check the historical state for the room. | |||
historical_state = await self.state_store.get_state_for_event( | |||
@@ -199,7 +199,7 @@ def name_from_member_event(member_event: EventBase) -> str: | |||
def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]: | |||
ret = {} # type: Dict[str, Dict[str, str]] | |||
ret: Dict[str, Dict[str, str]] = {} | |||
for k, v in state.items(): | |||
ret.setdefault(k[0], {})[k[1]] = v | |||
return ret | |||
@@ -195,9 +195,9 @@ class PushRuleEvaluatorForEvent: | |||
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches | |||
regex_cache = LruCache( | |||
regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( | |||
50000, "regex_push_cache" | |||
) # type: LruCache[Tuple[str, bool, bool], Pattern] | |||
) | |||
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: | |||
@@ -31,13 +31,13 @@ class PusherFactory: | |||
self.hs = hs | |||
self.config = hs.config | |||
self.pusher_types = { | |||
self.pusher_types: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] = { | |||
"http": HttpPusher | |||
} # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] | |||
} | |||
logger.info("email enable notifs: %r", hs.config.email_enable_notifs) | |||
if hs.config.email_enable_notifs: | |||
self.mailers = {} # type: Dict[str, Mailer] | |||
self.mailers: Dict[str, Mailer] = {} | |||
self._notif_template_html = hs.config.email_notif_template_html | |||
self._notif_template_text = hs.config.email_notif_template_text | |||
@@ -87,7 +87,7 @@ class PusherPool: | |||
self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering() | |||
# map from user id to app_id:pushkey to pusher | |||
self.pushers = {} # type: Dict[str, Dict[str, Pusher]] | |||
self.pushers: Dict[str, Dict[str, Pusher]] = {} | |||
def start(self) -> None: | |||
"""Starts the pushers off in a background process.""" | |||
@@ -115,7 +115,7 @@ CONDITIONAL_REQUIREMENTS = { | |||
"cache_memory": ["pympler"], | |||
} | |||
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] | |||
ALL_OPTIONAL_REQUIREMENTS: Set[str] = set() | |||
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): | |||
# Exclude systemd as it's a system-based requirement. | |||
@@ -193,7 +193,7 @@ def check_requirements(for_feature=None): | |||
if not for_feature: | |||
# Check the optional dependencies are up to date. We allow them to not be | |||
# installed. | |||
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str] | |||
OPTS: List[str] = sum(CONDITIONAL_REQUIREMENTS.values(), []) | |||
for dependency in OPTS: | |||
try: | |||
@@ -85,17 +85,17 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): | |||
is received. | |||
""" | |||
NAME = abc.abstractproperty() # type: str # type: ignore | |||
PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore | |||
NAME: str = abc.abstractproperty() # type: ignore | |||
PATH_ARGS: Tuple[str, ...] = abc.abstractproperty() # type: ignore | |||
METHOD = "POST" | |||
CACHE = True | |||
RETRY_ON_TIMEOUT = True | |||
def __init__(self, hs: "HomeServer"): | |||
if self.CACHE: | |||
self.response_cache = ResponseCache( | |||
self.response_cache: ResponseCache[str] = ResponseCache( | |||
hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000 | |||
) # type: ResponseCache[str] | |||
) | |||
# We reserve `instance_name` as a parameter to sending requests, so we | |||
# assert here that sub classes don't try and use the name. | |||
@@ -232,7 +232,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): | |||
# have a good idea that the request has either succeeded or failed on | |||
# the master, and so whether we should clean up or not. | |||
while True: | |||
headers = {} # type: Dict[bytes, List[bytes]] | |||
headers: Dict[bytes, List[bytes]] = {} | |||
# Add an authorization header, if configured. | |||
if replication_secret: | |||
headers[b"Authorization"] = [b"Bearer " + replication_secret] | |||
@@ -27,7 +27,9 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super().__init__(database, db_conn, hs) | |||
if isinstance(self.database_engine, PostgresEngine): | |||
self._cache_id_gen = MultiWriterIdGenerator( | |||
self._cache_id_gen: Optional[ | |||
MultiWriterIdGenerator | |||
] = MultiWriterIdGenerator( | |||
db_conn, | |||
database, | |||
stream_name="caches", | |||
@@ -41,7 +43,7 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): | |||
], | |||
sequence_name="cache_invalidation_stream_seq", | |||
writers=[], | |||
) # type: Optional[MultiWriterIdGenerator] | |||
) | |||
else: | |||
self._cache_id_gen = None | |||
@@ -23,9 +23,9 @@ class SlavedClientIpStore(BaseSlavedStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
super().__init__(database, db_conn, hs) | |||
self.client_ip_last_seen = LruCache( | |||
self.client_ip_last_seen: LruCache[tuple, int] = LruCache( | |||
cache_name="client_ip_last_seen", max_size=50000 | |||
) # type: LruCache[tuple, int] | |||
) | |||
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): | |||
now = int(self._clock.time_msec()) | |||
@@ -121,13 +121,13 @@ class ReplicationDataHandler: | |||
self._pusher_pool = hs.get_pusherpool() | |||
self._presence_handler = hs.get_presence_handler() | |||
self.send_handler = None # type: Optional[FederationSenderHandler] | |||
self.send_handler: Optional[FederationSenderHandler] = None | |||
if hs.should_send_federation(): | |||
self.send_handler = FederationSenderHandler(hs) | |||
# Map from stream to list of deferreds waiting for the stream to | |||
# arrive at a particular position. The lists are sorted by stream position. | |||
self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]] | |||
self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {} | |||
async def on_rdata( | |||
self, stream_name: str, instance_name: str, token: int, rows: list | |||
@@ -173,7 +173,7 @@ class ReplicationDataHandler: | |||
if entities: | |||
self.notifier.on_new_event("to_device_key", token, users=entities) | |||
elif stream_name == DeviceListsStream.NAME: | |||
all_room_ids = set() # type: Set[str] | |||
all_room_ids: Set[str] = set() | |||
for row in rows: | |||
if row.entity.startswith("@"): | |||
room_ids = await self.store.get_rooms_for_user(row.entity) | |||
@@ -201,7 +201,7 @@ class ReplicationDataHandler: | |||
if row.data.rejected: | |||
continue | |||
extra_users = () # type: Tuple[UserID, ...] | |||
extra_users: Tuple[UserID, ...] = () | |||
if row.data.type == EventTypes.Member and row.data.state_key: | |||
extra_users = (UserID.from_string(row.data.state_key),) | |||
@@ -348,7 +348,7 @@ class FederationSenderHandler: | |||
# Stores the latest position in the federation stream we've gotten up | |||
# to. This is always set before we use it. | |||
self.federation_position = None # type: Optional[int] | |||
self.federation_position: Optional[int] = None | |||
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") | |||
@@ -34,7 +34,7 @@ class Command(metaclass=abc.ABCMeta): | |||
A full command line on the wire is constructed from `NAME + " " + to_line()` | |||
""" | |||
NAME = None # type: str | |||
NAME: str | |||
@classmethod | |||
@abc.abstractmethod | |||
@@ -380,7 +380,7 @@ class RemoteServerUpCommand(_SimpleCommand): | |||
NAME = "REMOTE_SERVER_UP" | |||
_COMMANDS = ( | |||
_COMMANDS: Tuple[Type[Command], ...] = ( | |||
ServerCommand, | |||
RdataCommand, | |||
PositionCommand, | |||
@@ -393,7 +393,7 @@ _COMMANDS = ( | |||
UserIpCommand, | |||
RemoteServerUpCommand, | |||
ClearUserSyncsCommand, | |||
) # type: Tuple[Type[Command], ...] | |||
) | |||
# Map of command name to command type. | |||
COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS} | |||
@@ -105,12 +105,12 @@ class ReplicationCommandHandler: | |||
hs.get_instance_name() in hs.config.worker.writers.presence | |||
) | |||
self._streams = { | |||
self._streams: Dict[str, Stream] = { | |||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values() | |||
} # type: Dict[str, Stream] | |||
} | |||
# List of streams that this instance is the source of | |||
self._streams_to_replicate = [] # type: List[Stream] | |||
self._streams_to_replicate: List[Stream] = [] | |||
for stream in self._streams.values(): | |||
if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME: | |||
@@ -180,14 +180,14 @@ class ReplicationCommandHandler: | |||
# Map of stream name to batched updates. See RdataCommand for info on | |||
# how batching works. | |||
self._pending_batches = {} # type: Dict[str, List[Any]] | |||
self._pending_batches: Dict[str, List[Any]] = {} | |||
# The factory used to create connections. | |||
self._factory = None # type: Optional[ReconnectingClientFactory] | |||
self._factory: Optional[ReconnectingClientFactory] = None | |||
# The currently connected connections. (The list of places we need to send | |||
# outgoing replication commands to.) | |||
self._connections = [] # type: List[IReplicationConnection] | |||
self._connections: List[IReplicationConnection] = [] | |||
LaterGauge( | |||
"synapse_replication_tcp_resource_total_connections", | |||
@@ -200,7 +200,7 @@ class ReplicationCommandHandler: | |||
# them in order in a separate background process. | |||
# the streams which are currently being processed by _unsafe_process_queue | |||
self._processing_streams = set() # type: Set[str] | |||
self._processing_streams: Set[str] = set() | |||
# for each stream, a queue of commands that are awaiting processing, and the | |||
# connection that they arrived on. | |||
@@ -210,7 +210,7 @@ class ReplicationCommandHandler: | |||
# For each connection, the incoming stream names that have received a POSITION | |||
# from that connection. | |||
self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]] | |||
self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {} | |||
LaterGauge( | |||
"synapse_replication_tcp_command_queue", | |||
@@ -102,7 +102,7 @@ tcp_outbound_commands_counter = Counter( | |||
# A list of all connected protocols. This allows us to send metrics about the | |||
# connections. | |||
connected_connections = [] # type: List[BaseReplicationStreamProtocol] | |||
connected_connections: "List[BaseReplicationStreamProtocol]" = [] | |||
logger = logging.getLogger(__name__) | |||
@@ -146,15 +146,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
# The transport is going to be an ITCPTransport, but that doesn't have the | |||
# (un)registerProducer methods, those are only on the implementation. | |||
transport = None # type: Connection | |||
transport: Connection | |||
delimiter = b"\n" | |||
# Valid commands we expect to receive | |||
VALID_INBOUND_COMMANDS = [] # type: Collection[str] | |||
VALID_INBOUND_COMMANDS: Collection[str] = [] | |||
# Valid commands we can send | |||
VALID_OUTBOUND_COMMANDS = [] # type: Collection[str] | |||
VALID_OUTBOUND_COMMANDS: Collection[str] = [] | |||
max_line_buffer = 10000 | |||
@@ -165,7 +165,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
self.last_received_command = self.clock.time_msec() | |||
self.last_sent_command = 0 | |||
# When we requested the connection be closed | |||
self.time_we_closed = None # type: Optional[int] | |||
self.time_we_closed: Optional[int] = None | |||
self.received_ping = False # Have we received a ping from the other side | |||
@@ -175,10 +175,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): | |||
self.conn_id = random_string(5) # To dedupe in case of name clashes. | |||
# List of pending commands to send once we've established the connection | |||
self.pending_commands = [] # type: List[Command] | |||
self.pending_commands: List[Command] = [] | |||
# The LoopingCall for sending pings. | |||
self._send_ping_loop = None # type: Optional[task.LoopingCall] | |||
self._send_ping_loop: Optional[task.LoopingCall] = None | |||
# a logcontext which we use for processing incoming commands. We declare it as a | |||
# background process so that the CPU stats get reported to prometheus. | |||
@@ -57,7 +57,7 @@ class ConstantProperty(Generic[T, V]): | |||
it. | |||
""" | |||
constant = attr.ib() # type: V | |||
constant: V = attr.ib() | |||
def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V: | |||
return self.constant | |||
@@ -91,9 +91,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): | |||
commands. | |||
""" | |||
synapse_handler = None # type: ReplicationCommandHandler | |||
synapse_stream_name = None # type: str | |||
synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol | |||
synapse_handler: "ReplicationCommandHandler" | |||
synapse_stream_name: str | |||
synapse_outbound_redis_connection: txredisapi.RedisProtocol | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
@@ -85,9 +85,9 @@ class Stream: | |||
time it was called. | |||
""" | |||
NAME = None # type: str # The name of the stream | |||
NAME: str # The name of the stream | |||
# The type of the row. Used by the default impl of parse_row. | |||
ROW_TYPE = None # type: Any | |||
ROW_TYPE: Any = None | |||
@classmethod | |||
def parse_row(cls, row: StreamRow): | |||
@@ -283,9 +283,7 @@ class PresenceStream(Stream): | |||
assert isinstance(presence_handler, PresenceHandler) | |||
update_function = ( | |||
presence_handler.get_all_presence_updates | |||
) # type: UpdateFunction | |||
update_function: UpdateFunction = presence_handler.get_all_presence_updates | |||
else: | |||
# Query presence writer process | |||
update_function = make_http_update_function(hs, self.NAME) | |||
@@ -334,9 +332,9 @@ class TypingStream(Stream): | |||
if writer_instance == hs.get_instance_name(): | |||
# On the writer, query the typing handler | |||
typing_writer_handler = hs.get_typing_writer_handler() | |||
update_function = ( | |||
typing_writer_handler.get_all_typing_updates | |||
) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] | |||
update_function: Callable[ | |||
[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]] | |||
] = typing_writer_handler.get_all_typing_updates | |||
current_token_function = typing_writer_handler.get_current_token | |||
else: | |||
# Query the typing writer process | |||
@@ -65,7 +65,7 @@ class BaseEventsStreamRow: | |||
""" | |||
# Unique string that ids the type. Must be overridden in sub classes. | |||
TypeId = None # type: str | |||
TypeId: str | |||
@classmethod | |||
def from_data(cls, data): | |||
@@ -103,10 +103,10 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow): | |||
event_id = attr.ib() # str, optional | |||
_EventRows = ( | |||
_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = ( | |||
EventsStreamEventRow, | |||
EventsStreamCurrentStateRow, | |||
) # type: Tuple[Type[BaseEventsStreamRow], ...] | |||
) | |||
TypeToRow = {Row.TypeId: Row for Row in _EventRows} | |||
@@ -157,9 +157,9 @@ class EventsStream(Stream): | |||
# now we fetch up to that many rows from the events table | |||
event_rows = await self._store.get_all_new_forward_event_rows( | |||
event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows( | |||
instance_name, from_token, current_token, target_row_count | |||
) # type: List[Tuple] | |||
) | |||
# we rely on get_all_new_forward_event_rows strictly honouring the limit, so | |||
# that we know it is safe to just take upper_limit = event_rows[-1][0]. | |||
@@ -172,7 +172,7 @@ class EventsStream(Stream): | |||
if len(event_rows) == target_row_count: | |||
limited = True | |||
upper_limit = event_rows[-1][0] # type: int | |||
upper_limit: int = event_rows[-1][0] | |||
else: | |||
limited = False | |||
upper_limit = current_token | |||
@@ -191,30 +191,30 @@ class EventsStream(Stream): | |||
# finally, fetch the ex-outliers rows. We assume there are few enough of these | |||
# not to bother with the limit. | |||
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( | |||
ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows( | |||
instance_name, from_token, upper_limit | |||
) # type: List[Tuple] | |||
) | |||
# we now need to turn the raw database rows returned into tuples suitable | |||
# for the replication protocol (basically, we add an identifier to | |||
# distinguish the row type). At the same time, we can limit the event_rows | |||
# to the max stream_id from state_rows. | |||
event_updates = ( | |||
event_updates: Iterable[Tuple[int, Tuple]] = ( | |||
(stream_id, (EventsStreamEventRow.TypeId, rest)) | |||
for (stream_id, *rest) in event_rows | |||
if stream_id <= upper_limit | |||
) # type: Iterable[Tuple[int, Tuple]] | |||
) | |||
state_updates = ( | |||
state_updates: Iterable[Tuple[int, Tuple]] = ( | |||
(stream_id, (EventsStreamCurrentStateRow.TypeId, rest)) | |||
for (stream_id, *rest) in state_rows | |||
) # type: Iterable[Tuple[int, Tuple]] | |||
) | |||
ex_outliers_updates = ( | |||
ex_outliers_updates: Iterable[Tuple[int, Tuple]] = ( | |||
(stream_id, (EventsStreamEventRow.TypeId, rest)) | |||
for (stream_id, *rest) in ex_outliers_rows | |||
) # type: Iterable[Tuple[int, Tuple]] | |||
) | |||
# we need to return a sorted list, so merge them together. | |||
updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates)) | |||
@@ -51,9 +51,9 @@ class FederationStream(Stream): | |||
current_token = current_token_without_instance( | |||
federation_sender.get_current_token | |||
) | |||
update_function = ( | |||
federation_sender.get_replication_rows | |||
) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] | |||
update_function: Callable[ | |||
[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]] | |||
] = federation_sender.get_replication_rows | |||
elif hs.should_send_federation(): | |||
# federation sender: Query master process | |||
@@ -247,15 +247,15 @@ class HomeServer(metaclass=abc.ABCMeta): | |||
# the key we use to sign events and requests | |||
self.signing_key = config.key.signing_key[0] | |||
self.config = config | |||
self._listening_services = [] # type: List[twisted.internet.tcp.Port] | |||
self.start_time = None # type: Optional[int] | |||
self._listening_services: List[twisted.internet.tcp.Port] = [] | |||
self.start_time: Optional[int] = None | |||
self._instance_id = random_string(5) | |||
self._instance_name = config.worker.instance_name | |||
self.version_string = version_string | |||
self.datastores = None # type: Optional[Databases] | |||
self.datastores: Optional[Databases] = None | |||
self._module_web_resources: Dict[str, IResource] = {} | |||
self._module_web_resources_consumed = False | |||
@@ -34,7 +34,7 @@ class ConsentServerNotices: | |||
self._server_notices_manager = hs.get_server_notices_manager() | |||
self._store = hs.get_datastore() | |||
self._users_in_progress = set() # type: Set[str] | |||
self._users_in_progress: Set[str] = set() | |||
self._current_consent_version = hs.config.user_consent_version | |||
self._server_notice_content = hs.config.user_consent_server_notice_content | |||
@@ -205,7 +205,7 @@ class ResourceLimitsServerNotices: | |||
# The user has yet to join the server notices room | |||
pass | |||
referenced_events = [] # type: List[str] | |||
referenced_events: List[str] = [] | |||
if pinned_state_event is not None: | |||
referenced_events = list(pinned_state_event.content.get("pinned", [])) | |||
@@ -32,10 +32,12 @@ class ServerNoticesSender(WorkerServerNoticesSender): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self._server_notices = ( | |||
self._server_notices: Iterable[ | |||
Union[ConsentServerNotices, ResourceLimitsServerNotices] | |||
] = ( | |||
ConsentServerNotices(hs), | |||
ResourceLimitsServerNotices(hs), | |||
) # type: Iterable[Union[ConsentServerNotices, ResourceLimitsServerNotices]] | |||
) | |||
async def on_user_syncing(self, user_id: str) -> None: | |||
"""Called when the user performs a sync operation. | |||
@@ -309,9 +309,9 @@ class StateHandler: | |||
if old_state: | |||
# if we're given the state before the event, then we use that | |||
state_ids_before_event = { | |||
state_ids_before_event: StateMap[str] = { | |||
(s.type, s.state_key): s.event_id for s in old_state | |||
} # type: StateMap[str] | |||
} | |||
state_group_before_event = None | |||
state_group_before_event_prev_group = None | |||
deltas_to_state_group_before_event = None | |||
@@ -513,23 +513,25 @@ class StateResolutionHandler: | |||
self.resolve_linearizer = Linearizer(name="state_resolve_lock") | |||
# dict of set of event_ids -> _StateCacheEntry. | |||
self._state_cache = ExpiringCache( | |||
self._state_cache: ExpiringCache[ | |||
FrozenSet[int], _StateCacheEntry | |||
] = ExpiringCache( | |||
cache_name="state_cache", | |||
clock=self.clock, | |||
max_len=100000, | |||
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, | |||
iterable=True, | |||
reset_expiry_on_get=True, | |||
) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry] | |||
) | |||
# | |||
# stuff for tracking time spent on state-res by room | |||
# | |||
# tracks the amount of work done on state res per room | |||
self._state_res_metrics = defaultdict( | |||
self._state_res_metrics: DefaultDict[str, _StateResMetrics] = defaultdict( | |||
_StateResMetrics | |||
) # type: DefaultDict[str, _StateResMetrics] | |||
) | |||
self.clock.looping_call(self._report_metrics, 120 * 1000) | |||
@@ -700,9 +702,9 @@ class StateResolutionHandler: | |||
items = self._state_res_metrics.items() | |||
# log the N biggest rooms | |||
biggest = heapq.nlargest( | |||
biggest: List[Tuple[str, _StateResMetrics]] = heapq.nlargest( | |||
n_to_log, items, key=lambda i: extract_key(i[1]) | |||
) # type: List[Tuple[str, _StateResMetrics]] | |||
) | |||
metrics_logger.debug( | |||
"%i biggest rooms for state-res by %s: %s", | |||
len(biggest), | |||
@@ -754,7 +756,7 @@ def _make_state_cache_entry( | |||
# failing that, look for the closest match. | |||
prev_group = None | |||
delta_ids = None # type: Optional[StateMap[str]] | |||
delta_ids: Optional[StateMap[str]] = None | |||
for old_group, old_state in state_groups_ids.items(): | |||
n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} | |||
@@ -159,7 +159,7 @@ def _seperate( | |||
""" | |||
state_set_iterator = iter(state_sets) | |||
unconflicted_state = dict(next(state_set_iterator)) | |||
conflicted_state = {} # type: MutableStateMap[Set[str]] | |||
conflicted_state: MutableStateMap[Set[str]] = {} | |||
for state_set in state_set_iterator: | |||
for key, value in state_set.items(): | |||
@@ -276,7 +276,7 @@ async def _get_auth_chain_difference( | |||
# event IDs if they appear in the `event_map`. This is the intersection of | |||
# the event's auth chain with the events in the `event_map` *plus* their | |||
# auth event IDs. | |||
events_to_auth_chain = {} # type: Dict[str, Set[str]] | |||
events_to_auth_chain: Dict[str, Set[str]] = {} | |||
for event in event_map.values(): | |||
chain = {event.event_id} | |||
events_to_auth_chain[event.event_id] = chain | |||
@@ -301,17 +301,17 @@ async def _get_auth_chain_difference( | |||
# ((type, state_key)->event_id) mappings; and (b) we have stripped out | |||
# unpersisted events and replaced them with the persisted events in | |||
# their auth chain. | |||
state_sets_ids = [] # type: List[Set[str]] | |||
state_sets_ids: List[Set[str]] = [] | |||
# For each state set, the unpersisted event IDs reachable (by their auth | |||
# chain) from the events in that set. | |||
unpersisted_set_ids = [] # type: List[Set[str]] | |||
unpersisted_set_ids: List[Set[str]] = [] | |||
for state_set in state_sets: | |||
set_ids = set() # type: Set[str] | |||
set_ids: Set[str] = set() | |||
state_sets_ids.append(set_ids) | |||
unpersisted_ids = set() # type: Set[str] | |||
unpersisted_ids: Set[str] = set() | |||
unpersisted_set_ids.append(unpersisted_ids) | |||
for event_id in state_set.values(): | |||
@@ -334,7 +334,7 @@ async def _get_auth_chain_difference( | |||
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) | |||
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) | |||
difference_from_event_map = union - intersection # type: Collection[str] | |||
difference_from_event_map: Collection[str] = union - intersection | |||
else: | |||
difference_from_event_map = () | |||
state_sets_ids = [set(state_set.values()) for state_set in state_sets] | |||
@@ -458,7 +458,7 @@ async def _reverse_topological_power_sort( | |||
The sorted list | |||
""" | |||
graph = {} # type: Dict[str, Set[str]] | |||
graph: Dict[str, Set[str]] = {} | |||
for idx, event_id in enumerate(event_ids, start=1): | |||
await _add_event_and_auth_chain_to_graph( | |||
graph, room_id, event_id, event_map, state_res_store, auth_diff | |||
@@ -657,7 +657,7 @@ async def _get_mainline_depth_for_event( | |||
""" | |||
room_id = event.room_id | |||
tmp_event = event # type: Optional[EventBase] | |||
tmp_event: Optional[EventBase] = event | |||
# We do an iterative search, replacing `event with the power level in its | |||
# auth events (if any) | |||
@@ -767,7 +767,7 @@ def lexicographical_topological_sort( | |||
# outgoing edges, c.f. | |||
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm | |||
outdegree_map = graph | |||
reverse_graph = {} # type: Dict[str, Set[str]] | |||
reverse_graph: Dict[str, Set[str]] = {} | |||
# Lists of nodes with zero out degree. Is actually a tuple of | |||
# `(key(node), node)` so that sorting does the right thing | |||
@@ -32,9 +32,9 @@ class EventSources: | |||
} | |||
def __init__(self, hs): | |||
self.sources = { | |||
self.sources: Dict[str, Any] = { | |||
name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items() | |||
} # type: Dict[str, Any] | |||
} | |||
self.store = hs.get_datastore() | |||
def get_current_token(self) -> StreamToken: | |||
@@ -210,7 +210,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta): | |||
'domain' : The domain part of the name | |||
""" | |||
SIGIL = abc.abstractproperty() # type: str # type: ignore | |||
SIGIL: str = abc.abstractproperty() # type: ignore | |||
localpart = attr.ib(type=str) | |||
domain = attr.ib(type=str) | |||
@@ -304,7 +304,7 @@ class GroupID(DomainSpecificString): | |||
@classmethod | |||
def from_string(cls: Type[DS], s: str) -> DS: | |||
group_id = super().from_string(s) # type: DS # type: ignore | |||
group_id: DS = super().from_string(s) # type: ignore | |||
if not group_id.localpart: | |||
raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) | |||
@@ -600,7 +600,7 @@ class StreamToken: | |||
groups_key = attr.ib(type=int) | |||
_SEPARATOR = "_" | |||
START = None # type: StreamToken | |||
START: "StreamToken" | |||
@classmethod | |||
async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": | |||
@@ -90,7 +90,7 @@ async def filter_events_for_client( | |||
AccountDataTypes.IGNORED_USER_LIST, user_id | |||
) | |||
ignore_list = frozenset() # type: FrozenSet[str] | |||
ignore_list: FrozenSet[str] = frozenset() | |||
if ignore_dict_content: | |||
ignored_users_dict = ignore_dict_content.get("ignored_users", {}) | |||
if isinstance(ignored_users_dict, dict): | |||