Parcourir la source

Add login spam checker API (#15838)

tags/v1.87.0rc1
Erik Johnston il y a 10 mois
committed by GitHub
Parent
révision
25c55a9d22
Aucune clé connue n'a été trouvée dans la base pour cette signature ID de la clé GPG: 4AEE18F83AFDEB23
7 fichiers modifiés avec 285 ajouts et 6 suppressions
  1. +1
    -0
      changelog.d/15838.feature
  2. +36
    -0
      docs/modules/spam_checker_callbacks.md
  3. +11
    -0
      synapse/http/site.py
  4. +3
    -0
      synapse/module_api/__init__.py
  5. +80
    -0
      synapse/module_api/callbacks/spamchecker_callbacks.py
  6. +48
    -4
      synapse/rest/client/login.py
  7. +106
    -2
      tests/rest/client/test_login.py

+ 1
- 0
changelog.d/15838.feature Voir le fichier

@@ -0,0 +1 @@
Add spam checker module API for logins.

+ 36
- 0
docs/modules/spam_checker_callbacks.md Voir le fichier

@@ -348,6 +348,42 @@ callback returns `False`, Synapse falls through to the next one. The value of th
callback that does not return `False` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.


### `check_login_for_spam`

_First introduced in Synapse v1.87.0_

```python
async def check_login_for_spam(
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
```

Called when a user logs in.

The arguments passed to this callback are:

* `user_id`: The user ID the user is logging in with
* `device_id`: The device ID the user is re-logging into.
* `initial_display_name`: The device display name, if any.
* `request_info`: A collection of tuples, which first item is a user agent, and which
second item is an IP address. These user agents and IP addresses are the ones that were
used during the login process.
* `auth_provider_id`: The identifier of the SSO authentication provider, if any.

If multiple modules implement this callback, they will be considered in order. If a
callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one.
The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will
be used. If this happens, Synapse will not call any of the subsequent implementations of
this callback.

*Note:* This will not be called when a user registers.


## Example

The example below is a module that implements the spam checker callback


+ 11
- 0
synapse/http/site.py Voir le fichier

@@ -521,6 +521,11 @@ class SynapseRequest(Request):
else:
return self.getClientAddress().host

def request_info(self) -> "RequestInfo":
h = self.getHeader(b"User-Agent")
user_agent = h.decode("ascii", "replace") if h else None
return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available())


class XForwardedForRequest(SynapseRequest):
"""Request object which honours proxy headers
@@ -661,3 +666,9 @@ class SynapseSite(Site):

def log(self, request: SynapseRequest) -> None:
pass


@attr.s(auto_attribs=True, frozen=True, slots=True)
class RequestInfo:
user_agent: Optional[str]
ip: str

+ 3
- 0
synapse/module_api/__init__.py Voir le fichier

@@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
)
from synapse.module_api.callbacks.spamchecker_callbacks import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
CHECK_LOGIN_FOR_SPAM_CALLBACK,
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
CHECK_USERNAME_FOR_SPAM_CALLBACK,
@@ -302,6 +303,7 @@ class ModuleApi:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Registers callbacks for spam checking capabilities.

@@ -319,6 +321,7 @@ class ModuleApi:
check_username_for_spam=check_username_for_spam,
check_registration_for_spam=check_registration_for_spam,
check_media_file_for_spam=check_media_file_for_spam,
check_login_for_spam=check_login_for_spam,
)

def register_account_validity_callbacks(


+ 80
- 0
synapse/module_api/callbacks/spamchecker_callbacks.py Voir le fichier

@@ -196,6 +196,26 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
]
],
]
CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[
[
str,
Optional[str],
Optional[str],
Collection[Tuple[Optional[str], str]],
Optional[str],
],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
]
],
]


def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
@@ -315,6 +335,7 @@ class SpamCheckerModuleApiCallbacks:
self._check_media_file_for_spam_callbacks: List[
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
] = []
self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = []

def register_callbacks(
self,
@@ -335,6 +356,7 @@ class SpamCheckerModuleApiCallbacks:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
@@ -378,6 +400,9 @@ class SpamCheckerModuleApiCallbacks:
if check_media_file_for_spam is not None:
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)

if check_login_for_spam is not None:
self._check_login_for_spam_callbacks.append(check_login_for_spam)

@trace
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
@@ -819,3 +844,58 @@ class SpamCheckerModuleApiCallbacks:
return synapse.api.errors.Codes.FORBIDDEN, {}

return self.NOT_SPAM

async def check_login_for_spam(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if we should allow the given registration request.

Args:
user_id: The request user ID
request_info: List of tuples of user agent and IP that
were used during the registration process.
auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
"cas". If any. Note this does not include users registered
via a password provider.

Returns:
Enum for how the request should be handled
"""

for callback in self._check_login_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(
callback(
user_id,
device_id,
initial_display_name,
request_info,
auth_provider_id,
)
)
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is self.NOT_SPAM:
continue
elif isinstance(res, synapse.api.errors.Codes):
return res, {}
elif (
isinstance(res, tuple)
and len(res) == 2
and isinstance(res[0], synapse.api.errors.Codes)
and isinstance(res[1], dict)
):
return res
else:
logger.warning(
"Module returned invalid value, rejecting login as spam"
)
return synapse.api.errors.Codes.FORBIDDEN, {}

return self.NOT_SPAM

+ 48
- 4
synapse/rest/client/login.py Voir le fichier

@@ -50,7 +50,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.http.site import RequestInfo, SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
@@ -114,6 +114,7 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self._sso_handler = hs.get_sso_handler()
self._spam_checker = hs.get_module_api_callbacks().spam_checker

self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
@@ -197,6 +198,8 @@ class LoginRestServlet(RestServlet):
self._refresh_tokens_enabled and client_requested_refresh_token
)

request_info = request.request_info()

try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
requester = await self.auth.get_user_by_req(request)
@@ -216,6 +219,7 @@ class LoginRestServlet(RestServlet):
login_submission,
appservice,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif (
self.jwt_enabled
@@ -227,6 +231,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(
@@ -235,6 +240,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
else:
await self._address_ratelimiter.ratelimit(
@@ -243,6 +249,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -265,6 +272,8 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
appservice: ApplicationService,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
identifier = login_submission.get("identifier")
logger.info("Got appservice login request with identifier: %r", identifier)
@@ -300,10 +309,15 @@ class LoginRestServlet(RestServlet):
# The user represented by an appservice's configured sender_localpart
# is not actually created in Synapse.
should_check_deactivated=qualified_user_id != appservice.sender,
request_info=request_info,
)

async def _do_other_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""Handle non-token/saml/jwt logins

@@ -333,6 +347,7 @@ class LoginRestServlet(RestServlet):
login_submission,
callback,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
return result

@@ -347,6 +362,8 @@ class LoginRestServlet(RestServlet):
should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
should_check_deactivated: bool = True,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -371,6 +388,7 @@ class LoginRestServlet(RestServlet):

This exists purely for appservice's configured sender_localpart
which doesn't have an associated user in the database.
request_info: The user agent/IP address of the user.

Returns:
Dictionary of account information after successful login.
@@ -417,6 +435,22 @@ class LoginRestServlet(RestServlet):
)

initial_display_name = login_submission.get("initial_device_display_name")
spam_check = await self._spam_checker.check_login_for_spam(
user_id,
device_id=device_id,
initial_display_name=initial_display_name,
request_info=[(request_info.user_agent, request_info.ip)],
auth_provider_id=auth_provider_id,
)
if spam_check != self._spam_checker.NOT_SPAM:
logger.info("Blocking login due to spam checker")
raise SynapseError(
403,
msg="Login was blocked by the server",
errcode=spam_check[0],
additional_fields=spam_check[1],
)

(
device_id,
access_token,
@@ -451,7 +485,11 @@ class LoginRestServlet(RestServlet):
return result

async def _do_token_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""
Handle token login.
@@ -474,10 +512,15 @@ class LoginRestServlet(RestServlet):
auth_provider_id=res.auth_provider_id,
should_issue_refresh_token=should_issue_refresh_token,
auth_provider_session_id=res.auth_provider_session_id,
request_info=request_info,
)

async def _do_jwt_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
self,
login_submission: JsonDict,
should_issue_refresh_token: bool = False,
*,
request_info: RequestInfo,
) -> LoginResponse:
"""
Handle the custom JWT login.
@@ -496,6 +539,7 @@ class LoginRestServlet(RestServlet):
login_submission,
create_non_existent_users=True,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)




+ 106
- 2
tests/rest/client/test_login.py Voir le fichier

@@ -13,11 +13,12 @@
# limitations under the License.
import time
import urllib.parse
from typing import Any, Dict, List, Optional
from typing import Any, Collection, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib.parse import urlencode

import pymacaroons
from typing_extensions import Literal

from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
@@ -26,11 +27,12 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
from synapse.types import create_requester
from synapse.types import JsonDict, create_requester
from synapse.util import Clock

from tests import unittest
@@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [
]


class TestSpamChecker:
def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
check_login_for_spam=self.check_login_for_spam,
)

@staticmethod
def parse_config(config: JsonDict) -> None:
return None

async def check_login_for_spam(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union[
Literal["NOT_SPAM"],
Tuple["synapse.module_api.errors.Codes", JsonDict],
]:
return "NOT_SPAM"


class DenyAllSpamChecker:
def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
check_login_for_spam=self.check_login_for_spam,
)

@staticmethod
def parse_config(config: JsonDict) -> None:
return None

async def check_login_for_spam(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
request_info: Collection[Tuple[Optional[str], str]],
auth_provider_id: Optional[str] = None,
) -> Union[
Literal["NOT_SPAM"],
Tuple["synapse.module_api.errors.Codes", JsonDict],
]:
# Return an odd set of values to ensure that they get correctly passed
# to the client.
return Codes.LIMIT_EXCEEDED, {"extra": "value"}


class LoginRestServletTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -469,6 +521,58 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
],
)

@override_config(
{
"modules": [
{
"module": TestSpamChecker.__module__
+ "."
+ TestSpamChecker.__qualname__
}
]
}
)
def test_spam_checker_allow(self) -> None:
"""Check that that adding a spam checker doesn't break login."""
self.register_user("kermit", "monkey")

body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}

channel = self.make_request(
"POST",
"/_matrix/client/r0/login",
body,
)
self.assertEqual(channel.code, 200, channel.result)

@override_config(
{
"modules": [
{
"module": DenyAllSpamChecker.__module__
+ "."
+ DenyAllSpamChecker.__qualname__
}
]
}
)
def test_spam_checker_deny(self) -> None:
"""Check that login"""

self.register_user("kermit", "monkey")

body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}

channel = self.make_request(
"POST",
"/_matrix/client/r0/login",
body,
)
self.assertEqual(channel.code, 403, channel.result)
self.assertDictContainsSubset(
{"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body
)


@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):


Chargement…
Annuler
Enregistrer