@@ -10,6 +10,7 @@ | |||
*.tac | |||
_trial_temp/ | |||
_trial_temp*/ | |||
/out | |||
# stuff that is likely to exist when you run a server locally | |||
/*.db | |||
@@ -0,0 +1 @@ | |||
Expand type-checking on modules imported by synapse.config. |
@@ -17,6 +17,7 @@ | |||
"""Contains exceptions and error codes.""" | |||
import logging | |||
from typing import Dict | |||
from six import iteritems | |||
from six.moves import http_client | |||
@@ -111,7 +112,7 @@ class ProxiedRequestError(SynapseError): | |||
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): | |||
super(ProxiedRequestError, self).__init__(code, msg, errcode) | |||
if additional_fields is None: | |||
self._additional_fields = {} | |||
self._additional_fields = {} # type: Dict | |||
else: | |||
self._additional_fields = dict(additional_fields) | |||
@@ -12,6 +12,9 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Dict | |||
import attr | |||
@@ -102,4 +105,4 @@ KNOWN_ROOM_VERSIONS = { | |||
RoomVersions.V4, | |||
RoomVersions.V5, | |||
) | |||
} # type: dict[str, RoomVersion] | |||
} # type: Dict[str, RoomVersion] |
@@ -263,7 +263,9 @@ def start(hs, listeners=None): | |||
refresh_certificate(hs) | |||
# Start the tracer | |||
synapse.logging.opentracing.init_tracer(hs.config) | |||
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa | |||
hs.config | |||
) | |||
# It is now safe to start your Synapse. | |||
hs.start_listening(listeners) | |||
@@ -13,6 +13,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Dict | |||
from six import string_types | |||
from six.moves.urllib import parse as urlparse | |||
@@ -56,8 +57,8 @@ def load_appservices(hostname, config_files): | |||
return [] | |||
# Dicts of value -> filename | |||
seen_as_tokens = {} | |||
seen_ids = {} | |||
seen_as_tokens = {} # type: Dict[str, str] | |||
seen_ids = {} # type: Dict[str, str] | |||
appservices = [] | |||
@@ -73,8 +73,8 @@ DEFAULT_CONFIG = """\ | |||
class ConsentConfig(Config): | |||
def __init__(self): | |||
super(ConsentConfig, self).__init__() | |||
def __init__(self, *args): | |||
super(ConsentConfig, self).__init__(*args) | |||
self.user_consent_version = None | |||
self.user_consent_template_dir = None | |||
@@ -13,6 +13,8 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, List | |||
from synapse.util.module_loader import load_module | |||
from ._base import Config | |||
@@ -22,7 +24,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider" | |||
class PasswordAuthProviderConfig(Config): | |||
def read_config(self, config, **kwargs): | |||
self.password_providers = [] | |||
self.password_providers = [] # type: List[Any] | |||
providers = [] | |||
# We want to be backwards compatible with the old `ldap_config` | |||
@@ -15,6 +15,7 @@ | |||
import os | |||
from collections import namedtuple | |||
from typing import Dict, List | |||
from synapse.python_dependencies import DependencyException, check_requirements | |||
from synapse.util.module_loader import load_module | |||
@@ -61,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): | |||
Dictionary mapping from media type string to list of | |||
ThumbnailRequirement tuples. | |||
""" | |||
requirements = {} | |||
requirements = {} # type: Dict[str, List] | |||
for size in thumbnail_sizes: | |||
width = size["width"] | |||
height = size["height"] | |||
@@ -130,7 +131,7 @@ class ContentRepositoryConfig(Config): | |||
# | |||
# We don't create the storage providers here as not all workers need | |||
# them to be started. | |||
self.media_storage_providers = [] | |||
self.media_storage_providers = [] # type: List[tuple] | |||
for provider_config in storage_providers: | |||
# We special case the module "file_system" so as not to need to | |||
@@ -19,6 +19,7 @@ import logging | |||
import os.path | |||
import re | |||
from textwrap import indent | |||
from typing import List | |||
import attr | |||
import yaml | |||
@@ -243,7 +244,7 @@ class ServerConfig(Config): | |||
# events with profile information that differ from the target's global profile. | |||
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) | |||
self.listeners = [] | |||
self.listeners = [] # type: List[dict] | |||
for listener in config.get("listeners", []): | |||
if not isinstance(listener.get("port", None), int): | |||
raise ConfigError( | |||
@@ -287,7 +288,10 @@ class ServerConfig(Config): | |||
validator=attr.validators.instance_of(bool), default=False | |||
) | |||
complexity = attr.ib( | |||
validator=attr.validators.instance_of((int, float)), default=1.0 | |||
validator=attr.validators.instance_of( | |||
(float, int) # type: ignore[arg-type] # noqa | |||
), | |||
default=1.0, | |||
) | |||
complexity_error = attr.ib( | |||
validator=attr.validators.instance_of(str), | |||
@@ -366,7 +370,7 @@ class ServerConfig(Config): | |||
"cleanup_extremities_with_dummy_events", True | |||
) | |||
def has_tls_listener(self): | |||
def has_tls_listener(self) -> bool: | |||
return any(l["tls"] for l in self.listeners) | |||
def generate_config_section( | |||
@@ -59,8 +59,8 @@ class ServerNoticesConfig(Config): | |||
None if server notices are not enabled. | |||
""" | |||
def __init__(self): | |||
super(ServerNoticesConfig, self).__init__() | |||
def __init__(self, *args): | |||
super(ServerNoticesConfig, self).__init__(*args) | |||
self.server_notices_mxid = None | |||
self.server_notices_mxid_display_name = None | |||
self.server_notices_mxid_avatar_url = None | |||
@@ -170,6 +170,7 @@ import inspect | |||
import logging | |||
import re | |||
from functools import wraps | |||
from typing import Dict | |||
from canonicaljson import json | |||
@@ -547,7 +548,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T | |||
return | |||
span = opentracing.tracer.active_span | |||
carrier = {} | |||
carrier = {} # type: Dict[str, str] | |||
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) | |||
for key, value in carrier.items(): | |||
@@ -584,7 +585,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True): | |||
span = opentracing.tracer.active_span | |||
carrier = {} | |||
carrier = {} # type: Dict[str, str] | |||
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) | |||
for key, value in carrier.items(): | |||
@@ -639,7 +640,7 @@ def get_active_span_text_map(destination=None): | |||
if destination and not whitelisted_homeserver(destination): | |||
return {} | |||
carrier = {} | |||
carrier = {} # type: Dict[str, str] | |||
opentracing.tracer.inject( | |||
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier | |||
) | |||
@@ -653,7 +654,7 @@ def active_span_context_as_string(): | |||
Returns: | |||
The active span context encoded as a string. | |||
""" | |||
carrier = {} | |||
carrier = {} # type: Dict[str, str] | |||
if opentracing: | |||
opentracing.tracer.inject( | |||
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier | |||
@@ -119,7 +119,11 @@ def trace_function(f): | |||
logger = logging.getLogger(name) | |||
level = logging.DEBUG | |||
s = inspect.currentframe().f_back | |||
frame = inspect.currentframe() | |||
if frame is None: | |||
raise Exception("Can't get current frame!") | |||
s = frame.f_back | |||
to_print = [ | |||
"\t%s:%s %s. Args: args=%s, kwargs=%s" | |||
@@ -144,7 +148,7 @@ def trace_function(f): | |||
pathname=pathname, | |||
lineno=lineno, | |||
msg=msg, | |||
args=None, | |||
args=tuple(), | |||
exc_info=None, | |||
) | |||
@@ -157,7 +161,12 @@ def trace_function(f): | |||
def get_previous_frames(): | |||
s = inspect.currentframe().f_back.f_back | |||
frame = inspect.currentframe() | |||
if frame is None: | |||
raise Exception("Can't get current frame!") | |||
s = frame.f_back.f_back | |||
to_return = [] | |||
while s: | |||
if s.f_globals["__name__"].startswith("synapse"): | |||
@@ -174,7 +183,10 @@ def get_previous_frames(): | |||
def get_previous_frame(ignore=[]): | |||
s = inspect.currentframe().f_back.f_back | |||
frame = inspect.currentframe() | |||
if frame is None: | |||
raise Exception("Can't get current frame!") | |||
s = frame.f_back.f_back | |||
while s: | |||
if s.f_globals["__name__"].startswith("synapse"): | |||
@@ -125,7 +125,7 @@ class InFlightGauge(object): | |||
) | |||
# Counts number of in flight blocks for a given set of label values | |||
self._registrations = {} | |||
self._registrations = {} # type: Dict | |||
# Protects access to _registrations | |||
self._lock = threading.Lock() | |||
@@ -226,7 +226,7 @@ class BucketCollector(object): | |||
# Fetch the data -- this must be synchronous! | |||
data = self.data_collector() | |||
buckets = {} | |||
buckets = {} # type: Dict[float, int] | |||
res = [] | |||
for x in data.keys(): | |||
@@ -36,9 +36,9 @@ from twisted.web.resource import Resource | |||
try: | |||
from prometheus_client.samples import Sample | |||
except ImportError: | |||
Sample = namedtuple( | |||
Sample = namedtuple( # type: ignore[no-redef] # noqa | |||
"Sample", ["name", "labels", "value", "timestamp", "exemplar"] | |||
) # type: ignore | |||
) | |||
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") | |||
@@ -15,7 +15,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Set | |||
from typing import List, Set | |||
from pkg_resources import ( | |||
DistributionNotFound, | |||
@@ -73,6 +73,7 @@ REQUIREMENTS = [ | |||
"netaddr>=0.7.18", | |||
"Jinja2>=2.9", | |||
"bleach>=1.4.3", | |||
"typing-extensions>=3.7.4", | |||
] | |||
CONDITIONAL_REQUIREMENTS = { | |||
@@ -144,7 +145,11 @@ def check_requirements(for_feature=None): | |||
deps_needed.append(dependency) | |||
errors.append( | |||
"Needed %s, got %s==%s" | |||
% (dependency, e.dist.project_name, e.dist.version) | |||
% ( | |||
dependency, | |||
e.dist.project_name, # type: ignore[attr-defined] # noqa | |||
e.dist.version, # type: ignore[attr-defined] # noqa | |||
) | |||
) | |||
except DistributionNotFound: | |||
deps_needed.append(dependency) | |||
@@ -159,7 +164,7 @@ def check_requirements(for_feature=None): | |||
if not for_feature: | |||
# Check the optional dependencies are up to date. We allow them to not be | |||
# installed. | |||
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) | |||
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str] | |||
for dependency in OPTS: | |||
try: | |||
@@ -168,7 +173,11 @@ def check_requirements(for_feature=None): | |||
deps_needed.append(dependency) | |||
errors.append( | |||
"Needed optional %s, got %s==%s" | |||
% (dependency, e.dist.project_name, e.dist.version) | |||
% ( | |||
dependency, | |||
e.dist.project_name, # type: ignore[attr-defined] # noqa | |||
e.dist.version, # type: ignore[attr-defined] # noqa | |||
) | |||
) | |||
except DistributionNotFound: | |||
# If it's not found, we don't care | |||
@@ -318,6 +318,7 @@ class StreamToken( | |||
) | |||
): | |||
_SEPARATOR = "_" | |||
START = None # type: StreamToken | |||
@classmethod | |||
def from_string(cls, string): | |||
@@ -402,7 +403,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): | |||
followed by the "stream_ordering" id of the event it comes after. | |||
""" | |||
__slots__ = [] | |||
__slots__ = [] # type: list | |||
@classmethod | |||
def parse(cls, string): | |||
@@ -13,9 +13,11 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import collections | |||
import logging | |||
from contextlib import contextmanager | |||
from typing import Dict, Sequence, Set, Union | |||
from six.moves import range | |||
@@ -213,7 +215,9 @@ class Linearizer(object): | |||
# the first element is the number of things executing, and | |||
# the second element is an OrderedDict, where the keys are deferreds for the | |||
# things blocked from executing. | |||
self.key_to_defer = {} | |||
self.key_to_defer = ( | |||
{} | |||
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]] | |||
def queue(self, key): | |||
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. | |||
@@ -340,10 +344,10 @@ class ReadWriteLock(object): | |||
def __init__(self): | |||
# Latest readers queued | |||
self.key_to_current_readers = {} | |||
self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]] | |||
# Latest writer queued | |||
self.key_to_current_writer = {} | |||
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] | |||
@defer.inlineCallbacks | |||
def read(self, key): | |||
@@ -16,6 +16,7 @@ | |||
import logging | |||
import os | |||
from typing import Dict | |||
import six | |||
from six.moves import intern | |||
@@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name): | |||
caches_by_name = {} | |||
collectors_by_name = {} | |||
collectors_by_name = {} # type: Dict | |||
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) | |||
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) | |||
@@ -18,10 +18,12 @@ import inspect | |||
import logging | |||
import threading | |||
from collections import namedtuple | |||
from typing import Any, cast | |||
from six import itervalues | |||
from prometheus_client import Gauge | |||
from typing_extensions import Protocol | |||
from twisted.internet import defer | |||
@@ -37,6 +39,18 @@ from . import register_cache | |||
logger = logging.getLogger(__name__) | |||
class _CachedFunction(Protocol): | |||
invalidate = None # type: Any | |||
invalidate_all = None # type: Any | |||
invalidate_many = None # type: Any | |||
prefill = None # type: Any | |||
cache = None # type: Any | |||
num_args = None # type: Any | |||
def __name__(self): | |||
... | |||
cache_pending_metric = Gauge( | |||
"synapse_util_caches_cache_pending", | |||
"Number of lookups currently pending for this cache", | |||
@@ -245,7 +259,9 @@ class Cache(object): | |||
class _CacheDescriptorBase(object): | |||
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): | |||
def __init__( | |||
self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False | |||
): | |||
self.orig = orig | |||
if inlineCallbacks: | |||
@@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase): | |||
return tuple(get_cache_key_gen(args, kwargs)) | |||
@functools.wraps(self.orig) | |||
def wrapped(*args, **kwargs): | |||
def _wrapped(*args, **kwargs): | |||
# If we're passed a cache_context then we'll want to call its invalidate() | |||
# whenever we are invalidated | |||
invalidate_callback = kwargs.pop("on_invalidate", None) | |||
@@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase): | |||
return make_deferred_yieldable(observer) | |||
wrapped = cast(_CachedFunction, _wrapped) | |||
if self.num_args == 1: | |||
wrapped.invalidate = lambda key: cache.invalidate(key[0]) | |||
wrapped.prefill = lambda key, val: cache.prefill(key[0], val) | |||
@@ -1,3 +1,5 @@ | |||
from typing import Dict | |||
from six import itervalues | |||
SENTINEL = object() | |||
@@ -12,7 +14,7 @@ class TreeCache(object): | |||
def __init__(self): | |||
self.size = 0 | |||
self.root = {} | |||
self.root = {} # type: Dict | |||
def __setitem__(self, key, value): | |||
return self.set(key, value) | |||
@@ -54,5 +54,5 @@ def load_python_module(location: str): | |||
if spec is None: | |||
raise Exception("Unable to load module at %s" % (location,)) | |||
mod = importlib.util.module_from_spec(spec) | |||
spec.loader.exec_module(mod) | |||
spec.loader.exec_module(mod) # type: ignore | |||
return mod |