瀏覽代碼

Reduce the number of "untyped defs" (#12716)

tags/v1.60.0rc1
David Robertson 2 年之前
committed by GitHub
父節點
當前提交
17e1eb7749
沒有發現已知的金鑰在資料庫的簽署中 GPG 金鑰 ID: 4AEE18F83AFDEB23
共有 16 個檔案被更改,包括 142 行新增69 行删除
  1. +1
    -0
      changelog.d/12716.misc
  2. +24
    -0
      mypy.ini
  3. +1
    -1
      synapse/groups/groups_server.py
  4. +10
    -6
      synapse/http/client.py
  5. +1
    -1
      synapse/http/federation/matrix_federation_agent.py
  6. +2
    -2
      synapse/http/federation/srv_resolver.py
  7. +3
    -3
      synapse/http/federation/well_known_resolver.py
  8. +19
    -12
      synapse/http/matrixfederationclient.py
  9. +5
    -5
      synapse/http/request_metrics.py
  10. +31
    -13
      synapse/storage/database.py
  11. +14
    -10
      synapse/storage/databases/main/metrics.py
  12. +5
    -3
      synapse/storage/databases/main/stream.py
  13. +13
    -8
      synapse/storage/persist_events.py
  14. +1
    -1
      synapse/storage/prepare_database.py
  15. +4
    -2
      synapse/storage/state.py
  16. +8
    -2
      synapse/storage/types.py

+ 1
- 0
changelog.d/12716.misc 查看文件

@@ -0,0 +1 @@
Add type annotations to increase the number of modules passing `disallow-untyped-defs`.

+ 24
- 0
mypy.ini 查看文件

@@ -119,9 +119,18 @@ disallow_untyped_defs = True
[mypy-synapse.federation.transport.client]
disallow_untyped_defs = False

[mypy-synapse.groups.*]
disallow_untyped_defs = True

[mypy-synapse.handlers.*]
disallow_untyped_defs = True

[mypy-synapse.http.federation.*]
disallow_untyped_defs = True

[mypy-synapse.http.request_metrics]
disallow_untyped_defs = True

[mypy-synapse.http.server]
disallow_untyped_defs = True

@@ -196,12 +205,27 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.stream]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.transactions]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.user_erasure_store]
disallow_untyped_defs = True

[mypy-synapse.storage.prepare_database]
disallow_untyped_defs = True

[mypy-synapse.storage.persist_events]
disallow_untyped_defs = True

[mypy-synapse.storage.state]
disallow_untyped_defs = True

[mypy-synapse.storage.types]
disallow_untyped_defs = True

[mypy-synapse.storage.util.*]
disallow_untyped_defs = True



+ 1
- 1
synapse/groups/groups_server.py 查看文件

@@ -934,7 +934,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
# Before deleting the group lets kick everyone out of it
users = await self.store.get_users_in_group(group_id, include_private=True)

async def _kick_user_from_group(user_id):
async def _kick_user_from_group(user_id: str) -> None:
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
assert isinstance(


+ 10
- 6
synapse/http/client.py 查看文件

@@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IDelayedCall,
IHostResolution,
IReactorPluggableNameResolver,
IReactorTime,
IResolutionReceiver,
ITCPTransport,
)
@@ -121,13 +123,15 @@ def check_against_blacklist(
_EPSILON = 0.00000001


def _make_scheduler(reactor):
def _make_scheduler(
reactor: IReactorTime,
) -> Callable[[Callable[[], object]], IDelayedCall]:
"""Makes a schedular suitable for a Cooperator using the given reactor.

(This is effectively just a copy from `twisted.internet.task`)
"""

def _scheduler(x):
def _scheduler(x: Callable[[], object]) -> IDelayedCall:
return reactor.callLater(_EPSILON, x)

return _scheduler
@@ -775,7 +779,7 @@ class SimpleHttpClient:
)


def _timeout_to_request_timed_out_error(f: Failure):
def _timeout_to_request_timed_out_error(f: Failure) -> Failure:
if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
# The TCP connection has its own timeout (set by the 'connectTimeout' param
# on the Agent), which raises twisted_error.TimeoutError exception.
@@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred

def _maybe_fail(self):
def _maybe_fail(self) -> None:
"""
Report a max size exceed error and disconnect the first time this is called.
"""
@@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
Do not use this since it allows an attacker to intercept your communications.
"""

def __init__(self):
def __init__(self) -> None:
self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: False)

def getContext(self, hostname=None, port=None):
return self._context

def creatorForNetloc(self, hostname, port):
def creatorForNetloc(self, hostname: bytes, port: int):
return self

+ 1
- 1
synapse/http/federation/matrix_federation_agent.py 查看文件

@@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory:

self._srv_resolver = srv_resolver

def endpointForURI(self, parsed_uri: URI):
def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint":
return MatrixHostnameEndpoint(
self._reactor,
self._proxy_reactor,


+ 2
- 2
synapse/http/federation/srv_resolver.py 查看文件

@@ -16,7 +16,7 @@
import logging
import random
import time
from typing import Callable, Dict, List
from typing import Any, Callable, Dict, List

import attr

@@ -109,7 +109,7 @@ class SrvResolver:

def __init__(
self,
dns_client=client,
dns_client: Any = client,
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
get_time: Callable[[], float] = time.time,
):


+ 3
- 3
synapse/http/federation/well_known_resolver.py 查看文件

@@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")


@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class WellKnownLookupResult:
delegated_server = attr.ib()
delegated_server: Optional[bytes]


class WellKnownResolver:
@@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
class _FetchWellKnownFailure(Exception):
# True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
# a temporary failure.
temporary = attr.ib()
temporary: bool = attr.ib()

+ 19
- 12
synapse/http/matrixfederationclient.py 查看文件

@@ -23,6 +23,8 @@ from http import HTTPStatus
from io import BytesIO, StringIO
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Dict,
Generic,
@@ -44,7 +46,7 @@ from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
@@ -58,11 +60,13 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
BodyExceededMaxSize,
ByteWriteable,
_make_scheduler,
encode_query_args,
read_body_with_max_size,
)
@@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):

CONTENT_TYPE = "application/json"

def __init__(self):
def __init__(self) -> None:
self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer)

@@ -299,7 +303,9 @@ async def _handle_response(
class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly."""

def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"):
def __init__(
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
):
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file

@@ -317,7 +323,11 @@ class MatrixFederationHttpClient:
requests.
"""

def __init__(self, hs: "HomeServer", tls_client_options_factory):
def __init__(
self,
hs: "HomeServer",
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
):
self.hs = hs
self.signing_key = hs.signing_key
self.server_name = hs.hostname
@@ -348,10 +358,7 @@ class MatrixFederationHttpClient:
self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout = 60

def schedule(x):
self.reactor.callLater(_EPSILON, x)

self._cooperator = Cooperator(scheduler=schedule)
self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor))

self._sleeper = AwakenableSleeper(self.reactor)

@@ -364,7 +371,7 @@ class MatrixFederationHttpClient:
self,
request: MatrixFederationRequest,
try_trailing_slash_on_400: bool = False,
**send_request_args,
**send_request_args: Any,
) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
@@ -1159,7 +1166,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
output_stream,
output_stream: BinaryIO,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
@@ -1250,10 +1257,10 @@ class MatrixFederationHttpClient:
return length, headers


def _flatten_response_never_received(e):
def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"):
reasons = ", ".join(
_flatten_response_never_received(f.value) for f in e.reasons
_flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined]
)

return "%s:[%s]" % (type(e).__name__, reasons)


+ 5
- 5
synapse/http/request_metrics.py 查看文件

@@ -162,7 +162,7 @@ class RequestMetrics:
with _in_flight_requests_lock:
_in_flight_requests.add(self)

def stop(self, time_sec, response_code, sent_bytes):
def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None:
with _in_flight_requests_lock:
_in_flight_requests.discard(self)

@@ -186,13 +186,13 @@ class RequestMetrics:
)
return

response_code = str(response_code)
response_code_str = str(response_code)

outgoing_responses_counter.labels(self.method, response_code).inc()
outgoing_responses_counter.labels(self.method, response_code_str).inc()

response_count.labels(self.method, self.name, tag).inc()

response_timer.labels(self.method, self.name, tag, response_code).observe(
response_timer.labels(self.method, self.name, tag, response_code_str).observe(
time_sec - self.start_ts
)

@@ -221,7 +221,7 @@ class RequestMetrics:
# flight.
self.update_metrics()

def update_metrics(self):
def update_metrics(self) -> None:
"""Updates the in flight metrics with values from this request."""
if not self.start_context:
logger.error(


+ 31
- 13
synapse/storage/database.py 查看文件

@@ -31,6 +31,7 @@ from typing import (
List,
Optional,
Tuple,
Type,
TypeVar,
cast,
overload,
@@ -41,6 +42,7 @@ from prometheus_client import Histogram
from typing_extensions import Concatenate, Literal, ParamSpec

from twisted.enterprise import adbapi
from twisted.internet.interfaces import IReactorCore

from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {


def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
reactor: IReactorCore,
db_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database."""

@@ -101,7 +105,7 @@ def make_pool(
db_args = dict(db_config.config.get("args", {}))
db_args.setdefault("cp_reconnect", True)

def _on_new_connection(conn):
def _on_new_connection(conn: Connection) -> None:
# Ensure we have a logging context so we can correctly track queries,
# etc.
with LoggingContext("db.on_new_connection"):
@@ -157,7 +161,11 @@ class LoggingDatabaseConnection:
default_txn_name: str

def cursor(
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
self,
*,
txn_name: Optional[str] = None,
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
) -> "LoggingTransaction":
if not txn_name:
txn_name = self.default_txn_name
@@ -183,11 +191,16 @@ class LoggingDatabaseConnection:
self.conn.__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> Optional[bool]:
return self.conn.__exit__(exc_type, exc_value, traceback)

# Proxy through any unknown lookups to the DB conn class.
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
return getattr(self.conn, name)


@@ -391,17 +404,22 @@ class LoggingTransaction:
def __enter__(self) -> "LoggingTransaction":
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> None:
self.close()


class PerformanceCounters:
def __init__(self):
self.current_counters = {}
self.previous_counters = {}
def __init__(self) -> None:
self.current_counters: Dict[str, Tuple[int, float]] = {}
self.previous_counters: Dict[str, Tuple[int, float]] = {}

def update(self, key: str, duration_secs: float) -> None:
count, cum_time = self.current_counters.get(key, (0, 0))
count, cum_time = self.current_counters.get(key, (0, 0.0))
count += 1
cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
@@ -527,7 +545,7 @@ class DatabasePool:
def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time()

def loop():
def loop() -> None:
curr = self._current_txn_total_time
prev = self._previous_txn_total_time
self._previous_txn_total_time = curr
@@ -1186,7 +1204,7 @@ class DatabasePool:
if lock:
self.engine.lock_table(txn, table)

def _getwhere(key):
def _getwhere(key: str) -> str:
# If the value we're passing in is None (aka NULL), we need to use
# IS, not =, as NULL = NULL equals NULL (False).
if keyvalues[key] is None:
@@ -2258,7 +2276,7 @@ class DatabasePool:
term: Optional[str],
col: str,
retcols: Collection[str],
desc="simple_search_list",
desc: str = "simple_search_list",
) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.


+ 14
- 10
synapse/storage/databases/main/metrics.py 查看文件

@@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
from synapse.storage.types import Cursor

if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
self._last_user_visit_update = self._get_start_of_day()

@wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self):
async def _read_forward_extremities(self) -> None:
def fetch(txn):
txn.execute(
"""
@@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)

async def count_daily_e2ee_messages(self):
async def count_daily_e2ee_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.

@@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):

return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)

async def count_daily_sent_e2ee_messages(self):
async def count_daily_sent_e2ee_messages(self) -> int:
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
@@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_sent_e2ee_messages", _count_messages
)

async def count_daily_active_e2ee_rooms(self):
async def count_daily_active_e2ee_rooms(self) -> int:
def _count(txn):
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
@@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_active_e2ee_rooms", _count
)

async def count_daily_messages(self):
async def count_daily_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.

@@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):

return await self.db_pool.runInteraction("count_messages", _count_messages)

async def count_daily_sent_messages(self):
async def count_daily_sent_messages(self) -> int:
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
@@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_sent_messages", _count_messages
)

async def count_daily_active_rooms(self):
async def count_daily_active_rooms(self) -> int:
def _count(txn):
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
@@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_monthly_users", self._count_users, thirty_days_ago
)

def _count_users(self, txn, time_from):
def _count_users(self, txn: Cursor, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
@@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) u
"""
txn.execute(sql, (time_from,))
(count,) = txn.fetchone()
# Mypy knows that fetchone() might return None if there are no rows.
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
# returns exactly one row.
(count,) = txn.fetchone() # type: ignore[misc]
return count

async def count_r30_users(self) -> Dict[str, int]:
@@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_r30v2_users", _count_r30v2_users
)

def _get_start_of_day(self):
def _get_start_of_day(self) -> int:
"""
Returns millisecond unixtime for start of UTC day.
"""


+ 5
- 3
synapse/storage/databases/main/stream.py 查看文件

@@ -798,9 +798,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self,
txn: LoggingTransaction,
event_id: str,
allow_none=False,
) -> int:
return self.db_pool.simple_select_one_onecol_txn(
allow_none: bool = False,
) -> Optional[int]:
# Type ignore: we pass keyvalues a Dict[str, str]; the function wants
# Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload]
txn=txn,
table="events",
keyvalues={"event_id": event_id},


+ 13
- 8
synapse/storage/persist_events.py 查看文件

@@ -25,6 +25,7 @@ from typing import (
Collection,
Deque,
Dict,
Generator,
Generic,
Iterable,
List,
@@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):

return res

def _handle_queue(self, room_id):
def _handle_queue(self, room_id: str) -> None:
"""Attempts to handle the queue for a room if not already being handled.

The queue's callback will be invoked with for each item in the queue,
@@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):

self._currently_persisting_rooms.add(room_id)

async def handle_queue_loop():
async def handle_queue_loop() -> None:
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
@@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
with PreserveLoggingContext():
item.deferred.callback(ret)
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
self._event_persist_queues[room_id] = queue
remaining_queue = self._event_persist_queues.pop(room_id, None)
if remaining_queue:
self._event_persist_queues[room_id] = remaining_queue
self._currently_persisting_rooms.discard(room_id)

# set handle_queue_loop off in the background
run_as_background_process("persist_events", handle_queue_loop)

def _get_drainining_queue(self, room_id):
def _get_drainining_queue(
self, room_id: str
) -> Generator[_EventPersistQueueItem, None, None]:
queue = self._event_persist_queues.setdefault(room_id, deque())

try:
@@ -317,7 +320,9 @@ class EventsPersistenceStorage:
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))

async def enqueue(item):
async def enqueue(
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
) -> Dict[str, str]:
room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
@@ -1102,7 +1107,7 @@ class EventsPersistenceStorage:

return False

async def _handle_potentially_left_users(self, user_ids: Set[str]):
async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
"""Given a set of remote users check if the server still shares a room with
them. If not then mark those users' device cache as stale.
"""


+ 1
- 1
synapse/storage/prepare_database.py 查看文件

@@ -85,7 +85,7 @@ def prepare_database(
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
databases: Collection[str] = ("main", "state"),
):
) -> None:
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.



+ 4
- 2
synapse/storage/state.py 查看文件

@@ -62,7 +62,7 @@ class StateFilter:
types: "frozendict[str, Optional[FrozenSet[str]]]"
include_others: bool = False

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
@@ -138,7 +138,9 @@ class StateFilter:
)

@staticmethod
def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
def freeze(
types: Mapping[str, Optional[Collection[str]]], include_others: bool
) -> "StateFilter":
"""
Returns a (frozen) StateFilter with the same contents as the parameters
specified here, which can be made of mutable types.


+ 8
- 2
synapse/storage/types.py 查看文件

@@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from types import TracebackType
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union

from typing_extensions import Protocol

@@ -86,5 +87,10 @@ class Connection(Protocol):
def __enter__(self) -> "Connection":
...

def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
...

Loading…
取消
儲存