diff --git a/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle b/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle index 7b88ae75e91..0c9ef7a4374 100644 --- a/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle +++ b/firebase-ml-modeldownloader/firebase-ml-modeldownloader.gradle @@ -64,6 +64,7 @@ dependencies { implementation project(':encoders:firebase-encoders-json') implementation project(':firebase-common') implementation project(':firebase-components') + implementation project(':firebase-annotations') implementation project(':firebase-datatransport') implementation project(':firebase-installations-interop') implementation project(':transport:transport-api') @@ -87,6 +88,7 @@ dependencies { annotationProcessor "com.google.auto.value:auto-value:1.6.5" annotationProcessor project(":encoders:firebase-encoders-processor") + testImplementation(project(":integ-testing")) testImplementation 'androidx.test:core:1.3.0' testImplementation 'com.github.tomakehurst:wiremock-standalone:2.26.3' testImplementation "com.google.truth:truth:$googleTruthVersion" diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java index 5747560bb50..3309902f98f 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloader.java @@ -13,7 +13,6 @@ // limitations under the License. package com.google.firebase.ml.modeldownloader; -import android.annotation.SuppressLint; import android.os.Build.VERSION_CODES; import android.util.Log; import androidx.annotation.NonNull; @@ -26,6 +25,8 @@ import com.google.android.gms.tasks.Tasks; import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; +import com.google.firebase.annotations.concurrent.Background; +import com.google.firebase.annotations.concurrent.Blocking; import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.ModelDownloadLogEvent.DownloadStatus; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.ModelDownloadLogEvent.ErrorCode; @@ -36,7 +37,6 @@ import java.io.File; import java.util.Set; import java.util.concurrent.Executor; -import java.util.concurrent.Executors; import javax.inject.Inject; public class FirebaseModelDownloader { @@ -47,35 +47,15 @@ public class FirebaseModelDownloader { private final ModelFileDownloadService fileDownloadService; private final ModelFileManager fileManager; private final CustomModelDownloadService modelDownloadService; - private final Executor executor; + private final Executor bgExecutor; + private final Executor blockingExecutor; private final FirebaseMlLogger eventLogger; private final CustomModel.Factory modelFactory; @Inject - @RequiresApi(api = VERSION_CODES.KITKAT) - // TODO(b/258424267): Migrate to go/firebase-android-executors - @SuppressLint("ThreadPoolCreation") - FirebaseModelDownloader( - FirebaseOptions firebaseOptions, - SharedPreferencesUtil sharedPreferencesUtil, - ModelFileDownloadService fileDownloadService, - CustomModelDownloadService modelDownloadService, - ModelFileManager fileManager, - FirebaseMlLogger eventLogger, - CustomModel.Factory modelFactory) { - this( - firebaseOptions, - sharedPreferencesUtil, - fileDownloadService, - modelDownloadService, - fileManager, - eventLogger, - Executors.newSingleThreadExecutor(), - modelFactory); - } - @VisibleForTesting + @RequiresApi(api = VERSION_CODES.KITKAT) FirebaseModelDownloader( FirebaseOptions firebaseOptions, SharedPreferencesUtil sharedPreferencesUtil, @@ -83,7 +63,8 @@ public class FirebaseModelDownloader { CustomModelDownloadService modelDownloadService, ModelFileManager fileManager, FirebaseMlLogger eventLogger, - Executor executor, + @Background Executor bgExecutor, + @Blocking Executor blockingExecutor, CustomModel.Factory modelFactory) { this.firebaseOptions = firebaseOptions; this.sharedPreferencesUtil = sharedPreferencesUtil; @@ -91,7 +72,8 @@ public class FirebaseModelDownloader { this.modelDownloadService = modelDownloadService; this.fileManager = fileManager; this.eventLogger = eventLogger; - this.executor = executor; + this.bgExecutor = bgExecutor; + this.blockingExecutor = blockingExecutor; this.modelFactory = modelFactory; } @@ -227,7 +209,7 @@ private Task getCompletedLocalCustomModelTask(@NonNull CustomModel if (downloadInProgressTask != null) { return downloadInProgressTask.continueWithTask( - executor, + bgExecutor, downloadTask -> { if (downloadTask.isSuccessful()) { return finishModelDownload(model.getName()); @@ -251,7 +233,7 @@ private Task getCompletedLocalCustomModelTask(@NonNull CustomModel // bad model state - delete all existing model details and return exception return deleteDownloadedModel(model.getName()) .continueWithTask( - executor, + bgExecutor, deletionTask -> Tasks.forException( new FirebaseMlException( @@ -284,7 +266,7 @@ private Task getCustomModelTask( firebaseOptions.getProjectId(), modelName, modelHash); return incomingModelDetails.continueWithTask( - executor, + bgExecutor, incomingModelDetailTask -> { if (incomingModelDetailTask.isSuccessful()) { // null means we have the latest model or we failed to connect. @@ -368,7 +350,7 @@ && new File(currentModel.getLocalFilePath()).exists()) { return fileDownloadService .download(incomingModelDetailTask.getResult(), conditions) .continueWithTask( - executor, + blockingExecutor, downloadTask -> { if (downloadTask.isSuccessful()) { return finishModelDownload(modelName); @@ -401,14 +383,14 @@ private Task retryExpiredUrlDownload( firebaseOptions.getProjectId(), modelName); // no local model - start download. return retryModelDetails.continueWithTask( - executor, + bgExecutor, retryModelDetailTask -> { if (retryModelDetailTask.isSuccessful()) { // start download return fileDownloadService .download(retryModelDetailTask.getResult(), conditions) .continueWithTask( - executor, + bgExecutor, retryDownloadTask -> { if (retryDownloadTask.isSuccessful()) { return finishModelDownload(modelName); @@ -458,7 +440,7 @@ public Task> listDownloadedModels() { fileDownloadService.maybeCheckDownloadingComplete(); TaskCompletionSource> taskCompletionSource = new TaskCompletionSource<>(); - executor.execute( + bgExecutor.execute( () -> taskCompletionSource.setResult(sharedPreferencesUtil.listDownloadedModels())); return taskCompletionSource.getTask(); } @@ -472,7 +454,7 @@ public Task> listDownloadedModels() { public Task deleteDownloadedModel(@NonNull String modelName) { TaskCompletionSource taskCompletionSource = new TaskCompletionSource<>(); - executor.execute( + bgExecutor.execute( () -> { // remove all files associated with this model and then clean up model references. boolean isSuccessful = deleteModelDetails(modelName); diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java index dcb85a2a864..5f46ec18536 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderRegistrar.java @@ -20,13 +20,17 @@ import androidx.annotation.RequiresApi; import com.google.android.datatransport.TransportFactory; import com.google.firebase.FirebaseApp; +import com.google.firebase.annotations.concurrent.Background; +import com.google.firebase.annotations.concurrent.Blocking; import com.google.firebase.components.Component; import com.google.firebase.components.ComponentRegistrar; import com.google.firebase.components.Dependency; +import com.google.firebase.components.Qualified; import com.google.firebase.installations.FirebaseInstallationsApi; import com.google.firebase.platforminfo.LibraryVersionComponent; import java.util.Arrays; import java.util.List; +import java.util.concurrent.Executor; /** * Registrar for setting up Firebase ML Model Downloader's dependency injections in Firebase Android @@ -41,6 +45,8 @@ public class FirebaseModelDownloaderRegistrar implements ComponentRegistrar { @NonNull @RequiresApi(api = VERSION_CODES.KITKAT) public List> getComponents() { + Qualified bgExecutor = Qualified.qualified(Background.class, Executor.class); + Qualified blockingExecutor = Qualified.qualified(Blocking.class, Executor.class); return Arrays.asList( Component.builder(FirebaseModelDownloader.class) .name(LIBRARY_NAME) @@ -48,12 +54,16 @@ public List> getComponents() { .add(Dependency.required(FirebaseApp.class)) .add(Dependency.requiredProvider(FirebaseInstallationsApi.class)) .add(Dependency.requiredProvider(TransportFactory.class)) + .add(Dependency.required(bgExecutor)) + .add(Dependency.required(blockingExecutor)) .factory( c -> DaggerModelDownloaderComponent.builder() .setApplicationContext(c.get(Context.class)) .setFirebaseApp(c.get(FirebaseApp.class)) .setFis(c.getProvider(FirebaseInstallationsApi.class)) + .setBlockingExecutor(c.get(blockingExecutor)) + .setBgExecutor(c.get(bgExecutor)) .setTransportFactory(c.getProvider(TransportFactory.class)) .build() .getModelDownloader()) diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/ModelDownloaderComponent.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/ModelDownloaderComponent.java index f69036c4a0a..bbf3686a38d 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/ModelDownloaderComponent.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/ModelDownloaderComponent.java @@ -20,12 +20,15 @@ import com.google.android.datatransport.TransportFactory; import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; +import com.google.firebase.annotations.concurrent.Background; +import com.google.firebase.annotations.concurrent.Blocking; import com.google.firebase.inject.Provider; import com.google.firebase.installations.FirebaseInstallationsApi; import dagger.BindsInstance; import dagger.Component; import dagger.Module; import dagger.Provides; +import java.util.concurrent.Executor; import javax.inject.Named; import javax.inject.Singleton; @@ -49,11 +52,18 @@ interface Builder { @BindsInstance Builder setTransportFactory(Provider transportFactory); + @BindsInstance + Builder setBlockingExecutor(@Blocking Executor blockingExecutor); + + @BindsInstance + Builder setBgExecutor(@Background Executor bgExecutor); + ModelDownloaderComponent build(); } @Module interface MainModule { + @Provides @Named("persistenceKey") static String persistenceKey(FirebaseApp app) { diff --git a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java index edab83e1847..3b9f1ec8752 100644 --- a/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java +++ b/firebase-ml-modeldownloader/src/main/java/com/google/firebase/ml/modeldownloader/internal/CustomModelDownloadService.java @@ -14,7 +14,6 @@ package com.google.firebase.ml.modeldownloader.internal; -import android.annotation.SuppressLint; import android.content.Context; import android.content.pm.PackageManager; import android.text.TextUtils; @@ -28,6 +27,7 @@ import com.google.android.gms.tasks.Task; import com.google.android.gms.tasks.Tasks; import com.google.firebase.FirebaseOptions; +import com.google.firebase.annotations.concurrent.Blocking; import com.google.firebase.inject.Provider; import com.google.firebase.installations.FirebaseInstallationsApi; import com.google.firebase.installations.InstallationTokenResult; @@ -48,8 +48,7 @@ import java.util.Date; import java.util.Locale; import java.util.TimeZone; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.Executor; import java.util.zip.GZIPInputStream; import javax.inject.Inject; import org.json.JSONObject; @@ -88,7 +87,6 @@ public class CustomModelDownloadService { @VisibleForTesting static final String DOWNLOAD_MODEL_REGEX = "%s/v1beta2/projects/%s/models/%s:download"; - private final ExecutorService executorService; private final Provider firebaseInstallations; private final FirebaseMlLogger eventLogger; private final String apiKey; @@ -96,21 +94,21 @@ public class CustomModelDownloadService { private final Context context; private final CustomModel.Factory modelFactory; private String downloadHost = FIREBASE_DOWNLOAD_HOST; + private final Executor blockingExecutor; - // TODO(b/258424267): Migrate to go/firebase-android-executors - @SuppressLint("ThreadPoolCreation") @Inject public CustomModelDownloadService( Context context, FirebaseOptions options, Provider installationsApi, FirebaseMlLogger eventLogger, - CustomModel.Factory modelFactory) { + CustomModel.Factory modelFactory, + @Blocking Executor blockingExecutor) { this.context = context; firebaseInstallations = installationsApi; apiKey = options.getApiKey(); fingerprintHashForPackage = getFingerprintHashForPackage(context); - executorService = Executors.newCachedThreadPool(); + this.blockingExecutor = blockingExecutor; this.eventLogger = eventLogger; this.modelFactory = modelFactory; } @@ -119,7 +117,7 @@ public CustomModelDownloadService( CustomModelDownloadService( Context context, Provider firebaseInstallations, - ExecutorService executorService, + Executor blockingExecutor, String apiKey, String fingerprintHashForPackage, String downloadHost, @@ -127,7 +125,7 @@ public CustomModelDownloadService( CustomModel.Factory modelFactory) { this.context = context; this.firebaseInstallations = firebaseInstallations; - this.executorService = executorService; + this.blockingExecutor = blockingExecutor; this.apiKey = apiKey; this.fingerprintHashForPackage = fingerprintHashForPackage; this.downloadHost = downloadHost; @@ -169,64 +167,67 @@ public Task getCustomModelDetails( "Error cannot retrieve model from reading an empty modelName", FirebaseMlException.INVALID_ARGUMENT); - URL url = - new URL(String.format(DOWNLOAD_MODEL_REGEX, downloadHost, projectNumber, modelName)); - HttpURLConnection connection = (HttpURLConnection) url.openConnection(); - connection.setConnectTimeout(CONNECTION_TIME_OUT_MS); - connection.setRequestProperty(ACCEPT_ENCODING_HEADER_KEY, GZIP_CONTENT_ENCODING); - connection.setRequestProperty(CONTENT_TYPE, APPLICATION_JSON); - if (modelHash != null && !modelHash.isEmpty()) { - connection.setRequestProperty(IF_NONE_MATCH_HEADER_KEY, modelHash); - } - Task installationAuthTokenTask = firebaseInstallations.get().getToken(false); return installationAuthTokenTask.continueWithTask( - executorService, + blockingExecutor, (CustomModelTask) -> { - if (!installationAuthTokenTask.isSuccessful()) { - ErrorCode errorCode = ErrorCode.MODEL_INFO_DOWNLOAD_CONNECTION_FAILED; - String errorMessage = "Failed to get model due to authentication error"; - int exceptionCode = FirebaseMlException.UNAUTHENTICATED; - if (installationAuthTokenTask.getException() != null - && (installationAuthTokenTask.getException() instanceof UnknownHostException - || installationAuthTokenTask.getException().getCause() - instanceof UnknownHostException)) { - errorCode = ErrorCode.NO_NETWORK_CONNECTION; - errorMessage = "Failed to retrieve model info due to no internet connection."; - exceptionCode = FirebaseMlException.NO_NETWORK_CONNECTION; + try { + URL url = + new URL( + String.format(DOWNLOAD_MODEL_REGEX, downloadHost, projectNumber, modelName)); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setConnectTimeout(CONNECTION_TIME_OUT_MS); + connection.setRequestProperty(ACCEPT_ENCODING_HEADER_KEY, GZIP_CONTENT_ENCODING); + connection.setRequestProperty(CONTENT_TYPE, APPLICATION_JSON); + if (modelHash != null && !modelHash.isEmpty()) { + connection.setRequestProperty(IF_NONE_MATCH_HEADER_KEY, modelHash); + } + if (!installationAuthTokenTask.isSuccessful()) { + ErrorCode errorCode = ErrorCode.MODEL_INFO_DOWNLOAD_CONNECTION_FAILED; + String errorMessage = "Failed to get model due to authentication error"; + int exceptionCode = FirebaseMlException.UNAUTHENTICATED; + if (installationAuthTokenTask.getException() != null + && (installationAuthTokenTask.getException() instanceof UnknownHostException + || installationAuthTokenTask.getException().getCause() + instanceof UnknownHostException)) { + errorCode = ErrorCode.NO_NETWORK_CONNECTION; + errorMessage = "Failed to retrieve model info due to no internet connection."; + exceptionCode = FirebaseMlException.NO_NETWORK_CONNECTION; + } + eventLogger.logDownloadFailureWithReason( + modelFactory.create(modelName, modelHash != null ? modelHash : "", 0, 0L), + false, + errorCode.getValue()); + return Tasks.forException(new FirebaseMlException(errorMessage, exceptionCode)); } - eventLogger.logDownloadFailureWithReason( - modelFactory.create(modelName, modelHash != null ? modelHash : "", 0, 0L), - false, - errorCode.getValue()); - return Tasks.forException(new FirebaseMlException(errorMessage, exceptionCode)); - } - connection.setRequestProperty( - INSTALLATIONS_AUTH_TOKEN_HEADER, installationAuthTokenTask.getResult().getToken()); - connection.setRequestProperty(API_KEY_HEADER, apiKey); + connection.setRequestProperty( + INSTALLATIONS_AUTH_TOKEN_HEADER, + installationAuthTokenTask.getResult().getToken()); + connection.setRequestProperty(API_KEY_HEADER, apiKey); - // Headers required for Android API Key Restrictions. - connection.setRequestProperty(X_ANDROID_PACKAGE_HEADER, context.getPackageName()); + // Headers required for Android API Key Restrictions. + connection.setRequestProperty(X_ANDROID_PACKAGE_HEADER, context.getPackageName()); - if (fingerprintHashForPackage != null) { - connection.setRequestProperty(X_ANDROID_CERT_HEADER, fingerprintHashForPackage); - } + if (fingerprintHashForPackage != null) { + connection.setRequestProperty(X_ANDROID_CERT_HEADER, fingerprintHashForPackage); + } - return fetchDownloadDetails(modelName, connection); - }); + return fetchDownloadDetails(modelName, connection); + } catch (IOException e) { + eventLogger.logDownloadFailureWithReason( + modelFactory.create(modelName, modelHash, 0, 0L), + false, + ErrorCode.MODEL_INFO_DOWNLOAD_CONNECTION_FAILED.getValue()); - } catch (IOException e) { - eventLogger.logDownloadFailureWithReason( - modelFactory.create(modelName, modelHash, 0, 0L), - false, - ErrorCode.MODEL_INFO_DOWNLOAD_CONNECTION_FAILED.getValue()); + return Tasks.forException( + new FirebaseMlException( + "Error reading custom model from download service: " + e.getMessage(), + FirebaseMlException.INVALID_ARGUMENT)); + } + }); - return Tasks.forException( - new FirebaseMlException( - "Error reading custom model from download service: " + e.getMessage(), - FirebaseMlException.INVALID_ARGUMENT)); } catch (FirebaseMlException e) { return Tasks.forException(e); } diff --git a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java index afa6d59920d..00b26ca493a 100644 --- a/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java +++ b/firebase-ml-modeldownloader/src/test/java/com/google/firebase/ml/modeldownloader/FirebaseModelDownloaderTest.java @@ -37,6 +37,7 @@ import com.google.firebase.FirebaseApp; import com.google.firebase.FirebaseOptions; import com.google.firebase.FirebaseOptions.Builder; +import com.google.firebase.concurrent.TestOnlyExecutors; import com.google.firebase.ml.modeldownloader.internal.CustomModelDownloadService; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.ModelDownloadLogEvent.DownloadStatus; import com.google.firebase.ml.modeldownloader.internal.FirebaseMlLogEvent.ModelDownloadLogEvent.ErrorCode; @@ -137,7 +138,8 @@ public void setUp() throws Exception { mockModelDownloadService, mockFileManager, mockEventLogger, - executor, + TestOnlyExecutors.background(), + TestOnlyExecutors.blocking(), modelFactory); setUpTestingFiles(app);