@@ -0,0 +1 @@ | |||
Fixup `synapse.rest` to pass mypy. |
@@ -66,3 +66,12 @@ ignore_missing_imports = True | |||
[mypy-sentry_sdk] | |||
ignore_missing_imports = True | |||
[mypy-PIL.*] | |||
ignore_missing_imports = True | |||
[mypy-lxml] | |||
ignore_missing_imports = True | |||
[mypy-jwt.*] | |||
ignore_missing_imports = True |
@@ -338,21 +338,22 @@ class UserRegisterServlet(RestServlet): | |||
got_mac = body["mac"] | |||
want_mac = hmac.new( | |||
want_mac_builder = hmac.new( | |||
key=self.hs.config.registration_shared_secret.encode(), | |||
digestmod=hashlib.sha1, | |||
) | |||
want_mac.update(nonce.encode("utf8")) | |||
want_mac.update(b"\x00") | |||
want_mac.update(username) | |||
want_mac.update(b"\x00") | |||
want_mac.update(password) | |||
want_mac.update(b"\x00") | |||
want_mac.update(b"admin" if admin else b"notadmin") | |||
want_mac_builder.update(nonce.encode("utf8")) | |||
want_mac_builder.update(b"\x00") | |||
want_mac_builder.update(username) | |||
want_mac_builder.update(b"\x00") | |||
want_mac_builder.update(password) | |||
want_mac_builder.update(b"\x00") | |||
want_mac_builder.update(b"admin" if admin else b"notadmin") | |||
if user_type: | |||
want_mac.update(b"\x00") | |||
want_mac.update(user_type.encode("utf8")) | |||
want_mac = want_mac.hexdigest() | |||
want_mac_builder.update(b"\x00") | |||
want_mac_builder.update(user_type.encode("utf8")) | |||
want_mac = want_mac_builder.hexdigest() | |||
if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): | |||
raise SynapseError(403, "HMAC incorrect") | |||
@@ -514,7 +514,7 @@ class CasTicketServlet(RestServlet): | |||
if user is None: | |||
raise Exception("CAS response does not contain user") | |||
except Exception: | |||
logger.error("Error parsing CAS response", exc_info=1) | |||
logger.exception("Error parsing CAS response") | |||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) | |||
if not success: | |||
raise LoginError( | |||
@@ -16,6 +16,7 @@ | |||
""" This module contains REST servlets to do with rooms: /rooms/<paths> """ | |||
import logging | |||
from typing import List, Optional | |||
from six.moves.urllib import parse as urlparse | |||
@@ -207,7 +208,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): | |||
requester, event_dict, txn_id=txn_id | |||
) | |||
ret = {} | |||
ret = {} # type: dict | |||
if event: | |||
set_tag("event_id", event.event_id) | |||
ret = {"event_id": event.event_id} | |||
@@ -285,7 +286,7 @@ class JoinRoomAliasServlet(TransactionRestServlet): | |||
try: | |||
remote_room_hosts = [ | |||
x.decode("ascii") for x in request.args[b"server_name"] | |||
] | |||
] # type: Optional[List[str]] | |||
except Exception: | |||
remote_room_hosts = None | |||
elif RoomAlias.is_valid(room_identifier): | |||
@@ -375,7 +376,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): | |||
server = parse_string(request, "server", default=None) | |||
content = parse_json_object_from_request(request) | |||
limit = int(content.get("limit", 100)) | |||
limit = int(content.get("limit", 100)) # type: Optional[int] | |||
since_token = content.get("since", None) | |||
search_filter = content.get("filter", None) | |||
@@ -504,11 +505,16 @@ class RoomMessageListRestServlet(RestServlet): | |||
filter_bytes = parse_string(request, b"filter", encoding=None) | |||
if filter_bytes: | |||
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) | |||
event_filter = Filter(json.loads(filter_json)) | |||
if event_filter.filter_json.get("event_format", "client") == "federation": | |||
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] | |||
if ( | |||
event_filter | |||
and event_filter.filter_json.get("event_format", "client") | |||
== "federation" | |||
): | |||
as_client_event = False | |||
else: | |||
event_filter = None | |||
msgs = await self.pagination_handler.get_messages( | |||
room_id=room_id, | |||
requester=requester, | |||
@@ -611,7 +617,7 @@ class RoomEventContextServlet(RestServlet): | |||
filter_bytes = parse_string(request, "filter") | |||
if filter_bytes: | |||
filter_json = urlparse.unquote(filter_bytes) | |||
event_filter = Filter(json.loads(filter_json)) | |||
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] | |||
else: | |||
event_filter = None | |||
@@ -21,6 +21,7 @@ from typing import List, Union | |||
from six import string_types | |||
import synapse | |||
import synapse.api.auth | |||
import synapse.types | |||
from synapse.api.constants import LoginType | |||
from synapse.api.errors import ( | |||
@@ -405,7 +406,7 @@ class RegisterRestServlet(RestServlet): | |||
return ret | |||
elif kind != b"user": | |||
raise UnrecognizedRequestError( | |||
"Do not understand membership kind: %s" % (kind,) | |||
"Do not understand membership kind: %s" % (kind.decode("utf8"),) | |||
) | |||
# we do basic sanity checks here because the auth layer will store these | |||
@@ -14,6 +14,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Tuple | |||
from synapse.http import servlet | |||
from synapse.http.servlet import parse_json_object_from_request | |||
@@ -60,7 +61,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): | |||
sender_user_id, message_type, content["messages"] | |||
) | |||
response = (200, {}) | |||
response = (200, {}) # type: Tuple[int, dict] | |||
return response | |||
@@ -13,6 +13,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Dict, Set | |||
from canonicaljson import encode_canonical_json, json | |||
from signedjson.sign import sign_json | |||
@@ -103,7 +104,7 @@ class RemoteKey(DirectServeResource): | |||
async def _async_render_GET(self, request): | |||
if len(request.postpath) == 1: | |||
(server,) = request.postpath | |||
query = {server.decode("ascii"): {}} | |||
query = {server.decode("ascii"): {}} # type: dict | |||
elif len(request.postpath) == 2: | |||
server, key_id = request.postpath | |||
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") | |||
@@ -148,7 +149,7 @@ class RemoteKey(DirectServeResource): | |||
time_now_ms = self.clock.time_msec() | |||
cache_misses = dict() | |||
cache_misses = dict() # type: Dict[str, Set[str]] | |||
for (server_name, key_id, from_server), results in cached.items(): | |||
results = [(result["ts_added_ms"], result) for result in results] | |||
@@ -18,6 +18,7 @@ import errno | |||
import logging | |||
import os | |||
import shutil | |||
from typing import Dict, Tuple | |||
from six import iteritems | |||
@@ -605,7 +606,7 @@ class MediaRepository(object): | |||
# We deduplicate the thumbnail sizes by ignoring the cropped versions if | |||
# they have the same dimensions of a scaled one. | |||
thumbnails = {} | |||
thumbnails = {} # type: Dict[Tuple[int, int, str], str] | |||
for r_width, r_height, r_method, r_type in requirements: | |||
if r_method == "crop": | |||
thumbnails.setdefault((r_width, r_height, r_type), r_method) | |||
@@ -23,6 +23,7 @@ import re | |||
import shutil | |||
import sys | |||
import traceback | |||
from typing import Dict, Optional | |||
import six | |||
from six import string_types | |||
@@ -237,8 +238,8 @@ class PreviewUrlResource(DirectServeResource): | |||
# If we don't find a match, we'll look at the HTTP Content-Type, and | |||
# if that doesn't exist, we'll fall back to UTF-8. | |||
if not encoding: | |||
match = _content_type_match.match(media_info["media_type"]) | |||
encoding = match.group(1) if match else "utf-8" | |||
content_match = _content_type_match.match(media_info["media_type"]) | |||
encoding = content_match.group(1) if content_match else "utf-8" | |||
og = decode_and_calc_og(body, media_info["uri"], encoding) | |||
@@ -518,7 +519,7 @@ def _calc_og(tree, media_uri): | |||
# "og:video:height" : "720", | |||
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", | |||
og = {} | |||
og = {} # type: Dict[str, Optional[str]] | |||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): | |||
if "content" in tag.attrib: | |||
# if we've got more than 50 tags, someone is taking the piss | |||
@@ -296,8 +296,8 @@ class ThumbnailResource(DirectServeResource): | |||
d_h = desired_height | |||
if desired_method.lower() == "crop": | |||
info_list = [] | |||
info_list2 = [] | |||
crop_info_list = [] | |||
crop_info_list2 = [] | |||
for info in thumbnail_infos: | |||
t_w = info["thumbnail_width"] | |||
t_h = info["thumbnail_height"] | |||
@@ -309,7 +309,7 @@ class ThumbnailResource(DirectServeResource): | |||
type_quality = desired_type != info["thumbnail_type"] | |||
length_quality = info["thumbnail_length"] | |||
if t_w >= d_w or t_h >= d_h: | |||
info_list.append( | |||
crop_info_list.append( | |||
( | |||
aspect_quality, | |||
min_quality, | |||
@@ -320,7 +320,7 @@ class ThumbnailResource(DirectServeResource): | |||
) | |||
) | |||
else: | |||
info_list2.append( | |||
crop_info_list2.append( | |||
( | |||
aspect_quality, | |||
min_quality, | |||
@@ -330,10 +330,10 @@ class ThumbnailResource(DirectServeResource): | |||
info, | |||
) | |||
) | |||
if info_list: | |||
return min(info_list)[-1] | |||
if crop_info_list: | |||
return min(crop_info_list2)[-1] | |||
else: | |||
return min(info_list2)[-1] | |||
return min(crop_info_list2)[-1] | |||
else: | |||
info_list = [] | |||
info_list2 = [] | |||
@@ -183,8 +183,7 @@ commands = mypy \ | |||
synapse/logging/ \ | |||
synapse/module_api \ | |||
synapse/replication \ | |||
synapse/rest/consent \ | |||
synapse/rest/saml2 \ | |||
synapse/rest \ | |||
synapse/spam_checker_api \ | |||
synapse/storage/engines \ | |||
synapse/streams | |||