Skip to content

Commit

Permalink
fix(auth): Fallback to in-memory key value storage if encryption fails (
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjroach authored Jan 13, 2025
1 parent 7d1603c commit 97ddb8b
Show file tree
Hide file tree
Showing 14 changed files with 299 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.amplifyframework.analytics.pinpoint
import android.content.Context
import aws.sdk.kotlin.services.pinpoint.PinpointClient
import aws.smithy.kotlin.runtime.auth.awscredentials.CredentialsProvider
import com.amplifyframework.core.store.EncryptedKeyValueRepository
import com.amplifyframework.core.store.AmplifyKeyValueRepository
import com.amplifyframework.pinpoint.core.AnalyticsClient
import com.amplifyframework.pinpoint.core.TargetingClient
import com.amplifyframework.pinpoint.core.data.AndroidAppDetails
Expand Down Expand Up @@ -62,7 +62,7 @@ internal class PinpointManager constructor(
Context.MODE_PRIVATE
)

val encryptedStore = EncryptedKeyValueRepository(
val amplifyStore = AmplifyKeyValueRepository(
context,
"${awsPinpointConfiguration.appId}$PINPOINT_SHARED_PREFS_SUFFIX"
)
Expand All @@ -72,7 +72,7 @@ internal class PinpointManager constructor(
targetingClient = TargetingClient(
context,
pinpointClient,
encryptedStore,
amplifyStore,
sharedPrefs,
androidAppDetails,
androidDeviceDetails
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import kotlinx.serialization.json.Json
internal class AWSCognitoAuthCredentialStore(
val context: Context,
private val authConfiguration: AuthConfiguration,
isPersistenceEnabled: Boolean = true,
keyValueRepoFactory: KeyValueRepositoryFactory = KeyValueRepositoryFactory()
) : AuthCredentialStore {

Expand All @@ -39,7 +38,7 @@ internal class AWSCognitoAuthCredentialStore(
}

private var keyValue: KeyValueRepository =
keyValueRepoFactory.create(context, awsKeyValueStoreIdentifier, isPersistenceEnabled)
keyValueRepoFactory.create(context, awsKeyValueStoreIdentifier)

//region Save Credentials
override fun saveCredential(credential: AmplifyCredential) = keyValue.put(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,21 @@ import com.amplifyframework.auth.cognito.data.AWSCognitoLegacyCredentialStore.Co
import com.amplifyframework.auth.cognito.data.AWSCognitoLegacyCredentialStore.Companion.APP_TOKENS_INFO_CACHE
import com.amplifyframework.auth.cognito.data.AWSCognitoLegacyCredentialStore.Companion.AWS_KEY_VALUE_STORE_NAMESPACE_IDENTIFIER
import com.amplifyframework.auth.cognito.data.AWSCognitoLegacyCredentialStore.Companion.AWS_MOBILE_CLIENT_PROVIDER
import com.amplifyframework.core.store.EncryptedKeyValueRepository
import com.amplifyframework.core.store.AmplifyKeyValueRepository
import com.amplifyframework.core.store.InMemoryKeyValueRepository
import com.amplifyframework.core.store.KeyValueRepository

internal class KeyValueRepositoryFactory {
fun create(context: Context, keyValueRepoID: String, persistenceEnabled: Boolean = true): KeyValueRepository {
fun create(context: Context, keyValueRepoID: String): KeyValueRepository {
return when {
keyValueRepoID == awsKeyValueStoreIdentifier -> when {
persistenceEnabled -> EncryptedKeyValueRepository(context, keyValueRepoID)
else -> InMemoryKeyValueRepository()
}
keyValueRepoID == awsKeyValueStoreIdentifier -> AmplifyKeyValueRepository(context, keyValueRepoID)

keyValueRepoID == AWS_KEY_VALUE_STORE_NAMESPACE_IDENTIFIER ||
keyValueRepoID == APP_TOKENS_INFO_CACHE ||
keyValueRepoID == AWS_MOBILE_CLIENT_PROVIDER ||
keyValueRepoID.startsWith(APP_DEVICE_INFO_CACHE) ->
LegacyKeyValueRepository(context, keyValueRepoID)

else -> InMemoryKeyValueRepository()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ import org.junit.Assert
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mock
import org.mockito.Mockito
import org.mockito.Mockito.mock
import org.mockito.Mockito.times
import org.mockito.Mockito.verify
import org.mockito.junit.MockitoJUnitRunner
import org.robolectric.RobolectricTestRunner

@RunWith(MockitoJUnitRunner::class)
@RunWith(RobolectricTestRunner::class)
class AWSCognitoAuthCredentialStoreTest {

companion object {
Expand All @@ -52,17 +52,13 @@ class AWSCognitoAuthCredentialStoreTest {

private val keyValueRepoID: String = "com.amplify.credentialStore"

@Mock
private lateinit var mockConfig: AuthConfiguration
private val mockConfig = mock(AuthConfiguration::class.java)

@Mock
private lateinit var mockContext: Context
private val mockContext = mock(Context::class.java)

@Mock
private lateinit var mockKeyValue: KeyValueRepository
private val mockKeyValue: KeyValueRepository = mock(KeyValueRepository::class.java)

@Mock
private lateinit var mockFactory: KeyValueRepositoryFactory
private val mockFactory = mock(KeyValueRepositoryFactory::class.java)

private lateinit var persistentStore: AWSCognitoAuthCredentialStore

Expand All @@ -71,8 +67,7 @@ class AWSCognitoAuthCredentialStoreTest {
Mockito.`when`(
mockFactory.create(
mockContext,
keyValueRepoID,
true
keyValueRepoID
)
).thenReturn(mockKeyValue)

Expand All @@ -84,7 +79,7 @@ class AWSCognitoAuthCredentialStoreTest {
@Test
fun testSaveCredentialWithUserPool() {
setupUserPoolConfig()
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, true, mockFactory)
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, mockFactory)
persistentStore.saveCredential(getCredential())
verify(mockKeyValue, times(1))
.put(KEY_WITH_USER_POOL, serialized(getCredential()))
Expand All @@ -93,7 +88,7 @@ class AWSCognitoAuthCredentialStoreTest {
@Test
fun testSaveCredentialWithIdentityPool() {
setupIdentityPoolConfig()
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, true, mockFactory)
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, mockFactory)

persistentStore.saveCredential(getCredential())

Expand All @@ -105,7 +100,7 @@ class AWSCognitoAuthCredentialStoreTest {
fun testSaveCredentialWithUserAndIdentityPool() {
setupUserPoolConfig()
setupIdentityPoolConfig()
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, true, mockFactory)
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, mockFactory)

persistentStore.saveCredential(getCredential())

Expand All @@ -117,7 +112,7 @@ class AWSCognitoAuthCredentialStoreTest {
fun testRetrieveCredential() {
setupUserPoolConfig()
setupIdentityPoolConfig()
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, true, mockFactory)
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, mockFactory)

val actual = persistentStore.retrieveCredential()

Expand All @@ -127,7 +122,7 @@ class AWSCognitoAuthCredentialStoreTest {
@Test
fun testDeleteCredential() {
setupUserPoolConfig()
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, true, mockFactory)
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, mockFactory)

persistentStore.deleteCredential()

Expand All @@ -136,7 +131,7 @@ class AWSCognitoAuthCredentialStoreTest {

@Test
fun testInMemoryCredentialStore() {
val store = AWSCognitoAuthCredentialStore(mockContext, mockConfig, false)
val store = AWSCognitoAuthCredentialStore(mockContext, mockConfig)

store.saveCredential(getCredential())
assertEquals(getCredential(), store.retrieveCredential())
Expand All @@ -150,7 +145,7 @@ class AWSCognitoAuthCredentialStoreTest {

setupUserPoolConfig()
setupIdentityPoolConfig()
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, true, mockFactory)
persistentStore = AWSCognitoAuthCredentialStore(mockContext, mockConfig, mockFactory)
}

private fun setupIdentityPoolConfig() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,25 @@ class AWSCognitoLegacyCredentialStoreTest {
`when`(
mockFactory.create(
mockContext,
AWSCognitoLegacyCredentialStore.AWS_KEY_VALUE_STORE_NAMESPACE_IDENTIFIER,
true
AWSCognitoLegacyCredentialStore.AWS_KEY_VALUE_STORE_NAMESPACE_IDENTIFIER
)
).thenReturn(mockKeyValue)

`when`(
mockFactory.create(
mockContext,
AWSCognitoLegacyCredentialStore.APP_TOKENS_INFO_CACHE,
true
AWSCognitoLegacyCredentialStore.APP_TOKENS_INFO_CACHE
)
).thenReturn(mockKeyValue)

`when`(
mockFactory.create(
mockContext,
AWSCognitoLegacyCredentialStore.AWS_MOBILE_CLIENT_PROVIDER,
true
AWSCognitoLegacyCredentialStore.AWS_MOBILE_CLIENT_PROVIDER
)
).thenReturn(mockKeyValue)

`when`(mockFactory.create(mockContext, deviceDetailsCacheKey, true)).thenReturn(mockKeyValue)
`when`(mockFactory.create(mockContext, deviceDetailsCacheKey)).thenReturn(mockKeyValue)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AWSCognitoIdentityPoolOperations(
private val KEY_LOGINS_PROVIDER = "amplify.${identityPool.poolId}.session.loginsProvider"
private val KEY_IDENTITY_ID = "amplify.${identityPool.poolId}.session.identityId"
private val KEY_AWS_CREDENTIALS = "amplify.${identityPool.poolId}.session.credential"
private val awsAuthCredentialStore = AuthCredentialStore(context.applicationContext, pluginKeySanitized, true)
private val awsAuthCredentialStore = AuthCredentialStore(context.applicationContext, pluginKeySanitized)

val cognitoIdentityClient = CognitoClientFactory.createIdentityClient(
identityPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,16 @@
package com.amplifyframework.auth.plugins.core.data

import android.content.Context
import com.amplifyframework.core.store.EncryptedKeyValueRepository
import com.amplifyframework.core.store.InMemoryKeyValueRepository
import com.amplifyframework.core.store.AmplifyKeyValueRepository
import com.amplifyframework.core.store.KeyValueRepository

internal class AuthCredentialStore(
context: Context,
keyValueStoreIdentifierSuffix: String,
isPersistenceEnabled: Boolean
keyValueStoreIdentifierSuffix: String
) {
private val keyValueStoreIdentifier = "com.amplify.credentialStore.$keyValueStoreIdentifierSuffix"

private val keyValue: KeyValueRepository = if (isPersistenceEnabled) {
EncryptedKeyValueRepository(context, keyValueStoreIdentifier)
} else {
InMemoryKeyValueRepository()
}
private val keyValue: KeyValueRepository = AmplifyKeyValueRepository(context, keyValueStoreIdentifier)

fun put(key: String, value: String) = keyValue.put(key, value)
fun get(key: String): String? = keyValue.get(key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import android.content.UriMatcher
import android.database.Cursor
import android.net.Uri
import androidx.annotation.VisibleForTesting
import com.amplifyframework.core.store.EncryptedKeyValueRepository
import com.amplifyframework.core.store.AmplifyKeyValueRepository
import com.amplifyframework.logging.cloudwatch.models.CloudWatchLogEvent
import java.util.UUID
import kotlinx.coroutines.CoroutineDispatcher
Expand All @@ -36,8 +36,8 @@ internal class CloudWatchLoggingDatabase(
private val logEventsId = 20
private val passphraseKey = "passphrase"
private val mb = 1024 * 1024
private val encryptedKeyValueRepository: EncryptedKeyValueRepository by lazy {
EncryptedKeyValueRepository(
private val amplifyKeyValueRepository: AmplifyKeyValueRepository by lazy {
AmplifyKeyValueRepository(
context,
"awscloudwatchloggingdb"
)
Expand Down Expand Up @@ -137,7 +137,7 @@ internal class CloudWatchLoggingDatabase(

@VisibleForTesting(otherwise = VisibleForTesting.PRIVATE)
fun getDatabasePassphrase(): String {
return encryptedKeyValueRepository.get(passphraseKey) ?: kotlin.run {
return amplifyKeyValueRepository.get(passphraseKey) ?: kotlin.run {
val passphrase = UUID.randomUUID().toString()
// If the database is restored from backup and the passphrase key is not present,
// this would result in the database file not getting loaded.
Expand All @@ -146,7 +146,7 @@ internal class CloudWatchLoggingDatabase(
if (path.exists()) {
path.delete()
}
encryptedKeyValueRepository.put(passphraseKey, passphrase)
amplifyKeyValueRepository.put(passphraseKey, passphrase)
passphrase
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import com.amplifyframework.core.Amplify
import com.amplifyframework.core.Consumer
import com.amplifyframework.core.category.CategoryType
import com.amplifyframework.core.configuration.AmplifyOutputsData
import com.amplifyframework.core.store.EncryptedKeyValueRepository
import com.amplifyframework.core.store.AmplifyKeyValueRepository
import com.amplifyframework.core.store.KeyValueRepository
import com.amplifyframework.notifications.pushnotifications.NotificationPayload
import com.amplifyframework.notifications.pushnotifications.PushNotificationResult
Expand Down Expand Up @@ -120,7 +120,7 @@ class AWSPinpointPushNotificationsPlugin : PushNotificationsPlugin<PinpointClien
configuration.appId + AWS_PINPOINT_PUSHNOTIFICATIONS_PREFERENCES_SUFFIX,
Context.MODE_PRIVATE
)
store = EncryptedKeyValueRepository(
store = AmplifyKeyValueRepository(
context,
configuration.appId + AWS_PINPOINT_PUSHNOTIFICATIONS_PREFERENCES_SUFFIX
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amplifyframework.core.store

import android.content.Context
import com.amplifyframework.annotations.InternalAmplifyApi

@InternalAmplifyApi
class AmplifyKeyValueRepository(
private val context: Context,
private val sharedPreferencesName: String
) : KeyValueRepository {

// We attempt to get an encrypted persistent repository, but if that fails, use an in memory one instead.
private val repository: KeyValueRepository by lazy {
try {
EncryptedKeyValueRepository(context, sharedPreferencesName).also {
// This attempts to open EncryptedSharedPrefs. If it opens, we are good to use.
it.sharedPreferences
}
} catch (exception: Exception) {
// We crashed attempting to open EncryptedKeyValueRepository, use In-Memory Instead.
InMemoryKeyValueRepositoryProvider.getKeyValueRepository(sharedPreferencesName)
}
}

override fun get(dataKey: String): String? = repository.get(dataKey)

override fun put(dataKey: String, value: String?) = repository.put(dataKey, value)

override fun remove(dataKey: String) = repository.remove(dataKey)

override fun removeAll() = repository.removeAll()
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class EncryptedKeyValueRepository @VisibleForTesting constructor(
fileFactory = { dir, fileName -> File(dir, fileName) }
)

private val sharedPreferences by lazy { getOrCreateSharedPreferences() }
internal val sharedPreferences by lazy { getOrCreateSharedPreferences() }

override fun put(dataKey: String, value: String?) = edit { putString(dataKey, value) }
override fun get(dataKey: String): String? = sharedPreferences.getString(dataKey, null)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amplifyframework.core.store

import java.util.concurrent.ConcurrentHashMap

internal object InMemoryKeyValueRepositoryProvider {
private val inMemoryRepositories = ConcurrentHashMap<String, InMemoryKeyValueRepository>()

@Synchronized
fun getKeyValueRepository(name: String): InMemoryKeyValueRepository {
return inMemoryRepositories.getOrPut(name) { InMemoryKeyValueRepository() }
}
}
Loading

0 comments on commit 97ddb8b

Please sign in to comment.