13
13
// limitations under the License.
14
14
package com .google .firebase .ml .modeldownloader ;
15
15
16
- import android .annotation .SuppressLint ;
17
16
import android .os .Build .VERSION_CODES ;
18
17
import android .util .Log ;
19
18
import androidx .annotation .NonNull ;
26
25
import com .google .android .gms .tasks .Tasks ;
27
26
import com .google .firebase .FirebaseApp ;
28
27
import com .google .firebase .FirebaseOptions ;
28
+ import com .google .firebase .annotations .concurrent .Background ;
29
+ import com .google .firebase .annotations .concurrent .Blocking ;
29
30
import com .google .firebase .ml .modeldownloader .internal .CustomModelDownloadService ;
30
31
import com .google .firebase .ml .modeldownloader .internal .FirebaseMlLogEvent .ModelDownloadLogEvent .DownloadStatus ;
31
32
import com .google .firebase .ml .modeldownloader .internal .FirebaseMlLogEvent .ModelDownloadLogEvent .ErrorCode ;
36
37
import java .io .File ;
37
38
import java .util .Set ;
38
39
import java .util .concurrent .Executor ;
39
- import java .util .concurrent .Executors ;
40
40
import javax .inject .Inject ;
41
41
42
42
public class FirebaseModelDownloader {
@@ -47,34 +47,14 @@ public class FirebaseModelDownloader {
47
47
private final ModelFileDownloadService fileDownloadService ;
48
48
private final ModelFileManager fileManager ;
49
49
private final CustomModelDownloadService modelDownloadService ;
50
- private final Executor executor ;
50
+ private final Executor bgExecutor ;
51
+ private final Executor blockingExecutor ;
51
52
52
53
private final FirebaseMlLogger eventLogger ;
53
54
private final CustomModel .Factory modelFactory ;
54
55
55
56
@ Inject
56
57
@ RequiresApi (api = VERSION_CODES .KITKAT )
57
- // TODO(b/258424267): Migrate to go/firebase-android-executors
58
- @ SuppressLint ("ThreadPoolCreation" )
59
- FirebaseModelDownloader (
60
- FirebaseOptions firebaseOptions ,
61
- SharedPreferencesUtil sharedPreferencesUtil ,
62
- ModelFileDownloadService fileDownloadService ,
63
- CustomModelDownloadService modelDownloadService ,
64
- ModelFileManager fileManager ,
65
- FirebaseMlLogger eventLogger ,
66
- CustomModel .Factory modelFactory ) {
67
- this (
68
- firebaseOptions ,
69
- sharedPreferencesUtil ,
70
- fileDownloadService ,
71
- modelDownloadService ,
72
- fileManager ,
73
- eventLogger ,
74
- Executors .newSingleThreadExecutor (),
75
- modelFactory );
76
- }
77
-
78
58
@ VisibleForTesting
79
59
FirebaseModelDownloader (
80
60
FirebaseOptions firebaseOptions ,
@@ -83,15 +63,17 @@ public class FirebaseModelDownloader {
83
63
CustomModelDownloadService modelDownloadService ,
84
64
ModelFileManager fileManager ,
85
65
FirebaseMlLogger eventLogger ,
86
- Executor executor ,
66
+ @ Background Executor bgExecutor ,
67
+ @ Blocking Executor blockingExecutor ,
87
68
CustomModel .Factory modelFactory ) {
88
69
this .firebaseOptions = firebaseOptions ;
89
70
this .sharedPreferencesUtil = sharedPreferencesUtil ;
90
71
this .fileDownloadService = fileDownloadService ;
91
72
this .modelDownloadService = modelDownloadService ;
92
73
this .fileManager = fileManager ;
93
74
this .eventLogger = eventLogger ;
94
- this .executor = executor ;
75
+ this .bgExecutor = bgExecutor ;
76
+ this .blockingExecutor = blockingExecutor ;
95
77
this .modelFactory = modelFactory ;
96
78
}
97
79
@@ -227,7 +209,7 @@ private Task<CustomModel> getCompletedLocalCustomModelTask(@NonNull CustomModel
227
209
228
210
if (downloadInProgressTask != null ) {
229
211
return downloadInProgressTask .continueWithTask (
230
- executor ,
212
+ bgExecutor ,
231
213
downloadTask -> {
232
214
if (downloadTask .isSuccessful ()) {
233
215
return finishModelDownload (model .getName ());
@@ -251,7 +233,7 @@ private Task<CustomModel> getCompletedLocalCustomModelTask(@NonNull CustomModel
251
233
// bad model state - delete all existing model details and return exception
252
234
return deleteDownloadedModel (model .getName ())
253
235
.continueWithTask (
254
- executor ,
236
+ bgExecutor ,
255
237
deletionTask ->
256
238
Tasks .forException (
257
239
new FirebaseMlException (
@@ -284,7 +266,7 @@ private Task<CustomModel> getCustomModelTask(
284
266
firebaseOptions .getProjectId (), modelName , modelHash );
285
267
286
268
return incomingModelDetails .continueWithTask (
287
- executor ,
269
+ bgExecutor ,
288
270
incomingModelDetailTask -> {
289
271
if (incomingModelDetailTask .isSuccessful ()) {
290
272
// null means we have the latest model or we failed to connect.
@@ -368,7 +350,7 @@ && new File(currentModel.getLocalFilePath()).exists()) {
368
350
return fileDownloadService
369
351
.download (incomingModelDetailTask .getResult (), conditions )
370
352
.continueWithTask (
371
- executor ,
353
+ blockingExecutor ,
372
354
downloadTask -> {
373
355
if (downloadTask .isSuccessful ()) {
374
356
return finishModelDownload (modelName );
@@ -401,14 +383,14 @@ private Task<CustomModel> retryExpiredUrlDownload(
401
383
firebaseOptions .getProjectId (), modelName );
402
384
// no local model - start download.
403
385
return retryModelDetails .continueWithTask (
404
- executor ,
386
+ bgExecutor ,
405
387
retryModelDetailTask -> {
406
388
if (retryModelDetailTask .isSuccessful ()) {
407
389
// start download
408
390
return fileDownloadService
409
391
.download (retryModelDetailTask .getResult (), conditions )
410
392
.continueWithTask (
411
- executor ,
393
+ bgExecutor ,
412
394
retryDownloadTask -> {
413
395
if (retryDownloadTask .isSuccessful ()) {
414
396
return finishModelDownload (modelName );
@@ -458,7 +440,7 @@ public Task<Set<CustomModel>> listDownloadedModels() {
458
440
fileDownloadService .maybeCheckDownloadingComplete ();
459
441
460
442
TaskCompletionSource <Set <CustomModel >> taskCompletionSource = new TaskCompletionSource <>();
461
- executor .execute (
443
+ bgExecutor .execute (
462
444
() -> taskCompletionSource .setResult (sharedPreferencesUtil .listDownloadedModels ()));
463
445
return taskCompletionSource .getTask ();
464
446
}
@@ -472,7 +454,7 @@ public Task<Set<CustomModel>> listDownloadedModels() {
472
454
public Task <Void > deleteDownloadedModel (@ NonNull String modelName ) {
473
455
474
456
TaskCompletionSource <Void > taskCompletionSource = new TaskCompletionSource <>();
475
- executor .execute (
457
+ bgExecutor .execute (
476
458
() -> {
477
459
// remove all files associated with this model and then clean up model references.
478
460
boolean isSuccessful = deleteModelDetails (modelName );
0 commit comments