Skip to content

Commit 6f157f3

Browse files
Basic host identity manager
1 parent 5716ee0 commit 6f157f3

File tree

7 files changed

+363
-33
lines changed

7 files changed

+363
-33
lines changed

cmapi/cmapi_server/__main__.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,25 @@
1717
# TODO: fix dispatcher choose logic because code executing in endpoints.py
1818
# while import process, this cause module logger misconfiguration
1919
from cmapi_server.logging_management import config_cmapi_server_logging
20+
config_cmapi_server_logging()
21+
2022
from tracing.sentry import maybe_init_sentry
2123
from tracing.traceparent_backend import TraceparentBackend
2224
from tracing.tracer import get_tracer
25+
from tracing.trace_tool import register_tracing_tools
2326

24-
config_cmapi_server_logging()
2527
from cmapi_server import helpers
26-
from cmapi_server.constants import CMAPI_CONF_PATH, DEFAULT_MCS_CONF_PATH
27-
from cmapi_server.controllers.dispatcher import dispatcher, jsonify_404, jsonify_error
28+
from cmapi_server.constants import DEFAULT_MCS_CONF_PATH, CMAPI_CONF_PATH
29+
from cmapi_server.controllers.dispatcher import dispatcher, jsonify_error, jsonify_404
2830
from cmapi_server.failover_agent import FailoverAgent
29-
from cmapi_server.invariant_checks import run_invariant_checks
3031
from cmapi_server.managers.application import AppManager
31-
from cmapi_server.managers.certificate import CertificateManager
32+
from cmapi_server.managers.host_identity import get_host_address_manager
3233
from cmapi_server.managers.process import MCSProcessManager
34+
from cmapi_server.managers.certificate import CertificateManager
35+
from cmapi_server.invariant_checks import run_invariant_checks
3336
from failover.node_monitor import NodeMonitor
3437
from mcs_node_control.models.dbrm_socket import SOCK_TIMEOUT, DBRMSocketHandler
3538
from mcs_node_control.models.node_config import NodeConfig
36-
from tracing.trace_tool import register_tracing_tools
3739

3840

3941
def worker(app):
@@ -159,6 +161,9 @@ def stop(self):
159161
logging.error('Invariant checks failed, exiting')
160162
sys.exit(1)
161163

164+
my_identity = get_host_address_manager().get_local_identity()
165+
logging.info('My identity: %s', my_identity)
166+
162167
app = cherrypy.tree.mount(root=None, config=CMAPI_CONF_PATH)
163168
root_config = {
164169
"request.dispatch": dispatcher,

cmapi/cmapi_server/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ class CEJError(CMAPIBasicError):
2828
"""
2929

3030

31+
class ResolutionError(CMAPIBasicError):
32+
"""Errors related to DNS resolution"""
33+
34+
35+
class ResolutionPolicyViolationError(CMAPIBasicError):
36+
"""Errors where results are rejected by the current resolving policy."""
37+
38+
3139
@contextmanager
3240
def exc_to_cmapi_error(prefix: Optional[str] = None) -> Iterator[None]:
3341
"""Context manager to standardize error wrapping into CMAPIBasicError.
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
import hashlib
2+
import ipaddress
3+
import logging
4+
import socket
5+
import time
6+
from collections.abc import Iterable, Sequence
7+
from dataclasses import dataclass, field
8+
from functools import lru_cache
9+
from typing import Optional, Union
10+
11+
import dns.resolver
12+
import dns.reversename
13+
14+
from cmapi_server.exceptions import CMAPIBasicError, ResolutionError, ResolutionPolicyViolationError
15+
from cmapi_server.managers.network import NetworkManager
16+
17+
IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
@lru_cache(maxsize=1) # singleton
23+
def get_host_address_manager() -> 'HostAddressManager':
24+
return HostAddressManager()
25+
26+
27+
@dataclass(frozen=True)
28+
class ResolutionPolicy:
29+
"""Rules controlling how hosts are resolved, what is/isn't allowed (like IPv6)."""
30+
31+
allow_private_ips: bool = True
32+
allow_ipv6: bool = False
33+
require_hostname: bool = False
34+
35+
def validate_hostname(self, name: str) -> str:
36+
"""Validate and normalize a hostname to FQDN."""
37+
normalized = name.strip().lower()
38+
if not _is_fqdn(normalized):
39+
raise ResolutionPolicyViolationError(f'The name {name} is not a fully qualified domain name.')
40+
return normalized
41+
42+
def filter_addresses(self, addrs: Iterable[str]) -> list[IPAddress]:
43+
"""Filter out IP addresses that don't match the policy."""
44+
ips: list[IPAddress] = []
45+
logger.debug(
46+
'Filtering addresses: %s (allow_private_ips=%s, allow_ipv6=%s)',
47+
list(addrs), self.allow_private_ips, self.allow_ipv6
48+
)
49+
for addr in addrs:
50+
try:
51+
ip = ipaddress.ip_address(addr)
52+
except ValueError:
53+
logger.debug('Skip %s: not a valid IP literal', addr)
54+
continue
55+
56+
# Fixed rejections first
57+
if ip.is_loopback:
58+
logger.debug('Skip %s: loopback address', addr)
59+
continue
60+
if ip.is_link_local:
61+
logger.debug('Skip %s: link-local address', addr)
62+
continue
63+
if ip.is_multicast:
64+
logger.debug('Skip %s: multicast address', addr)
65+
continue
66+
if isinstance(ip, ipaddress.IPv4Address) and int(ip) == int(ipaddress.IPv4Address('255.255.255.255')):
67+
logger.debug('Skip %s: IPv4 broadcast address', addr)
68+
continue
69+
70+
# Policy conditionals
71+
if not self.allow_ipv6 and isinstance(ip, ipaddress.IPv6Address):
72+
logger.debug('Skip %s: IPv6 not allowed by policy', addr)
73+
continue
74+
if not ip.is_global and not ip.is_private:
75+
logger.debug('Skip %s: neither global nor private address', addr)
76+
continue
77+
if ip.is_private and not self.allow_private_ips:
78+
logger.debug('Skip %s: private address but private use disabled', addr)
79+
continue
80+
81+
ips.append(ip)
82+
logger.debug('Accept %s', addr)
83+
return ips
84+
85+
def order_addresses(self, ips: Sequence[IPAddress]) -> list[IPAddress]:
86+
"""Deterministic ordering: global first, then IPv4 before IPv6, then sort by numeric value."""
87+
def key(ip: IPAddress) -> tuple[int, int, int, int, int, int, int, int]:
88+
is_global = 0 if ip.is_global else 1
89+
is_ipv4 = 0 if isinstance(ip, ipaddress.IPv4Address) else 1
90+
return (
91+
is_global,
92+
is_ipv4,
93+
*list(ip.packed)
94+
)
95+
return sorted(ips, key=key)
96+
97+
98+
@dataclass
99+
class HostIdentity:
100+
"""Normalized result of host resolution (after policy is enforced)"""
101+
input: str
102+
addresses: list[str]
103+
names: list[str]
104+
primary_ip: str
105+
primary_name: Optional[str]
106+
unique_key: str # unique id of this host, will be used later for aliases
107+
observed_at: float = field(default_factory=lambda: time.time())
108+
109+
@staticmethod
110+
def from_policy(input: str, policy: ResolutionPolicy, ips: Sequence[IPAddress], names: Sequence[str]) -> 'HostIdentity':
111+
"""Construct a HostIdentity from filtered addresses and names using policy ordering and hashing rules."""
112+
if not ips:
113+
raise ResolutionPolicyViolationError('All resolved addresses were rejected by policy (loopback / link-local / multicast).')
114+
ordered = policy.order_addresses(ips)
115+
addresses = [str(ip) for ip in ordered]
116+
117+
# Calculate host unique key from its addresses (we prefer globals as more stable ones)
118+
globals_sorted = [str(ip) for ip in ordered if ip.is_global]
119+
privates_sorted = [str(ip) for ip in ordered if ip.is_private]
120+
basis = globals_sorted if globals_sorted else privates_sorted
121+
hasher = hashlib.sha256()
122+
for address_text in basis:
123+
hasher.update(address_text.encode('utf-8'))
124+
unique_key = hasher.hexdigest()
125+
126+
primary_ip = addresses[0]
127+
primary_name = names[0] if names else None
128+
return HostIdentity(
129+
input=input,
130+
addresses=addresses,
131+
names=list(names),
132+
primary_ip=primary_ip,
133+
primary_name=primary_name,
134+
unique_key=unique_key,
135+
)
136+
137+
138+
class HostAddressManager:
139+
"""In-memory resolver that performs DNS-only lookups, enforces policy, and caches results."""
140+
141+
def __init__(self, policy: Optional[ResolutionPolicy] = None) -> None:
142+
self._policy = policy if policy is not None else ResolutionPolicy()
143+
self._cache: dict[str, HostIdentity] = {}
144+
145+
@property
146+
def policy(self) -> ResolutionPolicy:
147+
return self._policy
148+
149+
def get_identity(self, target: str) -> HostIdentity:
150+
"""Resolve and normalize a hostname or IP under the current policy."""
151+
if target in self._cache:
152+
return self._cache[target]
153+
154+
target = target.strip()
155+
# If target is IP literal
156+
try:
157+
literal_ip = ipaddress.ip_address(target)
158+
except ValueError:
159+
literal_ip = None
160+
161+
if literal_ip is not None:
162+
# Use the literal IP only; names via PTR if available
163+
candidate_ips: set[str] = {str(literal_ip)}
164+
names_set: set[str] = set(self._reverse_dns_names(literal_ip))
165+
ips = self._policy.filter_addresses(candidate_ips)
166+
if self._policy.require_hostname and not names_set:
167+
logger.debug('Reject %s: no PTR hostname found and policy requires hostname', target)
168+
raise ResolutionPolicyViolationError('Policy requires a hostname for the input IP address (no PTR record found).')
169+
if not ips:
170+
raise ResolutionPolicyViolationError('Input IP address was rejected by policy')
171+
identity = HostIdentity.from_policy(target, self._policy, ips, sorted(names_set))
172+
self._cache[target] = identity
173+
return identity
174+
175+
# Target is hostname: if it's a valid FQDN, use DNS path; otherwise treat as local alias (e.g., localhost)
176+
try:
177+
fqdn = self._policy.validate_hostname(target)
178+
except ResolutionPolicyViolationError:
179+
fqdn = None
180+
if fqdn is None:
181+
# Non-FQDN (e.g., localhost). Seed with system resolver only.
182+
candidate_ips: set[str] = set(
183+
NetworkManager.resolve_hostname_to_ips(
184+
target,
185+
only_ipv4=not self._policy.allow_ipv6,
186+
exclude_loopback=False,
187+
)
188+
)
189+
# Collect names by PTR of candidate IPs and include original token
190+
names_set: set[str] = {target}
191+
for ip_text in list(candidate_ips):
192+
try:
193+
ip_obj = ipaddress.ip_address(ip_text)
194+
except ValueError:
195+
continue
196+
names_set.update(self._reverse_dns_names(ip_obj))
197+
ips = self._policy.filter_addresses(candidate_ips)
198+
if not ips:
199+
raise ResolutionPolicyViolationError('All resolved addresses were rejected by policy.')
200+
identity = HostIdentity.from_policy(target, self._policy, ips, sorted(names_set))
201+
self._cache[target] = identity
202+
return identity
203+
204+
# FQDN path: forward DNS, then PTR of accepted IPs; no further expansion
205+
addrs, _ = self._resolve_dns(fqdn)
206+
if not addrs:
207+
raise ResolutionError(f'Could not resolve {fqdn} to any IP addresses.')
208+
ips = self._policy.filter_addresses([str(ip) for ip in addrs])
209+
names_set: set[str] = {fqdn}
210+
ptr_match = False
211+
for ip in ips:
212+
ptrs = self._reverse_dns_names(ip)
213+
if fqdn in ptrs:
214+
ptr_match = True
215+
names_set.update(ptrs)
216+
if not ptr_match:
217+
raise ResolutionError(f'{fqdn} failed the DNS round-trip check — forward and reverse records do not match.')
218+
if not ips:
219+
raise ResolutionPolicyViolationError('All resolved addresses were rejected by policy.')
220+
identity = HostIdentity.from_policy(fqdn, self._policy, ips, sorted(names_set))
221+
self._cache[target] = identity
222+
return identity
223+
224+
def get_local_identity(self) -> HostIdentity:
225+
"""Return normalized identity of the current host."""
226+
name = socket.getfqdn().strip().lower()
227+
candidates: list[str] = []
228+
if _is_fqdn(name):
229+
candidates.append(name)
230+
try:
231+
local_ips = NetworkManager.get_current_node_ips(
232+
ignore_loopback=False,
233+
only_ipv4=not self._policy.allow_ipv6,
234+
)
235+
except CMAPIBasicError:
236+
logger.exception('Failed to get local IP addresses for identity resolution')
237+
local_ips = []
238+
candidates.extend(local_ips)
239+
seen: set[str] = set()
240+
for candidate in candidates:
241+
if candidate in seen:
242+
continue
243+
seen.add(candidate)
244+
try:
245+
return self.get_identity(candidate)
246+
except CMAPIBasicError:
247+
logger.exception('Local identity candidate failed resolution: %s', candidate)
248+
continue
249+
raise ResolutionPolicyViolationError('Could not determine any acceptable local IP addresses under current policy.')
250+
251+
def _resolve_dns(self, hostname: str) -> tuple[list[IPAddress], list[str]]:
252+
"""Resolve the given hostname using DNS and return (addresses, names)."""
253+
resolver = dns.resolver.Resolver(configure=True)
254+
255+
# IPv4
256+
a_records: list[IPAddress] = []
257+
try:
258+
for a_rdata in resolver.resolve(hostname, 'A', raise_on_no_answer=False):
259+
a_records.append(ipaddress.ip_address(a_rdata.to_text()))
260+
except Exception:
261+
logger.exception('A lookup failed for %s', hostname)
262+
263+
# IPv6
264+
aaaa_records: list[IPAddress] = []
265+
if self._policy.allow_ipv6:
266+
try:
267+
for aaaa_rdata in resolver.resolve(hostname, 'AAAA', raise_on_no_answer=False):
268+
aaaa_records.append(ipaddress.ip_address(aaaa_rdata.to_text()))
269+
except Exception:
270+
logger.exception('AAAA lookup failed for %s', hostname)
271+
272+
names = [hostname]
273+
return a_records + aaaa_records, names
274+
275+
def _reverse_dns_names(self, ip: IPAddress) -> list[str]:
276+
"""Fetch PTR names for an IP via DNS."""
277+
try:
278+
reverse_name = dns.reversename.from_address(str(ip))
279+
answer = dns.resolver.resolve(reverse_name, 'PTR', raise_on_no_answer=False)
280+
names: list[str] = []
281+
for ptr_rdata in answer:
282+
fqdn_name = str(ptr_rdata.target).rstrip('.').lower()
283+
if _is_fqdn(fqdn_name):
284+
names.append(fqdn_name)
285+
return names
286+
except Exception:
287+
logger.exception('PTR lookup failed for %s', ip)
288+
return []
289+
290+
def _contains_private(self, addrs: list[str]) -> bool:
291+
"""Return True if any resolvable address string is a private IP."""
292+
for addr in addrs:
293+
try:
294+
if ipaddress.ip_address(addr).is_private:
295+
return True
296+
except ValueError:
297+
continue
298+
return False
299+
300+
301+
def _is_fqdn(name: str) -> bool:
302+
"""Return True if the string is a valid FQDN (lower-cased, no trailing dot).
303+
304+
Rules for labels (per common DNS practice):
305+
- Each label is 1..63 characters.
306+
- Only ASCII letters, digits, and hyphen are allowed (LDH rule).
307+
- A label cannot start or end with a hyphen.
308+
- The name must contain at least one dot separating labels.
309+
"""
310+
if not name:
311+
return False
312+
if name.endswith('.'):
313+
return False
314+
if '.' not in name:
315+
return False
316+
labels = name.split('.')
317+
for label in labels:
318+
if not label or len(label) > 63:
319+
return False
320+
for ch in label:
321+
if not (ch.isalnum() or ch == '-'):
322+
return False
323+
if label[0] == '-' or label[-1] == '-':
324+
return False
325+
return True

0 commit comments

Comments
 (0)