Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.
 
 
 
 
 
 

258 рядки
8.0 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2020 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import itertools
  16. import re
  17. import secrets
  18. import string
  19. from typing import Any, Iterable, Optional, Tuple
  20. from netaddr import valid_ipv6
  21. from synapse.api.errors import Codes, SynapseError
  22. _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
  23. # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
  24. CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
  25. # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
  26. # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
  27. # says "there is no grammar for media ids"
  28. #
  29. # The server_name part of this is purposely lax: use parse_and_validate_mxc for
  30. # additional validation.
  31. #
  32. MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
  33. def random_string(length: int) -> str:
  34. """Generate a cryptographically secure string of random letters.
  35. Drawn from the characters: `a-z` and `A-Z`
  36. """
  37. return "".join(secrets.choice(string.ascii_letters) for _ in range(length))
  38. def random_string_with_symbols(length: int) -> str:
  39. """Generate a cryptographically secure string of random letters/numbers/symbols.
  40. Drawn from the characters: `a-z`, `A-Z`, `0-9`, and `.,;:^&*-_+=#~@`
  41. """
  42. return "".join(secrets.choice(_string_with_symbols) for _ in range(length))
  43. def is_ascii(s: bytes) -> bool:
  44. try:
  45. s.decode("ascii").encode("ascii")
  46. except UnicodeError:
  47. return False
  48. return True
  49. def assert_valid_client_secret(client_secret: str) -> None:
  50. """Validate that a given string matches the client_secret defined by the spec"""
  51. if (
  52. len(client_secret) <= 0
  53. or len(client_secret) > 255
  54. or CLIENT_SECRET_REGEX.match(client_secret) is None
  55. ):
  56. raise SynapseError(
  57. 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
  58. )
  59. def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
  60. """Split a server name into host/port parts.
  61. Args:
  62. server_name: server name to parse
  63. Returns:
  64. host/port parts.
  65. Raises:
  66. ValueError if the server name could not be parsed.
  67. """
  68. try:
  69. if server_name and server_name[-1] == "]":
  70. # ipv6 literal, hopefully
  71. return server_name, None
  72. domain_port = server_name.rsplit(":", 1)
  73. domain = domain_port[0]
  74. port = int(domain_port[1]) if domain_port[1:] else None
  75. return domain, port
  76. except Exception:
  77. raise ValueError("Invalid server name '%s'" % server_name)
  78. # An approximation of the domain name syntax in RFC 1035, section 2.3.1.
  79. # NB: "\Z" is not equivalent to "$".
  80. # The latter will match the position before a "\n" at the end of a string.
  81. VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*\\Z")
  82. def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
  83. """Split a server name into host/port parts and do some basic validation.
  84. Args:
  85. server_name: server name to parse
  86. Returns:
  87. host/port parts.
  88. Raises:
  89. ValueError if the server name could not be parsed.
  90. """
  91. host, port = parse_server_name(server_name)
  92. # these tests don't need to be bulletproof as we'll find out soon enough
  93. # if somebody is giving us invalid data. What we *do* need is to be sure
  94. # that nobody is sneaking IP literals in that look like hostnames, etc.
  95. # look for ipv6 literals
  96. if host and host[0] == "[":
  97. if host[-1] != "]":
  98. raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
  99. # valid_ipv6 raises when given an empty string
  100. ipv6_address = host[1:-1]
  101. if not ipv6_address or not valid_ipv6(ipv6_address):
  102. raise ValueError(
  103. "Server name '%s' is not a valid IPv6 address" % (server_name,)
  104. )
  105. elif not VALID_HOST_REGEX.match(host):
  106. raise ValueError("Server name '%s' has an invalid format" % (server_name,))
  107. return host, port
  108. def valid_id_server_location(id_server: str) -> bool:
  109. """Check whether an identity server location, such as the one passed as the
  110. `id_server` parameter to `/_matrix/client/r0/account/3pid/bind`, is valid.
  111. A valid identity server location consists of a valid hostname and optional
  112. port number, optionally followed by any number of `/` delimited path
  113. components, without any fragment or query string parts.
  114. Args:
  115. id_server: identity server location string to validate
  116. Returns:
  117. True if valid, False otherwise.
  118. """
  119. components = id_server.split("/", 1)
  120. host = components[0]
  121. try:
  122. parse_and_validate_server_name(host)
  123. except ValueError:
  124. return False
  125. if len(components) < 2:
  126. # no path
  127. return True
  128. path = components[1]
  129. return "#" not in path and "?" not in path
  130. def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
  131. """Parse the given string as an MXC URI
  132. Checks that the "server name" part is a valid server name
  133. Args:
  134. mxc: the (alleged) MXC URI to be checked
  135. Returns:
  136. hostname, port, media id
  137. Raises:
  138. ValueError if the URI cannot be parsed
  139. """
  140. m = MXC_REGEX.match(mxc)
  141. if not m:
  142. raise ValueError("mxc URI %r did not match expected format" % (mxc,))
  143. server_name = m.group(1)
  144. media_id = m.group(2)
  145. host, port = parse_and_validate_server_name(server_name)
  146. return host, port, media_id
  147. def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
  148. """If iterable has maxitems or fewer, return the stringification of a list
  149. containing those items.
  150. Otherwise, return the stringification of a list with the first maxitems items,
  151. followed by "...".
  152. Args:
  153. iterable: iterable to truncate
  154. maxitems: number of items to return before truncating
  155. """
  156. items = list(itertools.islice(iterable, maxitems + 1))
  157. if len(items) <= maxitems:
  158. return str(items)
  159. return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
  160. def strtobool(val: str) -> bool:
  161. """Convert a string representation of truth to True or False
  162. True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
  163. are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
  164. 'val' is anything else.
  165. This is lifted from distutils.util.strtobool, with the exception that it actually
  166. returns a bool, rather than an int.
  167. """
  168. val = val.lower()
  169. if val in ("y", "yes", "t", "true", "on", "1"):
  170. return True
  171. elif val in ("n", "no", "f", "false", "off", "0"):
  172. return False
  173. else:
  174. raise ValueError("invalid truth value %r" % (val,))
  175. _BASE62 = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
  176. def base62_encode(num: int, minwidth: int = 1) -> str:
  177. """Encode a number using base62
  178. Args:
  179. num: number to be encoded
  180. minwidth: width to pad to, if the number is small
  181. """
  182. res = ""
  183. while num:
  184. num, rem = divmod(num, 62)
  185. res = _BASE62[rem] + res
  186. # pad to minimum width
  187. pad = "0" * (minwidth - len(res))
  188. return pad + res
  189. def non_null_str_or_none(val: Any) -> Optional[str]:
  190. """Check that the arg is a string containing no null (U+0000) codepoints.
  191. If so, returns the given string unmodified; otherwise, returns None.
  192. """
  193. return val if isinstance(val, str) and "\u0000" not in val else None