From 4c6397cee8339eed124dc82effc684be39b7669f Mon Sep 17 00:00:00 2001 From: ndr_brt Date: Wed, 3 Jun 2026 12:29:11 +0200 Subject: [PATCH] feat: make TransferProcessProtocolService call fail when update fails --- .../TransferProcessProtocolServiceImpl.java | 151 ++++++++---------- ...ransferProcessProtocolServiceImplTest.java | 38 ++++- 2 files changed, 101 insertions(+), 88 deletions(-) diff --git a/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/controlplane/services/transferprocess/TransferProcessProtocolServiceImpl.java b/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/controlplane/services/transferprocess/TransferProcessProtocolServiceImpl.java index 7f1ad18538b..b4d09893eac 100644 --- a/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/controlplane/services/transferprocess/TransferProcessProtocolServiceImpl.java +++ b/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/controlplane/services/transferprocess/TransferProcessProtocolServiceImpl.java @@ -102,7 +102,6 @@ public TransferProcessProtocolServiceImpl(TransferProcessStore transferProcessSt public ServiceResult notifyRequested(ParticipantContext participantContext, TransferRequestMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchContractAgreement(participantContext, message) .compose(contractAgreement -> verifyRequest(participantContext, tokenRepresentation, message, contractAgreement)) - .compose(context -> validateRequestMessage(message, context)) .compose(context -> requestedAction(participantContext, message, context))); } @@ -111,9 +110,8 @@ public ServiceResult notifyRequested(ParticipantContext partici @NotNull public ServiceResult notifyStarted(ParticipantContext participantContext, TransferStartMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchRequestContext(participantContext, message.getProcessId()) - .compose(context -> verifyRequest(participantContext, tokenRepresentation, message, context.agreement()) - .compose(this::validateCounterParty)) - .compose(context -> onMessageDo(participantContext, message, transferProcess -> startedAction(message, transferProcess))) + .compose(context -> verifyRequest(participantContext, tokenRepresentation, message, context.agreement())) + .compose(context -> onMessageDo(message, transferProcess -> startedAction(message, transferProcess))) ); } @@ -122,9 +120,8 @@ public ServiceResult notifyStarted(ParticipantContext participa @NotNull public ServiceResult notifyCompleted(ParticipantContext participantContext, TransferCompletionMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchRequestContext(participantContext, message.getProcessId()) - .compose(context -> verifyRequest(participantContext, tokenRepresentation, message, context.agreement()) - .compose(this::validateCounterParty)) - .compose(i -> onMessageDo(participantContext, message, transferProcess -> completedAction(message, transferProcess))) + .compose(context -> verifyRequest(participantContext, tokenRepresentation, message, context.agreement())) + .compose(i -> onMessageDo(message, transferProcess -> completedAction(message, transferProcess))) ); } @@ -133,9 +130,8 @@ public ServiceResult notifyCompleted(ParticipantContext partici @NotNull public ServiceResult notifySuspended(ParticipantContext participantContext, TransferSuspensionMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchRequestContext(participantContext, message.getProcessId()) - .compose(context -> verifyRequest(participantContext, tokenRepresentation, message, context.agreement()) - .compose(this::validateCounterParty)) - .compose(i -> onMessageDo(participantContext, message, transferProcess -> suspendedAction(message, transferProcess))) + .compose(context -> verifyRequest(participantContext, tokenRepresentation, message, context.agreement())) + .compose(i -> onMessageDo(message, transferProcess -> suspendedAction(message, transferProcess))) ); } @@ -144,9 +140,8 @@ public ServiceResult notifySuspended(ParticipantContext partici @NotNull public ServiceResult notifyTerminated(ParticipantContext participantContext, TransferTerminationMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchRequestContext(participantContext, message.getProcessId()) - .compose(context -> verifyRequest(participantContext, tokenRepresentation, message, context.agreement()) - .compose(this::validateCounterParty)) - .compose(i -> onMessageDo(participantContext, message, transferProcess -> terminatedAction(message, transferProcess))) + .compose(context -> verifyRequest(participantContext, tokenRepresentation, message, context.agreement())) + .compose(i -> onMessageDo(message, transferProcess -> terminatedAction(message, transferProcess))) ); } @@ -157,30 +152,49 @@ public ServiceResult findById(ParticipantContext participantCon return transactionContext.execute(() -> fetchRequestContext(participantContext, message.getTransferProcessId()) .compose(context -> verifyRequest(participantContext, tokenRepresentation, message, context.agreement()) - .compose(this::validateCounterParty) .map(it -> context.transferProcess())) ); } @NotNull - private ServiceResult requestedAction(ParticipantContext participantContext, TransferRequestMessage message, ClaimTokenContext claimTokenContext) { + private ServiceResult requestedAction(ParticipantContext participantContext, TransferRequestMessage message, ClaimTokenContext context) { + var destination = message.getDataAddress(); + if (destination != null) { + var validDestination = dataAddressValidator.validateDestination(destination); + if (validDestination.failed()) { + return ServiceResult.badRequest(validDestination.getFailureMessages()); + } + } + + var transferType = message.getTransferType(); + var supportedTransferTypes = dataFlowController.transferTypesFor(context.agreement().getAssetId()); + if (!supportedTransferTypes.contains(transferType)) { + return ServiceResult.badRequest("TransferType %s is not supported".formatted(transferType)); + } + + var validationResult = contractValidationService.validateAgreement(context.participantAgent(), context.agreement()); + if (validationResult.failed()) { + return ServiceResult.conflict(format("Cannot process %s because %s", message.getClass().getSimpleName(), "agreement not found or not valid")); + } + var existingTransferProcess = transferProcessStore.findForCorrelationId(message.getConsumerPid()); if (existingTransferProcess != null) { return ServiceResult.success(existingTransferProcess); } - return transferProcessProviderFactory.create(participantContext, message, claimTokenContext.agreement(), claimTokenContext.participantAgent()) + return transferProcessProviderFactory.create(participantContext, message, context.agreement(), context.participantAgent()) .compose(process -> { var dataAddressStorage = message.getDataAddress() == null ? StoreResult.success() : dataAddressStore.store(message.getDataAddress(), process); - return dataAddressStorage.flatMap(ServiceResult::from) - .onSuccess(ignored -> { + return dataAddressStorage + .compose(i -> { process.protocolMessageReceived(message.getId()); - update(process); - observable.invokeForEach(l -> l.initiated(process)); + return update(process); }) + .onSuccess(i -> observable.invokeForEach(l -> l.initiated(process))) + .flatMap(ServiceResult::from) .map(ignored -> process); }); } @@ -276,29 +290,6 @@ private ServiceResult terminatedAction(TransferTerminationMessa } } - private ServiceResult validateRequestMessage(TransferRequestMessage message, ClaimTokenContext context) { - var destination = message.getDataAddress(); - if (destination != null) { - var validDestination = dataAddressValidator.validateDestination(destination); - if (validDestination.failed()) { - return ServiceResult.badRequest(validDestination.getFailureMessages()); - } - } - - var transferType = message.getTransferType(); - var supportedTransferTypes = dataFlowController.transferTypesFor(context.agreement().getAssetId()); - if (!supportedTransferTypes.contains(transferType)) { - return ServiceResult.badRequest("TransferType %s is not supported".formatted(transferType)); - } - - var validationResult = contractValidationService.validateAgreement(context.participantAgent(), context.agreement()); - if (validationResult.failed()) { - return ServiceResult.conflict(format("Cannot process %s because %s", message.getClass().getSimpleName(), "agreement not found or not valid")); - } - - return ServiceResult.success(context); - } - private ServiceResult fetchContractAgreement(ParticipantContext participantContext, TransferRequestMessage message) { return Optional.ofNullable(findAgreement(participantContext, message.getContractId())) .filter(agreement -> participantContext.getParticipantContextId().equals(agreement.getParticipantContextId())) @@ -307,14 +298,20 @@ private ServiceResult fetchContractAgreement(ParticipantConte } private ServiceResult fetchRequestContext(ParticipantContext participantContext, String transferProcessId) { - return findTransferProcessByIdReadOnly(participantContext, transferProcessId) - .compose(transferProcess -> { - var agreement = negotiationStore.findContractAgreement(transferProcess.getContractId()); - if (agreement == null) { - return ServiceResult.notFound(format("No transfer process with id %s found", transferProcess.getId())); - } - return ServiceResult.success(new TransferMessageContext(agreement, transferProcess)); - }); + var transferProcess = transferProcessStore.findById(transferProcessId); + if (transferProcess == null) { + return notFound(transferProcessId); + } + + if (!participantContext.getParticipantContextId().equals(transferProcess.getParticipantContextId())) { + return notFound(transferProcess.getId()); + } + + var agreement = negotiationStore.findContractAgreement(transferProcess.getContractId()); + if (agreement == null) { + return ServiceResult.notFound(format("No transfer process with id %s found", transferProcess.getId())); + } + return ServiceResult.success(new TransferMessageContext(agreement, transferProcess)); } private ServiceResult verifyRequest(ParticipantContext participantContext, TokenRepresentation tokenRepresentation, RemoteMessage message, ContractAgreement contractAgreement) { @@ -322,46 +319,32 @@ private ServiceResult verifyRequest(ParticipantContext partic if (result.failed()) { monitor.debug(() -> "Verification Failed: %s".formatted(result.getFailureDetail())); return ServiceResult.notFound("Not found"); - } else { - return ServiceResult.success(new ClaimTokenContext(result.getContent(), contractAgreement)); } - } - private ServiceResult onMessageDo(ParticipantContext participantContext, TransferRemoteMessage message, - Function> action) { - return findAndLease(participantContext, message) - .compose(transferProcess -> { - if (transferProcess.shouldIgnoreIncomingMessage(message.getId())) { - return transferProcessStore.breakLease(transferProcess).flatMap(ServiceResult::from).map(it -> transferProcess); - } else { - return action.apply(transferProcess).onFailure(f -> transferProcessStore.breakLease(transferProcess)); - } - }); - } - - private ServiceResult validateCounterParty(ClaimTokenContext claimTokenContext) { - var validation = contractValidationService.validateRequest(claimTokenContext.participantAgent(), claimTokenContext.agreement()); + var participantAgent = result.getContent(); + var validation = contractValidationService.validateRequest(participantAgent, contractAgreement); if (validation.failed()) { return ServiceResult.badRequest(validation.getFailureMessages()); } - return ServiceResult.success(); + return ServiceResult.success(new ClaimTokenContext(participantAgent, contractAgreement)); } - // find and lease - write access - private ServiceResult findAndLease(ParticipantContext participantContext, TransferRemoteMessage remoteMessage) { + private ServiceResult onMessageDo(TransferRemoteMessage message, + Function> action) { + return transferProcessStore - .findByIdAndLease(remoteMessage.getProcessId()) + .findByIdAndLease(message.getProcessId()) .flatMap(ServiceResult::from) - .compose(tp -> filterByParticipantContext(participantContext, tp)); - } + .compose(transferProcess -> { + if (transferProcess.shouldIgnoreIncomingMessage(message.getId())) { + transferProcessStore.breakLease(transferProcess); + return ServiceResult.success(transferProcess); + } - private ServiceResult filterByParticipantContext(ParticipantContext participantContext, TransferProcess transferProcess) { - if (participantContext.getParticipantContextId().equals(transferProcess.getParticipantContextId())) { - return ServiceResult.success(transferProcess); - } else { - return notFound(transferProcess.getId()); - } + return action.apply(transferProcess) + .onFailure(f -> transferProcessStore.breakLease(transferProcess)); + }); } private ContractAgreement findAgreement(ParticipantContext participantContext, String contractId) { @@ -373,13 +356,7 @@ private ContractAgreement findAgreement(ParticipantContext participantContext, S } } - private ServiceResult findTransferProcessByIdReadOnly(ParticipantContext participantContext, String id) { - return Optional.ofNullable(transferProcessStore.findById(id)) - .map(tp -> filterByParticipantContext(participantContext, tp)) - .orElseGet(() -> notFound(id)); - } - - private ServiceResult notFound(String transferProcessId) { + private ServiceResult notFound(String transferProcessId) { return ServiceResult.notFound(format("No transfer process with id %s found", transferProcessId)); } diff --git a/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/controlplane/services/transferprocess/TransferProcessProtocolServiceImplTest.java b/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/controlplane/services/transferprocess/TransferProcessProtocolServiceImplTest.java index 6c7969d9ec0..eb8eb4931d8 100644 --- a/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/controlplane/services/transferprocess/TransferProcessProtocolServiceImplTest.java +++ b/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/controlplane/services/transferprocess/TransferProcessProtocolServiceImplTest.java @@ -337,6 +337,7 @@ void validAgreement_shouldInitiateTransfer() { .build(); when(protocolTokenValidator.verify(eq(participantContext), eq(tokenRepresentation), any(), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(validationService.validateRequest(any(), isA(ContractAgreement.class))).thenReturn(Result.success()); when(negotiationStore.queryAgreements(any())).thenReturn(Stream.of(contractAgreement())); when(validationService.validateAgreement(any(ParticipantAgent.class), any())).thenReturn(Result.success(null)); when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success()); @@ -354,6 +355,36 @@ void validAgreement_shouldInitiateTransfer() { verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); } + @Test + void shouldFail_whenInvalidRequest() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var dataAddress = DataAddress.Builder.newInstance().type("any").build(); + var message = TransferRequestMessage.Builder.newInstance() + .consumerPid("consumerPid") + .processId("consumerPid") + .contractId("agreementId") + .protocol("protocol") + .callbackAddress("http://any") + .transferType("transferType") + .dataAddress(dataAddress) + .build(); + + when(protocolTokenValidator.verify(eq(participantContext), eq(tokenRepresentation), any(), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(validationService.validateRequest(any(), isA(ContractAgreement.class))).thenReturn(Result.failure("invalid credentials")); + when(negotiationStore.queryAgreements(any())).thenReturn(Stream.of(contractAgreement())); + when(validationService.validateAgreement(any(ParticipantAgent.class), any())).thenReturn(Result.success(null)); + when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success()); + var transferProcess = transferProcess(INITIAL, "transferProcessId"); + when(transferProcessProviderFactory.create(any(), any(), any(), any())).thenReturn(ServiceResult.success(transferProcess)); + + var result = service.notifyRequested(participantContext, message, tokenRepresentation); + + assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(BAD_REQUEST); + verify(store, never()).save(any()); + verifyNoInteractions(dataAddressStore, listener); + } + @Test void shouldFail_whenDataAddressStorageFails() { var participantAgent = participantAgent(); @@ -370,6 +401,7 @@ void shouldFail_whenDataAddressStorageFails() { .build(); when(protocolTokenValidator.verify(eq(participantContext), eq(tokenRepresentation), any(), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(validationService.validateRequest(any(), isA(ContractAgreement.class))).thenReturn(Result.success()); when(negotiationStore.queryAgreements(any())).thenReturn(Stream.of(contractAgreement())); when(validationService.validateAgreement(any(ParticipantAgent.class), any())).thenReturn(Result.success(null)); when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success()); @@ -398,6 +430,7 @@ void shouldFail_whenTransferProcessCreationFails() { .build(); when(protocolTokenValidator.verify(eq(participantContext), eq(tokenRepresentation), any(), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(validationService.validateRequest(any(), isA(ContractAgreement.class))).thenReturn(Result.success()); when(negotiationStore.queryAgreements(any())).thenReturn(Stream.of(contractAgreement())); when(validationService.validateAgreement(any(ParticipantAgent.class), any())).thenReturn(Result.success(null)); when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success()); @@ -428,6 +461,7 @@ void doNothingIfProcessAlreadyExist() { var tokenRepresentation = tokenRepresentation(); when(protocolTokenValidator.verify(eq(participantContext), eq(tokenRepresentation), any(), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(validationService.validateRequest(any(), isA(ContractAgreement.class))).thenReturn(Result.success()); when(negotiationStore.queryAgreements(any())).thenReturn(Stream.of(contractAgreement())); when(validationService.validateAgreement(any(ParticipantAgent.class), any())).thenReturn(Result.success(null)); when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success()); @@ -455,6 +489,7 @@ void invalidAgreement_shouldNotInitiateTransfer() { var tokenRepresentation = tokenRepresentation(); when(protocolTokenValidator.verify(eq(participantContext), eq(tokenRepresentation), any(), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(validationService.validateRequest(any(), isA(ContractAgreement.class))).thenReturn(Result.success()); when(negotiationStore.queryAgreements(any())).thenReturn(Stream.of(contractAgreement())); when(validationService.validateAgreement(any(ParticipantAgent.class), any())).thenReturn(Result.failure("error")); when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success()); @@ -481,6 +516,7 @@ void invalidDestination_shouldNotInitiateTransfer() { when(negotiationStore.queryAgreements(any())).thenReturn(Stream.of(contractAgreement())); when(protocolTokenValidator.verify(eq(participantContext), eq(tokenRepresentation), any(), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(validationService.validateRequest(any(), isA(ContractAgreement.class))).thenReturn(Result.success()); when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.failure(violation("invalid data address", "path"))); var result = service.notifyRequested(participantContext, message, tokenRepresentation); @@ -505,6 +541,7 @@ void shouldReturnBadRequest_whenTransferTypeNotSupported() { var contractAgreement = contractAgreementBuilder().assetId("assetId").build(); when(protocolTokenValidator.verify(eq(participantContext), eq(tokenRepresentation), any(), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(validationService.validateRequest(any(), isA(ContractAgreement.class))).thenReturn(Result.success()); when(negotiationStore.queryAgreements(any())).thenReturn(Stream.of(contractAgreement)); when(validationService.validateAgreement(any(ParticipantAgent.class), any())).thenReturn(Result.success(contractAgreement)); when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success()); @@ -1161,7 +1198,6 @@ void notify_shouldReturnConflict_whenFinalState assertThat(result).isFailed().satisfies(failure -> { assertThat(failure.getReason()).isEqualTo(CONFLICT); }); - verify(store).breakLease(transferProcess); verifyNoInteractions(listener); }