From 1f07ea8aa7c90fe0e0dee680879478d14fc1001b Mon Sep 17 00:00:00 2001 From: Francis Date: Fri, 17 Apr 2026 16:59:50 -0400 Subject: [PATCH] Fix stdout leak in find_ssh_servers when scans fail find_ssh_servers redirected sys.stdout manually and only restored it on the happy path. If asyncio.run() raised during LAN or USB scans, stdout could remain redirected to /dev/null and hide all subsequent output. Use nested context managers (open + redirect_stdout) to guarantee stdout restoration and close the devnull handle on both success and exception. Add three tests covering success and failures at both asyncio.run call sites to prove stdout is always restored. Fixes #40 --- ceti/whaletag.py | 13 ++++---- tests/test_whaletag.py | 69 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 6 deletions(-) create mode 100644 tests/test_whaletag.py diff --git a/ceti/whaletag.py b/ceti/whaletag.py index bd00b55..c27d81e 100644 --- a/ceti/whaletag.py +++ b/ceti/whaletag.py @@ -15,6 +15,7 @@ from argparse import Namespace import asyncio +import contextlib import ipaddress import os import re @@ -48,13 +49,13 @@ def find_ssh_servers(): for gateway_ip in getLANips(): netspec = findssh.netfromaddress(gateway_ip) coro = findssh.get_hosts(netspec, 22, "ssh", 1.0) - sys.stdout = open(os.devnull, "w") - lanhosts = asyncio.run(coro) - sys.stdout = sys.__stdout__ + with open(os.devnull, "w") as devnull: + with contextlib.redirect_stdout(devnull): + lanhosts = asyncio.run(coro) coro = findssh.get_hosts(ipaddress.IPv4Network(DEFAULT_USBGADGET_IPNETWORK), 22, "ssh", 1.0) - sys.stdout = open(os.devnull, "w") - usbhosts = asyncio.run(coro) - sys.stdout = sys.__stdout__ + with open(os.devnull, "w") as devnull: + with contextlib.redirect_stdout(devnull): + usbhosts = asyncio.run(coro) for ip in lanhosts+usbhosts: result.append(str(ip[0])) return result diff --git a/tests/test_whaletag.py b/tests/test_whaletag.py new file mode 100644 index 0000000..6a04f68 --- /dev/null +++ b/tests/test_whaletag.py @@ -0,0 +1,69 @@ +import ipaddress +import sys +from unittest import mock + +import pytest + +from ceti.whaletag import find_ssh_servers + + +class TestFindSshServers: + """Tests for find_ssh_servers() -- issue #40.""" + + @mock.patch("ceti.whaletag.findssh.get_hosts", new_callable=mock.Mock) + @mock.patch("ceti.whaletag.findssh.netfromaddress") + @mock.patch("ceti.whaletag.getLANips") + @mock.patch("ceti.whaletag.asyncio.run") + def test_restores_stdout_on_success( + self, mock_asyncio_run, mock_get_lanips, mock_netfromaddress, mock_get_hosts + ): + mock_get_lanips.return_value = [ipaddress.ip_address("192.168.1.1")] + mock_netfromaddress.return_value = "fake-netspec" + mock_get_hosts.return_value = "fake-coro" + mock_asyncio_run.side_effect = [ + [("192.168.1.10", 22)], + [("192.168.11.2", 22)], + ] + + orig_stdout = sys.stdout + result = find_ssh_servers() + + assert sys.stdout is orig_stdout + assert result == ["192.168.1.10", "192.168.11.2"] + + @mock.patch("ceti.whaletag.findssh.get_hosts", new_callable=mock.Mock) + @mock.patch("ceti.whaletag.findssh.netfromaddress") + @mock.patch("ceti.whaletag.getLANips") + @mock.patch("ceti.whaletag.asyncio.run") + def test_restores_stdout_on_first_run_error( + self, mock_asyncio_run, mock_get_lanips, mock_netfromaddress, mock_get_hosts + ): + mock_get_lanips.return_value = [ipaddress.ip_address("192.168.1.1")] + mock_netfromaddress.return_value = "fake-netspec" + mock_get_hosts.return_value = "fake-coro" + mock_asyncio_run.side_effect = RuntimeError("first run failed") + + orig_stdout = sys.stdout + with pytest.raises(RuntimeError, match="first run failed"): + find_ssh_servers() + assert sys.stdout is orig_stdout + + @mock.patch("ceti.whaletag.findssh.get_hosts", new_callable=mock.Mock) + @mock.patch("ceti.whaletag.findssh.netfromaddress") + @mock.patch("ceti.whaletag.getLANips") + @mock.patch("ceti.whaletag.asyncio.run") + def test_restores_stdout_on_second_run_error( + self, mock_asyncio_run, mock_get_lanips, mock_netfromaddress, mock_get_hosts + ): + mock_get_lanips.return_value = [ipaddress.ip_address("192.168.1.1")] + mock_netfromaddress.return_value = "fake-netspec" + mock_get_hosts.return_value = "fake-coro" + mock_asyncio_run.side_effect = [ + [("192.168.1.10", 22)], + RuntimeError("second run failed"), + ] + + orig_stdout = sys.stdout + with pytest.raises(RuntimeError, match="second run failed"): + find_ssh_servers() + assert sys.stdout is orig_stdout