- Update black version to the latest
- Run black auto formatting over the codebase
- Run autoformatting according to [`docs/code_style.md
`](80d6dc9783/docs/code_style.md)
- Update `code_style.md` docs around installing black to use the correct version
tags/v1.28.0rc1
@@ -0,0 +1 @@ | |||
Update the version of black used to 20.8b1. |
@@ -92,7 +92,7 @@ class SynapseCmd(cmd.Cmd): | |||
return self.config["user"].split(":")[1] | |||
def do_config(self, line): | |||
""" Show the config for this client: "config" | |||
"""Show the config for this client: "config" | |||
Edit a key value mapping: "config key value" e.g. "config token 1234" | |||
Config variables: | |||
user: The username to auth with. | |||
@@ -360,7 +360,7 @@ class SynapseCmd(cmd.Cmd): | |||
print(e) | |||
def do_topic(self, line): | |||
""""topic [set|get] <roomid> [<newtopic>]" | |||
""" "topic [set|get] <roomid> [<newtopic>]" | |||
Set the topic for a room: topic set <roomid> <newtopic> | |||
Get the topic for a room: topic get <roomid> | |||
""" | |||
@@ -690,7 +690,7 @@ class SynapseCmd(cmd.Cmd): | |||
self._do_presence_state(2, line) | |||
def _parse(self, line, keys, force_keys=False): | |||
""" Parses the given line. | |||
"""Parses the given line. | |||
Args: | |||
line : The line to parse | |||
@@ -721,7 +721,7 @@ class SynapseCmd(cmd.Cmd): | |||
query_params={"access_token": None}, | |||
alt_text=None, | |||
): | |||
""" Runs an HTTP request and pretty prints the output. | |||
"""Runs an HTTP request and pretty prints the output. | |||
Args: | |||
method: HTTP method | |||
@@ -23,11 +23,10 @@ from twisted.web.http_headers import Headers | |||
class HttpClient: | |||
""" Interface for talking json over http | |||
""" | |||
"""Interface for talking json over http""" | |||
def put_json(self, url, data): | |||
""" Sends the specifed json data using PUT | |||
"""Sends the specifed json data using PUT | |||
Args: | |||
url (str): The URL to PUT data to. | |||
@@ -41,7 +40,7 @@ class HttpClient: | |||
pass | |||
def get_json(self, url, args=None): | |||
""" Gets some json from the given host homeserver and path | |||
"""Gets some json from the given host homeserver and path | |||
Args: | |||
url (str): The URL to GET data from. | |||
@@ -58,7 +57,7 @@ class HttpClient: | |||
class TwistedHttpClient(HttpClient): | |||
""" Wrapper around the twisted HTTP client api. | |||
"""Wrapper around the twisted HTTP client api. | |||
Attributes: | |||
agent (twisted.web.client.Agent): The twisted Agent used to send the | |||
@@ -87,8 +86,7 @@ class TwistedHttpClient(HttpClient): | |||
defer.returnValue(json.loads(body)) | |||
def _create_put_request(self, url, json_data, headers_dict={}): | |||
""" Wrapper of _create_request to issue a PUT request | |||
""" | |||
"""Wrapper of _create_request to issue a PUT request""" | |||
if "Content-Type" not in headers_dict: | |||
raise defer.error(RuntimeError("Must include Content-Type header for PUTs")) | |||
@@ -98,8 +96,7 @@ class TwistedHttpClient(HttpClient): | |||
) | |||
def _create_get_request(self, url, headers_dict={}): | |||
""" Wrapper of _create_request to issue a GET request | |||
""" | |||
"""Wrapper of _create_request to issue a GET request""" | |||
return self._create_request("GET", url, headers_dict=headers_dict) | |||
@defer.inlineCallbacks | |||
@@ -127,8 +124,7 @@ class TwistedHttpClient(HttpClient): | |||
@defer.inlineCallbacks | |||
def _create_request(self, method, url, producer=None, headers_dict={}): | |||
""" Creates and sends a request to the given url | |||
""" | |||
"""Creates and sends a request to the given url""" | |||
headers_dict["User-Agent"] = ["Synapse Cmd Client"] | |||
retries_left = 5 | |||
@@ -185,8 +181,7 @@ class _RawProducer: | |||
class _JsonProducer: | |||
""" Used by the twisted http client to create the HTTP body from json | |||
""" | |||
"""Used by the twisted http client to create the HTTP body from json""" | |||
def __init__(self, jsn): | |||
self.data = jsn | |||
@@ -63,8 +63,7 @@ class CursesStdIO: | |||
self.redraw() | |||
def redraw(self): | |||
""" method for redisplaying lines | |||
based on internal list of lines """ | |||
"""method for redisplaying lines based on internal list of lines""" | |||
self.stdscr.clear() | |||
self.paintStatus(self.statusText) | |||
@@ -56,7 +56,7 @@ def excpetion_errback(failure): | |||
class InputOutput: | |||
""" This is responsible for basic I/O so that a user can interact with | |||
"""This is responsible for basic I/O so that a user can interact with | |||
the example app. | |||
""" | |||
@@ -68,8 +68,7 @@ class InputOutput: | |||
self.server = server | |||
def on_line(self, line): | |||
""" This is where we process commands. | |||
""" | |||
"""This is where we process commands.""" | |||
try: | |||
m = re.match(r"^join (\S+)$", line) | |||
@@ -133,7 +132,7 @@ class IOLoggerHandler(logging.Handler): | |||
class Room: | |||
""" Used to store (in memory) the current membership state of a room, and | |||
"""Used to store (in memory) the current membership state of a room, and | |||
which home servers we should send PDUs associated with the room to. | |||
""" | |||
@@ -148,8 +147,7 @@ class Room: | |||
self.have_got_metadata = False | |||
def add_participant(self, participant): | |||
""" Someone has joined the room | |||
""" | |||
"""Someone has joined the room""" | |||
self.participants.add(participant) | |||
self.invited.discard(participant) | |||
@@ -160,14 +158,13 @@ class Room: | |||
self.oldest_server = server | |||
def add_invited(self, invitee): | |||
""" Someone has been invited to the room | |||
""" | |||
"""Someone has been invited to the room""" | |||
self.invited.add(invitee) | |||
self.servers.add(origin_from_ucid(invitee)) | |||
class HomeServer(ReplicationHandler): | |||
""" A very basic home server implentation that allows people to join a | |||
"""A very basic home server implentation that allows people to join a | |||
room and then invite other people. | |||
""" | |||
@@ -181,8 +178,7 @@ class HomeServer(ReplicationHandler): | |||
self.output = output | |||
def on_receive_pdu(self, pdu): | |||
""" We just received a PDU | |||
""" | |||
"""We just received a PDU""" | |||
pdu_type = pdu.pdu_type | |||
if pdu_type == "sy.room.message": | |||
@@ -199,23 +195,20 @@ class HomeServer(ReplicationHandler): | |||
) | |||
def _on_message(self, pdu): | |||
""" We received a message | |||
""" | |||
"""We received a message""" | |||
self.output.print_line( | |||
"#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"]) | |||
) | |||
def _on_join(self, context, joinee): | |||
""" Someone has joined a room, either a remote user or a local user | |||
""" | |||
"""Someone has joined a room, either a remote user or a local user""" | |||
room = self._get_or_create_room(context) | |||
room.add_participant(joinee) | |||
self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED")) | |||
def _on_invite(self, origin, context, invitee): | |||
""" Someone has been invited | |||
""" | |||
"""Someone has been invited""" | |||
room = self._get_or_create_room(context) | |||
room.add_invited(invitee) | |||
@@ -228,8 +221,7 @@ class HomeServer(ReplicationHandler): | |||
@defer.inlineCallbacks | |||
def send_message(self, room_name, sender, body): | |||
""" Send a message to a room! | |||
""" | |||
"""Send a message to a room!""" | |||
destinations = yield self.get_servers_for_context(room_name) | |||
try: | |||
@@ -247,8 +239,7 @@ class HomeServer(ReplicationHandler): | |||
@defer.inlineCallbacks | |||
def join_room(self, room_name, sender, joinee): | |||
""" Join a room! | |||
""" | |||
"""Join a room!""" | |||
self._on_join(room_name, joinee) | |||
destinations = yield self.get_servers_for_context(room_name) | |||
@@ -269,8 +260,7 @@ class HomeServer(ReplicationHandler): | |||
@defer.inlineCallbacks | |||
def invite_to_room(self, room_name, sender, invitee): | |||
""" Invite someone to a room! | |||
""" | |||
"""Invite someone to a room!""" | |||
self._on_invite(self.server_name, room_name, invitee) | |||
destinations = yield self.get_servers_for_context(room_name) | |||
@@ -193,15 +193,12 @@ class TrivialXmppClient: | |||
time.sleep(7) | |||
print("SSRC spammer started") | |||
while self.running: | |||
ssrcMsg = ( | |||
"<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" | |||
% { | |||
"tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), | |||
"nick": self.userId, | |||
"assrc": self.ssrcs["audio"], | |||
"vssrc": self.ssrcs["video"], | |||
} | |||
) | |||
ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % { | |||
"tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), | |||
"nick": self.userId, | |||
"assrc": self.ssrcs["audio"], | |||
"vssrc": self.ssrcs["video"], | |||
} | |||
res = self.sendIq(ssrcMsg) | |||
print("reply from ssrc announce: ", res) | |||
time.sleep(10) | |||
@@ -8,16 +8,16 @@ errors in code. | |||
The necessary tools are detailed below. | |||
First install them with: | |||
pip install -e ".[lint,mypy]" | |||
- **black** | |||
The Synapse codebase uses [black](https://pypi.org/project/black/) | |||
as an opinionated code formatter, ensuring all comitted code is | |||
properly formatted. | |||
First install `black` with: | |||
pip install --upgrade black | |||
Have `black` auto-format your code (it shouldn't change any | |||
functionality) with: | |||
@@ -28,10 +28,6 @@ The necessary tools are detailed below. | |||
`flake8` is a code checking tool. We require code to pass `flake8` | |||
before being merged into the codebase. | |||
Install `flake8` with: | |||
pip install --upgrade flake8 flake8-comprehensions | |||
Check all application and test code with: | |||
flake8 synapse tests | |||
@@ -41,10 +37,6 @@ The necessary tools are detailed below. | |||
`isort` ensures imports are nicely formatted, and can suggest and | |||
auto-fix issues such as double-importing. | |||
Install `isort` with: | |||
pip install --upgrade isort | |||
Auto-fix imports with: | |||
isort -rc synapse tests | |||
@@ -87,7 +87,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: | |||
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg. | |||
signature = signature.copy_modified( | |||
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds, | |||
arg_types=arg_types, | |||
arg_names=arg_names, | |||
arg_kinds=arg_kinds, | |||
) | |||
return signature | |||
@@ -97,7 +97,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS) | |||
# We pin black so that our tests don't start failing on new releases. | |||
CONDITIONAL_REQUIREMENTS["lint"] = [ | |||
"isort==5.7.0", | |||
"black==19.10b0", | |||
"black==20.8b1", | |||
"flake8-comprehensions", | |||
"flake8", | |||
] | |||
@@ -89,12 +89,16 @@ class SortedDict(Dict[_KT, _VT]): | |||
def __reduce__( | |||
self, | |||
) -> Tuple[ | |||
Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]], | |||
Type[SortedDict[_KT, _VT]], | |||
Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]], | |||
]: ... | |||
def __repr__(self) -> str: ... | |||
def _check(self) -> None: ... | |||
def islice( | |||
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, | |||
self, | |||
start: Optional[int] = ..., | |||
stop: Optional[int] = ..., | |||
reverse=bool, | |||
) -> Iterator[_KT]: ... | |||
def bisect_left(self, value: _KT) -> int: ... | |||
def bisect_right(self, value: _KT) -> int: ... | |||
@@ -31,7 +31,9 @@ class SortedList(MutableSequence[_T]): | |||
DEFAULT_LOAD_FACTOR: int = ... | |||
def __init__( | |||
self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ..., | |||
self, | |||
iterable: Optional[Iterable[_T]] = ..., | |||
key: Optional[_Key[_T]] = ..., | |||
): ... | |||
# NB: currently mypy does not honour return type, see mypy #3307 | |||
@overload | |||
@@ -76,10 +78,18 @@ class SortedList(MutableSequence[_T]): | |||
def __len__(self) -> int: ... | |||
def reverse(self) -> None: ... | |||
def islice( | |||
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, | |||
self, | |||
start: Optional[int] = ..., | |||
stop: Optional[int] = ..., | |||
reverse=bool, | |||
) -> Iterator[_T]: ... | |||
def _islice( | |||
self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool, | |||
self, | |||
min_pos: int, | |||
min_idx: int, | |||
max_pos: int, | |||
max_idx: int, | |||
reverse: bool, | |||
) -> Iterator[_T]: ... | |||
def irange( | |||
self, | |||
@@ -168,7 +168,7 @@ class Auth: | |||
rights: str = "access", | |||
allow_expired: bool = False, | |||
) -> synapse.types.Requester: | |||
""" Get a registered user's ID. | |||
"""Get a registered user's ID. | |||
Args: | |||
request: An HTTP request with an access_token query parameter. | |||
@@ -294,9 +294,12 @@ class Auth: | |||
return user_id, app_service | |||
async def get_user_by_access_token( | |||
self, token: str, rights: str = "access", allow_expired: bool = False, | |||
self, | |||
token: str, | |||
rights: str = "access", | |||
allow_expired: bool = False, | |||
) -> TokenLookupResult: | |||
""" Validate access token and get user_id from it | |||
"""Validate access token and get user_id from it | |||
Args: | |||
token: The access token to get the user by | |||
@@ -489,7 +492,7 @@ class Auth: | |||
return service | |||
async def is_server_admin(self, user: UserID) -> bool: | |||
""" Check if the given user is a local server admin. | |||
"""Check if the given user is a local server admin. | |||
Args: | |||
user: user to check | |||
@@ -500,7 +503,10 @@ class Auth: | |||
return await self.store.is_server_admin(user) | |||
def compute_auth_events( | |||
self, event, current_state_ids: StateMap[str], for_verification: bool = False, | |||
self, | |||
event, | |||
current_state_ids: StateMap[str], | |||
for_verification: bool = False, | |||
) -> List[str]: | |||
"""Given an event and current state return the list of event IDs used | |||
to auth an event. | |||
@@ -128,8 +128,7 @@ class UserTypes: | |||
class RelationTypes: | |||
"""The types of relations known to this server. | |||
""" | |||
"""The types of relations known to this server.""" | |||
ANNOTATION = "m.annotation" | |||
REPLACE = "m.replace" | |||
@@ -390,8 +390,7 @@ class InvalidCaptchaError(SynapseError): | |||
class LimitExceededError(SynapseError): | |||
"""A client has sent too many requests and is being throttled. | |||
""" | |||
"""A client has sent too many requests and is being throttled.""" | |||
def __init__( | |||
self, | |||
@@ -408,8 +407,7 @@ class LimitExceededError(SynapseError): | |||
class RoomKeysVersionError(SynapseError): | |||
"""A client has tried to upload to a non-current version of the room_keys store | |||
""" | |||
"""A client has tried to upload to a non-current version of the room_keys store""" | |||
def __init__(self, current_version: str): | |||
""" | |||
@@ -426,7 +424,9 @@ class UnsupportedRoomVersionError(SynapseError): | |||
def __init__(self, msg: str = "Homeserver does not support this room version"): | |||
super().__init__( | |||
code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION, | |||
code=400, | |||
msg=msg, | |||
errcode=Codes.UNSUPPORTED_ROOM_VERSION, | |||
) | |||
@@ -461,8 +461,7 @@ class IncompatibleRoomVersionError(SynapseError): | |||
class PasswordRefusedError(SynapseError): | |||
"""A password has been refused, either during password reset/change or registration. | |||
""" | |||
"""A password has been refused, either during password reset/change or registration.""" | |||
def __init__( | |||
self, | |||
@@ -470,7 +469,9 @@ class PasswordRefusedError(SynapseError): | |||
errcode: str = Codes.WEAK_PASSWORD, | |||
): | |||
super().__init__( | |||
code=400, msg=msg, errcode=errcode, | |||
code=400, | |||
msg=msg, | |||
errcode=errcode, | |||
) | |||
@@ -493,7 +494,7 @@ class RequestSendFailed(RuntimeError): | |||
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs): | |||
""" Utility method for constructing an error response for client-server | |||
"""Utility method for constructing an error response for client-server | |||
interactions. | |||
Args: | |||
@@ -510,7 +511,7 @@ def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs): | |||
class FederationError(RuntimeError): | |||
""" This class is used to inform remote homeservers about erroneous | |||
"""This class is used to inform remote homeservers about erroneous | |||
PDUs they sent us. | |||
FATAL: The remote server could not interpret the source event. | |||
@@ -56,8 +56,7 @@ class UserPresenceState( | |||
@classmethod | |||
def default(cls, user_id): | |||
"""Returns a default presence state. | |||
""" | |||
"""Returns a default presence state.""" | |||
return cls( | |||
user_id=user_id, | |||
state=PresenceState.OFFLINE, | |||
@@ -58,7 +58,7 @@ def register_sighup(func, *args, **kwargs): | |||
def start_worker_reactor(appname, config, run_command=reactor.run): | |||
""" Run the reactor in the main process | |||
"""Run the reactor in the main process | |||
Daemonizes if necessary, and then configures some resources, before starting | |||
the reactor. Pulls configuration from the 'worker' settings in 'config'. | |||
@@ -93,7 +93,7 @@ def start_reactor( | |||
logger, | |||
run_command=reactor.run, | |||
): | |||
""" Run the reactor in the main process | |||
"""Run the reactor in the main process | |||
Daemonizes if necessary, and then configures some resources, before starting | |||
the reactor | |||
@@ -313,9 +313,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon | |||
refresh_certificate(hs) | |||
# Start the tracer | |||
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa | |||
hs | |||
) | |||
synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa | |||
# It is now safe to start your Synapse. | |||
hs.start_listening(listeners) | |||
@@ -370,8 +368,7 @@ def setup_sentry(hs): | |||
def setup_sdnotify(hs): | |||
"""Adds process state hooks to tell systemd what we are up to. | |||
""" | |||
"""Adds process state hooks to tell systemd what we are up to.""" | |||
# Tell systemd our state, if we're using it. This will silently fail if | |||
# we're not using systemd. | |||
@@ -405,8 +402,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100): | |||
class _LimitedHostnameResolver: | |||
"""Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups. | |||
""" | |||
"""Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.""" | |||
def __init__(self, resolver, max_dns_requests_in_flight): | |||
self._resolver = resolver | |||
@@ -421,8 +421,7 @@ class GenericWorkerPresence(BasePresenceHandler): | |||
] | |||
async def set_state(self, target_user, state, ignore_status_msg=False): | |||
"""Set the presence state of the user. | |||
""" | |||
"""Set the presence state of the user.""" | |||
presence = state["presence"] | |||
valid_presence = ( | |||
@@ -166,7 +166,10 @@ class ApplicationService: | |||
@cached(num_args=1, cache_context=True) | |||
async def matches_user_in_member_list( | |||
self, room_id: str, store: "DataStore", cache_context: _CacheContext, | |||
self, | |||
room_id: str, | |||
store: "DataStore", | |||
cache_context: _CacheContext, | |||
) -> bool: | |||
"""Check if this service is interested a room based upon it's membership | |||
@@ -227,7 +227,9 @@ class ApplicationServiceApi(SimpleHttpClient): | |||
try: | |||
await self.put_json( | |||
uri=uri, json_body=body, args={"access_token": service.hs_token}, | |||
uri=uri, | |||
json_body=body, | |||
args={"access_token": service.hs_token}, | |||
) | |||
sent_transactions_counter.labels(service.id).inc() | |||
sent_events_counter.labels(service.id).inc(len(events)) | |||
@@ -68,7 +68,7 @@ MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100 | |||
class ApplicationServiceScheduler: | |||
""" Public facing API for this module. Does the required DI to tie the | |||
"""Public facing API for this module. Does the required DI to tie the | |||
components together. This also serves as the "event_pool", which in this | |||
case is a simple array. | |||
""" | |||
@@ -224,7 +224,9 @@ class Config: | |||
return self.read_templates([filename])[0] | |||
def read_templates( | |||
self, filenames: List[str], custom_template_directory: Optional[str] = None, | |||
self, | |||
filenames: List[str], | |||
custom_template_directory: Optional[str] = None, | |||
) -> List[jinja2.Template]: | |||
"""Load a list of template files from disk using the given variables. | |||
@@ -264,7 +266,10 @@ class Config: | |||
# TODO: switch to synapse.util.templates.build_jinja_env | |||
loader = jinja2.FileSystemLoader(search_directories) | |||
env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),) | |||
env = jinja2.Environment( | |||
loader=loader, | |||
autoescape=jinja2.select_autoescape(), | |||
) | |||
# Update the environment with our custom filters | |||
env.filters.update( | |||
@@ -825,8 +830,7 @@ class ShardedWorkerHandlingConfig: | |||
instances = attr.ib(type=List[str]) | |||
def should_handle(self, instance_name: str, key: str) -> bool: | |||
"""Whether this instance is responsible for handling the given key. | |||
""" | |||
"""Whether this instance is responsible for handling the given key.""" | |||
# If multiple instances are not defined we always return true | |||
if not self.instances or len(self.instances) == 1: | |||
return True | |||
@@ -18,8 +18,7 @@ from ._base import Config | |||
class AuthConfig(Config): | |||
"""Password and login configuration | |||
""" | |||
"""Password and login configuration""" | |||
section = "auth" | |||
@@ -207,8 +207,7 @@ class DatabaseConfig(Config): | |||
) | |||
def get_single_database(self) -> DatabaseConnectionConfig: | |||
"""Returns the database if there is only one, useful for e.g. tests | |||
""" | |||
"""Returns the database if there is only one, useful for e.g. tests""" | |||
if not self.databases: | |||
raise Exception("More than one database exists") | |||
@@ -289,7 +289,8 @@ class EmailConfig(Config): | |||
self.email_notif_template_html, | |||
self.email_notif_template_text, | |||
) = self.read_templates( | |||
[notif_template_html, notif_template_text], template_dir, | |||
[notif_template_html, notif_template_text], | |||
template_dir, | |||
) | |||
self.email_notif_for_new_users = email_config.get( | |||
@@ -311,7 +312,8 @@ class EmailConfig(Config): | |||
self.account_validity_template_html, | |||
self.account_validity_template_text, | |||
) = self.read_templates( | |||
[expiry_template_html, expiry_template_text], template_dir, | |||
[expiry_template_html, expiry_template_text], | |||
template_dir, | |||
) | |||
subjects_config = email_config.get("subjects", {}) | |||
@@ -162,7 +162,10 @@ class LoggingConfig(Config): | |||
) | |||
logging_group.add_argument( | |||
"-f", "--log-file", dest="log_file", help=argparse.SUPPRESS, | |||
"-f", | |||
"--log-file", | |||
dest="log_file", | |||
help=argparse.SUPPRESS, | |||
) | |||
def generate_files(self, config, config_dir_path): | |||
@@ -355,9 +355,10 @@ def _parse_oidc_config_dict( | |||
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) | |||
ump_config.setdefault("config", {}) | |||
(user_mapping_provider_class, user_mapping_provider_config,) = load_module( | |||
ump_config, config_path + ("user_mapping_provider",) | |||
) | |||
( | |||
user_mapping_provider_class, | |||
user_mapping_provider_config, | |||
) = load_module(ump_config, config_path + ("user_mapping_provider",)) | |||
# Ensure loaded user mapping module has defined all necessary methods | |||
required_methods = [ | |||
@@ -372,7 +373,11 @@ def _parse_oidc_config_dict( | |||
if missing_methods: | |||
raise ConfigError( | |||
"Class %s is missing required " | |||
"methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),), | |||
"methods: %s" | |||
% ( | |||
user_mapping_provider_class, | |||
", ".join(missing_methods), | |||
), | |||
config_path + ("user_mapping_provider", "module"), | |||
) | |||
@@ -52,7 +52,7 @@ MediaStorageProviderConfig = namedtuple( | |||
def parse_thumbnail_requirements(thumbnail_sizes): | |||
""" Takes a list of dictionaries with "width", "height", and "method" keys | |||
"""Takes a list of dictionaries with "width", "height", and "method" keys | |||
and creates a map from image media types to the thumbnail size, thumbnailing | |||
method, and thumbnail media type to precalculate | |||
@@ -52,7 +52,12 @@ def _6to4(network: IPNetwork) -> IPNetwork: | |||
hex_network = hex(network.first)[2:] | |||
hex_network = ("0" * (8 - len(hex_network))) + hex_network | |||
return IPNetwork( | |||
"2002:%s:%s::/%d" % (hex_network[:4], hex_network[4:], 16 + network.prefixlen,) | |||
"2002:%s:%s::/%d" | |||
% ( | |||
hex_network[:4], | |||
hex_network[4:], | |||
16 + network.prefixlen, | |||
) | |||
) | |||
@@ -254,7 +259,8 @@ class ServerConfig(Config): | |||
# Whether to require sharing a room with a user to retrieve their | |||
# profile data | |||
self.limit_profile_requests_to_users_who_share_rooms = config.get( | |||
"limit_profile_requests_to_users_who_share_rooms", False, | |||
"limit_profile_requests_to_users_who_share_rooms", | |||
False, | |||
) | |||
if "restrict_public_rooms_to_local_users" in config and ( | |||
@@ -614,7 +620,9 @@ class ServerConfig(Config): | |||
if manhole: | |||
self.listeners.append( | |||
ListenerConfig( | |||
port=manhole, bind_addresses=["127.0.0.1"], type="manhole", | |||
port=manhole, | |||
bind_addresses=["127.0.0.1"], | |||
type="manhole", | |||
) | |||
) | |||
@@ -650,7 +658,8 @@ class ServerConfig(Config): | |||
# and letting the client know which email address is bound to an account and | |||
# which one isn't. | |||
self.request_token_inhibit_3pid_errors = config.get( | |||
"request_token_inhibit_3pid_errors", False, | |||
"request_token_inhibit_3pid_errors", | |||
False, | |||
) | |||
# List of users trialing the new experimental default push rules. This setting is | |||
@@ -35,8 +35,7 @@ class SsoAttributeRequirement: | |||
class SSOConfig(Config): | |||
"""SSO Configuration | |||
""" | |||
"""SSO Configuration""" | |||
section = "sso" | |||
@@ -33,8 +33,7 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: | |||
@attr.s | |||
class InstanceLocationConfig: | |||
"""The host and port to talk to an instance via HTTP replication. | |||
""" | |||
"""The host and port to talk to an instance via HTTP replication.""" | |||
host = attr.ib(type=str) | |||
port = attr.ib(type=int) | |||
@@ -54,13 +53,19 @@ class WriterLocations: | |||
) | |||
typing = attr.ib(default="master", type=str) | |||
to_device = attr.ib( | |||
default=["master"], type=List[str], converter=_instance_to_list_converter, | |||
default=["master"], | |||
type=List[str], | |||
converter=_instance_to_list_converter, | |||
) | |||
account_data = attr.ib( | |||
default=["master"], type=List[str], converter=_instance_to_list_converter, | |||
default=["master"], | |||
type=List[str], | |||
converter=_instance_to_list_converter, | |||
) | |||
receipts = attr.ib( | |||
default=["master"], type=List[str], converter=_instance_to_list_converter, | |||
default=["master"], | |||
type=List[str], | |||
converter=_instance_to_list_converter, | |||
) | |||
@@ -107,7 +112,9 @@ class WorkerConfig(Config): | |||
if manhole: | |||
self.worker_listeners.append( | |||
ListenerConfig( | |||
port=manhole, bind_addresses=["127.0.0.1"], type="manhole", | |||
port=manhole, | |||
bind_addresses=["127.0.0.1"], | |||
type="manhole", | |||
) | |||
) | |||
@@ -42,7 +42,7 @@ def check( | |||
do_sig_check: bool = True, | |||
do_size_check: bool = True, | |||
) -> None: | |||
""" Checks if this event is correctly authed. | |||
"""Checks if this event is correctly authed. | |||
Args: | |||
room_version_obj: the version of the room | |||
@@ -423,7 +423,9 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool: | |||
def check_redaction( | |||
room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], | |||
room_version_obj: RoomVersion, | |||
event: EventBase, | |||
auth_events: StateMap[EventBase], | |||
) -> bool: | |||
"""Check whether the event sender is allowed to redact the target event. | |||
@@ -459,7 +461,9 @@ def check_redaction( | |||
def _check_power_levels( | |||
room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], | |||
room_version_obj: RoomVersion, | |||
event: EventBase, | |||
auth_events: StateMap[EventBase], | |||
) -> None: | |||
user_list = event.content.get("users", {}) | |||
# Validate users | |||
@@ -98,7 +98,9 @@ class EventBuilder: | |||
return self._state_key is not None | |||
async def build( | |||
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]], | |||
self, | |||
prev_event_ids: List[str], | |||
auth_event_ids: Optional[List[str]], | |||
) -> EventBase: | |||
"""Transform into a fully signed and hashed event | |||
@@ -341,8 +341,7 @@ def _encode_state_dict(state_dict): | |||
def _decode_state_dict(input): | |||
"""Decodes a state dict encoded using `_encode_state_dict` above | |||
""" | |||
"""Decodes a state dict encoded using `_encode_state_dict` above""" | |||
if input is None: | |||
return None | |||
@@ -40,7 +40,8 @@ class ThirdPartyEventRules: | |||
if module is not None: | |||
self.third_party_rules = module( | |||
config=config, module_api=hs.get_module_api(), | |||
config=config, | |||
module_api=hs.get_module_api(), | |||
) | |||
async def check_event_allowed( | |||
@@ -34,7 +34,7 @@ SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.") | |||
def prune_event(event: EventBase) -> EventBase: | |||
""" Returns a pruned version of the given event, which removes all keys we | |||
"""Returns a pruned version of the given event, which removes all keys we | |||
don't know about or think could potentially be dodgy. | |||
This is used when we "redact" an event. We want to remove all fields that | |||
@@ -750,7 +750,11 @@ class FederationClient(FederationBase): | |||
return resp[1] | |||
async def send_invite( | |||
self, destination: str, room_id: str, event_id: str, pdu: EventBase, | |||
self, | |||
destination: str, | |||
room_id: str, | |||
event_id: str, | |||
pdu: EventBase, | |||
) -> EventBase: | |||
room_version = await self.store.get_room_version(room_id) | |||
@@ -85,7 +85,8 @@ received_queries_counter = Counter( | |||
) | |||
pdu_process_time = Histogram( | |||
"synapse_federation_server_pdu_process_time", "Time taken to process an event", | |||
"synapse_federation_server_pdu_process_time", | |||
"Time taken to process an event", | |||
) | |||
@@ -204,7 +205,7 @@ class FederationServer(FederationBase): | |||
async def _handle_incoming_transaction( | |||
self, origin: str, transaction: Transaction, request_time: int | |||
) -> Tuple[int, Dict[str, Any]]: | |||
""" Process an incoming transaction and return the HTTP response | |||
"""Process an incoming transaction and return the HTTP response | |||
Args: | |||
origin: the server making the request | |||
@@ -373,8 +374,7 @@ class FederationServer(FederationBase): | |||
return pdu_results | |||
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): | |||
"""Process the EDUs in a received transaction. | |||
""" | |||
"""Process the EDUs in a received transaction.""" | |||
async def _process_edu(edu_dict): | |||
received_edus_counter.inc() | |||
@@ -437,7 +437,10 @@ class FederationServer(FederationBase): | |||
raise AuthError(403, "Host not in room.") | |||
resp = await self._state_ids_resp_cache.wrap( | |||
(room_id, event_id), self._on_state_ids_request_compute, room_id, event_id, | |||
(room_id, event_id), | |||
self._on_state_ids_request_compute, | |||
room_id, | |||
event_id, | |||
) | |||
return 200, resp | |||
@@ -679,7 +682,7 @@ class FederationServer(FederationBase): | |||
) | |||
async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: | |||
""" Process a PDU received in a federation /send/ transaction. | |||
"""Process a PDU received in a federation /send/ transaction. | |||
If the event is invalid, then this method throws a FederationError. | |||
(The error will then be logged and sent back to the sender (which | |||
@@ -906,13 +909,11 @@ class FederationHandlerRegistry: | |||
self.query_handlers[query_type] = handler | |||
def register_instance_for_edu(self, edu_type: str, instance_name: str): | |||
"""Register that the EDU handler is on a different instance than master. | |||
""" | |||
"""Register that the EDU handler is on a different instance than master.""" | |||
self._edu_type_to_instance[edu_type] = [instance_name] | |||
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): | |||
"""Register that the EDU handler is on multiple instances. | |||
""" | |||
"""Register that the EDU handler is on multiple instances.""" | |||
self._edu_type_to_instance[edu_type] = instance_names | |||
async def on_edu(self, edu_type: str, origin: str, content: dict): | |||
@@ -30,8 +30,7 @@ logger = logging.getLogger(__name__) | |||
class TransactionActions: | |||
""" Defines persistence actions that relate to handling Transactions. | |||
""" | |||
"""Defines persistence actions that relate to handling Transactions.""" | |||
def __init__(self, datastore): | |||
self.store = datastore | |||
@@ -57,8 +56,7 @@ class TransactionActions: | |||
async def set_response( | |||
self, origin: str, transaction: Transaction, code: int, response: JsonDict | |||
) -> None: | |||
"""Persist how we responded to a transaction. | |||
""" | |||
"""Persist how we responded to a transaction.""" | |||
transaction_id = transaction.transaction_id # type: ignore | |||
if not transaction_id: | |||
raise RuntimeError("Cannot persist a transaction with no transaction_id") | |||
@@ -468,8 +468,7 @@ class KeyedEduRow( | |||
class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu | |||
"""Streams EDUs that don't have keys. See KeyedEduRow | |||
""" | |||
"""Streams EDUs that don't have keys. See KeyedEduRow""" | |||
TypeId = "e" | |||
@@ -519,7 +518,10 @@ def process_rows_for_federation(transaction_queue, rows): | |||
# them into the appropriate collection and then send them off. | |||
buff = ParsedFederationStreamData( | |||
presence=[], presence_destinations=[], keyed_edus={}, edus={}, | |||
presence=[], | |||
presence_destinations=[], | |||
keyed_edus={}, | |||
edus={}, | |||
) | |||
# Parse the rows in the stream and add to the buffer | |||
@@ -328,7 +328,9 @@ class FederationSender: | |||
# to allow us to perform catch-up later on if the remote is unreachable | |||
# for a while. | |||
await self.store.store_destination_rooms_entries( | |||
destinations, pdu.room_id, pdu.internal_metadata.stream_ordering, | |||
destinations, | |||
pdu.room_id, | |||
pdu.internal_metadata.stream_ordering, | |||
) | |||
for destination in destinations: | |||
@@ -475,7 +477,7 @@ class FederationSender: | |||
self, states: List[UserPresenceState], destinations: List[str] | |||
) -> None: | |||
"""Send the given presence states to the given destinations. | |||
destinations (list[str]) | |||
destinations (list[str]) | |||
""" | |||
if not states or not self.hs.config.use_presence: | |||
@@ -616,8 +618,8 @@ class FederationSender: | |||
last_processed = None # type: Optional[str] | |||
while True: | |||
destinations_to_wake = await self.store.get_catch_up_outstanding_destinations( | |||
last_processed | |||
destinations_to_wake = ( | |||
await self.store.get_catch_up_outstanding_destinations(last_processed) | |||
) | |||
if not destinations_to_wake: | |||
@@ -85,7 +85,8 @@ class PerDestinationQueue: | |||
# processing. We have a guard in `attempt_new_transaction` that | |||
# ensure we don't start sending stuff. | |||
logger.error( | |||
"Create a per destination queue for %s on wrong worker", destination, | |||
"Create a per destination queue for %s on wrong worker", | |||
destination, | |||
) | |||
self._should_send_on_this_instance = False | |||
@@ -440,8 +441,10 @@ class PerDestinationQueue: | |||
if first_catch_up_check: | |||
# first catchup so get last_successful_stream_ordering from database | |||
self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering( | |||
self._destination | |||
self._last_successful_stream_ordering = ( | |||
await self._store.get_destination_last_successful_stream_ordering( | |||
self._destination | |||
) | |||
) | |||
if self._last_successful_stream_ordering is None: | |||
@@ -457,7 +460,8 @@ class PerDestinationQueue: | |||
# get at most 50 catchup room/PDUs | |||
while True: | |||
event_ids = await self._store.get_catch_up_room_event_ids( | |||
self._destination, self._last_successful_stream_ordering, | |||
self._destination, | |||
self._last_successful_stream_ordering, | |||
) | |||
if not event_ids: | |||
@@ -65,7 +65,10 @@ class TransactionManager: | |||
@measure_func("_send_new_transaction") | |||
async def send_new_transaction( | |||
self, destination: str, pdus: List[EventBase], edus: List[Edu], | |||
self, | |||
destination: str, | |||
pdus: List[EventBase], | |||
edus: List[Edu], | |||
) -> bool: | |||
""" | |||
Args: | |||
@@ -39,7 +39,7 @@ class TransportLayerClient: | |||
@log_function | |||
def get_room_state_ids(self, destination, room_id, event_id): | |||
""" Requests all state for a given room from the given server at the | |||
"""Requests all state for a given room from the given server at the | |||
given event. Returns the state's event_id's | |||
Args: | |||
@@ -63,7 +63,7 @@ class TransportLayerClient: | |||
@log_function | |||
def get_event(self, destination, event_id, timeout=None): | |||
""" Requests the pdu with give id and origin from the given server. | |||
"""Requests the pdu with give id and origin from the given server. | |||
Args: | |||
destination (str): The host name of the remote homeserver we want | |||
@@ -84,7 +84,7 @@ class TransportLayerClient: | |||
@log_function | |||
def backfill(self, destination, room_id, event_tuples, limit): | |||
""" Requests `limit` previous PDUs in a given context before list of | |||
"""Requests `limit` previous PDUs in a given context before list of | |||
PDUs. | |||
Args: | |||
@@ -118,7 +118,7 @@ class TransportLayerClient: | |||
@log_function | |||
async def send_transaction(self, transaction, json_data_callback=None): | |||
""" Sends the given Transaction to its destination | |||
"""Sends the given Transaction to its destination | |||
Args: | |||
transaction (Transaction) | |||
@@ -551,8 +551,7 @@ class TransportLayerClient: | |||
@log_function | |||
def get_group_profile(self, destination, group_id, requester_user_id): | |||
"""Get a group profile | |||
""" | |||
"""Get a group profile""" | |||
path = _create_v1_path("/groups/%s/profile", group_id) | |||
return self.client.get_json( | |||
@@ -584,8 +583,7 @@ class TransportLayerClient: | |||
@log_function | |||
def get_group_summary(self, destination, group_id, requester_user_id): | |||
"""Get a group summary | |||
""" | |||
"""Get a group summary""" | |||
path = _create_v1_path("/groups/%s/summary", group_id) | |||
return self.client.get_json( | |||
@@ -597,8 +595,7 @@ class TransportLayerClient: | |||
@log_function | |||
def get_rooms_in_group(self, destination, group_id, requester_user_id): | |||
"""Get all rooms in a group | |||
""" | |||
"""Get all rooms in a group""" | |||
path = _create_v1_path("/groups/%s/rooms", group_id) | |||
return self.client.get_json( | |||
@@ -611,8 +608,7 @@ class TransportLayerClient: | |||
def add_room_to_group( | |||
self, destination, group_id, requester_user_id, room_id, content | |||
): | |||
"""Add a room to a group | |||
""" | |||
"""Add a room to a group""" | |||
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) | |||
return self.client.post_json( | |||
@@ -626,8 +622,7 @@ class TransportLayerClient: | |||
def update_room_in_group( | |||
self, destination, group_id, requester_user_id, room_id, config_key, content | |||
): | |||
"""Update room in group | |||
""" | |||
"""Update room in group""" | |||
path = _create_v1_path( | |||
"/groups/%s/room/%s/config/%s", group_id, room_id, config_key | |||
) | |||
@@ -641,8 +636,7 @@ class TransportLayerClient: | |||
) | |||
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): | |||
"""Remove a room from a group | |||
""" | |||
"""Remove a room from a group""" | |||
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) | |||
return self.client.delete_json( | |||
@@ -654,8 +648,7 @@ class TransportLayerClient: | |||
@log_function | |||
def get_users_in_group(self, destination, group_id, requester_user_id): | |||
"""Get users in a group | |||
""" | |||
"""Get users in a group""" | |||
path = _create_v1_path("/groups/%s/users", group_id) | |||
return self.client.get_json( | |||
@@ -667,8 +660,7 @@ class TransportLayerClient: | |||
@log_function | |||
def get_invited_users_in_group(self, destination, group_id, requester_user_id): | |||
"""Get users that have been invited to a group | |||
""" | |||
"""Get users that have been invited to a group""" | |||
path = _create_v1_path("/groups/%s/invited_users", group_id) | |||
return self.client.get_json( | |||
@@ -680,8 +672,7 @@ class TransportLayerClient: | |||
@log_function | |||
def accept_group_invite(self, destination, group_id, user_id, content): | |||
"""Accept a group invite | |||
""" | |||
"""Accept a group invite""" | |||
path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) | |||
return self.client.post_json( | |||
@@ -690,8 +681,7 @@ class TransportLayerClient: | |||
@log_function | |||
def join_group(self, destination, group_id, user_id, content): | |||
"""Attempts to join a group | |||
""" | |||
"""Attempts to join a group""" | |||
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) | |||
return self.client.post_json( | |||
@@ -702,8 +692,7 @@ class TransportLayerClient: | |||
def invite_to_group( | |||
self, destination, group_id, user_id, requester_user_id, content | |||
): | |||
"""Invite a user to a group | |||
""" | |||
"""Invite a user to a group""" | |||
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) | |||
return self.client.post_json( | |||
@@ -730,8 +719,7 @@ class TransportLayerClient: | |||
def remove_user_from_group( | |||
self, destination, group_id, requester_user_id, user_id, content | |||
): | |||
"""Remove a user from a group | |||
""" | |||
"""Remove a user from a group""" | |||
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) | |||
return self.client.post_json( | |||
@@ -772,8 +760,7 @@ class TransportLayerClient: | |||
def update_group_summary_room( | |||
self, destination, group_id, user_id, room_id, category_id, content | |||
): | |||
"""Update a room entry in a group summary | |||
""" | |||
"""Update a room entry in a group summary""" | |||
if category_id: | |||
path = _create_v1_path( | |||
"/groups/%s/summary/categories/%s/rooms/%s", | |||
@@ -796,8 +783,7 @@ class TransportLayerClient: | |||
def delete_group_summary_room( | |||
self, destination, group_id, user_id, room_id, category_id | |||
): | |||
"""Delete a room entry in a group summary | |||
""" | |||
"""Delete a room entry in a group summary""" | |||
if category_id: | |||
path = _create_v1_path( | |||
"/groups/%s/summary/categories/%s/rooms/%s", | |||
@@ -817,8 +803,7 @@ class TransportLayerClient: | |||
@log_function | |||
def get_group_categories(self, destination, group_id, requester_user_id): | |||
"""Get all categories in a group | |||
""" | |||
"""Get all categories in a group""" | |||
path = _create_v1_path("/groups/%s/categories", group_id) | |||
return self.client.get_json( | |||
@@ -830,8 +815,7 @@ class TransportLayerClient: | |||
@log_function | |||
def get_group_category(self, destination, group_id, requester_user_id, category_id): | |||
"""Get category info in a group | |||
""" | |||
"""Get category info in a group""" | |||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) | |||
return self.client.get_json( | |||
@@ -845,8 +829,7 @@ class TransportLayerClient: | |||
def update_group_category( | |||
self, destination, group_id, requester_user_id, category_id, content | |||
): | |||
"""Update a category in a group | |||
""" | |||
"""Update a category in a group""" | |||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) | |||
return self.client.post_json( | |||
@@ -861,8 +844,7 @@ class TransportLayerClient: | |||
def delete_group_category( | |||
self, destination, group_id, requester_user_id, category_id | |||
): | |||
"""Delete a category in a group | |||
""" | |||
"""Delete a category in a group""" | |||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) | |||
return self.client.delete_json( | |||
@@ -874,8 +856,7 @@ class TransportLayerClient: | |||
@log_function | |||
def get_group_roles(self, destination, group_id, requester_user_id): | |||
"""Get all roles in a group | |||
""" | |||
"""Get all roles in a group""" | |||
path = _create_v1_path("/groups/%s/roles", group_id) | |||
return self.client.get_json( | |||
@@ -887,8 +868,7 @@ class TransportLayerClient: | |||
@log_function | |||
def get_group_role(self, destination, group_id, requester_user_id, role_id): | |||
"""Get a roles info | |||
""" | |||
"""Get a roles info""" | |||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) | |||
return self.client.get_json( | |||
@@ -902,8 +882,7 @@ class TransportLayerClient: | |||
def update_group_role( | |||
self, destination, group_id, requester_user_id, role_id, content | |||
): | |||
"""Update a role in a group | |||
""" | |||
"""Update a role in a group""" | |||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) | |||
return self.client.post_json( | |||
@@ -916,8 +895,7 @@ class TransportLayerClient: | |||
@log_function | |||
def delete_group_role(self, destination, group_id, requester_user_id, role_id): | |||
"""Delete a role in a group | |||
""" | |||
"""Delete a role in a group""" | |||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) | |||
return self.client.delete_json( | |||
@@ -931,8 +909,7 @@ class TransportLayerClient: | |||
def update_group_summary_user( | |||
self, destination, group_id, requester_user_id, user_id, role_id, content | |||
): | |||
"""Update a users entry in a group | |||
""" | |||
"""Update a users entry in a group""" | |||
if role_id: | |||
path = _create_v1_path( | |||
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id | |||
@@ -950,8 +927,7 @@ class TransportLayerClient: | |||
@log_function | |||
def set_group_join_policy(self, destination, group_id, requester_user_id, content): | |||
"""Sets the join policy for a group | |||
""" | |||
"""Sets the join policy for a group""" | |||
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) | |||
return self.client.put_json( | |||
@@ -966,8 +942,7 @@ class TransportLayerClient: | |||
def delete_group_summary_user( | |||
self, destination, group_id, requester_user_id, user_id, role_id | |||
): | |||
"""Delete a users entry in a group | |||
""" | |||
"""Delete a users entry in a group""" | |||
if role_id: | |||
path = _create_v1_path( | |||
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id | |||
@@ -983,8 +958,7 @@ class TransportLayerClient: | |||
) | |||
def bulk_get_publicised_groups(self, destination, user_ids): | |||
"""Get the groups a list of users are publicising | |||
""" | |||
"""Get the groups a list of users are publicising""" | |||
path = _create_v1_path("/get_groups_publicised") | |||
@@ -364,7 +364,10 @@ class BaseFederationServlet: | |||
continue | |||
server.register_paths( | |||
method, (pattern,), self._wrap(code), self.__class__.__name__, | |||
method, | |||
(pattern,), | |||
self._wrap(code), | |||
self.__class__.__name__, | |||
) | |||
@@ -381,7 +384,7 @@ class FederationSendServlet(BaseFederationServlet): | |||
# This is when someone is trying to send us a bunch of data. | |||
async def on_PUT(self, origin, content, query, transaction_id): | |||
""" Called on PUT /send/<transaction_id>/ | |||
"""Called on PUT /send/<transaction_id>/ | |||
Args: | |||
request (twisted.web.http.Request): The HTTP request. | |||
@@ -855,8 +858,7 @@ class FederationVersionServlet(BaseFederationServlet): | |||
class FederationGroupsProfileServlet(BaseFederationServlet): | |||
"""Get/set the basic profile of a group on behalf of a user | |||
""" | |||
"""Get/set the basic profile of a group on behalf of a user""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/profile" | |||
@@ -895,8 +897,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet): | |||
class FederationGroupsRoomsServlet(BaseFederationServlet): | |||
"""Get the rooms in a group on behalf of a user | |||
""" | |||
"""Get the rooms in a group on behalf of a user""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/rooms" | |||
@@ -911,8 +912,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): | |||
class FederationGroupsAddRoomsServlet(BaseFederationServlet): | |||
"""Add/remove room from group | |||
""" | |||
"""Add/remove room from group""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" | |||
@@ -940,8 +940,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): | |||
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): | |||
"""Update room config in group | |||
""" | |||
"""Update room config in group""" | |||
PATH = ( | |||
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" | |||
@@ -961,8 +960,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): | |||
class FederationGroupsUsersServlet(BaseFederationServlet): | |||
"""Get the users in a group on behalf of a user | |||
""" | |||
"""Get the users in a group on behalf of a user""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/users" | |||
@@ -977,8 +975,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet): | |||
class FederationGroupsInvitedUsersServlet(BaseFederationServlet): | |||
"""Get the users that have been invited to a group | |||
""" | |||
"""Get the users that have been invited to a group""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/invited_users" | |||
@@ -995,8 +992,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet): | |||
class FederationGroupsInviteServlet(BaseFederationServlet): | |||
"""Ask a group server to invite someone to the group | |||
""" | |||
"""Ask a group server to invite someone to the group""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" | |||
@@ -1013,8 +1009,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet): | |||
class FederationGroupsAcceptInviteServlet(BaseFederationServlet): | |||
"""Accept an invitation from the group server | |||
""" | |||
"""Accept an invitation from the group server""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite" | |||
@@ -1028,8 +1023,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): | |||
class FederationGroupsJoinServlet(BaseFederationServlet): | |||
"""Attempt to join a group | |||
""" | |||
"""Attempt to join a group""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join" | |||
@@ -1043,8 +1037,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet): | |||
class FederationGroupsRemoveUserServlet(BaseFederationServlet): | |||
"""Leave or kick a user from the group | |||
""" | |||
"""Leave or kick a user from the group""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" | |||
@@ -1061,8 +1054,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): | |||
class FederationGroupsLocalInviteServlet(BaseFederationServlet): | |||
"""A group server has invited a local user | |||
""" | |||
"""A group server has invited a local user""" | |||
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" | |||
@@ -1076,8 +1068,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): | |||
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): | |||
"""A group server has removed a local user | |||
""" | |||
"""A group server has removed a local user""" | |||
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" | |||
@@ -1093,8 +1084,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): | |||
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): | |||
"""A group or user's server renews their attestation | |||
""" | |||
"""A group or user's server renews their attestation""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)" | |||
@@ -1156,8 +1146,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): | |||
class FederationGroupsCategoriesServlet(BaseFederationServlet): | |||
"""Get all categories for a group | |||
""" | |||
"""Get all categories for a group""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/categories/?" | |||
@@ -1172,8 +1161,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): | |||
class FederationGroupsCategoryServlet(BaseFederationServlet): | |||
"""Add/remove/get a category in a group | |||
""" | |||
"""Add/remove/get a category in a group""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)" | |||
@@ -1218,8 +1206,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): | |||
class FederationGroupsRolesServlet(BaseFederationServlet): | |||
"""Get roles in a group | |||
""" | |||
"""Get roles in a group""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/roles/?" | |||
@@ -1234,8 +1221,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet): | |||
class FederationGroupsRoleServlet(BaseFederationServlet): | |||
"""Add/remove/get a role in a group | |||
""" | |||
"""Add/remove/get a role in a group""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)" | |||
@@ -1325,8 +1311,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): | |||
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): | |||
"""Get roles in a group | |||
""" | |||
"""Get roles in a group""" | |||
PATH = "/get_groups_publicised" | |||
@@ -1339,8 +1324,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): | |||
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): | |||
"""Sets whether a group is joinable without an invite or knock | |||
""" | |||
"""Sets whether a group is joinable without an invite or knock""" | |||
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy" | |||
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) | |||
@attr.s(slots=True) | |||
class Edu(JsonEncodedObject): | |||
""" An Edu represents a piece of data sent from one homeserver to another. | |||
"""An Edu represents a piece of data sent from one homeserver to another. | |||
In comparison to Pdus, Edus are not persisted for a long time on disk, are | |||
not meaningful beyond a given pair of homeservers, and don't have an | |||
@@ -63,7 +63,7 @@ class Edu(JsonEncodedObject): | |||
class Transaction(JsonEncodedObject): | |||
""" A transaction is a list of Pdus and Edus to be sent to a remote home | |||
"""A transaction is a list of Pdus and Edus to be sent to a remote home | |||
server with some extra metadata. | |||
Example transaction:: | |||
@@ -99,7 +99,7 @@ class Transaction(JsonEncodedObject): | |||
] | |||
def __init__(self, transaction_id=None, pdus=[], **kwargs): | |||
""" If we include a list of pdus then we decode then as PDU's | |||
"""If we include a list of pdus then we decode then as PDU's | |||
automatically. | |||
""" | |||
@@ -111,7 +111,7 @@ class Transaction(JsonEncodedObject): | |||
@staticmethod | |||
def create_new(pdus, **kwargs): | |||
""" Used to create a new transaction. Will auto fill out | |||
"""Used to create a new transaction. Will auto fill out | |||
transaction_id and origin_server_ts keys. | |||
""" | |||
if "origin_server_ts" not in kwargs: | |||
@@ -61,8 +61,7 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 | |||
class GroupAttestationSigning: | |||
"""Creates and verifies group attestations. | |||
""" | |||
"""Creates and verifies group attestations.""" | |||
def __init__(self, hs): | |||
self.keyring = hs.get_keyring() | |||
@@ -125,8 +124,7 @@ class GroupAttestationSigning: | |||
class GroupAttestionRenewer: | |||
"""Responsible for sending and receiving attestation updates. | |||
""" | |||
"""Responsible for sending and receiving attestation updates.""" | |||
def __init__(self, hs): | |||
self.clock = hs.get_clock() | |||
@@ -142,8 +140,7 @@ class GroupAttestionRenewer: | |||
) | |||
async def on_renew_attestation(self, group_id, user_id, content): | |||
"""When a remote updates an attestation | |||
""" | |||
"""When a remote updates an attestation""" | |||
attestation = content["attestation"] | |||
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): | |||
@@ -161,8 +158,7 @@ class GroupAttestionRenewer: | |||
return run_as_background_process("renew_attestations", self._renew_attestations) | |||
async def _renew_attestations(self): | |||
"""Called periodically to check if we need to update any of our attestations | |||
""" | |||
"""Called periodically to check if we need to update any of our attestations""" | |||
now = self.clock.time_msec() | |||
@@ -165,16 +165,14 @@ class GroupsServerWorkerHandler: | |||
} | |||
async def get_group_categories(self, group_id, requester_user_id): | |||
"""Get all categories in a group (as seen by user) | |||
""" | |||
"""Get all categories in a group (as seen by user)""" | |||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) | |||
categories = await self.store.get_group_categories(group_id=group_id) | |||
return {"categories": categories} | |||
async def get_group_category(self, group_id, requester_user_id, category_id): | |||
"""Get a specific category in a group (as seen by user) | |||
""" | |||
"""Get a specific category in a group (as seen by user)""" | |||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) | |||
res = await self.store.get_group_category( | |||
@@ -186,24 +184,21 @@ class GroupsServerWorkerHandler: | |||
return res | |||
async def get_group_roles(self, group_id, requester_user_id): | |||
"""Get all roles in a group (as seen by user) | |||
""" | |||
"""Get all roles in a group (as seen by user)""" | |||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) | |||
roles = await self.store.get_group_roles(group_id=group_id) | |||
return {"roles": roles} | |||
async def get_group_role(self, group_id, requester_user_id, role_id): | |||
"""Get a specific role in a group (as seen by user) | |||
""" | |||
"""Get a specific role in a group (as seen by user)""" | |||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) | |||
res = await self.store.get_group_role(group_id=group_id, role_id=role_id) | |||
return res | |||
async def get_group_profile(self, group_id, requester_user_id): | |||
"""Get the group profile as seen by requester_user_id | |||
""" | |||
"""Get the group profile as seen by requester_user_id""" | |||
await self.check_group_is_ours(group_id, requester_user_id) | |||
@@ -350,8 +345,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
async def update_group_summary_room( | |||
self, group_id, requester_user_id, room_id, category_id, content | |||
): | |||
"""Add/update a room to the group summary | |||
""" | |||
"""Add/update a room to the group summary""" | |||
await self.check_group_is_ours( | |||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | |||
) | |||
@@ -375,8 +369,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
async def delete_group_summary_room( | |||
self, group_id, requester_user_id, room_id, category_id | |||
): | |||
"""Remove a room from the summary | |||
""" | |||
"""Remove a room from the summary""" | |||
await self.check_group_is_ours( | |||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | |||
) | |||
@@ -409,8 +402,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
async def update_group_category( | |||
self, group_id, requester_user_id, category_id, content | |||
): | |||
"""Add/Update a group category | |||
""" | |||
"""Add/Update a group category""" | |||
await self.check_group_is_ours( | |||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | |||
) | |||
@@ -428,8 +420,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
return {} | |||
async def delete_group_category(self, group_id, requester_user_id, category_id): | |||
"""Delete a group category | |||
""" | |||
"""Delete a group category""" | |||
await self.check_group_is_ours( | |||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | |||
) | |||
@@ -441,8 +432,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
return {} | |||
async def update_group_role(self, group_id, requester_user_id, role_id, content): | |||
"""Add/update a role in a group | |||
""" | |||
"""Add/update a role in a group""" | |||
await self.check_group_is_ours( | |||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | |||
) | |||
@@ -458,8 +448,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
return {} | |||
async def delete_group_role(self, group_id, requester_user_id, role_id): | |||
"""Remove role from group | |||
""" | |||
"""Remove role from group""" | |||
await self.check_group_is_ours( | |||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | |||
) | |||
@@ -471,8 +460,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
async def update_group_summary_user( | |||
self, group_id, requester_user_id, user_id, role_id, content | |||
): | |||
"""Add/update a users entry in the group summary | |||
""" | |||
"""Add/update a users entry in the group summary""" | |||
await self.check_group_is_ours( | |||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | |||
) | |||
@@ -494,8 +482,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
async def delete_group_summary_user( | |||
self, group_id, requester_user_id, user_id, role_id | |||
): | |||
"""Remove a user from the group summary | |||
""" | |||
"""Remove a user from the group summary""" | |||
await self.check_group_is_ours( | |||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | |||
) | |||
@@ -507,8 +494,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
return {} | |||
async def update_group_profile(self, group_id, requester_user_id, content): | |||
"""Update the group profile | |||
""" | |||
"""Update the group profile""" | |||
await self.check_group_is_ours( | |||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | |||
) | |||
@@ -539,8 +525,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
await self.store.update_group_profile(group_id, profile) | |||
async def add_room_to_group(self, group_id, requester_user_id, room_id, content): | |||
"""Add room to group | |||
""" | |||
"""Add room to group""" | |||
RoomID.from_string(room_id) # Ensure valid room id | |||
await self.check_group_is_ours( | |||
@@ -556,8 +541,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
async def update_room_in_group( | |||
self, group_id, requester_user_id, room_id, config_key, content | |||
): | |||
"""Update room in group | |||
""" | |||
"""Update room in group""" | |||
RoomID.from_string(room_id) # Ensure valid room id | |||
await self.check_group_is_ours( | |||
@@ -576,8 +560,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
return {} | |||
async def remove_room_from_group(self, group_id, requester_user_id, room_id): | |||
"""Remove room from group | |||
""" | |||
"""Remove room from group""" | |||
await self.check_group_is_ours( | |||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | |||
) | |||
@@ -587,8 +570,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
return {} | |||
async def invite_to_group(self, group_id, user_id, requester_user_id, content): | |||
"""Invite user to group | |||
""" | |||
"""Invite user to group""" | |||
group = await self.check_group_is_ours( | |||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id | |||
@@ -724,8 +706,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
return {"state": "join", "attestation": local_attestation} | |||
async def knock(self, group_id, requester_user_id, content): | |||
"""A user requests becoming a member of the group | |||
""" | |||
"""A user requests becoming a member of the group""" | |||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) | |||
raise NotImplementedError() | |||
@@ -918,8 +899,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): | |||
def _parse_join_policy_from_contents(content): | |||
"""Given a content for a request, return the specified join policy or None | |||
""" | |||
"""Given a content for a request, return the specified join policy or None""" | |||
join_policy_dict = content.get("m.join_policy") | |||
if join_policy_dict: | |||
@@ -929,8 +909,7 @@ def _parse_join_policy_from_contents(content): | |||
def _parse_join_policy_dict(join_policy_dict): | |||
"""Given a dict for the "m.join_policy" config return the join policy specified | |||
""" | |||
"""Given a dict for the "m.join_policy" config return the join policy specified""" | |||
join_policy_type = join_policy_dict.get("type") | |||
if not join_policy_type: | |||
return "invite" | |||
@@ -203,13 +203,11 @@ class AdminHandler(BaseHandler): | |||
class ExfiltrationWriter(metaclass=abc.ABCMeta): | |||
"""Interface used to specify how to write exported data. | |||
""" | |||
"""Interface used to specify how to write exported data.""" | |||
@abc.abstractmethod | |||
def write_events(self, room_id: str, events: List[EventBase]) -> None: | |||
"""Write a batch of events for a room. | |||
""" | |||
"""Write a batch of events for a room.""" | |||
raise NotImplementedError() | |||
@abc.abstractmethod | |||
@@ -290,7 +290,9 @@ class ApplicationServicesHandler: | |||
if not interested: | |||
continue | |||
presence_events, _ = await presence_source.get_new_events( | |||
user=user, service=service, from_key=from_key, | |||
user=user, | |||
service=service, | |||
from_key=from_key, | |||
) | |||
time_now = self.clock.time_msec() | |||
events.extend( | |||
@@ -120,7 +120,9 @@ def convert_client_dict_legacy_fields_to_identifier( | |||
# Ensure the identifier has a type | |||
if "type" not in identifier: | |||
raise SynapseError( | |||
400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM, | |||
400, | |||
"'identifier' dict has no key 'type'", | |||
errcode=Codes.MISSING_PARAM, | |||
) | |||
return identifier | |||
@@ -351,7 +353,11 @@ class AuthHandler(BaseHandler): | |||
try: | |||
result, params, session_id = await self.check_ui_auth( | |||
flows, request, request_body, description, get_new_session_data, | |||
flows, | |||
request, | |||
request_body, | |||
description, | |||
get_new_session_data, | |||
) | |||
except LoginError: | |||
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise). | |||
@@ -379,8 +385,7 @@ class AuthHandler(BaseHandler): | |||
return params, session_id | |||
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: | |||
"""Get a list of the authentication types this user can use | |||
""" | |||
"""Get a list of the authentication types this user can use""" | |||
ui_auth_types = set() | |||
@@ -723,7 +728,9 @@ class AuthHandler(BaseHandler): | |||
} | |||
def _auth_dict_for_flows( | |||
self, flows: List[List[str]], session_id: str, | |||
self, | |||
flows: List[List[str]], | |||
session_id: str, | |||
) -> Dict[str, Any]: | |||
public_flows = [] | |||
for f in flows: | |||
@@ -880,7 +887,9 @@ class AuthHandler(BaseHandler): | |||
return self._supported_login_types | |||
async def validate_login( | |||
self, login_submission: Dict[str, Any], ratelimit: bool = False, | |||
self, | |||
login_submission: Dict[str, Any], | |||
ratelimit: bool = False, | |||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: | |||
"""Authenticates the user for the /login API | |||
@@ -1023,7 +1032,9 @@ class AuthHandler(BaseHandler): | |||
raise | |||
async def _validate_userid_login( | |||
self, username: str, login_submission: Dict[str, Any], | |||
self, | |||
username: str, | |||
login_submission: Dict[str, Any], | |||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: | |||
"""Helper for validate_login | |||
@@ -1446,7 +1457,8 @@ class AuthHandler(BaseHandler): | |||
# is considered OK since the newest SSO attributes should be most valid. | |||
if extra_attributes: | |||
self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes( | |||
self._clock.time_msec(), extra_attributes, | |||
self._clock.time_msec(), | |||
extra_attributes, | |||
) | |||
# Create a login token | |||
@@ -1702,5 +1714,9 @@ class PasswordProvider: | |||
# This might return an awaitable, if it does block the log out | |||
# until it completes. | |||
await maybe_awaitable( | |||
g(user_id=user_id, device_id=device_id, access_token=access_token,) | |||
g( | |||
user_id=user_id, | |||
device_id=device_id, | |||
access_token=access_token, | |||
) | |||
) |
@@ -33,8 +33,7 @@ logger = logging.getLogger(__name__) | |||
class CasError(Exception): | |||
"""Used to catch errors when validating the CAS ticket. | |||
""" | |||
"""Used to catch errors when validating the CAS ticket.""" | |||
def __init__(self, error, error_description=None): | |||
self.error = error | |||
@@ -100,7 +99,10 @@ class CasHandler: | |||
Returns: | |||
The URL to use as a "service" parameter. | |||
""" | |||
return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),) | |||
return "%s?%s" % ( | |||
self._cas_service_url, | |||
urllib.parse.urlencode(args), | |||
) | |||
async def _validate_ticket( | |||
self, ticket: str, service_args: Dict[str, str] | |||
@@ -296,7 +298,10 @@ class CasHandler: | |||
# first check if we're doing a UIA | |||
if session: | |||
return await self._sso_handler.complete_sso_ui_auth_request( | |||
self.idp_id, cas_response.username, session, request, | |||
self.idp_id, | |||
cas_response.username, | |||
session, | |||
request, | |||
) | |||
# otherwise, we're handling a login request. | |||
@@ -366,7 +371,8 @@ class CasHandler: | |||
user_id = UserID(localpart, self._hostname).to_string() | |||
logger.debug( | |||
"Looking for existing account based on mapped %s", user_id, | |||
"Looking for existing account based on mapped %s", | |||
user_id, | |||
) | |||
users = await self._store.get_users_by_id_case_insensitive(user_id) | |||
@@ -196,8 +196,7 @@ class DeactivateAccountHandler(BaseHandler): | |||
run_as_background_process("user_parter_loop", self._user_parter_loop) | |||
async def _user_parter_loop(self) -> None: | |||
"""Loop that parts deactivated users from rooms | |||
""" | |||
"""Loop that parts deactivated users from rooms""" | |||
self._user_parter_running = True | |||
logger.info("Starting user parter") | |||
try: | |||
@@ -214,8 +213,7 @@ class DeactivateAccountHandler(BaseHandler): | |||
self._user_parter_running = False | |||
async def _part_user(self, user_id: str) -> None: | |||
"""Causes the given user_id to leave all the rooms they're joined to | |||
""" | |||
"""Causes the given user_id to leave all the rooms they're joined to""" | |||
user = UserID.from_string(user_id) | |||
rooms_for_user = await self.store.get_rooms_for_user(user_id) | |||
@@ -86,7 +86,7 @@ class DeviceWorkerHandler(BaseHandler): | |||
@trace | |||
async def get_device(self, user_id: str, device_id: str) -> JsonDict: | |||
""" Retrieve the given device | |||
"""Retrieve the given device | |||
Args: | |||
user_id: The user to get the device from | |||
@@ -341,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler): | |||
@trace | |||
async def delete_device(self, user_id: str, device_id: str) -> None: | |||
""" Delete the given device | |||
"""Delete the given device | |||
Args: | |||
user_id: The user to delete the device from. | |||
@@ -386,7 +386,7 @@ class DeviceHandler(DeviceWorkerHandler): | |||
await self.delete_devices(user_id, device_ids) | |||
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: | |||
""" Delete several devices | |||
"""Delete several devices | |||
Args: | |||
user_id: The user to delete devices from. | |||
@@ -417,7 +417,7 @@ class DeviceHandler(DeviceWorkerHandler): | |||
await self.notify_device_update(user_id, device_ids) | |||
async def update_device(self, user_id: str, device_id: str, content: dict) -> None: | |||
""" Update the given device | |||
"""Update the given device | |||
Args: | |||
user_id: The user to update devices of. | |||
@@ -534,7 +534,9 @@ class DeviceHandler(DeviceWorkerHandler): | |||
device id of the dehydrated device | |||
""" | |||
device_id = await self.check_device_registered( | |||
user_id, None, initial_device_display_name, | |||
user_id, | |||
None, | |||
initial_device_display_name, | |||
) | |||
old_device_id = await self.store.store_dehydrated_device( | |||
user_id, device_id, device_data | |||
@@ -803,7 +805,8 @@ class DeviceListUpdater: | |||
try: | |||
# Try to resync the current user's devices list. | |||
result = await self.user_device_resync( | |||
user_id=user_id, mark_failed_as_stale=False, | |||
user_id=user_id, | |||
mark_failed_as_stale=False, | |||
) | |||
# user_device_resync only returns a result if it managed to | |||
@@ -813,14 +816,17 @@ class DeviceListUpdater: | |||
# self.store.update_remote_device_list_cache). | |||
if result: | |||
logger.debug( | |||
"Successfully resynced the device list for %s", user_id, | |||
"Successfully resynced the device list for %s", | |||
user_id, | |||
) | |||
except Exception as e: | |||
# If there was an issue resyncing this user, e.g. if the remote | |||
# server sent a malformed result, just log the error instead of | |||
# aborting all the subsequent resyncs. | |||
logger.debug( | |||
"Could not resync the device list for %s: %s", user_id, e, | |||
"Could not resync the device list for %s: %s", | |||
user_id, | |||
e, | |||
) | |||
finally: | |||
# Allow future calls to retry resyncinc out of sync device lists. | |||
@@ -855,7 +861,9 @@ class DeviceListUpdater: | |||
return None | |||
except (RequestSendFailed, HttpResponseException) as e: | |||
logger.warning( | |||
"Failed to handle device list update for %s: %s", user_id, e, | |||
"Failed to handle device list update for %s: %s", | |||
user_id, | |||
e, | |||
) | |||
if mark_failed_as_stale: | |||
@@ -931,7 +939,9 @@ class DeviceListUpdater: | |||
# Handle cross-signing keys. | |||
cross_signing_device_ids = await self.process_cross_signing_key_update( | |||
user_id, master_key, self_signing_key, | |||
user_id, | |||
master_key, | |||
self_signing_key, | |||
) | |||
device_ids = device_ids + cross_signing_device_ids | |||
@@ -62,7 +62,8 @@ class DeviceMessageHandler: | |||
) | |||
else: | |||
hs.get_federation_registry().register_instances_for_edu( | |||
"m.direct_to_device", hs.config.worker.writers.to_device, | |||
"m.direct_to_device", | |||
hs.config.worker.writers.to_device, | |||
) | |||
# The handler to call when we think a user's device list might be out of | |||
@@ -73,8 +74,8 @@ class DeviceMessageHandler: | |||
hs.get_device_handler().device_list_updater.user_device_resync | |||
) | |||
else: | |||
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( | |||
hs | |||
self._user_device_resync = ( | |||
ReplicationUserDevicesResyncRestServlet.make_client(hs) | |||
) | |||
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: | |||
@@ -61,8 +61,8 @@ class E2eKeysHandler: | |||
self._is_master = hs.config.worker_app is None | |||
if not self._is_master: | |||
self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client( | |||
hs | |||
self._user_device_resync_client = ( | |||
ReplicationUserDevicesResyncRestServlet.make_client(hs) | |||
) | |||
else: | |||
# Only register this edu handler on master as it requires writing | |||
@@ -85,7 +85,7 @@ class E2eKeysHandler: | |||
async def query_devices( | |||
self, query_body: JsonDict, timeout: int, from_user_id: str | |||
) -> JsonDict: | |||
""" Handle a device key query from a client | |||
"""Handle a device key query from a client | |||
{ | |||
"device_keys": { | |||
@@ -391,8 +391,7 @@ class E2eKeysHandler: | |||
async def on_federation_query_client_keys( | |||
self, query_body: Dict[str, Dict[str, Optional[List[str]]]] | |||
) -> JsonDict: | |||
""" Handle a device key query from a federated server | |||
""" | |||
"""Handle a device key query from a federated server""" | |||
device_keys_query = query_body.get( | |||
"device_keys", {} | |||
) # type: Dict[str, Optional[List[str]]] | |||
@@ -1065,7 +1064,9 @@ class E2eKeysHandler: | |||
return key, key_id, verify_key | |||
async def _retrieve_cross_signing_keys_for_remote_user( | |||
self, user: UserID, desired_key_type: str, | |||
self, | |||
user: UserID, | |||
desired_key_type: str, | |||
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]: | |||
"""Queries cross-signing keys for a remote user and saves them to the database | |||
@@ -1269,8 +1270,7 @@ def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: | |||
@attr.s(slots=True) | |||
class SignatureListItem: | |||
"""An item in the signature list as used by upload_signatures_for_device_keys. | |||
""" | |||
"""An item in the signature list as used by upload_signatures_for_device_keys.""" | |||
signing_key_id = attr.ib(type=str) | |||
target_user_id = attr.ib(type=str) | |||
@@ -1355,8 +1355,12 @@ class SigningKeyEduUpdater: | |||
logger.info("pending updates: %r", pending_updates) | |||
for master_key, self_signing_key in pending_updates: | |||
new_device_ids = await device_list_updater.process_cross_signing_key_update( | |||
user_id, master_key, self_signing_key, | |||
new_device_ids = ( | |||
await device_list_updater.process_cross_signing_key_update( | |||
user_id, | |||
master_key, | |||
self_signing_key, | |||
) | |||
) | |||
device_ids = device_ids + new_device_ids | |||
@@ -57,8 +57,7 @@ class EventStreamHandler(BaseHandler): | |||
room_id: Optional[str] = None, | |||
is_guest: bool = False, | |||
) -> JsonDict: | |||
"""Fetches the events stream for a given user. | |||
""" | |||
"""Fetches the events stream for a given user.""" | |||
if room_id: | |||
blocked = await self.store.is_room_blocked(room_id) | |||
@@ -111,13 +111,13 @@ class _NewEventInfo: | |||
class FederationHandler(BaseHandler): | |||
"""Handles events that originated from federation. | |||
Responsible for: | |||
a) handling received Pdus before handing them on as Events to the rest | |||
of the homeserver (including auth and state conflict resolutions) | |||
b) converting events that were produced by local clients that may need | |||
to be sent to remote homeservers. | |||
c) doing the necessary dances to invite remote users and join remote | |||
rooms. | |||
Responsible for: | |||
a) handling received Pdus before handing them on as Events to the rest | |||
of the homeserver (including auth and state conflict resolutions) | |||
b) converting events that were produced by local clients that may need | |||
to be sent to remote homeservers. | |||
c) doing the necessary dances to invite remote users and join remote | |||
rooms. | |||
""" | |||
def __init__(self, hs: "HomeServer"): | |||
@@ -150,11 +150,11 @@ class FederationHandler(BaseHandler): | |||
) | |||
if hs.config.worker_app: | |||
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( | |||
hs | |||
self._user_device_resync = ( | |||
ReplicationUserDevicesResyncRestServlet.make_client(hs) | |||
) | |||
self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client( | |||
hs | |||
self._maybe_store_room_on_outlier_membership = ( | |||
ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs) | |||
) | |||
else: | |||
self._device_list_updater = hs.get_device_handler().device_list_updater | |||
@@ -172,7 +172,7 @@ class FederationHandler(BaseHandler): | |||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages | |||
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: | |||
""" Process a PDU received via a federation /send/ transaction, or | |||
"""Process a PDU received via a federation /send/ transaction, or | |||
via backfill of missing prev_events | |||
Args: | |||
@@ -368,7 +368,8 @@ class FederationHandler(BaseHandler): | |||
# know about | |||
for p in prevs - seen: | |||
logger.info( | |||
"Requesting state at missing prev_event %s", event_id, | |||
"Requesting state at missing prev_event %s", | |||
event_id, | |||
) | |||
with nested_logging_context(p): | |||
@@ -388,12 +389,14 @@ class FederationHandler(BaseHandler): | |||
event_map[x.event_id] = x | |||
room_version = await self.store.get_room_version_id(room_id) | |||
state_map = await self._state_resolution_handler.resolve_events_with_store( | |||
room_id, | |||
room_version, | |||
state_maps, | |||
event_map, | |||
state_res_store=StateResolutionStore(self.store), | |||
state_map = ( | |||
await self._state_resolution_handler.resolve_events_with_store( | |||
room_id, | |||
room_version, | |||
state_maps, | |||
event_map, | |||
state_res_store=StateResolutionStore(self.store), | |||
) | |||
) | |||
# We need to give _process_received_pdu the actual state events | |||
@@ -687,9 +690,12 @@ class FederationHandler(BaseHandler): | |||
return fetched_events | |||
async def _process_received_pdu( | |||
self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], | |||
self, | |||
origin: str, | |||
event: EventBase, | |||
state: Optional[Iterable[EventBase]], | |||
): | |||
""" Called when we have a new pdu. We need to do auth checks and put it | |||
"""Called when we have a new pdu. We need to do auth checks and put it | |||
through the StateHandler. | |||
Args: | |||
@@ -801,7 +807,7 @@ class FederationHandler(BaseHandler): | |||
@log_function | |||
async def backfill(self, dest, room_id, limit, extremities): | |||
""" Trigger a backfill request to `dest` for the given `room_id` | |||
"""Trigger a backfill request to `dest` for the given `room_id` | |||
This will attempt to get more events from the remote. If the other side | |||
has no new events to offer, this will return an empty list. | |||
@@ -1204,11 +1210,16 @@ class FederationHandler(BaseHandler): | |||
with nested_logging_context(event_id): | |||
try: | |||
event = await self.federation_client.get_pdu( | |||
[destination], event_id, room_version, outlier=True, | |||
[destination], | |||
event_id, | |||
room_version, | |||
outlier=True, | |||
) | |||
if event is None: | |||
logger.warning( | |||
"Server %s didn't return event %s", destination, event_id, | |||
"Server %s didn't return event %s", | |||
destination, | |||
event_id, | |||
) | |||
return | |||
@@ -1235,7 +1246,8 @@ class FederationHandler(BaseHandler): | |||
if aid not in event_map | |||
] | |||
persisted_events = await self.store.get_events( | |||
auth_events, allow_rejected=True, | |||
auth_events, | |||
allow_rejected=True, | |||
) | |||
event_infos = [] | |||
@@ -1251,7 +1263,9 @@ class FederationHandler(BaseHandler): | |||
event_infos.append(_NewEventInfo(event, None, auth)) | |||
await self._handle_new_events( | |||
destination, room_id, event_infos, | |||
destination, | |||
room_id, | |||
event_infos, | |||
) | |||
def _sanity_check_event(self, ev): | |||
@@ -1287,7 +1301,7 @@ class FederationHandler(BaseHandler): | |||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") | |||
async def send_invite(self, target_host, event): | |||
""" Sends the invite to the remote server for signing. | |||
"""Sends the invite to the remote server for signing. | |||
Invites must be signed by the invitee's server before distribution. | |||
""" | |||
@@ -1310,7 +1324,7 @@ class FederationHandler(BaseHandler): | |||
async def do_invite_join( | |||
self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict | |||
) -> Tuple[str, int]: | |||
""" Attempts to join the `joinee` to the room `room_id` via the | |||
"""Attempts to join the `joinee` to the room `room_id` via the | |||
servers contained in `target_hosts`. | |||
This first triggers a /make_join/ request that returns a partial | |||
@@ -1388,7 +1402,8 @@ class FederationHandler(BaseHandler): | |||
# so we can rely on it now. | |||
# | |||
await self.store.upsert_room_on_join( | |||
room_id=room_id, room_version=room_version_obj, | |||
room_id=room_id, | |||
room_version=room_version_obj, | |||
) | |||
max_stream_id = await self._persist_auth_tree( | |||
@@ -1458,7 +1473,7 @@ class FederationHandler(BaseHandler): | |||
async def on_make_join_request( | |||
self, origin: str, room_id: str, user_id: str | |||
) -> EventBase: | |||
""" We've received a /make_join/ request, so we create a partial | |||
"""We've received a /make_join/ request, so we create a partial | |||
join event for the room and return that. We do *not* persist or | |||
process it until the other server has signed it and sent it back. | |||
@@ -1483,7 +1498,8 @@ class FederationHandler(BaseHandler): | |||
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) | |||
if not is_in_room: | |||
logger.info( | |||
"Got /make_join request for room %s we are no longer in", room_id, | |||
"Got /make_join request for room %s we are no longer in", | |||
room_id, | |||
) | |||
raise NotFoundError("Not an active room on this server") | |||
@@ -1517,7 +1533,7 @@ class FederationHandler(BaseHandler): | |||
return event | |||
async def on_send_join_request(self, origin, pdu): | |||
""" We have received a join event for a room. Fully process it and | |||
"""We have received a join event for a room. Fully process it and | |||
respond with the current state and auth chains. | |||
""" | |||
event = pdu | |||
@@ -1573,7 +1589,7 @@ class FederationHandler(BaseHandler): | |||
async def on_invite_request( | |||
self, origin: str, event: EventBase, room_version: RoomVersion | |||
): | |||
""" We've got an invite event. Process and persist it. Sign it. | |||
"""We've got an invite event. Process and persist it. Sign it. | |||
Respond with the now signed event. | |||
""" | |||
@@ -1700,7 +1716,7 @@ class FederationHandler(BaseHandler): | |||
async def on_make_leave_request( | |||
self, origin: str, room_id: str, user_id: str | |||
) -> EventBase: | |||
""" We've received a /make_leave/ request, so we create a partial | |||
"""We've received a /make_leave/ request, so we create a partial | |||
leave event for the room and return that. We do *not* persist or | |||
process it until the other server has signed it and sent it back. | |||
@@ -1776,8 +1792,7 @@ class FederationHandler(BaseHandler): | |||
return None | |||
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: | |||
"""Returns the state at the event. i.e. not including said event. | |||
""" | |||
"""Returns the state at the event. i.e. not including said event.""" | |||
event = await self.store.get_event(event_id, check_room_id=room_id) | |||
@@ -1803,8 +1818,7 @@ class FederationHandler(BaseHandler): | |||
return [] | |||
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: | |||
"""Returns the state at the event. i.e. not including said event. | |||
""" | |||
"""Returns the state at the event. i.e. not including said event.""" | |||
event = await self.store.get_event(event_id, check_room_id=room_id) | |||
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) | |||
@@ -2010,7 +2024,11 @@ class FederationHandler(BaseHandler): | |||
for e_id in missing_auth_events: | |||
m_ev = await self.federation_client.get_pdu( | |||
[origin], e_id, room_version=room_version, outlier=True, timeout=10000, | |||
[origin], | |||
e_id, | |||
room_version=room_version, | |||
outlier=True, | |||
timeout=10000, | |||
) | |||
if m_ev and m_ev.event_id == e_id: | |||
event_map[e_id] = m_ev | |||
@@ -2160,7 +2178,9 @@ class FederationHandler(BaseHandler): | |||
) | |||
logger.debug( | |||
"Doing soft-fail check for %s: state %s", event.event_id, current_state_ids, | |||
"Doing soft-fail check for %s: state %s", | |||
event.event_id, | |||
current_state_ids, | |||
) | |||
# Now check if event pass auth against said current state | |||
@@ -2513,7 +2533,7 @@ class FederationHandler(BaseHandler): | |||
async def construct_auth_difference( | |||
self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase] | |||
) -> Dict: | |||
""" Given a local and remote auth chain, find the differences. This | |||
"""Given a local and remote auth chain, find the differences. This | |||
assumes that we have already processed all events in remote_auth | |||
Params: | |||
@@ -146,8 +146,7 @@ class GroupsLocalWorkerHandler: | |||
async def get_users_in_group( | |||
self, group_id: str, requester_user_id: str | |||
) -> JsonDict: | |||
"""Get users in a group | |||
""" | |||
"""Get users in a group""" | |||
if self.is_mine_id(group_id): | |||
return await self.groups_server_handler.get_users_in_group( | |||
group_id, requester_user_id | |||
@@ -283,8 +282,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): | |||
async def create_group( | |||
self, group_id: str, user_id: str, content: JsonDict | |||
) -> JsonDict: | |||
"""Create a group | |||
""" | |||
"""Create a group""" | |||
logger.info("Asking to create group with ID: %r", group_id) | |||
@@ -314,8 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): | |||
async def join_group( | |||
self, group_id: str, user_id: str, content: JsonDict | |||
) -> JsonDict: | |||
"""Request to join a group | |||
""" | |||
"""Request to join a group""" | |||
if self.is_mine_id(group_id): | |||
await self.groups_server_handler.join_group(group_id, user_id, content) | |||
local_attestation = None | |||
@@ -361,8 +358,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): | |||
async def accept_invite( | |||
self, group_id: str, user_id: str, content: JsonDict | |||
) -> JsonDict: | |||
"""Accept an invite to a group | |||
""" | |||
"""Accept an invite to a group""" | |||
if self.is_mine_id(group_id): | |||
await self.groups_server_handler.accept_invite(group_id, user_id, content) | |||
local_attestation = None | |||
@@ -408,8 +404,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): | |||
async def invite( | |||
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict | |||
) -> JsonDict: | |||
"""Invite a user to a group | |||
""" | |||
"""Invite a user to a group""" | |||
content = {"requester_user_id": requester_user_id, "config": config} | |||
if self.is_mine_id(group_id): | |||
res = await self.groups_server_handler.invite_to_group( | |||
@@ -434,8 +429,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): | |||
async def on_invite( | |||
self, group_id: str, user_id: str, content: JsonDict | |||
) -> JsonDict: | |||
"""One of our users were invited to a group | |||
""" | |||
"""One of our users were invited to a group""" | |||
# TODO: Support auto join and rejection | |||
if not self.is_mine_id(user_id): | |||
@@ -466,8 +460,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): | |||
async def remove_user_from_group( | |||
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict | |||
) -> JsonDict: | |||
"""Remove a user from a group | |||
""" | |||
"""Remove a user from a group""" | |||
if user_id == requester_user_id: | |||
token = await self.store.register_user_group_membership( | |||
group_id, user_id, membership="leave" | |||
@@ -501,8 +494,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): | |||
async def user_removed_from_group( | |||
self, group_id: str, user_id: str, content: JsonDict | |||
) -> None: | |||
"""One of our users was removed/kicked from a group | |||
""" | |||
"""One of our users was removed/kicked from a group""" | |||
# TODO: Check if user in group | |||
token = await self.store.register_user_group_membership( | |||
group_id, user_id, membership="leave" | |||
@@ -72,7 +72,10 @@ class IdentityHandler(BaseHandler): | |||
) | |||
def ratelimit_request_token_requests( | |||
self, request: SynapseRequest, medium: str, address: str, | |||
self, | |||
request: SynapseRequest, | |||
medium: str, | |||
address: str, | |||
): | |||
"""Used to ratelimit requests to `/requestToken` by IP and address. | |||
@@ -124,7 +124,8 @@ class InitialSyncHandler(BaseHandler): | |||
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] | |||
receipt = await self.store.get_linearized_receipts_for_rooms( | |||
joined_rooms, to_key=int(now_token.receipt_key), | |||
joined_rooms, | |||
to_key=int(now_token.receipt_key), | |||
) | |||
tags_by_room = await self.store.get_tags_for_user(user_id) | |||
@@ -169,7 +170,10 @@ class InitialSyncHandler(BaseHandler): | |||
self.state_handler.get_current_state, event.room_id | |||
) | |||
elif event.membership == Membership.LEAVE: | |||
room_end_token = RoomStreamToken(None, event.stream_ordering,) | |||
room_end_token = RoomStreamToken( | |||
None, | |||
event.stream_ordering, | |||
) | |||
deferred_room_state = run_in_background( | |||
self.state_store.get_state_for_events, [event.event_id] | |||
) | |||
@@ -284,7 +288,9 @@ class InitialSyncHandler(BaseHandler): | |||
membership, | |||
member_event_id, | |||
) = await self.auth.check_user_in_room_or_world_readable( | |||
room_id, user_id, allow_departed_users=True, | |||
room_id, | |||
user_id, | |||
allow_departed_users=True, | |||
) | |||
is_peeking = member_event_id is None | |||
@@ -65,8 +65,7 @@ logger = logging.getLogger(__name__) | |||
class MessageHandler: | |||
"""Contains some read only APIs to get state about a room | |||
""" | |||
"""Contains some read only APIs to get state about a room""" | |||
def __init__(self, hs): | |||
self.auth = hs.get_auth() | |||
@@ -88,9 +87,13 @@ class MessageHandler: | |||
) | |||
async def get_room_data( | |||
self, user_id: str, room_id: str, event_type: str, state_key: str, | |||
self, | |||
user_id: str, | |||
room_id: str, | |||
event_type: str, | |||
state_key: str, | |||
) -> dict: | |||
""" Get data from a room. | |||
"""Get data from a room. | |||
Args: | |||
user_id | |||
@@ -174,7 +177,10 @@ class MessageHandler: | |||
raise NotFoundError("Can't find event for token %s" % (at_token,)) | |||
visible_events = await filter_events_for_client( | |||
self.storage, user_id, last_events, filter_send_to_client=False, | |||
self.storage, | |||
user_id, | |||
last_events, | |||
filter_send_to_client=False, | |||
) | |||
event = last_events[0] | |||
@@ -571,7 +577,7 @@ class EventCreationHandler: | |||
async def _is_exempt_from_privacy_policy( | |||
self, builder: EventBuilder, requester: Requester | |||
) -> bool: | |||
""""Determine if an event to be sent is exempt from having to consent | |||
""" "Determine if an event to be sent is exempt from having to consent | |||
to the privacy policy | |||
Args: | |||
@@ -793,9 +799,10 @@ class EventCreationHandler: | |||
""" | |||
if prev_event_ids is not None: | |||
assert len(prev_event_ids) <= 10, ( | |||
"Attempting to create an event with %i prev_events" | |||
% (len(prev_event_ids),) | |||
assert ( | |||
len(prev_event_ids) <= 10 | |||
), "Attempting to create an event with %i prev_events" % ( | |||
len(prev_event_ids), | |||
) | |||
else: | |||
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) | |||
@@ -821,7 +828,8 @@ class EventCreationHandler: | |||
) | |||
if not third_party_result: | |||
logger.info( | |||
"Event %s forbidden by third-party rules", event, | |||
"Event %s forbidden by third-party rules", | |||
event, | |||
) | |||
raise SynapseError( | |||
403, "This event is not allowed in this context", Codes.FORBIDDEN | |||
@@ -1316,7 +1324,11 @@ class EventCreationHandler: | |||
# Since this is a dummy-event it is OK if it is sent by a | |||
# shadow-banned user. | |||
await self.handle_new_client_event( | |||
requester, event, context, ratelimit=False, ignore_shadow_ban=True, | |||
requester, | |||
event, | |||
context, | |||
ratelimit=False, | |||
ignore_shadow_ban=True, | |||
) | |||
return True | |||
except AuthError: | |||
@@ -73,8 +73,7 @@ JWKS = TypedDict("JWKS", {"keys": List[JWK]}) | |||
class OidcHandler: | |||
"""Handles requests related to the OpenID Connect login flow. | |||
""" | |||
"""Handles requests related to the OpenID Connect login flow.""" | |||
def __init__(self, hs: "HomeServer"): | |||
self._sso_handler = hs.get_sso_handler() | |||
@@ -216,8 +215,7 @@ class OidcHandler: | |||
class OidcError(Exception): | |||
"""Used to catch errors when calling the token_endpoint | |||
""" | |||
"""Used to catch errors when calling the token_endpoint""" | |||
def __init__(self, error, error_description=None): | |||
self.error = error | |||
@@ -252,7 +250,9 @@ class OidcProvider: | |||
self._scopes = provider.scopes | |||
self._user_profile_method = provider.user_profile_method | |||
self._client_auth = ClientAuth( | |||
provider.client_id, provider.client_secret, provider.client_auth_method, | |||
provider.client_id, | |||
provider.client_secret, | |||
provider.client_auth_method, | |||
) # type: ClientAuth | |||
self._client_auth_method = provider.client_auth_method | |||
@@ -509,7 +509,10 @@ class OidcProvider: | |||
# We're not using the SimpleHttpClient util methods as we don't want to | |||
# check the HTTP status code and we do the body encoding ourself. | |||
response = await self._http_client.request( | |||
method="POST", uri=uri, data=body.encode("utf-8"), headers=headers, | |||
method="POST", | |||
uri=uri, | |||
data=body.encode("utf-8"), | |||
headers=headers, | |||
) | |||
# This is used in multiple error messages below | |||
@@ -966,7 +969,9 @@ class OidcSessionTokenGenerator: | |||
A signed macaroon token with the session information. | |||
""" | |||
macaroon = pymacaroons.Macaroon( | |||
location=self._server_name, identifier="key", key=self._macaroon_secret_key, | |||
location=self._server_name, | |||
identifier="key", | |||
key=self._macaroon_secret_key, | |||
) | |||
macaroon.add_first_party_caveat("gen = 1") | |||
macaroon.add_first_party_caveat("type = session") | |||
@@ -197,7 +197,8 @@ class PaginationHandler: | |||
stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) | |||
r = await self.store.get_room_event_before_stream_ordering( | |||
room_id, stream_ordering, | |||
room_id, | |||
stream_ordering, | |||
) | |||
if not r: | |||
logger.warning( | |||
@@ -223,7 +224,12 @@ class PaginationHandler: | |||
# the background so that it's not blocking any other operation apart from | |||
# other purges in the same room. | |||
run_as_background_process( | |||
"_purge_history", self._purge_history, purge_id, room_id, token, True, | |||
"_purge_history", | |||
self._purge_history, | |||
purge_id, | |||
room_id, | |||
token, | |||
True, | |||
) | |||
def start_purge_history( | |||
@@ -389,7 +395,9 @@ class PaginationHandler: | |||
) | |||
await self.hs.get_federation_handler().maybe_backfill( | |||
room_id, curr_topo, limit=pagin_config.limit, | |||
room_id, | |||
curr_topo, | |||
limit=pagin_config.limit, | |||
) | |||
to_room_key = None | |||
@@ -635,8 +635,7 @@ class PresenceHandler(BasePresenceHandler): | |||
self.external_process_last_updated_ms.pop(process_id, None) | |||
async def current_state_for_user(self, user_id): | |||
"""Get the current presence state for a user. | |||
""" | |||
"""Get the current presence state for a user.""" | |||
res = await self.current_state_for_users([user_id]) | |||
return res[user_id] | |||
@@ -678,8 +677,7 @@ class PresenceHandler(BasePresenceHandler): | |||
self.federation.send_presence(states) | |||
async def incoming_presence(self, origin, content): | |||
"""Called when we receive a `m.presence` EDU from a remote server. | |||
""" | |||
"""Called when we receive a `m.presence` EDU from a remote server.""" | |||
if not self._presence_enabled: | |||
return | |||
@@ -729,8 +727,7 @@ class PresenceHandler(BasePresenceHandler): | |||
await self._update_states(updates) | |||
async def set_state(self, target_user, state, ignore_status_msg=False): | |||
"""Set the presence state of the user. | |||
""" | |||
"""Set the presence state of the user.""" | |||
status_msg = state.get("status_msg", None) | |||
presence = state["presence"] | |||
@@ -758,8 +755,7 @@ class PresenceHandler(BasePresenceHandler): | |||
await self._update_states([prev_state.copy_and_replace(**new_fields)]) | |||
async def is_visible(self, observed_user, observer_user): | |||
"""Returns whether a user can see another user's presence. | |||
""" | |||
"""Returns whether a user can see another user's presence.""" | |||
observer_room_ids = await self.store.get_rooms_for_user( | |||
observer_user.to_string() | |||
) | |||
@@ -953,8 +949,7 @@ class PresenceHandler(BasePresenceHandler): | |||
def should_notify(old_state, new_state): | |||
"""Decides if a presence state change should be sent to interested parties. | |||
""" | |||
"""Decides if a presence state change should be sent to interested parties.""" | |||
if old_state == new_state: | |||
return False | |||
@@ -207,7 +207,8 @@ class ProfileHandler(BaseHandler): | |||
# This must be done by the target user himself. | |||
if by_admin: | |||
requester = create_requester( | |||
target_user, authenticated_entity=requester.authenticated_entity, | |||
target_user, | |||
authenticated_entity=requester.authenticated_entity, | |||
) | |||
await self.store.set_profile_displayname( | |||
@@ -49,15 +49,15 @@ class ReceiptsHandler(BaseHandler): | |||
) | |||
else: | |||
hs.get_federation_registry().register_instances_for_edu( | |||
"m.receipt", hs.config.worker.writers.receipts, | |||
"m.receipt", | |||
hs.config.worker.writers.receipts, | |||
) | |||
self.clock = self.hs.get_clock() | |||
self.state = hs.get_state_handler() | |||
async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: | |||
"""Called when we receive an EDU of type m.receipt from a remote HS. | |||
""" | |||
"""Called when we receive an EDU of type m.receipt from a remote HS.""" | |||
receipts = [] | |||
for room_id, room_values in content.items(): | |||
for receipt_type, users in room_values.items(): | |||
@@ -83,8 +83,7 @@ class ReceiptsHandler(BaseHandler): | |||
await self._handle_new_receipts(receipts) | |||
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: | |||
"""Takes a list of receipts, stores them and informs the notifier. | |||
""" | |||
"""Takes a list of receipts, stores them and informs the notifier.""" | |||
min_batch_id = None # type: Optional[int] | |||
max_batch_id = None # type: Optional[int] | |||
@@ -62,8 +62,8 @@ class RegistrationHandler(BaseHandler): | |||
self._register_device_client = RegisterDeviceReplicationServlet.make_client( | |||
hs | |||
) | |||
self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client( | |||
hs | |||
self._post_registration_client = ( | |||
ReplicationPostRegisterActionsServlet.make_client(hs) | |||
) | |||
else: | |||
self.device_handler = hs.get_device_handler() | |||
@@ -189,12 +189,15 @@ class RegistrationHandler(BaseHandler): | |||
self.check_registration_ratelimit(address) | |||
result = await self.spam_checker.check_registration_for_spam( | |||
threepid, localpart, user_agent_ips or [], | |||
threepid, | |||
localpart, | |||
user_agent_ips or [], | |||
) | |||
if result == RegistrationBehaviour.DENY: | |||
logger.info( | |||
"Blocked registration of %r", localpart, | |||
"Blocked registration of %r", | |||
localpart, | |||
) | |||
# We return a 429 to make it not obvious that they've been | |||
# denied. | |||
@@ -203,7 +206,8 @@ class RegistrationHandler(BaseHandler): | |||
shadow_banned = result == RegistrationBehaviour.SHADOW_BAN | |||
if shadow_banned: | |||
logger.info( | |||
"Shadow banning registration of %r", localpart, | |||
"Shadow banning registration of %r", | |||
localpart, | |||
) | |||
# do not check_auth_blocking if the call is coming through the Admin API | |||
@@ -369,7 +373,9 @@ class RegistrationHandler(BaseHandler): | |||
config["room_alias_name"] = room_alias.localpart | |||
info, _ = await room_creation_handler.create_room( | |||
fake_requester, config=config, ratelimit=False, | |||
fake_requester, | |||
config=config, | |||
ratelimit=False, | |||
) | |||
# If the room does not require an invite, but another user | |||
@@ -753,7 +759,10 @@ class RegistrationHandler(BaseHandler): | |||
return | |||
await self._auth_handler.add_threepid( | |||
user_id, threepid["medium"], threepid["address"], threepid["validated_at"], | |||
user_id, | |||
threepid["medium"], | |||
threepid["address"], | |||
threepid["validated_at"], | |||
) | |||
# And we add an email pusher for them by default, but only | |||
@@ -805,5 +814,8 @@ class RegistrationHandler(BaseHandler): | |||
raise | |||
await self._auth_handler.add_threepid( | |||
user_id, threepid["medium"], threepid["address"], threepid["validated_at"], | |||
user_id, | |||
threepid["medium"], | |||
threepid["address"], | |||
threepid["validated_at"], | |||
) |
@@ -198,7 +198,9 @@ class RoomCreationHandler(BaseHandler): | |||
if r is None: | |||
raise NotFoundError("Unknown room id %s" % (old_room_id,)) | |||
new_room_id = await self._generate_room_id( | |||
creator_id=user_id, is_public=r["is_public"], room_version=new_version, | |||
creator_id=user_id, | |||
is_public=r["is_public"], | |||
room_version=new_version, | |||
) | |||
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) | |||
@@ -236,7 +238,9 @@ class RoomCreationHandler(BaseHandler): | |||
# now send the tombstone | |||
await self.event_creation_handler.handle_new_client_event( | |||
requester=requester, event=tombstone_event, context=tombstone_context, | |||
requester=requester, | |||
event=tombstone_event, | |||
context=tombstone_context, | |||
) | |||
old_room_state = await tombstone_context.get_current_state_ids() | |||
@@ -257,7 +261,10 @@ class RoomCreationHandler(BaseHandler): | |||
# finally, shut down the PLs in the old room, and update them in the new | |||
# room. | |||
await self._update_upgraded_room_pls( | |||
requester, old_room_id, new_room_id, old_room_state, | |||
requester, | |||
old_room_id, | |||
new_room_id, | |||
old_room_state, | |||
) | |||
return new_room_id | |||
@@ -570,7 +577,7 @@ class RoomCreationHandler(BaseHandler): | |||
ratelimit: bool = True, | |||
creator_join_profile: Optional[JsonDict] = None, | |||
) -> Tuple[dict, int]: | |||
""" Creates a new room. | |||
"""Creates a new room. | |||
Args: | |||
requester: | |||
@@ -691,7 +698,9 @@ class RoomCreationHandler(BaseHandler): | |||
is_public = visibility == "public" | |||
room_id = await self._generate_room_id( | |||
creator_id=user_id, is_public=is_public, room_version=room_version, | |||
creator_id=user_id, | |||
is_public=is_public, | |||
room_version=room_version, | |||
) | |||
# Check whether this visibility value is blocked by a third party module | |||
@@ -884,7 +893,10 @@ class RoomCreationHandler(BaseHandler): | |||
_, | |||
last_stream_id, | |||
) = await self.event_creation_handler.create_and_send_nonmember_event( | |||
creator, event, ratelimit=False, ignore_shadow_ban=True, | |||
creator, | |||
event, | |||
ratelimit=False, | |||
ignore_shadow_ban=True, | |||
) | |||
return last_stream_id | |||
@@ -984,7 +996,10 @@ class RoomCreationHandler(BaseHandler): | |||
return last_sent_stream_id | |||
async def _generate_room_id( | |||
self, creator_id: str, is_public: bool, room_version: RoomVersion, | |||
self, | |||
creator_id: str, | |||
is_public: bool, | |||
room_version: RoomVersion, | |||
): | |||
# autogen room IDs and try to create it. We may clash, so just | |||
# try a few times till one goes through, giving up eventually. | |||
@@ -191,7 +191,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
# do it up front for efficiency.) | |||
if txn_id and requester.access_token_id: | |||
existing_event_id = await self.store.get_event_id_from_transaction_id( | |||
room_id, requester.user.to_string(), requester.access_token_id, txn_id, | |||
room_id, | |||
requester.user.to_string(), | |||
requester.access_token_id, | |||
txn_id, | |||
) | |||
if existing_event_id: | |||
event_pos = await self.store.get_position_for_event(existing_event_id) | |||
@@ -238,7 +241,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
) | |||
result_event = await self.event_creation_handler.handle_new_client_event( | |||
requester, event, context, extra_users=[target], ratelimit=ratelimit, | |||
requester, | |||
event, | |||
context, | |||
extra_users=[target], | |||
ratelimit=ratelimit, | |||
) | |||
if event.membership == Membership.LEAVE: | |||
@@ -583,7 +590,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
# send the rejection to the inviter's HS (with fallback to | |||
# local event) | |||
return await self.remote_reject_invite( | |||
invite.event_id, txn_id, requester, content, | |||
invite.event_id, | |||
txn_id, | |||
requester, | |||
content, | |||
) | |||
# the inviter was on our server, but has now left. Carry on | |||
@@ -1056,8 +1066,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): | |||
user: UserID, | |||
content: dict, | |||
) -> Tuple[str, int]: | |||
"""Implements RoomMemberHandler._remote_join | |||
""" | |||
"""Implements RoomMemberHandler._remote_join""" | |||
# filter ourselves out of remote_room_hosts: do_invite_join ignores it | |||
# and if it is the only entry we'd like to return a 404 rather than a | |||
# 500. | |||
@@ -1211,7 +1220,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): | |||
event.internal_metadata.out_of_band_membership = True | |||
result_event = await self.event_creation_handler.handle_new_client_event( | |||
requester, event, context, extra_users=[UserID.from_string(target_user)], | |||
requester, | |||
event, | |||
context, | |||
extra_users=[UserID.from_string(target_user)], | |||
) | |||
# we know it was persisted, so must have a stream ordering | |||
assert result_event.internal_metadata.stream_ordering | |||
@@ -1219,8 +1231,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): | |||
return result_event.event_id, result_event.internal_metadata.stream_ordering | |||
async def _user_left_room(self, target: UserID, room_id: str) -> None: | |||
"""Implements RoomMemberHandler._user_left_room | |||
""" | |||
"""Implements RoomMemberHandler._user_left_room""" | |||
user_left_room(self.distributor, target, room_id) | |||
async def forget(self, user: UserID, room_id: str) -> None: | |||
@@ -44,8 +44,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): | |||
user: UserID, | |||
content: dict, | |||
) -> Tuple[str, int]: | |||
"""Implements RoomMemberHandler._remote_join | |||
""" | |||
"""Implements RoomMemberHandler._remote_join""" | |||
if len(remote_room_hosts) == 0: | |||
raise SynapseError(404, "No known servers") | |||
@@ -80,8 +79,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): | |||
return ret["event_id"], ret["stream_id"] | |||
async def _user_left_room(self, target: UserID, room_id: str) -> None: | |||
"""Implements RoomMemberHandler._user_left_room | |||
""" | |||
"""Implements RoomMemberHandler._user_left_room""" | |||
await self._notify_change_client( | |||
user_id=target.to_string(), room_id=room_id, change="left" | |||
) |
@@ -121,7 +121,8 @@ class SamlHandler(BaseHandler): | |||
now = self.clock.time_msec() | |||
self._outstanding_requests_dict[reqid] = Saml2SessionData( | |||
creation_time=now, ui_auth_session_id=ui_auth_session_id, | |||
creation_time=now, | |||
ui_auth_session_id=ui_auth_session_id, | |||
) | |||
for key, value in info["headers"]: | |||
@@ -450,7 +451,8 @@ class DefaultSamlMappingProvider: | |||
mxid_source = saml_response.ava[self._mxid_source_attribute][0] | |||
except KeyError: | |||
logger.warning( | |||
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, | |||
"SAML2 response lacks a '%s' attestation", | |||
self._mxid_source_attribute, | |||
) | |||
raise SynapseError( | |||
400, "%s not in SAML2 response" % (self._mxid_source_attribute,) | |||
@@ -327,7 +327,8 @@ class SsoHandler: | |||
# Check if we already have a mapping for this user. | |||
previously_registered_user_id = await self._store.get_user_by_external_id( | |||
auth_provider_id, remote_user_id, | |||
auth_provider_id, | |||
remote_user_id, | |||
) | |||
# A match was found, return the user ID. | |||
@@ -416,7 +417,8 @@ class SsoHandler: | |||
with await self._mapping_lock.queue(auth_provider_id): | |||
# first of all, check if we already have a mapping for this user | |||
user_id = await self.get_sso_user_by_remote_user_id( | |||
auth_provider_id, remote_user_id, | |||
auth_provider_id, | |||
remote_user_id, | |||
) | |||
# Check for grandfathering of users. | |||
@@ -461,7 +463,8 @@ class SsoHandler: | |||
) | |||
async def _call_attribute_mapper( | |||
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], | |||
self, | |||
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], | |||
) -> UserAttributes: | |||
"""Call the attribute mapper function in a loop, until we get a unique userid""" | |||
for i in range(self._MAP_USERNAME_RETRIES): | |||
@@ -632,7 +635,8 @@ class SsoHandler: | |||
""" | |||
user_id = await self.get_sso_user_by_remote_user_id( | |||
auth_provider_id, remote_user_id, | |||
auth_provider_id, | |||
remote_user_id, | |||
) | |||
user_id_to_verify = await self._auth_handler.get_session_data( | |||
@@ -671,7 +675,8 @@ class SsoHandler: | |||
# render an error page. | |||
html = self._bad_user_template.render( | |||
server_name=self._server_name, user_id_to_verify=user_id_to_verify, | |||
server_name=self._server_name, | |||
user_id_to_verify=user_id_to_verify, | |||
) | |||
respond_with_html(request, 200, html) | |||
@@ -695,7 +700,9 @@ class SsoHandler: | |||
raise SynapseError(400, "unknown session") | |||
async def check_username_availability( | |||
self, localpart: str, session_id: str, | |||
self, | |||
localpart: str, | |||
session_id: str, | |||
) -> bool: | |||
"""Handle an "is username available" callback check | |||
@@ -833,7 +840,8 @@ class SsoHandler: | |||
) | |||
attributes = UserAttributes( | |||
localpart=session.chosen_localpart, emails=session.emails_to_use, | |||
localpart=session.chosen_localpart, | |||
emails=session.emails_to_use, | |||
) | |||
if session.use_display_name: | |||
@@ -63,8 +63,7 @@ class StatsHandler: | |||
self.clock.call_later(0, self.notify_new_event) | |||
def notify_new_event(self) -> None: | |||
"""Called when there may be more deltas to process | |||
""" | |||
"""Called when there may be more deltas to process""" | |||
if not self.stats_enabled or self._is_processing: | |||
return | |||
@@ -339,8 +339,7 @@ class SyncHandler: | |||
since_token: Optional[StreamToken] = None, | |||
full_state: bool = False, | |||
) -> SyncResult: | |||
"""Get the sync for client needed to match what the server has now. | |||
""" | |||
"""Get the sync for client needed to match what the server has now.""" | |||
return await self.generate_sync_result(sync_config, since_token, full_state) | |||
async def push_rules_for_user(self, user: UserID) -> JsonDict: | |||
@@ -564,7 +563,7 @@ class SyncHandler: | |||
stream_position: StreamToken, | |||
state_filter: StateFilter = StateFilter.all(), | |||
) -> StateMap[str]: | |||
""" Get the room state at a particular stream position | |||
"""Get the room state at a particular stream position | |||
Args: | |||
room_id: room for which to get state | |||
@@ -598,7 +597,7 @@ class SyncHandler: | |||
state: MutableStateMap[EventBase], | |||
now_token: StreamToken, | |||
) -> Optional[JsonDict]: | |||
""" Works out a room summary block for this room, summarising the number | |||
"""Works out a room summary block for this room, summarising the number | |||
of joined members in the room, and providing the 'hero' members if the | |||
room has no name so clients can consistently name rooms. Also adds | |||
state events to 'state' if needed to describe the heroes. | |||
@@ -743,7 +742,7 @@ class SyncHandler: | |||
now_token: StreamToken, | |||
full_state: bool, | |||
) -> MutableStateMap[EventBase]: | |||
""" Works out the difference in state between the start of the timeline | |||
"""Works out the difference in state between the start of the timeline | |||
and the previous sync. | |||
Args: | |||
@@ -820,8 +819,10 @@ class SyncHandler: | |||
) | |||
elif batch.limited: | |||
if batch: | |||
state_at_timeline_start = await self.state_store.get_state_ids_for_event( | |||
batch.events[0].event_id, state_filter=state_filter | |||
state_at_timeline_start = ( | |||
await self.state_store.get_state_ids_for_event( | |||
batch.events[0].event_id, state_filter=state_filter | |||
) | |||
) | |||
else: | |||
# We can get here if the user has ignored the senders of all | |||
@@ -955,8 +956,7 @@ class SyncHandler: | |||
since_token: Optional[StreamToken] = None, | |||
full_state: bool = False, | |||
) -> SyncResult: | |||
"""Generates a sync result. | |||
""" | |||
"""Generates a sync result.""" | |||
# NB: The now_token gets changed by some of the generate_sync_* methods, | |||
# this is due to some of the underlying streams not supporting the ability | |||
# to query up to a given point. | |||
@@ -1030,8 +1030,8 @@ class SyncHandler: | |||
one_time_key_counts = await self.store.count_e2e_one_time_keys( | |||
user_id, device_id | |||
) | |||
unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types( | |||
user_id, device_id | |||
unused_fallback_key_types = ( | |||
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) | |||
) | |||
logger.debug("Fetching group data") | |||
@@ -1176,8 +1176,10 @@ class SyncHandler: | |||
# weren't in the previous sync *or* they left and rejoined. | |||
users_that_have_changed.update(newly_joined_or_invited_users) | |||
user_signatures_changed = await self.store.get_users_whose_signatures_changed( | |||
user_id, since_token.device_list_key | |||
user_signatures_changed = ( | |||
await self.store.get_users_whose_signatures_changed( | |||
user_id, since_token.device_list_key | |||
) | |||
) | |||
users_that_have_changed.update(user_signatures_changed) | |||
@@ -1393,8 +1395,10 @@ class SyncHandler: | |||
logger.debug("no-oping sync") | |||
return set(), set(), set(), set() | |||
ignored_account_data = await self.store.get_global_account_data_by_type_for_user( | |||
AccountDataTypes.IGNORED_USER_LIST, user_id=user_id | |||
ignored_account_data = ( | |||
await self.store.get_global_account_data_by_type_for_user( | |||
AccountDataTypes.IGNORED_USER_LIST, user_id=user_id | |||
) | |||
) | |||
# If there is ignored users account data and it matches the proper type, | |||
@@ -1499,8 +1503,7 @@ class SyncHandler: | |||
async def _get_rooms_changed( | |||
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] | |||
) -> _RoomChanges: | |||
"""Gets the the changes that have happened since the last sync. | |||
""" | |||
"""Gets the the changes that have happened since the last sync.""" | |||
user_id = sync_result_builder.sync_config.user.to_string() | |||
since_token = sync_result_builder.since_token | |||
now_token = sync_result_builder.now_token | |||
@@ -61,7 +61,8 @@ class FollowerTypingHandler: | |||
if hs.config.worker.writers.typing != hs.get_instance_name(): | |||
hs.get_federation_registry().register_instance_for_edu( | |||
"m.typing", hs.config.worker.writers.typing, | |||
"m.typing", | |||
hs.config.worker.writers.typing, | |||
) | |||
# map room IDs to serial numbers | |||
@@ -76,8 +77,7 @@ class FollowerTypingHandler: | |||
self.clock.looping_call(self._handle_timeouts, 5000) | |||
def _reset(self) -> None: | |||
"""Reset the typing handler's data caches. | |||
""" | |||
"""Reset the typing handler's data caches.""" | |||
# map room IDs to serial numbers | |||
self._room_serials = {} | |||
# map room IDs to sets of users currently typing | |||
@@ -149,8 +149,7 @@ class FollowerTypingHandler: | |||
def process_replication_rows( | |||
self, token: int, rows: List[TypingStream.TypingStreamRow] | |||
) -> None: | |||
"""Should be called whenever we receive updates for typing stream. | |||
""" | |||
"""Should be called whenever we receive updates for typing stream.""" | |||
if self._latest_room_serial > token: | |||
# The master has gone backwards. To prevent inconsistent data, just | |||
@@ -97,8 +97,7 @@ class UserDirectoryHandler(StateDeltasHandler): | |||
return results | |||
def notify_new_event(self) -> None: | |||
"""Called when there may be more deltas to process | |||
""" | |||
"""Called when there may be more deltas to process""" | |||
if not self.update_user_directory: | |||
return | |||
@@ -134,8 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler): | |||
) | |||
async def handle_user_deactivated(self, user_id: str) -> None: | |||
"""Called when a user ID is deactivated | |||
""" | |||
"""Called when a user ID is deactivated""" | |||
# FIXME(#3714): We should probably do this in the same worker as all | |||
# the other changes. | |||
await self.store.remove_from_user_dir(user_id) | |||
@@ -172,8 +170,7 @@ class UserDirectoryHandler(StateDeltasHandler): | |||
await self.store.update_user_directory_stream_pos(max_pos) | |||
async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None: | |||
"""Called with the state deltas to process | |||
""" | |||
"""Called with the state deltas to process""" | |||
for delta in deltas: | |||
typ = delta["type"] | |||
state_key = delta["state_key"] | |||
@@ -54,8 +54,7 @@ class QuieterFileBodyProducer(FileBodyProducer): | |||
def get_request_user_agent(request: IRequest, default: str = "") -> str: | |||
"""Return the last User-Agent header, or the given default. | |||
""" | |||
"""Return the last User-Agent header, or the given default.""" | |||
# There could be raw utf-8 bytes in the User-Agent header. | |||
# N.B. if you don't do this, the logger explodes cryptically | |||
@@ -398,7 +398,8 @@ class SimpleHttpClient: | |||
body_producer = None | |||
if data is not None: | |||
body_producer = QuieterFileBodyProducer( | |||
BytesIO(data), cooperator=self._cooperator, | |||
BytesIO(data), | |||
cooperator=self._cooperator, | |||
) | |||
request_deferred = treq.request( | |||
@@ -413,7 +414,9 @@ class SimpleHttpClient: | |||
# we use our own timeout mechanism rather than treq's as a workaround | |||
# for https://twistedmatrix.com/trac/ticket/9534. | |||
request_deferred = timeout_deferred( | |||
request_deferred, 60, self.hs.get_reactor(), | |||
request_deferred, | |||
60, | |||
self.hs.get_reactor(), | |||
) | |||
# turn timeouts into RequestTimedOutErrors | |||
@@ -195,8 +195,7 @@ class MatrixFederationAgent: | |||
@implementer(IAgentEndpointFactory) | |||
class MatrixHostnameEndpointFactory: | |||
"""Factory for MatrixHostnameEndpoint for parsing to an Agent. | |||
""" | |||
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.""" | |||
def __init__( | |||
self, | |||
@@ -261,8 +260,7 @@ class MatrixHostnameEndpoint: | |||
self._srv_resolver = srv_resolver | |||
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: | |||
"""Implements IStreamClientEndpoint interface | |||
""" | |||
"""Implements IStreamClientEndpoint interface""" | |||
return run_in_background(self._do_connect, protocol_factory) | |||
@@ -81,8 +81,7 @@ class WellKnownLookupResult: | |||
class WellKnownResolver: | |||
"""Handles well-known lookups for matrix servers. | |||
""" | |||
"""Handles well-known lookups for matrix servers.""" | |||
def __init__( | |||
self, | |||
@@ -254,7 +254,8 @@ class MatrixFederationHttpClient: | |||
# Use a BlacklistingAgentWrapper to prevent circumventing the IP | |||
# blacklist via IP literals in server names | |||
self.agent = BlacklistingAgentWrapper( | |||
self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist, | |||
self.agent, | |||
ip_blacklist=hs.config.federation_ip_range_blacklist, | |||
) | |||
self.clock = hs.get_clock() | |||
@@ -652,7 +653,7 @@ class MatrixFederationHttpClient: | |||
backoff_on_404: bool = False, | |||
try_trailing_slash_on_400: bool = False, | |||
) -> Union[JsonDict, list]: | |||
""" Sends the specified json data using PUT | |||
"""Sends the specified json data using PUT | |||
Args: | |||
destination: The remote server to send the HTTP request to. | |||
@@ -740,7 +741,7 @@ class MatrixFederationHttpClient: | |||
ignore_backoff: bool = False, | |||
args: Optional[QueryArgs] = None, | |||
) -> Union[JsonDict, list]: | |||
""" Sends the specified json data using POST | |||
"""Sends the specified json data using POST | |||
Args: | |||
destination: The remote server to send the HTTP request to. | |||
@@ -799,7 +800,11 @@ class MatrixFederationHttpClient: | |||
_sec_timeout = self.default_timeout | |||
body = await _handle_json_response( | |||
self.reactor, _sec_timeout, request, response, start_ms, | |||
self.reactor, | |||
_sec_timeout, | |||
request, | |||
response, | |||
start_ms, | |||
) | |||
return body | |||
@@ -813,7 +818,7 @@ class MatrixFederationHttpClient: | |||
ignore_backoff: bool = False, | |||
try_trailing_slash_on_400: bool = False, | |||
) -> Union[JsonDict, list]: | |||
""" GETs some json from the given host homeserver and path | |||
"""GETs some json from the given host homeserver and path | |||
Args: | |||
destination: The remote server to send the HTTP request to. | |||
@@ -994,7 +999,10 @@ class MatrixFederationHttpClient: | |||
except BodyExceededMaxSize: | |||
msg = "Requested file is too large > %r bytes" % (max_size,) | |||
logger.warning( | |||
"{%s} [%s] %s", request.txn_id, request.destination, msg, | |||
"{%s} [%s] %s", | |||
request.txn_id, | |||
request.destination, | |||
msg, | |||
) | |||
raise SynapseError(502, msg, Codes.TOO_LARGE) | |||
except Exception as e: | |||
@@ -213,8 +213,7 @@ class RequestMetrics: | |||
self.update_metrics() | |||
def update_metrics(self): | |||
"""Updates the in flight metrics with values from this request. | |||
""" | |||
"""Updates the in flight metrics with values from this request.""" | |||
new_stats = self.start_context.get_resource_usage() | |||
diff = new_stats - self._request_stats | |||
@@ -76,8 +76,7 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html> | |||
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: | |||
"""Sends a JSON error response to clients. | |||
""" | |||
"""Sends a JSON error response to clients.""" | |||
if f.check(SynapseError): | |||
error_code = f.value.code | |||
@@ -106,12 +105,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: | |||
pass | |||
else: | |||
respond_with_json( | |||
request, error_code, error_dict, send_cors=True, | |||
request, | |||
error_code, | |||
error_dict, | |||
send_cors=True, | |||
) | |||
def return_html_error( | |||
f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template], | |||
f: failure.Failure, | |||
request: Request, | |||
error_template: Union[str, jinja2.Template], | |||
) -> None: | |||
"""Sends an HTML error page corresponding to the given failure. | |||
@@ -189,8 +193,7 @@ ServletCallback = Callable[ | |||
class HttpServer(Protocol): | |||
""" Interface for registering callbacks on a HTTP server | |||
""" | |||
"""Interface for registering callbacks on a HTTP server""" | |||
def register_paths( | |||
self, | |||
@@ -199,7 +202,7 @@ class HttpServer(Protocol): | |||
callback: ServletCallback, | |||
servlet_classname: str, | |||
) -> None: | |||
""" Register a callback that gets fired if we receive a http request | |||
"""Register a callback that gets fired if we receive a http request | |||
with the given method for a path that matches the given regex. | |||
If the regex contains groups these gets passed to the callback via | |||
@@ -235,8 +238,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): | |||
self._extract_context = extract_context | |||
def render(self, request): | |||
""" This gets called by twisted every time someone sends us a request. | |||
""" | |||
"""This gets called by twisted every time someone sends us a request.""" | |||
defer.ensureDeferred(self._async_render_wrapper(request)) | |||
return NOT_DONE_YET | |||
@@ -287,13 +289,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): | |||
@abc.abstractmethod | |||
def _send_response( | |||
self, request: SynapseRequest, code: int, response_object: Any, | |||
self, | |||
request: SynapseRequest, | |||
code: int, | |||
response_object: Any, | |||
) -> None: | |||
raise NotImplementedError() | |||
@abc.abstractmethod | |||
def _send_error_response( | |||
self, f: failure.Failure, request: SynapseRequest, | |||
self, | |||
f: failure.Failure, | |||
request: SynapseRequest, | |||
) -> None: | |||
raise NotImplementedError() | |||
@@ -308,10 +315,12 @@ class DirectServeJsonResource(_AsyncResource): | |||
self.canonical_json = canonical_json | |||
def _send_response( | |||
self, request: Request, code: int, response_object: Any, | |||
self, | |||
request: Request, | |||
code: int, | |||
response_object: Any, | |||
): | |||
"""Implements _AsyncResource._send_response | |||
""" | |||
"""Implements _AsyncResource._send_response""" | |||
# TODO: Only enable CORS for the requests that need it. | |||
respond_with_json( | |||
request, | |||
@@ -322,15 +331,16 @@ class DirectServeJsonResource(_AsyncResource): | |||
) | |||
def _send_error_response( | |||
self, f: failure.Failure, request: SynapseRequest, | |||
self, | |||
f: failure.Failure, | |||
request: SynapseRequest, | |||
) -> None: | |||
"""Implements _AsyncResource._send_error_response | |||
""" | |||
"""Implements _AsyncResource._send_error_response""" | |||
return_json_error(f, request) | |||
class JsonResource(DirectServeJsonResource): | |||
""" This implements the HttpServer interface and provides JSON support for | |||
"""This implements the HttpServer interface and provides JSON support for | |||
Resources. | |||
Register callbacks via register_paths() | |||
@@ -443,10 +453,12 @@ class DirectServeHtmlResource(_AsyncResource): | |||
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE | |||
def _send_response( | |||
self, request: SynapseRequest, code: int, response_object: Any, | |||
self, | |||
request: SynapseRequest, | |||
code: int, | |||
response_object: Any, | |||
): | |||
"""Implements _AsyncResource._send_response | |||
""" | |||
"""Implements _AsyncResource._send_response""" | |||
# We expect to get bytes for us to write | |||
assert isinstance(response_object, bytes) | |||
html_bytes = response_object | |||
@@ -454,10 +466,11 @@ class DirectServeHtmlResource(_AsyncResource): | |||
respond_with_html_bytes(request, 200, html_bytes) | |||
def _send_error_response( | |||
self, f: failure.Failure, request: SynapseRequest, | |||
self, | |||
f: failure.Failure, | |||
request: SynapseRequest, | |||
) -> None: | |||
"""Implements _AsyncResource._send_error_response | |||
""" | |||
"""Implements _AsyncResource._send_error_response""" | |||
return_html_error(f, request, self.ERROR_TEMPLATE) | |||
@@ -534,7 +547,9 @@ class _ByteProducer: | |||
min_chunk_size = 1024 | |||
def __init__( | |||
self, request: Request, iterator: Iterator[bytes], | |||
self, | |||
request: Request, | |||
iterator: Iterator[bytes], | |||
): | |||
self._request = request | |||
self._iterator = iterator | |||
@@ -654,7 +669,10 @@ def respond_with_json( | |||
def respond_with_json_bytes( | |||
request: Request, code: int, json_bytes: bytes, send_cors: bool = False, | |||
request: Request, | |||
code: int, | |||
json_bytes: bytes, | |||
send_cors: bool = False, | |||
): | |||
"""Sends encoded JSON in response to the given request. | |||
@@ -769,7 +787,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None: | |||
def finish_request(request: Request): | |||
""" Finish writing the response to the request. | |||
"""Finish writing the response to the request. | |||
Twisted throws a RuntimeException if the connection closed before the | |||
response was written but doesn't provide a convenient or reliable way to | |||
@@ -258,7 +258,7 @@ def assert_params_in_dict(body, required): | |||
class RestServlet: | |||
""" A Synapse REST Servlet. | |||
"""A Synapse REST Servlet. | |||
An implementing class can either provide its own custom 'register' method, | |||
or use the automatic pattern handling provided by the base class. | |||
@@ -249,8 +249,7 @@ class SynapseRequest(Request): | |||
) | |||
def _finished_processing(self): | |||
"""Log the completion of this request and update the metrics | |||
""" | |||
"""Log the completion of this request and update the metrics""" | |||
assert self.logcontext is not None | |||
usage = self.logcontext.get_resource_usage() | |||
@@ -276,7 +275,8 @@ class SynapseRequest(Request): | |||
# authenticated (e.g. and admin is puppetting a user) then we log both. | |||
if self.requester.user.to_string() != authenticated_entity: | |||
authenticated_entity = "{},{}".format( | |||
authenticated_entity, self.requester.user.to_string(), | |||
authenticated_entity, | |||
self.requester.user.to_string(), | |||
) | |||
elif self.requester is not None: | |||
# This shouldn't happen, but we log it so we don't lose information | |||
@@ -322,8 +322,7 @@ class SynapseRequest(Request): | |||
logger.warning("Failed to stop metrics: %r", e) | |||
def _should_log_request(self) -> bool: | |||
"""Whether we should log at INFO that we processed the request. | |||
""" | |||
"""Whether we should log at INFO that we processed the request.""" | |||
if self.path == b"/health": | |||
return False | |||
@@ -174,7 +174,9 @@ class RemoteHandler(logging.Handler): | |||
# Make a new producer and start it. | |||
self._producer = LogProducer( | |||
buffer=self._buffer, transport=result.transport, format=self.format, | |||
buffer=self._buffer, | |||
transport=result.transport, | |||
format=self.format, | |||
) | |||
result.transport.registerProducer(self._producer, True) | |||
self._producer.resumeProducing() | |||
@@ -60,7 +60,10 @@ def parse_drain_configs( | |||
) | |||
# Either use the default formatter or the tersejson one. | |||
if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,): | |||
if logging_type in ( | |||
DrainType.CONSOLE_JSON, | |||
DrainType.FILE_JSON, | |||
): | |||
formatter = "json" # type: Optional[str] | |||
elif logging_type in ( | |||
DrainType.CONSOLE_JSON_TERSE, | |||
@@ -131,7 +134,9 @@ def parse_drain_configs( | |||
) | |||
def setup_structured_logging(log_config: dict,) -> dict: | |||
def setup_structured_logging( | |||
log_config: dict, | |||
) -> dict: | |||
""" | |||
Convert a legacy structured logging configuration (from Synapse < v1.23.0) | |||
to one compatible with the new standard library handlers. | |||
@@ -338,7 +338,10 @@ class LoggingContext: | |||
if self.previous_context != old_context: | |||
logcontext_error( | |||
"Expected previous context %r, found %r" | |||
% (self.previous_context, old_context,) | |||
% ( | |||
self.previous_context, | |||
old_context, | |||
) | |||
) | |||
return self | |||
@@ -562,7 +565,7 @@ class LoggingContextFilter(logging.Filter): | |||
class PreserveLoggingContext: | |||
"""Context manager which replaces the logging context | |||
The previous logging context is restored on exit.""" | |||
The previous logging context is restored on exit.""" | |||
__slots__ = ["_old_context", "_new_context"] | |||
@@ -585,7 +588,10 @@ class PreserveLoggingContext: | |||
else: | |||
logcontext_error( | |||
"Expected logging context %s but found %s" | |||
% (self._new_context, context,) | |||
% ( | |||
self._new_context, | |||
context, | |||
) | |||
) | |||
@@ -238,8 +238,7 @@ try: | |||
@attr.s(slots=True, frozen=True) | |||
class _WrappedRustReporter: | |||
"""Wrap the reporter to ensure `report_span` never throws. | |||
""" | |||
"""Wrap the reporter to ensure `report_span` never throws.""" | |||
_reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter)) | |||
@@ -326,8 +325,7 @@ def noop_context_manager(*args, **kwargs): | |||
def init_tracer(hs: "HomeServer"): | |||
"""Set the whitelists and initialise the JaegerClient tracer | |||
""" | |||
"""Set the whitelists and initialise the JaegerClient tracer""" | |||
global opentracing | |||
if not hs.config.opentracer_enabled: | |||
# We don't have a tracer | |||
@@ -384,7 +382,7 @@ def whitelisted_homeserver(destination): | |||
Args: | |||
destination (str) | |||
""" | |||
""" | |||
if _homeserver_whitelist: | |||
return _homeserver_whitelist.match(destination) | |||
@@ -43,8 +43,7 @@ def _log_debug_as_f(f, msg, msg_args): | |||
def log_function(f): | |||
""" Function decorator that logs every call to that function. | |||
""" | |||
"""Function decorator that logs every call to that function.""" | |||
func_name = f.__name__ | |||
@wraps(f) | |||
@@ -155,8 +155,7 @@ class InFlightGauge: | |||
self._registrations.setdefault(key, set()).add(callback) | |||
def unregister(self, key, callback): | |||
"""Registers that we've exited a block with labels `key`. | |||
""" | |||
"""Registers that we've exited a block with labels `key`.""" | |||
with self._lock: | |||
self._registrations.setdefault(key, set()).discard(callback) | |||
@@ -402,7 +401,9 @@ class PyPyGCStats: | |||
# Total time spent in GC: 0.073 # s.total_gc_time | |||
pypy_gc_time = CounterMetricFamily( | |||
"pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[], | |||
"pypy_gc_time_seconds_total", | |||
"Total time spent in PyPy GC", | |||
labels=[], | |||
) | |||
pypy_gc_time.add_metric([], s.total_gc_time / 1000) | |||
yield pypy_gc_time | |||
@@ -216,7 +216,7 @@ class MetricsHandler(BaseHTTPRequestHandler): | |||
@classmethod | |||
def factory(cls, registry): | |||
"""Returns a dynamic MetricsHandler class tied | |||
to the passed registry. | |||
to the passed registry. | |||
""" | |||
# This implementation relies on MetricsHandler.registry | |||
# (defined above and defaulted to REGISTRY). | |||
@@ -208,7 +208,8 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar | |||
return await maybe_awaitable(func(*args, **kwargs)) | |||
except Exception: | |||
logger.exception( | |||
"Background process '%s' threw an exception", desc, | |||
"Background process '%s' threw an exception", | |||
desc, | |||
) | |||
finally: | |||
_background_process_in_flight_count.labels(desc).dec() | |||
@@ -249,8 +250,7 @@ class BackgroundProcessLoggingContext(LoggingContext): | |||
self._proc = _BackgroundProcess(name, self) | |||
def start(self, rusage: "Optional[resource._RUsage]"): | |||
"""Log context has started running (again). | |||
""" | |||
"""Log context has started running (again).""" | |||
super().start(rusage) | |||
@@ -261,8 +261,7 @@ class BackgroundProcessLoggingContext(LoggingContext): | |||
_background_processes_active_since_last_scrape.add(self._proc) | |||
def __exit__(self, type, value, traceback) -> None: | |||
"""Log context has finished. | |||
""" | |||
"""Log context has finished.""" | |||
super().__exit__(type, value, traceback) | |||
@@ -275,7 +275,9 @@ class ModuleApi: | |||
redirect them directly if whitelisted). | |||
""" | |||
self._auth_handler._complete_sso_login( | |||
registered_user_id, request, client_redirect_url, | |||
registered_user_id, | |||
request, | |||
client_redirect_url, | |||
) | |||
async def complete_sso_login_async( | |||
@@ -352,7 +354,10 @@ class ModuleApi: | |||
event, | |||
_, | |||
) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event( | |||
requester, event_dict, ratelimit=False, ignore_shadow_ban=True, | |||
requester, | |||
event_dict, | |||
ratelimit=False, | |||
ignore_shadow_ban=True, | |||
) | |||
return event | |||
@@ -75,7 +75,7 @@ def count(func: Callable[[T], bool], it: Iterable[T]) -> int: | |||
class _NotificationListener: | |||
""" This represents a single client connection to the events stream. | |||
"""This represents a single client connection to the events stream. | |||
The events stream handler will have yielded to the deferred, so to | |||
notify the handler it is sufficient to resolve the deferred. | |||
""" | |||
@@ -119,7 +119,10 @@ class _NotifierUserStream: | |||
self.notify_deferred = ObservableDeferred(defer.Deferred()) | |||
def notify( | |||
self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int, | |||
self, | |||
stream_key: str, | |||
stream_id: Union[int, RoomStreamToken], | |||
time_now_ms: int, | |||
): | |||
"""Notify any listeners for this user of a new event from an | |||
event source. | |||
@@ -140,7 +143,7 @@ class _NotifierUserStream: | |||
noify_deferred.callback(self.current_token) | |||
def remove(self, notifier: "Notifier"): | |||
""" Remove this listener from all the indexes in the Notifier | |||
"""Remove this listener from all the indexes in the Notifier | |||
it knows about. | |||
""" | |||
@@ -186,7 +189,7 @@ class _PendingRoomEventEntry: | |||
class Notifier: | |||
""" This class is responsible for notifying any listeners when there are | |||
"""This class is responsible for notifying any listeners when there are | |||
new events available for it. | |||
Primarily used from the /events stream. | |||
@@ -265,8 +268,7 @@ class Notifier: | |||
max_room_stream_token: RoomStreamToken, | |||
extra_users: Collection[UserID] = [], | |||
): | |||
"""Unwraps event and calls `on_new_room_event_args`. | |||
""" | |||
"""Unwraps event and calls `on_new_room_event_args`.""" | |||
self.on_new_room_event_args( | |||
event_pos=event_pos, | |||
room_id=event.room_id, | |||
@@ -341,7 +343,10 @@ class Notifier: | |||
if users or rooms: | |||
self.on_new_event( | |||
"room_key", max_room_stream_token, users=users, rooms=rooms, | |||
"room_key", | |||
max_room_stream_token, | |||
users=users, | |||
rooms=rooms, | |||
) | |||
self._on_updated_room_token(max_room_stream_token) | |||
@@ -392,7 +397,7 @@ class Notifier: | |||
users: Collection[Union[str, UserID]] = [], | |||
rooms: Collection[str] = [], | |||
): | |||
""" Used to inform listeners that something has happened event wise. | |||
"""Used to inform listeners that something has happened event wise. | |||
Will wake up all listeners for the given users and rooms. | |||
""" | |||
@@ -418,7 +423,9 @@ class Notifier: | |||
# Notify appservices | |||
self._notify_app_services_ephemeral( | |||
stream_key, new_token, users, | |||
stream_key, | |||
new_token, | |||
users, | |||
) | |||
def on_new_replication_data(self) -> None: | |||
@@ -502,7 +509,7 @@ class Notifier: | |||
is_guest: bool = False, | |||
explicit_room_id: str = None, | |||
) -> EventStreamResult: | |||
""" For the given user and rooms, return any new events for them. If | |||
"""For the given user and rooms, return any new events for them. If | |||
there are no new events wait for up to `timeout` milliseconds for any | |||
new events to happen before returning. | |||
@@ -651,8 +658,7 @@ class Notifier: | |||
cb() | |||
def notify_remote_server_up(self, server: str): | |||
"""Notify any replication that a remote server has come back up | |||
""" | |||
"""Notify any replication that a remote server has come back up""" | |||
# We call federation_sender directly rather than registering as a | |||
# callback as a) we already have a reference to it and b) it introduces | |||
# circular dependencies. | |||
@@ -144,8 +144,7 @@ class BulkPushRuleEvaluator: | |||
@lru_cache() | |||
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom": | |||
"""Get the current RulesForRoom object for the given room id | |||
""" | |||
"""Get the current RulesForRoom object for the given room id""" | |||
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache | |||
# before any lookup methods get called on it as otherwise there may be | |||
# a race if invalidate_all gets called (which assumes its in the cache) | |||
@@ -252,7 +251,9 @@ class BulkPushRuleEvaluator: | |||
# notified for this event. (This will then get handled when we persist | |||
# the event) | |||
await self.store.add_push_actions_to_staging( | |||
event.event_id, actions_by_user, count_as_unread, | |||
event.event_id, | |||
actions_by_user, | |||
count_as_unread, | |||
) | |||
@@ -116,8 +116,7 @@ class EmailPusher(Pusher): | |||
self._is_processing = True | |||
def _resume_processing(self) -> None: | |||
"""Used by tests to resume processing of events after pausing. | |||
""" | |||
"""Used by tests to resume processing of events after pausing.""" | |||
assert self._is_processing | |||
self._is_processing = False | |||
self._start_processing() | |||
@@ -157,8 +156,10 @@ class EmailPusher(Pusher): | |||
being run. | |||
""" | |||
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering | |||
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email( | |||
self.user_id, start, self.max_stream_ordering | |||
unprocessed = ( | |||
await self.store.get_unread_push_actions_for_user_in_range_for_email( | |||
self.user_id, start, self.max_stream_ordering | |||
) | |||
) | |||
soonest_due_at = None # type: Optional[int] | |||
@@ -222,12 +223,14 @@ class EmailPusher(Pusher): | |||
self, last_stream_ordering: int | |||
) -> None: | |||
self.last_stream_ordering = last_stream_ordering | |||
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( | |||
self.app_id, | |||
self.email, | |||
self.user_id, | |||
last_stream_ordering, | |||
self.clock.time_msec(), | |||
pusher_still_exists = ( | |||
await self.store.update_pusher_last_stream_ordering_and_success( | |||
self.app_id, | |||
self.email, | |||
self.user_id, | |||
last_stream_ordering, | |||
self.clock.time_msec(), | |||
) | |||
) | |||
if not pusher_still_exists: | |||
# The pusher has been deleted while we were processing, so | |||
@@ -298,7 +301,8 @@ class EmailPusher(Pusher): | |||
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS | |||
) | |||
self.throttle_params[room_id] = ThrottleParams( | |||
self.clock.time_msec(), new_throttle_ms, | |||
self.clock.time_msec(), | |||
new_throttle_ms, | |||
) | |||
assert self.pusher_id is not None | |||
await self.store.set_throttle_params( | |||
@@ -176,8 +176,10 @@ class HttpPusher(Pusher): | |||
Never call this directly: use _process which will only allow this to | |||
run once per pusher. | |||
""" | |||
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http( | |||
self.user_id, self.last_stream_ordering, self.max_stream_ordering | |||
unprocessed = ( | |||
await self.store.get_unread_push_actions_for_user_in_range_for_http( | |||
self.user_id, self.last_stream_ordering, self.max_stream_ordering | |||
) | |||
) | |||
logger.info( | |||
@@ -204,12 +206,14 @@ class HttpPusher(Pusher): | |||
http_push_processed_counter.inc() | |||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC | |||
self.last_stream_ordering = push_action["stream_ordering"] | |||
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( | |||
self.app_id, | |||
self.pushkey, | |||
self.user_id, | |||
self.last_stream_ordering, | |||
self.clock.time_msec(), | |||
pusher_still_exists = ( | |||
await self.store.update_pusher_last_stream_ordering_and_success( | |||
self.app_id, | |||
self.pushkey, | |||
self.user_id, | |||
self.last_stream_ordering, | |||
self.clock.time_msec(), | |||
) | |||
) | |||
if not pusher_still_exists: | |||
# The pusher has been deleted while we were processing, so | |||
@@ -290,7 +294,8 @@ class HttpPusher(Pusher): | |||
# for sanity, we only remove the pushkey if it | |||
# was the one we actually sent... | |||
logger.warning( | |||
("Ignoring rejected pushkey %s because we didn't send it"), pk, | |||
("Ignoring rejected pushkey %s because we didn't send it"), | |||
pk, | |||
) | |||
else: | |||
logger.info("Pushkey %s was rejected: removing", pk) | |||
@@ -78,8 +78,7 @@ class PusherPool: | |||
self.pushers = {} # type: Dict[str, Dict[str, Pusher]] | |||
def start(self) -> None: | |||
"""Starts the pushers off in a background process. | |||
""" | |||
"""Starts the pushers off in a background process.""" | |||
if not self._should_start_pushers: | |||
logger.info("Not starting pushers because they are disabled in the config") | |||
return | |||
@@ -297,8 +296,7 @@ class PusherPool: | |||
return pusher | |||
async def _start_pushers(self) -> None: | |||
"""Start all the pushers | |||
""" | |||
"""Start all the pushers""" | |||
pushers = await self.store.get_all_pushers() | |||
# Stagger starting up the pushers so we don't completely drown the | |||
@@ -335,7 +333,8 @@ class PusherPool: | |||
return None | |||
except Exception: | |||
logger.exception( | |||
"Couldn't start pusher id %i: caught Exception", pusher_config.id, | |||
"Couldn't start pusher id %i: caught Exception", | |||
pusher_config.id, | |||
) | |||
return None | |||
@@ -273,7 +273,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): | |||
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) | |||
http_server.register_paths( | |||
method, [pattern], self._check_auth_and_handle, self.__class__.__name__, | |||
method, | |||
[pattern], | |||
self._check_auth_and_handle, | |||
self.__class__.__name__, | |||
) | |||
def _check_auth_and_handle(self, request, **kwargs): | |||