Skip to content

Add implementation for listDownloadedModels. #2154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@ dependencies {
testImplementation 'androidx.test:core:1.3.0'
testImplementation 'com.google.truth:truth:1.0.1'
testImplementation 'junit:junit:4.13'
testImplementation 'org.mockito:mockito-core:3.3.3'
testImplementation "org.robolectric:robolectric:$robolectricVersion"
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,34 @@
import androidx.annotation.VisibleForTesting;
import com.google.android.gms.common.internal.Preconditions;
import com.google.android.gms.tasks.Task;
import com.google.android.gms.tasks.TaskCompletionSource;
import com.google.firebase.FirebaseApp;
import com.google.firebase.FirebaseOptions;
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;

public class FirebaseModelDownloader {

private final FirebaseOptions firebaseOptions;
private final SharedPreferencesUtil sharedPreferencesUtil;
private final Executor executor;

FirebaseModelDownloader(FirebaseOptions firebaseOptions) {
FirebaseModelDownloader(FirebaseApp firebaseApp) {
this.firebaseOptions = firebaseApp.getOptions();
this.sharedPreferencesUtil = new SharedPreferencesUtil(firebaseApp);
this.executor = Executors.newCachedThreadPool();
}

@VisibleForTesting
FirebaseModelDownloader(
FirebaseOptions firebaseOptions,
SharedPreferencesUtil sharedPreferencesUtil,
Executor executor) {
this.firebaseOptions = firebaseOptions;
this.sharedPreferencesUtil = sharedPreferencesUtil;
this.executor = executor;
}

/**
Expand Down Expand Up @@ -84,7 +102,10 @@ public Task<CustomModel> getModel(
/** @return The set of all models that are downloaded to this device. */
@NonNull
public Task<Set<CustomModel>> listDownloadedModels() {
throw new UnsupportedOperationException("Not yet implemented.");
TaskCompletionSource<Set<CustomModel>> taskCompletionSource = new TaskCompletionSource<>();
executor.execute(
() -> taskCompletionSource.setResult(sharedPreferencesUtil.listDownloadedModels()));
return taskCompletionSource.getTask();
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import androidx.annotation.NonNull;
import com.google.firebase.FirebaseApp;
import com.google.firebase.FirebaseOptions;
import com.google.firebase.components.Component;
import com.google.firebase.components.ComponentRegistrar;
import com.google.firebase.components.Dependency;
Expand All @@ -38,8 +37,7 @@ public List<Component<?>> getComponents() {
return Arrays.asList(
Component.builder(FirebaseModelDownloader.class)
.add(Dependency.required(FirebaseApp.class))
.add(Dependency.required(FirebaseOptions.class))
.factory(c -> new FirebaseModelDownloader(c.get(FirebaseOptions.class)))
.factory(c -> new FirebaseModelDownloader(c.get(FirebaseApp.class)))
.build(),
LibraryVersionComponent.create("firebase-ml-modeldownloader", BuildConfig.VERSION_NAME));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import androidx.annotation.VisibleForTesting;
import com.google.firebase.FirebaseApp;
import com.google.firebase.ml.modeldownloader.CustomModel;
import java.util.HashSet;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/** @hide */
public class SharedPreferencesUtil {
Expand All @@ -33,6 +37,8 @@ public class SharedPreferencesUtil {
// local model details
private static final String LOCAL_MODEL_HASH_PATTERN = "current_model_hash_%s_%s";
private static final String LOCAL_MODEL_FILE_PATH_PATTERN = "current_model_path_%s_%s";
private static final String LOCAL_MODEL_FILE_PATH_MATCHER = "current_model_path_(.*?)_([^/]+)/?";

private static final String LOCAL_MODEL_FILE_SIZE_PATTERN = "current_model_size_%s_%s";
// details about model during download.
private static final String DOWNLOADING_MODEL_HASH_PATTERN = "downloading_model_hash_%s_%s";
Expand Down Expand Up @@ -190,6 +196,41 @@ public synchronized void clearModelDetails(@NonNull String modelName, boolean cl
.commit();
}

public synchronized Set<CustomModel> listDownloadedModels() {
Set<CustomModel> customModels = new HashSet<>();
Set<String> keySet = getSharedPreferences().getAll().keySet();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're reading all the keys at once and not dealing with the sharedpreferences anymore, it may not be necessary to sync the whole method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second part (added todo) will need to coordinate with android download manager, so I'll need the sync when I add that.


for (String key : keySet) {
// if a local file path is present - get model details.
Matcher matcher = Pattern.compile(LOCAL_MODEL_FILE_PATH_MATCHER).matcher(key);
if (matcher.find()) {
String modelName = matcher.group(matcher.groupCount());
CustomModel extractModel = getCustomModelDetails(modelName);
if (extractModel != null) {
customModels.add(extractModel);
}
} else {
matcher = Pattern.compile(DOWNLOADING_MODEL_ID_PATTERN).matcher(key);
if (matcher.find()) {
String modelName = matcher.group(matcher.groupCount());
CustomModel extractModel = isDownloadCompleted(modelName);
if (extractModel != null) {
customModels.add(extractModel);
}
}
}
}
return customModels;
}

synchronized CustomModel isDownloadCompleted(String modelName) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the is prefix in the name suggests a boolean response. What about using maybe, like maybeGetUpdatedModel? It's a bit more verbose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

CustomModel downloadModel = getCustomModelDetails(modelName);
// TODO(annz) check here if download currently in progress have completed.
// if yes, then complete file relocation and return the updated model, otherwise return null

return null;
}

/**
* Clears all stored data related to a custom model download.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,72 @@

package com.google.firebase.ml.modeldownloader;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.when;

import androidx.test.core.app.ApplicationProvider;
import com.google.android.gms.tasks.Task;
import com.google.firebase.FirebaseApp;
import com.google.firebase.FirebaseOptions;
import com.google.firebase.FirebaseOptions.Builder;
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.robolectric.RobolectricTestRunner;

@RunWith(RobolectricTestRunner.class)
public class FirebaseModelDownloaderTest {

public static final String TEST_PROJECT_ID = "777777777777";
public static final FirebaseOptions FIREBASE_OPTIONS =
new Builder()
.setApplicationId("1:123456789:android:abcdef")
.setProjectId(TEST_PROJECT_ID)
.build();
public static final String MODEL_NAME = "MODEL_NAME_1";
public static final CustomModelDownloadConditions DEFAULT_DOWNLOAD_CONDITIONS =
new CustomModelDownloadConditions.Builder().build();

public static final String MODEL_HASH = "dsf324";
// TODO replace with uploaded model.
CustomModel CUSTOM_MODEL = new CustomModel(MODEL_NAME, 0, 100, MODEL_HASH);

FirebaseModelDownloader firebaseModelDownloader;
@Mock SharedPreferencesUtil mockPrefs;

ExecutorService executor;

@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
FirebaseApp.clearInstancesForTest();
// default app
FirebaseApp.initializeApp(
ApplicationProvider.getApplicationContext(),
new FirebaseOptions.Builder()
.setApplicationId("1:123456789:android:abcdef")
.setProjectId(TEST_PROJECT_ID)
.build());
FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext(), FIREBASE_OPTIONS);

executor = new ThreadPoolExecutor(0, 1, 30L, TimeUnit.SECONDS, new LinkedBlockingQueue<>());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use a singleThread executor like in

That way the test shouldn't finish before running the tasks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

firebaseModelDownloader = new FirebaseModelDownloader(FIREBASE_OPTIONS, mockPrefs, executor);
}

@After
public void cleanUp() {
try {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK this is no longer needed when using the single thread executor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks - going too fast.

executor.awaitTermination(250, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
// do nothing.
}
}

@Test
Expand All @@ -54,10 +92,32 @@ public void getModel_unimplemented() {
}

@Test
public void listDownloadedModels_unimplemented() {
assertThrows(
UnsupportedOperationException.class,
() -> FirebaseModelDownloader.getInstance().listDownloadedModels());
public void listDownloadedModels_returnsEmptyModelList()
throws ExecutionException, InterruptedException {
when(mockPrefs.listDownloadedModels()).thenReturn(Collections.emptySet());
TestOnCompleteListener<Set<CustomModel>> onCompleteListener = new TestOnCompleteListener<>();
Task<Set<CustomModel>> task = firebaseModelDownloader.listDownloadedModels();
task.addOnCompleteListener(executor, onCompleteListener);
Set<CustomModel> customModelSet = onCompleteListener.await();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if Tasks.await(task) could be used instead here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had tried that got - "java.lang.IllegalStateException: Must not be called on the main application thread", found this solution was used elsewhere.


assertThat(task.isComplete()).isTrue();
assertEquals(customModelSet, Collections.EMPTY_SET);
}

@Test
public void listDownloadedModels_returnsModelList()
throws ExecutionException, InterruptedException {
when(mockPrefs.listDownloadedModels()).thenReturn(Collections.singleton(CUSTOM_MODEL));

TestOnCompleteListener<Set<CustomModel>> onCompleteListener = new TestOnCompleteListener<>();
Task<Set<CustomModel>> task = firebaseModelDownloader.listDownloadedModels();
task.addOnCompleteListener(executor, onCompleteListener);
Set<CustomModel> customModelSet = onCompleteListener.await();

assertThat(task.isComplete()).isTrue();
assertEquals(customModelSet, Collections.singleton(CUSTOM_MODEL));

executor.awaitTermination(500, TimeUnit.MILLISECONDS);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably get rid of this too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.firebase.ml.modeldownloader;

import androidx.annotation.NonNull;
import com.google.android.gms.tasks.OnCompleteListener;
import com.google.android.gms.tasks.Task;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

/**
* Helper listener that works around a limitation of the Tasks API where await() cannot be called on
* the main thread. This listener works around it by running itself on a different thread, thus
* allowing the main thread to be woken up when the Tasks complete.
*/
public class TestOnCompleteListener<TResult> implements OnCompleteListener<TResult> {
private static final long TIMEOUT_MS = 5000;
private final CountDownLatch latch = new CountDownLatch(1);
private Task<TResult> task;
private volatile TResult result;
private volatile Exception exception;
private volatile boolean successful;

@Override
public void onComplete(@NonNull Task<TResult> task) {
this.task = task;
successful = task.isSuccessful();
if (successful) {
result = task.getResult();
} else {
exception = task.getException();
}
latch.countDown();
}

/** Blocks until the {@link #onComplete} is called. */
public TResult await() throws InterruptedException, ExecutionException {
if (!latch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS)) {
throw new InterruptedException("timed out waiting for result");
}
if (successful) {
return result;
} else {
if (exception instanceof InterruptedException) {
throw (InterruptedException) exception;
}
// todo(annz) add firebase ml exception handling here.
if (exception instanceof IOException) {
throw new ExecutionException(exception);
}
throw new IllegalStateException("got an unexpected exception type", exception);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;

import androidx.test.core.app.ApplicationProvider;
import com.google.firebase.FirebaseApp;
import com.google.firebase.FirebaseOptions;
import com.google.firebase.ml.modeldownloader.CustomModel;
import java.util.Set;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -118,4 +120,43 @@ public void clearDownloadingModelDetails_keepsLocalModel() throws IllegalArgumen
retrievedModel = sharedPreferencesUtil.getCustomModelDetails(MODEL_NAME);
assertEquals(retrievedModel, CUSTOM_MODEL_DOWNLOAD_COMPLETE);
}

@Test
public void listDownloadedModels_localModelFound() throws IllegalArgumentException {
sharedPreferencesUtil.setUploadedCustomModelDetails(CUSTOM_MODEL_DOWNLOAD_COMPLETE);
Set<CustomModel> retrievedModel = sharedPreferencesUtil.listDownloadedModels();
assertEquals(retrievedModel.size(), 1);
assertEquals(retrievedModel.iterator().next(), CUSTOM_MODEL_DOWNLOAD_COMPLETE);
}

@Test
public void listDownloadedModels_downloadingModelNotFound() throws IllegalArgumentException {
sharedPreferencesUtil.setDownloadingCustomModelDetails(CUSTOM_MODEL_DOWNLOADING);
assertEquals(sharedPreferencesUtil.listDownloadedModels().size(), 0);
}

@Test
public void listDownloadedModels_noModels() throws IllegalArgumentException {
assertEquals(sharedPreferencesUtil.listDownloadedModels().size(), 0);
}

@Test
public void listDownloadedModels_multipleModels() throws IllegalArgumentException {
sharedPreferencesUtil.setUploadedCustomModelDetails(CUSTOM_MODEL_DOWNLOAD_COMPLETE);

CustomModel model2 =
new CustomModel(MODEL_NAME + "2", 0, 102, MODEL_HASH + "2", "file/path/store/ModelName2/1");
sharedPreferencesUtil.setUploadedCustomModelDetails(model2);

CustomModel model3 =
new CustomModel(MODEL_NAME + "3", 0, 103, MODEL_HASH + "3", "file/path/store/ModelName3/1");

sharedPreferencesUtil.setUploadedCustomModelDetails(model3);

Set<CustomModel> retrievedModel = sharedPreferencesUtil.listDownloadedModels();
assertEquals(retrievedModel.size(), 3);
assertTrue(retrievedModel.contains(CUSTOM_MODEL_DOWNLOAD_COMPLETE));
assertTrue(retrievedModel.contains(model2));
assertTrue(retrievedModel.contains(model3));
}
}