Skip to content

Commit

Permalink
3D scatterplots and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
cbbcbail committed Jul 31, 2024
1 parent 3d215c3 commit 8c3f9dd
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 22 deletions.
4 changes: 1 addition & 3 deletions flexibleSubsetSelection/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ def greedySwap(dataset, lossFunction, subsetSize, minLoss=0,
loss (float): The loss value of the final subset
"""
if verbose:
print(f"Solving for a subset of size {subsetSize} with "
f"{lossFunction.objectives} objective.")
print(f"Solving for a subset of size {subsetSize}.")
iterations = 0

# select random starting subset
Expand Down Expand Up @@ -492,7 +491,6 @@ def optimizeDistribution(dataset, lossFunction, environment, subsetSize,

return z.value.astype(int), problem.value


def sinkhorn(dataset, lossFunction, distanceMatrix, subsetSize, environment, lambdaReg=0.1, verbose=False):
datasetLength = dataset.size[0]

Expand Down
64 changes: 45 additions & 19 deletions flexibleSubsetSelection/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

# Third party
import matplotlib
from matplotlib.colors import to_rgb, to_hex
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import seaborn as sns
Expand Down Expand Up @@ -113,6 +115,12 @@ def onPick(event, color):
line.zorder = 3
line._axes.figure.canvas.draw_idle()

def initializePane3D(ax, color):
"""Initialize the color of the background panes with hex color for 3D."""
rgb = to_rgb(color)
ax.xaxis.set_pane_color(to_hex([min(1, c + (1 - c)*0.05) for c in rgb]))
ax.yaxis.set_pane_color(to_hex([max(0, c*0.95) for c in rgb]))
ax.zaxis.set_pane_color(rgb)

# --- Error Markers ------------------------------------------------------------

Expand Down Expand Up @@ -206,25 +214,43 @@ def scatter(ax, color, dataset=None, subset=None, features=(0, 1),

if dataset is None and subset is None:
raise ValueError("no dataset or subset specified")
if dataset is not None:
sns.scatterplot(data = dataset.data,
x = features[0],
y = features[1],
color = color.palette["green"],
ax = ax,
zorder=3,
**parameters)

if subset is not None:
sns.scatterplot(data = subset.data,
x = features[0],
y = features[1],
color = color.palette["darkGreen"],
ax = ax,
zorder=4,
**parameters)

def parallelCoordinates(ax, dataset, color, subset=None, dataLinewidth=0.5,

if len(features) == 3:
if dataset is not None:
ax.scatter(dataset.data[features[0]],
dataset.data[features[1]],
dataset.data[features[2]],
color=color.palette["green"],
zorder=3,
**parameters)

if subset is not None:
ax.scatter(subset.data[features[0]],
subset.data[features[1]],
subset.data[features[2]],
color=color.palette["darkGreen"],
zorder=4,
**parameters)
initializePane3D(ax, color["grey"])
else:
if dataset is not None:
sns.scatterplot(data = dataset.data,
x = features[0],
y = features[1],
color = color.palette["green"],
ax = ax,
zorder=3,
**parameters)
if subset is not None:
sns.scatterplot(data = subset.data,
x = features[0],
y = features[1],
color = color.palette["darkGreen"],
ax = ax,
zorder=4,
**parameters)

def parallelCoordinates(ax, color, dataset=None, subset=None, dataLinewidth=0.5,
subsetLinewidth=1.5, **parameters):
"""
Plot a parallel coordinates chart of dataset on ax
Expand Down
2 changes: 2 additions & 0 deletions flexibleSubsetSelection/sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def __init__(self, data: ArrayLike = None, randTypes: str | list = None,
})
else:
self.data = generate.randomData(randTypes, size, interval, seed)
if features is not None:
self.data.columns = features
self.size = size
else:
raise ValueError("No data or random generation method specified.")
Expand Down

0 comments on commit 8c3f9dd

Please sign in to comment.