Skip to content

Commit 4eca657

Browse files
authored
Add implementation for listDownloadedModels. (#2154)
* Add implementation for listModels. * Update to background thread execution for listDownloadModels call. * Update to background thread execution for listDownloadModels call. * update executor usages in unit tests.
1 parent 6e226ea commit 4eca657

File tree

7 files changed

+231
-15
lines changed

7 files changed

+231
-15
lines changed

firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,6 @@ dependencies {
5555
testImplementation 'androidx.test:core:1.3.0'
5656
testImplementation 'com.google.truth:truth:1.0.1'
5757
testImplementation 'junit:junit:4.13'
58+
testImplementation 'org.mockito:mockito-core:3.3.3'
5859
testImplementation "org.robolectric:robolectric:$robolectricVersion"
5960
}

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java

+23-2
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,34 @@
1818
import androidx.annotation.VisibleForTesting;
1919
import com.google.android.gms.common.internal.Preconditions;
2020
import com.google.android.gms.tasks.Task;
21+
import com.google.android.gms.tasks.TaskCompletionSource;
2122
import com.google.firebase.FirebaseApp;
2223
import com.google.firebase.FirebaseOptions;
24+
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
2325
import java.util.Set;
26+
import java.util.concurrent.Executor;
27+
import java.util.concurrent.Executors;
2428

2529
public class FirebaseModelDownloader {
2630

2731
private final FirebaseOptions firebaseOptions;
32+
private final SharedPreferencesUtil sharedPreferencesUtil;
33+
private final Executor executor;
2834

29-
FirebaseModelDownloader(FirebaseOptions firebaseOptions) {
35+
FirebaseModelDownloader(FirebaseApp firebaseApp) {
36+
this.firebaseOptions = firebaseApp.getOptions();
37+
this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp);
38+
this.executor = Executors.newCachedThreadPool();
39+
}
40+
41+
@VisibleForTesting
42+
FirebaseModelDownloader(
43+
FirebaseOptions firebaseOptions,
44+
SharedPreferencesUtil sharedPreferencesUtil,
45+
Executor executor) {
3046
this.firebaseOptions = firebaseOptions;
47+
this.sharedPreferencesUtil = sharedPreferencesUtil;
48+
this.executor = executor;
3149
}
3250

3351
/**
@@ -84,7 +102,10 @@ public Task<CustomModel> getModel(
84102
/** @return The set of all models that are downloaded to this device. */
85103
@NonNull
86104
public Task<Set<CustomModel>> listDownloadedModels() {
87-
throw new UnsupportedOperationException("Not yet implemented.");
105+
TaskCompletionSource<Set<CustomModel>> taskCompletionSource = new TaskCompletionSource<>();
106+
executor.execute(
107+
() -> taskCompletionSource.setResult(sharedPreferencesUtil.listDownloadedModels()));
108+
return taskCompletionSource.getTask();
88109
}
89110

90111
/*

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import androidx.annotation.NonNull;
1818
import com.google.firebase.FirebaseApp;
19-
import com.google.firebase.FirebaseOptions;
2019
import com.google.firebase.components.Component;
2120
import com.google.firebase.components.ComponentRegistrar;
2221
import com.google.firebase.components.Dependency;
@@ -38,8 +37,7 @@ public List<Component<?>> getComponents() {
3837
return Arrays.asList(
3938
Component.builder(FirebaseModelDownloader.class)
4039
.add(Dependency.required(FirebaseApp.class))
41-
.add(Dependency.required(FirebaseOptions.class))
42-
.factory(c -> new FirebaseModelDownloader(c.get(FirebaseOptions.class)))
40+
.factory(c -> new FirebaseModelDownloader(c.get(FirebaseApp.class)))
4341
.build(),
4442
LibraryVersionComponent.create("firebase-ml-modeldownloader", BuildConfig.VERSION_NAME));
4543
}

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtil.java

+41
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
import androidx.annotation.VisibleForTesting;
2424
import com.google.firebase.FirebaseApp;
2525
import com.google.firebase.ml.modeldownloader.CustomModel;
26+
import java.util.HashSet;
27+
import java.util.Set;
28+
import java.util.regex.Matcher;
29+
import java.util.regex.Pattern;
2630

2731
/** @hide */
2832
public class SharedPreferencesUtil {
@@ -33,6 +37,8 @@ public class SharedPreferencesUtil {
3337
// local model details
3438
private static final String LOCAL_MODEL_HASH_PATTERN = "current_model_hash_%s_%s";
3539
private static final String LOCAL_MODEL_FILE_PATH_PATTERN = "current_model_path_%s_%s";
40+
private static final String LOCAL_MODEL_FILE_PATH_MATCHER = "current_model_path_(.*?)_([^/]+)/?";
41+
3642
private static final String LOCAL_MODEL_FILE_SIZE_PATTERN = "current_model_size_%s_%s";
3743
// details about model during download.
3844
private static final String DOWNLOADING_MODEL_HASH_PATTERN = "downloading_model_hash_%s_%s";
@@ -190,6 +196,41 @@ public synchronized void clearModelDetails(@NonNull String modelName, boolean cl
190196
.commit();
191197
}
192198

199+
public synchronized Set<CustomModel> listDownloadedModels() {
200+
Set<CustomModel> customModels = new HashSet<>();
201+
Set<String> keySet = getSharedPreferences().getAll().keySet();
202+
203+
for (String key : keySet) {
204+
// if a local file path is present - get model details.
205+
Matcher matcher = Pattern.compile(LOCAL_MODEL_FILE_PATH_MATCHER).matcher(key);
206+
if (matcher.find()) {
207+
String modelName = matcher.group(matcher.groupCount());
208+
CustomModel extractModel = getCustomModelDetails(modelName);
209+
if (extractModel != null) {
210+
customModels.add(extractModel);
211+
}
212+
} else {
213+
matcher = Pattern.compile(DOWNLOADING_MODEL_ID_PATTERN).matcher(key);
214+
if (matcher.find()) {
215+
String modelName = matcher.group(matcher.groupCount());
216+
CustomModel extractModel = maybeGetUpdatedModel(modelName);
217+
if (extractModel != null) {
218+
customModels.add(extractModel);
219+
}
220+
}
221+
}
222+
}
223+
return customModels;
224+
}
225+
226+
synchronized CustomModel maybeGetUpdatedModel(String modelName) {
227+
CustomModel downloadModel = getCustomModelDetails(modelName);
228+
// TODO(annz) check here if download currently in progress have completed.
229+
// if yes, then complete file relocation and return the updated model, otherwise return null
230+
231+
return null;
232+
}
233+
193234
/**
194235
* Clears all stored data related to a custom model download.
195236
*

firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java

+56-10
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,60 @@
1414

1515
package com.google.firebase.ml.modeldownloader;
1616

17+
import static com.google.common.truth.Truth.assertThat;
18+
import static org.junit.Assert.assertEquals;
1719
import static org.junit.Assert.assertThrows;
20+
import static org.mockito.Mockito.when;
1821

1922
import androidx.test.core.app.ApplicationProvider;
23+
import com.google.android.gms.tasks.Task;
2024
import com.google.firebase.FirebaseApp;
2125
import com.google.firebase.FirebaseOptions;
26+
import com.google.firebase.FirebaseOptions.Builder;
27+
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
28+
import java.util.Collections;
29+
import java.util.Set;
30+
import java.util.concurrent.ExecutionException;
31+
import java.util.concurrent.ExecutorService;
32+
import java.util.concurrent.Executors;
2233
import org.junit.Before;
2334
import org.junit.Test;
2435
import org.junit.runner.RunWith;
36+
import org.mockito.Mock;
37+
import org.mockito.MockitoAnnotations;
2538
import org.robolectric.RobolectricTestRunner;
2639

2740
@RunWith(RobolectricTestRunner.class)
2841
public class FirebaseModelDownloaderTest {
2942

3043
public static final String TEST_PROJECT_ID = "777777777777";
44+
public static final FirebaseOptions FIREBASE_OPTIONS =
45+
new Builder()
46+
.setApplicationId("1:123456789:android:abcdef")
47+
.setProjectId(TEST_PROJECT_ID)
48+
.build();
3149
public static final String MODEL_NAME = "MODEL_NAME_1";
3250
public static final CustomModelDownloadConditions DEFAULT_DOWNLOAD_CONDITIONS =
3351
new CustomModelDownloadConditions.Builder().build();
3452

53+
public static final String MODEL_HASH = "dsf324";
54+
// TODO replace with uploaded model.
55+
CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, 0, 100, MODEL_HASH);
56+
57+
FirebaseModelDownloader firebaseModelDownloader;
58+
@Mock SharedPreferencesUtil mockPrefs;
59+
60+
ExecutorService executor;
61+
3562
@Before
3663
public void setUp() {
64+
MockitoAnnotations.initMocks(this);
3765
FirebaseApp.clearInstancesForTest();
3866
// default app
39-
FirebaseApp.initializeApp(
40-
ApplicationProvider.getApplicationContext(),
41-
new FirebaseOptions.Builder()
42-
.setApplicationId("1:123456789:android:abcdef")
43-
.setProjectId(TEST_PROJECT_ID)
44-
.build());
67+
FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS);
68+
69+
executor = Executors.newSingleThreadExecutor();
70+
firebaseModelDownloader = new FirebaseModelDownloader(FIREBASE_OPTIONS, mockPrefs, executor);
4571
}
4672

4773
@Test
@@ -54,10 +80,30 @@ public void getModel_unimplemented() {
5480
}
5581

5682
@Test
57-
public void listDownloadedModels_unimplemented() {
58-
assertThrows(
59-
UnsupportedOperationException.class,
60-
() -> FirebaseModelDownloader.getInstance().listDownloadedModels());
83+
public void listDownloadedModels_returnsEmptyModelList()
84+
throws ExecutionException, InterruptedException {
85+
when(mockPrefs.listDownloadedModels()).thenReturn(Collections.emptySet());
86+
TestOnCompleteListener<Set<CustomModel>> onCompleteListener = new TestOnCompleteListener<>();
87+
Task<Set<CustomModel>> task = firebaseModelDownloader.listDownloadedModels();
88+
task.addOnCompleteListener(executor, onCompleteListener);
89+
Set<CustomModel> customModelSet = onCompleteListener.await();
90+
91+
assertThat(task.isComplete()).isTrue();
92+
assertEquals(customModelSet, Collections.EMPTY_SET);
93+
}
94+
95+
@Test
96+
public void listDownloadedModels_returnsModelList()
97+
throws ExecutionException, InterruptedException {
98+
when(mockPrefs.listDownloadedModels()).thenReturn(Collections.singleton(CUSTOM_MODEL));
99+
100+
TestOnCompleteListener<Set<CustomModel>> onCompleteListener = new TestOnCompleteListener<>();
101+
Task<Set<CustomModel>> task = firebaseModelDownloader.listDownloadedModels();
102+
task.addOnCompleteListener(executor, onCompleteListener);
103+
Set<CustomModel> customModelSet = onCompleteListener.await();
104+
105+
assertThat(task.isComplete()).isTrue();
106+
assertEquals(customModelSet, Collections.singleton(CUSTOM_MODEL));
61107
}
62108

63109
@Test
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright 2020 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package com.google.firebase.ml.modeldownloader;
16+
17+
import androidx.annotation.NonNull;
18+
import com.google.android.gms.tasks.OnCompleteListener;
19+
import com.google.android.gms.tasks.Task;
20+
import java.io.IOException;
21+
import java.util.concurrent.CountDownLatch;
22+
import java.util.concurrent.ExecutionException;
23+
import java.util.concurrent.TimeUnit;
24+
25+
/**
26+
* Helper listener that works around a limitation of the Tasks API where await() cannot be called on
27+
* the main thread. This listener works around it by running itself on a different thread, thus
28+
* allowing the main thread to be woken up when the Tasks complete.
29+
*/
30+
public class TestOnCompleteListener<TResult> implements OnCompleteListener<TResult> {
31+
private static final long TIMEOUT_MS = 5000;
32+
private final CountDownLatch latch = new CountDownLatch(1);
33+
private Task<TResult> task;
34+
private volatile TResult result;
35+
private volatile Exception exception;
36+
private volatile boolean successful;
37+
38+
@Override
39+
public void onComplete(@NonNull Task<TResult> task) {
40+
this.task = task;
41+
successful = task.isSuccessful();
42+
if (successful) {
43+
result = task.getResult();
44+
} else {
45+
exception = task.getException();
46+
}
47+
latch.countDown();
48+
}
49+
50+
/** Blocks until the {@link #onComplete} is called. */
51+
public TResult await() throws InterruptedException, ExecutionException {
52+
if (!latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)) {
53+
throw new InterruptedException("timed out waiting for result");
54+
}
55+
if (successful) {
56+
return result;
57+
} else {
58+
if (exception instanceof InterruptedException) {
59+
throw (InterruptedException) exception;
60+
}
61+
// todo(annz) add firebase ml exception handling here.
62+
if (exception instanceof IOException) {
63+
throw new ExecutionException(exception);
64+
}
65+
throw new IllegalStateException("got an unexpected exception type", exception);
66+
}
67+
}
68+
}

firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/internal/SharedPreferencesUtilTest.java

+41
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
import static org.junit.Assert.assertEquals;
1818
import static org.junit.Assert.assertNotNull;
1919
import static org.junit.Assert.assertNull;
20+
import static org.junit.Assert.assertTrue;
2021

2122
import androidx.test.core.app.ApplicationProvider;
2223
import com.google.firebase.FirebaseApp;
2324
import com.google.firebase.FirebaseOptions;
2425
import com.google.firebase.ml.modeldownloader.CustomModel;
26+
import java.util.Set;
2527
import org.junit.Before;
2628
import org.junit.Test;
2729
import org.junit.runner.RunWith;
@@ -118,4 +120,43 @@ public void clearDownloadingModelDetails_keepsLocalModel() throws IllegalArgumen
118120
retrievedModel = sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME);
119121
assertEquals(retrievedModel, CUSTOM_MODEL_DOWNLOAD_COMPLETE);
120122
}
123+
124+
@Test
125+
public void listDownloadedModels_localModelFound() throws IllegalArgumentException {
126+
sharedPreferencesUtil.setUploadedCustomModelDetails(CUSTOM_MODEL_DOWNLOAD_COMPLETE);
127+
Set<CustomModel> retrievedModel = sharedPreferencesUtil.listDownloadedModels();
128+
assertEquals(retrievedModel.size(), 1);
129+
assertEquals(retrievedModel.iterator().next(), CUSTOM_MODEL_DOWNLOAD_COMPLETE);
130+
}
131+
132+
@Test
133+
public void listDownloadedModels_downloadingModelNotFound() throws IllegalArgumentException {
134+
sharedPreferencesUtil.setDownloadingCustomModelDetails(CUSTOM_MODEL_DOWNLOADING);
135+
assertEquals(sharedPreferencesUtil.listDownloadedModels().size(), 0);
136+
}
137+
138+
@Test
139+
public void listDownloadedModels_noModels() throws IllegalArgumentException {
140+
assertEquals(sharedPreferencesUtil.listDownloadedModels().size(), 0);
141+
}
142+
143+
@Test
144+
public void listDownloadedModels_multipleModels() throws IllegalArgumentException {
145+
sharedPreferencesUtil.setUploadedCustomModelDetails(CUSTOM_MODEL_DOWNLOAD_COMPLETE);
146+
147+
CustomModel model2 =
148+
new CustomModel(MODEL_NAME + "2", 0, 102, MODEL_HASH + "2", "file/path/store/ModelName2/1");
149+
sharedPreferencesUtil.setUploadedCustomModelDetails(model2);
150+
151+
CustomModel model3 =
152+
new CustomModel(MODEL_NAME + "3", 0, 103, MODEL_HASH + "3", "file/path/store/ModelName3/1");
153+
154+
sharedPreferencesUtil.setUploadedCustomModelDetails(model3);
155+
156+
Set<CustomModel> retrievedModel = sharedPreferencesUtil.listDownloadedModels();
157+
assertEquals(retrievedModel.size(), 3);
158+
assertTrue(retrievedModel.contains(CUSTOM_MODEL_DOWNLOAD_COMPLETE));
159+
assertTrue(retrievedModel.contains(model2));
160+
assertTrue(retrievedModel.contains(model3));
161+
}
121162
}

0 commit comments

Comments
 (0)