@@ -0,0 +1 @@ | |||||
Improve type hints. |
@@ -60,24 +60,6 @@ disallow_untyped_defs = False | |||||
[mypy-synapse.storage.database] | [mypy-synapse.storage.database] | ||||
disallow_untyped_defs = False | disallow_untyped_defs = False | ||||
[mypy-tests.scripts.test_new_matrix_user] | |||||
disallow_untyped_defs = False | |||||
[mypy-tests.server_notices.test_consent] | |||||
disallow_untyped_defs = False | |||||
[mypy-tests.server_notices.test_resource_limits_server_notices] | |||||
disallow_untyped_defs = False | |||||
[mypy-tests.test_federation] | |||||
disallow_untyped_defs = False | |||||
[mypy-tests.test_utils.*] | |||||
disallow_untyped_defs = False | |||||
[mypy-tests.test_visibility] | |||||
disallow_untyped_defs = False | |||||
[mypy-tests.unittest] | [mypy-tests.unittest] | ||||
disallow_untyped_defs = False | disallow_untyped_defs = False | ||||
@@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase): | |||||
hs = self.setup_test_homeserver() | hs = self.setup_test_homeserver() | ||||
self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) | self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) | ||||
self.hs_patcher.start() | |||||
self.hs_patcher.start() # type: ignore[attr-defined] | |||||
self.handler = hs.get_oidc_handler() | self.handler = hs.get_oidc_handler() | ||||
self.provider = self.handler._providers["oidc"] | self.provider = self.handler._providers["oidc"] | ||||
@@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase): | |||||
return hs | return hs | ||||
def tearDown(self) -> None: | def tearDown(self) -> None: | ||||
self.hs_patcher.stop() | |||||
self.hs_patcher.stop() # type: ignore[attr-defined] | |||||
return super().tearDown() | return super().tearDown() | ||||
def reset_mocks(self) -> None: | def reset_mocks(self) -> None: | ||||
@@ -12,29 +12,33 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import List | |||||
from typing import List, Optional | |||||
from unittest.mock import Mock, patch | from unittest.mock import Mock, patch | ||||
from synapse._scripts.register_new_matrix_user import request_registration | from synapse._scripts.register_new_matrix_user import request_registration | ||||
from synapse.types import JsonDict | |||||
from tests.unittest import TestCase | from tests.unittest import TestCase | ||||
class RegisterTestCase(TestCase): | class RegisterTestCase(TestCase): | ||||
def test_success(self): | |||||
def test_success(self) -> None: | |||||
""" | """ | ||||
The script will fetch a nonce, and then generate a MAC with it, and then | The script will fetch a nonce, and then generate a MAC with it, and then | ||||
post that MAC. | post that MAC. | ||||
""" | """ | ||||
def get(url, verify=None): | |||||
def get(url: str, verify: Optional[bool] = None) -> Mock: | |||||
r = Mock() | r = Mock() | ||||
r.status_code = 200 | r.status_code = 200 | ||||
r.json = lambda: {"nonce": "a"} | r.json = lambda: {"nonce": "a"} | ||||
return r | return r | ||||
def post(url, json=None, verify=None): | |||||
def post( | |||||
url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None | |||||
) -> Mock: | |||||
# Make sure we are sent the correct info | # Make sure we are sent the correct info | ||||
assert json is not None | |||||
self.assertEqual(json["username"], "user") | self.assertEqual(json["username"], "user") | ||||
self.assertEqual(json["password"], "pass") | self.assertEqual(json["password"], "pass") | ||||
self.assertEqual(json["nonce"], "a") | self.assertEqual(json["nonce"], "a") | ||||
@@ -70,12 +74,12 @@ class RegisterTestCase(TestCase): | |||||
# sys.exit shouldn't have been called. | # sys.exit shouldn't have been called. | ||||
self.assertEqual(err_code, []) | self.assertEqual(err_code, []) | ||||
def test_failure_nonce(self): | |||||
def test_failure_nonce(self) -> None: | |||||
""" | """ | ||||
If the script fails to fetch a nonce, it throws an error and quits. | If the script fails to fetch a nonce, it throws an error and quits. | ||||
""" | """ | ||||
def get(url, verify=None): | |||||
def get(url: str, verify: Optional[bool] = None) -> Mock: | |||||
r = Mock() | r = Mock() | ||||
r.status_code = 404 | r.status_code = 404 | ||||
r.reason = "Not Found" | r.reason = "Not Found" | ||||
@@ -107,20 +111,23 @@ class RegisterTestCase(TestCase): | |||||
self.assertIn("ERROR! Received 404 Not Found", out) | self.assertIn("ERROR! Received 404 Not Found", out) | ||||
self.assertNotIn("Success!", out) | self.assertNotIn("Success!", out) | ||||
def test_failure_post(self): | |||||
def test_failure_post(self) -> None: | |||||
""" | """ | ||||
The script will fetch a nonce, and then if the final POST fails, will | The script will fetch a nonce, and then if the final POST fails, will | ||||
report an error and quit. | report an error and quit. | ||||
""" | """ | ||||
def get(url, verify=None): | |||||
def get(url: str, verify: Optional[bool] = None) -> Mock: | |||||
r = Mock() | r = Mock() | ||||
r.status_code = 200 | r.status_code = 200 | ||||
r.json = lambda: {"nonce": "a"} | r.json = lambda: {"nonce": "a"} | ||||
return r | return r | ||||
def post(url, json=None, verify=None): | |||||
def post( | |||||
url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None | |||||
) -> Mock: | |||||
# Make sure we are sent the correct info | # Make sure we are sent the correct info | ||||
assert json is not None | |||||
self.assertEqual(json["username"], "user") | self.assertEqual(json["username"], "user") | ||||
self.assertEqual(json["password"], "pass") | self.assertEqual(json["password"], "pass") | ||||
self.assertEqual(json["nonce"], "a") | self.assertEqual(json["nonce"], "a") | ||||
@@ -14,8 +14,12 @@ | |||||
import os | import os | ||||
from twisted.test.proto_helpers import MemoryReactor | |||||
import synapse.rest.admin | import synapse.rest.admin | ||||
from synapse.rest.client import login, room, sync | from synapse.rest.client import login, room, sync | ||||
from synapse.server import HomeServer | |||||
from synapse.util import Clock | |||||
from tests import unittest | from tests import unittest | ||||
@@ -29,7 +33,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): | |||||
room.register_servlets, | room.register_servlets, | ||||
] | ] | ||||
def make_homeserver(self, reactor, clock): | |||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||||
tmpdir = self.mktemp() | tmpdir = self.mktemp() | ||||
os.mkdir(tmpdir) | os.mkdir(tmpdir) | ||||
@@ -53,15 +57,13 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): | |||||
"room_name": "Server Notices", | "room_name": "Server Notices", | ||||
} | } | ||||
hs = self.setup_test_homeserver(config=config) | |||||
return hs | |||||
return self.setup_test_homeserver(config=config) | |||||
def prepare(self, reactor, clock, hs): | |||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||||
self.user_id = self.register_user("bob", "abc123") | self.user_id = self.register_user("bob", "abc123") | ||||
self.access_token = self.login("bob", "abc123") | self.access_token = self.login("bob", "abc123") | ||||
def test_get_sync_message(self): | |||||
def test_get_sync_message(self) -> None: | |||||
""" | """ | ||||
When user consent server notices are enabled, a sync will cause a notice | When user consent server notices are enabled, a sync will cause a notice | ||||
to fire (in a room which the user is invited to). The notice contains | to fire (in a room which the user is invited to). The notice contains | ||||
@@ -24,6 +24,7 @@ from synapse.server import HomeServer | |||||
from synapse.server_notices.resource_limits_server_notices import ( | from synapse.server_notices.resource_limits_server_notices import ( | ||||
ResourceLimitsServerNotices, | ResourceLimitsServerNotices, | ||||
) | ) | ||||
from synapse.types import JsonDict | |||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from tests import unittest | from tests import unittest | ||||
@@ -33,7 +34,7 @@ from tests.utils import default_config | |||||
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | ||||
def default_config(self): | |||||
def default_config(self) -> JsonDict: | |||||
config = default_config("test") | config = default_config("test") | ||||
config.update( | config.update( | ||||
@@ -86,18 +87,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] | self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] | ||||
@override_config({"hs_disabled": True}) | @override_config({"hs_disabled": True}) | ||||
def test_maybe_send_server_notice_disabled_hs(self): | |||||
def test_maybe_send_server_notice_disabled_hs(self) -> None: | |||||
"""If the HS is disabled, we should not send notices""" | """If the HS is disabled, we should not send notices""" | ||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | ||||
self._send_notice.assert_not_called() | self._send_notice.assert_not_called() | ||||
@override_config({"limit_usage_by_mau": False}) | @override_config({"limit_usage_by_mau": False}) | ||||
def test_maybe_send_server_notice_to_user_flag_off(self): | |||||
def test_maybe_send_server_notice_to_user_flag_off(self) -> None: | |||||
"""If mau limiting is disabled, we should not send notices""" | """If mau limiting is disabled, we should not send notices""" | ||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | ||||
self._send_notice.assert_not_called() | self._send_notice.assert_not_called() | ||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): | |||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: | |||||
"""Test when user has blocked notice, but should have it removed""" | """Test when user has blocked notice, but should have it removed""" | ||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( | self._rlsn._auth_blocking.check_auth_blocking = Mock( | ||||
@@ -114,7 +115,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once() | self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once() | ||||
self._send_notice.assert_called_once() | self._send_notice.assert_called_once() | ||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): | |||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None: | |||||
""" | """ | ||||
Test when user has blocked notice, but notice ought to be there (NOOP) | Test when user has blocked notice, but notice ought to be there (NOOP) | ||||
""" | """ | ||||
@@ -134,7 +135,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
self._send_notice.assert_not_called() | self._send_notice.assert_not_called() | ||||
def test_maybe_send_server_notice_to_user_add_blocked_notice(self): | |||||
def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None: | |||||
""" | """ | ||||
Test when user does not have blocked notice, but should have one | Test when user does not have blocked notice, but should have one | ||||
""" | """ | ||||
@@ -147,7 +148,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
# Would be better to check contents, but 2 calls == set blocking event | # Would be better to check contents, but 2 calls == set blocking event | ||||
self.assertEqual(self._send_notice.call_count, 2) | self.assertEqual(self._send_notice.call_count, 2) | ||||
def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): | |||||
def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None: | |||||
""" | """ | ||||
Test when user does not have blocked notice, nor should they (NOOP) | Test when user does not have blocked notice, nor should they (NOOP) | ||||
""" | """ | ||||
@@ -159,7 +160,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
self._send_notice.assert_not_called() | self._send_notice.assert_not_called() | ||||
def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): | |||||
def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None: | |||||
""" | """ | ||||
Test when user is not part of the MAU cohort - this should not ever | Test when user is not part of the MAU cohort - this should not ever | ||||
happen - but ... | happen - but ... | ||||
@@ -175,7 +176,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
self._send_notice.assert_not_called() | self._send_notice.assert_not_called() | ||||
@override_config({"mau_limit_alerting": False}) | @override_config({"mau_limit_alerting": False}) | ||||
def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): | |||||
def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked( | |||||
self, | |||||
) -> None: | |||||
""" | """ | ||||
Test that when server is over MAU limit and alerting is suppressed, then | Test that when server is over MAU limit and alerting is suppressed, then | ||||
an alert message is not sent into the room | an alert message is not sent into the room | ||||
@@ -191,7 +194,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
self.assertEqual(self._send_notice.call_count, 0) | self.assertEqual(self._send_notice.call_count, 0) | ||||
@override_config({"mau_limit_alerting": False}) | @override_config({"mau_limit_alerting": False}) | ||||
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): | |||||
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None: | |||||
""" | """ | ||||
Test that when a server is disabled, that MAU limit alerting is ignored. | Test that when a server is disabled, that MAU limit alerting is ignored. | ||||
""" | """ | ||||
@@ -207,7 +210,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
self.assertEqual(self._send_notice.call_count, 2) | self.assertEqual(self._send_notice.call_count, 2) | ||||
@override_config({"mau_limit_alerting": False}) | @override_config({"mau_limit_alerting": False}) | ||||
def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): | |||||
def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked( | |||||
self, | |||||
) -> None: | |||||
""" | """ | ||||
When the room is already in a blocked state, test that when alerting | When the room is already in a blocked state, test that when alerting | ||||
is suppressed that the room is returned to an unblocked state. | is suppressed that the room is returned to an unblocked state. | ||||
@@ -242,7 +247,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): | |||||
sync.register_servlets, | sync.register_servlets, | ||||
] | ] | ||||
def default_config(self): | |||||
def default_config(self) -> JsonDict: | |||||
c = super().default_config() | c = super().default_config() | ||||
c["server_notices"] = { | c["server_notices"] = { | ||||
"system_mxid_localpart": "server", | "system_mxid_localpart": "server", | ||||
@@ -270,7 +275,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): | |||||
self.user_id = "@user_id:test" | self.user_id = "@user_id:test" | ||||
def test_server_notice_only_sent_once(self): | |||||
def test_server_notice_only_sent_once(self) -> None: | |||||
self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) | self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) | ||||
self.store.user_last_seen_monthly_active = Mock( | self.store.user_last_seen_monthly_active = Mock( | ||||
@@ -306,7 +311,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): | |||||
self.assertEqual(count, 1) | self.assertEqual(count, 1) | ||||
def test_no_invite_without_notice(self): | |||||
def test_no_invite_without_notice(self) -> None: | |||||
"""Tests that a user doesn't get invited to a server notices room without a | """Tests that a user doesn't get invited to a server notices room without a | ||||
server notice being sent. | server notice being sent. | ||||
@@ -328,7 +333,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): | |||||
m.assert_called_once_with(user_id) | m.assert_called_once_with(user_id) | ||||
def test_invite_with_notice(self): | |||||
def test_invite_with_notice(self) -> None: | |||||
"""Tests that, if the MAU limit is hit, the server notices user invites each user | """Tests that, if the MAU limit is hit, the server notices user invites each user | ||||
to a room in which it has sent a notice. | to a room in which it has sent a notice. | ||||
""" | """ | ||||
@@ -12,53 +12,48 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Optional, Union | |||||
from unittest.mock import Mock | from unittest.mock import Mock | ||||
from twisted.internet.defer import succeed | from twisted.internet.defer import succeed | ||||
from twisted.test.proto_helpers import MemoryReactor | |||||
from synapse.api.errors import FederationError | from synapse.api.errors import FederationError | ||||
from synapse.api.room_versions import RoomVersions | from synapse.api.room_versions import RoomVersions | ||||
from synapse.events import make_event_from_dict | |||||
from synapse.events import EventBase, make_event_from_dict | |||||
from synapse.events.snapshot import EventContext | |||||
from synapse.federation.federation_base import event_from_pdu_json | from synapse.federation.federation_base import event_from_pdu_json | ||||
from synapse.http.types import QueryParams | |||||
from synapse.logging.context import LoggingContext | from synapse.logging.context import LoggingContext | ||||
from synapse.types import UserID, create_requester | |||||
from synapse.server import HomeServer | |||||
from synapse.types import JsonDict, UserID, create_requester | |||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from synapse.util.retryutils import NotRetryingDestination | from synapse.util.retryutils import NotRetryingDestination | ||||
from tests import unittest | from tests import unittest | ||||
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver | |||||
from tests.test_utils import make_awaitable | from tests.test_utils import make_awaitable | ||||
class MessageAcceptTests(unittest.HomeserverTestCase): | class MessageAcceptTests(unittest.HomeserverTestCase): | ||||
def setUp(self): | |||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||||
self.http_client = Mock() | self.http_client = Mock() | ||||
self.reactor = ThreadedMemoryReactorClock() | |||||
self.hs_clock = Clock(self.reactor) | |||||
self.homeserver = setup_test_homeserver( | |||||
self.addCleanup, | |||||
federation_http_client=self.http_client, | |||||
clock=self.hs_clock, | |||||
reactor=self.reactor, | |||||
) | |||||
return self.setup_test_homeserver(federation_http_client=self.http_client) | |||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||||
user_id = UserID("us", "test") | user_id = UserID("us", "test") | ||||
our_user = create_requester(user_id) | our_user = create_requester(user_id) | ||||
room_creator = self.homeserver.get_room_creation_handler() | |||||
room_creator = self.hs.get_room_creation_handler() | |||||
self.room_id = self.get_success( | self.room_id = self.get_success( | ||||
room_creator.create_room( | room_creator.create_room( | ||||
our_user, room_creator._presets_dict["public_chat"], ratelimit=False | our_user, room_creator._presets_dict["public_chat"], ratelimit=False | ||||
) | ) | ||||
)[0]["room_id"] | )[0]["room_id"] | ||||
self.store = self.homeserver.get_datastores().main | |||||
self.store = self.hs.get_datastores().main | |||||
# Figure out what the most recent event is | # Figure out what the most recent event is | ||||
most_recent = self.get_success( | most_recent = self.get_success( | ||||
self.homeserver.get_datastores().main.get_latest_event_ids_in_room( | |||||
self.room_id | |||||
) | |||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) | |||||
)[0] | )[0] | ||||
join_event = make_event_from_dict( | join_event = make_event_from_dict( | ||||
@@ -78,14 +73,16 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
} | } | ||||
) | ) | ||||
self.handler = self.homeserver.get_federation_handler() | |||||
federation_event_handler = self.homeserver.get_federation_event_handler() | |||||
self.handler = self.hs.get_federation_handler() | |||||
federation_event_handler = self.hs.get_federation_event_handler() | |||||
async def _check_event_auth(origin, event, context): | |||||
async def _check_event_auth( | |||||
origin: Optional[str], event: EventBase, context: EventContext | |||||
) -> None: | |||||
pass | pass | ||||
federation_event_handler._check_event_auth = _check_event_auth | federation_event_handler._check_event_auth = _check_event_auth | ||||
self.client = self.homeserver.get_federation_client() | |||||
self.client = self.hs.get_federation_client() | |||||
self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( | self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( | ||||
lambda dest, pdus, **k: succeed(pdus) | lambda dest, pdus, **k: succeed(pdus) | ||||
) | ) | ||||
@@ -104,16 +101,25 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
"$join:test.serv", | "$join:test.serv", | ||||
) | ) | ||||
def test_cant_hide_direct_ancestors(self): | |||||
def test_cant_hide_direct_ancestors(self) -> None: | |||||
""" | """ | ||||
If you send a message, you must be able to provide the direct | If you send a message, you must be able to provide the direct | ||||
prev_events that said event references. | prev_events that said event references. | ||||
""" | """ | ||||
async def post_json(destination, path, data, headers=None, timeout=0): | |||||
async def post_json( | |||||
destination: str, | |||||
path: str, | |||||
data: Optional[JsonDict] = None, | |||||
long_retries: bool = False, | |||||
timeout: Optional[int] = None, | |||||
ignore_backoff: bool = False, | |||||
args: Optional[QueryParams] = None, | |||||
) -> Union[JsonDict, list]: | |||||
# If it asks us for new missing events, give them NOTHING | # If it asks us for new missing events, give them NOTHING | ||||
if path.startswith("/_matrix/federation/v1/get_missing_events/"): | if path.startswith("/_matrix/federation/v1/get_missing_events/"): | ||||
return {"events": []} | return {"events": []} | ||||
return {} | |||||
self.http_client.post_json = post_json | self.http_client.post_json = post_json | ||||
@@ -138,7 +144,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
} | } | ||||
) | ) | ||||
federation_event_handler = self.homeserver.get_federation_event_handler() | |||||
federation_event_handler = self.hs.get_federation_event_handler() | |||||
with LoggingContext("test-context"): | with LoggingContext("test-context"): | ||||
failure = self.get_failure( | failure = self.get_failure( | ||||
federation_event_handler.on_receive_pdu("test.serv", lying_event), | federation_event_handler.on_receive_pdu("test.serv", lying_event), | ||||
@@ -158,7 +164,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) | extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) | ||||
self.assertEqual(extrem[0], "$join:test.serv") | self.assertEqual(extrem[0], "$join:test.serv") | ||||
def test_retry_device_list_resync(self): | |||||
def test_retry_device_list_resync(self) -> None: | |||||
"""Tests that device lists are marked as stale if they couldn't be synced, and | """Tests that device lists are marked as stale if they couldn't be synced, and | ||||
that stale device lists are retried periodically. | that stale device lists are retried periodically. | ||||
""" | """ | ||||
@@ -171,24 +177,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
# When this function is called, increment the number of resync attempts (only if | # When this function is called, increment the number of resync attempts (only if | ||||
# we're querying devices for the right user ID), then raise a | # we're querying devices for the right user ID), then raise a | ||||
# NotRetryingDestination error to fail the resync gracefully. | # NotRetryingDestination error to fail the resync gracefully. | ||||
def query_user_devices(destination, user_id): | |||||
def query_user_devices( | |||||
destination: str, user_id: str, timeout: int = 30000 | |||||
) -> JsonDict: | |||||
if user_id == remote_user_id: | if user_id == remote_user_id: | ||||
self.resync_attempts += 1 | self.resync_attempts += 1 | ||||
raise NotRetryingDestination(0, 0, destination) | raise NotRetryingDestination(0, 0, destination) | ||||
# Register the mock on the federation client. | # Register the mock on the federation client. | ||||
federation_client = self.homeserver.get_federation_client() | |||||
federation_client = self.hs.get_federation_client() | |||||
federation_client.query_user_devices = Mock(side_effect=query_user_devices) | federation_client.query_user_devices = Mock(side_effect=query_user_devices) | ||||
# Register a mock on the store so that the incoming update doesn't fail because | # Register a mock on the store so that the incoming update doesn't fail because | ||||
# we don't share a room with the user. | # we don't share a room with the user. | ||||
store = self.homeserver.get_datastores().main | |||||
store = self.hs.get_datastores().main | |||||
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) | store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) | ||||
# Manually inject a fake device list update. We need this update to include at | # Manually inject a fake device list update. We need this update to include at | ||||
# least one prev_id so that the user's device list will need to be retried. | # least one prev_id so that the user's device list will need to be retried. | ||||
device_list_updater = self.homeserver.get_device_handler().device_list_updater | |||||
device_list_updater = self.hs.get_device_handler().device_list_updater | |||||
self.get_success( | self.get_success( | ||||
device_list_updater.incoming_device_list_update( | device_list_updater.incoming_device_list_update( | ||||
origin=remote_origin, | origin=remote_origin, | ||||
@@ -218,7 +226,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
self.reactor.advance(30) | self.reactor.advance(30) | ||||
self.assertEqual(self.resync_attempts, 2) | self.assertEqual(self.resync_attempts, 2) | ||||
def test_cross_signing_keys_retry(self): | |||||
def test_cross_signing_keys_retry(self) -> None: | |||||
"""Tests that resyncing a device list correctly processes cross-signing keys from | """Tests that resyncing a device list correctly processes cross-signing keys from | ||||
the remote server. | the remote server. | ||||
""" | """ | ||||
@@ -227,7 +235,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" | remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" | ||||
# Register mock device list retrieval on the federation client. | # Register mock device list retrieval on the federation client. | ||||
federation_client = self.homeserver.get_federation_client() | |||||
federation_client = self.hs.get_federation_client() | |||||
federation_client.query_user_devices = Mock( | federation_client.query_user_devices = Mock( | ||||
return_value=make_awaitable( | return_value=make_awaitable( | ||||
{ | { | ||||
@@ -252,7 +260,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
) | ) | ||||
# Resync the device list. | # Resync the device list. | ||||
device_handler = self.homeserver.get_device_handler() | |||||
device_handler = self.hs.get_device_handler() | |||||
self.get_success( | self.get_success( | ||||
device_handler.device_list_updater.user_device_resync(remote_user_id), | device_handler.device_list_updater.user_device_resync(remote_user_id), | ||||
) | ) | ||||
@@ -279,7 +287,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
class StripUnsignedFromEventsTestCase(unittest.TestCase): | class StripUnsignedFromEventsTestCase(unittest.TestCase): | ||||
def test_strip_unauthorized_unsigned_values(self): | |||||
def test_strip_unauthorized_unsigned_values(self) -> None: | |||||
event1 = { | event1 = { | ||||
"sender": "@baduser:test.serv", | "sender": "@baduser:test.serv", | ||||
"state_key": "@baduser:test.serv", | "state_key": "@baduser:test.serv", | ||||
@@ -296,7 +304,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase): | |||||
# Make sure unauthorized fields are stripped from unsigned | # Make sure unauthorized fields are stripped from unsigned | ||||
self.assertNotIn("more warez", filtered_event.unsigned) | self.assertNotIn("more warez", filtered_event.unsigned) | ||||
def test_strip_event_maintains_allowed_fields(self): | |||||
def test_strip_event_maintains_allowed_fields(self) -> None: | |||||
event2 = { | event2 = { | ||||
"sender": "@baduser:test.serv", | "sender": "@baduser:test.serv", | ||||
"state_key": "@baduser:test.serv", | "state_key": "@baduser:test.serv", | ||||
@@ -323,7 +331,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase): | |||||
self.assertIn("invite_room_state", filtered_event2.unsigned) | self.assertIn("invite_room_state", filtered_event2.unsigned) | ||||
self.assertEqual([], filtered_event2.unsigned["invite_room_state"]) | self.assertEqual([], filtered_event2.unsigned["invite_room_state"]) | ||||
def test_strip_event_removes_fields_based_on_event_type(self): | |||||
def test_strip_event_removes_fields_based_on_event_type(self) -> None: | |||||
event3 = { | event3 = { | ||||
"sender": "@baduser:test.serv", | "sender": "@baduser:test.serv", | ||||
"state_key": "@baduser:test.serv", | "state_key": "@baduser:test.serv", | ||||
@@ -20,12 +20,13 @@ import sys | |||||
import warnings | import warnings | ||||
from asyncio import Future | from asyncio import Future | ||||
from binascii import unhexlify | from binascii import unhexlify | ||||
from typing import Awaitable, Callable, Tuple, TypeVar | |||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar | |||||
from unittest.mock import Mock | from unittest.mock import Mock | ||||
import attr | import attr | ||||
import zope.interface | import zope.interface | ||||
from twisted.internet.interfaces import IProtocol | |||||
from twisted.python.failure import Failure | from twisted.python.failure import Failure | ||||
from twisted.web.client import ResponseDone | from twisted.web.client import ResponseDone | ||||
from twisted.web.http import RESPONSES | from twisted.web.http import RESPONSES | ||||
@@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse | |||||
from synapse.types import JsonDict | from synapse.types import JsonDict | ||||
if TYPE_CHECKING: | |||||
from sys import UnraisableHookArgs | |||||
TV = TypeVar("TV") | TV = TypeVar("TV") | ||||
@@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]: | |||||
unraisable_exceptions = [] | unraisable_exceptions = [] | ||||
orig_unraisablehook = sys.unraisablehook | orig_unraisablehook = sys.unraisablehook | ||||
def unraisablehook(unraisable): | |||||
def unraisablehook(unraisable: "UnraisableHookArgs") -> None: | |||||
unraisable_exceptions.append(unraisable.exc_value) | unraisable_exceptions.append(unraisable.exc_value) | ||||
def cleanup(): | |||||
def cleanup() -> None: | |||||
""" | """ | ||||
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. | A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. | ||||
""" | """ | ||||
sys.unraisablehook = orig_unraisablehook | sys.unraisablehook = orig_unraisablehook | ||||
if unraisable_exceptions: | if unraisable_exceptions: | ||||
raise unraisable_exceptions.pop() | |||||
exc = unraisable_exceptions.pop() | |||||
assert exc is not None | |||||
raise exc | |||||
sys.unraisablehook = unraisablehook | sys.unraisablehook = unraisablehook | ||||
return cleanup | return cleanup | ||||
def simple_async_mock(return_value=None, raises=None) -> Mock: | |||||
def simple_async_mock( | |||||
return_value: Optional[TV] = None, raises: Optional[Exception] = None | |||||
) -> Mock: | |||||
# AsyncMock is not available in python3.5, this mimics part of its behaviour | # AsyncMock is not available in python3.5, this mimics part of its behaviour | ||||
async def cb(*args, **kwargs): | |||||
async def cb(*args: Any, **kwargs: Any) -> Optional[TV]: | |||||
if raises: | if raises: | ||||
raise raises | raise raises | ||||
return return_value | return return_value | ||||
@@ -125,14 +133,14 @@ class FakeResponse: # type: ignore[misc] | |||||
headers: Headers = attr.Factory(Headers) | headers: Headers = attr.Factory(Headers) | ||||
@property | @property | ||||
def phrase(self): | |||||
def phrase(self) -> bytes: | |||||
return RESPONSES.get(self.code, b"Unknown Status") | return RESPONSES.get(self.code, b"Unknown Status") | ||||
@property | @property | ||||
def length(self): | |||||
def length(self) -> int: | |||||
return len(self.body) | return len(self.body) | ||||
def deliverBody(self, protocol): | |||||
def deliverBody(self, protocol: IProtocol) -> None: | |||||
protocol.dataReceived(self.body) | protocol.dataReceived(self.body) | ||||
protocol.connectionLost(Failure(ResponseDone())) | protocol.connectionLost(Failure(ResponseDone())) | ||||
@@ -12,7 +12,7 @@ | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import List, Optional, Tuple | |||||
from typing import Any, List, Optional, Tuple | |||||
import synapse.server | import synapse.server | ||||
from synapse.api.constants import EventTypes | from synapse.api.constants import EventTypes | ||||
@@ -32,7 +32,7 @@ async def inject_member_event( | |||||
membership: str, | membership: str, | ||||
target: Optional[str] = None, | target: Optional[str] = None, | ||||
extra_content: Optional[dict] = None, | extra_content: Optional[dict] = None, | ||||
**kwargs, | |||||
**kwargs: Any, | |||||
) -> EventBase: | ) -> EventBase: | ||||
"""Inject a membership event into a room.""" | """Inject a membership event into a room.""" | ||||
if target is None: | if target is None: | ||||
@@ -57,7 +57,7 @@ async def inject_event( | |||||
hs: synapse.server.HomeServer, | hs: synapse.server.HomeServer, | ||||
room_version: Optional[str] = None, | room_version: Optional[str] = None, | ||||
prev_event_ids: Optional[List[str]] = None, | prev_event_ids: Optional[List[str]] = None, | ||||
**kwargs, | |||||
**kwargs: Any, | |||||
) -> EventBase: | ) -> EventBase: | ||||
"""Inject a generic event into a room | """Inject a generic event into a room | ||||
@@ -82,7 +82,7 @@ async def create_event( | |||||
hs: synapse.server.HomeServer, | hs: synapse.server.HomeServer, | ||||
room_version: Optional[str] = None, | room_version: Optional[str] = None, | ||||
prev_event_ids: Optional[List[str]] = None, | prev_event_ids: Optional[List[str]] = None, | ||||
**kwargs, | |||||
**kwargs: Any, | |||||
) -> Tuple[EventBase, EventContext]: | ) -> Tuple[EventBase, EventContext]: | ||||
if room_version is None: | if room_version is None: | ||||
room_version = await hs.get_datastores().main.get_room_version_id( | room_version = await hs.get_datastores().main.get_room_version_id( | ||||
@@ -13,13 +13,13 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
from html.parser import HTMLParser | from html.parser import HTMLParser | ||||
from typing import Dict, Iterable, List, Optional, Tuple | |||||
from typing import Dict, Iterable, List, NoReturn, Optional, Tuple | |||||
class TestHtmlParser(HTMLParser): | class TestHtmlParser(HTMLParser): | ||||
"""A generic HTML page parser which extracts useful things from the HTML""" | """A generic HTML page parser which extracts useful things from the HTML""" | ||||
def __init__(self): | |||||
def __init__(self) -> None: | |||||
super().__init__() | super().__init__() | ||||
# a list of links found in the doc | # a list of links found in the doc | ||||
@@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser): | |||||
assert input_name | assert input_name | ||||
self.hiddens[input_name] = attr_dict["value"] | self.hiddens[input_name] = attr_dict["value"] | ||||
def error(_, message): | |||||
def error(self, message: str) -> NoReturn: | |||||
raise AssertionError(message) | raise AssertionError(message) |
@@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler): | |||||
tx_log = twisted.logger.Logger() | tx_log = twisted.logger.Logger() | ||||
def emit(self, record): | |||||
def emit(self, record: logging.LogRecord) -> None: | |||||
log_entry = self.format(record) | log_entry = self.format(record) | ||||
log_level = record.levelname.lower().replace("warning", "warn") | log_level = record.levelname.lower().replace("warning", "warn") | ||||
self.tx_log.emit( | self.tx_log.emit( | ||||
@@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler): | |||||
) | ) | ||||
def setup_logging(): | |||||
def setup_logging() -> None: | |||||
"""Configure the python logging appropriately for the tests. | """Configure the python logging appropriately for the tests. | ||||
(Logs will end up in _trial_temp.) | (Logs will end up in _trial_temp.) | ||||
@@ -14,7 +14,7 @@ | |||||
import json | import json | ||||
from typing import Any, Dict, List, Optional, Tuple | |||||
from typing import Any, ContextManager, Dict, List, Optional, Tuple | |||||
from unittest.mock import Mock, patch | from unittest.mock import Mock, patch | ||||
from urllib.parse import parse_qs | from urllib.parse import parse_qs | ||||
@@ -77,14 +77,14 @@ class FakeOidcServer: | |||||
self._id_token_overrides: Dict[str, Any] = {} | self._id_token_overrides: Dict[str, Any] = {} | ||||
def reset_mocks(self): | |||||
def reset_mocks(self) -> None: | |||||
self.request.reset_mock() | self.request.reset_mock() | ||||
self.get_jwks_handler.reset_mock() | self.get_jwks_handler.reset_mock() | ||||
self.get_metadata_handler.reset_mock() | self.get_metadata_handler.reset_mock() | ||||
self.get_userinfo_handler.reset_mock() | self.get_userinfo_handler.reset_mock() | ||||
self.post_token_handler.reset_mock() | self.post_token_handler.reset_mock() | ||||
def patch_homeserver(self, hs: HomeServer): | |||||
def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]: | |||||
"""Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. | """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. | ||||
This patch should be used whenever the HS is expected to perform request to the | This patch should be used whenever the HS is expected to perform request to the | ||||
@@ -188,7 +188,7 @@ class FakeOidcServer: | |||||
return self._sign(logout_token) | return self._sign(logout_token) | ||||
def id_token_override(self, overrides: dict): | |||||
def id_token_override(self, overrides: dict) -> ContextManager[dict]: | |||||
"""Temporarily patch the ID token generated by the token endpoint.""" | """Temporarily patch the ID token generated by the token endpoint.""" | ||||
return patch.object(self, "_id_token_overrides", overrides) | return patch.object(self, "_id_token_overrides", overrides) | ||||
@@ -247,7 +247,7 @@ class FakeOidcServer: | |||||
metadata: bool = False, | metadata: bool = False, | ||||
token: bool = False, | token: bool = False, | ||||
userinfo: bool = False, | userinfo: bool = False, | ||||
): | |||||
) -> ContextManager[Dict[str, Mock]]: | |||||
"""A context which makes a set of endpoints return a 500 error. | """A context which makes a set of endpoints return a 500 error. | ||||
Args: | Args: | ||||
@@ -258,7 +258,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||||
class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): | class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): | ||||
def test_out_of_band_invite_rejection(self): | |||||
def test_out_of_band_invite_rejection(self) -> None: | |||||
# this is where we have received an invite event over federation, and then | # this is where we have received an invite event over federation, and then | ||||
# rejected it. | # rejected it. | ||||
invite_pdu = { | invite_pdu = { | ||||
@@ -315,7 +315,7 @@ class HomeserverTestCase(TestCase): | |||||
# This has to be a function and not just a Mock, because | # This has to be a function and not just a Mock, because | ||||
# `self.helper.auth_user_id` is temporarily reassigned in some tests | # `self.helper.auth_user_id` is temporarily reassigned in some tests | ||||
async def get_requester(*args, **kwargs) -> Requester: | |||||
async def get_requester(*args: Any, **kwargs: Any) -> Requester: | |||||
assert self.helper.auth_user_id is not None | assert self.helper.auth_user_id is not None | ||||
return create_requester( | return create_requester( | ||||
user_id=UserID.from_string(self.helper.auth_user_id), | user_id=UserID.from_string(self.helper.auth_user_id), | ||||