Selaa lähdekoodia

Fix up some typechecking (#6150)

* type checking fixes

* changelog
tags/v1.5.0rc1
Amber Brown 4 vuotta sitten
committed by GitHub
vanhempi
commit
864f144543
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
22 muutettua tiedostoa jossa 104 lisäystä ja 40 poistoa
  1. +1
    -0
      .gitignore
  2. +1
    -0
      changelog.d/6150.misc
  3. +2
    -1
      synapse/api/errors.py
  4. +4
    -1
      synapse/api/room_versions.py
  5. +3
    -1
      synapse/app/_base.py
  6. +3
    -2
      synapse/config/appservice.py
  7. +2
    -2
      synapse/config/consent_config.py
  8. +3
    -1
      synapse/config/password_auth_providers.py
  9. +3
    -2
      synapse/config/repository.py
  10. +7
    -3
      synapse/config/server.py
  11. +2
    -2
      synapse/config/server_notices_config.py
  12. +5
    -4
      synapse/logging/opentracing.py
  13. +16
    -4
      synapse/logging/utils.py
  14. +2
    -2
      synapse/metrics/__init__.py
  15. +2
    -2
      synapse/metrics/_exposition.py
  16. +13
    -4
      synapse/python_dependencies.py
  17. +2
    -1
      synapse/types.py
  18. +7
    -3
      synapse/util/async_helpers.py
  19. +2
    -1
      synapse/util/caches/__init__.py
  20. +20
    -2
      synapse/util/caches/descriptors.py
  21. +3
    -1
      synapse/util/caches/treecache.py
  22. +1
    -1
      synapse/util/module_loader.py

+ 1
- 0
.gitignore Näytä tiedosto

@@ -10,6 +10,7 @@
*.tac
_trial_temp/
_trial_temp*/
/out

# stuff that is likely to exist when you run a server locally
/*.db


+ 1
- 0
changelog.d/6150.misc Näytä tiedosto

@@ -0,0 +1 @@
Expand type-checking on modules imported by synapse.config.

+ 2
- 1
synapse/api/errors.py Näytä tiedosto

@@ -17,6 +17,7 @@
"""Contains exceptions and error codes."""

import logging
from typing import Dict

from six import iteritems
from six.moves import http_client
@@ -111,7 +112,7 @@ class ProxiedRequestError(SynapseError):
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
super(ProxiedRequestError, self).__init__(code, msg, errcode)
if additional_fields is None:
self._additional_fields = {}
self._additional_fields = {} # type: Dict
else:
self._additional_fields = dict(additional_fields)



+ 4
- 1
synapse/api/room_versions.py Näytä tiedosto

@@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

import attr


@@ -102,4 +105,4 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V4,
RoomVersions.V5,
)
} # type: dict[str, RoomVersion]
} # type: Dict[str, RoomVersion]

+ 3
- 1
synapse/app/_base.py Näytä tiedosto

@@ -263,7 +263,9 @@ def start(hs, listeners=None):
refresh_certificate(hs)

# Start the tracer
synapse.logging.opentracing.init_tracer(hs.config)
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
hs.config
)

# It is now safe to start your Synapse.
hs.start_listening(listeners)


+ 3
- 2
synapse/config/appservice.py Näytä tiedosto

@@ -13,6 +13,7 @@
# limitations under the License.

import logging
from typing import Dict

from six import string_types
from six.moves.urllib import parse as urlparse
@@ -56,8 +57,8 @@ def load_appservices(hostname, config_files):
return []

# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
seen_as_tokens = {} # type: Dict[str, str]
seen_ids = {} # type: Dict[str, str]

appservices = []



+ 2
- 2
synapse/config/consent_config.py Näytä tiedosto

@@ -73,8 +73,8 @@ DEFAULT_CONFIG = """\


class ConsentConfig(Config):
def __init__(self):
super(ConsentConfig, self).__init__()
def __init__(self, *args):
super(ConsentConfig, self).__init__(*args)

self.user_consent_version = None
self.user_consent_template_dir = None


+ 3
- 1
synapse/config/password_auth_providers.py Näytä tiedosto

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List

from synapse.util.module_loader import load_module

from ._base import Config
@@ -22,7 +24,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"

class PasswordAuthProviderConfig(Config):
def read_config(self, config, **kwargs):
self.password_providers = []
self.password_providers = [] # type: List[Any]
providers = []

# We want to be backwards compatible with the old `ldap_config`


+ 3
- 2
synapse/config/repository.py Näytä tiedosto

@@ -15,6 +15,7 @@

import os
from collections import namedtuple
from typing import Dict, List

from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module
@@ -61,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
"""
requirements = {}
requirements = {} # type: Dict[str, List]
for size in thumbnail_sizes:
width = size["width"]
height = size["height"]
@@ -130,7 +131,7 @@ class ContentRepositoryConfig(Config):
#
# We don't create the storage providers here as not all workers need
# them to be started.
self.media_storage_providers = []
self.media_storage_providers = [] # type: List[tuple]

for provider_config in storage_providers:
# We special case the module "file_system" so as not to need to


+ 7
- 3
synapse/config/server.py Näytä tiedosto

@@ -19,6 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
from typing import List

import attr
import yaml
@@ -243,7 +244,7 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)

self.listeners = []
self.listeners = [] # type: List[dict]
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
raise ConfigError(
@@ -287,7 +288,10 @@ class ServerConfig(Config):
validator=attr.validators.instance_of(bool), default=False
)
complexity = attr.ib(
validator=attr.validators.instance_of((int, float)), default=1.0
validator=attr.validators.instance_of(
(float, int) # type: ignore[arg-type] # noqa
),
default=1.0,
)
complexity_error = attr.ib(
validator=attr.validators.instance_of(str),
@@ -366,7 +370,7 @@ class ServerConfig(Config):
"cleanup_extremities_with_dummy_events", True
)

def has_tls_listener(self):
def has_tls_listener(self) -> bool:
return any(l["tls"] for l in self.listeners)

def generate_config_section(


+ 2
- 2
synapse/config/server_notices_config.py Näytä tiedosto

@@ -59,8 +59,8 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled.
"""

def __init__(self):
super(ServerNoticesConfig, self).__init__()
def __init__(self, *args):
super(ServerNoticesConfig, self).__init__(*args)
self.server_notices_mxid = None
self.server_notices_mxid_display_name = None
self.server_notices_mxid_avatar_url = None


+ 5
- 4
synapse/logging/opentracing.py Näytä tiedosto

@@ -170,6 +170,7 @@ import inspect
import logging
import re
from functools import wraps
from typing import Dict

from canonicaljson import json

@@ -547,7 +548,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
return

span = opentracing.tracer.active_span
carrier = {}
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)

for key, value in carrier.items():
@@ -584,7 +585,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):

span = opentracing.tracer.active_span

carrier = {}
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)

for key, value in carrier.items():
@@ -639,7 +640,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination):
return {}

carrier = {}
carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
)
@@ -653,7 +654,7 @@ def active_span_context_as_string():
Returns:
The active span context encoded as a string.
"""
carrier = {}
carrier = {} # type: Dict[str, str]
if opentracing:
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier


+ 16
- 4
synapse/logging/utils.py Näytä tiedosto

@@ -119,7 +119,11 @@ def trace_function(f):
logger = logging.getLogger(name)
level = logging.DEBUG

s = inspect.currentframe().f_back
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")

s = frame.f_back

to_print = [
"\t%s:%s %s. Args: args=%s, kwargs=%s"
@@ -144,7 +148,7 @@ def trace_function(f):
pathname=pathname,
lineno=lineno,
msg=msg,
args=None,
args=tuple(),
exc_info=None,
)

@@ -157,7 +161,12 @@ def trace_function(f):


def get_previous_frames():
s = inspect.currentframe().f_back.f_back

frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")

s = frame.f_back.f_back
to_return = []
while s:
if s.f_globals["__name__"].startswith("synapse"):
@@ -174,7 +183,10 @@ def get_previous_frames():


def get_previous_frame(ignore=[]):
s = inspect.currentframe().f_back.f_back
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back

while s:
if s.f_globals["__name__"].startswith("synapse"):


+ 2
- 2
synapse/metrics/__init__.py Näytä tiedosto

@@ -125,7 +125,7 @@ class InFlightGauge(object):
)

# Counts number of in flight blocks for a given set of label values
self._registrations = {}
self._registrations = {} # type: Dict

# Protects access to _registrations
self._lock = threading.Lock()
@@ -226,7 +226,7 @@ class BucketCollector(object):
# Fetch the data -- this must be synchronous!
data = self.data_collector()

buckets = {}
buckets = {} # type: Dict[float, int]

res = []
for x in data.keys():


+ 2
- 2
synapse/metrics/_exposition.py Näytä tiedosto

@@ -36,9 +36,9 @@ from twisted.web.resource import Resource
try:
from prometheus_client.samples import Sample
except ImportError:
Sample = namedtuple(
Sample = namedtuple( # type: ignore[no-redef] # noqa
"Sample", ["name", "labels", "value", "timestamp", "exemplar"]
) # type: ignore
)


CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")


+ 13
- 4
synapse/python_dependencies.py Näytä tiedosto

@@ -15,7 +15,7 @@
# limitations under the License.

import logging
from typing import Set
from typing import List, Set

from pkg_resources import (
DistributionNotFound,
@@ -73,6 +73,7 @@ REQUIREMENTS = [
"netaddr>=0.7.18",
"Jinja2>=2.9",
"bleach>=1.4.3",
"typing-extensions>=3.7.4",
]

CONDITIONAL_REQUIREMENTS = {
@@ -144,7 +145,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency)
errors.append(
"Needed %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version)
% (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
)
except DistributionNotFound:
deps_needed.append(dependency)
@@ -159,7 +164,7 @@ def check_requirements(for_feature=None):
if not for_feature:
# Check the optional dependencies are up to date. We allow them to not be
# installed.
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), [])
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str]

for dependency in OPTS:
try:
@@ -168,7 +173,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency)
errors.append(
"Needed optional %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version)
% (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
)
except DistributionNotFound:
# If it's not found, we don't care


+ 2
- 1
synapse/types.py Näytä tiedosto

@@ -318,6 +318,7 @@ class StreamToken(
)
):
_SEPARATOR = "_"
START = None # type: StreamToken

@classmethod
def from_string(cls, string):
@@ -402,7 +403,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
followed by the "stream_ordering" id of the event it comes after.
"""

__slots__ = []
__slots__ = [] # type: list

@classmethod
def parse(cls, string):


+ 7
- 3
synapse/util/async_helpers.py Näytä tiedosto

@@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import logging
from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union

from six.moves import range

@@ -213,7 +215,9 @@ class Linearizer(object):
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
self.key_to_defer = {}
self.key_to_defer = (
{}
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]

def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@@ -340,10 +344,10 @@ class ReadWriteLock(object):

def __init__(self):
# Latest readers queued
self.key_to_current_readers = {}
self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]

# Latest writer queued
self.key_to_current_writer = {}
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]

@defer.inlineCallbacks
def read(self, key):


+ 2
- 1
synapse/util/caches/__init__.py Näytä tiedosto

@@ -16,6 +16,7 @@

import logging
import os
from typing import Dict

import six
from six.moves import intern
@@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name):


caches_by_name = {}
collectors_by_name = {}
collectors_by_name = {} # type: Dict

cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])


+ 20
- 2
synapse/util/caches/descriptors.py Näytä tiedosto

@@ -18,10 +18,12 @@ import inspect
import logging
import threading
from collections import namedtuple
from typing import Any, cast

from six import itervalues

from prometheus_client import Gauge
from typing_extensions import Protocol

from twisted.internet import defer

@@ -37,6 +39,18 @@ from . import register_cache
logger = logging.getLogger(__name__)


class _CachedFunction(Protocol):
invalidate = None # type: Any
invalidate_all = None # type: Any
invalidate_many = None # type: Any
prefill = None # type: Any
cache = None # type: Any
num_args = None # type: Any

def __name__(self):
...


cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
@@ -245,7 +259,9 @@ class Cache(object):


class _CacheDescriptorBase(object):
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
def __init__(
self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
):
self.orig = orig

if inlineCallbacks:
@@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return tuple(get_cache_key_gen(args, kwargs))

@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):

return make_deferred_yieldable(observer)

wrapped = cast(_CachedFunction, _wrapped)

if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)


+ 3
- 1
synapse/util/caches/treecache.py Näytä tiedosto

@@ -1,3 +1,5 @@
from typing import Dict

from six import itervalues

SENTINEL = object()
@@ -12,7 +14,7 @@ class TreeCache(object):

def __init__(self):
self.size = 0
self.root = {}
self.root = {} # type: Dict

def __setitem__(self, key, value):
return self.set(key, value)


+ 1
- 1
synapse/util/module_loader.py Näytä tiedosto

@@ -54,5 +54,5 @@ def load_python_module(location: str):
if spec is None:
raise Exception("Unable to load module at %s" % (location,))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
spec.loader.exec_module(mod) # type: ignore
return mod

Ladataan…
Peruuta
Tallenna