diff --git a/ceti/whaletag.py b/ceti/whaletag.py index bd00b55..c27d81e 100644 --- a/ceti/whaletag.py +++ b/ceti/whaletag.py @@ -15,6 +15,7 @@ from argparse import Namespace import asyncio +import contextlib import ipaddress import os import re @@ -48,13 +49,13 @@ def find_ssh_servers(): for gateway_ip in getLANips(): netspec = findssh.netfromaddress(gateway_ip) coro = findssh.get_hosts(netspec, 22, "ssh", 1.0) - sys.stdout = open(os.devnull, "w") - lanhosts = asyncio.run(coro) - sys.stdout = sys.__stdout__ + with open(os.devnull, "w") as devnull: + with contextlib.redirect_stdout(devnull): + lanhosts = asyncio.run(coro) coro = findssh.get_hosts(ipaddress.IPv4Network(DEFAULT_USBGADGET_IPNETWORK), 22, "ssh", 1.0) - sys.stdout = open(os.devnull, "w") - usbhosts = asyncio.run(coro) - sys.stdout = sys.__stdout__ + with open(os.devnull, "w") as devnull: + with contextlib.redirect_stdout(devnull): + usbhosts = asyncio.run(coro) for ip in lanhosts+usbhosts: result.append(str(ip[0])) return result diff --git a/tests/test_whaletag.py b/tests/test_whaletag.py new file mode 100644 index 0000000..6a04f68 --- /dev/null +++ b/tests/test_whaletag.py @@ -0,0 +1,69 @@ +import ipaddress +import sys +from unittest import mock + +import pytest + +from ceti.whaletag import find_ssh_servers + + +class TestFindSshServers: + """Tests for find_ssh_servers() -- issue #40.""" + + @mock.patch("ceti.whaletag.findssh.get_hosts", new_callable=mock.Mock) + @mock.patch("ceti.whaletag.findssh.netfromaddress") + @mock.patch("ceti.whaletag.getLANips") + @mock.patch("ceti.whaletag.asyncio.run") + def test_restores_stdout_on_success( + self, mock_asyncio_run, mock_get_lanips, mock_netfromaddress, mock_get_hosts + ): + mock_get_lanips.return_value = [ipaddress.ip_address("192.168.1.1")] + mock_netfromaddress.return_value = "fake-netspec" + mock_get_hosts.return_value = "fake-coro" + mock_asyncio_run.side_effect = [ + [("192.168.1.10", 22)], + [("192.168.11.2", 22)], + ] + + orig_stdout = sys.stdout + result = find_ssh_servers() + + assert sys.stdout is orig_stdout + assert result == ["192.168.1.10", "192.168.11.2"] + + @mock.patch("ceti.whaletag.findssh.get_hosts", new_callable=mock.Mock) + @mock.patch("ceti.whaletag.findssh.netfromaddress") + @mock.patch("ceti.whaletag.getLANips") + @mock.patch("ceti.whaletag.asyncio.run") + def test_restores_stdout_on_first_run_error( + self, mock_asyncio_run, mock_get_lanips, mock_netfromaddress, mock_get_hosts + ): + mock_get_lanips.return_value = [ipaddress.ip_address("192.168.1.1")] + mock_netfromaddress.return_value = "fake-netspec" + mock_get_hosts.return_value = "fake-coro" + mock_asyncio_run.side_effect = RuntimeError("first run failed") + + orig_stdout = sys.stdout + with pytest.raises(RuntimeError, match="first run failed"): + find_ssh_servers() + assert sys.stdout is orig_stdout + + @mock.patch("ceti.whaletag.findssh.get_hosts", new_callable=mock.Mock) + @mock.patch("ceti.whaletag.findssh.netfromaddress") + @mock.patch("ceti.whaletag.getLANips") + @mock.patch("ceti.whaletag.asyncio.run") + def test_restores_stdout_on_second_run_error( + self, mock_asyncio_run, mock_get_lanips, mock_netfromaddress, mock_get_hosts + ): + mock_get_lanips.return_value = [ipaddress.ip_address("192.168.1.1")] + mock_netfromaddress.return_value = "fake-netspec" + mock_get_hosts.return_value = "fake-coro" + mock_asyncio_run.side_effect = [ + [("192.168.1.10", 22)], + RuntimeError("second run failed"), + ] + + orig_stdout = sys.stdout + with pytest.raises(RuntimeError, match="second run failed"): + find_ssh_servers() + assert sys.stdout is orig_stdout