You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

300 lines
11 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2019 New Vector Ltd.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import itertools
  16. import json
  17. import logging
  18. from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
  19. from signedjson.key import decode_verify_key_bytes
  20. from unpaddedbase64 import decode_base64
  21. from synapse.storage._base import SQLBaseStore
  22. from synapse.storage.database import LoggingTransaction
  23. from synapse.storage.keys import FetchKeyResult
  24. from synapse.storage.types import Cursor
  25. from synapse.util.caches.descriptors import cached, cachedList
  26. from synapse.util.iterutils import batch_iter
  27. logger = logging.getLogger(__name__)
  28. db_binary_type = memoryview
  29. class KeyStore(SQLBaseStore):
  30. """Persistence for signature verification keys"""
  31. @cached()
  32. def _get_server_signature_key(
  33. self, server_name_and_key_id: Tuple[str, str]
  34. ) -> FetchKeyResult:
  35. raise NotImplementedError()
  36. @cachedList(
  37. cached_method_name="_get_server_signature_key",
  38. list_name="server_name_and_key_ids",
  39. )
  40. async def get_server_signature_keys(
  41. self, server_name_and_key_ids: Iterable[Tuple[str, str]]
  42. ) -> Dict[Tuple[str, str], FetchKeyResult]:
  43. """
  44. Args:
  45. server_name_and_key_ids:
  46. iterable of (server_name, key-id) tuples to fetch keys for
  47. Returns:
  48. A map from (server_name, key_id) -> FetchKeyResult, or None if the
  49. key is unknown
  50. """
  51. keys = {}
  52. def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
  53. """Processes a batch of keys to fetch, and adds the result to `keys`."""
  54. # batch_iter always returns tuples so it's safe to do len(batch)
  55. sql = """
  56. SELECT server_name, key_id, verify_key, ts_valid_until_ms
  57. FROM server_signature_keys WHERE 1=0
  58. """ + " OR (server_name=? AND key_id=?)" * len(
  59. batch
  60. )
  61. txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
  62. for row in txn:
  63. server_name, key_id, key_bytes, ts_valid_until_ms = row
  64. if ts_valid_until_ms is None:
  65. # Old keys may be stored with a ts_valid_until_ms of null,
  66. # in which case we treat this as if it was set to `0`, i.e.
  67. # it won't match key requests that define a minimum
  68. # `ts_valid_until_ms`.
  69. ts_valid_until_ms = 0
  70. keys[(server_name, key_id)] = FetchKeyResult(
  71. verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
  72. valid_until_ts=ts_valid_until_ms,
  73. )
  74. def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
  75. for batch in batch_iter(server_name_and_key_ids, 50):
  76. _get_keys(txn, batch)
  77. return keys
  78. return await self.db_pool.runInteraction("get_server_signature_keys", _txn)
  79. async def store_server_signature_keys(
  80. self,
  81. from_server: str,
  82. ts_added_ms: int,
  83. verify_keys: Mapping[Tuple[str, str], FetchKeyResult],
  84. ) -> None:
  85. """Stores NACL verification keys for remote servers.
  86. Args:
  87. from_server: Where the verification keys were looked up
  88. ts_added_ms: The time to record that the key was added
  89. verify_keys:
  90. keys to be stored. Each entry is a triplet of
  91. (server_name, key_id, key).
  92. """
  93. key_values = []
  94. value_values = []
  95. invalidations = []
  96. for (server_name, key_id), fetch_result in verify_keys.items():
  97. key_values.append((server_name, key_id))
  98. value_values.append(
  99. (
  100. from_server,
  101. ts_added_ms,
  102. fetch_result.valid_until_ts,
  103. db_binary_type(fetch_result.verify_key.encode()),
  104. )
  105. )
  106. # invalidate takes a tuple corresponding to the params of
  107. # _get_server_signature_key. _get_server_signature_key only takes one
  108. # param, which is itself the 2-tuple (server_name, key_id).
  109. invalidations.append((server_name, key_id))
  110. await self.db_pool.simple_upsert_many(
  111. table="server_signature_keys",
  112. key_names=("server_name", "key_id"),
  113. key_values=key_values,
  114. value_names=(
  115. "from_server",
  116. "ts_added_ms",
  117. "ts_valid_until_ms",
  118. "verify_key",
  119. ),
  120. value_values=value_values,
  121. desc="store_server_signature_keys",
  122. )
  123. invalidate = self._get_server_signature_key.invalidate
  124. for i in invalidations:
  125. invalidate((i,))
  126. async def store_server_keys_json(
  127. self,
  128. server_name: str,
  129. key_id: str,
  130. from_server: str,
  131. ts_now_ms: int,
  132. ts_expires_ms: int,
  133. key_json_bytes: bytes,
  134. ) -> None:
  135. """Stores the JSON bytes for a set of keys from a server
  136. The JSON should be signed by the originating server, the intermediate
  137. server, and by this server. Updates the value for the
  138. (server_name, key_id, from_server) triplet if one already existed.
  139. Args:
  140. server_name: The name of the server.
  141. key_id: The identifier of the key this JSON is for.
  142. from_server: The server this JSON was fetched from.
  143. ts_now_ms: The time now in milliseconds.
  144. ts_valid_until_ms: The time when this json stops being valid.
  145. key_json_bytes: The encoded JSON.
  146. """
  147. await self.db_pool.simple_upsert(
  148. table="server_keys_json",
  149. keyvalues={
  150. "server_name": server_name,
  151. "key_id": key_id,
  152. "from_server": from_server,
  153. },
  154. values={
  155. "server_name": server_name,
  156. "key_id": key_id,
  157. "from_server": from_server,
  158. "ts_added_ms": ts_now_ms,
  159. "ts_valid_until_ms": ts_expires_ms,
  160. "key_json": db_binary_type(key_json_bytes),
  161. },
  162. desc="store_server_keys_json",
  163. )
  164. # invalidate takes a tuple corresponding to the params of
  165. # _get_server_keys_json. _get_server_keys_json only takes one
  166. # param, which is itself the 2-tuple (server_name, key_id).
  167. self._get_server_keys_json.invalidate(((server_name, key_id),))
  168. @cached()
  169. def _get_server_keys_json(
  170. self, server_name_and_key_id: Tuple[str, str]
  171. ) -> FetchKeyResult:
  172. raise NotImplementedError()
  173. @cachedList(
  174. cached_method_name="_get_server_keys_json", list_name="server_name_and_key_ids"
  175. )
  176. async def get_server_keys_json(
  177. self, server_name_and_key_ids: Iterable[Tuple[str, str]]
  178. ) -> Dict[Tuple[str, str], FetchKeyResult]:
  179. """
  180. Args:
  181. server_name_and_key_ids:
  182. iterable of (server_name, key-id) tuples to fetch keys for
  183. Returns:
  184. A map from (server_name, key_id) -> FetchKeyResult, or None if the
  185. key is unknown
  186. """
  187. keys = {}
  188. def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
  189. """Processes a batch of keys to fetch, and adds the result to `keys`."""
  190. # batch_iter always returns tuples so it's safe to do len(batch)
  191. sql = """
  192. SELECT server_name, key_id, key_json, ts_valid_until_ms
  193. FROM server_keys_json WHERE 1=0
  194. """ + " OR (server_name=? AND key_id=?)" * len(
  195. batch
  196. )
  197. txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
  198. for server_name, key_id, key_json_bytes, ts_valid_until_ms in txn:
  199. if ts_valid_until_ms is None:
  200. # Old keys may be stored with a ts_valid_until_ms of null,
  201. # in which case we treat this as if it was set to `0`, i.e.
  202. # it won't match key requests that define a minimum
  203. # `ts_valid_until_ms`.
  204. ts_valid_until_ms = 0
  205. # The entire signed JSON response is stored in server_keys_json,
  206. # fetch out the bits needed.
  207. key_json = json.loads(bytes(key_json_bytes))
  208. key_base64 = key_json["verify_keys"][key_id]["key"]
  209. keys[(server_name, key_id)] = FetchKeyResult(
  210. verify_key=decode_verify_key_bytes(
  211. key_id, decode_base64(key_base64)
  212. ),
  213. valid_until_ts=ts_valid_until_ms,
  214. )
  215. def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
  216. for batch in batch_iter(server_name_and_key_ids, 50):
  217. _get_keys(txn, batch)
  218. return keys
  219. return await self.db_pool.runInteraction("get_server_keys_json", _txn)
  220. async def get_server_keys_json_for_remote(
  221. self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
  222. ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
  223. """Retrieve the key json for a list of server_keys and key ids.
  224. If no keys are found for a given server, key_id and source then
  225. that server, key_id, and source triplet entry will be an empty list.
  226. The JSON is returned as a byte array so that it can be efficiently
  227. used in an HTTP response.
  228. Args:
  229. server_keys: List of (server_name, key_id, source) triplets.
  230. Returns:
  231. A mapping from (server_name, key_id, source) triplets to a list of dicts
  232. """
  233. def _get_server_keys_json_txn(
  234. txn: LoggingTransaction,
  235. ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
  236. results = {}
  237. for server_name, key_id, from_server in server_keys:
  238. keyvalues = {"server_name": server_name}
  239. if key_id is not None:
  240. keyvalues["key_id"] = key_id
  241. if from_server is not None:
  242. keyvalues["from_server"] = from_server
  243. rows = self.db_pool.simple_select_list_txn(
  244. txn,
  245. "server_keys_json",
  246. keyvalues=keyvalues,
  247. retcols=(
  248. "key_id",
  249. "from_server",
  250. "ts_added_ms",
  251. "ts_valid_until_ms",
  252. "key_json",
  253. ),
  254. )
  255. results[(server_name, key_id, from_server)] = rows
  256. return results
  257. return await self.db_pool.runInteraction(
  258. "get_server_keys_json", _get_server_keys_json_txn
  259. )