Skip to content

update mlkitdownloader to use executors #4382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ dependencies {
implementation project(':encoders:firebase-encoders-json')
implementation project(':firebase-common')
implementation project(':firebase-components')
implementation project(':firebase-annotations')
implementation project(':firebase-datatransport')
implementation project(':firebase-installations-interop')
implementation project(':transport:transport-api')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.
package com.google.firebase.ml.modeldownloader;

import android.annotation.SuppressLint;
import android.os.Build.VERSION_CODES;
import android.util.Log;
import androidx.annotation.NonNull;
Expand All @@ -26,6 +25,8 @@
import com.google.android.gms.tasks.Tasks;
import com.google.firebase.FirebaseApp;
import com.google.firebase.FirebaseOptions;
import com.google.firebase.annotations.concurrent.Background;
import com.google.firebase.annotations.concurrent.Blocking;
import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService;
import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.ModelDownloadLogEvent.DownloadStatus;
import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.ModelDownloadLogEvent.ErrorCode;
Expand All @@ -36,7 +37,6 @@
import java.io.File;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import javax.inject.Inject;

public class FirebaseModelDownloader {
Expand All @@ -47,51 +47,33 @@ public class FirebaseModelDownloader {
private final ModelFileDownloadService fileDownloadService;
private final ModelFileManager fileManager;
private final CustomModelDownloadService modelDownloadService;
private final Executor executor;
private final Executor bgExecutor;
private final Executor blockingExecutor;

private final FirebaseMlLogger eventLogger;
private final CustomModel.Factory modelFactory;

@Inject
@RequiresApi(api = VERSION_CODES.KITKAT)
// TODO(b/258424267): Migrate to go/firebase-android-executors
@SuppressLint("ThreadPoolCreation")
FirebaseModelDownloader(
FirebaseOptions firebaseOptions,
SharedPreferencesUtil sharedPreferencesUtil,
ModelFileDownloadService fileDownloadService,
CustomModelDownloadService modelDownloadService,
ModelFileManager fileManager,
FirebaseMlLogger eventLogger,
CustomModel.Factory modelFactory) {
this(
firebaseOptions,
sharedPreferencesUtil,
fileDownloadService,
modelDownloadService,
fileManager,
eventLogger,
Executors.newSingleThreadExecutor(),
modelFactory);
}

@VisibleForTesting
@RequiresApi(api = VERSION_CODES.KITKAT)
FirebaseModelDownloader(
FirebaseOptions firebaseOptions,
SharedPreferencesUtil sharedPreferencesUtil,
ModelFileDownloadService fileDownloadService,
CustomModelDownloadService modelDownloadService,
ModelFileManager fileManager,
FirebaseMlLogger eventLogger,
Executor executor,
@Background Executor bgExecutor,
@Blocking Executor blockingExecutor,
CustomModel.Factory modelFactory) {
this.firebaseOptions = firebaseOptions;
this.sharedPreferencesUtil = sharedPreferencesUtil;
this.fileDownloadService = fileDownloadService;
this.modelDownloadService = modelDownloadService;
this.fileManager = fileManager;
this.eventLogger = eventLogger;
this.executor = executor;
this.bgExecutor = bgExecutor;
this.blockingExecutor = blockingExecutor;
this.modelFactory = modelFactory;
}

Expand Down Expand Up @@ -227,7 +209,7 @@ private Task<CustomModel> getCompletedLocalCustomModelTask(@NonNull CustomModel

if (downloadInProgressTask != null) {
return downloadInProgressTask.continueWithTask(
executor,
bgExecutor,
downloadTask -> {
if (downloadTask.isSuccessful()) {
return finishModelDownload(model.getName());
Expand All @@ -251,7 +233,7 @@ private Task<CustomModel> getCompletedLocalCustomModelTask(@NonNull CustomModel
// bad model state - delete all existing model details and return exception
return deleteDownloadedModel(model.getName())
.continueWithTask(
executor,
bgExecutor,
deletionTask ->
Tasks.forException(
new FirebaseMlException(
Expand Down Expand Up @@ -284,7 +266,7 @@ private Task<CustomModel> getCustomModelTask(
firebaseOptions.getProjectId(), modelName, modelHash);

return incomingModelDetails.continueWithTask(
executor,
bgExecutor,
incomingModelDetailTask -> {
if (incomingModelDetailTask.isSuccessful()) {
// null means we have the latest model or we failed to connect.
Expand Down Expand Up @@ -368,7 +350,7 @@ && new File(currentModel.getLocalFilePath()).exists()) {
return fileDownloadService
.download(incomingModelDetailTask.getResult(), conditions)
.continueWithTask(
executor,
blockingExecutor,
downloadTask -> {
if (downloadTask.isSuccessful()) {
return finishModelDownload(modelName);
Expand Down Expand Up @@ -401,14 +383,14 @@ private Task<CustomModel> retryExpiredUrlDownload(
firebaseOptions.getProjectId(), modelName);
// no local model - start download.
return retryModelDetails.continueWithTask(
executor,
bgExecutor,
retryModelDetailTask -> {
if (retryModelDetailTask.isSuccessful()) {
// start download
return fileDownloadService
.download(retryModelDetailTask.getResult(), conditions)
.continueWithTask(
executor,
bgExecutor,
retryDownloadTask -> {
if (retryDownloadTask.isSuccessful()) {
return finishModelDownload(modelName);
Expand Down Expand Up @@ -458,7 +440,7 @@ public Task<Set<CustomModel>> listDownloadedModels() {
fileDownloadService.maybeCheckDownloadingComplete();

TaskCompletionSource<Set<CustomModel>> taskCompletionSource = new TaskCompletionSource<>();
executor.execute(
bgExecutor.execute(
() -> taskCompletionSource.setResult(sharedPreferencesUtil.listDownloadedModels()));
return taskCompletionSource.getTask();
}
Expand All @@ -472,7 +454,7 @@ public Task<Set<CustomModel>> listDownloadedModels() {
public Task<Void> deleteDownloadedModel(@NonNull String modelName) {

TaskCompletionSource<Void> taskCompletionSource = new TaskCompletionSource<>();
executor.execute(
bgExecutor.execute(
() -> {
// remove all files associated with this model and then clean up model references.
boolean isSuccessful = deleteModelDetails(modelName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@
import androidx.annotation.RequiresApi;
import com.google.android.datatransport.TransportFactory;
import com.google.firebase.FirebaseApp;
import com.google.firebase.annotations.concurrent.Background;
import com.google.firebase.annotations.concurrent.Blocking;
import com.google.firebase.components.Component;
import com.google.firebase.components.ComponentRegistrar;
import com.google.firebase.components.Dependency;
import com.google.firebase.components.Qualified;
import com.google.firebase.installations.FirebaseInstallationsApi;
import com.google.firebase.platforminfo.LibraryVersionComponent;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Executor;

/**
* Registrar for setting up Firebase ML Model Downloader's dependency injections in Firebase Android
Expand All @@ -41,19 +45,25 @@ public class FirebaseModelDownloaderRegistrar implements ComponentRegistrar {
@NonNull
@RequiresApi(api = VERSION_CODES.KITKAT)
public List<Component<?>> getComponents() {
Qualified<Executor> bgExecutor = Qualified.qualified(Background.class, Executor.class);
Qualified<Executor> blockingExecutor = Qualified.qualified(Blocking.class, Executor.class);
return Arrays.asList(
Component.builder(FirebaseModelDownloader.class)
.name(LIBRARY_NAME)
.add(Dependency.required(Context.class))
.add(Dependency.required(FirebaseApp.class))
.add(Dependency.requiredProvider(FirebaseInstallationsApi.class))
.add(Dependency.requiredProvider(TransportFactory.class))
.add(Dependency.required(bgExecutor))
.add(Dependency.required(blockingExecutor))
.factory(
c ->
DaggerModelDownloaderComponent.builder()
.setApplicationContext(c.get(Context.class))
.setFirebaseApp(c.get(FirebaseApp.class))
.setFis(c.getProvider(FirebaseInstallationsApi.class))
.setBlockingExecutor(c.get(blockingExecutor))
.setBgExecutor(c.get(bgExecutor))
.setTransportFactory(c.getProvider(TransportFactory.class))
.build()
.getModelDownloader())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
import com.google.android.datatransport.TransportFactory;
import com.google.firebase.FirebaseApp;
import com.google.firebase.FirebaseOptions;
import com.google.firebase.annotations.concurrent.Background;
import com.google.firebase.annotations.concurrent.Blocking;
import com.google.firebase.inject.Provider;
import com.google.firebase.installations.FirebaseInstallationsApi;
import dagger.BindsInstance;
import dagger.Component;
import dagger.Module;
import dagger.Provides;
import java.util.concurrent.Executor;
import javax.inject.Named;
import javax.inject.Singleton;

Expand All @@ -49,11 +52,18 @@ interface Builder {
@BindsInstance
Builder setTransportFactory(Provider<TransportFactory> transportFactory);

@BindsInstance
Builder setBlockingExecutor(@Blocking Executor blockingExecutor);

@BindsInstance
Builder setBgExecutor(@Background Executor bgExecutor);

ModelDownloaderComponent build();
}

@Module
interface MainModule {

@Provides
@Named("persistenceKey")
static String persistenceKey(FirebaseApp app) {
Expand Down
Loading