@@ -25,7 +25,7 @@ from twisted.internet import defer | |||
import synapse.types | |||
from synapse import event_auth | |||
from synapse.api.constants import EventTypes, JoinRules, Membership | |||
from synapse.api.errors import AuthError, Codes | |||
from synapse.api.errors import AuthError, Codes, ResourceLimitError | |||
from synapse.types import UserID | |||
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache | |||
from synapse.util.caches.lrucache import LruCache | |||
@@ -784,10 +784,11 @@ class Auth(object): | |||
MAU cohort | |||
""" | |||
if self.hs.config.hs_disabled: | |||
raise AuthError( | |||
raise ResourceLimitError( | |||
403, self.hs.config.hs_disabled_message, | |||
errcode=Codes.RESOURCE_LIMIT_EXCEED, | |||
admin_uri=self.hs.config.admin_uri, | |||
limit_type=self.hs.config.hs_disabled_limit_type | |||
) | |||
if self.hs.config.limit_usage_by_mau is True: | |||
# If the user is already part of the MAU cohort | |||
@@ -798,8 +799,9 @@ class Auth(object): | |||
# Else if there is no room in the MAU bucket, bail | |||
current_mau = yield self.store.get_monthly_active_count() | |||
if current_mau >= self.hs.config.max_mau_value: | |||
raise AuthError( | |||
raise ResourceLimitError( | |||
403, "Monthly Active User Limits AU Limit Exceeded", | |||
admin_uri=self.hs.config.admin_uri, | |||
errcode=Codes.RESOURCE_LIMIT_EXCEED | |||
errcode=Codes.RESOURCE_LIMIT_EXCEED, | |||
limit_type="monthly_active_user" | |||
) |
@@ -224,15 +224,34 @@ class NotFoundError(SynapseError): | |||
class AuthError(SynapseError): | |||
"""An error raised when there was a problem authorising an event.""" | |||
def __init__(self, code, msg, errcode=Codes.FORBIDDEN, admin_uri=None): | |||
def __init__(self, *args, **kwargs): | |||
if "errcode" not in kwargs: | |||
kwargs["errcode"] = Codes.FORBIDDEN | |||
super(AuthError, self).__init__(*args, **kwargs) | |||
class ResourceLimitError(SynapseError): | |||
""" | |||
Any error raised when there is a problem with resource usage. | |||
For instance, the monthly active user limit for the server has been exceeded | |||
""" | |||
def __init__( | |||
self, code, msg, | |||
errcode=Codes.RESOURCE_LIMIT_EXCEED, | |||
admin_uri=None, | |||
limit_type=None, | |||
): | |||
self.admin_uri = admin_uri | |||
super(AuthError, self).__init__(code, msg, errcode=errcode) | |||
self.limit_type = limit_type | |||
super(ResourceLimitError, self).__init__(code, msg, errcode=errcode) | |||
def error_dict(self): | |||
return cs_error( | |||
self.msg, | |||
self.errcode, | |||
admin_uri=self.admin_uri, | |||
limit_type=self.limit_type | |||
) | |||
@@ -81,6 +81,7 @@ class ServerConfig(Config): | |||
# Options to disable HS | |||
self.hs_disabled = config.get("hs_disabled", False) | |||
self.hs_disabled_message = config.get("hs_disabled_message", "") | |||
self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "") | |||
# Admin uri to direct users at should their instance become blocked | |||
# due to resource constraints | |||
@@ -21,7 +21,7 @@ from twisted.internet import defer | |||
import synapse.handlers.auth | |||
from synapse.api.auth import Auth | |||
from synapse.api.errors import AuthError, Codes | |||
from synapse.api.errors import AuthError, Codes, ResourceLimitError | |||
from synapse.types import UserID | |||
from tests import unittest | |||
@@ -455,7 +455,7 @@ class AuthTestCase(unittest.TestCase): | |||
return_value=defer.succeed(lots_of_users) | |||
) | |||
with self.assertRaises(AuthError) as e: | |||
with self.assertRaises(ResourceLimitError) as e: | |||
yield self.auth.check_auth_blocking() | |||
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) | |||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) | |||
@@ -471,7 +471,7 @@ class AuthTestCase(unittest.TestCase): | |||
def test_hs_disabled(self): | |||
self.hs.config.hs_disabled = True | |||
self.hs.config.hs_disabled_message = "Reason for being disabled" | |||
with self.assertRaises(AuthError) as e: | |||
with self.assertRaises(ResourceLimitError) as e: | |||
yield self.auth.check_auth_blocking() | |||
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) | |||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) | |||
@@ -20,7 +20,7 @@ from twisted.internet import defer | |||
import synapse | |||
import synapse.api.errors | |||
from synapse.api.errors import AuthError | |||
from synapse.api.errors import ResourceLimitError | |||
from synapse.handlers.auth import AuthHandler | |||
from tests import unittest | |||
@@ -130,13 +130,13 @@ class AuthTestCase(unittest.TestCase): | |||
return_value=defer.succeed(self.large_number_of_users) | |||
) | |||
with self.assertRaises(AuthError): | |||
with self.assertRaises(ResourceLimitError): | |||
yield self.auth_handler.get_access_token_for_user_id('user_a') | |||
self.hs.get_datastore().get_monthly_active_count = Mock( | |||
return_value=defer.succeed(self.large_number_of_users) | |||
) | |||
with self.assertRaises(AuthError): | |||
with self.assertRaises(ResourceLimitError): | |||
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( | |||
self._get_macaroon().serialize() | |||
) | |||
@@ -149,13 +149,13 @@ class AuthTestCase(unittest.TestCase): | |||
self.hs.get_datastore().get_monthly_active_count = Mock( | |||
return_value=defer.succeed(self.hs.config.max_mau_value) | |||
) | |||
with self.assertRaises(AuthError): | |||
with self.assertRaises(ResourceLimitError): | |||
yield self.auth_handler.get_access_token_for_user_id('user_a') | |||
self.hs.get_datastore().get_monthly_active_count = Mock( | |||
return_value=defer.succeed(self.hs.config.max_mau_value) | |||
) | |||
with self.assertRaises(AuthError): | |||
with self.assertRaises(ResourceLimitError): | |||
yield self.auth_handler.validate_short_term_login_token_and_get_user_id( | |||
self._get_macaroon().serialize() | |||
) | |||
@@ -17,7 +17,7 @@ from mock import Mock | |||
from twisted.internet import defer | |||
from synapse.api.errors import AuthError | |||
from synapse.api.errors import ResourceLimitError | |||
from synapse.handlers.register import RegistrationHandler | |||
from synapse.types import UserID, create_requester | |||
@@ -109,13 +109,13 @@ class RegistrationTestCase(unittest.TestCase): | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=defer.succeed(self.lots_of_users) | |||
) | |||
with self.assertRaises(AuthError): | |||
with self.assertRaises(ResourceLimitError): | |||
yield self.handler.get_or_create_user("requester", 'b', "display_name") | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=defer.succeed(self.hs.config.max_mau_value) | |||
) | |||
with self.assertRaises(AuthError): | |||
with self.assertRaises(ResourceLimitError): | |||
yield self.handler.get_or_create_user("requester", 'b', "display_name") | |||
@defer.inlineCallbacks | |||
@@ -124,13 +124,13 @@ class RegistrationTestCase(unittest.TestCase): | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=defer.succeed(self.lots_of_users) | |||
) | |||
with self.assertRaises(AuthError): | |||
with self.assertRaises(ResourceLimitError): | |||
yield self.handler.register(localpart="local_part") | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=defer.succeed(self.hs.config.max_mau_value) | |||
) | |||
with self.assertRaises(AuthError): | |||
with self.assertRaises(ResourceLimitError): | |||
yield self.handler.register(localpart="local_part") | |||
@defer.inlineCallbacks | |||
@@ -139,11 +139,11 @@ class RegistrationTestCase(unittest.TestCase): | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=defer.succeed(self.lots_of_users) | |||
) | |||
with self.assertRaises(AuthError): | |||
with self.assertRaises(ResourceLimitError): | |||
yield self.handler.register_saml2(localpart="local_part") | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=defer.succeed(self.hs.config.max_mau_value) | |||
) | |||
with self.assertRaises(AuthError): | |||
with self.assertRaises(ResourceLimitError): | |||
yield self.handler.register_saml2(localpart="local_part") |
@@ -14,7 +14,7 @@ | |||
# limitations under the License. | |||
from twisted.internet import defer | |||
from synapse.api.errors import AuthError, Codes | |||
from synapse.api.errors import Codes, ResourceLimitError | |||
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION | |||
from synapse.handlers.sync import SyncConfig, SyncHandler | |||
from synapse.types import UserID | |||
@@ -49,7 +49,7 @@ class SyncTestCase(tests.unittest.TestCase): | |||
# Test that global lock works | |||
self.hs.config.hs_disabled = True | |||
with self.assertRaises(AuthError) as e: | |||
with self.assertRaises(ResourceLimitError) as e: | |||
yield self.sync_handler.wait_for_sync_for_user(sync_config) | |||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) | |||
@@ -57,7 +57,7 @@ class SyncTestCase(tests.unittest.TestCase): | |||
sync_config = self._generate_sync_config(user_id2) | |||
with self.assertRaises(AuthError) as e: | |||
with self.assertRaises(ResourceLimitError) as e: | |||
yield self.sync_handler.wait_for_sync_for_user(sync_config) | |||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) | |||
@@ -137,6 +137,7 @@ def setup_test_homeserver( | |||
config.limit_usage_by_mau = False | |||
config.hs_disabled = False | |||
config.hs_disabled_message = "" | |||
config.hs_disabled_limit_type = "" | |||
config.max_mau_value = 50 | |||
config.mau_limits_reserved_threepids = [] | |||
config.admin_uri = None | |||