Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions ceti/whaletag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from argparse import Namespace
import asyncio
import contextlib
import ipaddress
import os
import re
Expand Down Expand Up @@ -48,13 +49,13 @@ def find_ssh_servers():
for gateway_ip in getLANips():
netspec = findssh.netfromaddress(gateway_ip)
coro = findssh.get_hosts(netspec, 22, "ssh", 1.0)
sys.stdout = open(os.devnull, "w")
lanhosts = asyncio.run(coro)
sys.stdout = sys.__stdout__
with open(os.devnull, "w") as devnull:
with contextlib.redirect_stdout(devnull):
lanhosts = asyncio.run(coro)
coro = findssh.get_hosts(ipaddress.IPv4Network(DEFAULT_USBGADGET_IPNETWORK), 22, "ssh", 1.0)
sys.stdout = open(os.devnull, "w")
usbhosts = asyncio.run(coro)
sys.stdout = sys.__stdout__
with open(os.devnull, "w") as devnull:
with contextlib.redirect_stdout(devnull):
usbhosts = asyncio.run(coro)
for ip in lanhosts+usbhosts:
result.append(str(ip[0]))
return result
Expand Down
69 changes: 69 additions & 0 deletions tests/test_whaletag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import ipaddress
import sys
from unittest import mock

import pytest

from ceti.whaletag import find_ssh_servers


class TestFindSshServers:
"""Tests for find_ssh_servers() -- issue #40."""

@mock.patch("ceti.whaletag.findssh.get_hosts", new_callable=mock.Mock)
@mock.patch("ceti.whaletag.findssh.netfromaddress")
@mock.patch("ceti.whaletag.getLANips")
@mock.patch("ceti.whaletag.asyncio.run")
def test_restores_stdout_on_success(
self, mock_asyncio_run, mock_get_lanips, mock_netfromaddress, mock_get_hosts
):
mock_get_lanips.return_value = [ipaddress.ip_address("192.168.1.1")]
mock_netfromaddress.return_value = "fake-netspec"
mock_get_hosts.return_value = "fake-coro"
mock_asyncio_run.side_effect = [
[("192.168.1.10", 22)],
[("192.168.11.2", 22)],
]

orig_stdout = sys.stdout
result = find_ssh_servers()

assert sys.stdout is orig_stdout
assert result == ["192.168.1.10", "192.168.11.2"]

@mock.patch("ceti.whaletag.findssh.get_hosts", new_callable=mock.Mock)
@mock.patch("ceti.whaletag.findssh.netfromaddress")
@mock.patch("ceti.whaletag.getLANips")
@mock.patch("ceti.whaletag.asyncio.run")
def test_restores_stdout_on_first_run_error(
self, mock_asyncio_run, mock_get_lanips, mock_netfromaddress, mock_get_hosts
):
mock_get_lanips.return_value = [ipaddress.ip_address("192.168.1.1")]
mock_netfromaddress.return_value = "fake-netspec"
mock_get_hosts.return_value = "fake-coro"
mock_asyncio_run.side_effect = RuntimeError("first run failed")

orig_stdout = sys.stdout
with pytest.raises(RuntimeError, match="first run failed"):
find_ssh_servers()
assert sys.stdout is orig_stdout

@mock.patch("ceti.whaletag.findssh.get_hosts", new_callable=mock.Mock)
@mock.patch("ceti.whaletag.findssh.netfromaddress")
@mock.patch("ceti.whaletag.getLANips")
@mock.patch("ceti.whaletag.asyncio.run")
def test_restores_stdout_on_second_run_error(
self, mock_asyncio_run, mock_get_lanips, mock_netfromaddress, mock_get_hosts
):
mock_get_lanips.return_value = [ipaddress.ip_address("192.168.1.1")]
mock_netfromaddress.return_value = "fake-netspec"
mock_get_hosts.return_value = "fake-coro"
mock_asyncio_run.side_effect = [
[("192.168.1.10", 22)],
RuntimeError("second run failed"),
]

orig_stdout = sys.stdout
with pytest.raises(RuntimeError, match="second run failed"):
find_ssh_servers()
assert sys.stdout is orig_stdout