diff --git a/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle b/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle index 4b2187f8830..305e73befae 100644 --- a/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle +++ b/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle @@ -55,5 +55,6 @@ dependencies { testImplementation 'androidx.test:core:1.3.0' testImplementation 'com.google.truth:truth:1.0.1' testImplementation 'junit:junit:4.13' + testImplementation 'org.mockito:mockito-core:3.3.3' testImplementation "org.robolectric:robolectric:$robolectricVersion" } \ No newline at end of file diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java index 4e0fa3fad75..21499578aac 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java @@ -18,16 +18,34 @@ import androidx.annotation.VisibleForTesting; import com.google.android.gms.common.internal.Preconditions; import com.google.android.gms.tasks.Task; +import com.google.android.gms.tasks.TaskCompletionSource; import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; +import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil; import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; public class FirebaseModelDownloader { private final FirebaseOptions firebaseOptions; + private final SharedPreferencesUtil sharedPreferencesUtil; + private final Executor executor; - FirebaseModelDownloader(FirebaseOptions firebaseOptions) { + FirebaseModelDownloader(FirebaseApp firebaseApp) { + this.firebaseOptions = firebaseApp.getOptions(); + this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp); + this.executor = Executors.newCachedThreadPool(); + } + + @VisibleForTesting + FirebaseModelDownloader( + FirebaseOptions firebaseOptions, + SharedPreferencesUtil sharedPreferencesUtil, + Executor executor) { this.firebaseOptions = firebaseOptions; + this.sharedPreferencesUtil = sharedPreferencesUtil; + this.executor = executor; } /** @@ -84,7 +102,10 @@ public Task getModel( /** @return The set of all models that are downloaded to this device. */ @NonNull public Task> listDownloadedModels() { - throw new UnsupportedOperationException("Not yet implemented."); + TaskCompletionSource> taskCompletionSource = new TaskCompletionSource<>(); + executor.execute( + () -> taskCompletionSource.setResult(sharedPreferencesUtil.listDownloadedModels())); + return taskCompletionSource.getTask(); } /* diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java index 481dea140c2..a0aa70ab83b 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java @@ -16,7 +16,6 @@ import androidx.annotation.NonNull; import com.google.firebase.FirebaseApp; -import com.google.firebase.FirebaseOptions; import com.google.firebase.components.Component; import com.google.firebase.components.ComponentRegistrar; import com.google.firebase.components.Dependency; @@ -38,8 +37,7 @@ public List> getComponents() { return Arrays.asList( Component.builder(FirebaseModelDownloader.class) .add(Dependency.required(FirebaseApp.class)) - .add(Dependency.required(FirebaseOptions.class)) - .factory(c -> new FirebaseModelDownloader(c.get(FirebaseOptions.class))) + .factory(c -> new FirebaseModelDownloader(c.get(FirebaseApp.class))) .build(), LibraryVersionComponent.create("firebase-ml-modeldownloader", BuildConfig.VERSION_NAME)); } diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java index b1b90660ce3..b6b7c56663a 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java @@ -23,6 +23,10 @@ import androidx.annotation.VisibleForTesting; import com.google.firebase.FirebaseApp; import com.google.firebase.ml.modeldownloader.CustomModel; +import java.util.HashSet; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; /** @hide */ public class SharedPreferencesUtil { @@ -33,6 +37,8 @@ public class SharedPreferencesUtil { // local model details private static final String LOCAL_MODEL_HASH_PATTERN = "current_model_hash_%s_%s"; private static final String LOCAL_MODEL_FILE_PATH_PATTERN = "current_model_path_%s_%s"; + private static final String LOCAL_MODEL_FILE_PATH_MATCHER = "current_model_path_(.*?)_([^/]+)/?"; + private static final String LOCAL_MODEL_FILE_SIZE_PATTERN = "current_model_size_%s_%s"; // details about model during download. private static final String DOWNLOADING_MODEL_HASH_PATTERN = "downloading_model_hash_%s_%s"; @@ -190,6 +196,41 @@ public synchronized void clearModelDetails(@NonNull String modelName, boolean cl .commit(); } + public synchronized Set listDownloadedModels() { + Set customModels = new HashSet<>(); + Set keySet = getSharedPreferences().getAll().keySet(); + + for (String key : keySet) { + // if a local file path is present - get model details. + Matcher matcher = Pattern.compile(LOCAL_MODEL_FILE_PATH_MATCHER).matcher(key); + if (matcher.find()) { + String modelName = matcher.group(matcher.groupCount()); + CustomModel extractModel = getCustomModelDetails(modelName); + if (extractModel != null) { + customModels.add(extractModel); + } + } else { + matcher = Pattern.compile(DOWNLOADING_MODEL_ID_PATTERN).matcher(key); + if (matcher.find()) { + String modelName = matcher.group(matcher.groupCount()); + CustomModel extractModel = maybeGetUpdatedModel(modelName); + if (extractModel != null) { + customModels.add(extractModel); + } + } + } + } + return customModels; + } + + synchronized CustomModel maybeGetUpdatedModel(String modelName) { + CustomModel downloadModel = getCustomModelDetails(modelName); + // TODO(annz) check here if download currently in progress have completed. + // if yes, then complete file relocation and return the updated model, otherwise return null + + return null; + } + /** * Clears all stored data related to a custom model download. * diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java index 05d14f2afea..be3a1c214ce 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java @@ -14,34 +14,60 @@ package com.google.firebase.ml.modeldownloader; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.when; import androidx.test.core.app.ApplicationProvider; +import com.google.android.gms.tasks.Task; import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; +import com.google.firebase.FirebaseOptions.Builder; +import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.robolectric.RobolectricTestRunner; @RunWith(RobolectricTestRunner.class) public class FirebaseModelDownloaderTest { public static final String TEST_PROJECT_ID = "777777777777"; + public static final FirebaseOptions FIREBASE_OPTIONS = + new Builder() + .setApplicationId("1:123456789:android:abcdef") + .setProjectId(TEST_PROJECT_ID) + .build(); public static final String MODEL_NAME = "MODEL_NAME_1"; public static final CustomModelDownloadConditions DEFAULT_DOWNLOAD_CONDITIONS = new CustomModelDownloadConditions.Builder().build(); + public static final String MODEL_HASH = "dsf324"; + // TODO replace with uploaded model. + CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, 0, 100, MODEL_HASH); + + FirebaseModelDownloader firebaseModelDownloader; + @Mock SharedPreferencesUtil mockPrefs; + + ExecutorService executor; + @Before public void setUp() { + MockitoAnnotations.initMocks(this); FirebaseApp.clearInstancesForTest(); // default app - FirebaseApp.initializeApp( - ApplicationProvider.getApplicationContext(), - new FirebaseOptions.Builder() - .setApplicationId("1:123456789:android:abcdef") - .setProjectId(TEST_PROJECT_ID) - .build()); + FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS); + + executor = Executors.newSingleThreadExecutor(); + firebaseModelDownloader = new FirebaseModelDownloader(FIREBASE_OPTIONS, mockPrefs, executor); } @Test @@ -54,10 +80,30 @@ public void getModel_unimplemented() { } @Test - public void listDownloadedModels_unimplemented() { - assertThrows( - UnsupportedOperationException.class, - () -> FirebaseModelDownloader.getInstance().listDownloadedModels()); + public void listDownloadedModels_returnsEmptyModelList() + throws ExecutionException, InterruptedException { + when(mockPrefs.listDownloadedModels()).thenReturn(Collections.emptySet()); + TestOnCompleteListener> onCompleteListener = new TestOnCompleteListener<>(); + Task> task = firebaseModelDownloader.listDownloadedModels(); + task.addOnCompleteListener(executor, onCompleteListener); + Set customModelSet = onCompleteListener.await(); + + assertThat(task.isComplete()).isTrue(); + assertEquals(customModelSet, Collections.EMPTY_SET); + } + + @Test + public void listDownloadedModels_returnsModelList() + throws ExecutionException, InterruptedException { + when(mockPrefs.listDownloadedModels()).thenReturn(Collections.singleton(CUSTOM_MODEL)); + + TestOnCompleteListener> onCompleteListener = new TestOnCompleteListener<>(); + Task> task = firebaseModelDownloader.listDownloadedModels(); + task.addOnCompleteListener(executor, onCompleteListener); + Set customModelSet = onCompleteListener.await(); + + assertThat(task.isComplete()).isTrue(); + assertEquals(customModelSet, Collections.singleton(CUSTOM_MODEL)); } @Test diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/TestOnCompleteListener.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/TestOnCompleteListener.java new file mode 100644 index 00000000000..53ab6d00879 --- /dev/null +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/TestOnCompleteListener.java @@ -0,0 +1,68 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.firebase.ml.modeldownloader; + +import androidx.annotation.NonNull; +import com.google.android.gms.tasks.OnCompleteListener; +import com.google.android.gms.tasks.Task; +import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +/** + * Helper listener that works around a limitation of the Tasks API where await() cannot be called on + * the main thread. This listener works around it by running itself on a different thread, thus + * allowing the main thread to be woken up when the Tasks complete. + */ +public class TestOnCompleteListener implements OnCompleteListener { + private static final long TIMEOUT_MS = 5000; + private final CountDownLatch latch = new CountDownLatch(1); + private Task task; + private volatile TResult result; + private volatile Exception exception; + private volatile boolean successful; + + @Override + public void onComplete(@NonNull Task task) { + this.task = task; + successful = task.isSuccessful(); + if (successful) { + result = task.getResult(); + } else { + exception = task.getException(); + } + latch.countDown(); + } + + /** Blocks until the {@link #onComplete} is called. */ + public TResult await() throws InterruptedException, ExecutionException { + if (!latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)) { + throw new InterruptedException("timed out waiting for result"); + } + if (successful) { + return result; + } else { + if (exception instanceof InterruptedException) { + throw (InterruptedException) exception; + } + // todo(annz) add firebase ml exception handling here. + if (exception instanceof IOException) { + throw new ExecutionException(exception); + } + throw new IllegalStateException("got an unexpected exception type", exception); + } + } +} diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java index 05c7314fa30..60a62163e12 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java @@ -17,11 +17,13 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import androidx.test.core.app.ApplicationProvider; import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; import com.google.firebase.ml.modeldownloader.CustomModel; +import java.util.Set; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -118,4 +120,43 @@ public void clearDownloadingModelDetails_keepsLocalModel() throws IllegalArgumen retrievedModel = sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME); assertEquals(retrievedModel, CUSTOM_MODEL_DOWNLOAD_COMPLETE); } + + @Test + public void listDownloadedModels_localModelFound() throws IllegalArgumentException { + sharedPreferencesUtil.setUploadedCustomModelDetails(CUSTOM_MODEL_DOWNLOAD_COMPLETE); + Set retrievedModel = sharedPreferencesUtil.listDownloadedModels(); + assertEquals(retrievedModel.size(), 1); + assertEquals(retrievedModel.iterator().next(), CUSTOM_MODEL_DOWNLOAD_COMPLETE); + } + + @Test + public void listDownloadedModels_downloadingModelNotFound() throws IllegalArgumentException { + sharedPreferencesUtil.setDownloadingCustomModelDetails(CUSTOM_MODEL_DOWNLOADING); + assertEquals(sharedPreferencesUtil.listDownloadedModels().size(), 0); + } + + @Test + public void listDownloadedModels_noModels() throws IllegalArgumentException { + assertEquals(sharedPreferencesUtil.listDownloadedModels().size(), 0); + } + + @Test + public void listDownloadedModels_multipleModels() throws IllegalArgumentException { + sharedPreferencesUtil.setUploadedCustomModelDetails(CUSTOM_MODEL_DOWNLOAD_COMPLETE); + + CustomModel model2 = + new CustomModel(MODEL_NAME + "2", 0, 102, MODEL_HASH + "2", "file/path/store/ModelName2/1"); + sharedPreferencesUtil.setUploadedCustomModelDetails(model2); + + CustomModel model3 = + new CustomModel(MODEL_NAME + "3", 0, 103, MODEL_HASH + "3", "file/path/store/ModelName3/1"); + + sharedPreferencesUtil.setUploadedCustomModelDetails(model3); + + Set retrievedModel = sharedPreferencesUtil.listDownloadedModels(); + assertEquals(retrievedModel.size(), 3); + assertTrue(retrievedModel.contains(CUSTOM_MODEL_DOWNLOAD_COMPLETE)); + assertTrue(retrievedModel.contains(model2)); + assertTrue(retrievedModel.contains(model3)); + } }