Skip to content

Migrate ml to dagger DI. #4370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions firebase-ml-modeldownloader/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Unreleased
- [changed] Internal infrastructure improvements.

# 24.1.1
* [fixed] Fixed an issue where `FirebaseModelDownloader.getModel` was throwing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

plugins {
id 'firebase-library'
id 'firebase-vendor'
id 'com.google.protobuf'
}

Expand Down Expand Up @@ -76,6 +77,12 @@ dependencies {
implementation 'com.google.auto.service:auto-service-annotations:1.0-rc6'
implementation 'javax.inject:javax.inject:1'

implementation 'javax.inject:javax.inject:1'
vendor ('com.google.dagger:dagger:2.43.2') {
exclude group: "javax.inject", module: "javax.inject"
}
annotationProcessor 'com.google.dagger:dagger-compiler:2.43.2'

compileOnly "com.google.auto.value:auto-value-annotations:1.6.6"
annotationProcessor "com.google.auto.value:auto-value:1.6.5"
annotationProcessor project(":encoders:firebase-encoders-processor")
Expand Down
5 changes: 4 additions & 1 deletion firebase-ml-modeldownloader/ktx/ktx.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ dependencies {
implementation project(':firebase-ml-modeldownloader')

testImplementation "org.robolectric:robolectric:$robolectricVersion"
testImplementation 'junit:junit:4.12'
testImplementation 'junit:junit:4.13.2'
testImplementation "com.google.truth:truth:$googleTruthVersion"
testImplementation 'org.mockito:mockito-core:3.6.0'
testImplementation 'androidx.test:runner:1.5.1'
testImplementation 'androidx.test.ext:junit:1.1.4'
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@

package com.google.firebase.ml.modeldownloader.ktx

import androidx.test.core.app.ApplicationProvider
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.common.truth.Truth.assertThat
import com.google.firebase.FirebaseApp
import com.google.firebase.FirebaseOptions
import com.google.firebase.ktx.Firebase
import com.google.firebase.ktx.app
import com.google.firebase.ktx.initialize
import com.google.firebase.ml.modeldownloader.CustomModel
import com.google.firebase.ml.modeldownloader.FirebaseModelDownloader
import com.google.firebase.platforminfo.UserAgentPublisher
import org.junit.After
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import org.robolectric.RuntimeEnvironment

const val APP_ID = "1:1234567890:android:321abc456def7890"
const val API_KEY = "AIzaSyDOCAbC123dEf456GhI789jKl012-MnO"
Expand All @@ -39,15 +39,15 @@ abstract class BaseTestCase {
@Before
fun setUp() {
Firebase.initialize(
RuntimeEnvironment.application,
ApplicationProvider.getApplicationContext(),
FirebaseOptions.Builder()
.setApplicationId(APP_ID)
.setApiKey(API_KEY)
.setProjectId("123")
.build()
)
Firebase.initialize(
RuntimeEnvironment.application,
ApplicationProvider.getApplicationContext(),
FirebaseOptions.Builder()
.setApplicationId(APP_ID)
.setApiKey(API_KEY)
Expand All @@ -63,7 +63,7 @@ abstract class BaseTestCase {
}
}

@RunWith(RobolectricTestRunner::class)
@RunWith(AndroidJUnit4::class)
class ModelDownloaderTests : BaseTestCase() {

@Test
Expand Down Expand Up @@ -92,14 +92,17 @@ class ModelDownloaderTests : BaseTestCase() {

@Test
fun `CustomModel destructuring declarations work`() {
val app = Firebase.app(EXISTING_APP)

val modelName = "myModel"
val modelHash = "someHash"
val fileSize = 200L
val downloadId = 258L

val customModel = CustomModel(modelName, modelHash, fileSize, downloadId)
val customModel =
Firebase.modelDownloader(app).modelFactory.create(modelName, modelHash, fileSize, downloadId)

val (file, size, id, hash, name) = customModel
val (_, size, id, hash, name) = customModel

assertThat(name).isEqualTo(customModel.name)
assertThat(hash).isEqualTo(customModel.modelHash)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@

package com.google.firebase.ml_data_collection_tests;

import android.content.Context;
import android.content.SharedPreferences;
import androidx.test.core.app.ApplicationProvider;
import com.google.firebase.FirebaseApp;
import com.google.firebase.FirebaseOptions;
import com.google.firebase.ml.modeldownloader.FirebaseModelDownloader;
import com.google.firebase.ml.modeldownloader.internal.SharedPreferencesUtil;
import java.util.function.Consumer;

Expand Down Expand Up @@ -47,12 +46,8 @@ static void withApp(String name, Consumer<FirebaseApp> callable) {
}

static SharedPreferencesUtil getSharedPreferencesUtil(FirebaseApp app) {
return new SharedPreferencesUtil(app);
}

static SharedPreferences getSharedPreferences(FirebaseApp app) {
return app.getApplicationContext()
.getSharedPreferences(SharedPreferencesUtil.PREFERENCES_PACKAGE_NAME, Context.MODE_PRIVATE);
return new SharedPreferencesUtil(
app, FirebaseModelDownloader.getInstance(app).getModelFactory());
}

static void setSharedPreferencesTo(FirebaseApp app, Boolean enabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,25 @@ public class TestGetModelLocal {
private static final String MODEL_NAME_LOCAL = "getLocalModel";
private static final String MODEL_NAME_LOCAL_2 = "getLocalModel2";
private static final String MODEL_HASH = "origHash324";
private final CustomModel SETUP_LOADED_LOCAL_MODEL =
new CustomModel(MODEL_NAME_LOCAL, MODEL_HASH, 100, 0);

private FirebaseApp app;
private File firstDeviceModelFile;
private File firstLoadTempModelFile;

private CustomModel.Factory modelFactory;
private CustomModel setupLoadedLocalModel;

@Before
public void before() {
app = FirebaseApp.initializeApp(ApplicationProvider.getApplicationContext());
app.setDataCollectionDefaultEnabled(Boolean.FALSE);
FirebaseModelDownloader firebaseModelDownloader = FirebaseModelDownloader.getInstance(app);

SharedPreferencesUtil sharedPreferencesUtil = new SharedPreferencesUtil(app);
modelFactory = firebaseModelDownloader.getModelFactory();

setupLoadedLocalModel = modelFactory.create(MODEL_NAME_LOCAL, MODEL_HASH, 100, 0);

SharedPreferencesUtil sharedPreferencesUtil = new SharedPreferencesUtil(app, modelFactory);
// reset shared preferences and downloads for models used by this test.
firebaseModelDownloader.deleteDownloadedModel(MODEL_NAME_LOCAL);
firebaseModelDownloader.deleteDownloadedModel(MODEL_NAME_LOCAL_2);
Expand All @@ -79,7 +84,11 @@ public void teardown() {
}

private void setUpLoadedLocalModelWithFile() throws Exception {
ModelFileManager fileManager = ModelFileManager.getInstance();
ModelFileManager fileManager =
new ModelFileManager(
app.getApplicationContext(),
app.getPersistenceKey(),
new SharedPreferencesUtil(app, modelFactory));
final File testDir = new File(app.getApplicationContext().getNoBackupFilesDir(), "tmpModels");
testDir.mkdirs();
// make sure the directory is empty. Doesn't recurse into subdirs, but that's OK since
Expand All @@ -105,14 +114,14 @@ private void setUpLoadedLocalModelWithFile() throws Exception {
ParcelFileDescriptor fd =
ParcelFileDescriptor.open(firstLoadTempModelFile, ParcelFileDescriptor.MODE_READ_ONLY);

firstDeviceModelFile = fileManager.moveModelToDestinationFolder(SETUP_LOADED_LOCAL_MODEL, fd);
firstDeviceModelFile = fileManager.moveModelToDestinationFolder(setupLoadedLocalModel, fd);
assertEquals(firstDeviceModelFile, new File(expectedDestinationFolder + "/0"));
assertTrue(firstDeviceModelFile.exists());
fd.close();

fakePreloadedCustomModel(
MODEL_NAME_LOCAL,
SETUP_LOADED_LOCAL_MODEL.getModelHash(),
setupLoadedLocalModel.getModelHash(),
99,
expectedDestinationFolder + "/0");
}
Expand Down Expand Up @@ -228,9 +237,9 @@ public void localModel_preloadedDoNotFetchUpdate() throws Exception {
}

private void fakePreloadedCustomModel(String modelName, String hash, long size, String filePath) {
SharedPreferencesUtil sharedPreferencesUtil = new SharedPreferencesUtil(app);
SharedPreferencesUtil sharedPreferencesUtil = new SharedPreferencesUtil(app, modelFactory);
sharedPreferencesUtil.setLoadedCustomModelDetails(
new CustomModel(modelName, hash, size, 0L, filePath));
modelFactory.create(modelName, hash, size, 0L, filePath));
}

private Set<CustomModel> getDownloadedModelList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ public void before() {
app.setDataCollectionDefaultEnabled(Boolean.FALSE);
FirebaseModelDownloader firebaseModelDownloader = FirebaseModelDownloader.getInstance(app);

SharedPreferencesUtil sharedPreferencesUtil = new SharedPreferencesUtil(app);
SharedPreferencesUtil sharedPreferencesUtil =
new SharedPreferencesUtil(app, firebaseModelDownloader.getModelFactory());
// reset shared preferences and downloads for models used by this test.
firebaseModelDownloader.deleteDownloadedModel(MODEL_NAME_LOCAL);
firebaseModelDownloader.deleteDownloadedModel(MODEL_NAME_LOCAL_2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.RestrictTo;
import androidx.annotation.VisibleForTesting;
import com.google.android.gms.common.internal.Objects;
import com.google.firebase.ml.modeldownloader.internal.ModelFileDownloadService;
import dagger.assisted.Assisted;
import dagger.assisted.AssistedFactory;
import dagger.assisted.AssistedInject;
import java.io.File;

/**
Expand All @@ -27,6 +31,7 @@
* downloaded, the original model file will be removed once it is safe to do so.
*/
public class CustomModel {
private final ModelFileDownloadService fileDownloadService;
private final String name;
private final long downloadId;
private final long fileSize;
Expand All @@ -35,57 +40,31 @@ public class CustomModel {
private final String downloadUrl;
private final long downloadUrlExpiry;

/**
* Use when creating a custom model while the initial download is still in progress.
*
* @param name Model name.
* @param modelHash Model hash.
* @param fileSize Model file size.
* @param downloadId Android Download Manger - download ID.
* @hide
*/
public CustomModel(
@NonNull String name, @NonNull String modelHash, long fileSize, long downloadId) {
this(name, modelHash, fileSize, downloadId, "", "", 0);
}
/** @hide */
@AssistedFactory
public interface Factory {
CustomModel create(
@Assisted("name") String name,
@Assisted("modelHash") String modelHash,
@Assisted("fileSize") long fileSize,
@Assisted("downloadId") long downloadId,
@Assisted("localFilePath") String localFilePath,
@Assisted("downloadUrl") String downloadUrl,
@Assisted("downloadUrlExpiry") long downloadUrlExpiry);

/**
* Use when creating a custom model from a stored model with a new download in the background.
*
* @param name Model name.
* @param modelHash Model hash.
* @param fileSize Model file size.
* @param downloadId Android Download Manger - download ID.
* @hide
*/
public CustomModel(
@NonNull String name,
@NonNull String modelHash,
long fileSize,
long downloadId,
String localFilePath) {
this(name, modelHash, fileSize, downloadId, localFilePath, "", 0);
}
default CustomModel create(String name, String modelHash, long fileSize, long downloadId) {
return create(name, modelHash, fileSize, downloadId, "", "", 0);
}

/**
* Use when creating a custom model from a download service response. Download URL and download
* URL expiry should go together. These will not be stored in user preferences as this is a
* temporary step towards setting the actual download ID.
*
* @param name Model name.
* @param modelHash Model hash.
* @param fileSize Model file size.
* @param downloadUrl Download URL path
* @param downloadUrlExpiry Time download URL path expires.
* @hide
*/
public CustomModel(
@NonNull String name,
@NonNull String modelHash,
long fileSize,
String downloadUrl,
long downloadUrlExpiry) {
this(name, modelHash, fileSize, 0, "", downloadUrl, downloadUrlExpiry);
default CustomModel create(
String name, String modelHash, long fileSize, long downloadId, String localFilePath) {
return create(name, modelHash, fileSize, downloadId, localFilePath, "", 0);
}

default CustomModel create(
String name, String modelHash, long fileSize, String downloadUrl, long downloadUrlExpiry) {
return create(name, modelHash, fileSize, 0, "", downloadUrl, downloadUrlExpiry);
}
}

/**
Expand All @@ -100,14 +79,19 @@ public CustomModel(
* @param downloadUrlExpiry Expiry time of download URL link.
* @hide
*/
private CustomModel(
@NonNull String name,
@NonNull String modelHash,
long fileSize,
long downloadId,
@Nullable String localFilePath,
@Nullable String downloadUrl,
long downloadUrlExpiry) {
@AssistedInject
@VisibleForTesting
@RestrictTo(RestrictTo.Scope.LIBRARY)
public CustomModel(
ModelFileDownloadService fileDownloadService,
@Assisted("name") String name,
@Assisted("modelHash") String modelHash,
@Assisted("fileSize") long fileSize,
@Assisted("downloadId") long downloadId,
@Assisted("localFilePath") String localFilePath,
@Assisted("downloadUrl") String downloadUrl,
@Assisted("downloadUrlExpiry") long downloadUrlExpiry) {
this.fileDownloadService = fileDownloadService;
this.modelHash = modelHash;
this.name = name;
this.fileSize = fileSize;
Expand Down Expand Up @@ -137,7 +121,7 @@ public String getName() {
*/
@Nullable
public File getFile() {
return getFile(ModelFileDownloadService.getInstance());
return getFile(fileDownloadService);
}

/**
Expand Down
Loading