Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions rialto/loader/pyspark_feature_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from datetime import date
from typing import Dict, List, Union

from loguru import logger
from pyspark.sql import DataFrame, SparkSession

from rialto.common import TableReader
Expand Down Expand Up @@ -54,13 +55,25 @@ def __init__(

if isinstance(feature_schema, str):
feature_schema = [feature_schema]
self.feature_schemas = {}
for schema in feature_schema:
self.feature_schemas[schema] = None

self.feature_schemas = feature_schema
self.date_col = date_column
self.metadata = MetadataManager(spark, metadata_schema)

KeyMap = namedtuple("KeyMap", ["df", "key"])

def fetch_schema_tables(self, schema) -> List[str]:
"""
Fetch all tables in a schema

:param schema: schema name
:return: list of table names
"""
logger.info(f"Fetching tables in schema {schema}")
return [table.name for table in self.spark.catalog.listTables(schema)]

def read_group(self, group: str, information_date: date) -> DataFrame:
"""
Read a feature group by getting the latest partition by date
Expand All @@ -69,12 +82,15 @@ def read_group(self, group: str, information_date: date) -> DataFrame:
:param information_date: partition date
:return: dataframe
"""
selected = []
for schema, tables in self.feature_schemas.items():
if tables is None:
self.feature_schemas[schema] = self.fetch_schema_tables(schema)

for schema in self.feature_schemas:
tables = self.spark.catalog.listTables(schema)
if any(table.name == group for table in tables):
selected = []
for schema, tables in self.feature_schemas.items():
if group in tables:
selected.append(schema)

if len(selected) > 1:
raise ValueError(f"Multiple feature schemas contain table {group}: {selected}.")
elif len(selected) == 0:
Expand Down
21 changes: 17 additions & 4 deletions rialto/metadata/metadata_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import List

from delta.tables import DeltaTable
from loguru import logger

from rialto.metadata.data_classes.feature_metadata import FeatureMetadata
from rialto.metadata.data_classes.group_metadata import GroupMetadata
Expand All @@ -34,11 +35,22 @@ def __init__(self, session, schema_path: str = None):
self.groups = None
self.features = None

self.loaded = False

def _load_metadata(self):
if self.groups is None:
self.groups = self.spark.read.table(self.groups_path)
if self.features is None:
self.features = self.spark.read.table(self.features_path)
if not self.loaded:
logger.info(f"Loading metadata from {self.groups_path} and {self.features_path}")
self.groups = self.spark.read.table(self.groups_path).cache()
self.features = self.spark.read.table(self.features_path).cache()
self.loaded = True

def _reload_metadata(self):
logger.info(f"Loading metadata from {self.groups_path} and {self.features_path}")
self.groups.unpersist()
self.groups = self.spark.read.table(self.groups_path).cache()

self.features.unpersist()
self.features = self.spark.read.table(self.features_path).cache()

def _fetch_group_by_name(self, group_name: str) -> GroupMetadata:
group = self.groups.filter(self.groups.group_name == group_name).collect()
Expand Down Expand Up @@ -92,6 +104,7 @@ def update(
self._load_metadata()
self._add_group(group_md)
self._add_features(features_md, group_md.name)
self._reload_metadata()

def get_feature(self, group_name: str, feature_name: str) -> FeatureMetadata:
"""
Expand Down
1 change: 1 addition & 0 deletions tests/metadata/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ def mdc(spark):
mdc = MetadataManager(spark)
mdc.groups = spark.createDataFrame(group_base, group_schema)
mdc.features = spark.createDataFrame(feature_base, feature_schema)
mdc.loaded = True
return mdc