Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/webpubsub/azure-messaging-webpubsub-client/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

### Bugs Fixed

- Fixed a race condition where Web PubSub client send operations could miss fast ack responses.

### Other Changes

## 1.1.7 (2026-01-29)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ public Mono<WebPubSubResult> joinGroup(String group, Long ackId) {
if (ackId == null) {
ackId = nextAckId();
}
return sendMessage(new JoinGroupMessage().setGroup(group).setAckId(ackId)).then(waitForAckMessage(ackId))
return sendMessageAndWaitForAck(new JoinGroupMessage().setGroup(group).setAckId(ackId), ackId)
.retryWhen(sendMessageRetrySpec)
.map(result -> {
groups.compute(group,
Expand Down Expand Up @@ -346,7 +346,7 @@ public Mono<WebPubSubResult> leaveGroup(String group, Long ackId) {
if (ackId == null) {
ackId = nextAckId();
}
return sendMessage(new LeaveGroupMessage().setGroup(group).setAckId(ackId)).then(waitForAckMessage(ackId))
return sendMessageAndWaitForAck(new LeaveGroupMessage().setGroup(group).setAckId(ackId), ackId)
.retryWhen(sendMessageRetrySpec)
.map(result -> {
groups.compute(group,
Expand Down Expand Up @@ -414,8 +414,7 @@ public Mono<WebPubSubResult> sendToGroup(String group, BinaryData content, WebPu
.setAckId(ackId)
.setNoEcho(options.isEchoDisabled());

Mono<Void> sendMessageMono = sendMessage(message);
Mono<WebPubSubResult> responseMono = sendMessageMono.then(waitForAckMessage(ackId));
Mono<WebPubSubResult> responseMono = sendMessageAndWaitForAck(message, ackId);
return responseMono.retryWhen(sendMessageRetrySpec);
}

Expand Down Expand Up @@ -454,8 +453,7 @@ public Mono<WebPubSubResult> sendEvent(String eventName, BinaryData content, Web
.setDataType(dataFormat.toString())
.setAckId(ackId);

Mono<Void> sendMessageMono = sendMessage(message);
Mono<WebPubSubResult> responseMono = sendMessageMono.then(waitForAckMessage(ackId));
Mono<WebPubSubResult> responseMono = sendMessageAndWaitForAck(message, ackId);
return responseMono.retryWhen(sendMessageRetrySpec);
}

Expand Down Expand Up @@ -514,13 +512,7 @@ public Flux<RejoinGroupFailedEvent> receiveRejoinGroupFailedEvents() {
}

private long nextAckId() {
return ackId.getAndUpdate(value -> {
// keep positive
if (++value < 0) {
value = 0;
}
return value;
});
return ackId.updateAndGet(value -> value == Long.MAX_VALUE ? 1 : Math.max(0, value) + 1);
}

private Flux<AckMessage> receiveAckMessages() {
Expand All @@ -540,6 +532,11 @@ private Mono<Void> sendMessage(WebPubSubMessage message) {
}));
}

private Mono<WebPubSubResult> sendMessageAndWaitForAck(WebPubSubMessage message, Long ackId) {
return Mono.defer(() -> Mono.zip(waitForAckMessage(ackId), sendMessage(message).thenReturn(true))
.map(tuple -> tuple.getT1()));
}

private Mono<Void> checkStateBeforeSend() {
return Mono.defer(() -> {
WebPubSubClientState state = clientState.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,33 @@ private boolean processMessage(ChannelHandlerContext context, WebSocketFrame web
}

private void publishBuffer() {
final ByteBuffer[] nioBuffers = compositeByteBuf.nioBuffers();

if (nioBuffers.length == 0) {
if (compositeByteBuf == null) {
return;
}

if (compositeByteBuf.refCnt() == 0) {
compositeByteBuf = null;
return;
}
try {
if (compositeByteBuf.readableBytes() == 0) {
return;
}

final ByteBuffer[] nioBuffers = compositeByteBuf.nioBuffers();

if (nioBuffers.length == 0) {
return;
}

final BinaryData data = BinaryData.fromListByteBuffer(Arrays.asList(nioBuffers));
final String collected = data.toString();
final WebPubSubMessage deserialized = messageDecoder.decode(collected);

messageHandler.accept(deserialized);
} finally {
release(compositeByteBuf);
compositeByteBuf = null;
}
}

Expand All @@ -198,9 +211,8 @@ CloseWebSocketFrame getServerCloseWebSocketFrame() {
}

private static void release(CompositeByteBuf buffer) {
if (buffer.refCnt() > 0) {
if (buffer != null && buffer.refCnt() > 0) {
buffer.release();
buffer.clear();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ public void testTwoClients() throws InterruptedException {

@Test
@LiveOnly
public void testClientCloseable() {
public void testClientCloseable() throws InterruptedException {
CountDownLatch connectedLatch = new CountDownLatch(1);
CountDownLatch stoppedLatch = new CountDownLatch(1);
CountDownLatch disconnectedLatch = new CountDownLatch(1);
AtomicBoolean stoppedEventReceived = new AtomicBoolean(false);
AtomicBoolean disconnectedEventReceived = new AtomicBoolean(false);

Expand All @@ -105,16 +106,20 @@ public void testClientCloseable() {
stoppedLatch.countDown();
});
client.addOnConnectedEventHandler(connectedEvent -> connectedLatch.countDown());
client.addOnDisconnectedEventHandler(disconnectedEvent -> disconnectedEventReceived.set(true));
client.addOnDisconnectedEventHandler(disconnectedEvent -> {
disconnectedEventReceived.set(true);
disconnectedLatch.countDown();
});

client.start();

connectedLatch.countDown();
Assertions.assertTrue(connectedLatch.await(10, TimeUnit.SECONDS));

// stop not called explicitly
}

stoppedLatch.countDown();
Assertions.assertTrue(stoppedLatch.await(10, TimeUnit.SECONDS));
Assertions.assertTrue(disconnectedLatch.await(10, TimeUnit.SECONDS));

// verify client stopped via Closeable
Assertions.assertTrue(stoppedEventReceived.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,28 @@
package com.azure.messaging.webpubsub.client;

import com.azure.messaging.webpubsub.client.implementation.WebPubSubClientState;
import com.azure.messaging.webpubsub.client.implementation.models.AckMessage;
import com.azure.messaging.webpubsub.client.implementation.models.ConnectedMessage;
import com.azure.messaging.webpubsub.client.implementation.models.WebPubSubMessage;
import com.azure.messaging.webpubsub.client.implementation.models.WebPubSubMessageAck;
import com.azure.messaging.webpubsub.client.implementation.websocket.SendResult;
import com.azure.messaging.webpubsub.client.implementation.websocket.WebSocketClient;
import com.azure.messaging.webpubsub.client.implementation.websocket.WebSocketSession;
import com.azure.messaging.webpubsub.client.models.ConnectFailedException;
import com.azure.messaging.webpubsub.client.models.ConnectedEvent;
import com.azure.messaging.webpubsub.client.models.WebPubSubResult;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

public class MockClientTests {
Expand Down Expand Up @@ -71,6 +76,32 @@ public void testConnect() throws InterruptedException {
Assertions.assertEquals(1, events.size());
}

@Test
public void testGeneratedAckIdStartsAtOne() {
List<Long> ackIds = new ArrayList<>();
AtomicReference<Consumer<WebPubSubMessage>> messageHandlerReference = new AtomicReference<>();

WebSocketClient mockWsClient = (cec, path, loggerReference, messageHandler, openHandler, closeHandler) -> {
messageHandlerReference.set(messageHandler);
WebSocketSession mockWsSession = new MockWebSocketSession(true, messageHandlerReference, ackIds);
openHandler.accept(mockWsSession);
messageHandler.accept(new ConnectedMessage("mock_connection_id"));
return mockWsSession;
};

WebPubSubClientBuilder builder = new WebPubSubClientBuilder();
builder.webSocketClient = mockWsClient;
WebPubSubClient client = builder.clientAccessUrl("mock").buildClient();

client.start();
WebPubSubResult joinResult = client.joinGroup("group");
WebPubSubResult sendResult = client.sendToGroup("group", "message");

Assertions.assertEquals(1L, joinResult.getAckId());
Assertions.assertEquals(2L, sendResult.getAckId());
Assertions.assertIterableEquals(Arrays.asList(1L, 2L), ackIds);
}

private static void sendConnectedEvent(Consumer<WebPubSubMessage> messageHandler) {
Mono.delay(SMALL_DELAY)
.then(Mono.fromRunnable(() -> messageHandler.accept(new ConnectedMessage("mock_connection_id")))
Expand All @@ -79,14 +110,34 @@ private static void sendConnectedEvent(Consumer<WebPubSubMessage> messageHandler
}

private static final class MockWebSocketSession implements WebSocketSession {
private final boolean open;
private final AtomicReference<Consumer<WebPubSubMessage>> messageHandlerReference;
private final List<Long> ackIds;

private MockWebSocketSession() {
this(false, null, null);
}

private MockWebSocketSession(boolean open, AtomicReference<Consumer<WebPubSubMessage>> messageHandlerReference,
List<Long> ackIds) {
this.open = open;
this.messageHandlerReference = messageHandlerReference;
this.ackIds = ackIds;
}

@Override
public boolean isOpen() {
return false;
return open;
}

@Override
public void sendObjectAsync(Object data, Consumer<SendResult> handler) {

if (data instanceof WebPubSubMessageAck) {
long ackId = ((WebPubSubMessageAck) data).getAckId();
ackIds.add(ackId);
messageHandlerReference.get().accept(new AckMessage().setAckId(ackId).setSuccess(true));
}
handler.accept(new SendResult());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,31 @@
*/
public class TestBase extends TestProxyTestBase {

private static volatile WebPubSubServiceClient serviceClient;

private static WebPubSubServiceClient getServiceClient() {
if (serviceClient == null) {
synchronized (TestBase.class) {
if (serviceClient == null) {
Configuration configuration = Configuration.getGlobalConfiguration();

serviceClient
= new WebPubSubServiceClientBuilder().endpoint(configuration.get("WEB_PUB_SUB_ENDPOINT"))
.credential(TestUtils.getIdentityTestCredential(TestMode.LIVE))
.hub("hub1")
.buildClient();
}
}
}
return serviceClient;
}

protected static WebPubSubClientBuilder getClientBuilder() {
return getClientBuilder("user1");
}

protected static WebPubSubClientBuilder getClientBuilder(String userId) {
WebPubSubServiceClient client = new WebPubSubServiceClientBuilder()
.endpoint(Configuration.getGlobalConfiguration().get("WEB_PUB_SUB_ENDPOINT"))
.credential(TestUtils.getIdentityTestCredential(TestMode.LIVE))
.hub("hub1")
.buildClient();
WebPubSubServiceClient client = getServiceClient();

// client builder
return new WebPubSubClientBuilder().credential(new WebPubSubClientCredential(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ private static TokenCredential getIdentityTestCredentialHelper() {
Configuration config = Configuration.getGlobalConfiguration();

ChainedTokenCredentialBuilder builder
= new ChainedTokenCredentialBuilder().addLast(new EnvironmentCredentialBuilder().build())
.addLast(new AzureCliCredentialBuilder().build())
.addLast(new AzureDeveloperCliCredentialBuilder().build());
= new ChainedTokenCredentialBuilder().addLast(new EnvironmentCredentialBuilder().build());

String serviceConnectionId = config.get("AZURESUBSCRIPTION_SERVICE_CONNECTION_ID");
String clientId = config.get("AZURESUBSCRIPTION_CLIENT_ID");
Expand All @@ -59,7 +57,9 @@ private static TokenCredential getIdentityTestCredentialHelper() {
builder.addLast(trc -> azurePipelinesCredential.getToken(trc).subscribeOn(Schedulers.boundedElastic()));
}

builder.addLast(new AzurePowerShellCredentialBuilder().build());
builder.addLast(new AzurePowerShellCredentialBuilder().build())
.addLast(new AzureCliCredentialBuilder().build())
.addLast(new AzureDeveloperCliCredentialBuilder().build());

return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ final class TestUtils {

static String getSocketIOEndpoint() {
return Configuration.getGlobalConfiguration()
.get("WEB_PUB_SUB_SOCKETIO_ENDPOINT", "http://testsocketioendpoint.webpubsubdev.azure.com");
.get("WEB_PUB_SUB_SOCKETIO_ENDPOINT", "https://testsocketioendpoint.webpubsubdev.azure.com");
}

static String getEndpoint() {
return Configuration.getGlobalConfiguration()
.get("WEB_PUB_SUB_ENDPOINT", "http://testendpoint.webpubsubdev.azure.com");
.get("WEB_PUB_SUB_ENDPOINT", "https://testendpoint.webpubsubdev.azure.com");
}

static String getConnectionString() {
Expand Down
Loading