Skip to content

Commit

Permalink
Visualize sequential data with line plot (#500)
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 authored Nov 9, 2023
1 parent cea4dac commit 87790f2
Show file tree
Hide file tree
Showing 3 changed files with 442 additions and 2 deletions.
2 changes: 2 additions & 0 deletions sdmetrics/reports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class PlotConfig:
DATACEBO_GREEN = '#01E0C9'
DATACEBO_BLUE = '#03AFF1'
BACKGROUND_COLOR = '#F5F5F8'
DATACEBO_DARK_TRANSPARENT = 'rgba(0, 0, 54, 0.25)'
DATACEBO_GREEN_TRANSPARENT = 'rgba(1, 224, 201, 0.25)'
FONT_SIZE = 18


Expand Down
184 changes: 184 additions & 0 deletions sdmetrics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
import plotly.io as pio
from pandas.api.types import is_datetime64_dtype

Expand Down Expand Up @@ -512,3 +513,186 @@ def get_column_pair_plot(real_data, synthetic_data, column_names, plot_type=None
return _generate_heatmap_plot(all_data, columns)

return _generate_box_plot(all_data, columns)


def _generate_line_plot(real_data, synthetic_data, x_axis, y_axis, marker, annotations=None):
"""Generate a line plot of the real and synthetic data separated by a marker column.
Args:
real_data (pandas.DataFrame):
The real table data.
synthetic_column (pandas.Dataframe):
The synthetic table data.
x_axis (str):
The column name to be used as the x-axis of the graph
y_axis (str):
The column name to be used as the y-axis of the graph
marker (str):
The column used to define separate line sequences
annotations (None or dict):
Dict object that describes additional information to be presented in the graph
Returns:
plotly.graph_objects._figure.Figure
"""
# Check if the column is the appropriate type
all_data = pd.concat([real_data, synthetic_data], axis=0, ignore_index=True)
if not (is_datetime(all_data[x_axis]) or
pd.api.types.is_numeric_dtype(all_data[x_axis])):
raise ValueError(
f"Sequence Index '{x_axis}' must contain numerical or datetime values only")
if not (is_datetime(all_data[y_axis]) or
pd.api.types.is_numeric_dtype(all_data[y_axis])):
raise ValueError(
f"Column Name '{y_axis}' must contain numerical or datetime values only")

fig = px.line(all_data, x=x_axis, y=y_axis, color=marker,
color_discrete_map={
'Real': PlotConfig.DATACEBO_DARK,
'Synthetic': PlotConfig.DATACEBO_GREEN
})
if annotations:
fig.add_annotation(annotations)

if x_axis == 'sequence_index':
fig.update_xaxes(title_text='Sequence Position')

fig.update_layout(
title_text=f"Real vs Synthetic Data for column: '{y_axis}'",
plot_bgcolor=PlotConfig.BACKGROUND_COLOR,
font={'size': PlotConfig.FONT_SIZE},
)

# Add min-max shading
if 'min' in all_data and 'max' in all_data:
fig.add_trace(
go.Scatter(
name='Real-Min',
x=real_data[x_axis],
y=real_data['min'],
hoverinfo='skip',
marker={'color': PlotConfig.DATACEBO_DARK_TRANSPARENT},
showlegend=False,
mode='lines'
)
)
fig.add_trace(
go.Scatter(
name='Real-Max',
x=real_data[x_axis],
y=real_data['max'],
hoverinfo='skip',
marker={'color': PlotConfig.DATACEBO_DARK_TRANSPARENT},
showlegend=False,
mode='lines',
fill='tonexty',
fillcolor=PlotConfig.DATACEBO_DARK_TRANSPARENT,
)
)
fig.add_trace(
go.Scatter(
name='Synthetic-Min',
x=synthetic_data[x_axis],
y=synthetic_data['min'],
hoverinfo='skip',
marker={'color': PlotConfig.DATACEBO_GREEN_TRANSPARENT},
showlegend=False,
mode='lines'
)
)
fig.add_trace(
go.Scatter(
name='Synthetic-Max',
x=synthetic_data[x_axis],
y=synthetic_data['max'],
hoverinfo='skip',
marker={'color': PlotConfig.DATACEBO_GREEN_TRANSPARENT},
showlegend=False,
mode='lines',
fill='tonexty',
fillcolor=PlotConfig.DATACEBO_GREEN_TRANSPARENT,
)
)
return fig


def get_column_line_plot(real_data, synthetic_data, column_name, metadata):
"""Return a line plot of the real and synthetic data.
Args:
real_data (pandas.DataFrame):
The real table data.
synthetic_column (pandas.Dataframe):
The synthetic table data.
column_name (str):
The column name to be used as the y-axis of the graph
metadata (dict):
TimeSeries metadata dict. If not passed, the graph will
use raw indices to build the graph and only separate the sequences
into real and synthetic plots
Returns:
plotly.graph_objects._figure.Figure
"""
real_column = real_data[column_name]
synthetic_column = synthetic_data[column_name]

missing_data_real = get_missing_percentage(real_column)
missing_data_synthetic = get_missing_percentage(synthetic_column)
show_missing_values = missing_data_real > 0 or missing_data_synthetic > 0

annotations = None if not show_missing_values else {
'xref': 'paper',
'yref': 'paper',
'x': 1.0,
'y': 1.05,
'showarrow': False,
'text': (
f'*Missing Values: Real Data ({missing_data_real}%), '
f'Synthetic Data ({missing_data_synthetic}%)'
),
}

# Merge the real and synthetic data and add a flag ``Data`` to indicate each one.
r_data = real_data.copy()
s_data = synthetic_data.copy()

# Check for sequence index to determine the x-axis values
x_axis = 'sequence_index'
y_axis = column_name
if 'sequence_index' in metadata:
x_axis = metadata['sequence_index']
if 'sequence_key' in metadata:
r_data = r_data.groupby(x_axis, as_index=False).agg(
{
x_axis: 'first',
column_name: ['mean', 'min', 'max']
}
).rename(columns={'mean': column_name, 'first': x_axis})
s_data = s_data.groupby(x_axis, as_index=False).agg(
{
x_axis: 'first',
column_name: ['mean', 'min', 'max']
}
).rename(columns={'mean': column_name, 'first': x_axis})

r_data.columns = r_data.columns.droplevel(0)
s_data.columns = s_data.columns.droplevel(0)
else:
r_data['sequence_index'] = r_data.index
s_data['sequence_index'] = s_data.index

marker_name = 'Data'
r_data[marker_name] = 'Real'
s_data[marker_name] = 'Synthetic'

# Generate plot
fig = _generate_line_plot(
real_data=r_data,
synthetic_data=s_data,
x_axis=x_axis,
y_axis=y_axis,
marker=marker_name,
annotations=annotations
)
return fig
Loading

0 comments on commit 87790f2

Please sign in to comment.