Non puoi selezionare più di 25 argomenti Gli argomenti devono iniziare con una lettera o un numero, possono includere trattini ('-') e possono essere lunghi fino a 35 caratteri.
 
 
 
 
 
 

303 righe
9.5 KiB

  1. #!/usr/bin/env python
  2. # Copyright 2022-2023 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import logging
  17. import re
  18. from collections import defaultdict
  19. from dataclasses import dataclass
  20. from typing import Dict, Iterable, Optional, Pattern, Set, Tuple
  21. import yaml
  22. from synapse.config.homeserver import HomeServerConfig
  23. from synapse.federation.transport.server import (
  24. TransportLayerServer,
  25. register_servlets as register_federation_servlets,
  26. )
  27. from synapse.http.server import HttpServer, ServletCallback
  28. from synapse.rest import ClientRestResource
  29. from synapse.rest.key.v2 import RemoteKey
  30. from synapse.server import HomeServer
  31. from synapse.storage import DataStore
  32. logger = logging.getLogger("generate_workers_map")
  33. class MockHomeserver(HomeServer):
  34. DATASTORE_CLASS = DataStore # type: ignore
  35. def __init__(self, config: HomeServerConfig, worker_app: Optional[str]) -> None:
  36. super().__init__(config.server.server_name, config=config)
  37. self.config.worker.worker_app = worker_app
  38. GROUP_PATTERN = re.compile(r"\(\?P<[^>]+?>(.+?)\)")
  39. @dataclass
  40. class EndpointDescription:
  41. """
  42. Describes an endpoint and how it should be routed.
  43. """
  44. # The servlet class that handles this endpoint
  45. servlet_class: object
  46. # The category of this endpoint. Is read from the `CATEGORY` constant in the servlet
  47. # class.
  48. category: Optional[str]
  49. # TODO:
  50. # - does it need to be routed based on a stream writer config?
  51. # - does it benefit from any optimised, but optional, routing?
  52. # - what 'opinionated synapse worker class' (event_creator, synchrotron, etc) does
  53. # it go in?
  54. class EnumerationResource(HttpServer):
  55. """
  56. Accepts servlet registrations for the purposes of building up a description of
  57. all endpoints.
  58. """
  59. def __init__(self, is_worker: bool) -> None:
  60. self.registrations: Dict[Tuple[str, str], EndpointDescription] = {}
  61. self._is_worker = is_worker
  62. def register_paths(
  63. self,
  64. method: str,
  65. path_patterns: Iterable[Pattern],
  66. callback: ServletCallback,
  67. servlet_classname: str,
  68. ) -> None:
  69. # federation servlet callbacks are wrapped, so unwrap them.
  70. callback = getattr(callback, "__wrapped__", callback)
  71. # fish out the servlet class
  72. servlet_class = callback.__self__.__class__ # type: ignore
  73. if self._is_worker and method in getattr(
  74. servlet_class, "WORKERS_DENIED_METHODS", ()
  75. ):
  76. # This endpoint would cause an error if called on a worker, so pretend it
  77. # was never registered!
  78. return
  79. sd = EndpointDescription(
  80. servlet_class=servlet_class,
  81. category=getattr(servlet_class, "CATEGORY", None),
  82. )
  83. for pat in path_patterns:
  84. self.registrations[(method, pat.pattern)] = sd
  85. def get_registered_paths_for_hs(
  86. hs: HomeServer,
  87. ) -> Dict[Tuple[str, str], EndpointDescription]:
  88. """
  89. Given a homeserver, get all registered endpoints and their descriptions.
  90. """
  91. enumerator = EnumerationResource(is_worker=hs.config.worker.worker_app is not None)
  92. ClientRestResource.register_servlets(enumerator, hs)
  93. federation_server = TransportLayerServer(hs)
  94. # we can't use `federation_server.register_servlets` but this line does the
  95. # same thing, only it uses this enumerator
  96. register_federation_servlets(
  97. federation_server.hs,
  98. resource=enumerator,
  99. ratelimiter=federation_server.ratelimiter,
  100. authenticator=federation_server.authenticator,
  101. servlet_groups=federation_server.servlet_groups,
  102. )
  103. # the key server endpoints are separate again
  104. RemoteKey(hs).register(enumerator)
  105. return enumerator.registrations
  106. def get_registered_paths_for_default(
  107. worker_app: Optional[str], base_config: HomeServerConfig
  108. ) -> Dict[Tuple[str, str], EndpointDescription]:
  109. """
  110. Given the name of a worker application and a base homeserver configuration,
  111. returns:
  112. Dict from (method, path) to EndpointDescription
  113. TODO Don't require passing in a config
  114. """
  115. hs = MockHomeserver(base_config, worker_app)
  116. # TODO We only do this to avoid an error, but don't need the database etc
  117. hs.setup()
  118. return get_registered_paths_for_hs(hs)
  119. def elide_http_methods_if_unconflicting(
  120. registrations: Dict[Tuple[str, str], EndpointDescription],
  121. all_possible_registrations: Dict[Tuple[str, str], EndpointDescription],
  122. ) -> Dict[Tuple[str, str], EndpointDescription]:
  123. """
  124. Elides HTTP methods (by replacing them with `*`) if all possible registered methods
  125. can be handled by the worker whose registration map is `registrations`.
  126. i.e. the only endpoints left with methods (other than `*`) should be the ones where
  127. the worker can't handle all possible methods for that path.
  128. """
  129. def paths_to_methods_dict(
  130. methods_and_paths: Iterable[Tuple[str, str]]
  131. ) -> Dict[str, Set[str]]:
  132. """
  133. Given (method, path) pairs, produces a dict from path to set of methods
  134. available at that path.
  135. """
  136. result: Dict[str, Set[str]] = {}
  137. for method, path in methods_and_paths:
  138. result.setdefault(path, set()).add(method)
  139. return result
  140. all_possible_reg_methods = paths_to_methods_dict(all_possible_registrations)
  141. reg_methods = paths_to_methods_dict(registrations)
  142. output = {}
  143. for path, handleable_methods in reg_methods.items():
  144. if handleable_methods == all_possible_reg_methods[path]:
  145. any_method = next(iter(handleable_methods))
  146. # TODO This assumes that all methods have the same servlet.
  147. # I suppose that's possibly dubious?
  148. output[("*", path)] = registrations[(any_method, path)]
  149. else:
  150. for method in handleable_methods:
  151. output[(method, path)] = registrations[(method, path)]
  152. return output
  153. def simplify_path_regexes(
  154. registrations: Dict[Tuple[str, str], EndpointDescription]
  155. ) -> Dict[Tuple[str, str], EndpointDescription]:
  156. """
  157. Simplify all the path regexes for the dict of endpoint descriptions,
  158. so that we don't use the Python-specific regex extensions
  159. (and also to remove needlessly specific detail).
  160. """
  161. def simplify_path_regex(path: str) -> str:
  162. """
  163. Given a regex pattern, replaces all named capturing groups (e.g. `(?P<blah>xyz)`)
  164. with a simpler version available in more common regex dialects (e.g. `.*`).
  165. """
  166. # TODO it's hard to choose between these two;
  167. # `.*` is a vague simplification
  168. # return GROUP_PATTERN.sub(r"\1", path)
  169. return GROUP_PATTERN.sub(r".*", path)
  170. return {(m, simplify_path_regex(p)): v for (m, p), v in registrations.items()}
  171. def main() -> None:
  172. parser = argparse.ArgumentParser(
  173. description=(
  174. "Updates a synapse database to the latest schema and optionally runs background updates"
  175. " on it."
  176. )
  177. )
  178. parser.add_argument("-v", action="store_true")
  179. parser.add_argument(
  180. "--config-path",
  181. type=argparse.FileType("r"),
  182. required=True,
  183. help="Synapse configuration file",
  184. )
  185. args = parser.parse_args()
  186. # TODO
  187. # logging.basicConfig(**logging_config)
  188. # Load, process and sanity-check the config.
  189. hs_config = yaml.safe_load(args.config_path)
  190. config = HomeServerConfig()
  191. config.parse_config_dict(hs_config, "", "")
  192. master_paths = get_registered_paths_for_default(None, config)
  193. worker_paths = get_registered_paths_for_default(
  194. "synapse.app.generic_worker", config
  195. )
  196. all_paths = {**master_paths, **worker_paths}
  197. elided_worker_paths = elide_http_methods_if_unconflicting(worker_paths, all_paths)
  198. elide_http_methods_if_unconflicting(master_paths, all_paths)
  199. # TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT
  200. categories_to_methods_and_paths: Dict[
  201. Optional[str], Dict[Tuple[str, str], EndpointDescription]
  202. ] = defaultdict(dict)
  203. for (method, path), desc in elided_worker_paths.items():
  204. categories_to_methods_and_paths[desc.category][method, path] = desc
  205. for category, contents in categories_to_methods_and_paths.items():
  206. print_category(category, contents)
  207. def print_category(
  208. category_name: Optional[str],
  209. elided_worker_paths: Dict[Tuple[str, str], EndpointDescription],
  210. ) -> None:
  211. """
  212. Prints out a category, in documentation page style.
  213. Example:
  214. ```
  215. # Category name
  216. /path/xyz
  217. GET /path/abc
  218. ```
  219. """
  220. if category_name:
  221. print(f"# {category_name}")
  222. else:
  223. print("# (Uncategorised requests)")
  224. for ln in sorted(
  225. p for m, p in simplify_path_regexes(elided_worker_paths) if m == "*"
  226. ):
  227. print(ln)
  228. print()
  229. for ln in sorted(
  230. f"{m:6} {p}" for m, p in simplify_path_regexes(elided_worker_paths) if m != "*"
  231. ):
  232. print(ln)
  233. print()
  234. if __name__ == "__main__":
  235. main()