From aeb12cd3694a397f0d513061042e70026b1afb0d Mon Sep 17 00:00:00 2001 From: David Heryanto Date: Thu, 18 Apr 2019 14:55:23 +0800 Subject: [PATCH] Fix BigQuery query template to retrieve training data (#182) * Fix BigQuery query template to retrieve training data * Update expected value BigQuery template test * Use FeatureInfo to create Features in BigQueryDatasetTemplater so it's neater --- .../training/BigQueryDatasetTemplater.java | 29 +++++++------------ .../main/resources/templates/bq_training.tmpl | 10 +++---- .../BigQueryDatasetTemplaterTest.java | 24 +++++++++------ core/src/test/resources/sql/expQuery1.sql | 12 ++++---- core/src/test/resources/sql/expQuery2.sql | 8 ++--- 5 files changed, 39 insertions(+), 44 deletions(-) diff --git a/core/src/main/java/feast/core/training/BigQueryDatasetTemplater.java b/core/src/main/java/feast/core/training/BigQueryDatasetTemplater.java index ea73d89d44..569bba1150 100644 --- a/core/src/main/java/feast/core/training/BigQueryDatasetTemplater.java +++ b/core/src/main/java/feast/core/training/BigQueryDatasetTemplater.java @@ -22,19 +22,15 @@ import feast.core.dao.FeatureInfoRepository; import feast.core.model.FeatureInfo; import feast.core.model.StorageInfo; -import feast.specs.FeatureSpecProto.FeatureSpec; import feast.specs.StorageSpecProto.StorageSpec; +import lombok.Getter; + import java.time.Instant; import java.time.ZoneId; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; -import lombok.Getter; public class BigQueryDatasetTemplater { private final FeatureInfoRepository featureInfoRepository; @@ -59,10 +55,11 @@ public BigQueryDatasetTemplater( * @param limit limit * @return SQL query for creating training table. */ - public String createQuery( - FeatureSet featureSet, Timestamp startDate, Timestamp endDate, long limit) { + String createQuery(FeatureSet featureSet, Timestamp startDate, Timestamp endDate, long limit) { List featureIds = featureSet.getFeatureIdsList(); List featureInfos = featureInfoRepository.findAllById(featureIds); + Features features = new Features(featureInfos); + if (featureInfos.size() < featureIds.size()) { Set foundFeatureIds = featureInfos.stream().map(FeatureInfo::getId).collect(Collectors.toSet()); @@ -70,9 +67,6 @@ public String createQuery( throw new NoSuchElementException("features not found: " + featureIds); } - String tableId = getBqTableId(featureInfos.get(0)); - Features features = new Features(featureIds, tableId); - String startDateStr = formatDateString(startDate); String endDateStr = formatDateString(endDate); String limitStr = (limit != 0) ? String.valueOf(limit) : null; @@ -90,7 +84,7 @@ private String renderTemplate( return jinjava.render(template, context); } - private String getBqTableId(FeatureInfo featureInfo) { + private static String getBqTableId(FeatureInfo featureInfo) { StorageInfo whStorage = featureInfo.getWarehouseStore(); String type = whStorage.getType(); @@ -117,12 +111,9 @@ static final class Features { final List columns; final String tableId; - public Features(List featureIds, String tableId) { - this.columns = featureIds.stream() - .map(f -> f.replace(".", "_")) - .collect(Collectors.toList()); - this.tableId = tableId; + Features(List featureInfos) { + columns = featureInfos.stream().map(FeatureInfo::getName).collect(Collectors.toList()); + tableId = featureInfos.size() > 0 ? getBqTableId(featureInfos.get(0)) : ""; } } - } diff --git a/core/src/main/resources/templates/bq_training.tmpl b/core/src/main/resources/templates/bq_training.tmpl index 0a7ce6d322..df7c301ae2 100644 --- a/core/src/main/resources/templates/bq_training.tmpl +++ b/core/src/main/resources/templates/bq_training.tmpl @@ -1,11 +1,9 @@ SELECT - {{ feature_set.tableId }}.id, - {{ feature_set.tableId }}.event_timestamp - {% for feature in feature_set.columns -%} - ,{{ feature }} - {%- endfor %} + id, + event_timestamp{%- if feature_set.columns | length > 0 %},{%- endif %} + {{ feature_set.columns | join(',') }} FROM - {{ feature_set.tableId }} + `{{ feature_set.tableId }}` WHERE event_timestamp >= TIMESTAMP("{{ start_date }}") AND event_timestamp <= TIMESTAMP(DATETIME_ADD("{{ end_date }}", INTERVAL 1 DAY)) {% if limit is not none -%} LIMIT {{ limit }} diff --git a/core/src/test/java/feast/core/training/BigQueryDatasetTemplaterTest.java b/core/src/test/java/feast/core/training/BigQueryDatasetTemplaterTest.java index d8b4595d3e..114dcec539 100644 --- a/core/src/test/java/feast/core/training/BigQueryDatasetTemplaterTest.java +++ b/core/src/test/java/feast/core/training/BigQueryDatasetTemplaterTest.java @@ -96,10 +96,11 @@ public void shouldPassCorrectArgumentToTemplateEngine() { Timestamps.fromSeconds(Instant.parse("2019-01-01T00:00:00.00Z").getEpochSecond()); int limit = 100; String featureId = "myentity.feature1"; + String featureName = "feature1"; String tableId = "project.dataset.myentity"; when(featureInfoRespository.findAllById(any(List.class))) - .thenReturn(Collections.singletonList(createFeatureInfo(featureId, tableId))); + .thenReturn(Collections.singletonList(createFeatureInfo(featureId, featureName, tableId))); FeatureSet fs = FeatureSet.newBuilder() @@ -123,7 +124,7 @@ public void shouldPassCorrectArgumentToTemplateEngine() { Features features = (Features) actualContext.get("feature_set"); assertThat(features.getColumns().size(), equalTo(1)); - assertThat(features.getColumns().get(0), equalTo(featureId.replace(".", "_"))); + assertThat(features.getColumns().get(0), equalTo(featureName)); assertThat(features.getTableId(), equalTo(tableId)); } @@ -131,14 +132,17 @@ public void shouldPassCorrectArgumentToTemplateEngine() { public void shouldRenderCorrectQuery1() throws Exception { String tableId1 = "project.dataset.myentity"; String featureId1 = "myentity.feature1"; + String featureName1 = "feature1"; String featureId2 = "myentity.feature2"; + String featureName2 = "feature2"; - FeatureInfo featureInfo1 = createFeatureInfo(featureId1, tableId1); - FeatureInfo featureInfo2 = createFeatureInfo(featureId2, tableId1); + FeatureInfo featureInfo1 = createFeatureInfo(featureId1, featureName1, tableId1); + FeatureInfo featureInfo2 = createFeatureInfo(featureId2, featureName2, tableId1); String tableId2 = "project.dataset.myentity"; String featureId3 = "myentity.feature3"; - FeatureInfo featureInfo3 = createFeatureInfo(featureId3, tableId2); + String featureName3 = "feature3"; + FeatureInfo featureInfo3 = createFeatureInfo(featureId3, featureName3, tableId2); when(featureInfoRespository.findAllById(any(List.class))) .thenReturn(Arrays.asList(featureInfo1, featureInfo2, featureInfo3)); @@ -166,8 +170,9 @@ public void shouldRenderCorrectQuery2() throws Exception { String tableId = "project.dataset.myentity"; String featureId = "myentity.feature1"; + String featureName = "feature1"; - featureInfos.add(createFeatureInfo(featureId, tableId)); + featureInfos.add(createFeatureInfo(featureId, featureName, tableId)); featureIds.add(featureId); when(featureInfoRespository.findAllById(any(List.class))).thenReturn(featureInfos); @@ -197,7 +202,7 @@ private void checkExpectedQuery(String query, String pathToExpQuery) throws Exce assertThat(query, equalTo(expQuery)); } - private FeatureInfo createFeatureInfo(String id, String tableId) { + private FeatureInfo createFeatureInfo(String featureId, String featureName, String tableId) { StorageSpec storageSpec = StorageSpec.newBuilder() .setId("BQ") @@ -209,11 +214,12 @@ private FeatureInfo createFeatureInfo(String id, String tableId) { FeatureSpec fs = FeatureSpec.newBuilder() - .setId(id) + .setId(featureId) + .setName(featureName) .setDataStores(DataStores.newBuilder().setWarehouse(DataStore.newBuilder().setId("BQ"))) .build(); - EntitySpec entitySpec = EntitySpec.newBuilder().setName(id.split("\\.")[0]).build(); + EntitySpec entitySpec = EntitySpec.newBuilder().setName(featureId.split("\\.")[0]).build(); EntityInfo entityInfo = new EntityInfo(entitySpec); return new FeatureInfo(fs, entityInfo, null, storageInfo, null); } diff --git a/core/src/test/resources/sql/expQuery1.sql b/core/src/test/resources/sql/expQuery1.sql index 97d21c3fcb..019f5d9cbd 100644 --- a/core/src/test/resources/sql/expQuery1.sql +++ b/core/src/test/resources/sql/expQuery1.sql @@ -1,11 +1,11 @@ SELECT - project.dataset.myentity.id, - project.dataset.myentity.event_timestamp , - myentity_feature1, - myentity_feature2, - myentity_feature3 + id, + event_timestamp, + feature1, + feature2, + feature3 FROM - project.dataset.myentity + `project.dataset.myentity` WHERE event_timestamp >= TIMESTAMP("2018-01-02") AND event_timestamp <= TIMESTAMP(DATETIME_ADD("2018-01-30", INTERVAL 1 DAY)) LIMIT 100 \ No newline at end of file diff --git a/core/src/test/resources/sql/expQuery2.sql b/core/src/test/resources/sql/expQuery2.sql index 9b9f74a81d..e9b212d20a 100644 --- a/core/src/test/resources/sql/expQuery2.sql +++ b/core/src/test/resources/sql/expQuery2.sql @@ -1,9 +1,9 @@ SELECT - project.dataset.myentity.id, - project.dataset.myentity.event_timestamp , - myentity_feature1 + id, + event_timestamp, + feature1 FROM - project.dataset.myentity + `project.dataset.myentity` WHERE event_timestamp >= TIMESTAMP("2018-01-02") AND event_timestamp <= TIMESTAMP(DATETIME_ADD("2018-01-30", INTERVAL 1 DAY)) LIMIT 1000 \ No newline at end of file