@@ -0,0 +1 @@ | |||||
Run `pyupgrade` for Python 3.8+. |
@@ -769,7 +769,7 @@ def main(server_url, identity_server_url, username, token, config_path): | |||||
global CONFIG_JSON | global CONFIG_JSON | ||||
CONFIG_JSON = config_path # bit cheeky, but just overwrite the global | CONFIG_JSON = config_path # bit cheeky, but just overwrite the global | ||||
try: | try: | ||||
with open(config_path, "r") as config: | |||||
with open(config_path) as config: | |||||
syn_cmd.config = json.load(config) | syn_cmd.config = json.load(config) | ||||
try: | try: | ||||
http_client.verbose = "on" == syn_cmd.config["verbose"] | http_client.verbose = "on" == syn_cmd.config["verbose"] | ||||
@@ -861,7 +861,7 @@ def generate_worker_files( | |||||
# Then a worker config file | # Then a worker config file | ||||
convert( | convert( | ||||
"/conf/worker.yaml.j2", | "/conf/worker.yaml.j2", | ||||
"/conf/workers/{name}.yaml".format(name=worker_name), | |||||
f"/conf/workers/{worker_name}.yaml", | |||||
**worker_config, | **worker_config, | ||||
worker_log_config_filepath=log_config_filepath, | worker_log_config_filepath=log_config_filepath, | ||||
using_unix_sockets=using_unix_sockets, | using_unix_sockets=using_unix_sockets, | ||||
@@ -82,7 +82,7 @@ def generate_config_from_template( | |||||
with open(filename) as handle: | with open(filename) as handle: | ||||
value = handle.read() | value = handle.read() | ||||
else: | else: | ||||
log("Generating a random secret for {}".format(secret)) | |||||
log(f"Generating a random secret for {secret}") | |||||
value = codecs.encode(os.urandom(32), "hex").decode() | value = codecs.encode(os.urandom(32), "hex").decode() | ||||
with open(filename, "w") as handle: | with open(filename, "w") as handle: | ||||
handle.write(value) | handle.write(value) | ||||
@@ -47,7 +47,7 @@ can be passed on the commandline for debugging. | |||||
projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | ||||
class Builder(object): | |||||
class Builder: | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
redirect_stdout: bool = False, | redirect_stdout: bool = False, | ||||
@@ -43,7 +43,7 @@ def main(force_colors: bool) -> None: | |||||
diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None) | diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None) | ||||
# Get the schema version of the local file to check against current schema on develop | # Get the schema version of the local file to check against current schema on develop | ||||
with open("synapse/storage/schema/__init__.py", "r") as file: | |||||
with open("synapse/storage/schema/__init__.py") as file: | |||||
local_schema = file.read() | local_schema = file.read() | ||||
new_locals: Dict[str, Any] = {} | new_locals: Dict[str, Any] = {} | ||||
exec(local_schema, new_locals) | exec(local_schema, new_locals) | ||||
@@ -247,7 +247,7 @@ def main() -> None: | |||||
def read_args_from_config(args: argparse.Namespace) -> None: | def read_args_from_config(args: argparse.Namespace) -> None: | ||||
with open(args.config, "r") as fh: | |||||
with open(args.config) as fh: | |||||
config = yaml.safe_load(fh) | config = yaml.safe_load(fh) | ||||
if not args.server_name: | if not args.server_name: | ||||
@@ -1,5 +1,4 @@ | |||||
#!/usr/bin/env python | #!/usr/bin/env python | ||||
# -*- coding: utf-8 -*- | |||||
# Copyright 2020 The Matrix.org Foundation C.I.C. | # Copyright 2020 The Matrix.org Foundation C.I.C. | ||||
# | # | ||||
# Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
@@ -145,7 +145,7 @@ Example usage: | |||||
def read_args_from_config(args: argparse.Namespace) -> None: | def read_args_from_config(args: argparse.Namespace) -> None: | ||||
with open(args.config, "r") as fh: | |||||
with open(args.config) as fh: | |||||
config = yaml.safe_load(fh) | config = yaml.safe_load(fh) | ||||
if not args.server_name: | if not args.server_name: | ||||
args.server_name = config["server_name"] | args.server_name = config["server_name"] | ||||
@@ -25,7 +25,11 @@ from synapse.util.rust import check_rust_lib_up_to_date | |||||
from synapse.util.stringutils import strtobool | from synapse.util.stringutils import strtobool | ||||
# Check that we're not running on an unsupported Python version. | # Check that we're not running on an unsupported Python version. | ||||
if sys.version_info < (3, 8): | |||||
# | |||||
# Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the | |||||
# if-statement completely. | |||||
py_version = sys.version_info | |||||
if py_version < (3, 8): | |||||
print("Synapse requires Python 3.8 or above.") | print("Synapse requires Python 3.8 or above.") | ||||
sys.exit(1) | sys.exit(1) | ||||
@@ -78,7 +82,7 @@ try: | |||||
except ImportError: | except ImportError: | ||||
pass | pass | ||||
import synapse.util | |||||
import synapse.util # noqa: E402 | |||||
__version__ = synapse.util.SYNAPSE_VERSION | __version__ = synapse.util.SYNAPSE_VERSION | ||||
@@ -1205,10 +1205,10 @@ class CursesProgress(Progress): | |||||
self.total_processed = 0 | self.total_processed = 0 | ||||
self.total_remaining = 0 | self.total_remaining = 0 | ||||
super(CursesProgress, self).__init__() | |||||
super().__init__() | |||||
def update(self, table: str, num_done: int) -> None: | def update(self, table: str, num_done: int) -> None: | ||||
super(CursesProgress, self).update(table, num_done) | |||||
super().update(table, num_done) | |||||
self.total_processed = 0 | self.total_processed = 0 | ||||
self.total_remaining = 0 | self.total_remaining = 0 | ||||
@@ -1304,7 +1304,7 @@ class TerminalProgress(Progress): | |||||
"""Just prints progress to the terminal""" | """Just prints progress to the terminal""" | ||||
def update(self, table: str, num_done: int) -> None: | def update(self, table: str, num_done: int) -> None: | ||||
super(TerminalProgress, self).update(table, num_done) | |||||
super().update(table, num_done) | |||||
data = self.tables[table] | data = self.tables[table] | ||||
@@ -38,7 +38,7 @@ class MockHomeserver(HomeServer): | |||||
DATASTORE_CLASS = DataStore # type: ignore [assignment] | DATASTORE_CLASS = DataStore # type: ignore [assignment] | ||||
def __init__(self, config: HomeServerConfig): | def __init__(self, config: HomeServerConfig): | ||||
super(MockHomeserver, self).__init__( | |||||
super().__init__( | |||||
hostname=config.server.server_name, | hostname=config.server.server_name, | ||||
config=config, | config=config, | ||||
reactor=reactor, | reactor=reactor, | ||||
@@ -18,8 +18,7 @@ | |||||
"""Contains constants from the specification.""" | """Contains constants from the specification.""" | ||||
import enum | import enum | ||||
from typing_extensions import Final | |||||
from typing import Final | |||||
# the max size of a (canonical-json-encoded) event | # the max size of a (canonical-json-encoded) event | ||||
MAX_PDU_SIZE = 65536 | MAX_PDU_SIZE = 65536 | ||||
@@ -32,6 +32,7 @@ from typing import ( | |||||
Any, | Any, | ||||
Callable, | Callable, | ||||
Collection, | Collection, | ||||
ContextManager, | |||||
Dict, | Dict, | ||||
Generator, | Generator, | ||||
Iterable, | Iterable, | ||||
@@ -43,7 +44,6 @@ from typing import ( | |||||
) | ) | ||||
from prometheus_client import Counter | from prometheus_client import Counter | ||||
from typing_extensions import ContextManager | |||||
import synapse.metrics | import synapse.metrics | ||||
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState | from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState | ||||
@@ -24,13 +24,14 @@ from typing import ( | |||||
Iterable, | Iterable, | ||||
List, | List, | ||||
Mapping, | Mapping, | ||||
NoReturn, | |||||
Optional, | Optional, | ||||
Set, | Set, | ||||
) | ) | ||||
from urllib.parse import urlencode | from urllib.parse import urlencode | ||||
import attr | import attr | ||||
from typing_extensions import NoReturn, Protocol | |||||
from typing_extensions import Protocol | |||||
from twisted.web.iweb import IRequest | from twisted.web.iweb import IRequest | ||||
from twisted.web.server import Request | from twisted.web.server import Request | ||||
@@ -791,7 +792,7 @@ class SsoHandler: | |||||
if code != 200: | if code != 200: | ||||
raise Exception( | raise Exception( | ||||
"GET request to download sso avatar image returned {}".format(code) | |||||
f"GET request to download sso avatar image returned {code}" | |||||
) | ) | ||||
# upload name includes hash of the image file's content so that we can | # upload name includes hash of the image file's content so that we can | ||||
@@ -14,9 +14,15 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
import logging | import logging | ||||
from collections import Counter | from collections import Counter | ||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple | |||||
from typing_extensions import Counter as CounterType | |||||
from typing import ( | |||||
TYPE_CHECKING, | |||||
Any, | |||||
Counter as CounterType, | |||||
Dict, | |||||
Iterable, | |||||
Optional, | |||||
Tuple, | |||||
) | |||||
from synapse.api.constants import EventContentFields, EventTypes, Membership | from synapse.api.constants import EventContentFields, EventTypes, Membership | ||||
from synapse.metrics import event_processing_positions | from synapse.metrics import event_processing_positions | ||||
@@ -1442,11 +1442,9 @@ class SyncHandler: | |||||
# Now we have our list of joined room IDs, exclude as configured and freeze | # Now we have our list of joined room IDs, exclude as configured and freeze | ||||
joined_room_ids = frozenset( | joined_room_ids = frozenset( | ||||
( | |||||
room_id | |||||
for room_id in mutable_joined_room_ids | |||||
if room_id not in mutable_rooms_to_exclude | |||||
) | |||||
room_id | |||||
for room_id in mutable_joined_room_ids | |||||
if room_id not in mutable_rooms_to_exclude | |||||
) | ) | ||||
logger.debug( | logger.debug( | ||||
@@ -18,10 +18,9 @@ import traceback | |||||
from collections import deque | from collections import deque | ||||
from ipaddress import IPv4Address, IPv6Address, ip_address | from ipaddress import IPv4Address, IPv6Address, ip_address | ||||
from math import floor | from math import floor | ||||
from typing import Callable, Optional | |||||
from typing import Callable, Deque, Optional | |||||
import attr | import attr | ||||
from typing_extensions import Deque | |||||
from zope.interface import implementer | from zope.interface import implementer | ||||
from twisted.application.internet import ClientService | from twisted.application.internet import ClientService | ||||
@@ -426,9 +426,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
generally discouraged as it doesn't support internationalization. | generally discouraged as it doesn't support internationalization. | ||||
""" | """ | ||||
for callback in self._check_event_for_spam_callbacks: | for callback in self._check_event_for_spam_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
res = await delay_cancellation(callback(event)) | res = await delay_cancellation(callback(event)) | ||||
if res is False or res == self.NOT_SPAM: | if res is False or res == self.NOT_SPAM: | ||||
# This spam-checker accepts the event. | # This spam-checker accepts the event. | ||||
@@ -481,9 +479,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
True if the event should be silently dropped | True if the event should be silently dropped | ||||
""" | """ | ||||
for callback in self._should_drop_federated_event_callbacks: | for callback in self._should_drop_federated_event_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
res: Union[bool, str] = await delay_cancellation(callback(event)) | res: Union[bool, str] = await delay_cancellation(callback(event)) | ||||
if res: | if res: | ||||
return res | return res | ||||
@@ -505,9 +501,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise. | NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise. | ||||
""" | """ | ||||
for callback in self._user_may_join_room_callbacks: | for callback in self._user_may_join_room_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
res = await delay_cancellation(callback(user_id, room_id, is_invited)) | res = await delay_cancellation(callback(user_id, room_id, is_invited)) | ||||
# Normalize return values to `Codes` or `"NOT_SPAM"`. | # Normalize return values to `Codes` or `"NOT_SPAM"`. | ||||
if res is True or res is self.NOT_SPAM: | if res is True or res is self.NOT_SPAM: | ||||
@@ -546,9 +540,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
NOT_SPAM if the operation is permitted, Codes otherwise. | NOT_SPAM if the operation is permitted, Codes otherwise. | ||||
""" | """ | ||||
for callback in self._user_may_invite_callbacks: | for callback in self._user_may_invite_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
res = await delay_cancellation( | res = await delay_cancellation( | ||||
callback(inviter_userid, invitee_userid, room_id) | callback(inviter_userid, invitee_userid, room_id) | ||||
) | ) | ||||
@@ -593,9 +585,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
NOT_SPAM if the operation is permitted, Codes otherwise. | NOT_SPAM if the operation is permitted, Codes otherwise. | ||||
""" | """ | ||||
for callback in self._user_may_send_3pid_invite_callbacks: | for callback in self._user_may_send_3pid_invite_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
res = await delay_cancellation( | res = await delay_cancellation( | ||||
callback(inviter_userid, medium, address, room_id) | callback(inviter_userid, medium, address, room_id) | ||||
) | ) | ||||
@@ -630,9 +620,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
userid: The ID of the user attempting to create a room | userid: The ID of the user attempting to create a room | ||||
""" | """ | ||||
for callback in self._user_may_create_room_callbacks: | for callback in self._user_may_create_room_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
res = await delay_cancellation(callback(userid)) | res = await delay_cancellation(callback(userid)) | ||||
if res is True or res is self.NOT_SPAM: | if res is True or res is self.NOT_SPAM: | ||||
continue | continue | ||||
@@ -666,9 +654,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
""" | """ | ||||
for callback in self._user_may_create_room_alias_callbacks: | for callback in self._user_may_create_room_alias_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
res = await delay_cancellation(callback(userid, room_alias)) | res = await delay_cancellation(callback(userid, room_alias)) | ||||
if res is True or res is self.NOT_SPAM: | if res is True or res is self.NOT_SPAM: | ||||
continue | continue | ||||
@@ -701,9 +687,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
room_id: The ID of the room that would be published | room_id: The ID of the room that would be published | ||||
""" | """ | ||||
for callback in self._user_may_publish_room_callbacks: | for callback in self._user_may_publish_room_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
res = await delay_cancellation(callback(userid, room_id)) | res = await delay_cancellation(callback(userid, room_id)) | ||||
if res is True or res is self.NOT_SPAM: | if res is True or res is self.NOT_SPAM: | ||||
continue | continue | ||||
@@ -742,9 +726,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
True if the user is spammy. | True if the user is spammy. | ||||
""" | """ | ||||
for callback in self._check_username_for_spam_callbacks: | for callback in self._check_username_for_spam_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
# Make a copy of the user profile object to ensure the spam checker cannot | # Make a copy of the user profile object to ensure the spam checker cannot | ||||
# modify it. | # modify it. | ||||
res = await delay_cancellation(callback(user_profile.copy())) | res = await delay_cancellation(callback(user_profile.copy())) | ||||
@@ -776,9 +758,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
""" | """ | ||||
for callback in self._check_registration_for_spam_callbacks: | for callback in self._check_registration_for_spam_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
behaviour = await delay_cancellation( | behaviour = await delay_cancellation( | ||||
callback(email_threepid, username, request_info, auth_provider_id) | callback(email_threepid, username, request_info, auth_provider_id) | ||||
) | ) | ||||
@@ -820,9 +800,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
""" | """ | ||||
for callback in self._check_media_file_for_spam_callbacks: | for callback in self._check_media_file_for_spam_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
res = await delay_cancellation(callback(file_wrapper, file_info)) | res = await delay_cancellation(callback(file_wrapper, file_info)) | ||||
# Normalize return values to `Codes` or `"NOT_SPAM"`. | # Normalize return values to `Codes` or `"NOT_SPAM"`. | ||||
if res is False or res is self.NOT_SPAM: | if res is False or res is self.NOT_SPAM: | ||||
@@ -869,9 +847,7 @@ class SpamCheckerModuleApiCallbacks: | |||||
""" | """ | ||||
for callback in self._check_login_for_spam_callbacks: | for callback in self._check_login_for_spam_callbacks: | ||||
with Measure( | |||||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||||
): | |||||
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"): | |||||
res = await delay_cancellation( | res = await delay_cancellation( | ||||
callback( | callback( | ||||
user_id, | user_id, | ||||
@@ -17,6 +17,7 @@ from typing import ( | |||||
TYPE_CHECKING, | TYPE_CHECKING, | ||||
Any, | Any, | ||||
Awaitable, | Awaitable, | ||||
Deque, | |||||
Dict, | Dict, | ||||
Iterable, | Iterable, | ||||
Iterator, | Iterator, | ||||
@@ -29,7 +30,6 @@ from typing import ( | |||||
) | ) | ||||
from prometheus_client import Counter | from prometheus_client import Counter | ||||
from typing_extensions import Deque | |||||
from twisted.internet.protocol import ReconnectingClientFactory | from twisted.internet.protocol import ReconnectingClientFactory | ||||
@@ -13,10 +13,9 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Optional, Tuple, Union, cast | |||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast | |||||
from canonicaljson import encode_canonical_json | from canonicaljson import encode_canonical_json | ||||
from typing_extensions import TYPE_CHECKING | |||||
from synapse.api.errors import Codes, StoreError, SynapseError | from synapse.api.errors import Codes, StoreError, SynapseError | ||||
from synapse.storage._base import SQLBaseStore, db_to_json | from synapse.storage._base import SQLBaseStore, db_to_json | ||||
@@ -188,7 +188,7 @@ class KeyStore(SQLBaseStore): | |||||
# invalidate takes a tuple corresponding to the params of | # invalidate takes a tuple corresponding to the params of | ||||
# _get_server_keys_json. _get_server_keys_json only takes one | # _get_server_keys_json. _get_server_keys_json only takes one | ||||
# param, which is itself the 2-tuple (server_name, key_id). | # param, which is itself the 2-tuple (server_name, key_id). | ||||
self._get_server_keys_json.invalidate((((server_name, key_id),))) | |||||
self._get_server_keys_json.invalidate(((server_name, key_id),)) | |||||
@cached() | @cached() | ||||
def _get_server_keys_json( | def _get_server_keys_json( | ||||
@@ -19,6 +19,7 @@ from itertools import chain | |||||
from typing import ( | from typing import ( | ||||
TYPE_CHECKING, | TYPE_CHECKING, | ||||
Any, | Any, | ||||
Counter, | |||||
Dict, | Dict, | ||||
Iterable, | Iterable, | ||||
List, | List, | ||||
@@ -28,8 +29,6 @@ from typing import ( | |||||
cast, | cast, | ||||
) | ) | ||||
from typing_extensions import Counter | |||||
from twisted.internet.defer import DeferredLock | from twisted.internet.defer import DeferredLock | ||||
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership | from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership | ||||
@@ -145,5 +145,5 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM | |||||
This is not provided by DBAPI2, and so needs engine-specific support. | This is not provided by DBAPI2, and so needs engine-specific support. | ||||
""" | """ | ||||
with open(filepath, "rt") as f: | |||||
with open(filepath) as f: | |||||
cls.executescript(cursor, f.read()) | cls.executescript(cursor, f.read()) |
@@ -16,10 +16,18 @@ import logging | |||||
import os | import os | ||||
import re | import re | ||||
from collections import Counter | from collections import Counter | ||||
from typing import Collection, Generator, Iterable, List, Optional, TextIO, Tuple | |||||
from typing import ( | |||||
Collection, | |||||
Counter as CounterType, | |||||
Generator, | |||||
Iterable, | |||||
List, | |||||
Optional, | |||||
TextIO, | |||||
Tuple, | |||||
) | |||||
import attr | import attr | ||||
from typing_extensions import Counter as CounterType | |||||
from synapse.config.homeserver import HomeServerConfig | from synapse.config.homeserver import HomeServerConfig | ||||
from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction | from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction | ||||
@@ -21,6 +21,7 @@ from typing import ( | |||||
Any, | Any, | ||||
ClassVar, | ClassVar, | ||||
Dict, | Dict, | ||||
Final, | |||||
List, | List, | ||||
Mapping, | Mapping, | ||||
Match, | Match, | ||||
@@ -38,7 +39,7 @@ import attr | |||||
from immutabledict import immutabledict | from immutabledict import immutabledict | ||||
from signedjson.key import decode_verify_key_bytes | from signedjson.key import decode_verify_key_bytes | ||||
from signedjson.types import VerifyKey | from signedjson.types import VerifyKey | ||||
from typing_extensions import Final, TypedDict | |||||
from typing_extensions import TypedDict | |||||
from unpaddedbase64 import decode_base64 | from unpaddedbase64 import decode_base64 | ||||
from zope.interface import Interface | from zope.interface import Interface | ||||
@@ -22,6 +22,7 @@ import logging | |||||
from contextlib import asynccontextmanager | from contextlib import asynccontextmanager | ||||
from typing import ( | from typing import ( | ||||
Any, | Any, | ||||
AsyncContextManager, | |||||
AsyncIterator, | AsyncIterator, | ||||
Awaitable, | Awaitable, | ||||
Callable, | Callable, | ||||
@@ -42,7 +43,7 @@ from typing import ( | |||||
) | ) | ||||
import attr | import attr | ||||
from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec | |||||
from typing_extensions import Concatenate, Literal, ParamSpec | |||||
from twisted.internet import defer | from twisted.internet import defer | ||||
from twisted.internet.defer import CancelledError | from twisted.internet.defer import CancelledError | ||||
@@ -218,7 +218,7 @@ class MacaroonGenerator: | |||||
# to avoid validating those as guest tokens, we explicitely verify if | # to avoid validating those as guest tokens, we explicitely verify if | ||||
# the macaroon includes the "guest = true" caveat. | # the macaroon includes the "guest = true" caveat. | ||||
is_guest = any( | is_guest = any( | ||||
(caveat.caveat_id == "guest = true" for caveat in macaroon.caveats) | |||||
caveat.caveat_id == "guest = true" for caveat in macaroon.caveats | |||||
) | ) | ||||
if not is_guest: | if not is_guest: | ||||
@@ -20,6 +20,7 @@ import typing | |||||
from typing import ( | from typing import ( | ||||
Any, | Any, | ||||
Callable, | Callable, | ||||
ContextManager, | |||||
DefaultDict, | DefaultDict, | ||||
Dict, | Dict, | ||||
Iterator, | Iterator, | ||||
@@ -33,7 +34,6 @@ from typing import ( | |||||
from weakref import WeakSet | from weakref import WeakSet | ||||
from prometheus_client.core import Counter | from prometheus_client.core import Counter | ||||
from typing_extensions import ContextManager | |||||
from twisted.internet import defer | from twisted.internet import defer | ||||
@@ -17,6 +17,7 @@ from enum import Enum, auto | |||||
from typing import ( | from typing import ( | ||||
Collection, | Collection, | ||||
Dict, | Dict, | ||||
Final, | |||||
FrozenSet, | FrozenSet, | ||||
List, | List, | ||||
Mapping, | Mapping, | ||||
@@ -27,7 +28,6 @@ from typing import ( | |||||
) | ) | ||||
import attr | import attr | ||||
from typing_extensions import Final | |||||
from synapse.api.constants import EventTypes, HistoryVisibility, Membership | from synapse.api.constants import EventTypes, HistoryVisibility, Membership | ||||
from synapse.events import EventBase | from synapse.events import EventBase | ||||
@@ -26,7 +26,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): | |||||
def make_homeserver( | def make_homeserver( | ||||
self, reactor: ThreadedMemoryReactorClock, clock: Clock | self, reactor: ThreadedMemoryReactorClock, clock: Clock | ||||
) -> HomeServer: | ) -> HomeServer: | ||||
hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock) | |||||
hs = super().make_homeserver(reactor, clock) | |||||
# We don't want our tests to actually report statistics, so check | # We don't want our tests to actually report statistics, so check | ||||
# that it's not enabled | # that it's not enabled | ||||
@@ -312,7 +312,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||||
[("server9", get_key_id(key1))] | [("server9", get_key_id(key1))] | ||||
) | ) | ||||
result = self.get_success(d) | result = self.get_success(d) | ||||
self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0) | |||||
self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0) | |||||
def test_verify_json_dedupes_key_requests(self) -> None: | def test_verify_json_dedupes_key_requests(self) -> None: | ||||
"""Two requests for the same key should be deduped.""" | """Two requests for the same key should be deduped.""" | ||||
@@ -514,7 +514,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||||
self.assertEqual(response.code, 200) | self.assertEqual(response.code, 200) | ||||
# Send the body | # Send the body | ||||
request.write('{ "a": 1 }'.encode("ascii")) | |||||
request.write(b'{ "a": 1 }') | |||||
request.finish() | request.finish() | ||||
self.reactor.pump((0.1,)) | self.reactor.pump((0.1,)) | ||||
@@ -757,7 +757,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): | |||||
self.assertEqual(channel.json_body["creator"], user_id) | self.assertEqual(channel.json_body["creator"], user_id) | ||||
# Check room alias. | # Check room alias. | ||||
self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}") | |||||
self.assertEqual(room_alias, f"#foo-bar:{self.module_api.server_name}") | |||||
# Let's try a room with no alias. | # Let's try a room with no alias. | ||||
room_id, room_alias = self.get_success( | room_id, room_alias = self.get_success( | ||||
@@ -116,7 +116,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): | |||||
self.assertEqual(request.method, b"GET") | self.assertEqual(request.method, b"GET") | ||||
self.assertEqual( | self.assertEqual( | ||||
request.path, | request.path, | ||||
f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"), | |||||
f"/_matrix/media/r0/download/{target}/{media_id}".encode(), | |||||
) | ) | ||||
self.assertEqual( | self.assertEqual( | ||||
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] | request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] | ||||
@@ -627,8 +627,8 @@ class RedactionsTestCase(HomeserverTestCase): | |||||
redact_event = timeline[-1] | redact_event = timeline[-1] | ||||
self.assertEqual(redact_event["type"], EventTypes.Redaction) | self.assertEqual(redact_event["type"], EventTypes.Redaction) | ||||
# The redacts key should be in the content and the redacts keys. | # The redacts key should be in the content and the redacts keys. | ||||
self.assertEquals(redact_event["content"]["redacts"], event_id) | |||||
self.assertEquals(redact_event["redacts"], event_id) | |||||
self.assertEqual(redact_event["content"]["redacts"], event_id) | |||||
self.assertEqual(redact_event["redacts"], event_id) | |||||
# But it isn't actually part of the event. | # But it isn't actually part of the event. | ||||
def get_event(txn: LoggingTransaction) -> JsonDict: | def get_event(txn: LoggingTransaction) -> JsonDict: | ||||
@@ -642,10 +642,10 @@ class RedactionsTestCase(HomeserverTestCase): | |||||
event_json = self.get_success( | event_json = self.get_success( | ||||
main_datastore.db_pool.runInteraction("get_event", get_event) | main_datastore.db_pool.runInteraction("get_event", get_event) | ||||
) | ) | ||||
self.assertEquals(event_json["type"], EventTypes.Redaction) | |||||
self.assertEqual(event_json["type"], EventTypes.Redaction) | |||||
if expect_content: | if expect_content: | ||||
self.assertNotIn("redacts", event_json) | self.assertNotIn("redacts", event_json) | ||||
self.assertEquals(event_json["content"]["redacts"], event_id) | |||||
self.assertEqual(event_json["content"]["redacts"], event_id) | |||||
else: | else: | ||||
self.assertEquals(event_json["redacts"], event_id) | |||||
self.assertEqual(event_json["redacts"], event_id) | |||||
self.assertNotIn("redacts", event_json["content"]) | self.assertNotIn("redacts", event_json["content"]) |
@@ -129,7 +129,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): | |||||
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}", | f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}", | ||||
access_token=self.user_token, | access_token=self.user_token, | ||||
) | ) | ||||
self.assertEquals(200, channel.code, channel.json_body) | |||||
self.assertEqual(200, channel.code, channel.json_body) | |||||
return [ev["event_id"] for ev in channel.json_body["chunk"]] | return [ev["event_id"] for ev in channel.json_body["chunk"]] | ||||
def _get_bundled_aggregations(self) -> JsonDict: | def _get_bundled_aggregations(self) -> JsonDict: | ||||
@@ -142,7 +142,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): | |||||
f"/_matrix/client/v3/rooms/{self.room}/event/{self.parent_id}", | f"/_matrix/client/v3/rooms/{self.room}/event/{self.parent_id}", | ||||
access_token=self.user_token, | access_token=self.user_token, | ||||
) | ) | ||||
self.assertEquals(200, channel.code, channel.json_body) | |||||
self.assertEqual(200, channel.code, channel.json_body) | |||||
return channel.json_body["unsigned"].get("m.relations", {}) | return channel.json_body["unsigned"].get("m.relations", {}) | ||||
def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: | def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: | ||||
@@ -1602,7 +1602,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): | |||||
f"/_matrix/client/v1/rooms/{self.room}/threads", | f"/_matrix/client/v1/rooms/{self.room}/threads", | ||||
access_token=self.user_token, | access_token=self.user_token, | ||||
) | ) | ||||
self.assertEquals(200, channel.code, channel.json_body) | |||||
self.assertEqual(200, channel.code, channel.json_body) | |||||
threads = channel.json_body["chunk"] | threads = channel.json_body["chunk"] | ||||
return [ | return [ | ||||
( | ( | ||||
@@ -1634,7 +1634,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): | |||||
################################################## | ################################################## | ||||
# Check the test data is configured as expected. # | # Check the test data is configured as expected. # | ||||
################################################## | ################################################## | ||||
self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) | |||||
self.assertEqual(self._get_related_events(), list(reversed(thread_replies))) | |||||
relations = self._get_bundled_aggregations() | relations = self._get_bundled_aggregations() | ||||
self.assertDictContainsSubset( | self.assertDictContainsSubset( | ||||
{"count": 3, "current_user_participated": True}, | {"count": 3, "current_user_participated": True}, | ||||
@@ -1655,7 +1655,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): | |||||
self._redact(thread_replies.pop()) | self._redact(thread_replies.pop()) | ||||
# The thread should still exist, but the latest event should be updated. | # The thread should still exist, but the latest event should be updated. | ||||
self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) | |||||
self.assertEqual(self._get_related_events(), list(reversed(thread_replies))) | |||||
relations = self._get_bundled_aggregations() | relations = self._get_bundled_aggregations() | ||||
self.assertDictContainsSubset( | self.assertDictContainsSubset( | ||||
{"count": 2, "current_user_participated": True}, | {"count": 2, "current_user_participated": True}, | ||||
@@ -1674,7 +1674,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): | |||||
self._redact(thread_replies.pop(0)) | self._redact(thread_replies.pop(0)) | ||||
# Nothing should have changed (except the thread count). | # Nothing should have changed (except the thread count). | ||||
self.assertEquals(self._get_related_events(), thread_replies) | |||||
self.assertEqual(self._get_related_events(), thread_replies) | |||||
relations = self._get_bundled_aggregations() | relations = self._get_bundled_aggregations() | ||||
self.assertDictContainsSubset( | self.assertDictContainsSubset( | ||||
{"count": 1, "current_user_participated": True}, | {"count": 1, "current_user_participated": True}, | ||||
@@ -1691,11 +1691,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): | |||||
# Redact the last remaining event. # | # Redact the last remaining event. # | ||||
#################################### | #################################### | ||||
self._redact(thread_replies.pop(0)) | self._redact(thread_replies.pop(0)) | ||||
self.assertEquals(thread_replies, []) | |||||
self.assertEqual(thread_replies, []) | |||||
# The event should no longer be considered a thread. | # The event should no longer be considered a thread. | ||||
self.assertEquals(self._get_related_events(), []) | |||||
self.assertEquals(self._get_bundled_aggregations(), {}) | |||||
self.assertEqual(self._get_related_events(), []) | |||||
self.assertEqual(self._get_bundled_aggregations(), {}) | |||||
self.assertEqual(self._get_threads(), []) | self.assertEqual(self._get_threads(), []) | ||||
def test_redact_parent_edit(self) -> None: | def test_redact_parent_edit(self) -> None: | ||||
@@ -1749,8 +1749,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): | |||||
# The relations are returned. | # The relations are returned. | ||||
event_ids = self._get_related_events() | event_ids = self._get_related_events() | ||||
relations = self._get_bundled_aggregations() | relations = self._get_bundled_aggregations() | ||||
self.assertEquals(event_ids, [related_event_id]) | |||||
self.assertEquals( | |||||
self.assertEqual(event_ids, [related_event_id]) | |||||
self.assertEqual( | |||||
relations[RelationTypes.REFERENCE], | relations[RelationTypes.REFERENCE], | ||||
{"chunk": [{"event_id": related_event_id}]}, | {"chunk": [{"event_id": related_event_id}]}, | ||||
) | ) | ||||
@@ -1772,7 +1772,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): | |||||
# The unredacted relation should still exist. | # The unredacted relation should still exist. | ||||
event_ids = self._get_related_events() | event_ids = self._get_related_events() | ||||
relations = self._get_bundled_aggregations() | relations = self._get_bundled_aggregations() | ||||
self.assertEquals(len(event_ids), 1) | |||||
self.assertEqual(len(event_ids), 1) | |||||
self.assertDictContainsSubset( | self.assertDictContainsSubset( | ||||
{ | { | ||||
"count": 1, | "count": 1, | ||||
@@ -1816,7 +1816,7 @@ class ThreadsTestCase(BaseRelationsTestCase): | |||||
f"/_matrix/client/v1/rooms/{self.room}/threads", | f"/_matrix/client/v1/rooms/{self.room}/threads", | ||||
access_token=self.user_token, | access_token=self.user_token, | ||||
) | ) | ||||
self.assertEquals(200, channel.code, channel.json_body) | |||||
self.assertEqual(200, channel.code, channel.json_body) | |||||
threads = self._get_threads(channel.json_body) | threads = self._get_threads(channel.json_body) | ||||
self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)]) | self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)]) | ||||
@@ -1829,7 +1829,7 @@ class ThreadsTestCase(BaseRelationsTestCase): | |||||
f"/_matrix/client/v1/rooms/{self.room}/threads", | f"/_matrix/client/v1/rooms/{self.room}/threads", | ||||
access_token=self.user_token, | access_token=self.user_token, | ||||
) | ) | ||||
self.assertEquals(200, channel.code, channel.json_body) | |||||
self.assertEqual(200, channel.code, channel.json_body) | |||||
# Tuple of (thread ID, latest event ID) for each thread. | # Tuple of (thread ID, latest event ID) for each thread. | ||||
threads = self._get_threads(channel.json_body) | threads = self._get_threads(channel.json_body) | ||||
self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)]) | self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)]) | ||||
@@ -1850,7 +1850,7 @@ class ThreadsTestCase(BaseRelationsTestCase): | |||||
f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1", | f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1", | ||||
access_token=self.user_token, | access_token=self.user_token, | ||||
) | ) | ||||
self.assertEquals(200, channel.code, channel.json_body) | |||||
self.assertEqual(200, channel.code, channel.json_body) | |||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] | thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] | ||||
self.assertEqual(thread_roots, [thread_2]) | self.assertEqual(thread_roots, [thread_2]) | ||||
@@ -1864,7 +1864,7 @@ class ThreadsTestCase(BaseRelationsTestCase): | |||||
f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}", | f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}", | ||||
access_token=self.user_token, | access_token=self.user_token, | ||||
) | ) | ||||
self.assertEquals(200, channel.code, channel.json_body) | |||||
self.assertEqual(200, channel.code, channel.json_body) | |||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] | thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] | ||||
self.assertEqual(thread_roots, [thread_1], channel.json_body) | self.assertEqual(thread_roots, [thread_1], channel.json_body) | ||||
@@ -1899,7 +1899,7 @@ class ThreadsTestCase(BaseRelationsTestCase): | |||||
f"/_matrix/client/v1/rooms/{self.room}/threads", | f"/_matrix/client/v1/rooms/{self.room}/threads", | ||||
access_token=self.user_token, | access_token=self.user_token, | ||||
) | ) | ||||
self.assertEquals(200, channel.code, channel.json_body) | |||||
self.assertEqual(200, channel.code, channel.json_body) | |||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] | thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] | ||||
self.assertEqual( | self.assertEqual( | ||||
thread_roots, [thread_3, thread_2, thread_1], channel.json_body | thread_roots, [thread_3, thread_2, thread_1], channel.json_body | ||||
@@ -1911,7 +1911,7 @@ class ThreadsTestCase(BaseRelationsTestCase): | |||||
f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated", | f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated", | ||||
access_token=self.user_token, | access_token=self.user_token, | ||||
) | ) | ||||
self.assertEquals(200, channel.code, channel.json_body) | |||||
self.assertEqual(200, channel.code, channel.json_body) | |||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] | thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] | ||||
self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) | self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) | ||||
@@ -1943,6 +1943,6 @@ class ThreadsTestCase(BaseRelationsTestCase): | |||||
f"/_matrix/client/v1/rooms/{self.room}/threads", | f"/_matrix/client/v1/rooms/{self.room}/threads", | ||||
access_token=self.user_token, | access_token=self.user_token, | ||||
) | ) | ||||
self.assertEquals(200, channel.code, channel.json_body) | |||||
self.assertEqual(200, channel.code, channel.json_body) | |||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] | thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] | ||||
self.assertEqual(thread_roots, [thread_1], channel.json_body) | self.assertEqual(thread_roots, [thread_1], channel.json_body) |
@@ -1362,7 +1362,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): | |||||
# Ensure the event was persisted with the correct timestamp. | # Ensure the event was persisted with the correct timestamp. | ||||
res = self.get_success(self.main_store.get_event(event_id)) | res = self.get_success(self.main_store.get_event(event_id)) | ||||
self.assertEquals(ts, res.origin_server_ts) | |||||
self.assertEqual(ts, res.origin_server_ts) | |||||
def test_send_state_event_ts(self) -> None: | def test_send_state_event_ts(self) -> None: | ||||
"""Test sending a state event with a custom timestamp.""" | """Test sending a state event with a custom timestamp.""" | ||||
@@ -1384,7 +1384,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): | |||||
# Ensure the event was persisted with the correct timestamp. | # Ensure the event was persisted with the correct timestamp. | ||||
res = self.get_success(self.main_store.get_event(event_id)) | res = self.get_success(self.main_store.get_event(event_id)) | ||||
self.assertEquals(ts, res.origin_server_ts) | |||||
self.assertEqual(ts, res.origin_server_ts) | |||||
def test_send_membership_event_ts(self) -> None: | def test_send_membership_event_ts(self) -> None: | ||||
"""Test sending a membership event with a custom timestamp.""" | """Test sending a membership event with a custom timestamp.""" | ||||
@@ -1406,7 +1406,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): | |||||
# Ensure the event was persisted with the correct timestamp. | # Ensure the event was persisted with the correct timestamp. | ||||
res = self.get_success(self.main_store.get_event(event_id)) | res = self.get_success(self.main_store.get_event(event_id)) | ||||
self.assertEquals(ts, res.origin_server_ts) | |||||
self.assertEqual(ts, res.origin_server_ts) | |||||
class RoomJoinRatelimitTestCase(RoomBase): | class RoomJoinRatelimitTestCase(RoomBase): | ||||
@@ -26,6 +26,7 @@ from typing import ( | |||||
Any, | Any, | ||||
Awaitable, | Awaitable, | ||||
Callable, | Callable, | ||||
Deque, | |||||
Dict, | Dict, | ||||
Iterable, | Iterable, | ||||
List, | List, | ||||
@@ -41,7 +42,7 @@ from typing import ( | |||||
from unittest.mock import Mock | from unittest.mock import Mock | ||||
import attr | import attr | ||||
from typing_extensions import Deque, ParamSpec | |||||
from typing_extensions import ParamSpec | |||||
from zope.interface import implementer | from zope.interface import implementer | ||||
from twisted.internet import address, threads, udp | from twisted.internet import address, threads, udp | ||||
@@ -40,7 +40,7 @@ from tests.test_utils import make_awaitable | |||||
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): | class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): | ||||
def setUp(self) -> None: | def setUp(self) -> None: | ||||
super(ApplicationServiceStoreTestCase, self).setUp() | |||||
super().setUp() | |||||
self.as_yaml_files: List[str] = [] | self.as_yaml_files: List[str] = [] | ||||
@@ -71,7 +71,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): | |||||
except Exception: | except Exception: | ||||
pass | pass | ||||
super(ApplicationServiceStoreTestCase, self).tearDown() | |||||
super().tearDown() | |||||
def _add_appservice( | def _add_appservice( | ||||
self, as_token: str, id: str, url: str, hs_token: str, sender: str | self, as_token: str, id: str, url: str, hs_token: str, sender: str | ||||
@@ -110,7 +110,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): | |||||
class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): | class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): | ||||
def setUp(self) -> None: | def setUp(self) -> None: | ||||
super(ApplicationServiceTransactionStoreTestCase, self).setUp() | |||||
super().setUp() | |||||
self.as_yaml_files: List[str] = [] | self.as_yaml_files: List[str] = [] | ||||
self.hs.config.appservice.app_service_config_files = self.as_yaml_files | self.hs.config.appservice.app_service_config_files = self.as_yaml_files | ||||
@@ -20,7 +20,7 @@ from tests import unittest | |||||
class DataStoreTestCase(unittest.HomeserverTestCase): | class DataStoreTestCase(unittest.HomeserverTestCase): | ||||
def setUp(self) -> None: | def setUp(self) -> None: | ||||
super(DataStoreTestCase, self).setUp() | |||||
super().setUp() | |||||
self.store = self.hs.get_datastores().main | self.store = self.hs.get_datastores().main | ||||
@@ -318,14 +318,14 @@ class MessageSearchTest(HomeserverTestCase): | |||||
result = self.get_success( | result = self.get_success( | ||||
store.search_msgs([self.room_id], query, ["content.body"]) | store.search_msgs([self.room_id], query, ["content.body"]) | ||||
) | ) | ||||
self.assertEquals( | |||||
self.assertEqual( | |||||
result["count"], | result["count"], | ||||
1 if expect_to_contain else 0, | 1 if expect_to_contain else 0, | ||||
f"expected '{query}' to match '{self.PHRASE}'" | f"expected '{query}' to match '{self.PHRASE}'" | ||||
if expect_to_contain | if expect_to_contain | ||||
else f"'{query}' unexpectedly matched '{self.PHRASE}'", | else f"'{query}' unexpectedly matched '{self.PHRASE}'", | ||||
) | ) | ||||
self.assertEquals( | |||||
self.assertEqual( | |||||
len(result["results"]), | len(result["results"]), | ||||
1 if expect_to_contain else 0, | 1 if expect_to_contain else 0, | ||||
"results array length should match count", | "results array length should match count", | ||||
@@ -336,14 +336,14 @@ class MessageSearchTest(HomeserverTestCase): | |||||
result = self.get_success( | result = self.get_success( | ||||
store.search_rooms([self.room_id], query, ["content.body"], 10) | store.search_rooms([self.room_id], query, ["content.body"], 10) | ||||
) | ) | ||||
self.assertEquals( | |||||
self.assertEqual( | |||||
result["count"], | result["count"], | ||||
1 if expect_to_contain else 0, | 1 if expect_to_contain else 0, | ||||
f"expected '{query}' to match '{self.PHRASE}'" | f"expected '{query}' to match '{self.PHRASE}'" | ||||
if expect_to_contain | if expect_to_contain | ||||
else f"'{query}' unexpectedly matched '{self.PHRASE}'", | else f"'{query}' unexpectedly matched '{self.PHRASE}'", | ||||
) | ) | ||||
self.assertEquals( | |||||
self.assertEqual( | |||||
len(result["results"]), | len(result["results"]), | ||||
1 if expect_to_contain else 0, | 1 if expect_to_contain else 0, | ||||
"results array length should match count", | "results array length should match count", | ||||
@@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM" | |||||
class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | ||||
def setUp(self) -> None: | def setUp(self) -> None: | ||||
super(FilterEventsForServerTestCase, self).setUp() | |||||
super().setUp() | |||||
self.event_creation_handler = self.hs.get_event_creation_handler() | self.event_creation_handler = self.hs.get_event_creation_handler() | ||||
self.event_builder_factory = self.hs.get_event_builder_factory() | self.event_builder_factory = self.hs.get_event_builder_factory() | ||||
self._storage_controllers = self.hs.get_storage_controllers() | self._storage_controllers = self.hs.get_storage_controllers() | ||||