Skip to content

Commit 2314db2

Browse files
Order hostnames so that ones that are visible in DNS go first
1 parent 7bdcd35 commit 2314db2

File tree

3 files changed

+124
-105
lines changed

3 files changed

+124
-105
lines changed

cmapi/cmapi_server/managers/host_identity.py

Lines changed: 50 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
So:
1010
1. We must choose 1 IP address and 0/1 hostnames as primary (of many)
1111
2. We need to filter out unreliable names
12-
3. We must order IPs/hostnames by source reliability
12+
3. We must order IPs/hostnames by source reliability (DNS > /etc/hosts)
1313
4. There can be very many resolving sources, so we cannot resolve everything ourselves, and must rely on OS resolving
1414
(see /etc/nsswitch.conf, there are local /etc/hosts, DNS, mDNS, systemd-resolved, LDAP, myhostname, etc)
15-
5. But most important sources (DNS and /etc/hosts, recommended by our manual...) must be checked and ordered by our policy
1615
"""
1716
import hashlib
1817
import ipaddress
@@ -21,14 +20,11 @@
2120
from dataclasses import dataclass, field
2221
from datetime import datetime, timezone
2322
from functools import lru_cache
24-
from typing import Optional, Union
25-
26-
import dns.resolver
27-
import dns.reversename
23+
from typing import Dict, List, Optional, Set, Tuple, Union
2824

2925
from cmapi_server.exceptions import CMAPIBasicError, ResolutionError, ResolutionPolicyViolationError
3026
from cmapi_server.managers.network import NetworkManager
31-
from cmapi_server.managers.resolving_sources import get_resolving_source
27+
from cmapi_server.managers.resolving_sources import ResolvingSourceName, get_resolving_source
3228

3329
IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
3430

@@ -42,8 +38,8 @@ class HostIdentity:
4238
# The first, most important IP addr and host name (ordering is done by policy)
4339
primary_ip: str
4440
primary_name: Optional[str] # Name can be missing (but policy can make it required)
45-
ips: list[str] # All IP addrs
46-
names: list[str] # All host names (only those that are visible to other hosts)
41+
ips: List[str] # All IP addrs
42+
names: List[str] # All host names (only those that are visible to other hosts)
4743
unique_key: str # unique id of this host, will be used later for aliases
4844
observed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
4945

@@ -113,9 +109,9 @@ class ResolutionPolicy:
113109
allow_ipv6: bool = False
114110
require_hostname: bool = False
115111

116-
def filter_addresses(self, addrs: Iterable[str]) -> list[IPAddress]:
112+
def filter_addresses(self, addrs: Iterable[str]) -> List[IPAddress]:
117113
"""Filter out IP addresses that don't match the policy."""
118-
ips: list[IPAddress] = []
114+
ips: List[IPAddress] = []
119115
logger.debug(
120116
'Filtering addresses: %s (allow_private_ips=%s, allow_ipv6=%s)',
121117
list(addrs), self.allow_private_ips, self.allow_ipv6
@@ -155,7 +151,21 @@ def filter_addresses(self, addrs: Iterable[str]) -> list[IPAddress]:
155151
logger.debug('Accept %s', addr)
156152
return ips
157153

158-
def order_addresses(self, ips: Sequence[IPAddress]) -> list[IPAddress]:
154+
def order_hostnames(
155+
self,
156+
names: Sequence[str],
157+
name_sources: Dict[str, ResolvingSourceName],
158+
) -> List[str]:
159+
"""Order hostnames so that DNS-sourced names come first."""
160+
161+
def key(name: str) -> Tuple[int, str]:
162+
src = name_sources.get(name, ResolvingSourceName.OS)
163+
is_dns = 0 if src is ResolvingSourceName.DNS else 1
164+
return (is_dns, name.lower())
165+
166+
return sorted(names, key=key)
167+
168+
def order_addresses(self, ips: Sequence[IPAddress]) -> List[IPAddress]:
159169
"""Order IPs deterministically to choose the primary IP."""
160170
def key(ip: IPAddress) -> tuple[int, int, int, int, int, int, int, int]:
161171
is_global = 0 if ip.is_global else 1
@@ -173,7 +183,7 @@ class HostAddressManager:
173183

174184
def __init__(self, policy: Optional[ResolutionPolicy] = None) -> None:
175185
self._policy = policy if policy is not None else ResolutionPolicy()
176-
self._cache: dict[str, HostIdentity] = {}
186+
self._cache: Dict[str, HostIdentity] = {}
177187

178188
def get_identity(self, target: str) -> HostIdentity:
179189
"""Resolve and normalize a hostname or IP."""
@@ -249,12 +259,15 @@ def _get_identity_from_ip(self, ip: IPAddress, original_input: str) -> HostIdent
249259
if not self._policy.filter_addresses([str(ip)]):
250260
raise ResolutionPolicyViolationError('Input IP address was rejected by policy')
251261

252-
names: list[str] = self._get_names_of_ip(ip)
262+
names: List[str] = self._get_names_of_ip(ip)
253263
if self._policy.require_hostname and not names:
254264
logger.warning('Reject %s: no names found for this IP and policy requires hostname', original_input)
255265
raise ResolutionPolicyViolationError('Policy requires a hostname for the input IP address (no PTR record found).')
256266

257-
return HostIdentity.from_policy(original_input, self._policy, [ip], sorted(names))
267+
name_sources: Dict[str, ResolvingSourceName] = {name: ResolvingSourceName.DNS for name in names}
268+
ordered_names = self._policy.order_hostnames(sorted(names), name_sources)
269+
270+
return HostIdentity.from_policy(original_input, self._policy, [ip], ordered_names)
258271

259272
def _get_identity_from_hostname(self, hostname: str) -> HostIdentity:
260273
normalized = hostname.strip().lower()
@@ -271,13 +284,16 @@ def _get_identity_from_fqdn(self, fqdn: str) -> HostIdentity:
271284
filtered_ips = self._policy.filter_addresses([str(ip) for ip in addrs])
272285

273286
# Resolve each IP back to names and check if there is any that resolved back to passed fqdn
274-
names: set[str] = {fqdn}
287+
names: Set[str] = {fqdn}
288+
name_sources: Dict[str, ResolvingSourceName] = {fqdn: ResolvingSourceName.OS}
275289
roundtrip_found = False
276290
for ip in filtered_ips:
277291
names_of_ip = self._get_names_of_ip(ip)
278292
if fqdn in names_of_ip:
279293
roundtrip_found = True
280-
names.update(names_of_ip)
294+
for n in names_of_ip:
295+
names.add(n)
296+
name_sources[n] = ResolvingSourceName.DNS
281297

282298
if not roundtrip_found:
283299
logger.warning(
@@ -290,30 +306,30 @@ def _get_identity_from_fqdn(self, fqdn: str) -> HostIdentity:
290306
if not filtered_ips:
291307
raise ResolutionPolicyViolationError('All resolved addresses were rejected by policy.')
292308

293-
return HostIdentity.from_policy(fqdn, self._policy, filtered_ips, sorted(names))
309+
ordered_names = self._policy.order_hostnames(sorted(names), name_sources)
310+
311+
return HostIdentity.from_policy(fqdn, self._policy, filtered_ips, ordered_names)
294312

295313
def _get_identity_from_non_fqdn(self, hostname: str) -> HostIdentity:
296314
# Like FQDN version, but we know that nothing will resolve back to passed hostname
297-
candidate_ips: set[str] = set(
298-
NetworkManager.resolve_hostname_to_ips(
299-
hostname,
300-
only_ipv4=not self._policy.allow_ipv6,
301-
exclude_loopback=False,
302-
)
303-
)
315+
os_resolver = get_resolving_source(ResolvingSourceName.OS)
316+
candidate_ips: Set[str] = set(str(ip) for ip in os_resolver.resolve(hostname))
304317

305318
ips = self._policy.filter_addresses(candidate_ips)
306319
if not ips:
307320
raise ResolutionPolicyViolationError('All resolved addresses were rejected by policy.')
308321

309322
# Collect names of IPs
310-
names: set[str] = set()
323+
names: Set[str] = set()
324+
name_sources: Dict[str, ResolvingSourceName] = {}
311325
for ip_text in list(candidate_ips):
312326
ip = _ip_or_none(ip_text)
313327
if ip is None:
314328
logger.error('Invalid IP address: %s', ip_text)
315329
continue
316-
names.update(self._get_names_of_ip(ip))
330+
for n in self._get_names_of_ip(ip):
331+
names.add(n)
332+
name_sources[n] = ResolvingSourceName.DNS
317333

318334
if self._policy.require_hostname and not names:
319335
logger.error(
@@ -323,66 +339,24 @@ def _get_identity_from_non_fqdn(self, hostname: str) -> HostIdentity:
323339
)
324340
raise ResolutionPolicyViolationError('Policy requires a hostname for the input host, but DNS did not return any FQDN names.')
325341

326-
return HostIdentity.from_policy(hostname, self._policy, ips, sorted(names))
342+
ordered_names = self._policy.order_hostnames(sorted(names), name_sources)
327343

328-
def _resolve_dns(self, hostname: str) -> list[IPAddress]:
344+
return HostIdentity.from_policy(hostname, self._policy, ips, ordered_names)
345+
346+
def _resolve_dns(self, hostname: str) -> List[IPAddress]:
329347
"""Resolve the given hostname using DNS and return addresses."""
330-
resolver = get_resolving_source('dns')
348+
resolver = get_resolving_source(ResolvingSourceName.DNS)
331349
return resolver.resolve(hostname)
332350

333-
def _get_names_of_ip(self, ip: IPAddress) -> list[str]:
351+
def _get_names_of_ip(self, ip: IPAddress) -> List[str]:
334352
"""Fetch PTR names for an IP via DNS."""
335353
try:
336-
resolver = get_resolving_source('dns')
354+
resolver = get_resolving_source(ResolvingSourceName.DNS)
337355
return resolver.reverse(ip)
338356
except Exception:
339357
logger.exception('ip-to-name lookup unexpected failure for %s', ip)
340358
raise
341359

342-
# DNS abstraction methods for easier mocking
343-
def _dns_resolve_ipv4(self, hostname: str) -> list[str]:
344-
resolver = dns.resolver.Resolver(configure=True)
345-
results: list[str] = []
346-
for rdata in resolver.resolve(hostname, 'A', raise_on_no_answer=False):
347-
try:
348-
results.append(rdata.to_text())
349-
except Exception:
350-
continue
351-
return results
352-
353-
def _dns_resolve_ipv6(self, hostname: str) -> list[str]:
354-
resolver = dns.resolver.Resolver(configure=True)
355-
results: list[str] = []
356-
for rdata in resolver.resolve(hostname, 'AAAA', raise_on_no_answer=False):
357-
try:
358-
results.append(rdata.to_text())
359-
except Exception:
360-
continue
361-
return results
362-
363-
def _dns_reverse(self, ip_text: str) -> list[str]:
364-
reverse_name = dns.reversename.from_address(ip_text)
365-
answer = dns.resolver.resolve(reverse_name, 'PTR', raise_on_no_answer=False)
366-
names: list[str] = []
367-
for ptr_rdata in answer:
368-
try:
369-
name = str(ptr_rdata.target).rstrip('.').lower()
370-
if name:
371-
names.append(name)
372-
except Exception:
373-
continue
374-
return names
375-
376-
def _contains_private(self, addrs: list[str]) -> bool:
377-
"""Return True if any resolvable address string is a private IP."""
378-
for addr in addrs:
379-
ip = _ip_or_none(addr)
380-
if ip is None:
381-
continue
382-
if ip.is_private:
383-
return True
384-
return False
385-
386360

387361
@lru_cache(maxsize=1) # singleton
388362
def get_host_address_manager() -> 'HostAddressManager':
@@ -395,9 +369,6 @@ def _ip_or_none(val: str) -> Optional[IPAddress]:
395369
except ValueError:
396370
return None
397371

398-
def _is_ip_address(val: str) -> bool:
399-
return _ip_or_none(val) is not None
400-
401372
def _is_fqdn(name: str) -> bool:
402373
"""Return True if the string is a valid FQDN (lower-cased, no trailing dot).
403374

cmapi/cmapi_server/managers/resolving_sources.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,47 @@
11
import ipaddress
22
import logging
3+
from enum import Enum
34
from functools import cache
4-
from typing import Union
5+
from typing import Dict, List, Set, Type, Union
56

67
import dns.resolver
78
import dns.reversename
89

10+
from cmapi_server.managers.network import NetworkManager
11+
912
IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
1013

1114
logger = logging.getLogger(__name__)
1215

1316

17+
class ResolvingSourceName(str, Enum):
18+
DNS = 'dns'
19+
OS = 'os'
20+
21+
1422
class ResolvingSource:
1523
"""Base class for name/IP resolution sources"""
1624

17-
def resolve(self, hostname: str) -> list[IPAddress]:
25+
name: ResolvingSourceName
26+
27+
def resolve(self, hostname: str) -> List[IPAddress]:
1828
"""Forward lookup: hostname -> list of IPAddress objects."""
1929
raise NotImplementedError
2030

21-
def reverse(self, ip: IPAddress) -> list[str]:
31+
def reverse(self, ip: IPAddress) -> List[str]:
2232
"""Reverse lookup: IPAddress -> list of normalized hostnames."""
2333
raise NotImplementedError
2434

2535

2636
class DNSResolvingSource(ResolvingSource):
27-
"""DNS-based resolving source"""
37+
"""Use only DNS for resolution"""
38+
39+
name = ResolvingSourceName.DNS
2840

29-
def resolve(self, hostname: str) -> list[IPAddress]:
41+
def resolve(self, hostname: str) -> List[IPAddress]:
3042
resolver = dns.resolver.Resolver(configure=True)
31-
results: list[IPAddress] = []
32-
seen: set[str] = set()
43+
results: List[IPAddress] = []
44+
seen: Set[str] = set()
3345

3446
# A records
3547
try:
@@ -68,15 +80,15 @@ def resolve(self, hostname: str) -> list[IPAddress]:
6880

6981
return results
7082

71-
def reverse(self, ip: IPAddress) -> list[str]:
83+
def reverse(self, ip: IPAddress) -> List[str]:
7284
ip_text = str(ip)
7385
reverse_name = dns.reversename.from_address(ip_text)
7486
try:
7587
answer = dns.resolver.resolve(reverse_name, 'PTR', raise_on_no_answer=False)
7688
except dns.resolver.NoAnswer:
7789
return []
7890

79-
names: list[str] = []
91+
names: List[str] = []
8092
for ptr_rdata in answer:
8193
try:
8294
name = str(ptr_rdata.target).rstrip('.').lower()
@@ -87,8 +99,44 @@ def reverse(self, ip: IPAddress) -> list[str]:
8799
return names
88100

89101

102+
class OSResolvingSource(ResolvingSource):
103+
"""Uses all the sources defined in /etc/nsswitch.conf for resolving"""
104+
105+
name = ResolvingSourceName.OS
106+
107+
def resolve(self, hostname: str) -> List[IPAddress]:
108+
ip_texts = NetworkManager.resolve_hostname_to_ips(
109+
hostname,
110+
only_ipv4=False,
111+
exclude_loopback=False,
112+
)
113+
114+
results: List[IPAddress] = []
115+
seen: set[str] = set()
116+
for ip_text in ip_texts:
117+
if ip_text in seen:
118+
continue
119+
try:
120+
ip = ipaddress.ip_address(ip_text)
121+
except ValueError:
122+
logger.error('OS resolver returned invalid IP address %s for host name %s, skipping', ip_text, hostname)
123+
continue
124+
seen.add(ip_text)
125+
results.append(ip)
126+
return results
127+
128+
def reverse(self, ip: IPAddress) -> List[str]:
129+
return NetworkManager.get_hostnames_by_ip(str(ip))
130+
131+
132+
_RESOLVER_REGISTRY: Dict[ResolvingSourceName, Type[ResolvingSource]] = {
133+
cls.name: cls for cls in (DNSResolvingSource, OSResolvingSource)
134+
}
135+
136+
90137
@cache
91-
def get_resolving_source(name: str) -> ResolvingSource:
92-
if name == 'dns':
93-
return DNSResolvingSource()
94-
raise ValueError(f'Unknown resolving source: {name}')
138+
def get_resolving_source(name: ResolvingSourceName) -> ResolvingSource:
139+
cls = _RESOLVER_REGISTRY.get(name)
140+
if cls is None:
141+
raise ValueError(f'Unknown resolving source: {name}')
142+
return cls()

0 commit comments

Comments
 (0)