diff --git a/cmapi/cmapi_server/__main__.py b/cmapi/cmapi_server/__main__.py index c61f63916..ca9a11b19 100644 --- a/cmapi/cmapi_server/__main__.py +++ b/cmapi/cmapi_server/__main__.py @@ -17,24 +17,26 @@ # TODO: fix dispatcher choose logic because code executing in endpoints.py # while import process, this cause module logger misconfiguration from cmapi_server.logging_management import config_cmapi_server_logging +config_cmapi_server_logging() + from tracing.sentry import maybe_init_sentry from tracing.traceparent_backend import TraceparentBackend from tracing.tracer import get_tracer +from tracing.trace_tool import register_tracing_tools -config_cmapi_server_logging() from cmapi_server import helpers -from cmapi_server.constants import CMAPI_CONF_PATH, DEFAULT_MCS_CONF_PATH -from cmapi_server.controllers.dispatcher import dispatcher, jsonify_404, jsonify_error +from cmapi_server.constants import DEFAULT_MCS_CONF_PATH, CMAPI_CONF_PATH +from cmapi_server.controllers.dispatcher import dispatcher, jsonify_error, jsonify_404 from cmapi_server.failover_agent import FailoverAgent -from cmapi_server.invariant_checks import run_invariant_checks from cmapi_server.managers.application import AppManager -from cmapi_server.managers.certificate import CertificateManager +from cmapi_server.managers.host_identity import get_host_address_manager from cmapi_server.managers.process import MCSProcessManager +from cmapi_server.managers.certificate import CertificateManager +from cmapi_server.invariant_checks import run_invariant_checks from failover.node_monitor import NodeMonitor from failover.config import Config from mcs_node_control.models.dbrm_socket import SOCK_TIMEOUT, DBRMSocketHandler from mcs_node_control.models.node_config import NodeConfig -from tracing.trace_tool import register_tracing_tools def worker(app): @@ -161,6 +163,9 @@ def stop(self): logging.error('Invariant checks failed, exiting') sys.exit(1) + my_identity = get_host_address_manager().get_local_identity() + logging.info('My identity: %s', my_identity) + app = cherrypy.tree.mount(root=None, config=CMAPI_CONF_PATH) root_config = { "request.dispatch": dispatcher, diff --git a/cmapi/cmapi_server/exceptions.py b/cmapi/cmapi_server/exceptions.py index 34ebb6702..c13e5272c 100644 --- a/cmapi/cmapi_server/exceptions.py +++ b/cmapi/cmapi_server/exceptions.py @@ -30,6 +30,14 @@ class CEJError(CMAPIBasicError): """ +class ResolutionError(CMAPIBasicError): + """Errors related to DNS resolution""" + + +class ResolutionPolicyViolationError(CMAPIBasicError): + """Errors where results are rejected by the current resolving policy.""" + + @contextmanager def exc_to_cmapi_error(prefix: Optional[str] = None) -> Iterator[None]: """Context manager to standardize error wrapping into CMAPIBasicError. diff --git a/cmapi/cmapi_server/managers/host_identity.py b/cmapi/cmapi_server/managers/host_identity.py new file mode 100644 index 000000000..b25b3afc9 --- /dev/null +++ b/cmapi/cmapi_server/managers/host_identity.py @@ -0,0 +1,438 @@ +import hashlib +import ipaddress +import logging +from collections.abc import Iterable, Sequence +from dataclasses import dataclass, field +from datetime import datetime, timezone +from functools import lru_cache +from typing import Optional, Union + +import dns.resolver +import dns.reversename + +from cmapi_server.exceptions import CMAPIBasicError, ResolutionError, ResolutionPolicyViolationError +from cmapi_server.managers.network import NetworkManager + +IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + +logger = logging.getLogger(__name__) + + +@dataclass +class HostIdentity: + """Host's network identity, IP addrs and hostnames that are visible from other hosts""" + input: str + # The first, most important IP addr and host name (ordering is done by policy) + primary_ip: str + primary_name: Optional[str] # Name can be missing (but policy can make it required) + ips: list[str] # All IP addrs + names: list[str] # All host names (only those that are visible to other hosts) + unique_key: str # unique id of this host, will be used later for aliases + observed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + @staticmethod + def from_policy(input: str, policy: 'ResolutionPolicy', ips: Sequence[IPAddress], names: Sequence[str]) -> 'HostIdentity': + """Construct a HostIdentity from filtered addresses and names using policy ordering.""" + if not ips: + raise ResolutionPolicyViolationError('All resolved addresses were rejected by policy.') + + ordered_ips = policy.order_addresses(ips) + # Calculate host unique key from its addresses (we prefer globals as more stable ones) + globals_ordered = [ip for ip in ordered_ips if ip.is_global] + privates_ordered = [ip for ip in ordered_ips if ip.is_private] + to_hash = globals_ordered if globals_ordered else privates_ordered + hasher = hashlib.sha256() + for ip in to_hash: + hasher.update(str(ip).encode('utf-8')) + unique_key = hasher.hexdigest() + + return HostIdentity( + input=input, + ips=[str(ip) for ip in ordered_ips], + names=list(names), + primary_ip=str(ordered_ips[0]), + primary_name=names[0] if names else None, + unique_key=unique_key, + ) + + @property + def effective_hostname(self) -> str: + """Hostname or primary IP if hostname is not set + + If policy requires host to have a hostname and it doesn't, we won't even get here + """ + return self.primary_name or self.primary_ip + + def __repr__(self) -> str: + parts: list[str] = [ + f'input={self.input!r}', + f'primary_ip={self.primary_ip!r}', + ] + if self.primary_name is not None: + parts.append(f'primary_name={self.primary_name!r}') + # Don't show addrs and hostnames if there is only one addr or hostname + if self.ips != [self.primary_ip]: + parts.append(f'ips={self.ips!r}') + if self.names != [self.primary_name]: + parts.append(f'names={self.names!r}') + parts.append(f'unique_key={self.unique_key[:4]}') + parts.append(f'observed_at={self.observed_at.replace(microsecond=0).isoformat()!r}') + return 'HostIdentity(' + ', '.join(parts) + ')' + + def __str__(self) -> str: + return repr(self) + + +@dataclass(frozen=True) +class ResolutionPolicy: + """Class to represent our requirements for the addresses, how are they resolved, + what must be filtered out, how addresses are ordered, etc. + + It's nice to separate these concerns from the resolving itself. + Also maybe in the future we'll make it configurable by users. + """ + + allow_private_ips: bool = True + allow_ipv6: bool = False + require_hostname: bool = False + + def filter_addresses(self, addrs: Iterable[str]) -> list[IPAddress]: + """Filter out IP addresses that don't match the policy.""" + ips: list[IPAddress] = [] + logger.debug( + 'Filtering addresses: %s (allow_private_ips=%s, allow_ipv6=%s)', + list(addrs), self.allow_private_ips, self.allow_ipv6 + ) + for addr in addrs: + ip = _ip_or_none(addr) + if ip is None: + logger.error('Skip %s: not a valid IP literal', addr) + continue + + # Fixed rejections first + if ip.is_loopback: + logger.debug('Skip %s: loopback address', addr) + continue + if ip.is_link_local: + logger.debug('Skip %s: link-local address', addr) + continue + if ip.is_multicast: + logger.debug('Skip %s: multicast address', addr) + continue + if isinstance(ip, ipaddress.IPv4Address) and int(ip) == int(ipaddress.IPv4Address('255.255.255.255')): + logger.debug('Skip %s: IPv4 broadcast address', addr) + continue + + # Policy conditionals + if not self.allow_ipv6 and isinstance(ip, ipaddress.IPv6Address): + logger.debug('Skip %s: IPv6 not allowed by policy', addr) + continue + if not ip.is_global and not ip.is_private: + logger.debug('Skip %s: neither global nor private address', addr) + continue + if ip.is_private and not self.allow_private_ips: + logger.debug('Skip %s: private address but private use disabled', addr) + continue + + ips.append(ip) + logger.debug('Accept %s', addr) + return ips + + def order_addresses(self, ips: Sequence[IPAddress]) -> list[IPAddress]: + """Order IPs deterministically to choose the primary IP.""" + def key(ip: IPAddress) -> tuple[int, int, int, int, int, int, int, int]: + is_global = 0 if ip.is_global else 1 + is_ipv4 = 0 if isinstance(ip, ipaddress.IPv4Address) else 1 + return ( + is_global, + is_ipv4, + *list(ip.packed) + ) + return sorted(ips, key=key) + + +class HostAddressManager: + """Calculates HostIdentity from passed hostname or IP address.""" + + def __init__(self, policy: Optional[ResolutionPolicy] = None) -> None: + self._policy = policy if policy is not None else ResolutionPolicy() + self._cache: dict[str, HostIdentity] = {} + + def get_identity(self, target: str) -> HostIdentity: + """Resolve and normalize a hostname or IP.""" + if target in self._cache: + return self._cache[target] + + target = target.strip() + + ip = _ip_or_none(target) + if ip is not None: + identity = self._get_identity_from_ip(ip, target) + else: + identity = self._get_identity_from_hostname(target) + + self._cache[target] = identity + return identity + + def get_local_identity(self) -> HostIdentity: + """Calculate HostIdentity for the current host.""" + try: + local_ips = NetworkManager.get_current_node_ips( + ignore_loopback=False, + only_ipv4=not self._policy.allow_ipv6, + ) + except: + logger.exception('Failed to get local IP addresses') + raise + + # Use the first IP that resolves into a normal identity that is not rejected by policy + for ip_text in set(local_ips): + try: + return self.get_identity(ip_text) + except CMAPIBasicError: + logger.debug('Local identity candidate %s failed resolution', ip_text) + continue + raise ResolutionPolicyViolationError('Could not determine any acceptable local IP addresses under current policy.') + + def check_hostname_rev_lookup(self, hostname: str) -> tuple[HostIdentity, bool]: + """Resolve hostname and check that at least one of its IPs resolves back to it. + + Returns identity and boolean, true means that roundtrip check was successful. + """ + identity = self.get_identity(hostname) + + normalized = hostname.strip().lower() + if not normalized: + return identity, False + + for ip_text in identity.ips: + ip = _ip_or_none(ip_text) + if ip is None: + logger.error('Invalid IP address: %s', ip_text) + continue + + try: + reverse_names = self._get_names_of_ip(ip) + except Exception: + logger.exception('Failed to get reverse names for %s', ip) + continue + + for rev_name in reverse_names: + # Count reverse names like "mcs1." as a match for "mcs" + # It is true because of search_domain/ndots (in /etc/resolv.conf) + # Resolvers respect these options and will add them to the hostname + if rev_name == normalized or rev_name.startswith(normalized + '.'): + logger.debug('Roundtrip check passed for %s: %s', hostname, ip) + return identity, True + + logger.warning('Roundtrip check failed for %s', hostname) + return identity, False + + def _get_identity_from_ip(self, ip: IPAddress, original_input: str) -> HostIdentity: + if not self._policy.filter_addresses([str(ip)]): + raise ResolutionPolicyViolationError('Input IP address was rejected by policy') + + names: list[str] = self._get_names_of_ip(ip) + if self._policy.require_hostname and not names: + logger.warning('Reject %s: no names found for this IP and policy requires hostname', original_input) + raise ResolutionPolicyViolationError('Policy requires a hostname for the input IP address (no PTR record found).') + + return HostIdentity.from_policy(original_input, self._policy, [ip], sorted(names)) + + def _get_identity_from_hostname(self, hostname: str) -> HostIdentity: + normalized = hostname.strip().lower() + if _is_fqdn(normalized): + return self._get_identity_from_fqdn(normalized) + else: + return self._get_identity_from_non_fqdn(hostname) + + def _get_identity_from_fqdn(self, fqdn: str) -> HostIdentity: + # Get IPs from hostname (via DNS), filter them, then get names from each IP + addrs = self._resolve_dns(fqdn) + if not addrs: + raise ResolutionError(f'Could not resolve {fqdn} to any IP addresses.') + filtered_ips = self._policy.filter_addresses([str(ip) for ip in addrs]) + + # Resolve each IP back to names and check if there is any that resolved back to passed fqdn + names: set[str] = {fqdn} + roundtrip_found = False + for ip in filtered_ips: + names_of_ip = self._get_names_of_ip(ip) + if fqdn in names_of_ip: + roundtrip_found = True + names.update(names_of_ip) + + if not roundtrip_found: + logger.warning( + 'FQDN %s failed DNS forward/reverse round-trip; IPs: %s', + fqdn, + [str(ip) for ip in filtered_ips], + ) + raise ResolutionError(f'{fqdn} failed the DNS round-trip check — forward and reverse records do not match.') + + if not filtered_ips: + raise ResolutionPolicyViolationError('All resolved addresses were rejected by policy.') + + return HostIdentity.from_policy(fqdn, self._policy, filtered_ips, sorted(names)) + + def _get_identity_from_non_fqdn(self, hostname: str) -> HostIdentity: + # Like FQDN version, but we know that nothing will resolve back to passed hostname + candidate_ips: set[str] = set( + NetworkManager.resolve_hostname_to_ips( + hostname, + only_ipv4=not self._policy.allow_ipv6, + exclude_loopback=False, + ) + ) + + ips = self._policy.filter_addresses(candidate_ips) + if not ips: + raise ResolutionPolicyViolationError('All resolved addresses were rejected by policy.') + + # Collect names of IPs + names: set[str] = set() + for ip_text in list(candidate_ips): + ip = _ip_or_none(ip_text) + if ip is None: + logger.error('Invalid IP address: %s', ip_text) + continue + names.update(self._get_names_of_ip(ip)) + + if self._policy.require_hostname and not names: + logger.error( + 'Non-FQDN name %s does not resolve to any valid hostname; candidate IPs: %s', + hostname, + sorted(candidate_ips), + ) + raise ResolutionPolicyViolationError('Policy requires a hostname for the input host, but DNS did not return any FQDN names.') + + return HostIdentity.from_policy(hostname, self._policy, ips, sorted(names)) + + def _resolve_dns(self, hostname: str) -> list[IPAddress]: + """Resolve the given hostname using DNS and return addresses.""" + ipv4_texts: list[str] = [] + ipv6_texts: list[str] = [] + try: + ipv4_texts = self._dns_resolve_ipv4(hostname) + except dns.resolver.NoAnswer: + logger.warning('IPv4 lookup returned no records for %s', hostname) + ipv4_texts = [] + except Exception: + logger.exception('IPv4 lookup unexpected failure for %s', hostname) + raise + + if self._policy.allow_ipv6: + try: + ipv6_texts = self._dns_resolve_ipv6(hostname) + except dns.resolver.NoAnswer: + logger.warning('IPv6 lookup returned no records for %s', hostname) + ipv6_texts = [] + except Exception: + logger.exception('IPv6 lookup unexpected failure for %s', hostname) + raise + + addrs: list[IPAddress] = [] + for ip_text in ipv4_texts + ipv6_texts: + ip = _ip_or_none(ip_text) + if ip is None: + logger.error('DNS returned invalid IP address %s for host name %s, skipping', ip_text, hostname) + continue + addrs.append(ip) + + return addrs + + def _get_names_of_ip(self, ip: IPAddress) -> list[str]: + """Fetch PTR names for an IP via DNS.""" + try: + return self._dns_reverse(str(ip)) + except dns.resolver.NoAnswer: + logger.warning('ip-to-name lookup returned no records for %s', ip) + return [] + except Exception: + logger.exception('ip-to-name lookup unexpected failure for %s', ip) + raise + + # DNS abstraction methods for easier mocking + def _dns_resolve_ipv4(self, hostname: str) -> list[str]: + resolver = dns.resolver.Resolver(configure=True) + results: list[str] = [] + for rdata in resolver.resolve(hostname, 'A', raise_on_no_answer=False): + try: + results.append(rdata.to_text()) + except Exception: + continue + return results + + def _dns_resolve_ipv6(self, hostname: str) -> list[str]: + resolver = dns.resolver.Resolver(configure=True) + results: list[str] = [] + for rdata in resolver.resolve(hostname, 'AAAA', raise_on_no_answer=False): + try: + results.append(rdata.to_text()) + except Exception: + continue + return results + + def _dns_reverse(self, ip_text: str) -> list[str]: + reverse_name = dns.reversename.from_address(ip_text) + answer = dns.resolver.resolve(reverse_name, 'PTR', raise_on_no_answer=False) + names: list[str] = [] + for ptr_rdata in answer: + try: + name = str(ptr_rdata.target).rstrip('.').lower() + if name: + names.append(name) + except Exception: + continue + return names + + def _contains_private(self, addrs: list[str]) -> bool: + """Return True if any resolvable address string is a private IP.""" + for addr in addrs: + ip = _ip_or_none(addr) + if ip is None: + continue + if ip.is_private: + return True + return False + + +@lru_cache(maxsize=1) # singleton +def get_host_address_manager() -> 'HostAddressManager': + return HostAddressManager() + + +def _ip_or_none(val: str) -> Optional[IPAddress]: + try: + return ipaddress.ip_address(val) + except ValueError: + return None + +def _is_ip_address(val: str) -> bool: + return _ip_or_none(val) is not None + +def _is_fqdn(name: str) -> bool: + """Return True if the string is a valid FQDN (lower-cased, no trailing dot). + + Rules for labels (per common DNS practice): + - Each label is 1..63 characters. + - Only ASCII letters, digits, and hyphen are allowed (LDH rule). + - A label cannot start or end with a hyphen. + - The name must contain at least one dot separating labels. + """ + if not name: + return False + if name.endswith('.'): + return False + if '.' not in name: + return False + labels = name.split('.') + for label in labels: + if not label or len(label) > 63: + return False + for ch in label: + if not (ch.isalnum() or ch == '-'): + return False + if label[0] == '-' or label[-1] == '-': + return False + return True diff --git a/cmapi/cmapi_server/managers/network.py b/cmapi/cmapi_server/managers/network.py index 3d43ee195..65cd18c5d 100644 --- a/cmapi/cmapi_server/managers/network.py +++ b/cmapi/cmapi_server/managers/network.py @@ -3,9 +3,8 @@ import logging import socket import struct -from dataclasses import dataclass from ipaddress import ip_address -from typing import List, Optional, cast +from typing import Optional from cmapi_server.exceptions import CMAPIBasicError @@ -250,71 +249,3 @@ def is_only_loopback_hostname(cls, hostname: str) -> bool: if not ip_address(ip).is_loopback: return False return True - - @classmethod - def resolve_ip_and_hostname(cls, input_str: str) -> tuple[str, Optional[str]]: - """Resolve input string to an (IP, hostname) pair. - - :param input_str: Input which may be an IP address or a hostname - :type input_str: str - :return: A tuple containing (ip, hostname) - :rtype: tuple[str, str] - :raises CMAPIBasicError: if hostname resolution yields no IPs - """ - ip: str = '' - hostname: Optional[str] = None - - if cls.is_ip(input_str): - ip = input_str - hostname = cls.get_hostname(input_str) - else: - hostname = input_str - ip_list = cls.resolve_hostname_to_ips( - input_str, - exclude_loopback=not cls.is_only_loopback_hostname(input_str) - ) - if not ip_list: - raise CMAPIBasicError(f'No IPs found for {hostname!r}') - ip = ip_list[0] - return ip, hostname - - @classmethod - def validate_hostname_fwd_rev(cls, hostname: str) -> None: - """Validate forward and reverse DNS for a hostname. - - Checks that hostname resolves to one or more usable IPs and that at - least one of those IPs reverse-resolves back to the provided hostname - (either an exact match or an FQDN starting with the hostname are accepted). - - :raises CMAPIBasicError: if validation fails - """ - exclude_loopback = not cls.is_only_loopback_hostname(hostname) - ips = cls.resolve_hostname_to_ips( - hostname, - only_ipv4=True, - exclude_loopback=exclude_loopback, - ) - - if not ips: - raise CMAPIBasicError( - f"Hostname {hostname!r} did not resolve to any usable IPs. " - "Please fix DNS or add the host by IP." - ) - - wanted = hostname.rstrip('.').lower() - for ip in ips: - rev_names = cls.get_hostnames_by_ip(ip) - for rev in rev_names: - rev_norm = rev.rstrip('.').lower() - # Accept exact match ("db1" == "db1") or FQDN starting with the short hostname - # e.g. user provided "db1" and PTR returns "db1.example.com" - if rev_norm == wanted or rev_norm.startswith(wanted + '.'): - return - - raise CMAPIBasicError( - 'Forward/reverse DNS check failed: ' - f"hostname {hostname!r} resolved to {ips}, but none of these IPs " - f"reverse-resolve back to {hostname!r}. Consider adding the host by IP, " - 'or fix DNS so that at least one IP has a PTR/record mapping back to ' - 'the provided hostname.' - ) diff --git a/cmapi/cmapi_server/node_manipulation.py b/cmapi/cmapi_server/node_manipulation.py index 10e1819f9..9ed9af2ca 100644 --- a/cmapi/cmapi_server/node_manipulation.py +++ b/cmapi/cmapi_server/node_manipulation.py @@ -7,15 +7,12 @@ import logging import os import shutil -import socket import subprocess import time from typing import Optional import requests from lxml import etree -from mcs_node_control.models.node_config import NodeConfig -from tracing.traced_session import get_traced_session from cmapi_server import helpers from cmapi_server.constants import ( @@ -26,9 +23,12 @@ LOCALHOSTS, MCS_DATA_PATH, ) +from cmapi_server.exceptions import CMAPIBasicError from cmapi_server.managers.application import AppStatefulConfig +from cmapi_server.managers.host_identity import get_host_address_manager from cmapi_server.managers.network import NetworkManager - +from mcs_node_control.models.node_config import NodeConfig +from tracing.traced_session import get_traced_session PMS_NODE_PORT = '8620' EXEMGR_NODE_PORT = '8601' @@ -101,14 +101,27 @@ def add_node( c_root = node_config.get_current_config_root(input_config_filename) logging.info('Adding node %s', node) + + addr_mgr = get_host_address_manager() + host_identity = addr_mgr.get_identity(node) + logging.debug('Resolved %s to %s', node, host_identity) + # If a hostname (not IP) is provided, ensure fwd/rev DNS consistency. # Skip validation for localhost aliases to preserve legacy single-node flows. if not NetworkManager.is_ip(node) and not NetworkManager.is_only_loopback_hostname(node): - NetworkManager.validate_hostname_fwd_rev(node) + _, ok = addr_mgr.check_hostname_rev_lookup(node) + if not ok: + raise CMAPIBasicError( + f'''Forward/reverse DNS check failed: + hostname {node!r} resolved to {host_identity.ips}, but none of these IPs + reverse-resolve back to {node!r}. Consider adding the host by IP, + or fix DNS so that at least one IP has a PTR/record mapping back to + the provided hostname.''' + ) try: if not _replace_localhost(c_root, node): - ip4, _ = NetworkManager.resolve_ip_and_hostname(node) + ip4 = host_identity.primary_ip pm_num = _add_node_to_PMS(c_root, ip4) if not read_replica: @@ -179,7 +192,8 @@ def remove_node( try: active_nodes = helpers.get_active_nodes(input_config_filename) - ip4, _ = NetworkManager.resolve_ip_and_hostname(node) + host_identity = get_host_address_manager().get_identity(node) + ip4 = host_identity.primary_ip if len(active_nodes) > 1: pm_num = _remove_node_from_PMS(c_root, ip4) @@ -284,7 +298,11 @@ def add_dbroot(input_config_filename = None, output_config_filename = None, host else: c_root = node_config.get_current_config_root(config_filename = input_config_filename) - ip4, _ = NetworkManager.resolve_ip_and_hostname(host) if host else (None, None) + if host: + host_identity = get_host_address_manager().get_identity(host) + ip4 = host_identity.primary_ip + else: + ip4 = None try: ret = _add_dbroot(c_root, ip4) except Exception as e: @@ -382,11 +400,9 @@ def _add_active_node(root, node): we replace it with the IP address. ''' - ip4, hostname = NetworkManager.resolve_ip_and_hostname(node) - # If reverse lookup failed, hostname may be None. Use the IP as a - # fallback so removal by hostname also works consistently. - if hostname is None: - hostname = ip4 + host_identity = get_host_address_manager().get_identity(node) + ip4 = host_identity.primary_ip + hostname = host_identity.effective_hostname # Remove both hostname and IP form before adding IP to avoid checking if # some of them is already there. Then we add by IP @@ -419,11 +435,9 @@ def _remove_node(root, node): remove node from DesiredNodes, InactiveNodes, ActiveNodes ''' # Remove both hostname and IPv4 forms - ip4, hostname = NetworkManager.resolve_ip_and_hostname(node) - # If reverse lookup failed, normalize hostname to ip so we always try - # removing both variants from lists. - if hostname is None: - hostname = ip4 + host_identity = get_host_address_manager().get_identity(node) + ip4 = host_identity.primary_ip + hostname = host_identity.effective_hostname for lst in ( root.find("./DesiredNodes"), root.find("./InactiveNodes"), @@ -436,9 +450,9 @@ def _remove_node(root, node): # This moves a node from ActiveNodes to InactiveNodes def _deactivate_node(root, node): """Move node from ActiveNodes to InactiveNodes. Store as IPv4.""" - ip4, hostname = NetworkManager.resolve_ip_and_hostname(node) - if hostname is None: - hostname = ip4 + host_identity = get_host_address_manager().get_identity(node) + ip4 = host_identity.primary_ip + hostname = host_identity.effective_hostname active_nodes = root.find("./ActiveNodes") __remove_helper(active_nodes, hostname) @@ -1035,7 +1049,9 @@ def _add_Module_entries(root, node: str) -> None: # XXXPAT: No guarantee these are the values used in the rest of the system. # TODO: what should we do with complicated network configs where node has # several ips and\or several hostnames - ip4, hostname = NetworkManager.resolve_ip_and_hostname(node) + host_identity = get_host_address_manager().get_identity(node) + ip4 = host_identity.primary_ip + hostname = host_identity.primary_name if hostname is None: logging.warning(f'Could not resolve hostname for {node}, using IP address as hostname') hostname = ip4 @@ -1088,7 +1104,8 @@ def _add_WES(root, pm_num, node): `node` may be a hostname or an IP; we normalize to IPv4 to avoid mismatches when comparing against ModuleIPAddr entries. """ - ip4, _hostname = NetworkManager.resolve_ip_and_hostname(node) + host_identity = get_host_address_manager().get_identity(node) + ip4 = host_identity.primary_ip wes_node = etree.SubElement(root, f"pm{pm_num}_WriteEngineServer") etree.SubElement(wes_node, "IPAddr").text = ip4 etree.SubElement(wes_node, "Port").text = "8630" @@ -1205,16 +1222,8 @@ def _replace_localhost(root: etree.Element, node: str) -> bool: ) return False - # TODO use NetworkManager here - # getaddrinfo returns list of 5-tuples (..., sockaddr) - # use sockaddr to retrieve ip, sockaddr = (address, port) for AF_INET - ipaddr = socket.getaddrinfo(node, 8640, family=socket.AF_INET)[0][-1][0] - # signifies that node is an IP addr already - if ipaddr == node: - # use the primary hostname if given an ip addr - hostname = socket.gethostbyaddr(ipaddr)[0] - else: - hostname = node # use whatever name they gave us + host_identity = get_host_address_manager().get_identity(node) + ipaddr, hostname = host_identity.primary_ip, host_identity.effective_hostname logging.info( f'add_node(): replacing 127.0.0.1/localhost with {ipaddr}/{hostname} ' f'as this node\'s name. Be sure {hostname} resolves to {ipaddr} on ' @@ -1235,11 +1244,11 @@ def _replace_localhost(root: etree.Element, node: str) -> bool: if 'ModuleIPAddr' in n.tag: n.text = ipaddr - logging.info(f"Replaced %s (was %s) with IP %s", path, old_val, ipaddr) + logging.info("Replaced %s (was %s) with IP %s", path, old_val, ipaddr) continue if 'ModuleHostName' in n.tag: n.text = hostname - logging.info(f"Replaced %s (was %s) with hostname %s", path, old_val, hostname) + logging.info("Replaced %s (was %s) with hostname %s", path, old_val, hostname) continue # Generic fields: replace localhost IPs with ipaddr, hostnames with hostname @@ -1248,15 +1257,15 @@ def _replace_localhost(root: etree.Element, node: str) -> bool: new_val = ipaddr if is_local_ip else hostname if is_local_ip: new_val = ipaddr - logging.info(f"Replaced %s (was %s) with IP %s", path, old_val, new_val) + logging.info("Replaced %s (was %s) with IP %s", path, old_val, new_val) else: new_val = hostname - logging.info(f"Replaced %s (was %s) with hostname %s", path, old_val, new_val) + logging.info("Replaced %s (was %s) with hostname %s", path, old_val, new_val) n.text = new_val old_controller = controller_host.text controller_host.text = hostname # keep controllernode as fqdn - logging.info(f"Replaced %s (was %s) with hostname %s", './DBRM_Controller/IPAddr', old_controller, hostname) + logging.info("Replaced %s (was %s) with hostname %s", './DBRM_Controller/IPAddr', old_controller, hostname) return True diff --git a/cmapi/cmapi_server/test/mock_resolution.py b/cmapi/cmapi_server/test/mock_resolution.py index 026c7dc3f..ed648eeee 100644 --- a/cmapi/cmapi_server/test/mock_resolution.py +++ b/cmapi/cmapi_server/test/mock_resolution.py @@ -28,7 +28,7 @@ class MockResolutionBuilder: def __init__(self): # Forward: hostname -> ip self._forward: Dict[str, str] = {} - # Reverse: ip -> (primary_hostname, alias_list) + # Reverse: ip -> (primary_name, alias_list) self._reverse: Dict[str, Tuple[str, List[str]]] = {} # Defaults used when no explicit mapping is provided self._default_ip: Optional[str] = None @@ -73,6 +73,79 @@ def set_default(self, ip: str, hostname: str): self._default_hostname = hostname return self + def build(self): + + @contextmanager + def _ctx(): + patches = [ + # Patch socket-level resolvers (NetworkManager uses these under the hood) + patch('socket.getaddrinfo', side_effect=self._fake_getaddrinfo), + patch('socket.gethostbyname', side_effect=self._fake_gethostbyname), + patch('socket.gethostbyaddr', side_effect=self._fake_gethostbyaddr), + # Patch local identity to be synthetic; avoid real system calls + patch('socket.gethostname', return_value=CUR_HOST_HOSTNAME), + patch('socket.getfqdn', return_value=CUR_HOST_HOSTNAME), + # Patch NetworkManager local IP discovery (it uses psutil or system libs, + # proper mocking would be too complex) + patch('cmapi_server.managers.network.NetworkManager.get_current_node_ips', return_value=[CUR_HOST_IP, DEFAULT_LOCALHOST_IP]), + # Patch HostAddressManager DNS abstraction methods + patch('cmapi_server.managers.host_identity.HostAddressManager._dns_resolve_ipv4', + side_effect=self._fake_dns_resolve_ipv4), + patch('cmapi_server.managers.host_identity.HostAddressManager._dns_resolve_ipv6', + side_effect=self._fake_dns_resolve_ipv6), + patch('cmapi_server.managers.host_identity.HostAddressManager._dns_reverse', + side_effect=self._fake_dns_reverse), + ] + with ExitStack() as stack: + for p in patches: + stack.enter_context(p) + yield + + return _ctx() + + def _fake_getaddrinfo(self, host, port, family=socket.AF_UNSPEC, type=0, proto=0, flags=0): + # Only handle AF_INET calls; otherwise, simulate failure + if family not in (socket.AF_UNSPEC, socket.AF_INET): + raise socket.gaierror + # For localhost, return loopback first and include CUR_HOST_IP as secondary + if host == DEFAULT_LOCALHOST_HOSTNAME: + return [ + (socket.AF_INET, socket.SOCK_STREAM, 6, '', (DEFAULT_LOCALHOST_IP, port)), + (socket.AF_INET, socket.SOCK_STREAM, 6, '', (CUR_HOST_IP, port)), + ] + ip, _ = self._resolve_forward(host) + return [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (ip, port))] + + def _fake_gethostbyname(self, name: str) -> str: + ip, _ = self._resolve_forward(name) + return ip + + def _fake_gethostbyaddr(self, addr: str): + # If no reverse record was set, simulate reverse lookup failure + if addr not in self._reverse: + raise socket.herror + primary, aliases = self._reverse[addr] + return (primary, aliases, [addr]) + + # HostIdentityManager DNS abstraction mocks + def _fake_dns_resolve_ipv4(self, hostname: str) -> List[str]: + # Return mapped IPv4 for provided hostname, or empty list if unknown + ip = self._forward.get(hostname) + return [ip] if ip else [] + + def _fake_dns_resolve_ipv6(self, hostname: str) -> List[str]: + # Keep IPv6 disabled by default in tests for determinism + return [] + + def _fake_dns_reverse(self, ip_text: str) -> List[str]: + # Return PTR names (primary + aliases) from reverse map, lowercase + rec = self._reverse.get(ip_text) + if not rec: + return [] + primary, aliases = rec + names = [primary, *aliases] + return [n.rstrip('.').lower() for n in names if n] + def _resolve_forward(self, host: str) -> Tuple[str, str]: """Resolve hostname or IP to (ip, hostname) using mappings/defaults.""" # If input looks like an IP, return it with reverse or default hostname @@ -101,50 +174,6 @@ def _resolve_forward(self, host: str) -> Tuple[str, str]: # As a last resort, echo back (host, host) return host, host - def build(self): - - def _fake_getaddrinfo(host, port, family=socket.AF_UNSPEC, type=0, proto=0, flags=0): - # Only handle AF_INET calls; otherwise, simulate failure - if family not in (socket.AF_UNSPEC, socket.AF_INET): - raise socket.gaierror - # For localhost, return loopback first and include CUR_HOST_IP as secondary - if host == DEFAULT_LOCALHOST_HOSTNAME: - return [ - (socket.AF_INET, socket.SOCK_STREAM, 6, '', (DEFAULT_LOCALHOST_IP, port)), - (socket.AF_INET, socket.SOCK_STREAM, 6, '', (CUR_HOST_IP, port)), - ] - ip, _ = self._resolve_forward(host) - return [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (ip, port))] - - def _fake_gethostbyname(name: str) -> str: - ip, _ = self._resolve_forward(name) - return ip - - def _fake_gethostbyaddr(addr: str): - # If no reverse record was set, simulate reverse lookup failure - if addr not in self._reverse: - raise socket.herror - primary, aliases = self._reverse[addr] - return (primary, aliases, [addr]) - - @contextmanager - def _ctx(): - patches = [ - # Patch socket-level resolvers (NetworkManager uses these under the hood) - patch('socket.getaddrinfo', side_effect=_fake_getaddrinfo), - patch('socket.gethostbyname', side_effect=_fake_gethostbyname), - patch('socket.gethostbyaddr', side_effect=_fake_gethostbyaddr), - # Patch local identity to be synthetic; avoid real system calls - patch('socket.gethostname', return_value=CUR_HOST_HOSTNAME), - patch('socket.getfqdn', return_value=CUR_HOST_HOSTNAME), - ] - with ExitStack() as stack: - for p in patches: - stack.enter_context(p) - yield - - return _ctx() - def simple_resolution_mock(hostname: str, ip: str): """Return a context manager for simple name/IP resolution mocking. diff --git a/cmapi/cmapi_server/test/test_cluster.py b/cmapi/cmapi_server/test/test_cluster.py index 364461b55..437cab47a 100644 --- a/cmapi/cmapi_server/test/test_cluster.py +++ b/cmapi/cmapi_server/test/test_cluster.py @@ -3,6 +3,7 @@ import socket import subprocess from shutil import copyfile +from unittest.mock import patch import requests @@ -24,6 +25,12 @@ class BaseClusterTestCase(BaseServerTestCase): @classmethod def setUpClass(cls) -> None: copyfile(MCS_CONFIG_FILEPATH, COPY_MCS_CONFIG_FILEPATH) + # Disable real DNS lookups for rev lookup checks, this isn't focus of these tests + cls._rev_check_patcher = patch( + 'cmapi_server.managers.host_identity.HostAddressManager.check_hostname_rev_lookup', + new=lambda manager, hostname: (manager.get_identity(hostname), True), + ) + cls._rev_check_patcher.start() return super().setUpClass() @classmethod @@ -32,6 +39,7 @@ def tearDownClass(cls) -> None: os.remove(os.path.abspath(COPY_MCS_CONFIG_FILEPATH)) MCSProcessManager.stop_node(is_primary=True, use_sudo=False) MCSProcessManager.start_node(is_primary=True, use_sudo=False) + cls._rev_check_patcher.stop() return super().tearDownClass() def setUp(self) -> None: diff --git a/cmapi/cmapi_server/test/test_failover_agent.py b/cmapi/cmapi_server/test/test_failover_agent.py index 9c9bb03e0..a4248e01d 100644 --- a/cmapi/cmapi_server/test/test_failover_agent.py +++ b/cmapi/cmapi_server/test/test_failover_agent.py @@ -1,13 +1,10 @@ import logging -from mcs_node_control.models.node_config import NodeConfig - from cmapi_server.failover_agent import FailoverAgent -from cmapi_server.managers.network import NetworkManager from cmapi_server.node_manipulation import add_node, remove_node -from cmapi_server.test.mock_resolution import simple_resolution_mock, make_local_resolution_builder +from cmapi_server.test.mock_resolution import make_local_resolution_builder from cmapi_server.test.unittest_global import BaseNodeManipTestCase, tmp_mcs_config_filename - +from mcs_node_control.models.node_config import NodeConfig logging.basicConfig(level='DEBUG') diff --git a/cmapi/mcs_node_control/models/node_config.py b/cmapi/mcs_node_control/models/node_config.py index 106f89644..c5a487503 100644 --- a/cmapi/mcs_node_control/models/node_config.py +++ b/cmapi/mcs_node_control/models/node_config.py @@ -4,12 +4,12 @@ import pwd import re import socket +from collections.abc import Iterator from contextlib import contextmanager from os import chown, mkdir, replace from pathlib import Path from shutil import copyfile from typing import Optional -from collections.abc import Iterator from xml.dom import minidom # to pick up pretty printing functionality from lxml import etree @@ -17,8 +17,10 @@ from cmapi_server.constants import ( DEFAULT_MCS_CONF_PATH, DEFAULT_SM_CONF_PATH, + LOCALHOSTS, MCS_MODULE_FILE_PATH, ) +from cmapi_server.managers.host_identity import get_host_address_manager # from cmapi_server.managers.process import MCSProcessManager from mcs_node_control.models.misc import get_dbroots_list, read_module_id @@ -433,7 +435,21 @@ def is_primary_node(self, root=None): root = self.get_current_config_root() primary_address = self.get_dbrm_conn_info(root)['IPAddr'] - return primary_address in self.get_network_addresses_and_names() + + local_identity = get_host_address_manager().get_local_identity() + candidates = set(local_identity.ips) + candidates.update(local_identity.names) + + # HostIdentity ignores loopback/non-DNS-visible names, so add them explicitly + if self.is_single_node(root): + candidates.update(LOCALHOSTS) + + is_primary = primary_address in candidates + module_logger.debug( + 'is_primary: %s, primary_address: %s, local_identity: %s', + is_primary, primary_address, local_identity, + ) + return is_primary def is_single_node(self, root=None): diff --git a/cmapi/pyproject.toml b/cmapi/pyproject.toml index df5e56d23..d02cb9f80 100644 --- a/cmapi/pyproject.toml +++ b/cmapi/pyproject.toml @@ -1,6 +1,15 @@ [tool.ruff] line-length = 100 target-version = "py39" +# Exclude cache and temporary directories +exclude = [ + "__pycache__", +] + +[tool.ruff.format] +quote-style = "single" + +[tool.ruff.lint] # Enable common rule sets select = [ "E", # pycodestyle errors @@ -11,17 +20,8 @@ select = [ "N", # pep8-naming: naming conventions "Q", # flake8-quotes: enforce quote style ] - ignore = [] -# Exclude cache and temporary directories -exclude = [ - "__pycache__", -] - -[tool.ruff.format] -quote-style = "single" - [tool.ruff.lint.isort] known-first-party = ["cmapi_server", "failover", "mcs_node_control", "tracing"] force-single-line = false diff --git a/cmapi/requirements.in b/cmapi/requirements.in index 4c6acc990..d20040e21 100644 --- a/cmapi/requirements.in +++ b/cmapi/requirements.in @@ -20,4 +20,5 @@ pydantic==2.11.7 sentry-sdk==2.34.1 # Invariant checks mr_kot==0.9.2 -mr_kot_fs_validators==0.2.0 \ No newline at end of file +mr_kot_fs_validators==0.2.0 +dnspython==2.7.0 \ No newline at end of file diff --git a/cmapi/requirements.txt b/cmapi/requirements.txt index dcd444069..775fd6490 100644 --- a/cmapi/requirements.txt +++ b/cmapi/requirements.txt @@ -4,23 +4,6 @@ # # dev_tools/piptools.sh compile-all # -aiohttp==3.11.16 -awscli==1.38.28 -CherryPy==18.10.0 -cryptography==43.0.3 -distro==1.9.0 -furl==2.1.4 -gsutil==5.33 -lxml==5.3.2 -psutil==7.0.0 -pyotp==2.9.0 -requests==2.32.3 -# required for CherryPy RoutesDispatcher, -# but CherryPy itself has no such a dependency -Routes==2.5.1 -typer==0.15.2 - -# indirect dependencies aiohappyeyeballs==2.6.1 # via aiohttp aiohttp==3.11.16 @@ -75,6 +58,8 @@ cryptography==43.0.3 # pyopenssl distro==1.9.0 # via -r requirements.in +dnspython==2.7.0 + # via -r requirements.in docutils==0.16 # via awscli fasteners==0.20 diff --git a/cmapi/tracing/trace_tool.py b/cmapi/tracing/trace_tool.py index 83acae0fa..5fbb868bb 100644 --- a/cmapi/tracing/trace_tool.py +++ b/cmapi/tracing/trace_tool.py @@ -98,7 +98,6 @@ def _record_incoming_json_preview(req) -> dict[str, Any]: parsed_json = getattr(req, 'json', None) if parsed_json is None: - logger.debug('request.json is not available') return attrs normalized = json.dumps(parsed_json, ensure_ascii=False, sort_keys=True) if len(normalized) > _PREVIEW_MAX_CHARS: