diff --git a/firebase-ml-modeldownloader/CHANGELOG.md b/firebase-ml-modeldownloader/CHANGELOG.md index e2773e0ea49..9a6b9c8ea6f 100644 --- a/firebase-ml-modeldownloader/CHANGELOG.md +++ b/firebase-ml-modeldownloader/CHANGELOG.md @@ -1,4 +1,5 @@ # Unreleased +- [changed] Internal infrastructure improvements. # 24.1.1 * [fixed] Fixed an issue where `FirebaseModelDownloader.getModel` was throwing diff --git a/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle b/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle index efdc6e389c8..7b88ae75e91 100644 --- a/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle +++ b/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle @@ -14,6 +14,7 @@ plugins { id 'firebase-library' + id 'firebase-vendor' id 'com.google.protobuf' } @@ -76,6 +77,12 @@ dependencies { implementation 'com.google.auto.service:auto-service-annotations:1.0-rc6' implementation 'javax.inject:javax.inject:1' + implementation 'javax.inject:javax.inject:1' + vendor ('com.google.dagger:dagger:2.43.2') { + exclude group: "javax.inject", module: "javax.inject" + } + annotationProcessor 'com.google.dagger:dagger-compiler:2.43.2' + compileOnly "com.google.auto.value:auto-value-annotations:1.6.6" annotationProcessor "com.google.auto.value:auto-value:1.6.5" annotationProcessor project(":encoders:firebase-encoders-processor") diff --git a/firebase-ml-modeldownloader/ktx/ktx.gradle b/firebase-ml-modeldownloader/ktx/ktx.gradle index 396d946788e..849310f14cb 100644 --- a/firebase-ml-modeldownloader/ktx/ktx.gradle +++ b/firebase-ml-modeldownloader/ktx/ktx.gradle @@ -49,6 +49,9 @@ dependencies { implementation project(':firebase-ml-modeldownloader') testImplementation "org.robolectric:robolectric:$robolectricVersion" - testImplementation 'junit:junit:4.12' + testImplementation 'junit:junit:4.13.2' testImplementation "com.google.truth:truth:$googleTruthVersion" + testImplementation 'org.mockito:mockito-core:3.6.0' + testImplementation 'androidx.test:runner:1.5.1' + testImplementation 'androidx.test.ext:junit:1.1.4' } diff --git a/firebase-ml-modeldownloader/ktx/src/test/kotlin/com/google/firebase/ml/modeldownloader/ktx/ModelDownloaderTests.kt b/firebase-ml-modeldownloader/ktx/src/test/kotlin/com/google/firebase/ml/modeldownloader/ktx/ModelDownloaderTests.kt index cb049e7bc90..b2d8be7eab9 100644 --- a/firebase-ml-modeldownloader/ktx/src/test/kotlin/com/google/firebase/ml/modeldownloader/ktx/ModelDownloaderTests.kt +++ b/firebase-ml-modeldownloader/ktx/src/test/kotlin/com/google/firebase/ml/modeldownloader/ktx/ModelDownloaderTests.kt @@ -14,13 +14,14 @@ package com.google.firebase.ml.modeldownloader.ktx +import androidx.test.core.app.ApplicationProvider +import androidx.test.ext.junit.runners.AndroidJUnit4 import com.google.common.truth.Truth.assertThat import com.google.firebase.FirebaseApp import com.google.firebase.FirebaseOptions import com.google.firebase.ktx.Firebase import com.google.firebase.ktx.app import com.google.firebase.ktx.initialize -import com.google.firebase.ml.modeldownloader.CustomModel import com.google.firebase.ml.modeldownloader.FirebaseModelDownloader import com.google.firebase.platforminfo.UserAgentPublisher import org.junit.After @@ -28,7 +29,6 @@ import org.junit.Before import org.junit.Test import org.junit.runner.RunWith import org.robolectric.RobolectricTestRunner -import org.robolectric.RuntimeEnvironment const val APP_ID = "1:1234567890:android:321abc456def7890" const val API_KEY = "AIzaSyDOCAbC123dEf456GhI789jKl012-MnO" @@ -39,7 +39,7 @@ abstract class BaseTestCase { @Before fun setUp() { Firebase.initialize( - RuntimeEnvironment.application, + ApplicationProvider.getApplicationContext(), FirebaseOptions.Builder() .setApplicationId(APP_ID) .setApiKey(API_KEY) @@ -47,7 +47,7 @@ abstract class BaseTestCase { .build() ) Firebase.initialize( - RuntimeEnvironment.application, + ApplicationProvider.getApplicationContext(), FirebaseOptions.Builder() .setApplicationId(APP_ID) .setApiKey(API_KEY) @@ -63,7 +63,7 @@ abstract class BaseTestCase { } } -@RunWith(RobolectricTestRunner::class) +@RunWith(AndroidJUnit4::class) class ModelDownloaderTests : BaseTestCase() { @Test @@ -92,14 +92,17 @@ class ModelDownloaderTests : BaseTestCase() { @Test fun `CustomModel destructuring declarations work`() { + val app = Firebase.app(EXISTING_APP) + val modelName = "myModel" val modelHash = "someHash" val fileSize = 200L val downloadId = 258L - val customModel = CustomModel(modelName, modelHash, fileSize, downloadId) + val customModel = + Firebase.modelDownloader(app).modelFactory.create(modelName, modelHash, fileSize, downloadId) - val (file, size, id, hash, name) = customModel + val (_, size, id, hash, name) = customModel assertThat(name).isEqualTo(customModel.name) assertThat(hash).isEqualTo(customModel.modelHash) diff --git a/firebase-ml-modeldownloader/ml-data-collection-tests/src/test/java/com/google/firebase/ml_data_collection_tests/MlDataCollectionTestUtil.java b/firebase-ml-modeldownloader/ml-data-collection-tests/src/test/java/com/google/firebase/ml_data_collection_tests/MlDataCollectionTestUtil.java index cca82b8e6f0..2b9788e8d84 100644 --- a/firebase-ml-modeldownloader/ml-data-collection-tests/src/test/java/com/google/firebase/ml_data_collection_tests/MlDataCollectionTestUtil.java +++ b/firebase-ml-modeldownloader/ml-data-collection-tests/src/test/java/com/google/firebase/ml_data_collection_tests/MlDataCollectionTestUtil.java @@ -14,11 +14,10 @@ package com.google.firebase.ml_data_collection_tests; -import android.content.Context; -import android.content.SharedPreferences; import androidx.test.core.app.ApplicationProvider; import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; +import com.google.firebase.ml.modeldownloader.FirebaseModelDownloader; import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil; import java.util.function.Consumer; @@ -47,12 +46,8 @@ static void withApp(String name, Consumer callable) { } static SharedPreferencesUtil getSharedPreferencesUtil(FirebaseApp app) { - return new SharedPreferencesUtil(app); - } - - static SharedPreferences getSharedPreferences(FirebaseApp app) { - return app.getApplicationContext() - .getSharedPreferences(SharedPreferencesUtil.PREFERENCES_PACKAGE_NAME, Context.MODE_PRIVATE); + return new SharedPreferencesUtil( + app, FirebaseModelDownloader.getInstance(app).getModelFactory()); } static void setSharedPreferencesTo(FirebaseApp app, Boolean enabled) { diff --git a/firebase-ml-modeldownloader/src/androidTest/java/com/google/firebase/ml/modeldownloader/TestGetModelLocal.java b/firebase-ml-modeldownloader/src/androidTest/java/com/google/firebase/ml/modeldownloader/TestGetModelLocal.java index 1febf4a9785..8fa23f00bb5 100644 --- a/firebase-ml-modeldownloader/src/androidTest/java/com/google/firebase/ml/modeldownloader/TestGetModelLocal.java +++ b/firebase-ml-modeldownloader/src/androidTest/java/com/google/firebase/ml/modeldownloader/TestGetModelLocal.java @@ -45,20 +45,25 @@ public class TestGetModelLocal { private static final String MODEL_NAME_LOCAL = "getLocalModel"; private static final String MODEL_NAME_LOCAL_2 = "getLocalModel2"; private static final String MODEL_HASH = "origHash324"; - private final CustomModel SETUP_LOADED_LOCAL_MODEL = - new CustomModel(MODEL_NAME_LOCAL, MODEL_HASH, 100, 0); private FirebaseApp app; private File firstDeviceModelFile; private File firstLoadTempModelFile; + private CustomModel.Factory modelFactory; + private CustomModel setupLoadedLocalModel; + @Before public void before() { app = FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext()); app.setDataCollectionDefaultEnabled(Boolean.FALSE); FirebaseModelDownloader firebaseModelDownloader = FirebaseModelDownloader.getInstance(app); - SharedPreferencesUtil sharedPreferencesUtil = new SharedPreferencesUtil(app); + modelFactory = firebaseModelDownloader.getModelFactory(); + + setupLoadedLocalModel = modelFactory.create(MODEL_NAME_LOCAL, MODEL_HASH, 100, 0); + + SharedPreferencesUtil sharedPreferencesUtil = new SharedPreferencesUtil(app, modelFactory); // reset shared preferences and downloads for models used by this test. firebaseModelDownloader.deleteDownloadedModel(MODEL_NAME_LOCAL); firebaseModelDownloader.deleteDownloadedModel(MODEL_NAME_LOCAL_2); @@ -79,7 +84,11 @@ public void teardown() { } private void setUpLoadedLocalModelWithFile() throws Exception { - ModelFileManager fileManager = ModelFileManager.getInstance(); + ModelFileManager fileManager = + new ModelFileManager( + app.getApplicationContext(), + app.getPersistenceKey(), + new SharedPreferencesUtil(app, modelFactory)); final File testDir = new File(app.getApplicationContext().getNoBackupFilesDir(), "tmpModels"); testDir.mkdirs(); // make sure the directory is empty. Doesn't recurse into subdirs, but that's OK since @@ -105,14 +114,14 @@ private void setUpLoadedLocalModelWithFile() throws Exception { ParcelFileDescriptor fd = ParcelFileDescriptor.open(firstLoadTempModelFile, ParcelFileDescriptor.MODE_READ_ONLY); - firstDeviceModelFile = fileManager.moveModelToDestinationFolder(SETUP_LOADED_LOCAL_MODEL, fd); + firstDeviceModelFile = fileManager.moveModelToDestinationFolder(setupLoadedLocalModel, fd); assertEquals(firstDeviceModelFile, new File(expectedDestinationFolder + "/0")); assertTrue(firstDeviceModelFile.exists()); fd.close(); fakePreloadedCustomModel( MODEL_NAME_LOCAL, - SETUP_LOADED_LOCAL_MODEL.getModelHash(), + setupLoadedLocalModel.getModelHash(), 99, expectedDestinationFolder + "/0"); } @@ -228,9 +237,9 @@ public void localModel_preloadedDoNotFetchUpdate() throws Exception { } private void fakePreloadedCustomModel(String modelName, String hash, long size, String filePath) { - SharedPreferencesUtil sharedPreferencesUtil = new SharedPreferencesUtil(app); + SharedPreferencesUtil sharedPreferencesUtil = new SharedPreferencesUtil(app, modelFactory); sharedPreferencesUtil.setLoadedCustomModelDetails( - new CustomModel(modelName, hash, size, 0L, filePath)); + modelFactory.create(modelName, hash, size, 0L, filePath)); } private Set getDownloadedModelList() diff --git a/firebase-ml-modeldownloader/src/androidTest/java/com/google/firebase/ml/modeldownloader/TestPublicApi.java b/firebase-ml-modeldownloader/src/androidTest/java/com/google/firebase/ml/modeldownloader/TestPublicApi.java index 52f58a32137..f1dfbf2a841 100644 --- a/firebase-ml-modeldownloader/src/androidTest/java/com/google/firebase/ml/modeldownloader/TestPublicApi.java +++ b/firebase-ml-modeldownloader/src/androidTest/java/com/google/firebase/ml/modeldownloader/TestPublicApi.java @@ -45,7 +45,8 @@ public void before() { app.setDataCollectionDefaultEnabled(Boolean.FALSE); FirebaseModelDownloader firebaseModelDownloader = FirebaseModelDownloader.getInstance(app); - SharedPreferencesUtil sharedPreferencesUtil = new SharedPreferencesUtil(app); + SharedPreferencesUtil sharedPreferencesUtil = + new SharedPreferencesUtil(app, firebaseModelDownloader.getModelFactory()); // reset shared preferences and downloads for models used by this test. firebaseModelDownloader.deleteDownloadedModel(MODEL_NAME_LOCAL); firebaseModelDownloader.deleteDownloadedModel(MODEL_NAME_LOCAL_2); diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModel.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModel.java index 2395161ed98..16d9c32c6f0 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModel.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModel.java @@ -16,9 +16,13 @@ import androidx.annotation.NonNull; import androidx.annotation.Nullable; +import androidx.annotation.RestrictTo; import androidx.annotation.VisibleForTesting; import com.google.android.gms.common.internal.Objects; import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService; +import dagger.assisted.Assisted; +import dagger.assisted.AssistedFactory; +import dagger.assisted.AssistedInject; import java.io.File; /** @@ -27,6 +31,7 @@ * downloaded, the original model file will be removed once it is safe to do so. */ public class CustomModel { + private final ModelFileDownloadService fileDownloadService; private final String name; private final long downloadId; private final long fileSize; @@ -35,57 +40,31 @@ public class CustomModel { private final String downloadUrl; private final long downloadUrlExpiry; - /** - * Use when creating a custom model while the initial download is still in progress. - * - * @param name Model name. - * @param modelHash Model hash. - * @param fileSize Model file size. - * @param downloadId Android Download Manger - download ID. - * @hide - */ - public CustomModel( - @NonNull String name, @NonNull String modelHash, long fileSize, long downloadId) { - this(name, modelHash, fileSize, downloadId, "", "", 0); - } + /** @hide */ + @AssistedFactory + public interface Factory { + CustomModel create( + @Assisted("name") String name, + @Assisted("modelHash") String modelHash, + @Assisted("fileSize") long fileSize, + @Assisted("downloadId") long downloadId, + @Assisted("localFilePath") String localFilePath, + @Assisted("downloadUrl") String downloadUrl, + @Assisted("downloadUrlExpiry") long downloadUrlExpiry); - /** - * Use when creating a custom model from a stored model with a new download in the background. - * - * @param name Model name. - * @param modelHash Model hash. - * @param fileSize Model file size. - * @param downloadId Android Download Manger - download ID. - * @hide - */ - public CustomModel( - @NonNull String name, - @NonNull String modelHash, - long fileSize, - long downloadId, - String localFilePath) { - this(name, modelHash, fileSize, downloadId, localFilePath, "", 0); - } + default CustomModel create(String name, String modelHash, long fileSize, long downloadId) { + return create(name, modelHash, fileSize, downloadId, "", "", 0); + } - /** - * Use when creating a custom model from a download service response. Download URL and download - * URL expiry should go together. These will not be stored in user preferences as this is a - * temporary step towards setting the actual download ID. - * - * @param name Model name. - * @param modelHash Model hash. - * @param fileSize Model file size. - * @param downloadUrl Download URL path - * @param downloadUrlExpiry Time download URL path expires. - * @hide - */ - public CustomModel( - @NonNull String name, - @NonNull String modelHash, - long fileSize, - String downloadUrl, - long downloadUrlExpiry) { - this(name, modelHash, fileSize, 0, "", downloadUrl, downloadUrlExpiry); + default CustomModel create( + String name, String modelHash, long fileSize, long downloadId, String localFilePath) { + return create(name, modelHash, fileSize, downloadId, localFilePath, "", 0); + } + + default CustomModel create( + String name, String modelHash, long fileSize, String downloadUrl, long downloadUrlExpiry) { + return create(name, modelHash, fileSize, 0, "", downloadUrl, downloadUrlExpiry); + } } /** @@ -100,14 +79,19 @@ public CustomModel( * @param downloadUrlExpiry Expiry time of download URL link. * @hide */ - private CustomModel( - @NonNull String name, - @NonNull String modelHash, - long fileSize, - long downloadId, - @Nullable String localFilePath, - @Nullable String downloadUrl, - long downloadUrlExpiry) { + @AssistedInject + @VisibleForTesting + @RestrictTo(RestrictTo.Scope.LIBRARY) + public CustomModel( + ModelFileDownloadService fileDownloadService, + @Assisted("name") String name, + @Assisted("modelHash") String modelHash, + @Assisted("fileSize") long fileSize, + @Assisted("downloadId") long downloadId, + @Assisted("localFilePath") String localFilePath, + @Assisted("downloadUrl") String downloadUrl, + @Assisted("downloadUrlExpiry") long downloadUrlExpiry) { + this.fileDownloadService = fileDownloadService; this.modelHash = modelHash; this.name = name; this.fileSize = fileSize; @@ -137,7 +121,7 @@ public String getName() { */ @Nullable public File getFile() { - return getFile(ModelFileDownloadService.getInstance()); + return getFile(fileDownloadService); } /** diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java index ac67275d045..5747560bb50 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java @@ -26,7 +26,6 @@ import com.google.android.gms.tasks.Tasks; import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; -import com.google.firebase.installations.FirebaseInstallationsApi; import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.ModelDownloadLogEvent.DownloadStatus; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.ModelDownloadLogEvent.ErrorCode; @@ -38,6 +37,7 @@ import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import javax.inject.Inject; public class FirebaseModelDownloader { @@ -50,21 +50,29 @@ public class FirebaseModelDownloader { private final Executor executor; private final FirebaseMlLogger eventLogger; + private final CustomModel.Factory modelFactory; + @Inject @RequiresApi(api = VERSION_CODES.KITKAT) // TODO(b/258424267): Migrate to go/firebase-android-executors @SuppressLint("ThreadPoolCreation") FirebaseModelDownloader( - FirebaseApp firebaseApp, FirebaseInstallationsApi firebaseInstallationsApi) { - this.firebaseOptions = firebaseApp.getOptions(); - this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp); - this.eventLogger = FirebaseMlLogger.getInstance(firebaseApp); - this.fileDownloadService = new ModelFileDownloadService(firebaseApp); - this.modelDownloadService = - new CustomModelDownloadService(firebaseApp, firebaseInstallationsApi); - - this.executor = Executors.newSingleThreadExecutor(); - fileManager = ModelFileManager.getInstance(); + FirebaseOptions firebaseOptions, + SharedPreferencesUtil sharedPreferencesUtil, + ModelFileDownloadService fileDownloadService, + CustomModelDownloadService modelDownloadService, + ModelFileManager fileManager, + FirebaseMlLogger eventLogger, + CustomModel.Factory modelFactory) { + this( + firebaseOptions, + sharedPreferencesUtil, + fileDownloadService, + modelDownloadService, + fileManager, + eventLogger, + Executors.newSingleThreadExecutor(), + modelFactory); } @VisibleForTesting @@ -75,7 +83,8 @@ public class FirebaseModelDownloader { CustomModelDownloadService modelDownloadService, ModelFileManager fileManager, FirebaseMlLogger eventLogger, - Executor executor) { + Executor executor, + CustomModel.Factory modelFactory) { this.firebaseOptions = firebaseOptions; this.sharedPreferencesUtil = sharedPreferencesUtil; this.fileDownloadService = fileDownloadService; @@ -83,6 +92,7 @@ public class FirebaseModelDownloader { this.fileManager = fileManager; this.eventLogger = eventLogger; this.executor = executor; + this.modelFactory = modelFactory; } /** @@ -539,4 +549,10 @@ public Task getModelDownloadId( String getApplicationId() { return firebaseOptions.getApplicationId(); } + + /** @hide */ + @VisibleForTesting + public CustomModel.Factory getModelFactory() { + return modelFactory; + } } diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java index cb92e0950ac..dcb85a2a864 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java @@ -14,6 +14,7 @@ package com.google.firebase.ml.modeldownloader; +import android.content.Context; import android.os.Build.VERSION_CODES; import androidx.annotation.NonNull; import androidx.annotation.RequiresApi; @@ -23,11 +24,6 @@ import com.google.firebase.components.ComponentRegistrar; import com.google.firebase.components.Dependency; import com.google.firebase.installations.FirebaseInstallationsApi; -import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService; -import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogger; -import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService; -import com.google.firebase.ml.modeldownloader.internal.ModelFileManager; -import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil; import com.google.firebase.platforminfo.LibraryVersionComponent; import java.util.Arrays; import java.util.List; @@ -48,43 +44,19 @@ public List> getComponents() { return Arrays.asList( Component.builder(FirebaseModelDownloader.class) .name(LIBRARY_NAME) + .add(Dependency.required(Context.class)) .add(Dependency.required(FirebaseApp.class)) - .add(Dependency.required(FirebaseInstallationsApi.class)) + .add(Dependency.requiredProvider(FirebaseInstallationsApi.class)) + .add(Dependency.requiredProvider(TransportFactory.class)) .factory( c -> - new FirebaseModelDownloader( - c.get(FirebaseApp.class), c.get(FirebaseInstallationsApi.class))) - .build(), - Component.builder(SharedPreferencesUtil.class) - .add(Dependency.required(FirebaseApp.class)) - .factory(c -> new SharedPreferencesUtil(c.get(FirebaseApp.class))) - .build(), - Component.builder(FirebaseMlLogger.class) - .add(Dependency.required(FirebaseApp.class)) - .add(Dependency.required(TransportFactory.class)) - .add(Dependency.required(SharedPreferencesUtil.class)) - .factory( - c -> - new FirebaseMlLogger( - c.get(FirebaseApp.class), - c.get(SharedPreferencesUtil.class), - c.get(TransportFactory.class))) - .build(), - Component.builder(ModelFileManager.class) - .add(Dependency.required(FirebaseApp.class)) - .factory(c -> new ModelFileManager(c.get(FirebaseApp.class))) - .build(), - Component.builder(ModelFileDownloadService.class) - .add(Dependency.required(FirebaseApp.class)) - .factory(c -> new ModelFileDownloadService(c.get(FirebaseApp.class))) - .build(), - Component.builder(CustomModelDownloadService.class) - .add(Dependency.required(FirebaseApp.class)) - .add(Dependency.required(FirebaseInstallationsApi.class)) - .factory( - c -> - new CustomModelDownloadService( - c.get(FirebaseApp.class), c.get(FirebaseInstallationsApi.class))) + DaggerModelDownloaderComponent.builder() + .setApplicationContext(c.get(Context.class)) + .setFirebaseApp(c.get(FirebaseApp.class)) + .setFis(c.getProvider(FirebaseInstallationsApi.class)) + .setTransportFactory(c.getProvider(TransportFactory.class)) + .build() + .getModelDownloader()) .build(), LibraryVersionComponent.create(LIBRARY_NAME, BuildConfig.VERSION_NAME)); } diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/ModelDownloaderComponent.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/ModelDownloaderComponent.java new file mode 100644 index 00000000000..f69036c4a0a --- /dev/null +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/ModelDownloaderComponent.java @@ -0,0 +1,88 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.firebase.ml.modeldownloader; + +import android.content.Context; +import android.content.pm.PackageManager; +import android.util.Log; +import com.google.android.datatransport.TransportFactory; +import com.google.firebase.FirebaseApp; +import com.google.firebase.FirebaseOptions; +import com.google.firebase.inject.Provider; +import com.google.firebase.installations.FirebaseInstallationsApi; +import dagger.BindsInstance; +import dagger.Component; +import dagger.Module; +import dagger.Provides; +import javax.inject.Named; +import javax.inject.Singleton; + +/** @hide */ +@Component(modules = ModelDownloaderComponent.MainModule.class) +@Singleton +interface ModelDownloaderComponent { + FirebaseModelDownloader getModelDownloader(); + + @Component.Builder + interface Builder { + @BindsInstance + Builder setApplicationContext(Context context); + + @BindsInstance + Builder setFirebaseApp(FirebaseApp app); + + @BindsInstance + Builder setFis(Provider fis); + + @BindsInstance + Builder setTransportFactory(Provider transportFactory); + + ModelDownloaderComponent build(); + } + + @Module + interface MainModule { + @Provides + @Named("persistenceKey") + static String persistenceKey(FirebaseApp app) { + return app.getPersistenceKey(); + } + + @Provides + @Named("appPackageName") + static String appPackageName(Context applicationContext) { + return applicationContext.getPackageName(); + } + + @Provides + static FirebaseOptions firebaseOptions(FirebaseApp app) { + return app.getOptions(); + } + + @Provides + @Singleton + @Named("appVersionCode") + static String appVersionCode( + Context applicationContext, @Named("appPackageName") String appPackageName) { + try { + return String.valueOf( + applicationContext.getPackageManager().getPackageInfo(appPackageName, 0).versionCode); + } catch (PackageManager.NameNotFoundException e) { + Log.e("ModelDownloader", "Exception thrown when trying to get app version " + e); + } + return ""; + } + } +} diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java index b046a6a3778..edab83e1847 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java @@ -27,7 +27,8 @@ import com.google.android.gms.common.util.VisibleForTesting; import com.google.android.gms.tasks.Task; import com.google.android.gms.tasks.Tasks; -import com.google.firebase.FirebaseApp; +import com.google.firebase.FirebaseOptions; +import com.google.firebase.inject.Provider; import com.google.firebase.installations.FirebaseInstallationsApi; import com.google.firebase.installations.InstallationTokenResult; import com.google.firebase.ml.modeldownloader.CustomModel; @@ -50,6 +51,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.zip.GZIPInputStream; +import javax.inject.Inject; import org.json.JSONObject; /** @@ -87,34 +89,42 @@ public class CustomModelDownloadService { static final String DOWNLOAD_MODEL_REGEX = "%s/v1beta2/projects/%s/models/%s:download"; private final ExecutorService executorService; - private final FirebaseInstallationsApi firebaseInstallations; + private final Provider firebaseInstallations; private final FirebaseMlLogger eventLogger; private final String apiKey; @Nullable private final String fingerprintHashForPackage; private final Context context; + private final CustomModel.Factory modelFactory; private String downloadHost = FIREBASE_DOWNLOAD_HOST; // TODO(b/258424267): Migrate to go/firebase-android-executors @SuppressLint("ThreadPoolCreation") + @Inject public CustomModelDownloadService( - FirebaseApp firebaseApp, FirebaseInstallationsApi installationsApi) { - context = firebaseApp.getApplicationContext(); + Context context, + FirebaseOptions options, + Provider installationsApi, + FirebaseMlLogger eventLogger, + CustomModel.Factory modelFactory) { + this.context = context; firebaseInstallations = installationsApi; - apiKey = firebaseApp.getOptions().getApiKey(); + apiKey = options.getApiKey(); fingerprintHashForPackage = getFingerprintHashForPackage(context); executorService = Executors.newCachedThreadPool(); - this.eventLogger = FirebaseMlLogger.getInstance(); + this.eventLogger = eventLogger; + this.modelFactory = modelFactory; } @VisibleForTesting CustomModelDownloadService( Context context, - FirebaseInstallationsApi firebaseInstallations, + Provider firebaseInstallations, ExecutorService executorService, String apiKey, String fingerprintHashForPackage, String downloadHost, - FirebaseMlLogger eventLogger) { + FirebaseMlLogger eventLogger, + CustomModel.Factory modelFactory) { this.context = context; this.firebaseInstallations = firebaseInstallations; this.executorService = executorService; @@ -122,6 +132,7 @@ public CustomModelDownloadService( this.fingerprintHashForPackage = fingerprintHashForPackage; this.downloadHost = downloadHost; this.eventLogger = eventLogger; + this.modelFactory = modelFactory; } /** @@ -169,7 +180,7 @@ public Task getCustomModelDetails( } Task installationAuthTokenTask = - firebaseInstallations.getToken(false); + firebaseInstallations.get().getToken(false); return installationAuthTokenTask.continueWithTask( executorService, (CustomModelTask) -> { @@ -186,7 +197,7 @@ public Task getCustomModelDetails( exceptionCode = FirebaseMlException.NO_NETWORK_CONNECTION; } eventLogger.logDownloadFailureWithReason( - new CustomModel(modelName, modelHash != null ? modelHash : "", 0, 0L), + modelFactory.create(modelName, modelHash != null ? modelHash : "", 0, 0L), false, errorCode.getValue()); return Tasks.forException(new FirebaseMlException(errorMessage, exceptionCode)); @@ -208,7 +219,7 @@ public Task getCustomModelDetails( } catch (IOException e) { eventLogger.logDownloadFailureWithReason( - new CustomModel(modelName, modelHash, 0, 0L), + modelFactory.create(modelName, modelHash, 0, 0L), false, ErrorCode.MODEL_INFO_DOWNLOAD_CONNECTION_FAILED.getValue()); @@ -314,7 +325,7 @@ private Task fetchDownloadDetails(String modelName, HttpURLConnecti errorMessage = "Failed to retrieve model info due to no internet connection."; exceptionCode = FirebaseMlException.NO_NETWORK_CONNECTION; } - eventLogger.logModelInfoRetrieverFailure(new CustomModel(modelName, "", 0, 0), errorCode); + eventLogger.logModelInfoRetrieverFailure(modelFactory.create(modelName, "", 0, 0), errorCode); return Tasks.forException(new FirebaseMlException(errorMessage, exceptionCode)); } } @@ -322,7 +333,7 @@ private Task fetchDownloadDetails(String modelName, HttpURLConnecti private Task setAndLogException( String modelName, int httpResponseCode, String errorMessage, @Code int invalidArgument) { eventLogger.logModelInfoRetrieverFailure( - new CustomModel(modelName, "", 0, 0), + modelFactory.create(modelName, "", 0, 0), ErrorCode.MODEL_INFO_DOWNLOAD_UNSUCCESSFUL_HTTP_STATUS, httpResponseCode); return Tasks.forException(new FirebaseMlException(errorMessage, invalidArgument)); @@ -341,7 +352,7 @@ private Task readCustomModelResponse( if (modelHash == null || modelHash.isEmpty()) { eventLogger.logDownloadFailureWithReason( - new CustomModel(modelName, modelHash, 0, 0L), + modelFactory.create(modelName, modelHash, 0, 0L), false, ErrorCode.MODEL_INFO_DOWNLOAD_CONNECTION_FAILED.getValue()); return Tasks.forException( @@ -375,12 +386,13 @@ private Task readCustomModelResponse( inputStream.close(); if (!downloadUrl.isEmpty() && expireTime > 0L) { - CustomModel model = new CustomModel(modelName, modelHash, fileSize, downloadUrl, expireTime); + CustomModel model = + modelFactory.create(modelName, modelHash, fileSize, downloadUrl, expireTime); eventLogger.logModelInfoRetrieverSuccess(model); return Tasks.forResult(model); } eventLogger.logDownloadFailureWithReason( - new CustomModel(modelName, modelHash, 0, 0L), + modelFactory.create(modelName, modelHash, 0, 0L), false, ErrorCode.MODEL_INFO_DOWNLOAD_CONNECTION_FAILED.getValue()); return Tasks.forException( diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/DataTransportMlEventSender.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/DataTransportMlEventSender.java index 8dbae13bd29..ca6ec9d65c1 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/DataTransportMlEventSender.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/DataTransportMlEventSender.java @@ -19,6 +19,10 @@ import com.google.android.datatransport.Event; import com.google.android.datatransport.Transport; import com.google.android.datatransport.TransportFactory; +import com.google.firebase.components.Lazy; +import com.google.firebase.inject.Provider; +import javax.inject.Inject; +import javax.inject.Singleton; /** * This class is responsible for sending Firebase ML Log Events to Firebase through Google @@ -28,26 +32,26 @@ * * @hide */ +@Singleton public class DataTransportMlEventSender { private static final String FIREBASE_ML_LOG_SDK_NAME = "FIREBASE_ML_LOG_SDK"; - private final Transport transport; + private final Provider> transport; - @NonNull - public static DataTransportMlEventSender create(TransportFactory transportFactory) { - final Transport transport = - transportFactory.getTransport( - FIREBASE_ML_LOG_SDK_NAME, - FirebaseMlLogEvent.class, - Encoding.of("json"), - FirebaseMlLogEvent.getFirebaseMlJsonTransformer()); - return new DataTransportMlEventSender(transport); - } - - DataTransportMlEventSender(Transport transport) { - this.transport = transport; + @Inject + DataTransportMlEventSender(Provider transportFactory) { + this.transport = + new Lazy<>( + () -> + transportFactory + .get() + .getTransport( + FIREBASE_ML_LOG_SDK_NAME, + FirebaseMlLogEvent.class, + Encoding.of("json"), + FirebaseMlLogEvent.getFirebaseMlJsonTransformer())); } public void sendEvent(@NonNull FirebaseMlLogEvent firebaseMlLogEvent) { - transport.send(Event.ofData(firebaseMlLogEvent)); + transport.get().send(Event.ofData(firebaseMlLogEvent)); } } diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/FirebaseMlLogger.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/FirebaseMlLogger.java index b68b61bf829..651fe901370 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/FirebaseMlLogger.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/FirebaseMlLogger.java @@ -14,15 +14,11 @@ package com.google.firebase.ml.modeldownloader.internal; -import android.content.pm.PackageInfo; -import android.content.pm.PackageManager.NameNotFoundException; import android.os.SystemClock; import android.util.Log; import androidx.annotation.NonNull; -import androidx.annotation.VisibleForTesting; import androidx.annotation.WorkerThread; -import com.google.android.datatransport.TransportFactory; -import com.google.firebase.FirebaseApp; +import com.google.firebase.FirebaseOptions; import com.google.firebase.ml.modeldownloader.BuildConfig; import com.google.firebase.ml.modeldownloader.CustomModel; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.DeleteModelLogEvent; @@ -33,6 +29,10 @@ import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.ModelDownloadLogEvent.ModelOptions; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.ModelDownloadLogEvent.ModelOptions.ModelInfo; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.SystemInfo; +import javax.inject.Inject; +import javax.inject.Named; +import javax.inject.Provider; +import javax.inject.Singleton; /** * Logging class for Firebase ML Event logging. @@ -40,62 +40,35 @@ * @hide */ @WorkerThread +@Singleton public class FirebaseMlLogger { public static final int NO_FAILURE_VALUE = 0; private static final String TAG = "FirebaseMlLogger"; private final SharedPreferencesUtil sharedPreferencesUtil; private final DataTransportMlEventSender eventSender; - private final FirebaseApp firebaseApp; + private final FirebaseOptions firebaseOptions; - private final String appPackageName; - private final String appVersion; + private final Provider appPackageName; + private final Provider appVersionCode; private final String firebaseProjectId; private final String apiKey; + @Inject public FirebaseMlLogger( - @NonNull FirebaseApp firebaseApp, - @NonNull SharedPreferencesUtil sharedPreferencesUtil, - @NonNull TransportFactory transportFactory) { - this.firebaseApp = firebaseApp; - this.sharedPreferencesUtil = sharedPreferencesUtil; - this.eventSender = DataTransportMlEventSender.create(transportFactory); - - this.firebaseProjectId = getProjectId(); - this.apiKey = getApiKey(); - this.appPackageName = firebaseApp.getApplicationContext().getPackageName(); - this.appVersion = getAppVersion(); - } - - @VisibleForTesting - FirebaseMlLogger( - @NonNull FirebaseApp firebaseApp, - @NonNull SharedPreferencesUtil sharedPreferencesUtil, - @NonNull DataTransportMlEventSender eventSender) { - this.firebaseApp = firebaseApp; + FirebaseOptions options, + SharedPreferencesUtil sharedPreferencesUtil, + DataTransportMlEventSender eventSender, + @Named("appPackageName") Provider appPackageName, + @Named("appVersionCode") Provider appVersionCode) { + this.firebaseOptions = options; this.sharedPreferencesUtil = sharedPreferencesUtil; this.eventSender = eventSender; this.firebaseProjectId = getProjectId(); this.apiKey = getApiKey(); - this.appPackageName = firebaseApp.getApplicationContext().getPackageName(); - this.appVersion = getAppVersion(); - } - - /** - * Get FirebaseMlLogger instance using the firebase app returned by {@link - * FirebaseApp#getInstance()} - * - * @return FirebaseMlLogger - */ - @NonNull - public static FirebaseMlLogger getInstance() { - return FirebaseApp.getInstance().get(FirebaseMlLogger.class); - } - - @NonNull - public static FirebaseMlLogger getInstance(@NonNull FirebaseApp app) { - return app.get(FirebaseMlLogger.class); + this.appPackageName = appPackageName; + this.appVersionCode = appVersionCode; } void logModelInfoRetrieverFailure(CustomModel model, ErrorCode errorCode) { @@ -256,33 +229,15 @@ private void logDownloadEvent( private SystemInfo getSystemInfo() { return SystemInfo.builder() .setFirebaseProjectId(firebaseProjectId) - .setAppId(appPackageName) - .setAppVersion(appVersion) + .setAppId(appPackageName.get()) + .setAppVersion(appVersionCode.get()) .setApiKey(apiKey) .setMlSdkVersion(BuildConfig.VERSION_NAME) .build(); } - private String getAppVersion() { - String version = ""; - try { - PackageInfo packageInfo = - firebaseApp - .getApplicationContext() - .getPackageManager() - .getPackageInfo(firebaseApp.getApplicationContext().getPackageName(), 0); - version = String.valueOf(packageInfo.versionCode); - } catch (NameNotFoundException e) { - Log.e(TAG, "Exception thrown when trying to get app version " + e); - } - return version; - } - private String getProjectId() { - if (firebaseApp == null) { - return ""; - } - String projectId = firebaseApp.getOptions().getProjectId(); + String projectId = firebaseOptions.getProjectId(); if (projectId == null) { return ""; } @@ -290,10 +245,6 @@ private String getProjectId() { } private String getApiKey() { - if (firebaseApp == null) { - return ""; - } - String key = firebaseApp.getOptions().getApiKey(); - return key == null ? "" : key; + return firebaseOptions.getApiKey(); } } diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadService.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadService.java index ba5ae50043d..2af5578df46 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadService.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadService.java @@ -36,7 +36,6 @@ import com.google.android.gms.tasks.Task; import com.google.android.gms.tasks.TaskCompletionSource; import com.google.android.gms.tasks.Tasks; -import com.google.firebase.FirebaseApp; import com.google.firebase.ml.modeldownloader.CustomModel; import com.google.firebase.ml.modeldownloader.CustomModelDownloadConditions; import com.google.firebase.ml.modeldownloader.FirebaseMlException; @@ -47,6 +46,7 @@ import java.util.Date; import java.util.regex.Matcher; import java.util.regex.Pattern; +import javax.inject.Inject; /** * Calls the Android Download service to copy the model file to device (temp location) and then @@ -65,6 +65,7 @@ public class ModelFileDownloadService { private final ModelFileManager fileManager; private final SharedPreferencesUtil sharedPreferencesUtil; private final FirebaseMlLogger eventLogger; + private final CustomModel.Factory modelFactory; private boolean isInitialLoad; @@ -82,40 +83,39 @@ public class ModelFileDownloadService { private CustomModelDownloadConditions downloadConditions = new CustomModelDownloadConditions.Builder().build(); - public ModelFileDownloadService(@NonNull FirebaseApp firebaseApp) { - this.context = firebaseApp.getApplicationContext(); - downloadManager = (DownloadManager) context.getSystemService(Context.DOWNLOAD_SERVICE); - this.fileManager = ModelFileManager.getInstance(firebaseApp); - this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp); - this.isInitialLoad = true; - this.eventLogger = FirebaseMlLogger.getInstance(); + @Inject + public ModelFileDownloadService( + Context context, + FirebaseMlLogger eventLogger, + ModelFileManager modelFileManager, + SharedPreferencesUtil sharedPreferencesUtil, + CustomModel.Factory modelFactory) { + this( + context, + (DownloadManager) context.getSystemService(Context.DOWNLOAD_SERVICE), + modelFileManager, + sharedPreferencesUtil, + eventLogger, + true, + modelFactory); } @VisibleForTesting ModelFileDownloadService( - @NonNull FirebaseApp firebaseApp, + Context context, DownloadManager downloadManager, ModelFileManager fileManager, SharedPreferencesUtil sharedPreferencesUtil, FirebaseMlLogger eventLogger, - boolean isInitialLoad) { - this.context = firebaseApp.getApplicationContext(); + boolean isInitialLoad, + CustomModel.Factory modelFactory) { + this.context = context; this.downloadManager = downloadManager; this.fileManager = fileManager; this.sharedPreferencesUtil = sharedPreferencesUtil; this.eventLogger = eventLogger; this.isInitialLoad = isInitialLoad; - } - - /** - * Get ModelFileDownloadService instance using the firebase app returned by {@link - * FirebaseApp#getInstance()} - * - * @return ModelFileDownloadService - */ - @NonNull - public static ModelFileDownloadService getInstance() { - return FirebaseApp.getInstance().get(ModelFileDownloadService.class); + this.modelFactory = modelFactory; } public Task download( @@ -291,7 +291,7 @@ synchronized Long scheduleModelDownload(@NonNull CustomModel customModel) // update the custom model to store the download id - do not lose current local file - in case // this is a background update. CustomModel model = - new CustomModel( + modelFactory.create( customModel.getName(), customModel.getModelHash(), customModel.getSize(), @@ -429,7 +429,7 @@ public File loadNewlyDownloadedModelFile(CustomModel model) { + newModelFile.getParent()); // Successfully moved, update share preferences sharedPreferencesUtil.setLoadedCustomModelDetails( - new CustomModel( + modelFactory.create( model.getName(), model.getModelHash(), model.getSize(), 0, newModelFile.getPath())); maybeCleanUpOldModels(); diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManager.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManager.java index aa349ad9819..b318cdc5564 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManager.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManager.java @@ -23,48 +23,39 @@ import androidx.annotation.Nullable; import androidx.annotation.VisibleForTesting; import androidx.annotation.WorkerThread; -import com.google.firebase.FirebaseApp; import com.google.firebase.ml.modeldownloader.CustomModel; import com.google.firebase.ml.modeldownloader.FirebaseMlException; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; +import javax.inject.Inject; +import javax.inject.Named; +import javax.inject.Singleton; /** * Model File Manager is used to move the downloaded file to the appropriate locations. * * @hide */ +@Singleton public class ModelFileManager { public static final String CUSTOM_MODEL_ROOT_PATH = "com.google.firebase.ml.custom.models"; private static final String TAG = "FirebaseModelFileManage"; private static final int INVALID_INDEX = -1; private final Context context; - private final FirebaseApp firebaseApp; + private final String persistenceKey; private final SharedPreferencesUtil sharedPreferencesUtil; - public ModelFileManager(@NonNull FirebaseApp firebaseApp) { - this.context = firebaseApp.getApplicationContext(); - this.firebaseApp = firebaseApp; - this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp); - } - - /** - * Get ModelFileDownloadService instance using the firebase app returned by {@link - * FirebaseApp#getInstance()} - * - * @return ModelFileDownloadService - */ - @NonNull - public static ModelFileManager getInstance() { - return FirebaseApp.getInstance().get(ModelFileManager.class); - } - - @NonNull - public static ModelFileManager getInstance(@NonNull FirebaseApp app) { - return app.get(ModelFileManager.class); + @Inject + public ModelFileManager( + Context applicationContext, + @Named("persistenceKey") String persistenceKey, + SharedPreferencesUtil sharedPreferencesUtil) { + this.context = applicationContext; + this.persistenceKey = persistenceKey; + this.sharedPreferencesUtil = sharedPreferencesUtil; } void deleteNonLatestCustomModels() throws FirebaseMlException { @@ -97,7 +88,7 @@ private File getModelDirUnsafe(@NonNull String modelName) { } else { root = context.getApplicationContext().getDir(modelTypeSpecificRoot, Context.MODE_PRIVATE); } - File firebaseAppDir = new File(root, firebaseApp.getPersistenceKey()); + File firebaseAppDir = new File(root, persistenceKey); return new File(firebaseAppDir, modelName); } diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java index 4559fbcb165..e5af8131291 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java @@ -29,8 +29,11 @@ import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; +import javax.inject.Inject; +import javax.inject.Singleton; /** @hide */ +@Singleton public class SharedPreferencesUtil { public static final String FIREBASE_MODELDOWNLOADER_COLLECTION_ENABLED = @@ -58,10 +61,13 @@ public class SharedPreferencesUtil { private final String persistenceKey; private final FirebaseApp firebaseApp; + private final CustomModel.Factory modelFactory; - public SharedPreferencesUtil(FirebaseApp firebaseApp) { + @Inject + public SharedPreferencesUtil(FirebaseApp firebaseApp, CustomModel.Factory modelFactory) { this.firebaseApp = firebaseApp; this.persistenceKey = firebaseApp.getPersistenceKey(); + this.modelFactory = modelFactory; } /** @@ -99,7 +105,7 @@ public synchronized CustomModel getCustomModelDetails(@NonNull String modelName) getSharedPreferences() .getLong(String.format(DOWNLOADING_MODEL_ID_PATTERN, persistenceKey, modelName), 0); - return new CustomModel(modelName, modelHash, fileSize, id, filePath); + return modelFactory.create(modelName, modelHash, fileSize, id, filePath); } /** @@ -130,7 +136,7 @@ public synchronized CustomModel getDownloadingCustomModelDetails(@NonNull String getSharedPreferences() .getLong(String.format(DOWNLOADING_MODEL_ID_PATTERN, persistenceKey, modelName), 0); - return new CustomModel(modelName, modelHash, fileSize, id); + return modelFactory.create(modelName, modelHash, fileSize, id); } /** diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelTest.java index c10b6ec9d75..d4abab80f06 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/CustomModelTest.java @@ -19,12 +19,13 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import android.content.Context; import androidx.test.core.app.ApplicationProvider; -import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; import com.google.firebase.FirebaseOptions.Builder; import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService; @@ -34,8 +35,6 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; import org.robolectric.RobolectricTestRunner; @RunWith(RobolectricTestRunner.class) @@ -52,31 +51,42 @@ public class CustomModelTest { .build(); private static final long URL_EXPIRATION = 604800L; - private final CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0); - private final CustomModel CUSTOM_MODEL_URL = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION); - private final CustomModel CUSTOM_MODEL_BADFILE = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, "tmp/some/bad/filepath/model.tflite"); + private CustomModel CUSTOM_MODEL; + private CustomModel CUSTOM_MODEL_URL; + private CustomModel CUSTOM_MODEL_BADFILE; private File testModelFile; private File testModelFile2; private CustomModel customModelWithFile; - @Mock private ModelFileDownloadService fileDownloadService; + private final ModelFileDownloadService fileDownloadService = mock(ModelFileDownloadService.class); + + private final CustomModel.Factory modelFactory = + (name, modelHash, fileSize, downloadId, localFilePath, downloadUrl, downloadUrlExpiry) -> + new CustomModel( + fileDownloadService, + name, + modelHash, + fileSize, + downloadId, + localFilePath, + downloadUrl, + downloadUrlExpiry); @Before public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - FirebaseApp.clearInstancesForTest(); - // default app - FirebaseApp app = - FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS); - setUpTestingFiles(app); - customModelWithFile = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, testModelFile.getPath()); + CUSTOM_MODEL = modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0); + CUSTOM_MODEL_URL = modelFactory.create(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION); + CUSTOM_MODEL_BADFILE = + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0, "tmp/some/bad/filepath/model.tflite"); + + setUpTestingFiles(ApplicationProvider.getApplicationContext()); + customModelWithFile = + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0, testModelFile.getPath()); } - private void setUpTestingFiles(FirebaseApp app) throws IOException { - final File testDir = new File(app.getApplicationContext().getNoBackupFilesDir(), "tmpModels"); + private void setUpTestingFiles(Context context) throws IOException { + final File testDir = new File(context.getNoBackupFilesDir(), "tmpModels"); testDir.mkdirs(); // make sure the directory is empty. Doesn't recurse into subdirs, but that's OK since // we're only using this directory for this test and we won't create any subdirs. @@ -168,36 +178,40 @@ public void customModel_getDownloadUrlExpiry() { @Test public void customModel_equals() { // downloading models - assertEquals(CUSTOM_MODEL, new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0)); - assertNotEquals(CUSTOM_MODEL, new CustomModel(MODEL_NAME, MODEL_HASH, 101, 0)); - assertNotEquals(CUSTOM_MODEL, new CustomModel(MODEL_NAME, MODEL_HASH, 100, 101)); + assertEquals(CUSTOM_MODEL, modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0)); + assertNotEquals(CUSTOM_MODEL, modelFactory.create(MODEL_NAME, MODEL_HASH, 101, 0)); + assertNotEquals(CUSTOM_MODEL, modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 101)); // get model details models assertEquals( - CUSTOM_MODEL_URL, new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION)); + CUSTOM_MODEL_URL, + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION)); assertNotEquals( - CUSTOM_MODEL_URL, new CustomModel(MODEL_NAME, MODEL_HASH, 101, MODEL_URL, URL_EXPIRATION)); + CUSTOM_MODEL_URL, + modelFactory.create(MODEL_NAME, MODEL_HASH, 101, MODEL_URL, URL_EXPIRATION)); assertNotEquals( CUSTOM_MODEL_URL, - new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION + 10L)); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION + 10L)); } @Test public void customModel_hashCode() { assertEquals( - CUSTOM_MODEL.hashCode(), new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0).hashCode()); + CUSTOM_MODEL.hashCode(), modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0).hashCode()); assertNotEquals( - CUSTOM_MODEL.hashCode(), new CustomModel(MODEL_NAME, MODEL_HASH, 101, 0).hashCode()); + CUSTOM_MODEL.hashCode(), modelFactory.create(MODEL_NAME, MODEL_HASH, 101, 0).hashCode()); assertNotEquals( - CUSTOM_MODEL.hashCode(), new CustomModel(MODEL_NAME, MODEL_HASH, 100, 101).hashCode()); + CUSTOM_MODEL.hashCode(), modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 101).hashCode()); assertEquals( CUSTOM_MODEL_URL.hashCode(), - new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION).hashCode()); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION).hashCode()); assertNotEquals( CUSTOM_MODEL_URL.hashCode(), - new CustomModel(MODEL_NAME, MODEL_HASH, 101, MODEL_URL, URL_EXPIRATION).hashCode()); + modelFactory.create(MODEL_NAME, MODEL_HASH, 101, MODEL_URL, URL_EXPIRATION).hashCode()); assertNotEquals( CUSTOM_MODEL_URL.hashCode(), - new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION + 10L).hashCode()); + modelFactory + .create(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION + 10L) + .hashCode()); } } diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java index cc76d7a0a00..afa6d59920d 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java @@ -23,6 +23,7 @@ import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -52,8 +53,6 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; import org.robolectric.RobolectricTestRunner; @RunWith(RobolectricTestRunner.class) @@ -78,21 +77,32 @@ public class FirebaseModelDownloaderTest { private static final CustomModelDownloadConditions DOWNLOAD_CONDITIONS = new CustomModelDownloadConditions.Builder().requireWifi().build(); - private final CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0); - private final CustomModel ORIG_CUSTOM_MODEL_URL = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION + 10L); - private final CustomModel UPDATE_CUSTOM_MODEL_URL = - new CustomModel(MODEL_NAME, UPDATE_MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION + 10L); - private final CustomModel UPDATE_IN_PROGRESS_CUSTOM_MODEL = - new CustomModel(MODEL_NAME, UPDATE_MODEL_HASH, 100, DOWNLOAD_ID); + private CustomModel CUSTOM_MODEL; + private CustomModel ORIG_CUSTOM_MODEL_URL; + private CustomModel UPDATE_CUSTOM_MODEL_URL; + private CustomModel UPDATE_IN_PROGRESS_CUSTOM_MODEL; private CustomModel customModelUpdateLoaded; private CustomModel customModelLoaded; - private @Mock SharedPreferencesUtil mockPrefs; - private @Mock ModelFileDownloadService mockFileDownloadService; - private @Mock CustomModelDownloadService mockModelDownloadService; - private @Mock ModelFileManager mockFileManager; - private @Mock FirebaseMlLogger mockEventLogger; + private final SharedPreferencesUtil mockPrefs = mock(SharedPreferencesUtil.class); + private final ModelFileDownloadService mockFileDownloadService = + mock(ModelFileDownloadService.class); + private final CustomModelDownloadService mockModelDownloadService = + mock(CustomModelDownloadService.class); + private final ModelFileManager mockFileManager = mock(ModelFileManager.class); + private final FirebaseMlLogger mockEventLogger = mock(FirebaseMlLogger.class); + + private final CustomModel.Factory modelFactory = + (name, modelHash, fileSize, downloadId, localFilePath, downloadUrl, downloadUrlExpiry) -> + new CustomModel( + mockFileDownloadService, + name, + modelHash, + fileSize, + downloadId, + localFilePath, + downloadUrl, + downloadUrlExpiry); private FirebaseModelDownloader firebaseModelDownloader; private ExecutorService executor; @@ -107,7 +117,14 @@ public class FirebaseModelDownloaderTest { @Before public void setUp() throws Exception { - MockitoAnnotations.initMocks(this); + CUSTOM_MODEL = modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0); + ORIG_CUSTOM_MODEL_URL = + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION + 10L); + UPDATE_CUSTOM_MODEL_URL = + modelFactory.create(MODEL_NAME, UPDATE_MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION + 10L); + UPDATE_IN_PROGRESS_CUSTOM_MODEL = + modelFactory.create(MODEL_NAME, UPDATE_MODEL_HASH, 100, DOWNLOAD_ID); + FirebaseApp.clearInstancesForTest(); FirebaseApp app = FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS); @@ -120,7 +137,8 @@ public void setUp() throws Exception { mockModelDownloadService, mockFileManager, mockEventLogger, - executor); + executor, + modelFactory); setUpTestingFiles(app); doNothing().when(mockEventLogger).logDownloadEventWithExactDownloadTime(any(), any(), any()); @@ -131,7 +149,11 @@ public void setUp() throws Exception { } private void setUpTestingFiles(FirebaseApp app) throws Exception { - fileManager = new ModelFileManager(app); + fileManager = + new ModelFileManager( + app.getApplicationContext(), + app.getPersistenceKey(), + new SharedPreferencesUtil(app, modelFactory)); final File testDir = new File(app.getApplicationContext().getNoBackupFilesDir(), "tmpModels"); testDir.mkdirs(); // make sure the directory is empty. Doesn't recurse into subdirs, but that's OK since @@ -172,9 +194,9 @@ private void setUpTestingFiles(FirebaseApp app) throws Exception { fd2.close(); customModelLoaded = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, expectedDestinationFolder + "/0"); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0, expectedDestinationFolder + "/0"); customModelUpdateLoaded = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, expectedDestinationFolder + "/1"); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0, expectedDestinationFolder + "/1"); } @After @@ -210,7 +232,8 @@ public void getModel_latestModel_localExists_noUpdate() throws Exception { public void getModel_latestModel_localExists_noUpdate_MissingFile() throws Exception { // model with missing file. CustomModel missingFileModel = - new CustomModel(MODEL_NAME, UPDATE_MODEL_HASH, 100, 0, expectedDestinationFolder + "/4"); + modelFactory.create( + MODEL_NAME, UPDATE_MODEL_HASH, 100, 0, expectedDestinationFolder + "/4"); when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))) .thenReturn(missingFileModel) .thenReturn(customModelUpdateLoaded); @@ -240,7 +263,7 @@ public void getModel_latestModel_localExists_noUpdate_MissingFile() throws Excep @Test public void getModel_latestModel_localExists_noUpdate_MissingDownloadId() throws Exception { - CustomModel badLocalModel = new CustomModel(MODEL_NAME, UPDATE_MODEL_HASH, 100, 0); + CustomModel badLocalModel = modelFactory.create(MODEL_NAME, UPDATE_MODEL_HASH, 100, 0); when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))) .thenReturn(badLocalModel) // getlocalModelDetails 1 .thenReturn(null) // getCustomModelTask 1 @@ -275,7 +298,7 @@ public void getModel_latestModel_localExists_noUpdate_MissingDownloadId() throws @Test public void getModel_latestModel_localExists_noUpdate_inProgress() throws Exception { // model with no file yet. - CustomModel inProgressLocalModel = new CustomModel(MODEL_NAME, UPDATE_MODEL_HASH, 100, 88); + CustomModel inProgressLocalModel = modelFactory.create(MODEL_NAME, UPDATE_MODEL_HASH, 100, 88); when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))) .thenReturn(inProgressLocalModel) // getlocalModelDetails 1 .thenReturn(inProgressLocalModel) // getCustomModelTask 1 @@ -356,7 +379,7 @@ public void getModel_latestModel_localExists_UpdateFound() throws Exception { @Test public void getModel_latestModel_localExists_DownloadInProgress() throws Exception { CustomModel customModelLoadedWithDownload = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 99, expectedDestinationFolder + "/0"); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 99, expectedDestinationFolder + "/0"); when(mockPrefs.getCustomModelDetails(eq(MODEL_NAME))).thenReturn(customModelLoadedWithDownload); when(mockPrefs.getDownloadingCustomModelDetails(eq(MODEL_NAME))) @@ -423,7 +446,6 @@ public void getModel_latestModel_noLocalModel_modelDownloadServiceFails() throws } verify(mockPrefs, times(2)).getCustomModelDetails(eq(MODEL_NAME)); - verify(mockFileDownloadService, never()).loadNewlyDownloadedModelFile(any()); assertThat(task.isComplete()).isTrue(); assertThat(task.isSuccessful()).isFalse(); assertTrue(task.getException().getMessage().contains("bad state")); diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadServiceTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadServiceTest.java index 12da8853d2d..6f5a3634b9b 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadServiceTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadServiceTest.java @@ -121,6 +121,20 @@ public Builder toBuilder() { private FirebaseInstallationsApi installationsApiMock; @Mock private FirebaseMlLogger mockEventLogger; + private final ModelFileDownloadService modelFileDownloadService = + mock(ModelFileDownloadService.class); + private final CustomModel.Factory modelFactory = + (name, modelHash, fileSize, downloadId, localFilePath, downloadUrl, downloadUrlExpiry) -> + new CustomModel( + modelFileDownloadService, + name, + modelHash, + fileSize, + downloadId, + localFilePath, + downloadUrl, + downloadUrlExpiry); + @Before public void setUp() { MockitoAnnotations.initMocks(this); @@ -172,18 +186,20 @@ public void downloadService_noHashSuccess() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getNewDownloadUrlWithExpiry(PROJECT_ID, MODEL_NAME); Assert.assertEquals( modelTask.getResult(), - new CustomModel(MODEL_NAME, MODEL_HASH, FILE_SIZE, DOWNLOAD_URI, TEST_EXPIRATION_IN_MS)); + modelFactory.create( + MODEL_NAME, MODEL_HASH, FILE_SIZE, DOWNLOAD_URI, TEST_EXPIRATION_IN_MS)); WireMock.verify( getRequestedFor(urlEqualTo(downloadPath)) @@ -220,18 +236,20 @@ public void downloadService_fingerPrintHashNull_NoCertHeader() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, null, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getNewDownloadUrlWithExpiry(PROJECT_ID, MODEL_NAME); Assert.assertEquals( modelTask.getResult(), - new CustomModel(MODEL_NAME, MODEL_HASH, FILE_SIZE, DOWNLOAD_URI, TEST_EXPIRATION_IN_MS)); + modelFactory.create( + MODEL_NAME, MODEL_HASH, FILE_SIZE, DOWNLOAD_URI, TEST_EXPIRATION_IN_MS)); WireMock.verify( getRequestedFor(urlEqualTo(downloadPath)) @@ -266,18 +284,20 @@ public void downloadService_withHashSuccess_noMatch() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, MODEL_HASH); Assert.assertEquals( modelTask.getResult(), - new CustomModel(MODEL_NAME, MODEL_HASH, FILE_SIZE, DOWNLOAD_URI, TEST_EXPIRATION_IN_MS)); + modelFactory.create( + MODEL_NAME, MODEL_HASH, FILE_SIZE, DOWNLOAD_URI, TEST_EXPIRATION_IN_MS)); WireMock.verify( getRequestedFor(urlEqualTo(downloadPath)) @@ -314,12 +334,13 @@ public void downloadService_withHashSuccess_match() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, MODEL_HASH); @@ -363,12 +384,13 @@ public void downloadService_modelNotFound() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, MODEL_HASH); @@ -415,12 +437,13 @@ public void downloadService_badRequest() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, MODEL_HASH); @@ -473,12 +496,13 @@ public void downloadService_forbidden() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, MODEL_HASH); @@ -531,12 +555,13 @@ public void downloadService_internalError() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, MODEL_HASH); @@ -588,12 +613,13 @@ public void downloadService_tooManyRequest() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, MODEL_HASH); @@ -643,12 +669,13 @@ public void downloadService_authenticationIssue() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, MODEL_HASH); @@ -690,12 +717,13 @@ public void downloadService_unauthenticatedToken() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, MODEL_HASH); @@ -718,12 +746,13 @@ public void downloadService_nullModelHashPassedUnauthenticatedToken() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, null); @@ -747,12 +776,13 @@ public void downloadService_malFormedUrl() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, "https7://localhost:8989/barUrl", - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, MODEL_HASH); @@ -775,12 +805,13 @@ public void downloadService_unauthenticatedToken_noNetworkConnection() { CustomModelDownloadService service = new CustomModelDownloadService( ApplicationProvider.getApplicationContext(), - installationsApiMock, + () -> installationsApiMock, directExecutor, API_KEY, PACKAGE_FINGERPRINT_HASH, TEST_ENDPOINT, - mockEventLogger); + mockEventLogger, + modelFactory); Task modelTask = service.getCustomModelDetails(PROJECT_ID, MODEL_NAME, MODEL_HASH); diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/DataTransportMlEventSenderTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/DataTransportMlEventSenderTest.java index 4ca7f441664..576280c8c87 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/DataTransportMlEventSenderTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/DataTransportMlEventSenderTest.java @@ -18,13 +18,16 @@ import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import com.google.android.datatransport.Transport; +import com.google.android.datatransport.TransportFactory; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.EventName; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.SystemInfo; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.robolectric.RobolectricTestRunner; @@ -32,6 +35,7 @@ @RunWith(RobolectricTestRunner.class) public class DataTransportMlEventSenderTest { + @Mock private TransportFactory mockTransportFactory; @Mock private Transport mockTransport; private DataTransportMlEventSender statsSender; @@ -48,7 +52,10 @@ public class DataTransportMlEventSenderTest { @Before public void setup() { MockitoAnnotations.initMocks(this); - statsSender = new DataTransportMlEventSender(mockTransport); + when(mockTransportFactory.getTransport( + any(), ArgumentMatchers.>any(), any(), any())) + .thenReturn(mockTransport); + statsSender = new DataTransportMlEventSender(() -> mockTransportFactory); } @Test diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/FirebaseMlLoggerTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/FirebaseMlLoggerTest.java index 7903278ccfe..501f17df49a 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/FirebaseMlLoggerTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/FirebaseMlLoggerTest.java @@ -18,6 +18,7 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; @@ -33,6 +34,7 @@ import com.google.firebase.ml.modeldownloader.BuildConfig; import com.google.firebase.ml.modeldownloader.CustomModel; import com.google.firebase.ml.modeldownloader.FirebaseMlException; +import com.google.firebase.ml.modeldownloader.FirebaseModelDownloader; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.DeleteModelLogEvent; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.EventName; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.ModelDownloadLogEvent; @@ -45,9 +47,7 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mock; import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; import org.robolectric.RobolectricTestRunner; @RunWith(RobolectricTestRunner.class) @@ -69,26 +69,42 @@ public class FirebaseMlLoggerTest { private static final String MODEL_HASH = "dsf324"; private static final long SYSTEM_TIME = 2000; private static final Long DOWNLOAD_ID = 987923L; - private static final CustomModel CUSTOM_MODEL_DOWNLOADING = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, DOWNLOAD_ID); + private CustomModel CUSTOM_MODEL_DOWNLOADING; private static final ModelOptions MODEL_OPTIONS = ModelOptions.builder() .setModelInfo(ModelInfo.builder().setName(MODEL_NAME).setHash(MODEL_HASH).build()) .build(); - @Mock private SharedPreferencesUtil mockSharedPreferencesUtil; - @Mock private DataTransportMlEventSender mockStatsSender; + private final SharedPreferencesUtil mockSharedPreferencesUtil = mock(SharedPreferencesUtil.class); + private final DataTransportMlEventSender mockStatsSender = mock(DataTransportMlEventSender.class); private FirebaseMlLogger mlLogger; + private CustomModel.Factory modelFactory; @Before public void setUp() throws NameNotFoundException { - MockitoAnnotations.initMocks(this); FirebaseApp.clearInstancesForTest(); FirebaseApp app = FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS); - mlLogger = new FirebaseMlLogger(app, mockSharedPreferencesUtil, mockStatsSender); + modelFactory = FirebaseModelDownloader.getInstance(app).getModelFactory(); + mlLogger = + new FirebaseMlLogger( + FIREBASE_OPTIONS, + mockSharedPreferencesUtil, + mockStatsSender, + () -> ApplicationProvider.getApplicationContext().getPackageName(), + () -> { + try { + return String.valueOf( + app.getApplicationContext() + .getPackageManager() + .getPackageInfo(app.getApplicationContext().getPackageName(), 0) + .versionCode); + } catch (NameNotFoundException e) { + return ""; + } + }); systemInfo = SystemInfo.builder() .setFirebaseProjectId(TEST_PROJECT_ID) @@ -109,6 +125,7 @@ public void setUp() throws NameNotFoundException { doNothing().when(mockSharedPreferencesUtil).setModelDownloadCompleteTimeMs(any(), anyLong()); when(mockSharedPreferencesUtil.getCustomModelStatsCollectionFlag()).thenReturn(true); SystemClock.setCurrentTimeMillis(SYSTEM_TIME + 500); + CUSTOM_MODEL_DOWNLOADING = modelFactory.create(MODEL_NAME, MODEL_HASH, 100, DOWNLOAD_ID); } @Test diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadServiceTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadServiceTest.java index 90e6eed211c..6f3c186bdfd 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadServiceTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileDownloadServiceTest.java @@ -26,6 +26,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -54,12 +55,11 @@ import java.util.Date; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; import org.robolectric.RobolectricTestRunner; import org.robolectric.annotation.LooperMode; @@ -83,16 +83,11 @@ public class ModelFileDownloadServiceTest { private static final long URL_EXPIRATION_FUTURE = (new Date()).getTime() + 600000; private static final Long DOWNLOAD_ID = 987923L; - private static final CustomModel CUSTOM_MODEL_PREVIOUS_LOADED = - new CustomModel(MODEL_NAME, MODEL_HASH + "2", 105, 0, "FakeFile/path.tflite"); - private static final CustomModel CUSTOM_MODEL_NO_URL = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0); - private static final CustomModel CUSTOM_MODEL_URL = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION_FUTURE); - private static final CustomModel CUSTOM_MODEL_EXPIRED_URL = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION_OLD); - private static final CustomModel CUSTOM_MODEL_DOWNLOADING = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, DOWNLOAD_ID); + private CustomModel CUSTOM_MODEL_PREVIOUS_LOADED; + private CustomModel CUSTOM_MODEL_NO_URL; + private CustomModel CUSTOM_MODEL_URL; + private CustomModel CUSTOM_MODEL_EXPIRED_URL; + private CustomModel CUSTOM_MODEL_DOWNLOADING; CustomModel customModelDownloadComplete; private static final CustomModelDownloadConditions DOWNLOAD_CONDITIONS_CHARGING_IDLE = @@ -101,52 +96,77 @@ public class ModelFileDownloadServiceTest { File testTempModelFile; File testAppModelFile; + private final DownloadManager mockDownloadManager = mock(DownloadManager.class); + private final ModelFileManager mockFileManager = mock(ModelFileManager.class); + private final FirebaseMlLogger mockStatsLogger = mock(FirebaseMlLogger.class); + + private final ExecutorService executor = Executors.newSingleThreadExecutor(); + private ModelFileDownloadService modelFileDownloadService; private ModelFileDownloadService modelFileDownloadServiceInitialLoad; private SharedPreferencesUtil sharedPreferencesUtil; - @Mock DownloadManager mockDownloadManager; - @Mock ModelFileManager mockFileManager; - @Mock FirebaseMlLogger mockStatsLogger; + private CustomModel.Factory modelFactory; - ExecutorService executor; private MatrixCursor matrixCursor; FirebaseApp app; @Before public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); FirebaseApp.clearInstancesForTest(); app = FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS); - executor = Executors.newSingleThreadExecutor(); - sharedPreferencesUtil = new SharedPreferencesUtil(app); + AtomicReference serviceRef = new AtomicReference<>(); + modelFactory = + (name, modelHash, fileSize, downloadId, localFilePath, downloadUrl, downloadUrlExpiry) -> + new CustomModel( + serviceRef.get(), + name, + modelHash, + fileSize, + downloadId, + localFilePath, + downloadUrl, + downloadUrlExpiry); + sharedPreferencesUtil = new SharedPreferencesUtil(app, modelFactory); sharedPreferencesUtil.clearModelDetails(MODEL_NAME); modelFileDownloadService = new ModelFileDownloadService( - app, + ApplicationProvider.getApplicationContext(), mockDownloadManager, mockFileManager, sharedPreferencesUtil, mockStatsLogger, - false); + false, + modelFactory); + serviceRef.set(modelFileDownloadService); modelFileDownloadServiceInitialLoad = new ModelFileDownloadService( - app, + ApplicationProvider.getApplicationContext(), mockDownloadManager, mockFileManager, sharedPreferencesUtil, mockStatsLogger, - true); + true, + modelFactory); + + CUSTOM_MODEL_PREVIOUS_LOADED = + modelFactory.create(MODEL_NAME, MODEL_HASH + "2", 105, 0, "FakeFile/path.tflite"); + CUSTOM_MODEL_NO_URL = modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0); + CUSTOM_MODEL_URL = + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION_FUTURE); + CUSTOM_MODEL_EXPIRED_URL = + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, URL_EXPIRATION_OLD); + CUSTOM_MODEL_DOWNLOADING = modelFactory.create(MODEL_NAME, MODEL_HASH, 100, DOWNLOAD_ID); matrixCursor = new MatrixCursor(new String[] {DownloadManager.COLUMN_STATUS}); testTempModelFile = File.createTempFile("fakeTempFile", ".tflite"); testAppModelFile = File.createTempFile("fakeAppFile", ".tflite"); customModelDownloadComplete = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, testAppModelFile.getPath()); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0, testAppModelFile.getPath()); } @After @@ -412,7 +432,7 @@ public void ensureModelDownloaded_downloadFailed_urlExpiry() { when(mockDownloadManager.query(any())).thenReturn(matrixCursor); CustomModel justAboutToExpireModel = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, (new Date()).getTime() + 3); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, (new Date()).getTime() + 3); TestOnCompleteListener onCompleteListener = new TestOnCompleteListener<>(); Task task = modelFileDownloadService.ensureModelDownloaded(justAboutToExpireModel); @@ -611,7 +631,7 @@ public void ensureModelDownloaded_alreadyInProgess_UrlExpired() throws Exception // set up the first request CustomModel justAboutToExpireModel = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, (new Date()).getTime() + 30); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, MODEL_URL, (new Date()).getTime() + 30); TestOnCompleteListener onCompleteListener = new TestOnCompleteListener<>(); Task task = modelFileDownloadService.ensureModelDownloaded(justAboutToExpireModel); task.addOnCompleteListener(executor, onCompleteListener); @@ -820,7 +840,8 @@ public void maybeCheckDownloadingComplete_downloadInprogress() { public void maybeCheckDownloadingComplete_multipleDownloads() throws Exception { sharedPreferencesUtil.setDownloadingCustomModelDetails(CUSTOM_MODEL_DOWNLOADING); String secondModelName = "secondModelName"; - CustomModel downloading2 = new CustomModel(secondModelName, MODEL_HASH, 100, DOWNLOAD_ID + 1); + CustomModel downloading2 = + modelFactory.create(secondModelName, MODEL_HASH, 100, DOWNLOAD_ID + 1); sharedPreferencesUtil.setDownloadingCustomModelDetails(downloading2); assertNull(modelFileDownloadService.getDownloadingModelStatusCode(0L)); @@ -839,7 +860,7 @@ public void maybeCheckDownloadingComplete_multipleDownloads() throws Exception { sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME), customModelDownloadComplete); assertEquals( sharedPreferencesUtil.getCustomModelDetails(secondModelName), - new CustomModel(secondModelName, MODEL_HASH, 100, 0, testAppModelFile.getPath())); + modelFactory.create(secondModelName, MODEL_HASH, 100, 0, testAppModelFile.getPath())); verify(mockDownloadManager, times(5)).query(any()); verify(mockDownloadManager, times(2)).remove(anyLong()); } diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManagerTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManagerTest.java index 8dac9c6dc4c..0c70673a63c 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManagerTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/ModelFileManagerTest.java @@ -25,6 +25,7 @@ import com.google.firebase.FirebaseOptions.Builder; import com.google.firebase.ml.modeldownloader.CustomModel; import com.google.firebase.ml.modeldownloader.FirebaseMlException; +import com.google.firebase.ml.modeldownloader.FirebaseModelDownloader; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; @@ -51,8 +52,8 @@ public class ModelFileManagerTest { public static final String MODEL_NAME_2 = "MODEL_NAME_2"; public static final String MODEL_HASH_2 = "hash2"; - final CustomModel CUSTOM_MODEL_NO_FILE = new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0); - final CustomModel CUSTOM_MODEL_NO_FILE_2 = new CustomModel(MODEL_NAME_2, MODEL_HASH_2, 101, 0); + private CustomModel CUSTOM_MODEL_NO_FILE; + private CustomModel CUSTOM_MODEL_NO_FILE_2; private File testModelFile; private File testModelFile2; @@ -60,7 +61,8 @@ public class ModelFileManagerTest { ModelFileManager fileManager; FirebaseApp app; private SharedPreferencesUtil sharedPreferencesUtil; - String modelDestinationFolder; + private String modelDestinationFolder; + private CustomModel.Factory modelFactory; @Before public void setUp() throws IOException { @@ -68,10 +70,18 @@ public void setUp() throws IOException { FirebaseApp.clearInstancesForTest(); app = FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS); - sharedPreferencesUtil = new SharedPreferencesUtil(app); - fileManager = new ModelFileManager(app); + modelFactory = FirebaseModelDownloader.getInstance(app).getModelFactory(); + + sharedPreferencesUtil = new SharedPreferencesUtil(app, modelFactory); + fileManager = + new ModelFileManager( + ApplicationProvider.getApplicationContext(), + app.getPersistenceKey(), + sharedPreferencesUtil); modelDestinationFolder = setUpTestingFiles(app, MODEL_NAME); + CUSTOM_MODEL_NO_FILE = modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0); + CUSTOM_MODEL_NO_FILE_2 = modelFactory.create(MODEL_NAME_2, MODEL_HASH_2, 101, 0); } private String setUpTestingFiles(FirebaseApp app, String modelName) throws IOException { @@ -181,7 +191,7 @@ public void deleteNonLatestCustomModels_fileToDelete() MoveFileToDestination(modelDestinationFolder, testModelFile2, CUSTOM_MODEL_NO_FILE, 1); sharedPreferencesUtil.setLoadedCustomModelDetails( - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, modelDestinationFolder + "/1")); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0, modelDestinationFolder + "/1")); fileManager.deleteNonLatestCustomModels(); assertFalse(new File(modelDestinationFolder + "/0").exists()); @@ -199,11 +209,11 @@ public void deleteNonLatestCustomModels_whenModelOnDiskButNotInPreferences() MoveFileToDestination(modelDestinationFolder2, testModelFile2, CUSTOM_MODEL_NO_FILE_2, 0); sharedPreferencesUtil.setLoadedCustomModelDetails( - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, modelDestinationFolder + "/0")); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0, modelDestinationFolder + "/0")); // Download in progress, hence file path is not present sharedPreferencesUtil.setLoadedCustomModelDetails( - new CustomModel(MODEL_NAME_2, MODEL_HASH_2, 100, 0)); + modelFactory.create(MODEL_NAME_2, MODEL_HASH_2, 100, 0)); fileManager.deleteNonLatestCustomModels(); @@ -217,7 +227,7 @@ public void deleteNonLatestCustomModels_noFileToDelete() MoveFileToDestination(modelDestinationFolder, testModelFile, CUSTOM_MODEL_NO_FILE, 0); sharedPreferencesUtil.setLoadedCustomModelDetails( - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, modelDestinationFolder + "/0")); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0, modelDestinationFolder + "/0")); fileManager.deleteNonLatestCustomModels(); assertTrue(new File(modelDestinationFolder + "/0").exists()); @@ -230,14 +240,14 @@ public void deleteNonLatestCustomModels_multipleNamedModels() MoveFileToDestination(modelDestinationFolder, testModelFile2, CUSTOM_MODEL_NO_FILE, 1); sharedPreferencesUtil.setLoadedCustomModelDetails( - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, modelDestinationFolder + "/1")); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0, modelDestinationFolder + "/1")); String modelDestinationFolder2 = setUpTestingFiles(app, MODEL_NAME_2); MoveFileToDestination(modelDestinationFolder2, testModelFile, CUSTOM_MODEL_NO_FILE_2, 0); MoveFileToDestination(modelDestinationFolder2, testModelFile2, CUSTOM_MODEL_NO_FILE_2, 1); sharedPreferencesUtil.setLoadedCustomModelDetails( - new CustomModel(MODEL_NAME_2, MODEL_HASH_2, 101, 0, modelDestinationFolder2 + "/1")); + modelFactory.create(MODEL_NAME_2, MODEL_HASH_2, 101, 0, modelDestinationFolder2 + "/1")); fileManager.deleteNonLatestCustomModels(); @@ -297,14 +307,14 @@ public void deleteOldModels_multipleNamedModels() throws FirebaseMlException, IO MoveFileToDestination(modelDestinationFolder, testModelFile2, CUSTOM_MODEL_NO_FILE, 1); sharedPreferencesUtil.setLoadedCustomModelDetails( - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, modelDestinationFolder + "/1")); + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0, modelDestinationFolder + "/1")); String modelDestinationFolder2 = setUpTestingFiles(app, MODEL_NAME_2); MoveFileToDestination(modelDestinationFolder2, testModelFile, CUSTOM_MODEL_NO_FILE_2, 0); MoveFileToDestination(modelDestinationFolder2, testModelFile2, CUSTOM_MODEL_NO_FILE_2, 1); sharedPreferencesUtil.setLoadedCustomModelDetails( - new CustomModel(MODEL_NAME_2, MODEL_HASH_2, 101, 0, modelDestinationFolder2 + "/1")); + modelFactory.create(MODEL_NAME_2, MODEL_HASH_2, 101, 0, modelDestinationFolder2 + "/1")); fileManager.deleteOldModels(MODEL_NAME, modelDestinationFolder + "/1"); diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java index 5bd1722b603..22c4fb442f4 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java @@ -25,6 +25,7 @@ import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; import com.google.firebase.ml.modeldownloader.CustomModel; +import com.google.firebase.ml.modeldownloader.FirebaseModelDownloader; import java.util.Set; import org.junit.Before; import org.junit.Test; @@ -39,14 +40,12 @@ public class SharedPreferencesUtilTest { private static final String TEST_PROJECT_ID = "777777777777"; private static final String MODEL_NAME = "ModelName"; private static final String MODEL_HASH = "dsf324"; - private static final CustomModel CUSTOM_MODEL_DOWNLOAD_COMPLETE = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 0, "file/path/store/ModelName/1"); - private static final CustomModel CUSTOM_MODEL_UPDATE_IN_BACKGROUND = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 986, "file/path/store/ModelName/1"); - private static final CustomModel CUSTOM_MODEL_DOWNLOADING = - new CustomModel(MODEL_NAME, MODEL_HASH, 100, 986); + private CustomModel CUSTOM_MODEL_DOWNLOAD_COMPLETE; + private CustomModel CUSTOM_MODEL_UPDATE_IN_BACKGROUND; + private CustomModel CUSTOM_MODEL_DOWNLOADING; private SharedPreferencesUtil sharedPreferencesUtil; private FirebaseApp app; + private CustomModel.Factory modelFactory; @Before public void setUp() { @@ -60,10 +59,17 @@ public void setUp() { .setProjectId(TEST_PROJECT_ID) .build()); + modelFactory = FirebaseModelDownloader.getInstance(app).getModelFactory(); + app.setDataCollectionDefaultEnabled(Boolean.TRUE); // default sharedPreferenceUtil - sharedPreferencesUtil = new SharedPreferencesUtil(app); + sharedPreferencesUtil = new SharedPreferencesUtil(app, modelFactory); assertNotNull(sharedPreferencesUtil); + CUSTOM_MODEL_DOWNLOAD_COMPLETE = + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 0, "file/path/store/ModelName/1"); + CUSTOM_MODEL_UPDATE_IN_BACKGROUND = + modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 986, "file/path/store/ModelName/1"); + CUSTOM_MODEL_DOWNLOADING = modelFactory.create(MODEL_NAME, MODEL_HASH, 100, 986); } @Test @@ -158,11 +164,13 @@ public void listDownloadedModels_multipleModels() { sharedPreferencesUtil.setLoadedCustomModelDetails(CUSTOM_MODEL_DOWNLOAD_COMPLETE); CustomModel model2 = - new CustomModel(MODEL_NAME + "2", MODEL_HASH + "2", 102, 0, "file/path/store/ModelName2/1"); + modelFactory.create( + MODEL_NAME + "2", MODEL_HASH + "2", 102, 0, "file/path/store/ModelName2/1"); sharedPreferencesUtil.setLoadedCustomModelDetails(model2); CustomModel model3 = - new CustomModel(MODEL_NAME + "3", MODEL_HASH + "3", 103, 0, "file/path/store/ModelName3/1"); + modelFactory.create( + MODEL_NAME + "3", MODEL_HASH + "3", 103, 0, "file/path/store/ModelName3/1"); sharedPreferencesUtil.setLoadedCustomModelDetails(model3); @@ -185,7 +193,7 @@ public void getCustomModelStatsCollectionFlag_defaultFirebaseAppTrue() { public void getCustomModelStatsCollectionFlag_defaultFirebaseAppFalse() { app.setDataCollectionDefaultEnabled(Boolean.FALSE); // default sharedPreferenceUtil - SharedPreferencesUtil disableLogUtil = new SharedPreferencesUtil(app); + SharedPreferencesUtil disableLogUtil = new SharedPreferencesUtil(app, modelFactory); assertEquals( disableLogUtil.getCustomModelStatsCollectionFlag(), app.isDataCollectionDefaultEnabled()); assertFalse(disableLogUtil.getCustomModelStatsCollectionFlag()); @@ -195,7 +203,7 @@ public void getCustomModelStatsCollectionFlag_defaultFirebaseAppFalse() { public void getCustomModelStatsCollectionFlag_overrideFirebaseAppFalse() { app.setDataCollectionDefaultEnabled(Boolean.FALSE); // default sharedPreferenceUtil - SharedPreferencesUtil sharedPreferencesUtil2 = new SharedPreferencesUtil(app); + SharedPreferencesUtil sharedPreferencesUtil2 = new SharedPreferencesUtil(app, modelFactory); sharedPreferencesUtil2.setCustomModelStatsCollectionEnabled(true); assertEquals(sharedPreferencesUtil2.getCustomModelStatsCollectionFlag(), true); assertTrue(sharedPreferencesUtil2.getCustomModelStatsCollectionFlag());