Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import javasabr.rlib.network.impl.DefaultBufferAllocator;
import javasabr.rlib.network.impl.DefaultConnection;
import javasabr.rlib.network.impl.StringDataConnection;
import javasabr.rlib.network.impl.StringDataMtlsServerConnection;
import javasabr.rlib.network.impl.StringDataSslConnection;
import javasabr.rlib.network.packet.impl.DefaultReadableNetworkPacket;
import javasabr.rlib.network.packet.registry.ReadableNetworkPacketRegistry;
Expand Down Expand Up @@ -140,7 +141,11 @@ public static ClientNetwork<StringDataSslConnection> stringDataSslClientNetwork(
SSLContext sslContext) {
return clientNetwork(
networkConfig,
(network, channel) -> new StringDataSslConnection(network, channel, bufferAllocator, sslContext, true));
(network, channel) -> {
StringDataSslConnection connection = new StringDataSslConnection(network, channel, bufferAllocator, sslContext, true);
connection.beginHandshake();
return connection;
});
}

/**
Expand Down Expand Up @@ -196,7 +201,11 @@ public static ServerNetwork<StringDataSslConnection> stringDataSslServerNetwork(
SSLContext sslContext) {
return serverNetwork(
networkConfig,
(network, channel) -> new StringDataSslConnection(network, channel, bufferAllocator, sslContext, false));
(network, channel) -> {
StringDataSslConnection connection = new StringDataSslConnection(network, channel, bufferAllocator, sslContext, false);
connection.beginHandshake();
return connection;
});
}

/**
Expand Down Expand Up @@ -231,4 +240,26 @@ public static ServerNetwork<DefaultConnection> defaultServerNetwork(
networkConfig,
(network, channel) -> new DefaultConnection(network, channel, bufferAllocator, packetRegistry));
}

/**
* Create string packet based asynchronous Mutual TLS server network.
*
* @param networkConfig the server network configuration
* @param bufferAllocator the buffer allocator
* @param sslContext SSL context
* @return a new mTLS server network
* @since 10.0.0
*/
public static ServerNetwork<StringDataMtlsServerConnection> stringDataMtlsServerNetwork(
ServerNetworkConfig networkConfig,
BufferAllocator bufferAllocator,
SSLContext sslContext) {
return serverNetwork(
networkConfig,
(network, channel) -> {
StringDataMtlsServerConnection connection = new StringDataMtlsServerConnection(network, channel, bufferAllocator, sslContext);
connection.beginHandshake();
return connection;
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package javasabr.rlib.network.exception;

public class ConnectionClosedException extends NetworkException {

public ConnectionClosedException(String remoteAddress) {
super("Connection closed: %s".formatted(remoteAddress));
}

public ConnectionClosedException(String remoteAddress, Throwable cause) {
super("Connection closed: %s".formatted(remoteAddress), cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import javasabr.rlib.network.Connection;
import javasabr.rlib.network.Network;
import javasabr.rlib.network.UnsafeConnection;
import javasabr.rlib.network.exception.ConnectionClosedException;
import javasabr.rlib.network.packet.NetworkPacketReader;
import javasabr.rlib.network.packet.NetworkPacketWriter;
import javasabr.rlib.network.packet.ReadableNetworkPacket;
Expand Down Expand Up @@ -64,6 +65,7 @@ public WritablePacketWithFeedback(CompletableFuture<Boolean> attachment, Writabl

final MutableArray<BiConsumer<C, ? super ReadableNetworkPacket<C>>> validPacketSubscribers;
final MutableArray<BiConsumer<C, ? super ReadableNetworkPacket<C>>> invalidPacketSubscribers;
final MutableArray<FluxSink<?>> activeSinks;

final int maxPacketsByRead;

Expand All @@ -84,6 +86,7 @@ public AbstractConnection(
this.closed = new AtomicBoolean(false);
this.validPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class);
this.invalidPacketSubscribers = ArrayFactory.copyOnModifyArray(BiConsumer.class);
this.activeSinks = ArrayFactory.stampedLockBasedArray(FluxSink.class);
this.remoteAddress = String.valueOf(NetworkUtils.getRemoteAddress(channel));
}

Expand Down Expand Up @@ -134,10 +137,12 @@ protected void registerFluxOnReceivedEvents(

validPacketSubscribers.add(validListener);
invalidPacketSubscribers.add(invalidListener);
activeSinks.add(sink);

sink.onDispose(() -> {
validPacketSubscribers.remove(validListener);
validPacketSubscribers.remove(invalidListener);
activeSinks.remove(sink);
});

network.inNetworkThread(() -> packetReader().startRead());
Expand All @@ -146,14 +151,22 @@ protected void registerFluxOnReceivedEvents(
protected void registerFluxOnReceivedValidPackets(FluxSink<? super ReadableNetworkPacket<C>> sink) {
BiConsumer<C, ReadableNetworkPacket<C>> listener = (connection, packet) -> sink.next(packet);
validPacketSubscribers.add(listener);
sink.onDispose(() -> validPacketSubscribers.remove(listener));
activeSinks.add(sink);
sink.onDispose(() -> {
validPacketSubscribers.remove(listener);
activeSinks.remove(sink);
});
network.inNetworkThread(() -> packetReader().startRead());
}

protected void registerFluxOnReceivedInvalidPackets(FluxSink<? super ReadableNetworkPacket<C>> sink) {
BiConsumer<C, ReadableNetworkPacket<C>> listener = (connection, packet) -> sink.next(packet);
invalidPacketSubscribers.add(listener);
sink.onDispose(() -> invalidPacketSubscribers.remove(listener));
activeSinks.add(sink);
sink.onDispose(() -> {
invalidPacketSubscribers.remove(listener);
activeSinks.remove(sink);
});
network.inNetworkThread(() -> packetReader().startRead());
}

Expand Down Expand Up @@ -184,6 +197,24 @@ protected void doClose() {
clearWaitPackets();
packetReader().close();
packetWriter().close();
notifySinksOnError();
}

protected void notifySinksOnError() {
if (activeSinks.isEmpty()) {
return;
}
ConnectionClosedException error = new ConnectionClosedException(remoteAddress);
activeSinks
.iterations()
.forEach(error, (sink, exc) -> {
try {
sink.error(exc);
} catch (RuntimeException e) {
log.error(e.getMessage(), "Failed to notify sink of connection closure: "::formatted);
}
});
activeSinks.clear();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ public AbstractSslConnection(
super(network, channel, bufferAllocator, maxPacketsByRead);
this.sslEngine = sslContext.createSSLEngine();
this.sslEngine.setUseClientMode(clientMode);
}

public void beginHandshake() {
try {
this.sslEngine.beginHandshake();
} catch (SSLException e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package javasabr.rlib.network.impl;

import javasabr.rlib.network.BufferAllocator;
import javasabr.rlib.network.Network;
import javasabr.rlib.network.packet.impl.StringReadableNetworkPacket;

import javax.net.ssl.SSLContext;
import java.nio.channels.AsynchronousSocketChannel;

/**
Comment thread
crazyrokr marked this conversation as resolved.
* @author crazyrokr
*/
public class StringDataMtlsServerConnection extends DefaultDataSslConnection<StringDataMtlsServerConnection> {

public StringDataMtlsServerConnection(
Network<StringDataMtlsServerConnection> network,
AsynchronousSocketChannel channel,
BufferAllocator bufferAllocator,
SSLContext sslContext) {
super(network, channel, bufferAllocator, sslContext, 100, 2, false);
sslEngine.setNeedClientAuth(true);
}

@Override
protected StringReadableNetworkPacket<StringDataMtlsServerConnection> createReadablePacket() {
return new StringReadableNetworkPacket<>();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,14 @@ protected void handleFailedReceiving(Throwable exception, ByteBuffer readingBuff
retryReadLater();
}
}
case AsynchronousCloseException ex ->
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
case ClosedChannelException ex ->
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
case AsynchronousCloseException ex -> {
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
connection.close();
}
case ClosedChannelException ex -> {
log.info(remoteAddress(), "[%s] Connection was closed"::formatted);
connection.close();
}
default -> {
log.error(exception);
connection.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ protected AbstractSslNetworkPacketReader(
protected void handleReceivedData(int receivedBytes, ByteBuffer readingBuffer) {
if (receivedBytes == -1) {
doHandshake(sslNetworkBuffer(), -1);
connection.close();
return;
}
super.handleReceivedData(receivedBytes, readingBuffer);
Expand Down Expand Up @@ -159,6 +160,9 @@ protected int doHandshake(ByteBuffer networkBuffer, int receivedBytes) {
case NEED_WRAP: {
log.debug(remoteAddress, "[%s] Send command to wrap data"::formatted);
packetWriter.accept(SslWrapRequestNetworkPacket.getInstance());
if (networkBuffer.hasRemaining()) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

return decryptAndRead(networkBuffer);
}
NetworkUtils.cleanNetworkBuffer(networkBuffer);
return SKIP_READ_PACKETS;
}
Expand Down Expand Up @@ -203,6 +207,10 @@ protected int decryptAndRead(ByteBuffer receivedBuffer) {
}
switch (result.getStatus()) {
case OK: {
if (result.bytesConsumed() == 0 && result.bytesProduced() == 0) {
log.debug(remoteAddress, "[%s] No progress during decryption, stop processing"::formatted);
return SKIP_READ_PACKETS;
}
sslDataBuffer.flip();
logDataAfterDecrypt(remoteAddress, sslDataBuffer);
total += readPackets(sslDataBuffer, sslDataPendingBuffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ protected ByteBuffer doHandshake(HandshakeStatus handshakeStatus) {
break;
}
case NEED_UNWRAP: {
break;
return EMPTY_BUFFER;
}
default: {
throw new IllegalStateException("Invalid SSL status:" + handshakeStatus);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package javasabr.rlib.network;

import static org.assertj.core.api.Assertions.assertThat;

import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javasabr.rlib.network.exception.ConnectionClosedException;
import javasabr.rlib.network.impl.AbstractConnection;
import javasabr.rlib.network.impl.DefaultConnection;
import javasabr.rlib.network.packet.impl.DefaultReadableNetworkPacket;
import javasabr.rlib.network.packet.registry.ReadableNetworkPacketRegistry;
import org.junit.jupiter.api.Test;

public class ConnectionCloseTest extends BaseNetworkTest {

@Test
void shouldPropagateConnectionCloseToClient() throws InterruptedException {
// given
var packetRegistry = ReadableNetworkPacketRegistry.of(
DefaultReadableNetworkPacket.class,
DefaultConnection.class,
DefaultNetworkTest.ServerPackets.RequestEchoMessage.class,
DefaultNetworkTest.ServerPackets.RequestServerTime.class);
var serverNetwork = NetworkFactory.defaultServerNetwork(packetRegistry);
InetSocketAddress serverAddress = serverNetwork.start();
serverNetwork.onAccept(AbstractConnection::close);
var clientNetwork = NetworkFactory.defaultClientNetwork(packetRegistry);
CountDownLatch closeLatch = new CountDownLatch(1);

// when
clientNetwork
.connectReactive(serverAddress)
.flatMapMany(AbstractConnection::receivedEvents)
.doOnError(e -> {
if (e instanceof ConnectionClosedException) {
closeLatch.countDown();
}
})
.subscribe();

// then
assertThat(closeLatch.await(5000, TimeUnit.MILLISECONDS))
.as("Client should be notified that connection is closed")
.isTrue();
clientNetwork.shutdown();
serverNetwork.shutdown();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;
import javasabr.rlib.common.util.ObjectUtils;
import javasabr.rlib.common.util.StringUtils;
import javasabr.rlib.common.util.Utils;
import javasabr.rlib.network.client.ClientNetwork;
import javasabr.rlib.network.exception.ConnectionClosedException;
import javasabr.rlib.network.impl.DefaultBufferAllocator;
import javasabr.rlib.network.impl.StringDataMtlsServerConnection;
import javasabr.rlib.network.impl.StringDataSslConnection;
import javasabr.rlib.network.packet.ReadableNetworkPacket;
import javasabr.rlib.network.packet.impl.StringReadableNetworkPacket;
Expand Down Expand Up @@ -328,6 +331,63 @@ void shouldReceiveManyPacketsFromSmallToBigSize() {
}
}

@Test
@SneakyThrows
void shouldRejectClientWithoutCertificateWithinMutualTls() {
InputStream serverKeystoreFile = StringSslNetworkTest.class.getResourceAsStream("/ssl/rlib_test_cert.p12");
SSLContext serverSslContext = NetworkUtils.createSslContext(serverKeystoreFile, "test");
ServerNetworkConfig serverConfig = ServerNetworkConfig.SimpleServerNetworkConfig.builder().build();
BufferAllocator bufferAllocator = new DefaultBufferAllocator(serverConfig);

ServerNetwork<StringDataMtlsServerConnection> serverNetwork =
NetworkFactory.stringDataMtlsServerNetwork(serverConfig, bufferAllocator, serverSslContext);

InetSocketAddress serverAddress = serverNetwork.start();
CountDownLatch dataReceivedByServer = new CountDownLatch(1);

serverNetwork
.accepted()
.flatMap(Connection::receivedEvents)
.subscribe(event -> dataReceivedByServer.countDown());

SSLContext clientWithoutCertContext = NetworkUtils.createAllTrustedClientSslContext();
ClientNetwork<StringDataSslConnection> clientNetwork = NetworkFactory.stringDataSslClientNetwork(
NetworkConfig.DEFAULT_CLIENT,
new DefaultBufferAllocator(NetworkConfig.DEFAULT_CLIENT),
clientWithoutCertContext);

AtomicReference<Throwable> connectionError = new AtomicReference<>();
CountDownLatch errorReceived = new CountDownLatch(1);

try {
clientNetwork
.connectReactive(serverAddress)
.doOnNext(connection -> connection.sendInBackground(new StringWritableNetworkPacket<>("no cert")))
.flatMapMany(Connection::receivedEvents)
.subscribe(
event -> {},
ex -> {
connectionError.set(ex);
errorReceived.countDown();
});

assertThat(errorReceived.await(5, TimeUnit.SECONDS))
.as("Client subscriber must receive an error when the server closes the mTLS connection.")
.isTrue();

assertThat(connectionError.get())
.as("Client must receive ConnectionClosedException, not a timeout.")
.isInstanceOf(ConnectionClosedException.class);

assertThat(dataReceivedByServer.getCount())
.as("Server must not receive data from an unauthenticated client.")
.isEqualTo(1);
} finally {
serverNetwork.shutdown();
clientNetwork.shutdown();
}
}

private static StringWritableNetworkPacket<StringDataSslConnection> newMessage(int minMessageLength, int maxMessageLength) {
return new StringWritableNetworkPacket<>(StringUtils.generate(minMessageLength, maxMessageLength));
}
Expand Down
Loading
Loading