Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/10753.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Widened ``trace_request_ctx`` parameter type from ``Mapping[str, Any] | None`` to ``object`` to allow passing instances of user-defined classes as trace context -- by :user:`nightcityblade`.
1 change: 1 addition & 0 deletions CHANGES/12106.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a ``dns_cache_max_size`` parameter to ``TCPConnector`` to limit the size of the cache -- by :user:`Dreamsorcerer`.
5 changes: 2 additions & 3 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Coroutine,
Generator,
Iterable,
Mapping,
Sequence,
)
from contextlib import suppress
Expand Down Expand Up @@ -195,7 +194,7 @@ class _RequestOptions(TypedDict, total=False):
ssl: SSLContext | bool | Fingerprint
server_hostname: str | None
proxy_headers: LooseHeaders | None
trace_request_ctx: Mapping[str, Any] | None
trace_request_ctx: object
read_bufsize: int | None
auto_decompress: bool | None
max_line_size: int | None
Expand Down Expand Up @@ -501,7 +500,7 @@ async def _request(
ssl: SSLContext | bool | Fingerprint = True,
server_hostname: str | None = None,
proxy_headers: LooseHeaders | None = None,
trace_request_ctx: Mapping[str, Any] | None = None,
trace_request_ctx: object = None,
read_bufsize: int | None = None,
auto_decompress: bool | None = None,
max_line_size: int | None = None,
Expand Down
24 changes: 18 additions & 6 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,25 +794,33 @@ async def _create_connection(


class _DNSCacheTable:
def __init__(self, ttl: float | None = None) -> None:
self._addrs_rr: dict[tuple[str, int], tuple[Iterator[ResolveResult], int]] = {}
def __init__(self, ttl: float | None = None, max_size: int = 1000) -> None:
self._addrs_rr: OrderedDict[
tuple[str, int], tuple[Iterator[ResolveResult], int]
] = OrderedDict()
self._timestamps: dict[tuple[str, int], float] = {}
self._ttl = ttl
self._max_size = max_size

def __contains__(self, host: object) -> bool:
return host in self._addrs_rr

def add(self, key: tuple[str, int], addrs: list[ResolveResult]) -> None:
if key in self._addrs_rr:
self._addrs_rr.move_to_end(key)

self._addrs_rr[key] = (cycle(addrs), len(addrs))

if self._ttl is not None:
self._timestamps[key] = monotonic()

if len(self._addrs_rr) > self._max_size:
oldest_key, _ = self._addrs_rr.popitem(last=False)
self._timestamps.pop(oldest_key, None)

def remove(self, key: tuple[str, int]) -> None:
self._addrs_rr.pop(key, None)

if self._ttl is not None:
self._timestamps.pop(key, None)
self._timestamps.pop(key, None)

def clear(self) -> None:
self._addrs_rr.clear()
Expand All @@ -823,6 +831,7 @@ def next_addrs(self, key: tuple[str, int]) -> list[ResolveResult]:
addrs = list(islice(loop, length))
# Consume one more element to shift internal state of `cycle`
next(loop)
self._addrs_rr.move_to_end(key)
return addrs

def expired(self, key: tuple[str, int]) -> bool:
Expand Down Expand Up @@ -909,6 +918,7 @@ def __init__(
*,
use_dns_cache: bool = True,
ttl_dns_cache: int | None = 10,
dns_cache_max_size: int = 1000,
family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC,
ssl: bool | Fingerprint | SSLContext = True,
local_addr: tuple[str, int] | None = None,
Expand Down Expand Up @@ -949,7 +959,9 @@ def __init__(
self._resolver_owner = False

self._use_dns_cache = use_dns_cache
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
self._cached_hosts = _DNSCacheTable(
ttl=ttl_dns_cache, max_size=dns_cache_max_size
)
self._throttle_dns_futures: dict[tuple[str, int], set[asyncio.Future[None]]] = (
{}
)
Expand Down
57 changes: 57 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4338,6 +4338,63 @@ def test_next_addrs_single(self, dns_cache_table: _DNSCacheTable) -> None:
addrs = dns_cache_table.next_addrs(self.host2)
assert addrs == [self.result1]

def test_max_size_eviction(self) -> None:
table = _DNSCacheTable(max_size=2)

table.add(self.host1, [self.result1])
table.add(self.host2, [self.result2])

host3 = ("example.com", 80)
result3: ResolveResult = {
**self.result1,
"hostname": "example.com",
"host": "1.2.3.4",
}
table.add(host3, [result3])

assert len(table._addrs_rr) == 2
assert self.host1 not in table._addrs_rr
assert host3 in table._addrs_rr

def test_lru_eviction(self) -> None:
table = _DNSCacheTable(max_size=2)

table.add(self.host1, [self.result1])
table.add(self.host2, [self.result2])

table.next_addrs(self.host1)

host3 = ("example.com", 80)
result3: ResolveResult = {
**self.result1,
"hostname": "example.com",
"host": "1.2.3.4",
}
table.add(host3, [result3])

assert self.host1 in table._addrs_rr
assert self.host2 not in table._addrs_rr

def test_lru_eviction_add(self) -> None:
table = _DNSCacheTable(max_size=2)

table.add(self.host1, [self.result1])
table.add(self.host2, [self.result2])

# Re-add, thus making host1 the most recently used.
table.add(self.host1, [self.result1])

host3 = ("example.com", 80)
result3: ResolveResult = {
**self.result1,
"hostname": "example.com",
"host": "1.2.3.4",
}
table.add(host3, [result3])

assert self.host1 in table._addrs_rr
assert self.host2 not in table._addrs_rr


async def test_connector_cache_trace_race() -> None:
class DummyTracer(Trace):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ def test_trace_config_ctx_request_ctx(self) -> None:
)
assert trace_config_ctx.trace_request_ctx is trace_request_ctx

def test_trace_config_ctx_custom_class(self) -> None:
"""Custom class instances should be accepted as trace_request_ctx (#10753)."""

class MyContext:
def __init__(self, request_id: int) -> None:
self.request_id = request_id

ctx = MyContext(request_id=42)
trace_config = TraceConfig()
trace_config_ctx = trace_config.trace_config_ctx(trace_request_ctx=ctx)
assert trace_config_ctx.trace_request_ctx is ctx
assert trace_config_ctx.trace_request_ctx.request_id == 42

def test_freeze(self) -> None:
trace_config = TraceConfig()
trace_config.freeze()
Expand Down
Loading