-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from josemcorderoc/dev
File datasource support with DuckDB, Adds choropleth map providers matplotlib and plotly
- Loading branch information
Showing
14 changed files
with
212 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,4 +20,5 @@ build/ | |
.git.bfg-report | ||
|
||
# Data | ||
*.jsonl | ||
*.jsonl | ||
dev |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,48 @@ | ||
|
||
from typing import Literal, overload | ||
import folium | ||
import geopandas as gpd | ||
import numpy as np | ||
import plotly.graph_objects as go | ||
import matplotlib.figure | ||
import matplotlib.axes | ||
|
||
from prompt2map.types import Map | ||
|
||
@overload | ||
def choropleth_map(data: gpd.GeoDataFrame, value_column: str, title: str, provider: Literal["folium"]) -> folium.Map: ... | ||
|
||
@overload | ||
def choropleth_map(data: gpd.GeoDataFrame, value_column: str, title: str, provider: Literal["plotly"]) -> go.Figure: ... | ||
|
||
from prompt2map.interfaces.core.map import Map | ||
@overload | ||
def choropleth_map(data: gpd.GeoDataFrame, value_column: str, title: str, provider: Literal["matplotlib"]) -> matplotlib.axes.Axes | matplotlib.figure.Figure: ... | ||
|
||
class ChoroplethMap(Map): | ||
def __init__(self, data: gpd.GeoDataFrame, value_column: str, cmap='viridis', legend_title='Legend', title='Choropleth Map', height=500, width=500) -> None: | ||
self.data = data | ||
self.value_column = value_column | ||
|
||
self.fig = self.data.explore(value_column, cmap=cmap) | ||
|
||
def show(self) -> None: | ||
self.fig.show_in_browser() | ||
|
||
def _repr_html_(self): | ||
return self.fig._repr_html_() | ||
def choropleth_map(data: gpd.GeoDataFrame, value_column: str, title: str, provider: Literal["folium", "plotly", "matplotlib"], cmap='viridis') -> Map: | ||
if provider == "folium": | ||
fig = data.explore(value_column, cmap=cmap, title=title) | ||
elif provider == "plotly": | ||
import plotly.express as px | ||
geom_all = data.geometry.union_all() | ||
minx, miny, maxx, maxy = geom_all.bounds | ||
max_bound = max(abs(maxx-minx), abs(maxy-miny)) * 111 | ||
zoom = 13 - np.log(max_bound) | ||
fig = px.choropleth_mapbox(data, | ||
geojson=data.to_geo_dict(), | ||
locations=data.index, | ||
color="N_INDIVIDUOS", | ||
mapbox_style="carto-positron", | ||
center={"lat": geom_all.centroid.y, "lon": geom_all.centroid.x}, | ||
zoom=zoom, | ||
opacity=0.8, | ||
title=title | ||
) | ||
fig.update_geos(fitbounds="locations") | ||
fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0}) | ||
return fig | ||
elif provider == "matplotlib": | ||
fig = data.plot(column=value_column, legend=True, title=title) | ||
return fig | ||
else: | ||
raise ValueError("Invalid provider") | ||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
|
||
import logging | ||
from typing import Any | ||
import duckdb | ||
import geopandas as gpd | ||
import pandas as pd | ||
from prompt2map.interfaces.sql.geo_database import GeoDatabase | ||
|
||
class GeoDuckDB(GeoDatabase): | ||
def __init__(self, table_name: str, file_path: str, embeddings_path: str, descriptions_path: str) -> None: | ||
self.table_name = table_name | ||
self.file_path = file_path | ||
self.embeddings_path = embeddings_path | ||
self.descriptions_path = descriptions_path | ||
|
||
self.embeddings_table_name = "embeddings" | ||
self.descriptions_table_name = "descriptions" | ||
self.logger = logging.getLogger(__name__) | ||
self.connection = duckdb.connect() | ||
|
||
self.connection.install_extension("spatial") | ||
self.connection.load_extension("spatial") | ||
|
||
self.connection.execute(f"CREATE TABLE {self.table_name} AS SELECT * FROM '{self.file_path}'") | ||
self.connection.execute(f"CREATE TABLE {self.embeddings_table_name} AS SELECT * FROM '{self.embeddings_path}'") | ||
self.connection.execute(f"CREATE TABLE {self.descriptions_table_name} AS SELECT * FROM '{self.descriptions_path}'") | ||
|
||
# validate that the main table has a geometry column | ||
metadata = self.get_fields_metadata(self.table_name) | ||
|
||
geometry_columns = metadata[metadata["type"] == "GEOMETRY"]["name"].to_list() | ||
|
||
if len(geometry_columns) == 0: | ||
raise ValueError(f"No geometry columns found in table {self.table_name}") | ||
elif len(geometry_columns) > 1: | ||
raise ValueError(f"Multiple geometry columns found in table {self.table_name}") | ||
|
||
self.geometry_column = geometry_columns[0] | ||
self.crs = gpd.read_parquet(self.file_path).crs | ||
|
||
self.embedding_length = self.connection.sql(f"""SELECT len(values) | ||
FROM {self.embeddings_table_name} | ||
LIMIT 1;""").fetchall()[0][0] | ||
|
||
self.embedding_type = f"DOUBLE[{self.embedding_length}]" | ||
|
||
def get_fields_metadata(self, table_name) -> pd.DataFrame: | ||
return self.connection.sql( | ||
f"SELECT * FROM pragma_table_info('{table_name}')" | ||
).df() | ||
|
||
|
||
def get_schema(self) -> str: | ||
variables_block = self.connection.sql(""" | ||
SELECT string_agg( | ||
CASE | ||
WHEN description IS NULL OR description = '' THEN format('\t{} {}', ti.name, ti.type) | ||
ELSE format('\t{} {}, -- {}', ti.name, ti.type, d.description) | ||
END, '\n')""" + f""" | ||
FROM pragma_table_info('{self.table_name}') ti | ||
LEFT JOIN {self.descriptions_table_name} d ON d.column = ti.name | ||
""").fetchall()[0][0] | ||
|
||
create_table_statement = f"""CREATE TABLE {self.table_name} (\n{variables_block}\n);""" | ||
return create_table_statement | ||
|
||
def get_geodata(self, query: str) -> gpd.GeoDataFrame: | ||
result = self.connection.sql(query) | ||
df = self.connection.sql(f"""SELECT | ||
* EXCLUDE ({self.geometry_column}), | ||
ST_AsText({self.geometry_column}) AS wkt_geom | ||
FROM result""").df() | ||
|
||
geometry = gpd.GeoSeries.from_wkt(df['wkt_geom'], crs=self.crs) | ||
gdf = gpd.GeoDataFrame(df.drop(columns=["wkt_geom"]), geometry=geometry) | ||
return gdf | ||
|
||
def get_literals(self, table: str, column: str) -> list[Any]: | ||
return [value for (value), in self.connection.sql(f"SELECT DISTINCT {column} FROM {table}").fetchall()] | ||
|
||
def get_most_similar_cosine(self, table: str, column: str, text_embedding: list[float], embedding_suffix: str) -> str: | ||
self.logger.info(f"table: {table}, column: {column}") | ||
query = f""" | ||
SELECT t.{column} | ||
FROM {table} t | ||
LEFT JOIN {self.embeddings_table_name} e ON t.{column} = e.text | ||
ORDER BY array_distance(e.values::{self.embedding_type}, {text_embedding}::{self.embedding_type}) | ||
LIMIT 1;""" | ||
self.logger.info(f"Executing query: {query}") | ||
result = self.connection.sql(query).fetchall() | ||
if len(result) == 0: | ||
raise ValueError(f"No similar value found in {table}.{column}") | ||
return result[0][0] | ||
|
||
|
||
def get_most_similar_levenshtein(self, table: str, column: str, text: str) -> str: | ||
raise NotImplementedError | ||
|
||
def get_column_type(self, table_name: str, column_name: str) -> str | None: | ||
df = self.get_fields_metadata(table_name) | ||
df = df[df.name == column_name] | ||
if len(df) == 0: | ||
return None | ||
return df.iloc[0]["type"] | ||
|
||
def get_geo_column(self) -> tuple[str, str]: | ||
return self.table_name, self.geometry_column | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.