Skip to content

Commit

Permalink
Get scatter color from trace; don't use /api/charts for large scatter…
Browse files Browse the repository at this point in the history
… plots (#91)

* WIP: get scatter color from trace

* Don't use /api/charts for large scatter plot thumbnails

* clean up

---------

Co-authored-by: Caleb Kaiser <[email protected]>
  • Loading branch information
dsblank and Caleb Kaiser authored Apr 10, 2023
1 parent 7ab11cb commit 3428c99
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 43 deletions.
26 changes: 21 additions & 5 deletions backend/kangas/datatypes/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json

from .base import Asset
from .utils import flatten, get_file_extension, is_valid_file_path
from .utils import flatten, get_color, get_file_extension, is_valid_file_path


class Embedding(Asset):
Expand All @@ -27,6 +27,7 @@ class Embedding(Asset):
def __init__(
self,
embedding=None,
label=None,
file_name=None,
metadata=None,
source=None,
Expand All @@ -36,6 +37,7 @@ def __init__(
if unserialize:
return
if self.source is not None:
# FIXME: this is for images, not others
self._log_metadata(
filename=self.source,
extension=get_file_extension(self.source),
Expand All @@ -44,23 +46,36 @@ def __init__(
self._log_metadata(**metadata)
return

if label:
color = get_color(label)
else:
color = None

self.metadata["label"] = label
self.metadata["color"] = color

if file_name:
if is_valid_file_path(file_name):
with open(file_name, "rb") as io_object:
self.asset_data = io_object.read()
with open(file_name, "r") as io_object:
self.asset_data = json.dumps(
{"vector": io_object.read(), "label": label, "color": color}
)
self.metadata["extension"] = get_file_extension(file_name)
self.metadata["filename"] = file_name
else:
raise ValueError("file not found: %r" % file_name)
else:
self.asset_data = json.dumps(embedding)
self.asset_data = json.dumps(
{"vector": embedding, "label": label, "color": color}
)
if metadata:
self.metadata.update(metadata)

@classmethod
def get_statistics(cls, datagrid, col_name, field_name):
from sklearn.decomposition import IncrementalPCA

# FIXME: compute min and max of eigenspace
minimum = None
maximum = None
avg = None
Expand All @@ -79,7 +94,8 @@ def get_statistics(cls, datagrid, col_name, field_name):
field_name=field_name
)
):
vectors = json.loads(row[1])
embedding = json.loads(row[1])
vectors = embedding["vector"]
vector = flatten(vectors)
# FIXME: could scale them here; leave to user for now
batch.append(vector)
Expand Down
15 changes: 14 additions & 1 deletion backend/kangas/server/flask_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,12 +636,25 @@ def get_embeddings_as_pca():
column_value = request.args.get("columnValue")
group_by = request.args.get("groupBy")
where_expr = request.args.get("whereExpr")
# if thumbnail, need these:
thumbnail = request.args.get("thumbnail", "false") == "true"
height = int(request.args.get("height", "116"))
width = int(request.args.get("width", "0"))

if ensure_datagrid_path(dgid):
pca_data = select_pca_data_task.apply(
args=(dgid, asset_id, column_name, column_value, group_by, where_expr)
).get()
return pca_data
if thumbnail:
image = generate_chart_image_task.apply(
args=("scatter", pca_data, width, height)
).get()
response = make_response(image)
response.headers.add("Cache-Control", "max-age=604800")
response.headers.add("Content-type", "image/png")
return response
else:
return pca_data
else:
return error(404)

Expand Down
38 changes: 28 additions & 10 deletions backend/kangas/server/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2059,21 +2059,27 @@ def select_pca_data(dgid, asset_id, column_name, column_value, group_by, where_e

pca_eigen_vectors = metadata[column_name]["other"]["pca_eigen_vectors"]
pca_mean = metadata[column_name]["other"]["pca_mean"]
color = get_color(column_name)
default_color = get_color(column_name)

pca = PCA()
pca.components_ = np.array(pca_eigen_vectors)
pca.mean_ = np.array(pca_mean)

traces = []
if asset_id:
asset_data = select_asset(dgid, asset_id)
vector = pca.transform([json.loads(asset_data)])
asset_data_raw = select_asset(dgid, asset_id)
asset_data = json.loads(asset_data_raw)
vector = pca.transform([asset_data["vector"]])
if asset_data["color"]:
color = asset_data["color"]
else:
color = default_color

traces.append(
{
"x": [round(vector[0][0], 3)],
"y": [round(vector[0][1], 3)],
"color": [color],
"type": "scatter",
"mode": "markers",
"marker": {"size": 12, "color": color},
Expand Down Expand Up @@ -2152,21 +2158,29 @@ def select_pca_data(dgid, asset_id, column_name, column_value, group_by, where_e

xs = []
ys = []
colors = []
for asset_data_row in all_asset_data:
asset_data = asset_data_row[0]
# FIXME: can transform all at once
vector = pca.transform([json.loads(asset_data)])
asset_data_raw = asset_data_row[0]
asset_data = json.loads(asset_data_raw)
vector = asset_data["vector"]
if asset_data["color"]:
color = asset_data["color"]
else:
color = default_color

xs.append(round(vector[0][0], 3))
ys.append(round(vector[0][1], 3))
# FIXME: can transform all at once
eigen_vector = pca.transform([vector])
xs.append(round(eigen_vector[0][0], 3))
ys.append(round(eigen_vector[0][1], 3))
colors.append(color)

traces.append(
{
"x": xs,
"y": ys,
"type": "scatter",
"mode": "markers",
"marker": {"size": 3, "color": color},
"marker": {"size": 3, "color": colors},
}
)

Expand Down Expand Up @@ -2460,7 +2474,7 @@ def generate_chart_image(chart_type, data, width, height):
span_y = max_y - min_y

radius = trace["marker"]["size"] / 2
color = trace["marker"]["color"]
colors = trace["marker"]["color"]
margin = 5

total_width = width - margin * 2
Expand All @@ -2476,6 +2490,10 @@ def generate_chart_image(chart_type, data, width, height):
)

for count, [x, y] in enumerate(zip(trace["x"], trace["y"])):
if isinstance(colors, list):
color = colors[count]
else:
color = colors
drawing.ellipse(
[
margin + (total_width * (x - min_x) / span_x) - radius,
Expand Down
5 changes: 0 additions & 5 deletions frontend/app/cells/charts/histogram/Histogram.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import fetchHistogram from "../../../../lib/fetchHistogram"
import HistogramClient from "./HistogramClient";
import classNames from 'classnames/bind';
import styles from '../Charts.module.scss'

const cx = classNames.bind(styles);

const Histogram = async ({ value, expanded, ssr }) => {
const ssrData = ssr ? await fetchHistogram(value, ssr) : false;

return <HistogramClient expanded={expanded} value={value} ssrData={ssrData} />;

}

export default Histogram;
59 changes: 42 additions & 17 deletions frontend/app/cells/embedding/EmbeddingCellClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -88,34 +88,47 @@ const EmbeddingClient = ({ value, expanded, query, columnName, ssrData }) => {
}, [columnName]);


useEffect(() => {
if (!value || ssrData || !query) return;

const queryParams = useMemo(() => {
if (!query?.groupBy) {
fetchEmbeddingsAsPCA({dgid: query?.dgid, timestamp: query?.timestamp, columnName, assetId: value?.assetId}).then(res => {
setResponse(res);
});
return {
dgid: query?.dgid,
timestamp: query?.timestamp,
columnName,
assetId: value?.assetId
};
} else {
fetchEmbeddingsAsPCA({dgid: query?.dgid, timestamp: query?.timestamp, columnName, columnValue: value?.columnValue,
groupBy: value?.groupBy, whereExpr: value?.whereExpr}).then(res => {
setResponse(res);
});
return {
dgid: query?.dgid,
timestamp: query?.timestamp,
columnName,
columnValue: value?.columnValue,
groupBy: value?.groupBy,
whereExpr: value?.whereExpr
};
}
}, [value, query, ssrData]);
}, [query, value, columnName]);

useEffect(() => {
if (ssrData || !queryParams) return;

fetchEmbeddingsAsPCA(queryParams).then(res => {
setResponse(res);
});
}, [ssrData, queryParams]);


const queryString = useMemo(() => {
if (!data) return;
if (!query?.dgid) return;

return new URLSearchParams(
Object.fromEntries(
Object.entries({
chartType: 'scatter',
data: JSON.stringify(data)
...queryParams,
thumbnail: true
}).filter(([k, v]) => typeof(v) !== 'undefined' && v !== null)
)
).toString();
}, [data]);
}, [queryParams]);


if (!data || data?.error) {
Expand All @@ -124,9 +137,21 @@ const EmbeddingClient = ({ value, expanded, query, columnName, ssrData }) => {

if (!expanded) {
if (!query?.groupBy) {
return (<img src={`${config.rootPath}api/charts?${queryString}`} loading="lazy" className={cx(['chart-thumbnail', 'embedding'])} />);
return (
<img
src={`${config.rootPath}api/embeddings-as-pca?${queryString}`}
loading="lazy"
className={cx(['chart-thumbnail', 'embedding'])}
/>
);
} else {
return (<img src={`${config.rootPath}api/charts?${queryString}`} loading="lazy" className={cx(['chart-thumbnail', 'embedding-grouped'])} />);
return (
<img
src={`${config.rootPath}api/embeddings-as-pca?${queryString}`}
loading="lazy"
className={cx(['chart-thumbnail', 'embedding-grouped'])}
/>
);
}
}

Expand Down
19 changes: 14 additions & 5 deletions frontend/pages/api/embeddings-as-pca.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import stream, { Stream } from 'stream';
import config from '../../config';

const handler = async (req, res) => {
Expand All @@ -10,12 +11,20 @@ const handler = async (req, res) => {
);

const result = await fetch(
`${config.apiUrl}embeddings-as-pca?${query.toString()}`,
`${config.apiUrl}embeddings-as-pca?${query.toString()}`,
{ next: { revalidate: 10000 } }
);

const json = await result.json();
res.send(json);
}

if (!req.query.thumbnail) {
const json = await result.json();
res.send(json);
} else {
const image = await result.body;
const passthrough = new Stream.PassThrough();
stream.pipeline(image, passthrough, (err) => err ? console.error(err) : null);
res.setHeader('Cache-Control', 'max-age=604800')
passthrough.pipe(res);
}
};

export default handler;

0 comments on commit 3428c99

Please sign in to comment.