Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,7 @@ def __init__(
fqdn = None
srv_service_name = keyword_opts.get("srvservicename")
srv_max_hosts = keyword_opts.get("srvmaxhosts")
srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix")
if len([h for h in self._host if "/" in h]) > 1:
raise ConfigurationError("host must not contain multiple MongoDB URIs")
for entity in self._host:
Expand Down Expand Up @@ -858,6 +859,8 @@ def __init__(
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)

srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
if srv_allowed_hosts_suffix is None:
srv_allowed_hosts_suffix = opts.get("srvallowedhostssuffix")
opts = self._normalize_and_validate_options(opts, self._seeds)

# Username and password passed as kwargs override user info in URI.
Expand Down Expand Up @@ -895,7 +898,9 @@ def __init__(

self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries)

self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._init_based_on_options(
self._seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix
)

self._opened = False
self._closed = False
Expand All @@ -913,6 +918,7 @@ async def _resolve_srv(self) -> None:
opts = common._CaseInsensitiveDictionary()
srv_service_name = keyword_opts.get("srvservicename")
srv_max_hosts = keyword_opts.get("srvmaxhosts")
srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix")
for entity in self._host:
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
# it must be a URI,
Expand All @@ -933,6 +939,7 @@ async def _resolve_srv(self) -> None:
connect_timeout=timeout,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
srv_allowed_hosts_suffix=srv_allowed_hosts_suffix,
)
seeds.update(res["nodelist"])
opts = res["options"]
Expand Down Expand Up @@ -965,6 +972,8 @@ async def _resolve_srv(self) -> None:
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)

srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
if srv_allowed_hosts_suffix is None:
srv_allowed_hosts_suffix = opts.get("srvAllowedHostsSuffix")
opts = self._normalize_and_validate_options(opts, seeds)

# Username and password passed as kwargs override user info in URI.
Expand All @@ -974,10 +983,16 @@ async def _resolve_srv(self) -> None:
username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC
)

self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
self._init_based_on_options(
seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix
)

def _init_based_on_options(
self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any
self,
seeds: Collection[tuple[str, int]],
srv_max_hosts: Any,
srv_service_name: Any,
srv_allowed_hosts_suffix: Any,
) -> None:
self._event_listeners = self._options.pool_options._event_listeners
self._topology_settings = TopologySettings(
Expand All @@ -996,6 +1011,7 @@ def _init_based_on_options(
load_balanced=self._options.load_balanced,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
srv_allowed_hosts_suffix=srv_allowed_hosts_suffix,
server_monitoring_mode=self._options.server_monitoring_mode,
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
)
Expand Down
1 change: 1 addition & 0 deletions pymongo/asynchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
self._fqdn,
self._settings.pool_options.connect_timeout,
self._settings.srv_service_name,
srv_allowed_hosts_suffix=self._settings.srv_allowed_hosts_suffix,
)
seedlist, ttl = await resolver.get_hosts_and_min_ttl()
if len(seedlist) == 0:
Expand Down
7 changes: 7 additions & 0 deletions pymongo/asynchronous/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
load_balanced: Optional[bool] = None,
srv_service_name: str = common.SRV_SERVICE_NAME,
srv_max_hosts: int = 0,
srv_allowed_hosts_suffix: Optional[str] = None,
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
topology_id: Optional[ObjectId] = None,
):
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
self._load_balanced = load_balanced
self._srv_service_name = srv_service_name
self._srv_max_hosts = srv_max_hosts or 0
self._srv_allowed_hosts_suffix = srv_allowed_hosts_suffix
self._server_monitoring_mode = server_monitoring_mode
if topology_id is not None:
self._topology_id = topology_id
Expand Down Expand Up @@ -155,6 +157,11 @@ def srv_max_hosts(self) -> int:
"""The srvMaxHosts."""
return self._srv_max_hosts

@property
def srv_allowed_hosts_suffix(self) -> Optional[str]:
"""The srvAllowedHostsSuffix."""
return self._srv_allowed_hosts_suffix

@property
def server_monitoring_mode(self) -> str:
"""The serverMonitoringMode."""
Expand Down
20 changes: 14 additions & 6 deletions pymongo/asynchronous/srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,15 @@ def __init__(
connect_timeout: Optional[float],
srv_service_name: str,
srv_max_hosts: int = 0,
srv_allowed_hosts_suffix: Optional[str] = None,
):
self.__fqdn = fqdn
self.__srv = srv_service_name
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
self.__srv_max_hosts = srv_max_hosts or 0
self.__srv_allowed_hosts_suffix = (
"." + srv_allowed_hosts_suffix.lower().lstrip(".") if srv_allowed_hosts_suffix else None
) # ensure there's a . at the beginning of the domain
# Validate the fully qualified domain name.
try:
ipaddress.ip_address(fqdn)
Expand Down Expand Up @@ -134,12 +138,16 @@ async def _get_srv_response_and_hosts(
raise ConfigurationError(
"Invalid SRV host: return address is identical to SRV hostname"
)
try:
nlist = srv_host.split(".")[1:][-self.__slen :]
except Exception as exc:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc
if self.__plist != nlist:
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__srv_allowed_hosts_suffix is not None:
if not srv_host.endswith(self.__srv_allowed_hosts_suffix):
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
else:
try:
nlist = srv_host.split(".")[1:][-self.__slen :]
except Exception as exc:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc
if self.__plist != nlist:
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__srv_max_hosts:
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
return results, nodes
Expand Down
8 changes: 7 additions & 1 deletion pymongo/asynchronous/uri_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async def parse_uri(
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
srv_allowed_hosts_suffix: Optional[str] = None,
) -> dict[str, Any]:
"""Parse and validate a MongoDB URI.

Expand Down Expand Up @@ -115,6 +116,7 @@ async def parse_uri(
connect_timeout,
srv_service_name,
srv_max_hosts,
srv_allowed_hosts_suffix,
)
)
result["options"] = _make_options_case_sensitive(result["options"])
Expand All @@ -130,6 +132,7 @@ async def _parse_srv(
connect_timeout: Optional[float] = None,
srv_service_name: Optional[str] = None,
srv_max_hosts: Optional[int] = None,
srv_allowed_hosts_suffix: Optional[str] = None,
) -> dict[str, Any]:
if uri.startswith(SCHEME):
is_srv = False
Expand Down Expand Up @@ -157,14 +160,17 @@ async def _parse_srv(

hosts = unquote_plus(hosts)
srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts")
srv_allowed_hosts_suffix = srv_allowed_hosts_suffix or options.get("srvAllowedHostsSuffix")
if is_srv:
nodes = split_hosts(hosts, default_port=None)
fqdn, port = nodes[0]

# Use the connection timeout. connectTimeoutMS passed as a keyword
# argument overrides the same option passed in the connection string.
connect_timeout = connect_timeout or options.get("connectTimeoutMS")
dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts)
dns_resolver = _SrvResolver(
fqdn, connect_timeout, srv_service_name, srv_max_hosts, srv_allowed_hosts_suffix
)
nodes = await dns_resolver.get_hosts()
dns_options = await dns_resolver.get_options()
if dns_options:
Expand Down
1 change: 1 addition & 0 deletions pymongo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ def validate_server_monitoring_mode(option: str, value: str) -> str:
"zlibcompressionlevel": validate_zlib_compression_level,
"srvservicename": validate_string,
"srvmaxhosts": validate_non_negative_integer,
"srvallowedhostssuffix": validate_string,
"timeoutms": validate_timeoutms,
"servermonitoringmode": validate_server_monitoring_mode,
"maxadaptiveretries": validate_non_negative_integer,
Expand Down
22 changes: 19 additions & 3 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,7 @@ def __init__(
fqdn = None
srv_service_name = keyword_opts.get("srvservicename")
srv_max_hosts = keyword_opts.get("srvmaxhosts")
srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix")
if len([h for h in self._host if "/" in h]) > 1:
raise ConfigurationError("host must not contain multiple MongoDB URIs")
for entity in self._host:
Expand Down Expand Up @@ -858,6 +859,8 @@ def __init__(
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)

srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
if srv_allowed_hosts_suffix is None:
srv_allowed_hosts_suffix = opts.get("srvallowedhostssuffix")
opts = self._normalize_and_validate_options(opts, self._seeds)

# Username and password passed as kwargs override user info in URI.
Expand Down Expand Up @@ -895,7 +898,9 @@ def __init__(

self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries)

self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
self._init_based_on_options(
self._seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix
)

self._opened = False
self._closed = False
Expand All @@ -913,6 +918,7 @@ def _resolve_srv(self) -> None:
opts = common._CaseInsensitiveDictionary()
srv_service_name = keyword_opts.get("srvservicename")
srv_max_hosts = keyword_opts.get("srvmaxhosts")
srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix")
for entity in self._host:
# A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/'
# it must be a URI,
Expand All @@ -933,6 +939,7 @@ def _resolve_srv(self) -> None:
connect_timeout=timeout,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
srv_allowed_hosts_suffix=srv_allowed_hosts_suffix,
)
seeds.update(res["nodelist"])
opts = res["options"]
Expand Down Expand Up @@ -965,6 +972,8 @@ def _resolve_srv(self) -> None:
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)

srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
if srv_allowed_hosts_suffix is None:
srv_allowed_hosts_suffix = opts.get("srvAllowedHostsSuffix")
opts = self._normalize_and_validate_options(opts, seeds)

# Username and password passed as kwargs override user info in URI.
Expand All @@ -974,10 +983,16 @@ def _resolve_srv(self) -> None:
username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC
)

self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
self._init_based_on_options(
seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix
)

def _init_based_on_options(
self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any
self,
seeds: Collection[tuple[str, int]],
srv_max_hosts: Any,
srv_service_name: Any,
srv_allowed_hosts_suffix: Any,
) -> None:
self._event_listeners = self._options.pool_options._event_listeners
self._topology_settings = TopologySettings(
Expand All @@ -996,6 +1011,7 @@ def _init_based_on_options(
load_balanced=self._options.load_balanced,
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
srv_allowed_hosts_suffix=srv_allowed_hosts_suffix,
server_monitoring_mode=self._options.server_monitoring_mode,
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
)
Expand Down
1 change: 1 addition & 0 deletions pymongo/synchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]:
self._fqdn,
self._settings.pool_options.connect_timeout,
self._settings.srv_service_name,
srv_allowed_hosts_suffix=self._settings.srv_allowed_hosts_suffix,
)
seedlist, ttl = resolver.get_hosts_and_min_ttl()
if len(seedlist) == 0:
Expand Down
7 changes: 7 additions & 0 deletions pymongo/synchronous/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
load_balanced: Optional[bool] = None,
srv_service_name: str = common.SRV_SERVICE_NAME,
srv_max_hosts: int = 0,
srv_allowed_hosts_suffix: Optional[str] = None,
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
topology_id: Optional[ObjectId] = None,
):
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
self._load_balanced = load_balanced
self._srv_service_name = srv_service_name
self._srv_max_hosts = srv_max_hosts or 0
self._srv_allowed_hosts_suffix = srv_allowed_hosts_suffix
self._server_monitoring_mode = server_monitoring_mode
if topology_id is not None:
self._topology_id = topology_id
Expand Down Expand Up @@ -155,6 +157,11 @@ def srv_max_hosts(self) -> int:
"""The srvMaxHosts."""
return self._srv_max_hosts

@property
def srv_allowed_hosts_suffix(self) -> Optional[str]:
"""The srvAllowedHostsSuffix."""
return self._srv_allowed_hosts_suffix

@property
def server_monitoring_mode(self) -> str:
"""The serverMonitoringMode."""
Expand Down
20 changes: 14 additions & 6 deletions pymongo/synchronous/srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,15 @@ def __init__(
connect_timeout: Optional[float],
srv_service_name: str,
srv_max_hosts: int = 0,
srv_allowed_hosts_suffix: Optional[str] = None,
):
self.__fqdn = fqdn
self.__srv = srv_service_name
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
self.__srv_max_hosts = srv_max_hosts or 0
self.__srv_allowed_hosts_suffix = (
"." + srv_allowed_hosts_suffix.lower().lstrip(".") if srv_allowed_hosts_suffix else None
) # ensure there's a . at the beginning of the domain
# Validate the fully qualified domain name.
try:
ipaddress.ip_address(fqdn)
Expand Down Expand Up @@ -134,12 +138,16 @@ def _get_srv_response_and_hosts(
raise ConfigurationError(
"Invalid SRV host: return address is identical to SRV hostname"
)
try:
nlist = srv_host.split(".")[1:][-self.__slen :]
except Exception as exc:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc
if self.__plist != nlist:
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__srv_allowed_hosts_suffix is not None:
if not srv_host.endswith(self.__srv_allowed_hosts_suffix):
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
else:
try:
nlist = srv_host.split(".")[1:][-self.__slen :]
except Exception as exc:
raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc
if self.__plist != nlist:
raise ConfigurationError(f"Invalid SRV host: {node[0]}")
if self.__srv_max_hosts:
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
return results, nodes
Expand Down
Loading
Loading