You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

303 lines
11 KiB

  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import re
  16. from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
  17. from synapse._pydantic_compat import HAS_PYDANTIC_V2
  18. if TYPE_CHECKING or HAS_PYDANTIC_V2:
  19. from pydantic.v1 import Extra, StrictInt, StrictStr
  20. else:
  21. from pydantic import StrictInt, StrictStr, Extra
  22. from signedjson.sign import sign_json
  23. from twisted.web.server import Request
  24. from synapse.crypto.keyring import ServerKeyFetcher
  25. from synapse.http.server import HttpServer
  26. from synapse.http.servlet import (
  27. RestServlet,
  28. parse_and_validate_json_object_from_request,
  29. parse_integer,
  30. )
  31. from synapse.rest.models import RequestBodyModel
  32. from synapse.storage.keys import FetchKeyResultForRemote
  33. from synapse.types import JsonDict
  34. from synapse.util import json_decoder
  35. from synapse.util.async_helpers import yieldable_gather_results
  36. if TYPE_CHECKING:
  37. from synapse.server import HomeServer
  38. logger = logging.getLogger(__name__)
  39. class _KeyQueryCriteriaDataModel(RequestBodyModel):
  40. class Config:
  41. extra = Extra.allow
  42. minimum_valid_until_ts: Optional[StrictInt]
  43. class RemoteKey(RestServlet):
  44. """HTTP resource for retrieving the TLS certificate and NACL signature
  45. verification keys for a collection of servers. Checks that the reported
  46. X.509 TLS certificate matches the one used in the HTTPS connection. Checks
  47. that the NACL signature for the remote server is valid. Returns a dict of
  48. JSON signed by both the remote server and by this server.
  49. Supports individual GET APIs and a bulk query POST API.
  50. Requests:
  51. GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
  52. GET /_matrix/key/v2/query/remote.server.example.com/a.key.id HTTP/1.1
  53. POST /_matrix/v2/query HTTP/1.1
  54. Content-Type: application/json
  55. {
  56. "server_keys": {
  57. "remote.server.example.com": {
  58. "a.key.id": {
  59. "minimum_valid_until_ts": 1234567890123
  60. }
  61. }
  62. }
  63. }
  64. Response:
  65. HTTP/1.1 200 OK
  66. Content-Type: application/json
  67. {
  68. "server_keys": [
  69. {
  70. "server_name": "remote.server.example.com"
  71. "valid_until_ts": # posix timestamp
  72. "verify_keys": {
  73. "a.key.id": { # The identifier for a key.
  74. key: "" # base64 encoded verification key.
  75. }
  76. }
  77. "old_verify_keys": {
  78. "an.old.key.id": { # The identifier for an old key.
  79. key: "", # base64 encoded key
  80. "expired_ts": 0, # when the key stop being used.
  81. }
  82. }
  83. "signatures": {
  84. "remote.server.example.com": {...}
  85. "this.server.example.com": {...}
  86. }
  87. }
  88. ]
  89. }
  90. """
  91. CATEGORY = "Federation requests"
  92. class PostBody(RequestBodyModel):
  93. server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]]
  94. def __init__(self, hs: "HomeServer"):
  95. self.fetcher = ServerKeyFetcher(hs)
  96. self.store = hs.get_datastores().main
  97. self.clock = hs.get_clock()
  98. self.federation_domain_whitelist = (
  99. hs.config.federation.federation_domain_whitelist
  100. )
  101. self.config = hs.config
  102. def register(self, http_server: HttpServer) -> None:
  103. http_server.register_paths(
  104. "GET",
  105. (
  106. re.compile(
  107. "^/_matrix/key/v2/query/(?P<server>[^/]*)(/(?P<key_id>[^/]*))?$"
  108. ),
  109. ),
  110. self.on_GET,
  111. self.__class__.__name__,
  112. )
  113. http_server.register_paths(
  114. "POST",
  115. (re.compile("^/_matrix/key/v2/query$"),),
  116. self.on_POST,
  117. self.__class__.__name__,
  118. )
  119. async def on_GET(
  120. self, request: Request, server: str, key_id: Optional[str] = None
  121. ) -> Tuple[int, JsonDict]:
  122. if server and key_id:
  123. # Matrix 1.6 drops support for passing the key_id, this is incompatible
  124. # with earlier versions and is allowed in order to support both.
  125. # A warning is issued to help determine when it is safe to drop this.
  126. logger.warning(
  127. "Request for remote server key with deprecated key ID (logging to determine usage level for future removal): %s / %s",
  128. server,
  129. key_id,
  130. )
  131. minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
  132. query = {
  133. server: {
  134. key_id: _KeyQueryCriteriaDataModel(
  135. minimum_valid_until_ts=minimum_valid_until_ts
  136. )
  137. }
  138. }
  139. else:
  140. query = {server: {}}
  141. return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
  142. async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
  143. content = parse_and_validate_json_object_from_request(request, self.PostBody)
  144. query = content.server_keys
  145. return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
  146. async def query_keys(
  147. self,
  148. query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]],
  149. query_remote_on_cache_miss: bool = False,
  150. ) -> JsonDict:
  151. logger.info("Handling query for keys %r", query)
  152. server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
  153. for server_name, key_ids in query.items():
  154. if key_ids:
  155. results: Mapping[
  156. str, Optional[FetchKeyResultForRemote]
  157. ] = await self.store.get_server_keys_json_for_remote(
  158. server_name, key_ids
  159. )
  160. else:
  161. results = await self.store.get_all_server_keys_json_for_remote(
  162. server_name
  163. )
  164. server_keys.update(
  165. ((server_name, key_id), res) for key_id, res in results.items()
  166. )
  167. json_results: Set[bytes] = set()
  168. time_now_ms = self.clock.time_msec()
  169. # Map server_name->key_id->int. Note that the value of the int is unused.
  170. # XXX: why don't we just use a set?
  171. cache_misses: Dict[str, Dict[str, int]] = {}
  172. for (server_name, key_id), key_result in server_keys.items():
  173. if not query[server_name]:
  174. # all keys were requested. Just return what we have without worrying
  175. # about validity
  176. if key_result:
  177. json_results.add(key_result.key_json)
  178. continue
  179. miss = False
  180. if key_result is None:
  181. miss = True
  182. else:
  183. ts_added_ms = key_result.added_ts
  184. ts_valid_until_ms = key_result.valid_until_ts
  185. req_key = query.get(server_name, {}).get(
  186. key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None)
  187. )
  188. req_valid_until = req_key.minimum_valid_until_ts
  189. if req_valid_until is not None:
  190. if ts_valid_until_ms < req_valid_until:
  191. logger.debug(
  192. "Cached response for %r/%r is older than requested"
  193. ": valid_until (%r) < minimum_valid_until (%r)",
  194. server_name,
  195. key_id,
  196. ts_valid_until_ms,
  197. req_valid_until,
  198. )
  199. miss = True
  200. else:
  201. logger.debug(
  202. "Cached response for %r/%r is newer than requested"
  203. ": valid_until (%r) >= minimum_valid_until (%r)",
  204. server_name,
  205. key_id,
  206. ts_valid_until_ms,
  207. req_valid_until,
  208. )
  209. elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
  210. logger.debug(
  211. "Cached response for %r/%r is too old"
  212. ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
  213. server_name,
  214. key_id,
  215. ts_added_ms,
  216. ts_valid_until_ms,
  217. time_now_ms,
  218. )
  219. # We more than half way through the lifetime of the
  220. # response. We should fetch a fresh copy.
  221. miss = True
  222. else:
  223. logger.debug(
  224. "Cached response for %r/%r is still valid"
  225. ": (added (%r) + valid_until (%r)) / 2 < now (%r)",
  226. server_name,
  227. key_id,
  228. ts_added_ms,
  229. ts_valid_until_ms,
  230. time_now_ms,
  231. )
  232. json_results.add(key_result.key_json)
  233. if miss and query_remote_on_cache_miss:
  234. # only bother attempting to fetch keys from servers on our whitelist
  235. if (
  236. self.federation_domain_whitelist is None
  237. or server_name in self.federation_domain_whitelist
  238. ):
  239. cache_misses.setdefault(server_name, {})[key_id] = 0
  240. # If there is a cache miss, request the missing keys, then recurse (and
  241. # ensure the result is sent).
  242. if cache_misses:
  243. await yieldable_gather_results(
  244. lambda t: self.fetcher.get_keys(*t),
  245. (
  246. (server_name, list(keys), 0)
  247. for server_name, keys in cache_misses.items()
  248. ),
  249. )
  250. return await self.query_keys(query, query_remote_on_cache_miss=False)
  251. else:
  252. signed_keys = []
  253. for key_json_raw in json_results:
  254. key_json = json_decoder.decode(key_json_raw.decode("utf-8"))
  255. for signing_key in self.config.key.key_server_signing_keys:
  256. key_json = sign_json(
  257. key_json, self.config.server.server_name, signing_key
  258. )
  259. signed_keys.append(key_json)
  260. return {"server_keys": signed_keys}