@@ -0,0 +1 @@ | |||
Add an admin API to manage ratelimit for a specific user. |
@@ -202,7 +202,7 @@ The following fields are returned in the JSON response body: | |||
- ``users`` - An array of objects, each containing information about an user. | |||
User objects contain the following fields: | |||
- ``name`` - string - Fully-qualified user ID (ex. `@user:server.com`). | |||
- ``name`` - string - Fully-qualified user ID (ex. ``@user:server.com``). | |||
- ``is_guest`` - bool - Status if that user is a guest account. | |||
- ``admin`` - bool - Status if that user is a server administrator. | |||
- ``user_type`` - string - Type of the user. Normal users are type ``None``. | |||
@@ -864,3 +864,118 @@ The following parameters should be set in the URL: | |||
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must | |||
be local. | |||
Override ratelimiting for users | |||
=============================== | |||
This API allows to override or disable ratelimiting for a specific user. | |||
There are specific APIs to set, get and delete a ratelimit. | |||
Get status of ratelimit | |||
----------------------- | |||
The API is:: | |||
GET /_synapse/admin/v1/users/<user_id>/override_ratelimit | |||
To use it, you will need to authenticate by providing an ``access_token`` for a | |||
server admin: see `README.rst <README.rst>`_. | |||
A response body like the following is returned: | |||
.. code:: json | |||
{ | |||
"messages_per_second": 0, | |||
"burst_count": 0 | |||
} | |||
**Parameters** | |||
The following parameters should be set in the URL: | |||
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must | |||
be local. | |||
**Response** | |||
The following fields are returned in the JSON response body: | |||
- ``messages_per_second`` - integer - The number of actions that can | |||
be performed in a second. `0` mean that ratelimiting is disabled for this user. | |||
- ``burst_count`` - integer - How many actions that can be performed before | |||
being limited. | |||
If **no** custom ratelimit is set, an empty JSON dict is returned. | |||
.. code:: json | |||
{} | |||
Set ratelimit | |||
------------- | |||
The API is:: | |||
POST /_synapse/admin/v1/users/<user_id>/override_ratelimit | |||
To use it, you will need to authenticate by providing an ``access_token`` for a | |||
server admin: see `README.rst <README.rst>`_. | |||
A response body like the following is returned: | |||
.. code:: json | |||
{ | |||
"messages_per_second": 0, | |||
"burst_count": 0 | |||
} | |||
**Parameters** | |||
The following parameters should be set in the URL: | |||
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must | |||
be local. | |||
Body parameters: | |||
- ``messages_per_second`` - positive integer, optional. The number of actions that can | |||
be performed in a second. Defaults to ``0``. | |||
- ``burst_count`` - positive integer, optional. How many actions that can be performed | |||
before being limited. Defaults to ``0``. | |||
To disable users' ratelimit set both values to ``0``. | |||
**Response** | |||
The following fields are returned in the JSON response body: | |||
- ``messages_per_second`` - integer - The number of actions that can | |||
be performed in a second. | |||
- ``burst_count`` - integer - How many actions that can be performed before | |||
being limited. | |||
Delete ratelimit | |||
---------------- | |||
The API is:: | |||
DELETE /_synapse/admin/v1/users/<user_id>/override_ratelimit | |||
To use it, you will need to authenticate by providing an ``access_token`` for a | |||
server admin: see `README.rst <README.rst>`_. | |||
An empty JSON dict is returned. | |||
.. code:: json | |||
{} | |||
**Parameters** | |||
The following parameters should be set in the URL: | |||
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must | |||
be local. | |||
@@ -54,6 +54,7 @@ from synapse.rest.admin.users import ( | |||
AccountValidityRenewServlet, | |||
DeactivateAccountRestServlet, | |||
PushersRestServlet, | |||
RateLimitRestServlet, | |||
ResetPasswordRestServlet, | |||
SearchUsersRestServlet, | |||
ShadowBanRestServlet, | |||
@@ -239,6 +240,7 @@ def register_servlets(hs, http_server): | |||
ShadowBanRestServlet(hs).register(http_server) | |||
ForwardExtremitiesRestServlet(hs).register(http_server) | |||
RoomEventContextServlet(hs).register(http_server) | |||
RateLimitRestServlet(hs).register(http_server) | |||
def register_servlets_for_client_rest_resource(hs, http_server): | |||
@@ -981,3 +981,114 @@ class ShadowBanRestServlet(RestServlet): | |||
await self.store.set_shadow_banned(UserID.from_string(user_id), True) | |||
return 200, {} | |||
class RateLimitRestServlet(RestServlet): | |||
"""An admin API to override ratelimiting for an user. | |||
Example: | |||
POST /_synapse/admin/v1/users/@test:example.com/override_ratelimit | |||
{ | |||
"messages_per_second": 0, | |||
"burst_count": 0 | |||
} | |||
200 OK | |||
{ | |||
"messages_per_second": 0, | |||
"burst_count": 0 | |||
} | |||
""" | |||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit") | |||
def __init__(self, hs: "HomeServer"): | |||
self.hs = hs | |||
self.store = hs.get_datastore() | |||
self.auth = hs.get_auth() | |||
async def on_GET( | |||
self, request: SynapseRequest, user_id: str | |||
) -> Tuple[int, JsonDict]: | |||
await assert_requester_is_admin(self.auth, request) | |||
if not self.hs.is_mine_id(user_id): | |||
raise SynapseError(400, "Can only lookup local users") | |||
if not await self.store.get_user_by_id(user_id): | |||
raise NotFoundError("User not found") | |||
ratelimit = await self.store.get_ratelimit_for_user(user_id) | |||
if ratelimit: | |||
# convert `null` to `0` for consistency | |||
# both values do the same in retelimit handler | |||
ret = { | |||
"messages_per_second": 0 | |||
if ratelimit.messages_per_second is None | |||
else ratelimit.messages_per_second, | |||
"burst_count": 0 | |||
if ratelimit.burst_count is None | |||
else ratelimit.burst_count, | |||
} | |||
else: | |||
ret = {} | |||
return 200, ret | |||
async def on_POST( | |||
self, request: SynapseRequest, user_id: str | |||
) -> Tuple[int, JsonDict]: | |||
await assert_requester_is_admin(self.auth, request) | |||
if not self.hs.is_mine_id(user_id): | |||
raise SynapseError(400, "Only local users can be ratelimited") | |||
if not await self.store.get_user_by_id(user_id): | |||
raise NotFoundError("User not found") | |||
body = parse_json_object_from_request(request, allow_empty_body=True) | |||
messages_per_second = body.get("messages_per_second", 0) | |||
burst_count = body.get("burst_count", 0) | |||
if not isinstance(messages_per_second, int) or messages_per_second < 0: | |||
raise SynapseError( | |||
400, | |||
"%r parameter must be a positive int" % (messages_per_second,), | |||
errcode=Codes.INVALID_PARAM, | |||
) | |||
if not isinstance(burst_count, int) or burst_count < 0: | |||
raise SynapseError( | |||
400, | |||
"%r parameter must be a positive int" % (burst_count,), | |||
errcode=Codes.INVALID_PARAM, | |||
) | |||
await self.store.set_ratelimit_for_user( | |||
user_id, messages_per_second, burst_count | |||
) | |||
ratelimit = await self.store.get_ratelimit_for_user(user_id) | |||
assert ratelimit is not None | |||
ret = { | |||
"messages_per_second": ratelimit.messages_per_second, | |||
"burst_count": ratelimit.burst_count, | |||
} | |||
return 200, ret | |||
async def on_DELETE( | |||
self, request: SynapseRequest, user_id: str | |||
) -> Tuple[int, JsonDict]: | |||
await assert_requester_is_admin(self.auth, request) | |||
if not self.hs.is_mine_id(user_id): | |||
raise SynapseError(400, "Only local users can be ratelimited") | |||
if not await self.store.get_user_by_id(user_id): | |||
raise NotFoundError("User not found") | |||
await self.store.delete_ratelimit_for_user(user_id) | |||
return 200, {} |
@@ -521,13 +521,11 @@ class RoomWorkerStore(SQLBaseStore): | |||
) | |||
@cached(max_entries=10000) | |||
async def get_ratelimit_for_user(self, user_id): | |||
"""Check if there are any overrides for ratelimiting for the given | |||
user | |||
async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]: | |||
"""Check if there are any overrides for ratelimiting for the given user | |||
Args: | |||
user_id (str) | |||
user_id: user ID of the user | |||
Returns: | |||
RatelimitOverride if there is an override, else None. If the contents | |||
of RatelimitOverride are None or 0 then ratelimitng has been | |||
@@ -549,6 +547,62 @@ class RoomWorkerStore(SQLBaseStore): | |||
else: | |||
return None | |||
async def set_ratelimit_for_user( | |||
self, user_id: str, messages_per_second: int, burst_count: int | |||
) -> None: | |||
"""Sets whether a user is set an overridden ratelimit. | |||
Args: | |||
user_id: user ID of the user | |||
messages_per_second: The number of actions that can be performed in a second. | |||
burst_count: How many actions that can be performed before being limited. | |||
""" | |||
def set_ratelimit_txn(txn): | |||
self.db_pool.simple_upsert_txn( | |||
txn, | |||
table="ratelimit_override", | |||
keyvalues={"user_id": user_id}, | |||
values={ | |||
"messages_per_second": messages_per_second, | |||
"burst_count": burst_count, | |||
}, | |||
) | |||
self._invalidate_cache_and_stream( | |||
txn, self.get_ratelimit_for_user, (user_id,) | |||
) | |||
await self.db_pool.runInteraction("set_ratelimit", set_ratelimit_txn) | |||
async def delete_ratelimit_for_user(self, user_id: str) -> None: | |||
"""Delete an overridden ratelimit for a user. | |||
Args: | |||
user_id: user ID of the user | |||
""" | |||
def delete_ratelimit_txn(txn): | |||
row = self.db_pool.simple_select_one_txn( | |||
txn, | |||
table="ratelimit_override", | |||
keyvalues={"user_id": user_id}, | |||
retcols=["user_id"], | |||
allow_none=True, | |||
) | |||
if not row: | |||
return | |||
# They are there, delete them. | |||
self.db_pool.simple_delete_one_txn( | |||
txn, "ratelimit_override", keyvalues={"user_id": user_id} | |||
) | |||
self._invalidate_cache_and_stream( | |||
txn, self.get_ratelimit_for_user, (user_id,) | |||
) | |||
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn) | |||
@cached() | |||
async def get_retention_policy_for_room(self, room_id): | |||
"""Get the retention policy for a given room. | |||
@@ -3011,3 +3011,287 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): | |||
# Ensure the user is shadow-banned (and the cache was cleared). | |||
result = self.get_success(self.store.get_user_by_access_token(other_user_token)) | |||
self.assertTrue(result.shadow_banned) | |||
class RateLimitTestCase(unittest.HomeserverTestCase): | |||
servlets = [ | |||
synapse.rest.admin.register_servlets, | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs): | |||
self.store = hs.get_datastore() | |||
self.admin_user = self.register_user("admin", "pass", admin=True) | |||
self.admin_user_tok = self.login("admin", "pass") | |||
self.other_user = self.register_user("user", "pass") | |||
self.url = ( | |||
"/_synapse/admin/v1/users/%s/override_ratelimit" | |||
% urllib.parse.quote(self.other_user) | |||
) | |||
def test_no_auth(self): | |||
""" | |||
Try to get information of a user without authentication. | |||
""" | |||
channel = self.make_request("GET", self.url, b"{}") | |||
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) | |||
channel = self.make_request("POST", self.url, b"{}") | |||
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) | |||
channel = self.make_request("DELETE", self.url, b"{}") | |||
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) | |||
def test_requester_is_no_admin(self): | |||
""" | |||
If the user is not a server admin, an error is returned. | |||
""" | |||
other_user_token = self.login("user", "pass") | |||
channel = self.make_request( | |||
"GET", | |||
self.url, | |||
access_token=other_user_token, | |||
) | |||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) | |||
channel = self.make_request( | |||
"POST", | |||
self.url, | |||
access_token=other_user_token, | |||
) | |||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) | |||
channel = self.make_request( | |||
"DELETE", | |||
self.url, | |||
access_token=other_user_token, | |||
) | |||
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) | |||
def test_user_does_not_exist(self): | |||
""" | |||
Tests that a lookup for a user that does not exist returns a 404 | |||
""" | |||
url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" | |||
channel = self.make_request( | |||
"GET", | |||
url, | |||
access_token=self.admin_user_tok, | |||
) | |||
self.assertEqual(404, channel.code, msg=channel.json_body) | |||
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) | |||
channel = self.make_request( | |||
"POST", | |||
url, | |||
access_token=self.admin_user_tok, | |||
) | |||
self.assertEqual(404, channel.code, msg=channel.json_body) | |||
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) | |||
channel = self.make_request( | |||
"DELETE", | |||
url, | |||
access_token=self.admin_user_tok, | |||
) | |||
self.assertEqual(404, channel.code, msg=channel.json_body) | |||
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) | |||
def test_user_is_not_local(self): | |||
""" | |||
Tests that a lookup for a user that is not a local returns a 400 | |||
""" | |||
url = ( | |||
"/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit" | |||
) | |||
channel = self.make_request( | |||
"GET", | |||
url, | |||
access_token=self.admin_user_tok, | |||
) | |||
self.assertEqual(400, channel.code, msg=channel.json_body) | |||
self.assertEqual("Can only lookup local users", channel.json_body["error"]) | |||
channel = self.make_request( | |||
"POST", | |||
url, | |||
access_token=self.admin_user_tok, | |||
) | |||
self.assertEqual(400, channel.code, msg=channel.json_body) | |||
self.assertEqual( | |||
"Only local users can be ratelimited", channel.json_body["error"] | |||
) | |||
channel = self.make_request( | |||
"DELETE", | |||
url, | |||
access_token=self.admin_user_tok, | |||
) | |||
self.assertEqual(400, channel.code, msg=channel.json_body) | |||
self.assertEqual( | |||
"Only local users can be ratelimited", channel.json_body["error"] | |||
) | |||
def test_invalid_parameter(self): | |||
""" | |||
If parameters are invalid, an error is returned. | |||
""" | |||
# messages_per_second is a string | |||
channel = self.make_request( | |||
"POST", | |||
self.url, | |||
access_token=self.admin_user_tok, | |||
content={"messages_per_second": "string"}, | |||
) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) | |||
# messages_per_second is negative | |||
channel = self.make_request( | |||
"POST", | |||
self.url, | |||
access_token=self.admin_user_tok, | |||
content={"messages_per_second": -1}, | |||
) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) | |||
# burst_count is a string | |||
channel = self.make_request( | |||
"POST", | |||
self.url, | |||
access_token=self.admin_user_tok, | |||
content={"burst_count": "string"}, | |||
) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) | |||
# burst_count is negative | |||
channel = self.make_request( | |||
"POST", | |||
self.url, | |||
access_token=self.admin_user_tok, | |||
content={"burst_count": -1}, | |||
) | |||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) | |||
def test_return_zero_when_null(self): | |||
""" | |||
If values in database are `null` API should return an int `0` | |||
""" | |||
self.get_success( | |||
self.store.db_pool.simple_upsert( | |||
table="ratelimit_override", | |||
keyvalues={"user_id": self.other_user}, | |||
values={ | |||
"messages_per_second": None, | |||
"burst_count": None, | |||
}, | |||
) | |||
) | |||
# request status | |||
channel = self.make_request( | |||
"GET", | |||
self.url, | |||
access_token=self.admin_user_tok, | |||
) | |||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(0, channel.json_body["messages_per_second"]) | |||
self.assertEqual(0, channel.json_body["burst_count"]) | |||
def test_success(self): | |||
""" | |||
Rate-limiting (set/update/delete) should succeed for an admin. | |||
""" | |||
# request status | |||
channel = self.make_request( | |||
"GET", | |||
self.url, | |||
access_token=self.admin_user_tok, | |||
) | |||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertNotIn("messages_per_second", channel.json_body) | |||
self.assertNotIn("burst_count", channel.json_body) | |||
# set ratelimit | |||
channel = self.make_request( | |||
"POST", | |||
self.url, | |||
access_token=self.admin_user_tok, | |||
content={"messages_per_second": 10, "burst_count": 11}, | |||
) | |||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(10, channel.json_body["messages_per_second"]) | |||
self.assertEqual(11, channel.json_body["burst_count"]) | |||
# update ratelimit | |||
channel = self.make_request( | |||
"POST", | |||
self.url, | |||
access_token=self.admin_user_tok, | |||
content={"messages_per_second": 20, "burst_count": 21}, | |||
) | |||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(20, channel.json_body["messages_per_second"]) | |||
self.assertEqual(21, channel.json_body["burst_count"]) | |||
# request status | |||
channel = self.make_request( | |||
"GET", | |||
self.url, | |||
access_token=self.admin_user_tok, | |||
) | |||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertEqual(20, channel.json_body["messages_per_second"]) | |||
self.assertEqual(21, channel.json_body["burst_count"]) | |||
# delete ratelimit | |||
channel = self.make_request( | |||
"DELETE", | |||
self.url, | |||
access_token=self.admin_user_tok, | |||
) | |||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertNotIn("messages_per_second", channel.json_body) | |||
self.assertNotIn("burst_count", channel.json_body) | |||
# request status | |||
channel = self.make_request( | |||
"GET", | |||
self.url, | |||
access_token=self.admin_user_tok, | |||
) | |||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) | |||
self.assertNotIn("messages_per_second", channel.json_body) | |||
self.assertNotIn("burst_count", channel.json_body) |