Skip to content

Commit 1a32948

Browse files
committed
VectorValue type support for Firestore.
1 parent c928402 commit 1a32948

File tree

14 files changed

+662
-56
lines changed

14 files changed

+662
-56
lines changed

firebase-firestore/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ it on behalf.
113113

114114
Run below to format Java code:
115115
```bash
116-
./gradlew :firebase-firestore:googleJavaFormat
116+
./gradlew :firebase-firestore:spotlessApply
117117
```
118118

119119
See [here](../README.md#code-formatting) if you want to be able to format code

firebase-firestore/src/androidTest/java/com/google/firebase/firestore/NumericTransformsTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
@RunWith(AndroidJUnit4.class)
3636
public class NumericTransformsTest {
37-
private static final double DOUBLE_EPSILON = 0.000001;
37+
public static final double DOUBLE_EPSILON = 0.000001;
3838

3939
// A document reference to read and write to.
4040
private DocumentReference docRef;

firebase-firestore/src/androidTest/java/com/google/firebase/firestore/VectorTest.java

Lines changed: 360 additions & 0 deletions
Large diffs are not rendered by default.

firebase-firestore/src/main/java/com/google/firebase/firestore/DocumentSnapshot.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,18 @@ public DocumentReference getReference() {
484484
return new DocumentReference(key, firestore);
485485
}
486486

487+
/**
488+
* Returns the value of the field as a VectorValue.
489+
*
490+
* @param field The path to the field.
491+
* @throws RuntimeException if the value is not a VectorValue.
492+
* @return The value of the field.
493+
*/
494+
@Nullable
495+
public VectorValue getVectorValue(@NonNull String field) {
496+
return (VectorValue) get(field);
497+
}
498+
487499
@Nullable
488500
private <T> T getTypedValue(String field, Class<T> clazz) {
489501
checkNotNull(field, "Provided field must not be null.");

firebase-firestore/src/main/java/com/google/firebase/firestore/FieldValue.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,15 @@ public static FieldValue increment(long l) {
182182
public static FieldValue increment(double l) {
183183
return new NumericIncrementFieldValue(l);
184184
}
185+
186+
/**
187+
* Creates a new {@link VectorValue} constructed with a copy of the given array of doubles.
188+
*
189+
* @param values Create a {@link VectorValue} instance with a copy of this array of doubles.
190+
* @return A new {@link VectorValue} constructed with a copy of the given array of doubles.
191+
*/
192+
@NonNull
193+
public static VectorValue vector(@NonNull double[] values) {
194+
return new VectorValue(values);
195+
}
185196
}

firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataReader.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import com.google.firebase.firestore.model.DatabaseId;
3333
import com.google.firebase.firestore.model.FieldPath;
3434
import com.google.firebase.firestore.model.ObjectValue;
35+
import com.google.firebase.firestore.model.Values;
3536
import com.google.firebase.firestore.model.mutation.ArrayTransformOperation;
3637
import com.google.firebase.firestore.model.mutation.FieldMask;
3738
import com.google.firebase.firestore.model.mutation.NumericIncrementTransformOperation;
@@ -45,11 +46,13 @@
4546
import com.google.protobuf.NullValue;
4647
import com.google.type.LatLng;
4748
import java.util.ArrayList;
49+
import java.util.Arrays;
4850
import java.util.Date;
4951
import java.util.Iterator;
5052
import java.util.List;
5153
import java.util.Map;
5254
import java.util.Map.Entry;
55+
import java.util.stream.Collectors;
5356

5457
/**
5558
* Helper for parsing raw user input (provided via the API) into internal model classes.
@@ -440,13 +443,26 @@ private Value parseScalarValue(Object input, ParseContext context) {
440443
databaseId.getDatabaseId(),
441444
((DocumentReference) input).getPath()))
442445
.build();
446+
} else if (input instanceof VectorValue) {
447+
return parseVectorValue(((VectorValue) input), context);
443448
} else if (input.getClass().isArray()) {
444449
throw context.createError("Arrays are not supported; use a List instead");
445450
} else {
446451
throw context.createError("Unsupported type: " + Util.typeName(input));
447452
}
448453
}
449454

455+
private Value parseVectorValue(VectorValue vector, ParseContext context) {
456+
MapValue.Builder mapBuilder = MapValue.newBuilder();
457+
458+
mapBuilder.putFields(Values.TYPE_KEY, Values.VECTOR_VALUE_TYPE);
459+
mapBuilder.putFields(
460+
Values.VECTOR_MAP_VECTORS_KEY,
461+
parseData(Arrays.stream(vector.toArray()).boxed().collect(Collectors.toList()), context));
462+
463+
return Value.newBuilder().setMapValue(mapBuilder).build();
464+
}
465+
450466
private Value parseTimestamp(Timestamp timestamp) {
451467
// Firestore backend truncates precision down to microseconds. To ensure offline mode works
452468
// the same with regards to truncation, perform the truncation immediately without waiting for

firebase-firestore/src/main/java/com/google/firebase/firestore/UserDataWriter.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
import static com.google.firebase.firestore.model.Values.TYPE_ORDER_SERVER_TIMESTAMP;
2828
import static com.google.firebase.firestore.model.Values.TYPE_ORDER_STRING;
2929
import static com.google.firebase.firestore.model.Values.TYPE_ORDER_TIMESTAMP;
30+
import static com.google.firebase.firestore.model.Values.TYPE_ORDER_VECTOR;
3031
import static com.google.firebase.firestore.model.Values.typeOrder;
3132
import static com.google.firebase.firestore.util.Assert.fail;
3233

3334
import androidx.annotation.RestrictTo;
3435
import com.google.firebase.Timestamp;
3536
import com.google.firebase.firestore.model.DatabaseId;
3637
import com.google.firebase.firestore.model.DocumentKey;
38+
import com.google.firebase.firestore.model.Values;
3739
import com.google.firebase.firestore.util.Logger;
3840
import com.google.firestore.v1.ArrayValue;
3941
import com.google.firestore.v1.Value;
@@ -86,6 +88,8 @@ public Object convertValue(Value value) {
8688
case TYPE_ORDER_GEOPOINT:
8789
return new GeoPoint(
8890
value.getGeoPointValue().getLatitude(), value.getGeoPointValue().getLongitude());
91+
case TYPE_ORDER_VECTOR:
92+
return convertVectorValue(value.getMapValue().getFieldsMap());
8993
default:
9094
throw fail("Unknown value type: " + value.getValueTypeCase());
9195
}
@@ -99,6 +103,14 @@ Map<String, Object> convertObject(Map<String, Value> mapValue) {
99103
return result;
100104
}
101105

106+
VectorValue convertVectorValue(Map<String, Value> mapValue) {
107+
double[] values =
108+
mapValue.get(Values.VECTOR_MAP_VECTORS_KEY).getArrayValue().getValuesList().stream()
109+
.mapToDouble(val -> val.getDoubleValue())
110+
.toArray();
111+
return new VectorValue(values);
112+
}
113+
102114
private Object convertServerTimestamp(Value serverTimestampValue) {
103115
switch (serverTimestampBehavior) {
104116
case PREVIOUS:
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package com.google.firebase.firestore;
2+
3+
import androidx.annotation.NonNull;
4+
import androidx.annotation.Nullable;
5+
import java.io.Serializable;
6+
import java.util.Arrays;
7+
8+
public class VectorValue implements Serializable {
9+
private final double[] values;
10+
11+
VectorValue(@Nullable double[] values) {
12+
if (values == null) this.values = new double[] {};
13+
else this.values = values.clone();
14+
}
15+
16+
/**
17+
* Returns a representation of the vector as an array of doubles.
18+
*
19+
* @return A representation of the vector as an array of doubles
20+
*/
21+
@NonNull
22+
public double[] toArray() {
23+
return this.values.clone();
24+
}
25+
26+
/**
27+
* Returns true if this VectorValue is equal to the provided object.
28+
*
29+
* @param obj The object to compare against.
30+
* @return Whether this VectorValue is equal to the provided object.
31+
*/
32+
@Override
33+
public boolean equals(Object obj) {
34+
if (this == obj) {
35+
return true;
36+
}
37+
if (obj == null || getClass() != obj.getClass()) {
38+
return false;
39+
}
40+
VectorValue otherArray = (VectorValue) obj;
41+
return Arrays.equals(this.values, otherArray.values);
42+
}
43+
44+
@Override
45+
public int hashCode() {
46+
return Arrays.hashCode(values);
47+
}
48+
}

firebase-firestore/src/main/java/com/google/firebase/firestore/core/Target.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ private Pair<Value, Boolean> getAscendingBound(
246246
switch (fieldFilter.getOperator()) {
247247
case LESS_THAN:
248248
case LESS_THAN_OR_EQUAL:
249-
filterValue = Values.getLowerBound(fieldFilter.getValue().getValueTypeCase());
249+
filterValue = Values.getLowerBound(fieldFilter.getValue());
250250
break;
251251
case EQUAL:
252252
case IN:
@@ -311,7 +311,7 @@ private Pair<Value, Boolean> getDescendingBound(
311311
switch (fieldFilter.getOperator()) {
312312
case GREATER_THAN_OR_EQUAL:
313313
case GREATER_THAN:
314-
filterValue = Values.getUpperBound(fieldFilter.getValue().getValueTypeCase());
314+
filterValue = Values.getUpperBound(fieldFilter.getValue());
315315
filterInclusive = false;
316316
break;
317317
case EQUAL:

firebase-firestore/src/main/java/com/google/firebase/firestore/index/FirestoreIndexValueWriter.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ public class FirestoreIndexValueWriter {
4141
public static final int INDEX_TYPE_REFERENCE = 37;
4242
public static final int INDEX_TYPE_GEOPOINT = 45;
4343
public static final int INDEX_TYPE_ARRAY = 50;
44+
public static final int INDEX_TYPE_VECTOR = 53;
4445
public static final int INDEX_TYPE_MAP = 55;
4546
public static final int INDEX_TYPE_REFERENCE_SEGMENT = 60;
4647

@@ -114,6 +115,9 @@ private void writeIndexValueAux(Value indexValue, DirectionalIndexByteEncoder en
114115
if (Values.isMaxValue(indexValue)) {
115116
writeValueTypeLabel(encoder, Integer.MAX_VALUE);
116117
break;
118+
} else if (Values.isVectorValue(indexValue)) {
119+
writeIndexVector(indexValue.getMapValue(), encoder);
120+
break;
117121
}
118122
writeIndexMap(indexValue.getMapValue(), encoder);
119123
writeTruncationMarker(encoder);
@@ -139,6 +143,22 @@ private void writeUnlabeledIndexString(
139143
}
140144

141145
private void writeIndexMap(MapValue mapIndexValue, DirectionalIndexByteEncoder encoder) {
146+
writeValueTypeLabel(encoder, INDEX_TYPE_VECTOR);
147+
148+
Map<String, Value> map = mapIndexValue.getFieldsMap();
149+
150+
// Vectors sort first by length
151+
String key = Values.VECTOR_MAP_VECTORS_KEY;
152+
int length = map.get(key).getArrayValue().getValuesCount();
153+
this.writeValueTypeLabel(encoder, INDEX_TYPE_VECTOR);
154+
encoder.writeLong(length);
155+
156+
// Vectors then sort by position value
157+
this.writeIndexString(key, encoder);
158+
this.writeIndexValueAux(map.get(key), encoder);
159+
}
160+
161+
private void writeIndexVector(MapValue mapIndexValue, DirectionalIndexByteEncoder encoder) {
142162
writeValueTypeLabel(encoder, INDEX_TYPE_MAP);
143163
for (Map.Entry<String, Value> entry : mapIndexValue.getFieldsMap().entrySet()) {
144164
String key = entry.getKey();

0 commit comments

Comments
 (0)