Skip to content

Commit

Permalink
Merge pull request #290 from IABTechLab/wzh-uid2-3572-small-change-in…
Browse files Browse the repository at this point in the history
…-s3keyprovider

improve s3 key manger, add key selection logics
  • Loading branch information
lizk886 authored Aug 1, 2024
2 parents a928bcd + e7bffd0 commit 0c008f0
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 72 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>com.uid2</groupId>
<artifactId>uid2-shared</artifactId>
<version>7.16.0</version>
<version>7.16.10-alpha-131-SNAPSHOT</version>
<name>${project.groupId}:${project.artifactId}</name>
<description>Library for all the shared uid2 operations</description>
<url>https://github.com/IABTechLab/uid2docs</url>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,26 @@
import com.uid2.shared.store.ScopedStoreReader;
import com.uid2.shared.store.parser.S3KeyParser;
import com.uid2.shared.store.scope.StoreScope;
import com.uid2.shared.model.S3Key;
import io.vertx.core.json.JsonObject;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.ArrayList;
import java.util.Set;
import java.util.HashSet;
import java.util.Map;
import java.util.List;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.stream.Collectors;

import com.uid2.shared.model.S3Key;
import java.time.Instant;

public class RotatingS3KeyProvider implements StoreReader<Map<Integer, S3Key>> {
ScopedStoreReader<Map<Integer, S3Key>> reader;

private static final Logger LOGGER = LoggerFactory.getLogger(RotatingS3KeyProvider.class);
public final Map<Integer, List<S3Key>> siteToKeysMap = new HashMap<>();
public Map<Integer, List<S3Key>> siteToKeysMap = new HashMap<>();

public RotatingS3KeyProvider(DownloadCloudStorage fileStreamProvider, StoreScope scope) {
this.reader = new ScopedStoreReader<>(fileStreamProvider, scope, new S3KeyParser(), "s3encryption_keys");
Expand All @@ -46,7 +47,9 @@ public long getVersion(JsonObject metadata) {

@Override
public long loadContent(JsonObject metadata) throws Exception {
return reader.loadContent(metadata, "s3encryption_keys");
long result = reader.loadContent(metadata, "s3encryption_keys");
updateSiteToKeysMapping();
return result;
}

@Override
Expand All @@ -58,9 +61,11 @@ public Map<Integer, S3Key> getAll() {
public void updateSiteToKeysMapping() {
Map<Integer, S3Key> allKeys = getAll();
siteToKeysMap.clear();
for (S3Key key : allKeys.values()) {
siteToKeysMap.computeIfAbsent(key.getSiteId(), k -> new ArrayList<>()).add(key);
}
allKeys.values().forEach(key ->
this.siteToKeysMap
.computeIfAbsent(key.getSiteId(), k -> new ArrayList<>())
.add(key)
);
LOGGER.info("Updated site-to-keys mapping for {} sites", siteToKeysMap.size());
}

Expand All @@ -77,28 +82,28 @@ public int getTotalSites() {
return siteToKeysMap.size();
}

public List<S3Key> getKeys(int siteId) {
//for s3 encryption keys retrieval
return siteToKeysMap.getOrDefault(siteId, new ArrayList<>());
}

public Collection<S3Key> getKeysForSite(Integer siteId) {
Map<Integer, S3Key> allKeys = getAll();
return allKeys.values().stream()
.filter(key -> key.getSiteId()==(siteId))
.collect(Collectors.toList());
}

public S3Key getEncryptionKeyForSite(Integer siteId) {
public S3Key getEncryptionKeyForSite(Integer siteId) {
//get the youngest activated key
Collection<S3Key> keys = getKeysForSite(siteId);
if (keys.isEmpty()) {
long now = Instant.now().getEpochSecond();
if (keys.isEmpty()) {
throw new IllegalStateException("No S3 keys available for encryption for site ID: " + siteId);
} else {
Map<Integer, S3Key> allKeys = getAll();
S3Key largestKey = null;
for (S3Key key : allKeys.values()) {
if (key.getSiteId() == siteId) {
if (largestKey == null || key.getId() > largestKey.getId()) {
largestKey = key;
}
}
}
return largestKey;
}
return keys.stream()
.filter(key -> key.getActivates() <= now)
.max(Comparator.comparingLong(S3Key::getCreated))
.orElseThrow(() -> new IllegalStateException("No active keys found for site ID: " + siteId));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ public synchronized void refresh() throws Exception {
final long version = this.versionedStore.getVersion(metadata);
if (version > this.latestVersion.get()) {
long entryCount = this.versionedStore.loadContent(metadata);
if (this.versionedStore instanceof RotatingS3KeyProvider) {
((RotatingS3KeyProvider) this.versionedStore).updateSiteToKeysMapping();
}
this.latestVersion.set(version);
this.latestEntryCount.set(entryCount);
LOGGER.info("Successfully loaded " + this.storeName + " version " + version);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.time.Instant;
import java.util.Set;
import java.util.Map;
import java.util.List;
import java.util.HashMap;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
Expand All @@ -33,6 +37,8 @@ public class RotatingS3KeyProviderTest {

private RotatingS3KeyProvider rotatingS3KeyProvider;

private static final long CURRENT_TIME = Instant.now().getEpochSecond();

@BeforeEach
void setUp() {
MockitoAnnotations.openMocks(this);
Expand Down Expand Up @@ -192,25 +198,6 @@ void testGetKeysForSite() {
assertTrue(noKeysRetrieved.isEmpty());
}

@Test
void testGetEncryptionKeyForSite() {
Map<Integer, S3Key> existingKeys = new HashMap<>();
S3Key key1 = new S3Key(1, 123, 1687635529, 1687808329, "S3keySecretByteHere1");
S3Key key2 = new S3Key(2, 123, 1687808429, 1687808329, "S3keySecretByteHere2");
S3Key key3 = new S3Key(3, 123, 1687635529, 1687808329, "S3keySecretByteHere3");
existingKeys.put(1, key1);
existingKeys.put(2, key2);
existingKeys.put(3, key3);
when(reader.getSnapshot()).thenReturn(existingKeys);

S3Key retrievedKey = rotatingS3KeyProvider.getEncryptionKeyForSite(123);
assertNotNull(retrievedKey);
assertEquals(key3, retrievedKey);

when(reader.getSnapshot()).thenReturn(new HashMap<>());
assertThrows(IllegalStateException.class, () -> rotatingS3KeyProvider.getEncryptionKeyForSite(456));
}

@Test
void testGetAllWithSingleKey() {
Map<Integer, S3Key> existingKeys = new HashMap<>();
Expand All @@ -226,7 +213,7 @@ void testGetAllWithSingleKey() {
@Test
void testGetEncryptionKeyForSiteWithSingleKey() {
Map<Integer, S3Key> existingKeys = new HashMap<>();
S3Key singleKey = new S3Key(1, 123, 1687635529, 1687808329, "S3keySecretByteHere1");
S3Key singleKey = new S3Key(1, 123, CURRENT_TIME - 1000, CURRENT_TIME + 1000, "S3keySecretByteHere1");
existingKeys.put(1, singleKey);
when(reader.getSnapshot()).thenReturn(existingKeys);

Expand Down Expand Up @@ -264,28 +251,17 @@ void testGetKeysForSiteWithMultipleKeys() {
assertTrue(retrievedKeys.contains(key2));
}

@Test
void testGetEncryptionKeyForSiteWithMultipleKeys() {
Map<Integer, S3Key> existingKeys = new HashMap<>();
S3Key key1 = new S3Key(1, 123, 1687635529, 1687808329, "S3keySecretByteHere1");
S3Key key2 = new S3Key(2, 123, 1687808429, 1687808329, "S3keySecretByteHere2");
existingKeys.put(1, key1);
existingKeys.put(2, key2);
when(reader.getSnapshot()).thenReturn(existingKeys);

S3Key retrievedKey = rotatingS3KeyProvider.getEncryptionKeyForSite(123);
assertNotNull(retrievedKey);
assertEquals(key2, retrievedKey);
}

@Test
void testGetEncryptionKeyForNonExistentSite() {
Map<Integer, S3Key> existingKeys = new HashMap<>();
S3Key key1 = new S3Key(1, 123, 1687635529, 1687808329, "S3keySecretByteHere1");
S3Key key1 = new S3Key(1, 123, CURRENT_TIME - 1000, CURRENT_TIME + 1000, "S3keySecretByteHere1");
existingKeys.put(1, key1);
when(reader.getSnapshot()).thenReturn(existingKeys);

assertThrows(IllegalStateException.class, () -> rotatingS3KeyProvider.getEncryptionKeyForSite(456));
IllegalStateException exception = assertThrows(IllegalStateException.class,
() -> rotatingS3KeyProvider.getEncryptionKeyForSite(456));

assertEquals("No S3 keys available for encryption for site ID: 456", exception.getMessage());
}

@Test
Expand Down Expand Up @@ -323,12 +299,16 @@ void testGetKeysForSiteWithEmptyMap() {
@Test
void testGetEncryptionKeyForSiteWithMultipleKeysAndNonExistentSite() {
Map<Integer, S3Key> existingKeys = new HashMap<>();
S3Key key1 = new S3Key(1, 123, 1687635529, 1687808329, "S3keySecretByteHere1");
S3Key key2 = new S3Key(2, 123, 1687808429, 1687808329, "S3keySecretByteHere2");
S3Key key1 = new S3Key(1, 123, CURRENT_TIME - 2000, CURRENT_TIME + 1000, "S3keySecretByteHere1");
S3Key key2 = new S3Key(2, 123, CURRENT_TIME - 1000, CURRENT_TIME + 2000, "S3keySecretByteHere2");
existingKeys.put(1, key1);
existingKeys.put(2, key2);
when(reader.getSnapshot()).thenReturn(existingKeys);
assertThrows(IllegalStateException.class, () -> rotatingS3KeyProvider.getEncryptionKeyForSite(456));

IllegalStateException exception = assertThrows(IllegalStateException.class,
() -> rotatingS3KeyProvider.getEncryptionKeyForSite(456));

assertEquals("No S3 keys available for encryption for site ID: 456", exception.getMessage());
}

@Test
Expand Down Expand Up @@ -420,4 +400,74 @@ void testGetTotalSites() {
int totalSites = rotatingS3KeyProvider.getTotalSites();
assertEquals(2, totalSites);
}

@Test
void testGetKeysForSiteFromMap() {
S3Key key1 = new S3Key(1, 100, 1687635529, 1687808329, "secret1");
S3Key key2 = new S3Key(2, 100, 1687808429, 1687981229, "secret2");
S3Key key3 = new S3Key(3, 200, 1687981329, 1688154129, "secret3");

Map<Integer, List<S3Key>> testMap = new HashMap<>();
testMap.put(100, Arrays.asList(key1, key2));
testMap.put(200, Collections.singletonList(key3));

rotatingS3KeyProvider.siteToKeysMap = testMap;

List<S3Key> result1 = rotatingS3KeyProvider.getKeys(100);
assertEquals(2, result1.size());
assertTrue(result1.contains(key1));
assertTrue(result1.contains(key2));

List<S3Key> result2 = rotatingS3KeyProvider.getKeys(200);
assertEquals(1, result2.size());
assertTrue(result2.contains(key3));

List<S3Key> result3 = rotatingS3KeyProvider.getKeys(300);
assertTrue(result3.isEmpty());
}

@Test
void testGetKeysForSiteFromMapWithEmptyMap() {
rotatingS3KeyProvider.siteToKeysMap = new HashMap<>();

List<S3Key> result = rotatingS3KeyProvider.getKeys(100);
assertTrue(result.isEmpty());
}

@Test
void testGetKeysForSiteFromMapWithNullMap() {
rotatingS3KeyProvider.siteToKeysMap = null;

assertThrows(NullPointerException.class, () -> rotatingS3KeyProvider.getKeys(100));
}

@Test
void testGetEncryptionKeyForSite() {
Map<Integer, S3Key> existingKeys = new HashMap<>();
S3Key key1 = new S3Key(1, 123, CURRENT_TIME - 3000, 1687808329, "S3keySecretByteHere1");
S3Key key2 = new S3Key(2, 123, CURRENT_TIME - 2000, 1687981229, "S3keySecretByteHere2");
S3Key key3 = new S3Key(3, 123, CURRENT_TIME - 1000, 1688154129, "S3keySecretByteHere3");
S3Key key4 = new S3Key(4, 123, CURRENT_TIME + 1000, 1688327029, "S3keySecretByteHere4"); // Future key
existingKeys.put(1, key1);
existingKeys.put(2, key2);
existingKeys.put(3, key3);
existingKeys.put(4, key4);
when(reader.getSnapshot()).thenReturn(existingKeys);

S3Key retrievedKey = rotatingS3KeyProvider.getEncryptionKeyForSite(123);
assertNotNull(retrievedKey);
assertEquals(key3, retrievedKey); // Should return the most recent active key
}

@Test
void testGetEncryptionKeyForSiteWithNoActiveKeys() {
Map<Integer, S3Key> existingKeys = new HashMap<>();
S3Key key1 = new S3Key(1, 123, CURRENT_TIME + 1000, 1687808329, "S3keySecretByteHere1");
S3Key key2 = new S3Key(2, 123, CURRENT_TIME + 2000, 1687981229, "S3keySecretByteHere2");
existingKeys.put(1, key1);
existingKeys.put(2, key2);
when(reader.getSnapshot()).thenReturn(existingKeys);

assertThrows(IllegalStateException.class, () -> rotatingS3KeyProvider.getEncryptionKeyForSite(123));
}
}

0 comments on commit 0c008f0

Please sign in to comment.