Browse Source

Send the opentracing span information to appservices (#16227)

tags/v1.93.0rc1
Marcel 8 months ago
committed by GitHub
parent
commit
13e9cad537
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 14 deletions
  1. +1
    -0
      changelog.d/16227.feature
  2. +24
    -8
      synapse/appservice/api.py
  3. +12
    -6
      tests/appservice/test_api.py

+ 1
- 0
changelog.d/16227.feature View File

@@ -0,0 +1 @@
Add span information to requests sent to appservices. Contributed by MTRNord.

+ 24
- 8
synapse/appservice/api.py View File

@@ -40,6 +40,7 @@ from synapse.appservice import (
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.events.utils import SerializeEventConfig, serialize_event
from synapse.http.client import SimpleHttpClient, is_unknown_endpoint from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
from synapse.logging import opentracing
from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache


@@ -125,6 +126,17 @@ class ApplicationServiceApi(SimpleHttpClient):
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
) )


def _get_headers(self, service: "ApplicationService") -> Dict[bytes, List[bytes]]:
"""This makes sure we have always the auth header and opentracing headers set."""

# This is also ensured before in the functions. However this is needed to please
# the typechecks.
assert service.hs_token is not None

headers = {b"Authorization": [b"Bearer " + service.hs_token.encode("ascii")]}
opentracing.inject_header_dict(headers, check_destination=False)
return headers

async def query_user(self, service: "ApplicationService", user_id: str) -> bool: async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
if service.url is None: if service.url is None:
return False return False
@@ -136,10 +148,11 @@ class ApplicationServiceApi(SimpleHttpClient):
args = None args = None
if self.config.use_appservice_legacy_authorization: if self.config.use_appservice_legacy_authorization:
args = {"access_token": service.hs_token} args = {"access_token": service.hs_token}

response = await self.get_json( response = await self.get_json(
f"{service.url}{APP_SERVICE_PREFIX}/users/{urllib.parse.quote(user_id)}", f"{service.url}{APP_SERVICE_PREFIX}/users/{urllib.parse.quote(user_id)}",
args, args,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
headers=self._get_headers(service),
) )
if response is not None: # just an empty json object if response is not None: # just an empty json object
return True return True
@@ -162,10 +175,11 @@ class ApplicationServiceApi(SimpleHttpClient):
args = None args = None
if self.config.use_appservice_legacy_authorization: if self.config.use_appservice_legacy_authorization:
args = {"access_token": service.hs_token} args = {"access_token": service.hs_token}

response = await self.get_json( response = await self.get_json(
f"{service.url}{APP_SERVICE_PREFIX}/rooms/{urllib.parse.quote(alias)}", f"{service.url}{APP_SERVICE_PREFIX}/rooms/{urllib.parse.quote(alias)}",
args, args,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
headers=self._get_headers(service),
) )
if response is not None: # just an empty json object if response is not None: # just an empty json object
return True return True
@@ -203,10 +217,11 @@ class ApplicationServiceApi(SimpleHttpClient):
**fields, **fields,
b"access_token": service.hs_token, b"access_token": service.hs_token,
} }

response = await self.get_json( response = await self.get_json(
f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/{kind}/{urllib.parse.quote(protocol)}", f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/{kind}/{urllib.parse.quote(protocol)}",
args=args, args=args,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
headers=self._get_headers(service),
) )
if not isinstance(response, list): if not isinstance(response, list):
logger.warning( logger.warning(
@@ -243,10 +258,11 @@ class ApplicationServiceApi(SimpleHttpClient):
args = None args = None
if self.config.use_appservice_legacy_authorization: if self.config.use_appservice_legacy_authorization:
args = {"access_token": service.hs_token} args = {"access_token": service.hs_token}

info = await self.get_json( info = await self.get_json(
f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/protocol/{urllib.parse.quote(protocol)}", f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/protocol/{urllib.parse.quote(protocol)}",
args, args,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
headers=self._get_headers(service),
) )


if not _is_valid_3pe_metadata(info): if not _is_valid_3pe_metadata(info):
@@ -283,7 +299,7 @@ class ApplicationServiceApi(SimpleHttpClient):
await self.post_json_get_json( await self.post_json_get_json(
uri=f"{service.url}{APP_SERVICE_PREFIX}/ping", uri=f"{service.url}{APP_SERVICE_PREFIX}/ping",
post_json={"transaction_id": txn_id}, post_json={"transaction_id": txn_id},
headers={"Authorization": [f"Bearer {service.hs_token}"]},
headers=self._get_headers(service),
) )


async def push_bulk( async def push_bulk(
@@ -364,7 +380,7 @@ class ApplicationServiceApi(SimpleHttpClient):
f"{service.url}{APP_SERVICE_PREFIX}/transactions/{urllib.parse.quote(str(txn_id))}", f"{service.url}{APP_SERVICE_PREFIX}/transactions/{urllib.parse.quote(str(txn_id))}",
json_body=body, json_body=body,
args=args, args=args,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
headers=self._get_headers(service),
) )
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug( logger.debug(
@@ -437,7 +453,7 @@ class ApplicationServiceApi(SimpleHttpClient):
response = await self.post_json_get_json( response = await self.post_json_get_json(
uri, uri,
body, body,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
headers=self._get_headers(service),
) )
except HttpResponseException as e: except HttpResponseException as e:
# The appservice doesn't support this endpoint. # The appservice doesn't support this endpoint.
@@ -498,7 +514,7 @@ class ApplicationServiceApi(SimpleHttpClient):
response = await self.post_json_get_json( response = await self.post_json_get_json(
uri, uri,
query, query,
headers={"Authorization": [f"Bearer {service.hs_token}"]},
headers=self._get_headers(service),
) )
except HttpResponseException as e: except HttpResponseException as e:
# The appservice doesn't support this endpoint. # The appservice doesn't support this endpoint.


+ 12
- 6
tests/appservice/test_api.py View File

@@ -76,7 +76,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
) -> List[JsonDict]: ) -> List[JsonDict]:
# Ensure the access token is passed as a header. # Ensure the access token is passed as a header.
if not headers or not headers.get("Authorization"):
if not headers or not headers.get(b"Authorization"):
raise RuntimeError("Access token not provided") raise RuntimeError("Access token not provided")
# ... and not as a query param # ... and not as a query param
if b"access_token" in args: if b"access_token" in args:
@@ -84,7 +84,9 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
"Access token should not be passed as a query param." "Access token should not be passed as a query param."
) )


self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
self.assertEqual(
headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
)
self.request_url = url self.request_url = url
if url == URL_USER: if url == URL_USER:
return SUCCESS_RESULT_USER return SUCCESS_RESULT_USER
@@ -152,11 +154,13 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
# Ensure the access token is passed as a both a query param and in the headers. # Ensure the access token is passed as a both a query param and in the headers.
if not args.get(b"access_token"): if not args.get(b"access_token"):
raise RuntimeError("Access token should be provided in query params.") raise RuntimeError("Access token should be provided in query params.")
if not headers or not headers.get("Authorization"):
if not headers or not headers.get(b"Authorization"):
raise RuntimeError("Access token should be provided in auth headers.") raise RuntimeError("Access token should be provided in auth headers.")


self.assertEqual(args.get(b"access_token"), TOKEN) self.assertEqual(args.get(b"access_token"), TOKEN)
self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
self.assertEqual(
headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
)
self.request_url = url self.request_url = url
if url == URL_USER: if url == URL_USER:
return SUCCESS_RESULT_USER return SUCCESS_RESULT_USER
@@ -208,10 +212,12 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
) -> JsonDict: ) -> JsonDict:
# Ensure the access token is passed as both a header and query arg. # Ensure the access token is passed as both a header and query arg.
if not headers.get("Authorization"):
if not headers.get(b"Authorization"):
raise RuntimeError("Access token not provided") raise RuntimeError("Access token not provided")


self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
self.assertEqual(
headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()]
)
return RESPONSE return RESPONSE


# We assign to a method, which mypy doesn't like. # We assign to a method, which mypy doesn't like.


Loading…
Cancel
Save