diff --git a/CHANGES/10753.bugfix.rst b/CHANGES/10753.bugfix.rst new file mode 100644 index 00000000000..e0f4cfd0dd1 --- /dev/null +++ b/CHANGES/10753.bugfix.rst @@ -0,0 +1 @@ +Widened ``trace_request_ctx`` parameter type from ``Mapping[str, Any] | None`` to ``object`` to allow passing instances of user-defined classes as trace context -- by :user:`nightcityblade`. diff --git a/CHANGES/12106.feature.rst b/CHANGES/12106.feature.rst new file mode 100644 index 00000000000..daa9088eed6 --- /dev/null +++ b/CHANGES/12106.feature.rst @@ -0,0 +1 @@ +Added a ``dns_cache_max_size`` parameter to ``TCPConnector`` to limit the size of the cache -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/client.py b/aiohttp/client.py index 2aa8dfb8acf..4f25cc49b40 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -16,7 +16,6 @@ Coroutine, Generator, Iterable, - Mapping, Sequence, ) from contextlib import suppress @@ -195,7 +194,7 @@ class _RequestOptions(TypedDict, total=False): ssl: SSLContext | bool | Fingerprint server_hostname: str | None proxy_headers: LooseHeaders | None - trace_request_ctx: Mapping[str, Any] | None + trace_request_ctx: object read_bufsize: int | None auto_decompress: bool | None max_line_size: int | None @@ -501,7 +500,7 @@ async def _request( ssl: SSLContext | bool | Fingerprint = True, server_hostname: str | None = None, proxy_headers: LooseHeaders | None = None, - trace_request_ctx: Mapping[str, Any] | None = None, + trace_request_ctx: object = None, read_bufsize: int | None = None, auto_decompress: bool | None = None, max_line_size: int | None = None, diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 0a4f79c0154..7abe43dbe03 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -794,25 +794,33 @@ async def _create_connection( class _DNSCacheTable: - def __init__(self, ttl: float | None = None) -> None: - self._addrs_rr: dict[tuple[str, int], tuple[Iterator[ResolveResult], int]] = {} + def __init__(self, ttl: float | None = None, max_size: int = 1000) -> None: + self._addrs_rr: OrderedDict[ + tuple[str, int], tuple[Iterator[ResolveResult], int] + ] = OrderedDict() self._timestamps: dict[tuple[str, int], float] = {} self._ttl = ttl + self._max_size = max_size def __contains__(self, host: object) -> bool: return host in self._addrs_rr def add(self, key: tuple[str, int], addrs: list[ResolveResult]) -> None: + if key in self._addrs_rr: + self._addrs_rr.move_to_end(key) + self._addrs_rr[key] = (cycle(addrs), len(addrs)) if self._ttl is not None: self._timestamps[key] = monotonic() + if len(self._addrs_rr) > self._max_size: + oldest_key, _ = self._addrs_rr.popitem(last=False) + self._timestamps.pop(oldest_key, None) + def remove(self, key: tuple[str, int]) -> None: self._addrs_rr.pop(key, None) - - if self._ttl is not None: - self._timestamps.pop(key, None) + self._timestamps.pop(key, None) def clear(self) -> None: self._addrs_rr.clear() @@ -823,6 +831,7 @@ def next_addrs(self, key: tuple[str, int]) -> list[ResolveResult]: addrs = list(islice(loop, length)) # Consume one more element to shift internal state of `cycle` next(loop) + self._addrs_rr.move_to_end(key) return addrs def expired(self, key: tuple[str, int]) -> bool: @@ -909,6 +918,7 @@ def __init__( *, use_dns_cache: bool = True, ttl_dns_cache: int | None = 10, + dns_cache_max_size: int = 1000, family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC, ssl: bool | Fingerprint | SSLContext = True, local_addr: tuple[str, int] | None = None, @@ -949,7 +959,9 @@ def __init__( self._resolver_owner = False self._use_dns_cache = use_dns_cache - self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) + self._cached_hosts = _DNSCacheTable( + ttl=ttl_dns_cache, max_size=dns_cache_max_size + ) self._throttle_dns_futures: dict[tuple[str, int], set[asyncio.Future[None]]] = ( {} ) diff --git a/tests/test_connector.py b/tests/test_connector.py index cb156cef86b..a3fd3626157 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -4338,6 +4338,63 @@ def test_next_addrs_single(self, dns_cache_table: _DNSCacheTable) -> None: addrs = dns_cache_table.next_addrs(self.host2) assert addrs == [self.result1] + def test_max_size_eviction(self) -> None: + table = _DNSCacheTable(max_size=2) + + table.add(self.host1, [self.result1]) + table.add(self.host2, [self.result2]) + + host3 = ("example.com", 80) + result3: ResolveResult = { + **self.result1, + "hostname": "example.com", + "host": "1.2.3.4", + } + table.add(host3, [result3]) + + assert len(table._addrs_rr) == 2 + assert self.host1 not in table._addrs_rr + assert host3 in table._addrs_rr + + def test_lru_eviction(self) -> None: + table = _DNSCacheTable(max_size=2) + + table.add(self.host1, [self.result1]) + table.add(self.host2, [self.result2]) + + table.next_addrs(self.host1) + + host3 = ("example.com", 80) + result3: ResolveResult = { + **self.result1, + "hostname": "example.com", + "host": "1.2.3.4", + } + table.add(host3, [result3]) + + assert self.host1 in table._addrs_rr + assert self.host2 not in table._addrs_rr + + def test_lru_eviction_add(self) -> None: + table = _DNSCacheTable(max_size=2) + + table.add(self.host1, [self.result1]) + table.add(self.host2, [self.result2]) + + # Re-add, thus making host1 the most recently used. + table.add(self.host1, [self.result1]) + + host3 = ("example.com", 80) + result3: ResolveResult = { + **self.result1, + "hostname": "example.com", + "host": "1.2.3.4", + } + table.add(host3, [result3]) + + assert self.host1 in table._addrs_rr + assert self.host2 not in table._addrs_rr + async def test_connector_cache_trace_race() -> None: class DummyTracer(Trace): diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 7ee7e6ae6d7..0ec7b442a26 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -59,6 +59,19 @@ def test_trace_config_ctx_request_ctx(self) -> None: ) assert trace_config_ctx.trace_request_ctx is trace_request_ctx + def test_trace_config_ctx_custom_class(self) -> None: + """Custom class instances should be accepted as trace_request_ctx (#10753).""" + + class MyContext: + def __init__(self, request_id: int) -> None: + self.request_id = request_id + + ctx = MyContext(request_id=42) + trace_config = TraceConfig() + trace_config_ctx = trace_config.trace_config_ctx(trace_request_ctx=ctx) + assert trace_config_ctx.trace_request_ctx is ctx + assert trace_config_ctx.trace_request_ctx.request_id == 42 + def test_freeze(self) -> None: trace_config = TraceConfig() trace_config.freeze()