@@ -0,0 +1 @@ | |||
Add spam checker module API for logins. |
@@ -348,6 +348,42 @@ callback returns `False`, Synapse falls through to the next one. The value of th | |||
callback that does not return `False` will be used. If this happens, Synapse will not call | |||
any of the subsequent implementations of this callback. | |||
### `check_login_for_spam` | |||
_First introduced in Synapse v1.87.0_ | |||
```python | |||
async def check_login_for_spam( | |||
user_id: str, | |||
device_id: Optional[str], | |||
initial_display_name: Optional[str], | |||
request_info: Collection[Tuple[Optional[str], str]], | |||
auth_provider_id: Optional[str] = None, | |||
) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"] | |||
``` | |||
Called when a user logs in. | |||
The arguments passed to this callback are: | |||
* `user_id`: The user ID the user is logging in with | |||
* `device_id`: The device ID the user is re-logging into. | |||
* `initial_display_name`: The device display name, if any. | |||
* `request_info`: A collection of tuples, which first item is a user agent, and which | |||
second item is an IP address. These user agents and IP addresses are the ones that were | |||
used during the login process. | |||
* `auth_provider_id`: The identifier of the SSO authentication provider, if any. | |||
If multiple modules implement this callback, they will be considered in order. If a | |||
callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one. | |||
The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will | |||
be used. If this happens, Synapse will not call any of the subsequent implementations of | |||
this callback. | |||
*Note:* This will not be called when a user registers. | |||
## Example | |||
The example below is a module that implements the spam checker callback | |||
@@ -521,6 +521,11 @@ class SynapseRequest(Request): | |||
else: | |||
return self.getClientAddress().host | |||
def request_info(self) -> "RequestInfo": | |||
h = self.getHeader(b"User-Agent") | |||
user_agent = h.decode("ascii", "replace") if h else None | |||
return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available()) | |||
class XForwardedForRequest(SynapseRequest): | |||
"""Request object which honours proxy headers | |||
@@ -661,3 +666,9 @@ class SynapseSite(Site): | |||
def log(self, request: SynapseRequest) -> None: | |||
pass | |||
@attr.s(auto_attribs=True, frozen=True, slots=True) | |||
class RequestInfo: | |||
user_agent: Optional[str] | |||
ip: str |
@@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import ( | |||
) | |||
from synapse.module_api.callbacks.spamchecker_callbacks import ( | |||
CHECK_EVENT_FOR_SPAM_CALLBACK, | |||
CHECK_LOGIN_FOR_SPAM_CALLBACK, | |||
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK, | |||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK, | |||
CHECK_USERNAME_FOR_SPAM_CALLBACK, | |||
@@ -302,6 +303,7 @@ class ModuleApi: | |||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK | |||
] = None, | |||
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, | |||
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None, | |||
) -> None: | |||
"""Registers callbacks for spam checking capabilities. | |||
@@ -319,6 +321,7 @@ class ModuleApi: | |||
check_username_for_spam=check_username_for_spam, | |||
check_registration_for_spam=check_registration_for_spam, | |||
check_media_file_for_spam=check_media_file_for_spam, | |||
check_login_for_spam=check_login_for_spam, | |||
) | |||
def register_account_validity_callbacks( | |||
@@ -196,6 +196,26 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[ | |||
] | |||
], | |||
] | |||
CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[ | |||
[ | |||
str, | |||
Optional[str], | |||
Optional[str], | |||
Collection[Tuple[Optional[str], str]], | |||
Optional[str], | |||
], | |||
Awaitable[ | |||
Union[ | |||
Literal["NOT_SPAM"], | |||
Codes, | |||
# Highly experimental, not officially part of the spamchecker API, may | |||
# disappear without warning depending on the results of ongoing | |||
# experiments. | |||
# Use this to return additional information as part of an error. | |||
Tuple[Codes, JsonDict], | |||
] | |||
], | |||
] | |||
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None: | |||
@@ -315,6 +335,7 @@ class SpamCheckerModuleApiCallbacks: | |||
self._check_media_file_for_spam_callbacks: List[ | |||
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK | |||
] = [] | |||
self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = [] | |||
def register_callbacks( | |||
self, | |||
@@ -335,6 +356,7 @@ class SpamCheckerModuleApiCallbacks: | |||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK | |||
] = None, | |||
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, | |||
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None, | |||
) -> None: | |||
"""Register callbacks from module for each hook.""" | |||
if check_event_for_spam is not None: | |||
@@ -378,6 +400,9 @@ class SpamCheckerModuleApiCallbacks: | |||
if check_media_file_for_spam is not None: | |||
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam) | |||
if check_login_for_spam is not None: | |||
self._check_login_for_spam_callbacks.append(check_login_for_spam) | |||
@trace | |||
async def check_event_for_spam( | |||
self, event: "synapse.events.EventBase" | |||
@@ -819,3 +844,58 @@ class SpamCheckerModuleApiCallbacks: | |||
return synapse.api.errors.Codes.FORBIDDEN, {} | |||
return self.NOT_SPAM | |||
async def check_login_for_spam( | |||
self, | |||
user_id: str, | |||
device_id: Optional[str], | |||
initial_display_name: Optional[str], | |||
request_info: Collection[Tuple[Optional[str], str]], | |||
auth_provider_id: Optional[str] = None, | |||
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]: | |||
"""Checks if we should allow the given registration request. | |||
Args: | |||
user_id: The request user ID | |||
request_info: List of tuples of user agent and IP that | |||
were used during the registration process. | |||
auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml", | |||
"cas". If any. Note this does not include users registered | |||
via a password provider. | |||
Returns: | |||
Enum for how the request should be handled | |||
""" | |||
for callback in self._check_login_for_spam_callbacks: | |||
with Measure( | |||
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) | |||
): | |||
res = await delay_cancellation( | |||
callback( | |||
user_id, | |||
device_id, | |||
initial_display_name, | |||
request_info, | |||
auth_provider_id, | |||
) | |||
) | |||
# Normalize return values to `Codes` or `"NOT_SPAM"`. | |||
if res is self.NOT_SPAM: | |||
continue | |||
elif isinstance(res, synapse.api.errors.Codes): | |||
return res, {} | |||
elif ( | |||
isinstance(res, tuple) | |||
and len(res) == 2 | |||
and isinstance(res[0], synapse.api.errors.Codes) | |||
and isinstance(res[1], dict) | |||
): | |||
return res | |||
else: | |||
logger.warning( | |||
"Module returned invalid value, rejecting login as spam" | |||
) | |||
return synapse.api.errors.Codes.FORBIDDEN, {} | |||
return self.NOT_SPAM |
@@ -50,7 +50,7 @@ from synapse.http.servlet import ( | |||
parse_json_object_from_request, | |||
parse_string, | |||
) | |||
from synapse.http.site import SynapseRequest | |||
from synapse.http.site import RequestInfo, SynapseRequest | |||
from synapse.rest.client._base import client_patterns | |||
from synapse.rest.well_known import WellKnownBuilder | |||
from synapse.types import JsonDict, UserID | |||
@@ -114,6 +114,7 @@ class LoginRestServlet(RestServlet): | |||
self.auth_handler = self.hs.get_auth_handler() | |||
self.registration_handler = hs.get_registration_handler() | |||
self._sso_handler = hs.get_sso_handler() | |||
self._spam_checker = hs.get_module_api_callbacks().spam_checker | |||
self._well_known_builder = WellKnownBuilder(hs) | |||
self._address_ratelimiter = Ratelimiter( | |||
@@ -197,6 +198,8 @@ class LoginRestServlet(RestServlet): | |||
self._refresh_tokens_enabled and client_requested_refresh_token | |||
) | |||
request_info = request.request_info() | |||
try: | |||
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: | |||
requester = await self.auth.get_user_by_req(request) | |||
@@ -216,6 +219,7 @@ class LoginRestServlet(RestServlet): | |||
login_submission, | |||
appservice, | |||
should_issue_refresh_token=should_issue_refresh_token, | |||
request_info=request_info, | |||
) | |||
elif ( | |||
self.jwt_enabled | |||
@@ -227,6 +231,7 @@ class LoginRestServlet(RestServlet): | |||
result = await self._do_jwt_login( | |||
login_submission, | |||
should_issue_refresh_token=should_issue_refresh_token, | |||
request_info=request_info, | |||
) | |||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: | |||
await self._address_ratelimiter.ratelimit( | |||
@@ -235,6 +240,7 @@ class LoginRestServlet(RestServlet): | |||
result = await self._do_token_login( | |||
login_submission, | |||
should_issue_refresh_token=should_issue_refresh_token, | |||
request_info=request_info, | |||
) | |||
else: | |||
await self._address_ratelimiter.ratelimit( | |||
@@ -243,6 +249,7 @@ class LoginRestServlet(RestServlet): | |||
result = await self._do_other_login( | |||
login_submission, | |||
should_issue_refresh_token=should_issue_refresh_token, | |||
request_info=request_info, | |||
) | |||
except KeyError: | |||
raise SynapseError(400, "Missing JSON keys.") | |||
@@ -265,6 +272,8 @@ class LoginRestServlet(RestServlet): | |||
login_submission: JsonDict, | |||
appservice: ApplicationService, | |||
should_issue_refresh_token: bool = False, | |||
*, | |||
request_info: RequestInfo, | |||
) -> LoginResponse: | |||
identifier = login_submission.get("identifier") | |||
logger.info("Got appservice login request with identifier: %r", identifier) | |||
@@ -300,10 +309,15 @@ class LoginRestServlet(RestServlet): | |||
# The user represented by an appservice's configured sender_localpart | |||
# is not actually created in Synapse. | |||
should_check_deactivated=qualified_user_id != appservice.sender, | |||
request_info=request_info, | |||
) | |||
async def _do_other_login( | |||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False | |||
self, | |||
login_submission: JsonDict, | |||
should_issue_refresh_token: bool = False, | |||
*, | |||
request_info: RequestInfo, | |||
) -> LoginResponse: | |||
"""Handle non-token/saml/jwt logins | |||
@@ -333,6 +347,7 @@ class LoginRestServlet(RestServlet): | |||
login_submission, | |||
callback, | |||
should_issue_refresh_token=should_issue_refresh_token, | |||
request_info=request_info, | |||
) | |||
return result | |||
@@ -347,6 +362,8 @@ class LoginRestServlet(RestServlet): | |||
should_issue_refresh_token: bool = False, | |||
auth_provider_session_id: Optional[str] = None, | |||
should_check_deactivated: bool = True, | |||
*, | |||
request_info: RequestInfo, | |||
) -> LoginResponse: | |||
"""Called when we've successfully authed the user and now need to | |||
actually login them in (e.g. create devices). This gets called on | |||
@@ -371,6 +388,7 @@ class LoginRestServlet(RestServlet): | |||
This exists purely for appservice's configured sender_localpart | |||
which doesn't have an associated user in the database. | |||
request_info: The user agent/IP address of the user. | |||
Returns: | |||
Dictionary of account information after successful login. | |||
@@ -417,6 +435,22 @@ class LoginRestServlet(RestServlet): | |||
) | |||
initial_display_name = login_submission.get("initial_device_display_name") | |||
spam_check = await self._spam_checker.check_login_for_spam( | |||
user_id, | |||
device_id=device_id, | |||
initial_display_name=initial_display_name, | |||
request_info=[(request_info.user_agent, request_info.ip)], | |||
auth_provider_id=auth_provider_id, | |||
) | |||
if spam_check != self._spam_checker.NOT_SPAM: | |||
logger.info("Blocking login due to spam checker") | |||
raise SynapseError( | |||
403, | |||
msg="Login was blocked by the server", | |||
errcode=spam_check[0], | |||
additional_fields=spam_check[1], | |||
) | |||
( | |||
device_id, | |||
access_token, | |||
@@ -451,7 +485,11 @@ class LoginRestServlet(RestServlet): | |||
return result | |||
async def _do_token_login( | |||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False | |||
self, | |||
login_submission: JsonDict, | |||
should_issue_refresh_token: bool = False, | |||
*, | |||
request_info: RequestInfo, | |||
) -> LoginResponse: | |||
""" | |||
Handle token login. | |||
@@ -474,10 +512,15 @@ class LoginRestServlet(RestServlet): | |||
auth_provider_id=res.auth_provider_id, | |||
should_issue_refresh_token=should_issue_refresh_token, | |||
auth_provider_session_id=res.auth_provider_session_id, | |||
request_info=request_info, | |||
) | |||
async def _do_jwt_login( | |||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False | |||
self, | |||
login_submission: JsonDict, | |||
should_issue_refresh_token: bool = False, | |||
*, | |||
request_info: RequestInfo, | |||
) -> LoginResponse: | |||
""" | |||
Handle the custom JWT login. | |||
@@ -496,6 +539,7 @@ class LoginRestServlet(RestServlet): | |||
login_submission, | |||
create_non_existent_users=True, | |||
should_issue_refresh_token=should_issue_refresh_token, | |||
request_info=request_info, | |||
) | |||
@@ -13,11 +13,12 @@ | |||
# limitations under the License. | |||
import time | |||
import urllib.parse | |||
from typing import Any, Dict, List, Optional | |||
from typing import Any, Collection, Dict, List, Optional, Tuple, Union | |||
from unittest.mock import Mock | |||
from urllib.parse import urlencode | |||
import pymacaroons | |||
from typing_extensions import Literal | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from twisted.web.resource import Resource | |||
@@ -26,11 +27,12 @@ import synapse.rest.admin | |||
from synapse.api.constants import ApprovalNoticeMedium, LoginType | |||
from synapse.api.errors import Codes | |||
from synapse.appservice import ApplicationService | |||
from synapse.module_api import ModuleApi | |||
from synapse.rest.client import devices, login, logout, register | |||
from synapse.rest.client.account import WhoamiRestServlet | |||
from synapse.rest.synapse.client import build_synapse_client_resource_tree | |||
from synapse.server import HomeServer | |||
from synapse.types import create_requester | |||
from synapse.types import JsonDict, create_requester | |||
from synapse.util import Clock | |||
from tests import unittest | |||
@@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [ | |||
] | |||
class TestSpamChecker: | |||
def __init__(self, config: None, api: ModuleApi): | |||
api.register_spam_checker_callbacks( | |||
check_login_for_spam=self.check_login_for_spam, | |||
) | |||
@staticmethod | |||
def parse_config(config: JsonDict) -> None: | |||
return None | |||
async def check_login_for_spam( | |||
self, | |||
user_id: str, | |||
device_id: Optional[str], | |||
initial_display_name: Optional[str], | |||
request_info: Collection[Tuple[Optional[str], str]], | |||
auth_provider_id: Optional[str] = None, | |||
) -> Union[ | |||
Literal["NOT_SPAM"], | |||
Tuple["synapse.module_api.errors.Codes", JsonDict], | |||
]: | |||
return "NOT_SPAM" | |||
class DenyAllSpamChecker: | |||
def __init__(self, config: None, api: ModuleApi): | |||
api.register_spam_checker_callbacks( | |||
check_login_for_spam=self.check_login_for_spam, | |||
) | |||
@staticmethod | |||
def parse_config(config: JsonDict) -> None: | |||
return None | |||
async def check_login_for_spam( | |||
self, | |||
user_id: str, | |||
device_id: Optional[str], | |||
initial_display_name: Optional[str], | |||
request_info: Collection[Tuple[Optional[str], str]], | |||
auth_provider_id: Optional[str] = None, | |||
) -> Union[ | |||
Literal["NOT_SPAM"], | |||
Tuple["synapse.module_api.errors.Codes", JsonDict], | |||
]: | |||
# Return an odd set of values to ensure that they get correctly passed | |||
# to the client. | |||
return Codes.LIMIT_EXCEEDED, {"extra": "value"} | |||
class LoginRestServletTestCase(unittest.HomeserverTestCase): | |||
servlets = [ | |||
synapse.rest.admin.register_servlets_for_client_rest_resource, | |||
@@ -469,6 +521,58 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): | |||
], | |||
) | |||
@override_config( | |||
{ | |||
"modules": [ | |||
{ | |||
"module": TestSpamChecker.__module__ | |||
+ "." | |||
+ TestSpamChecker.__qualname__ | |||
} | |||
] | |||
} | |||
) | |||
def test_spam_checker_allow(self) -> None: | |||
"""Check that that adding a spam checker doesn't break login.""" | |||
self.register_user("kermit", "monkey") | |||
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"} | |||
channel = self.make_request( | |||
"POST", | |||
"/_matrix/client/r0/login", | |||
body, | |||
) | |||
self.assertEqual(channel.code, 200, channel.result) | |||
@override_config( | |||
{ | |||
"modules": [ | |||
{ | |||
"module": DenyAllSpamChecker.__module__ | |||
+ "." | |||
+ DenyAllSpamChecker.__qualname__ | |||
} | |||
] | |||
} | |||
) | |||
def test_spam_checker_deny(self) -> None: | |||
"""Check that login""" | |||
self.register_user("kermit", "monkey") | |||
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"} | |||
channel = self.make_request( | |||
"POST", | |||
"/_matrix/client/r0/login", | |||
body, | |||
) | |||
self.assertEqual(channel.code, 403, channel.result) | |||
self.assertDictContainsSubset( | |||
{"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body | |||
) | |||
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") | |||
class MultiSSOTestCase(unittest.HomeserverTestCase): | |||