From 238b60712d7552e7aba79fe40011c418ffa1012e Mon Sep 17 00:00:00 2001 From: Jakub Zerko Date: Mon, 25 May 2026 10:58:18 +0200 Subject: [PATCH 1/2] chore: add Metro build infrastructure --- build-logic/plugins/build.gradle.kts | 1 + .../src/main/kotlin/AndroidApplicationConventionPlugin.kt | 1 + .../plugins/src/main/kotlin/AndroidLibraryConventionPlugin.kt | 1 + build.gradle.kts | 1 + gradle/libs.versions.toml | 3 +++ 5 files changed, 7 insertions(+) diff --git a/build-logic/plugins/build.gradle.kts b/build-logic/plugins/build.gradle.kts index da308216ea4..8d6465be2fc 100644 --- a/build-logic/plugins/build.gradle.kts +++ b/build-logic/plugins/build.gradle.kts @@ -38,6 +38,7 @@ dependencies { compileOnly(libs.android.gradlePlugin) compileOnly(libs.kotlin.gradlePlugin) compileOnly(libs.kover.gradlePlugin) + compileOnly(libs.metro.gradlePlugin) testImplementation(libs.junit4) testImplementation(kotlin("test")) diff --git a/build-logic/plugins/src/main/kotlin/AndroidApplicationConventionPlugin.kt b/build-logic/plugins/src/main/kotlin/AndroidApplicationConventionPlugin.kt index 380fe40ba8f..7aa01ba2779 100644 --- a/build-logic/plugins/src/main/kotlin/AndroidApplicationConventionPlugin.kt +++ b/build-logic/plugins/src/main/kotlin/AndroidApplicationConventionPlugin.kt @@ -29,6 +29,7 @@ class AndroidApplicationConventionPlugin : Plugin { override fun apply(target: Project): Unit = with(target) { with(pluginManager) { apply("com.android.application") + apply("dev.zacsweers.metro") } extensions.configure { diff --git a/build-logic/plugins/src/main/kotlin/AndroidLibraryConventionPlugin.kt b/build-logic/plugins/src/main/kotlin/AndroidLibraryConventionPlugin.kt index f35a7bd3c84..d87aeb7bd7a 100644 --- a/build-logic/plugins/src/main/kotlin/AndroidLibraryConventionPlugin.kt +++ b/build-logic/plugins/src/main/kotlin/AndroidLibraryConventionPlugin.kt @@ -28,6 +28,7 @@ class AndroidLibraryConventionPlugin : Plugin { override fun apply(target: Project): Unit = with(target) { with(pluginManager) { apply("com.android.library") + apply("dev.zacsweers.metro") } extensions.configure { diff --git a/build.gradle.kts b/build.gradle.kts index d5fbf8590b1..da264968690 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -58,6 +58,7 @@ allprojects { plugins { id(ScriptPlugins.infrastructure) alias(libs.plugins.ksp) apply false // https://github.com/google/dagger/issues/3965 + alias(libs.plugins.metro) apply false alias(libs.plugins.compose.compiler) apply false alias(libs.plugins.cyclonedx) } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 362f88416ba..86cc322328f 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -24,6 +24,7 @@ desugaring = "2.1.5" firebaseBOM = "34.7.0" fragment = "1.5.6" resaca = "5.0.2" +metro = "1.1.1" bundlizer = "0.8.0" squareup-javapoet = "1.13.0" visibilityModifiers = "1.1.0" @@ -133,6 +134,7 @@ kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } aboutLibraries = { id = "com.mikepenz.aboutlibraries.plugin.android", version.ref = "aboutLibraries" } ksp = { id = "com.google.devtools.ksp", version.ref = "ksp" } +metro = { id = "dev.zacsweers.metro", version.ref = "metro" } screenshot = { id = "com.android.compose.screenshot", version.ref = "screenshot"} compose-compiler = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" } cyclonedx = { id = "org.cyclonedx.bom", version.ref = "cyclonedx" } @@ -163,6 +165,7 @@ googleGms-gradlePlugin = { module = "com.google.gms:google-services", version.re googleGms-location = { module = "com.google.android.gms:play-services-location", version.ref = "gms-location" } aboutLibraries-gradlePlugin = { module = "com.mikepenz.aboutlibraries.plugin:aboutlibraries-plugin", version.ref = "aboutLibraries" } kover-gradlePlugin = { module = "org.jetbrains.kotlinx:kover-gradle-plugin", version.ref = "kover" } +metro-gradlePlugin = { module = "dev.zacsweers.metro:dev.zacsweers.metro.gradle.plugin", version.ref = "metro" } ktx-serialization = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "ktx-serialization" } ktx-dateTime = { module = "org.jetbrains.kotlinx:kotlinx-datetime", version.ref = "ktx-dateTime" } ktx-immutableCollections = { module = "org.jetbrains.kotlinx:kotlinx-collections-immutable", version.ref = "ktx-immutableCollections" } From c6d607b0501da506c921f34c662a2ff3d004971f Mon Sep 17 00:00:00 2001 From: Jakub Zerko Date: Mon, 25 May 2026 15:05:26 +0200 Subject: [PATCH 2/2] refactor: prepare Android ViewModel creation boundary --- .../wire/android/di/AssistedViewModelExt.kt | 25 ++++ .../wire/android/navigation/MainNavHost.kt | 10 +- .../login/LoginSavedInputStore.kt | 24 ++++ .../ui/authentication/login/LoginScreen.kt | 16 ++- .../ui/authentication/login/LoginViewModel.kt | 25 +--- .../login/SavedStateLoginSavedInputStore.kt | 52 +++++++ .../email/LoginEmailVerificationCodeScreen.kt | 3 +- .../login/email/LoginEmailViewModel.kt | 28 ++-- .../login/sso/LoginSSOScreen.kt | 25 ++-- .../login/sso/LoginSSOViewModel.kt | 135 +++++++++++++----- .../login/password/NewLoginPasswordScreen.kt | 6 +- .../ui/authentication/LoginViewModelTest.kt | 12 +- .../SavedStateLoginSavedInputStoreTest.kt | 44 ++++++ .../login/email/LoginEmailViewModelTest.kt | 15 +- .../login/sso/LoginSSOViewModelTest.kt | 106 ++++++++++---- 15 files changed, 391 insertions(+), 135 deletions(-) create mode 100644 app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginSavedInputStore.kt create mode 100644 app/src/main/kotlin/com/wire/android/ui/authentication/login/SavedStateLoginSavedInputStore.kt create mode 100644 app/src/test/kotlin/com/wire/android/ui/authentication/login/SavedStateLoginSavedInputStoreTest.kt diff --git a/app/src/main/kotlin/com/wire/android/di/AssistedViewModelExt.kt b/app/src/main/kotlin/com/wire/android/di/AssistedViewModelExt.kt index f239779156e..c3c11f867e6 100644 --- a/app/src/main/kotlin/com/wire/android/di/AssistedViewModelExt.kt +++ b/app/src/main/kotlin/com/wire/android/di/AssistedViewModelExt.kt @@ -19,7 +19,11 @@ package com.wire.android.di import androidx.activity.ComponentActivity import androidx.activity.viewModels +import androidx.compose.runtime.Composable +import androidx.hilt.navigation.compose.hiltViewModel import androidx.lifecycle.ViewModel +import androidx.lifecycle.ViewModelStoreOwner +import androidx.lifecycle.viewmodel.compose.LocalViewModelStoreOwner import dagger.hilt.android.lifecycle.withCreationCallback inline fun ComponentActivity.assistedViewModels( @@ -27,3 +31,24 @@ inline fun ComponentActivity.assistedViewM ) = viewModels(extrasProducer = { defaultViewModelCreationExtras.withCreationCallback { factory -> create(factory) } }) + +@Composable +inline fun wireViewModel( + viewModelStoreOwner: ViewModelStoreOwner = checkNotNull(LocalViewModelStoreOwner.current) { + "No ViewModelStoreOwner was provided via LocalViewModelStoreOwner" + }, + key: String? = null, +): VM = hiltViewModel(viewModelStoreOwner = viewModelStoreOwner, key = key) + +@Composable +inline fun wireViewModel( + viewModelStoreOwner: ViewModelStoreOwner = checkNotNull(LocalViewModelStoreOwner.current) { + "No ViewModelStoreOwner was provided via LocalViewModelStoreOwner" + }, + key: String? = null, + noinline creationCallback: (VMF) -> VM +): VM = hiltViewModel( + viewModelStoreOwner = viewModelStoreOwner, + key = key, + creationCallback = creationCallback +) diff --git a/app/src/main/kotlin/com/wire/android/navigation/MainNavHost.kt b/app/src/main/kotlin/com/wire/android/navigation/MainNavHost.kt index dfe770da4f5..593ebadf21b 100644 --- a/app/src/main/kotlin/com/wire/android/navigation/MainNavHost.kt +++ b/app/src/main/kotlin/com/wire/android/navigation/MainNavHost.kt @@ -31,6 +31,7 @@ import com.ramcosta.composedestinations.DestinationsNavHost import com.ramcosta.composedestinations.generated.app.destinations.ConversationScreenDestination import com.ramcosta.composedestinations.generated.app.destinations.NewLoginPasswordScreenDestination import com.ramcosta.composedestinations.generated.app.destinations.NewLoginVerificationCodeScreenDestination +import com.ramcosta.composedestinations.generated.app.navArgs import com.ramcosta.composedestinations.generated.app.navgraphs.NewConversationGraph import com.ramcosta.composedestinations.generated.app.navgraphs.PersonalToTeamMigrationGraph import com.ramcosta.composedestinations.generated.cells.destinations.SearchScreenDestination @@ -49,7 +50,9 @@ import com.ramcosta.composedestinations.scope.resultRecipient import com.ramcosta.composedestinations.spec.Direction import com.wire.android.feature.cells.ui.CellViewModel import com.wire.android.feature.sketch.model.DrawingCanvasNavBackArgs +import com.wire.android.di.wireViewModel import com.wire.android.navigation.transition.LocalSharedTransitionScope +import com.wire.android.ui.authentication.login.LoginNavArgs import com.wire.android.ui.authentication.login.email.LoginEmailViewModel import com.wire.android.ui.home.conversations.ConversationScreen import com.wire.android.ui.home.newconversation.NewConversationViewModel @@ -94,7 +97,12 @@ fun MainNavHost( val loginPasswordEntry = remember(navBackStackEntry) { navController.getBackStackEntry(NewLoginPasswordScreenDestination.route) } - dependency(hiltViewModel(loginPasswordEntry)) + dependency( + wireViewModel( + viewModelStoreOwner = loginPasswordEntry, + creationCallback = { factory -> factory.create(loginPasswordEntry.navArgs()) } + ) + ) } // 👇 To reuse CellViewModel from the parent screen on SearchScreen diff --git a/app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginSavedInputStore.kt b/app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginSavedInputStore.kt new file mode 100644 index 00000000000..c64aab83185 --- /dev/null +++ b/app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginSavedInputStore.kt @@ -0,0 +1,24 @@ +/* + * Wire + * Copyright (C) 2026 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ + +package com.wire.android.ui.authentication.login + +interface LoginSavedInputStore { + var userIdentifier: String? + var ssoCode: String? +} diff --git a/app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginScreen.kt b/app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginScreen.kt index 6df72b5a708..7c18dacb721 100644 --- a/app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginScreen.kt +++ b/app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginScreen.kt @@ -44,12 +44,12 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalFocusManager import androidx.compose.ui.platform.LocalSoftwareKeyboardController import androidx.compose.ui.res.stringResource -import androidx.hilt.navigation.compose.hiltViewModel import com.ramcosta.composedestinations.generated.app.destinations.E2EIEnrollmentScreenDestination import com.ramcosta.composedestinations.generated.app.destinations.HomeScreenDestination import com.ramcosta.composedestinations.generated.app.destinations.InitialSyncScreenDestination import com.ramcosta.composedestinations.generated.app.destinations.RemoveDeviceScreenDestination import com.wire.android.R +import com.wire.android.di.wireViewModel import com.wire.android.navigation.BackStackMode import com.wire.android.navigation.NavigationCommand import com.wire.android.navigation.Navigator @@ -86,7 +86,9 @@ import kotlinx.coroutines.launch fun LoginScreen( navigator: Navigator, loginNavArgs: LoginNavArgs, - loginEmailViewModel: LoginEmailViewModel = hiltViewModel() + loginEmailViewModel: LoginEmailViewModel = wireViewModel( + creationCallback = { factory -> factory.create(loginNavArgs) } + ) ) { LoginContent( @@ -105,6 +107,7 @@ fun LoginScreen( onRemoveDeviceNeeded = { navigator.navigate(NavigationCommand(RemoveDeviceScreenDestination, BackStackMode.CLEAR_WHOLE)) }, + loginNavArgs = loginNavArgs, loginEmailViewModel = loginEmailViewModel, ssoLoginResult = loginNavArgs.ssoLoginResult, ssoCodeAutoLogin = loginNavArgs.ssoCodeAutoLogin @@ -116,6 +119,7 @@ private fun LoginContent( onBackPressed: () -> Unit, onSuccess: (initialSyncCompleted: Boolean, isE2EIRequired: Boolean) -> Unit, onRemoveDeviceNeeded: () -> Unit, + loginNavArgs: LoginNavArgs, loginEmailViewModel: LoginEmailViewModel, ssoLoginResult: DeepLinkResult.SSOLogin?, ssoCodeAutoLogin: SSOCodeAutoLogin?, @@ -139,6 +143,7 @@ private fun LoginContent( onBackPressed = onBackPressed, onSuccess = onSuccess, onRemoveDeviceNeeded = onRemoveDeviceNeeded, + loginNavArgs = loginNavArgs, loginEmailViewModel = loginEmailViewModel, ssoLoginResult = ssoLoginResult, ssoCodeAutoLogin = ssoCodeAutoLogin @@ -154,6 +159,7 @@ private fun MainLoginContent( onBackPressed: () -> Unit, onSuccess: (initialSyncCompleted: Boolean, isE2EIRequired: Boolean) -> Unit, onRemoveDeviceNeeded: () -> Unit, + loginNavArgs: LoginNavArgs, loginEmailViewModel: LoginEmailViewModel, ssoLoginResult: DeepLinkResult.SSOLogin?, ssoCodeAutoLogin: SSOCodeAutoLogin?, @@ -233,6 +239,7 @@ private fun MainLoginContent( LoginTabItem.SSO -> LoginSSOScreen( onSuccess, onRemoveDeviceNeeded, + loginNavArgs, ssoLoginResult, ssoCodeAutoLogin, ) @@ -264,7 +271,10 @@ private fun PreviewLoginScreen() = WireTheme { onBackPressed = {}, onSuccess = { _, _ -> }, onRemoveDeviceNeeded = {}, - loginEmailViewModel = hiltViewModel(), + loginNavArgs = LoginNavArgs(), + loginEmailViewModel = wireViewModel( + creationCallback = { factory -> factory.create(LoginNavArgs()) } + ), ssoLoginResult = null, ssoCodeAutoLogin = null ) diff --git a/app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginViewModel.kt b/app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginViewModel.kt index 2feb1ea62ff..a687f0cd20a 100644 --- a/app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginViewModel.kt +++ b/app/src/main/kotlin/com/wire/android/ui/authentication/login/LoginViewModel.kt @@ -18,12 +18,9 @@ package com.wire.android.ui.authentication.login -import androidx.lifecycle.SavedStateHandle import androidx.lifecycle.ViewModel import com.wire.android.datastore.UserDataStoreProvider import com.wire.android.di.ClientScopeProvider -import com.wire.android.di.KaliumCoreLogic -import com.ramcosta.composedestinations.generated.app.navArgs import com.wire.kalium.logic.CoreLogic import com.wire.kalium.logic.configuration.server.ServerConfig import com.wire.kalium.logic.data.client.ClientCapability @@ -32,13 +29,10 @@ import com.wire.kalium.logic.feature.auth.AddAuthenticatedUserUseCase import com.wire.kalium.logic.feature.auth.AuthenticationResult import com.wire.kalium.logic.feature.auth.DomainLookupUseCase import com.wire.kalium.logic.feature.client.RegisterClientResult -import dagger.hilt.android.lifecycle.HiltViewModel -import javax.inject.Inject -@HiltViewModel @Suppress("TooManyFunctions") open class LoginViewModel( - savedStateHandle: SavedStateHandle, + loginNavArgs: LoginNavArgs, val clientScopeProviderFactory: ClientScopeProvider.Factory, val userDataStoreProvider: UserDataStoreProvider, val coreLogic: CoreLogic, @@ -46,23 +40,6 @@ open class LoginViewModel( defaultServerConfig: ServerConfig.Links ) : ViewModel() { - @Inject - constructor( - savedStateHandle: SavedStateHandle, - clientScopeProviderFactory: ClientScopeProvider.Factory, - userDataStoreProvider: UserDataStoreProvider, - @KaliumCoreLogic coreLogic: CoreLogic, - defaultServerConfig: ServerConfig.Links - ) : this( - savedStateHandle, - clientScopeProviderFactory, - userDataStoreProvider, - coreLogic, - LoginViewModelExtension(clientScopeProviderFactory, userDataStoreProvider), - defaultServerConfig - ) - - private val loginNavArgs: LoginNavArgs = savedStateHandle.navArgs() val serverConfig: ServerConfig.Links = loginNavArgs.loginPasswordPath?.customServerConfig ?: defaultServerConfig suspend fun registerClient( diff --git a/app/src/main/kotlin/com/wire/android/ui/authentication/login/SavedStateLoginSavedInputStore.kt b/app/src/main/kotlin/com/wire/android/ui/authentication/login/SavedStateLoginSavedInputStore.kt new file mode 100644 index 00000000000..8e2fa94a810 --- /dev/null +++ b/app/src/main/kotlin/com/wire/android/ui/authentication/login/SavedStateLoginSavedInputStore.kt @@ -0,0 +1,52 @@ +/* + * Wire + * Copyright (C) 2026 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ + +package com.wire.android.ui.authentication.login + +import androidx.lifecycle.SavedStateHandle +import dagger.Module +import dagger.Provides +import dagger.hilt.InstallIn +import dagger.hilt.android.components.ViewModelComponent + +private const val USER_IDENTIFIER_SAVED_STATE_KEY = "user_identifier" +private const val SSO_CODE_SAVED_STATE_KEY = "sso_code" + +class SavedStateLoginSavedInputStore( + private val savedStateHandle: SavedStateHandle, +) : LoginSavedInputStore { + override var userIdentifier: String? + get() = savedStateHandle[USER_IDENTIFIER_SAVED_STATE_KEY] + set(value) { + savedStateHandle[USER_IDENTIFIER_SAVED_STATE_KEY] = value + } + + override var ssoCode: String? + get() = savedStateHandle[SSO_CODE_SAVED_STATE_KEY] + set(value) { + savedStateHandle[SSO_CODE_SAVED_STATE_KEY] = value + } +} + +@Module +@InstallIn(ViewModelComponent::class) +object LoginSavedInputStoreModule { + @Provides + fun provideLoginSavedInputStore(savedStateHandle: SavedStateHandle): LoginSavedInputStore = + SavedStateLoginSavedInputStore(savedStateHandle) +} diff --git a/app/src/main/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailVerificationCodeScreen.kt b/app/src/main/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailVerificationCodeScreen.kt index 73a453494bb..d2343846cd9 100644 --- a/app/src/main/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailVerificationCodeScreen.kt +++ b/app/src/main/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailVerificationCodeScreen.kt @@ -20,7 +20,6 @@ package com.wire.android.ui.authentication.login.email import androidx.compose.foundation.text.input.TextFieldState import androidx.compose.runtime.Composable -import androidx.hilt.navigation.compose.hiltViewModel import com.wire.android.ui.authentication.login.LoginState import com.wire.android.ui.authentication.verificationcode.VerificationCodeScreenContent import com.wire.android.ui.authentication.verificationcode.VerificationCodeState @@ -29,7 +28,7 @@ import com.wire.android.util.ui.PreviewMultipleThemes @Composable fun LoginEmailVerificationCodeScreen( - viewModel: LoginEmailViewModel = hiltViewModel() + viewModel: LoginEmailViewModel ) = VerificationCodeScreenContent( viewModel.secondFactorVerificationCodeTextState, viewModel.secondFactorVerificationCodeState, diff --git a/app/src/main/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailViewModel.kt b/app/src/main/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailViewModel.kt index 11ab4c2bfd7..a2474d2ba9f 100644 --- a/app/src/main/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailViewModel.kt +++ b/app/src/main/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailViewModel.kt @@ -25,21 +25,21 @@ import androidx.compose.foundation.text.input.setTextAndPlaceCursorAtEnd import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.setValue -import androidx.lifecycle.SavedStateHandle import androidx.lifecycle.viewModelScope import com.wire.android.datastore.UserDataStoreProvider import com.wire.android.di.ClientScopeProvider import com.wire.android.di.DefaultWebSocketEnabledByDefault import com.wire.android.di.KaliumCoreLogic import com.wire.android.ui.authentication.login.LoginNavArgs +import com.wire.android.ui.authentication.login.LoginSavedInputStore import com.wire.android.ui.authentication.login.LoginState import com.wire.android.ui.authentication.login.LoginViewModel +import com.wire.android.ui.authentication.login.LoginViewModelExtension import com.wire.android.ui.authentication.login.PreFilledUserIdentifierType import com.wire.android.ui.authentication.login.isProxyAuthRequired import com.wire.android.ui.authentication.login.toLoginError import com.wire.android.ui.authentication.verificationcode.VerificationCodeState import com.wire.android.ui.common.textfield.textAsFlow -import com.ramcosta.composedestinations.generated.app.navArgs import com.wire.android.util.EMPTY import com.wire.android.util.dispatchers.DispatcherProvider import com.wire.android.util.ui.CountdownTimer @@ -58,6 +58,9 @@ import com.wire.kalium.logic.feature.auth.autoVersioningAuth.AutoVersionAuthScop import com.wire.kalium.logic.feature.auth.verification.RequestSecondFactorVerificationCodeUseCase import com.wire.kalium.logic.feature.client.RegisterClientResult import com.wire.kalium.logic.feature.session.CurrentSessionResult +import dagger.assisted.Assisted +import dagger.assisted.AssistedFactory +import dagger.assisted.AssistedInject import dagger.hilt.android.lifecycle.HiltViewModel import kotlinx.coroutines.Job import kotlinx.coroutines.flow.MutableStateFlow @@ -68,14 +71,14 @@ import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.update import kotlinx.coroutines.launch import kotlinx.coroutines.withContext -import javax.inject.Inject @Suppress("LongParameterList", "ComplexMethod", "TooManyFunctions") -@HiltViewModel -class LoginEmailViewModel @Inject constructor( +@HiltViewModel(assistedFactory = LoginEmailViewModel.Factory::class) +class LoginEmailViewModel @AssistedInject constructor( + @Assisted val loginNavArgs: LoginNavArgs, private val addAuthenticatedUser: AddAuthenticatedUserUseCase, clientScopeProviderFactory: ClientScopeProvider.Factory, - private val savedStateHandle: SavedStateHandle, + private val savedInputStore: LoginSavedInputStore, userDataStoreProvider: UserDataStoreProvider, @KaliumCoreLogic coreLogic: CoreLogic, private val resendCodeTimer: CountdownTimer, @@ -83,13 +86,13 @@ class LoginEmailViewModel @Inject constructor( defaultServerConfig: ServerConfig.Links, @DefaultWebSocketEnabledByDefault private val defaultWebSocketEnabledByDefault: Boolean, ) : LoginViewModel( - savedStateHandle, + loginNavArgs, clientScopeProviderFactory, userDataStoreProvider, coreLogic, + LoginViewModelExtension(clientScopeProviderFactory, userDataStoreProvider), defaultServerConfig ) { - val loginNavArgs: LoginNavArgs = savedStateHandle.navArgs() private val preFilledUserIdentifier: PreFilledUserIdentifierType = loginNavArgs.userHandle ?: PreFilledUserIdentifierType.None val userIdentifierTextState: TextFieldState = TextFieldState() @@ -105,18 +108,23 @@ class LoginEmailViewModel @Inject constructor( @VisibleForTesting internal val loginJobData = MutableStateFlow(null) + @AssistedFactory + interface Factory { + fun create(args: LoginNavArgs): LoginEmailViewModel + } + init { userIdentifierTextState.setTextAndPlaceCursorAtEnd( if (preFilledUserIdentifier is PreFilledUserIdentifierType.PreFilled) { preFilledUserIdentifier.userIdentifier } else { - savedStateHandle[USER_IDENTIFIER_SAVED_STATE_KEY] ?: String.EMPTY + savedInputStore.userIdentifier ?: String.EMPTY } ) viewModelScope.launch { combine( userIdentifierTextState.textAsFlow().distinctUntilChanged().onEach { - savedStateHandle[USER_IDENTIFIER_SAVED_STATE_KEY] = it.toString() + savedInputStore.userIdentifier = it.toString() }, passwordTextState.textAsFlow(), proxyIdentifierTextState.textAsFlow(), diff --git a/app/src/main/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOScreen.kt b/app/src/main/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOScreen.kt index 98b46e9579b..5224a06f295 100644 --- a/app/src/main/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOScreen.kt +++ b/app/src/main/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOScreen.kt @@ -29,7 +29,6 @@ import androidx.compose.foundation.layout.padding import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.text.KeyboardOptions import androidx.compose.foundation.text.input.TextFieldState -import androidx.compose.foundation.text.input.setTextAndPlaceCursorAtEnd import androidx.compose.foundation.verticalScroll import androidx.compose.material3.MaterialTheme import androidx.compose.runtime.Composable @@ -42,9 +41,11 @@ import androidx.compose.ui.platform.testTag import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.input.ImeAction import androidx.compose.ui.text.input.KeyboardType -import androidx.hilt.navigation.compose.hiltViewModel import com.wire.android.R +import com.wire.android.di.wireViewModel import com.wire.android.ui.authentication.login.LoginErrorDialog +import com.wire.android.ui.authentication.login.LoginNavArgs +import com.wire.android.ui.authentication.login.SSOCodeAutoLogin import com.wire.android.ui.authentication.login.LoginState import com.wire.android.ui.authentication.login.toLoginDialogErrorData import com.wire.android.ui.common.button.WireButtonState @@ -64,9 +65,12 @@ import kotlinx.coroutines.flow.onEach fun LoginSSOScreen( onSuccess: (initialSyncCompleted: Boolean, isE2EIRequired: Boolean) -> Unit, onRemoveDeviceNeeded: () -> Unit, + loginNavArgs: LoginNavArgs, ssoLoginResult: DeepLinkResult.SSOLogin?, - ssoCodeAutoLogin: com.wire.android.ui.authentication.login.SSOCodeAutoLogin?, - loginSSOViewModel: LoginSSOViewModel = hiltViewModel(), + ssoCodeAutoLogin: SSOCodeAutoLogin?, + loginSSOViewModel: LoginSSOViewModel = wireViewModel( + creationCallback = { factory -> factory.create(loginNavArgs) } + ), scrollState: ScrollState = rememberScrollState() ) { val scope = rememberCoroutineScope() @@ -81,13 +85,12 @@ fun LoginSSOScreen( // Handle SSO code auto-login from intent parameter LaunchedEffect(ssoCodeAutoLogin) { ssoCodeAutoLogin?.let { - // Pre-fill the SSO code - loginSSOViewModel.ssoTextState.setTextAndPlaceCursorAtEnd(it.ssoCode) - - // Auto-initiate login if flag is set - if (it.autoInitiateLogin) { - loginSSOViewModel.login() - } + loginSSOViewModel.handleSSOCodeAutoLogin( + ssoCode = it.ssoCode, + autoInitiateLogin = it.autoInitiateLogin, + nomadServiceUrl = it.nomadServiceUrl, + cookieLabel = it.cookieLabel, + ) } } LoginSSOContent( diff --git a/app/src/main/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOViewModel.kt b/app/src/main/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOViewModel.kt index 2e755ae0335..391ce34f65d 100644 --- a/app/src/main/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOViewModel.kt +++ b/app/src/main/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOViewModel.kt @@ -18,15 +18,14 @@ package com.wire.android.ui.authentication.login.sso +import android.database.sqlite.SQLiteException import androidx.annotation.VisibleForTesting import androidx.compose.foundation.text.input.TextFieldState import androidx.compose.foundation.text.input.setTextAndPlaceCursorAtEnd import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.setValue -import androidx.lifecycle.SavedStateHandle import androidx.lifecycle.viewModelScope -import com.ramcosta.composedestinations.generated.app.navArgs import com.wire.android.appLogger import com.wire.android.config.DefaultServerConfig import com.wire.android.datastore.UserDataStoreProvider @@ -34,8 +33,10 @@ import com.wire.android.di.ClientScopeProvider import com.wire.android.di.DefaultWebSocketEnabledByDefault import com.wire.android.di.KaliumCoreLogic import com.wire.android.ui.authentication.login.LoginNavArgs +import com.wire.android.ui.authentication.login.LoginSavedInputStore import com.wire.android.ui.authentication.login.LoginState import com.wire.android.ui.authentication.login.LoginViewModel +import com.wire.android.ui.authentication.login.LoginViewModelExtension import com.wire.android.ui.authentication.login.toLoginError import com.wire.android.ui.common.dialogs.CustomServerDetailsDialogState import com.wire.android.ui.common.textfield.textAsFlow @@ -57,43 +58,58 @@ import com.wire.kalium.logic.feature.auth.sso.SSOLoginSessionResult import com.wire.kalium.logic.feature.backup.RestoreCryptoStateResult import com.wire.kalium.logic.feature.client.RegisterClientResult import com.wire.kalium.logic.feature.session.DoesValidSessionExistResult +import dagger.assisted.Assisted +import dagger.assisted.AssistedFactory +import dagger.assisted.AssistedInject import dagger.hilt.android.lifecycle.HiltViewModel +import java.io.IOException import kotlinx.coroutines.CancellationException import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.collectLatest import kotlinx.coroutines.flow.distinctUntilChanged import kotlinx.coroutines.launch import kotlinx.coroutines.withContext -import android.database.sqlite.SQLiteException -import java.io.IOException -import javax.inject.Inject @Suppress("LongParameterList", "TooManyFunctions") -@HiltViewModel -class LoginSSOViewModel( - private val savedStateHandle: SavedStateHandle, - val addAuthenticatedUser: AddAuthenticatedUserUseCase, - private val validateEmailUseCase: ValidateEmailUseCase, - coreLogic: CoreLogic, - clientScopeProviderFactory: ClientScopeProvider.Factory, - userDataStoreProvider: UserDataStoreProvider, - private val ssoExtension: LoginSSOViewModelExtension, - serverConfig: ServerConfig.Links, - private val dispatchers: DispatcherProvider, -) : LoginViewModel( - savedStateHandle, - clientScopeProviderFactory, - userDataStoreProvider, - coreLogic, - serverConfig -) { - private val loginNavArgs: LoginNavArgs = savedStateHandle.navArgs() - private var pendingNomadServiceUrl: String? = loginNavArgs.ssoCodeAutoLogin?.nomadServiceUrl - private var pendingCookieLabel: String? = loginNavArgs.ssoCodeAutoLogin?.cookieLabel - - @Inject +@HiltViewModel(assistedFactory = LoginSSOViewModel.Factory::class) +class LoginSSOViewModel : LoginViewModel { + private val savedInputStore: LoginSavedInputStore + val addAuthenticatedUser: AddAuthenticatedUserUseCase + private val validateEmailUseCase: ValidateEmailUseCase + private val ssoExtension: LoginSSOViewModelExtension + private val dispatchers: DispatcherProvider + + private var pendingNomadServiceUrl: String? = null + private var pendingCookieLabel: String? = null + + constructor( + loginNavArgs: LoginNavArgs, + savedInputStore: LoginSavedInputStore, + addAuthenticatedUser: AddAuthenticatedUserUseCase, + validateEmailUseCase: ValidateEmailUseCase, + coreLogic: CoreLogic, + clientScopeProviderFactory: ClientScopeProvider.Factory, + userDataStoreProvider: UserDataStoreProvider, + serverConfig: ServerConfig.Links, + ssoExtension: LoginSSOViewModelExtension, + dispatchers: DispatcherProvider, + ) : this( + loginNavArgs, + savedInputStore, + addAuthenticatedUser, + validateEmailUseCase, + coreLogic, + clientScopeProviderFactory, + userDataStoreProvider, + ssoExtension, + serverConfig, + dispatchers, + ) + + @AssistedInject constructor( - savedStateHandle: SavedStateHandle, + @Assisted loginNavArgs: LoginNavArgs, + savedInputStore: LoginSavedInputStore, addAuthenticatedUser: AddAuthenticatedUserUseCase, validateEmailUseCase: ValidateEmailUseCase, @KaliumCoreLogic coreLogic: CoreLogic, @@ -103,30 +119,65 @@ class LoginSSOViewModel( @DefaultWebSocketEnabledByDefault defaultWebSocketEnabledByDefault: Boolean, dispatchers: DispatcherProvider, ) : this( - savedStateHandle, + loginNavArgs, + savedInputStore, addAuthenticatedUser, validateEmailUseCase, coreLogic, clientScopeProviderFactory, userDataStoreProvider, - LoginSSOViewModelExtension(addAuthenticatedUser, coreLogic, defaultWebSocketEnabledByDefault), serverConfig, + LoginSSOViewModelExtension(addAuthenticatedUser, coreLogic, defaultWebSocketEnabledByDefault), dispatchers, ) + private constructor( + loginNavArgs: LoginNavArgs, + savedInputStore: LoginSavedInputStore, + addAuthenticatedUser: AddAuthenticatedUserUseCase, + validateEmailUseCase: ValidateEmailUseCase, + coreLogic: CoreLogic, + clientScopeProviderFactory: ClientScopeProvider.Factory, + userDataStoreProvider: UserDataStoreProvider, + ssoExtension: LoginSSOViewModelExtension, + serverConfig: ServerConfig.Links, + dispatchers: DispatcherProvider, + ) : super( + loginNavArgs, + clientScopeProviderFactory, + userDataStoreProvider, + coreLogic, + LoginViewModelExtension(clientScopeProviderFactory, userDataStoreProvider), + serverConfig + ) { + this.savedInputStore = savedInputStore + this.addAuthenticatedUser = addAuthenticatedUser + this.validateEmailUseCase = validateEmailUseCase + this.ssoExtension = ssoExtension + this.dispatchers = dispatchers + pendingNomadServiceUrl = loginNavArgs.ssoCodeAutoLogin?.nomadServiceUrl + pendingCookieLabel = loginNavArgs.ssoCodeAutoLogin?.cookieLabel + observeSSOCodeInput() + } + var openWebUrl = MutableSharedFlow>() val ssoTextState: TextFieldState = TextFieldState() var loginState: LoginSSOState by mutableStateOf(LoginSSOState()) - init { - ssoTextState.setTextAndPlaceCursorAtEnd(savedStateHandle[SSO_CODE_SAVED_STATE_KEY] ?: String.EMPTY) + @AssistedFactory + interface Factory { + fun create(args: LoginNavArgs): LoginSSOViewModel + } + + private fun observeSSOCodeInput() { + ssoTextState.setTextAndPlaceCursorAtEnd(savedInputStore.ssoCode ?: String.EMPTY) viewModelScope.launch { ssoTextState.textAsFlow().distinctUntilChanged().collectLatest { if (loginState.flowState != LoginState.Loading) { updateSSOFlowState(LoginState.Default) } - savedStateHandle[SSO_CODE_SAVED_STATE_KEY] = it.toString() + savedInputStore.ssoCode = it.toString() } } } @@ -184,6 +235,21 @@ class LoginSSOViewModel( } } + fun handleSSOCodeAutoLogin( + ssoCode: String, + autoInitiateLogin: Boolean, + nomadServiceUrl: String?, + cookieLabel: String?, + ) { + pendingNomadServiceUrl = nomadServiceUrl + pendingCookieLabel = cookieLabel + ssoTextState.setTextAndPlaceCursorAtEnd(ssoCode) + + if (autoInitiateLogin) { + login() + } + } + @VisibleForTesting fun domainLookupFlow() { viewModelScope.launch { @@ -374,7 +440,6 @@ class LoginSSOViewModel( } companion object { - const val SSO_CODE_SAVED_STATE_KEY = "sso_code" private const val TAG = "[LoginSSOViewModel]" } diff --git a/app/src/main/kotlin/com/wire/android/ui/newauthentication/login/password/NewLoginPasswordScreen.kt b/app/src/main/kotlin/com/wire/android/ui/newauthentication/login/password/NewLoginPasswordScreen.kt index e30f391b07f..b19e0f0c2f1 100644 --- a/app/src/main/kotlin/com/wire/android/ui/newauthentication/login/password/NewLoginPasswordScreen.kt +++ b/app/src/main/kotlin/com/wire/android/ui/newauthentication/login/password/NewLoginPasswordScreen.kt @@ -46,10 +46,10 @@ import androidx.compose.ui.semantics.testTagsAsResourceId import androidx.compose.ui.text.input.ImeAction import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextDecoration -import androidx.hilt.navigation.compose.hiltViewModel import com.wire.android.BuildConfig import com.wire.android.BuildConfig.ENABLE_NEW_REGISTRATION import com.wire.android.R +import com.wire.android.di.wireViewModel import com.wire.android.navigation.BackStackMode import com.wire.android.navigation.NavigationCommand import com.wire.android.navigation.Navigator @@ -105,7 +105,9 @@ import com.wire.kalium.logic.configuration.server.ServerConfig fun NewLoginPasswordScreen( navigator: Navigator, navArgs: LoginNavArgs, - loginEmailViewModel: LoginEmailViewModel = hiltViewModel() + loginEmailViewModel: LoginEmailViewModel = wireViewModel( + creationCallback = { factory -> factory.create(navArgs) } + ) ) { clearAutofillTree() LoginStateNavigationAndDialogs(loginEmailViewModel, navigator) diff --git a/app/src/test/kotlin/com/wire/android/ui/authentication/LoginViewModelTest.kt b/app/src/test/kotlin/com/wire/android/ui/authentication/LoginViewModelTest.kt index 8a391b7327d..537a55a2d55 100644 --- a/app/src/test/kotlin/com/wire/android/ui/authentication/LoginViewModelTest.kt +++ b/app/src/test/kotlin/com/wire/android/ui/authentication/LoginViewModelTest.kt @@ -18,14 +18,13 @@ package com.wire.android.ui.authentication -import androidx.lifecycle.SavedStateHandle import com.wire.android.config.CoroutineTestExtension import com.wire.android.datastore.UserDataStoreProvider import com.wire.android.di.ClientScopeProvider import com.wire.android.ui.authentication.login.LoginNavArgs import com.wire.android.ui.authentication.login.LoginPasswordPath import com.wire.android.ui.authentication.login.LoginViewModel -import com.ramcosta.composedestinations.generated.app.navArgs +import com.wire.android.ui.authentication.login.LoginViewModelExtension import com.wire.kalium.logic.CoreLogic import com.wire.kalium.logic.configuration.server.ServerConfig import com.wire.kalium.logic.data.id.QualifiedID @@ -47,9 +46,6 @@ class LoginViewModelTest { @MockK private lateinit var qualifiedIdMapper: QualifiedIdMapper - @MockK - private lateinit var savedStateHandle: SavedStateHandle - @MockK private lateinit var userDataStoreProvider: UserDataStoreProvider @@ -62,14 +58,12 @@ class LoginViewModelTest { fun setup() { MockKAnnotations.init(this) every { qualifiedIdMapper.fromStringToQualifiedID(any()) } returns QualifiedID("", "") - every { savedStateHandle.navArgs() } returns LoginNavArgs( - loginPasswordPath = LoginPasswordPath(ServerConfig.STAGING) - ) loginViewModel = LoginViewModel( - savedStateHandle, + LoginNavArgs(loginPasswordPath = LoginPasswordPath(ServerConfig.STAGING)), clientScopeProviderFactory, userDataStoreProvider, coreLogic, + LoginViewModelExtension(clientScopeProviderFactory, userDataStoreProvider), ServerConfig.STAGING ) } diff --git a/app/src/test/kotlin/com/wire/android/ui/authentication/login/SavedStateLoginSavedInputStoreTest.kt b/app/src/test/kotlin/com/wire/android/ui/authentication/login/SavedStateLoginSavedInputStoreTest.kt new file mode 100644 index 00000000000..6d05abc1085 --- /dev/null +++ b/app/src/test/kotlin/com/wire/android/ui/authentication/login/SavedStateLoginSavedInputStoreTest.kt @@ -0,0 +1,44 @@ +/* + * Wire + * Copyright (C) 2026 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ + +package com.wire.android.ui.authentication.login + +import androidx.lifecycle.SavedStateHandle +import com.wire.android.assertions.shouldBeEqualTo +import org.junit.jupiter.api.Test + +class SavedStateLoginSavedInputStoreTest { + + @Test + fun givenUserIdentifierIsSet_whenReadingItBack_thenValueIsReturned() { + val store = SavedStateLoginSavedInputStore(SavedStateHandle()) + + store.userIdentifier = "user@example.com" + + store.userIdentifier shouldBeEqualTo "user@example.com" + } + + @Test + fun givenSsoCodeIsSet_whenReadingItBack_thenValueIsReturned() { + val store = SavedStateLoginSavedInputStore(SavedStateHandle()) + + store.ssoCode = "wire-sso-code" + + store.ssoCode shouldBeEqualTo "wire-sso-code" + } +} diff --git a/app/src/test/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailViewModelTest.kt b/app/src/test/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailViewModelTest.kt index 727e11dd517..dfeea04efd9 100644 --- a/app/src/test/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailViewModelTest.kt +++ b/app/src/test/kotlin/com/wire/android/ui/authentication/login/email/LoginEmailViewModelTest.kt @@ -21,7 +21,6 @@ package com.wire.android.ui.authentication.login.email import androidx.compose.foundation.text.input.setTextAndPlaceCursorAtEnd -import androidx.lifecycle.SavedStateHandle import app.cash.turbine.test import com.wire.android.assertions.shouldBeEqualTo import com.wire.android.assertions.shouldBeInstanceOf @@ -37,8 +36,8 @@ import com.wire.android.di.ClientScopeProvider import com.wire.android.framework.TestClient import com.wire.android.ui.authentication.login.LoginNavArgs import com.wire.android.ui.authentication.login.LoginPasswordPath +import com.wire.android.ui.authentication.login.LoginSavedInputStore import com.wire.android.ui.authentication.login.LoginState -import com.ramcosta.composedestinations.generated.app.navArgs import com.wire.android.util.EMPTY import com.wire.android.util.newServerConfig import com.wire.android.util.ui.CountdownTimer @@ -815,7 +814,7 @@ class LoginEmailViewModelTest { internal lateinit var getOrRegisterClientUseCase: GetOrRegisterClientUseCase @MockK - internal lateinit var savedStateHandle: SavedStateHandle + internal lateinit var savedInputStore: LoginSavedInputStore @MockK internal lateinit var qualifiedIdMapper: QualifiedIdMapper @@ -853,17 +852,14 @@ class LoginEmailViewModelTest { init { MockKAnnotations.init(this, relaxUnitFun = true) mockUri() - every { savedStateHandle.get(any()) } returns null + every { savedInputStore.userIdentifier } returns null every { qualifiedIdMapper.fromStringToQualifiedID(any()) } returns USER_ID - every { savedStateHandle.set(any(), any()) } returns Unit + every { savedInputStore.userIdentifier = any() } returns Unit every { coreLogic.getGlobalScope().validateEmailUseCase } returns validateEmailUseCase every { coreLogic.getSessionScope(any()).users } returns userScope every { userScope.persistSelfUserEmail } returns persistSelfUserEmailUseCase every { clientScopeProviderFactory.create(any()).clientScope } returns clientScope every { clientScope.getOrRegister } returns getOrRegisterClientUseCase - every { savedStateHandle.navArgs() } returns LoginNavArgs( - loginPasswordPath = LoginPasswordPath(newServerConfig(1).links) - ) coEvery { autoVersionAuthScopeUseCase(any()) } returns AutoVersionAuthScopeUseCase.Result.Success(authenticationScope) every { authenticationScope.login } returns loginUseCase every { authenticationScope.requestSecondFactorVerificationCode } returns requestSecondFactorCodeUseCase @@ -877,9 +873,10 @@ class LoginEmailViewModelTest { } fun arrange() = this to LoginEmailViewModel( + LoginNavArgs(loginPasswordPath = LoginPasswordPath(newServerConfig(1).links)), addAuthenticatedUserUseCase, clientScopeProviderFactory, - savedStateHandle, + savedInputStore, userDataStoreProvider, coreLogic, countdownTimer, diff --git a/app/src/test/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOViewModelTest.kt b/app/src/test/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOViewModelTest.kt index b56397c214c..7b80ae9fc5e 100644 --- a/app/src/test/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOViewModelTest.kt +++ b/app/src/test/kotlin/com/wire/android/ui/authentication/login/sso/LoginSSOViewModelTest.kt @@ -19,9 +19,7 @@ package com.wire.android.ui.authentication.login.sso import androidx.compose.foundation.text.input.setTextAndPlaceCursorAtEnd -import androidx.lifecycle.SavedStateHandle import app.cash.turbine.test -import com.ramcosta.composedestinations.generated.app.navArgs import com.wire.android.assertions.shouldBeEqualTo import com.wire.android.assertions.shouldBeInstanceOf import com.wire.android.assertions.shouldNotBeInstanceOf @@ -29,13 +27,13 @@ import com.wire.android.config.CoroutineTestExtension import com.wire.android.config.NavigationTestExtension import com.wire.android.config.SnapshotExtension import com.wire.android.config.TestDispatcherProvider -import com.wire.android.config.mockUri import com.wire.android.datastore.UserDataStoreProvider import com.wire.android.di.ClientScopeProvider import com.wire.android.framework.TestClient import com.wire.android.framework.TestUser import com.wire.android.ui.authentication.login.LoginNavArgs import com.wire.android.ui.authentication.login.LoginPasswordPath +import com.wire.android.ui.authentication.login.LoginSavedInputStore import com.wire.android.ui.authentication.login.LoginState import com.wire.android.ui.authentication.login.SSOCodeAutoLogin import com.wire.android.ui.common.dialogs.CustomServerDetailsDialogState @@ -204,6 +202,49 @@ class LoginSSOViewModelTest { } } + @Test + fun `given auto login params, when handled, then sso code is prefilled without reading navigation args`() { + val expectedSSOCode = "wire-fd994b20-b9af-11ec-ae36-00163e9b33ca" + val (_, loginViewModel) = Arrangement().arrange() + + loginViewModel.handleSSOCodeAutoLogin( + ssoCode = expectedSSOCode, + autoInitiateLogin = false, + nomadServiceUrl = "https://nomad.example.com/service", + cookieLabel = "shared-device" + ) + + loginViewModel.ssoTextState.text.toString() shouldBeEqualTo expectedSSOCode + } + + @Test + fun `given auto login params with auto initiate, when handled, then login starts with cookie label`() = runTest { + val expectedSSOCode = "wire-fd994b20-b9af-11ec-ae36-00163e9b33ca" + val (arrangement, loginViewModel) = Arrangement() + .withValidateEmailReturning(false) + .withInitiateSSO(expectedSSOCode) + .arrange() + + loginViewModel.handleSSOCodeAutoLogin( + ssoCode = expectedSSOCode, + autoInitiateLogin = true, + nomadServiceUrl = "https://nomad.example.com/service", + cookieLabel = "shared-device" + ) + advanceUntilIdle() + + coVerify(exactly = 1) { + arrangement.ssoExtension.initiateSSO( + eq(SERVER_CONFIG.links), + eq(expectedSSOCode), + eq("shared-device"), + any(), + any(), + any() + ) + } + } + @Test fun `given sso code and button is clicked, when login returns InvalidCode error, then InvalidCodeError is passed`() = runTest { val expectedSSOCode = "wire-fd994b20-b9af-11ec-ae36-00163e9b33ca" @@ -1018,7 +1059,7 @@ class LoginSSOViewModelTest { private class Arrangement { @MockK - lateinit var savedStateHandle: SavedStateHandle + lateinit var savedInputStore: LoginSavedInputStore @MockK lateinit var ssoInitiateLoginUseCase: SSOInitiateLoginUseCase @@ -1076,15 +1117,10 @@ class LoginSSOViewModelTest { init { MockKAnnotations.init(this) - mockUri() - every { savedStateHandle.get(any()) } returns null - every { savedStateHandle.set(any(), any()) } returns Unit + every { savedInputStore.ssoCode } returns null + every { savedInputStore.ssoCode = any() } returns Unit every { clientScopeProviderFactory.create(any()).clientScope } returns clientScope every { clientScope.getOrRegister } returns getOrRegisterClientUseCase - every { savedStateHandle.navArgs() } returns LoginNavArgs( - loginPasswordPath = LoginPasswordPath(SERVER_CONFIG.links) - ) - coEvery { autoVersionAuthScopeUseCase(null) } returns AutoVersionAuthScopeUseCase.Result.Success( @@ -1179,28 +1215,40 @@ class LoginSSOViewModelTest { coEvery { doesValidSessionExistUseCase(any()) } returns DoesValidSessionExistResult.Success(valid) } + private var ssoCodeAutoLogin: SSOCodeAutoLogin? = null + fun withNomadAutoLogin(nomadServiceUrl: String) = apply { - every { savedStateHandle.navArgs() } returns LoginNavArgs( - loginPasswordPath = LoginPasswordPath(SERVER_CONFIG.links), - ssoCodeAutoLogin = SSOCodeAutoLogin( - ssoCode = "wire-sso-code", - nomadServiceUrl = nomadServiceUrl, - cookieLabel = "shared-device" - ) + ssoCodeAutoLogin = SSOCodeAutoLogin( + ssoCode = "wire-sso-code", + autoInitiateLogin = false, + nomadServiceUrl = nomadServiceUrl, + cookieLabel = "shared-device" ) } - fun arrange() = this to LoginSSOViewModel( - savedStateHandle = savedStateHandle, - addAuthenticatedUser = addAuthenticatedUserUseCase, - validateEmailUseCase = validateEmailUseCase, - coreLogic = coreLogic, - clientScopeProviderFactory = clientScopeProviderFactory, - userDataStoreProvider = userDataStoreProvider, - serverConfig = SERVER_CONFIG.links, - ssoExtension = ssoExtension, - dispatchers = TestDispatcherProvider(), - ) + fun arrange(): Pair { + val viewModel = LoginSSOViewModel( + loginNavArgs = LoginNavArgs(loginPasswordPath = LoginPasswordPath(SERVER_CONFIG.links)), + savedInputStore = savedInputStore, + addAuthenticatedUser = addAuthenticatedUserUseCase, + validateEmailUseCase = validateEmailUseCase, + coreLogic = coreLogic, + clientScopeProviderFactory = clientScopeProviderFactory, + userDataStoreProvider = userDataStoreProvider, + serverConfig = SERVER_CONFIG.links, + ssoExtension = ssoExtension, + dispatchers = TestDispatcherProvider(), + ) + ssoCodeAutoLogin?.let { + viewModel.handleSSOCodeAutoLogin( + ssoCode = it.ssoCode, + autoInitiateLogin = it.autoInitiateLogin, + nomadServiceUrl = it.nomadServiceUrl, + cookieLabel = it.cookieLabel, + ) + } + return this to viewModel + } } companion object {