Skip to content

Commit 82c93e5

Browse files
annzimmerrlazo
authored andcommitted
Add public getModelDownloadId method (#2315)
* Add public getModelDownloadId method - can be used to track download progress. * Apply suggestions from code review Co-authored-by: Rodrigo Lazo <[email protected]> Co-authored-by: Rodrigo Lazo <[email protected]>
1 parent 1ee2466 commit 82c93e5

File tree

4 files changed

+48
-3
lines changed

4 files changed

+48
-3
lines changed

firebase-ml-modeldownloader/api.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ package com.google.firebase.ml.modeldownloader {
5757
method @NonNull public static com.google.firebase.ml.modeldownloader.FirebaseModelDownloader getInstance();
5858
method @NonNull public static com.google.firebase.ml.modeldownloader.FirebaseModelDownloader getInstance(@NonNull com.google.firebase.FirebaseApp);
5959
method @NonNull public com.google.android.gms.tasks.Task<com.google.firebase.ml.modeldownloader.CustomModel> getModel(@NonNull String, @NonNull com.google.firebase.ml.modeldownloader.DownloadType, @Nullable com.google.firebase.ml.modeldownloader.CustomModelDownloadConditions);
60+
method public long getModelDownloadId(@NonNull String);
6061
method @NonNull public com.google.android.gms.tasks.Task<java.util.Set<com.google.firebase.ml.modeldownloader.CustomModel>> listDownloadedModels();
6162
method public void setStatsCollectionEnabled(boolean);
6263
}

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/CustomModelDownloadConditions.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public static class Builder {
5454
private boolean isWifiRequired = false;
5555
private boolean isDeviceIdleRequired = false;
5656

57-
/** Sets whether charging is required. Only works on Android N and above. */
57+
/** Sets charging as required. Only works on Android N and above. */
5858
@NonNull
5959
@RequiresApi(VERSION_CODES.N)
6060
@TargetApi(VERSION_CODES.N)
@@ -63,15 +63,15 @@ public Builder requireCharging() {
6363
return this;
6464
}
6565

66-
/** Sets whether wifi is required. */
66+
/** Sets wifi as required. */
6767
@NonNull
6868
public Builder requireWifi() {
6969
this.isWifiRequired = true;
7070
return this;
7171
}
7272

7373
/**
74-
* Sets whether device idle is required.
74+
* Sets device idle as required.
7575
*
7676
* <p>Idle mode is a loose definition provided by the system, which means that the device is not
7777
* in use, and has not been in use for some time.

firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,31 @@ public void setStatsCollectionEnabled(boolean enabled) {
422422
sharedPreferencesUtil.setCustomModelStatsCollectionEnabled(enabled);
423423
}
424424

425+
/**
426+
* Get the current models' download id (returns background download id when applicable). This id
427+
* can be used to create a progress bar to track file download progress.
428+
*
429+
* <p>If no model exists or there is no download in progress, return 0.
430+
*
431+
* <p>If 0 is returned immediately after starting a download via getModel, then
432+
*
433+
* <ul>
434+
* <li>the enqueuing wasn't needed: the getModel task already completed and/or no background
435+
* update.
436+
* <li>the enqueuing hasn't completed: the download id hasn't generated yet - try again.
437+
* </ul>
438+
*
439+
* @param modelName - model name
440+
* @return id associated with Android Download Manager.
441+
*/
442+
public long getModelDownloadId(@NonNull String modelName) {
443+
CustomModel localModel = sharedPreferencesUtil.getDownloadingCustomModelDetails(modelName);
444+
if (localModel != null) {
445+
return localModel.getDownloadId();
446+
}
447+
return 0;
448+
}
449+
425450
/** Returns the nick name of the {@link FirebaseApp} of this {@link FirebaseModelDownloader} */
426451
@VisibleForTesting
427452
String getApplicationId() {

firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,4 +725,23 @@ public void setStatsCollectionEnabled() {
725725
firebaseModelDownloader.setStatsCollectionEnabled(false);
726726
verify(mockPrefs, times(1)).setCustomModelStatsCollectionEnabled(eq(false));
727727
}
728+
729+
@Test
730+
public void getModelDownloadId_noDownload() {
731+
when(mockPrefs.getDownloadingCustomModelDetails(eq(MODEL_NAME))).thenReturn(customModelLoaded);
732+
assertEquals(firebaseModelDownloader.getModelDownloadId(MODEL_NAME), 0);
733+
}
734+
735+
@Test
736+
public void getModelDownloadId_noNamedModel() {
737+
when(mockPrefs.getDownloadingCustomModelDetails(eq(MODEL_NAME))).thenReturn(null);
738+
assertEquals(firebaseModelDownloader.getModelDownloadId(MODEL_NAME), 0);
739+
}
740+
741+
@Test
742+
public void getModelDownloadId_download() {
743+
when(mockPrefs.getDownloadingCustomModelDetails(eq(MODEL_NAME)))
744+
.thenReturn(UPDATE_IN_PROGRESS_CUSTOM_MODEL);
745+
assertEquals(firebaseModelDownloader.getModelDownloadId(MODEL_NAME), DOWNLOAD_ID);
746+
}
728747
}

0 commit comments

Comments
 (0)