Skip to content

Implement BloomFilter class #4524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 20, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.firebase.firestore.remote;

import android.util.Base64;
import androidx.annotation.NonNull;
import androidx.annotation.VisibleForTesting;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;

public class BloomFilter {
private final int size;
private final byte[] bitmap;
private final int hashCount;

public BloomFilter(@NonNull byte[] bitmap, int padding, int hashCount) {
if (padding < 0 || padding >= 8) {
throw new IllegalArgumentException("Invalid padding: " + padding);
}

if (bitmap.length > 0) {
// Only empty bloom filter can have 0 hash count.
if (hashCount <= 0) {
throw new IllegalArgumentException("Invalid hash count: " + hashCount);
}
} else {
if (hashCount < 0) {
throw new IllegalArgumentException("Invalid hash count: " + hashCount);
}

// Empty bloom filter should have 0 padding.
if (padding != 0) {
throw new IllegalArgumentException("Invalid padding when bitmap length is 0: " + padding);
}
}
this.bitmap = bitmap;
this.hashCount = hashCount;
this.size = bitmap.length * 8 - padding;
}

/** Return if a bloom filter is empty. */
@VisibleForTesting
boolean isEmpty() {
return this.size == 0;
}

public boolean mightContain(@NonNull String value) {
// Empty bitmap or empty value should always return false on membership check.
if (this.isEmpty() || value.isEmpty()) {
return false;
}

byte[] md5HashedValue = md5Hash(value);
if (md5HashedValue.length != 16) {
throw new RuntimeException(
"Invalid md5HashedValue.length: " + md5HashedValue.length + " (expected 16)");
}

long hash1 = getLongLittleEndian(md5HashedValue, 0);
long hash2 = getLongLittleEndian(md5HashedValue, 8);

for (int i = 0; i < this.hashCount; i++) {
int index = this.getBitIndex(hash1, hash2, i);
if (!this.isBitSet(index)) {
return false;
}
}
return true;
}

@NonNull
public static byte[] md5Hash(@NonNull String value) {
MessageDigest digest;
try {
digest = MessageDigest.getInstance("MD5");
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("Missing MD5 MessageDigest provider.", e);
}
return digest.digest(value.getBytes());
}

// Interpret 8 bytes into a long, using little endian 2’s complement.
public static long getLongLittleEndian(@NonNull byte[] bytes, int offset) {
long result = 0;
for (int i = 0; i < 8 && i < bytes.length; i++) {
result |= (bytes[offset + i] & 0xFFL) << (i * 8);
}
return result;
}

// Calculate the ith hash value based on the hashed 64bit integers,
// and calculate its corresponding bit index in the bitmap to be checked.
private int getBitIndex(long hash1, long hash2, int index) {
// Calculate hashed value h(i) = h1 + (i * h2).
long combinedHash = hash1 + (hash2 * index);
long mod = UnsignedLong.remainder(combinedHash, this.size);
return (int) mod;
}

// Return whether the bit on the given index in the bitmap is set to 1.
private boolean isBitSet(int index) {
// To retrieve bit n, calculate: (bitmap[n / 8] & (0x01 << (n % 8))).
byte byteAtIndex = this.bitmap[(index / 8)];
int offset = index % 8;
return (byteAtIndex & (0x01 << offset)) != 0;
}

@Override
public String toString() {
return "BloomFilter{"
+ ", hashCount="
+ hashCount
+ ", size="
+ size
+ "bitmap="
+ Base64.encodeToString(bitmap, Base64.NO_WRAP)
+ '}';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (C) 2008 The Guava Authors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
* in compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*/
package com.google.firebase.firestore.remote;

public class UnsignedLong {

/**
* Returns dividend % divisor, where the dividend and divisor are treated as unsigned 64-bit
* quantities.
*
* <p><b>Java 8 users:</b> use {@link Long#remainderUnsigned(long, long)} instead.
*
* @param dividend the dividend (numerator)
* @param divisor the divisor (denominator)
* @throws ArithmeticException if divisor is 0
* @since 11.0
*/
public static long remainder(long dividend, long divisor) {
if (divisor < 0) { // i.e., divisor >= 2^63:
if (compare(dividend, divisor) < 0) {
return dividend; // dividend < divisor
} else {
return dividend - divisor; // dividend >= divisor
}
}

// Optimization - use signed modulus if dividend < 2^63
if (dividend >= 0) {
return dividend % divisor;
}

/*
* Otherwise, approximate the quotient, check, and correct if necessary. Our approximation is
* guaranteed to be either exact or one less than the correct value. This follows from the fact
* that floor(floor(x)/i) == floor(x/i) for any real x and integer i != 0. The proof is not
* quite trivial.
*/
long quotient = ((dividend >>> 1) / divisor) << 1;
long rem = dividend - quotient * divisor;
return rem - (compare(rem, divisor) >= 0 ? divisor : 0);
}

/**
* Compares the two specified {@code long} values, treating them as unsigned values between {@code
* 0} and {@code 2^64 - 1} inclusive.
*
* <p><b>Java 8 users:</b> use {@link Long#compareUnsigned(long, long)} instead.
*
* @param a the first unsigned {@code long} to compare
* @param b the second unsigned {@code long} to compare
* @return a negative value if {@code a} is less than {@code b}; a positive value if {@code a} is
* greater than {@code b}; or zero if they are equal
*/
public static int compare(long a, long b) {
a = flip(a);
b = flip(b);
return Long.compare(a, b);
}

/**
* A (self-inverse) bijection which converts the ordering on unsigned longs to the ordering on
* longs, that is, {@code a <= b} as unsigned longs if and only if {@code flip(a) <= flip(b)} as
* signed longs.
*/
private static long flip(long a) {
return a ^ Long.MIN_VALUE;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.firebase.firestore.remote;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.Base64;
import java.util.HashMap;
import java.util.stream.Stream;
import org.json.JSONObject;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
import org.robolectric.annotation.Config;

@RunWith(RobolectricTestRunner.class)
@Config(manifest = Config.NONE)
public class BloomFilterTest {

@Test
public void testEmptyBloomFilter() {
BloomFilter bloomFilter = new BloomFilter(new byte[0], 0, 0);
assertTrue(bloomFilter.isEmpty());
}

@Test
public void testEmptyBloomFilterThrowException() {
IllegalArgumentException paddingException =
assertThrows(IllegalArgumentException.class, () -> new BloomFilter(new byte[0], 1, 0));
assertThat(paddingException)
.hasMessageThat()
.contains("Invalid padding when bitmap length is 0: 1");
IllegalArgumentException hashCountException =
assertThrows(IllegalArgumentException.class, () -> new BloomFilter(new byte[0], 0, -1));
assertThat(hashCountException).hasMessageThat().contains("Invalid hash count: -1");
}

@Test
public void testNonEmptyBloomFilter() {
BloomFilter bloomFilter1 = new BloomFilter(new byte[1], 0, 1);
assertFalse(bloomFilter1.isEmpty());
BloomFilter bloomFilter2 = new BloomFilter(new byte[1], 7, 1);
assertFalse(bloomFilter2.isEmpty());
}

@Test
public void testNonEmptyBloomFilterThrowException() {
IllegalArgumentException negativePaddingException =
assertThrows(IllegalArgumentException.class, () -> new BloomFilter(new byte[1], -1, 1));
assertThat(negativePaddingException).hasMessageThat().contains("Invalid padding: -1");
IllegalArgumentException overflowPaddingException =
assertThrows(IllegalArgumentException.class, () -> new BloomFilter(new byte[1], 8, 1));
assertThat(overflowPaddingException).hasMessageThat().contains("Invalid padding: 8");

IllegalArgumentException negativeHashCountException =
assertThrows(IllegalArgumentException.class, () -> new BloomFilter(new byte[1], 1, -1));
assertThat(negativeHashCountException).hasMessageThat().contains("Invalid hash count: -1");
IllegalArgumentException zeroHashCountException =
assertThrows(IllegalArgumentException.class, () -> new BloomFilter(new byte[1], 1, 0));
assertThat(zeroHashCountException).hasMessageThat().contains("Invalid hash count: 0");
}

@Test
public void testBloomFilterProcessNonStandardCharacters() {
// A non-empty BloomFilter object with 1 insertion : "ÀÒ∑"
BloomFilter bloomFilter = new BloomFilter(new byte[] {(byte) 237, 5}, 5, 8);
assertTrue(bloomFilter.mightContain("ÀÒ∑"));
assertFalse(bloomFilter.mightContain("Ò∑À"));
}

@Test
public void testEmptyBloomFilterMightContainAlwaysReturnFalse() {
BloomFilter bloomFilter = new BloomFilter(new byte[0], 0, 0);
assertFalse(bloomFilter.mightContain("abc"));
}

@Test
public void testBloomFilterMightContainOnEmptyStringAlwaysReturnFalse() {
BloomFilter emptyBloomFilter = new BloomFilter(new byte[0], 0, 0);
BloomFilter nonEmptyBloomFilter =
new BloomFilter(new byte[] {(byte) 255, (byte) 255, (byte) 255}, 1, 16);

assertFalse(emptyBloomFilter.mightContain(""));
assertFalse(nonEmptyBloomFilter.mightContain(""));
}

/**
* Golden tests are generated by backend based on inserting n number of document paths into a
* bloom filter.
*
* <p>Full document path is generated by concatenating documentPrefix and number n, eg,
* projects/project-1/databases/database-1/documents/coll/doc12.
*
* <p>The test result is generated by checking the membership of documents from documentPrefix+0
* to documentPrefix+2n. The membership results from 0 to n is expected to be true, and the
* membership results from n to 2n is expected to be false with some false positive results.
*/
@Test
public void testBloomFilterGoldenTest() throws Exception {
String documentPrefix = "projects/project-1/databases/database-1/documents/coll/doc";

// Import the golden test files for bloom filter
HashMap<String, JSONObject> parsedSpecFiles = new HashMap<>();
File jsonDir = new File("src/test/resources/bloom_filter_golden_test_data");
File[] jsonFiles = jsonDir.listFiles();
assert jsonFiles != null;
for (File file : jsonFiles) {
if (!file.toString().endsWith(".json")) {
continue;
}

// Read the files into a map.
StringBuilder builder = new StringBuilder();
BufferedReader reader = new BufferedReader(new FileReader(file));
Stream<String> lines = reader.lines();
lines.forEach(builder::append);
String json = builder.toString();
JSONObject fileJSON = new JSONObject(json);
parsedSpecFiles.put(file.getName(), fileJSON);
}

// Loop and test the files
for (String fileName : parsedSpecFiles.keySet()) {
if (fileName.contains("membership_test_result")) {
continue;
}

// Read test data and instantiate a BloomFilter object
JSONObject fileJSON = parsedSpecFiles.get(fileName);
assert fileJSON != null;
JSONObject bits = fileJSON.getJSONObject("bits");
String bitmap = bits.getString("bitmap");
int padding = bits.getInt("padding");
int hashCount = fileJSON.getInt("hashCount");
BloomFilter bloomFilter =
new BloomFilter(Base64.getDecoder().decode(bitmap), padding, hashCount);

// Find corresponding membership test result.
JSONObject resultJSON =
parsedSpecFiles.get(fileName.replace("bloom_filter_proto", "membership_test_result"));
assert resultJSON != null;
String membershipTestResults = resultJSON.getString("membershipTestResults");

// Run and compare mightContain result with the expectation.
for (int i = 0; i < membershipTestResults.length(); i++) {
boolean expectedMembershipResult = membershipTestResults.charAt(i) == '1';
boolean mightContain = bloomFilter.mightContain(documentPrefix + i);
assertEquals(
"MightContain result doesn't match the expectation. File: "
+ fileName
+ ". Document: "
+ documentPrefix
+ i,
mightContain,
expectedMembershipResult);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{ "bits": { "bitmap": "RswZ", "padding": 1 }, "hashCount": 16 }
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"membershipTestResults" : "10"}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"bits":{"bitmap":"mwE=","padding":5},"hashCount":8}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"membershipTestResults" : "10"}
Loading