Skip to content

Commit

Permalink
Merge pull request #1 from josemcorderoc/dev
Browse files Browse the repository at this point in the history
File datasource support with DuckDB, Adds choropleth map providers matplotlib and plotly
  • Loading branch information
josemcorderoc authored Sep 22, 2024
2 parents a5c401b + 20a7270 commit 3a8d1b9
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 65 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ build/
.git.bfg-report

# Data
*.jsonl
*.jsonl
dev
10 changes: 10 additions & 0 deletions prompt2map/application/core/prompt2map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from prompt2map.application.retrievers.sql_geo_retriever import SQLGeoRetriever
from prompt2map.interfaces.core.geo_retriever import GeoRetriever
from prompt2map.interfaces.core.map_generator import MapGenerator
from prompt2map.providers.geoduckdb import GeoDuckDB
from prompt2map.providers.openai import OpenAIProvider
from prompt2map.providers.postgres_db import PostgresDB
from prompt2map.types import Map
Expand Down Expand Up @@ -35,6 +36,15 @@ def from_postgis(cls, db_name: str, db_user: str, db_password: str, db_host: str
sql_retrievier = SQLGeoRetriever(db, sql_query_processor=query_processor)
openai_generator = OpenAIMapGenerator(openai_provider)
return cls(retriever=sql_retrievier, generator=openai_generator)

@classmethod
def from_file(cls, table_name: str, file_path: str, embeddings_path: str, descriptions_path: str) -> Self:
db = GeoDuckDB(table_name, file_path, embeddings_path, descriptions_path)
openai_provider = OpenAIProvider()
query_processor = SQLQueryProcessor(db, openai_provider)
sql_retrievier = SQLGeoRetriever(db, sql_query_processor=query_processor)
openai_generator = OpenAIMapGenerator(openai_provider)
return cls(retriever=sql_retrievier, generator=openai_generator)


def generate_map(prompt: str, data_source: str | gpd.GeoDataFrame = None):
Expand Down
6 changes: 3 additions & 3 deletions prompt2map/application/generators/openai_map_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import Any, Callable, Optional
import geopandas as gpd

from prompt2map.application.maps.choropleth_map import choropleth_map
from prompt2map.application.maps.bar_chart_map import BarChartMap
from prompt2map.application.maps.choropleth_map import ChoroplethMap
from prompt2map.interfaces.core.map import Map
from prompt2map.interfaces.core.map_generator import MapGenerator
from prompt2map.providers.openai import OpenAIProvider
from prompt2map.types import Map

def get_available_tools(data: gpd.GeoDataFrame) -> list[dict[str, Any]]:
return [
Expand Down Expand Up @@ -60,7 +60,7 @@ def get_available_tools(data: gpd.GeoDataFrame) -> list[dict[str, Any]]:

def create_choropleth_map(data: gpd.GeoDataFrame, title: str, value_column: str) -> Map:
# TODO check if any processing is needed
return ChoroplethMap(data=data, title=title, value_column=value_column)
return choropleth_map(data, value_column, title, "folium")

def create_bar_chart_map(data: gpd.GeoDataFrame, value_columns: list[str]) -> Map:
return BarChartMap(data=data, value_columns=value_columns)
Expand Down
4 changes: 2 additions & 2 deletions prompt2map/application/maps/bar_chart_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import geopandas as gpd
import plotly.express as px

from prompt2map.interfaces.core.map import Map
from prompt2map.types import Map


def average_bounding_boxes(s: gpd.GeoSeries) -> tuple[float, float]:
Expand Down Expand Up @@ -97,7 +97,7 @@ def plot_polygons(polygons):
ax.set_aspect('equal')
plt.show()

class BarChartMap(Map):
class BarChartMap:
def __init__(self, data: gpd.GeoDataFrame, value_columns: list[str], height=500, width=500, colors: Optional[list[str]] = None) -> None:
self.data = data
self.value_columns = value_columns
Expand Down
58 changes: 45 additions & 13 deletions prompt2map/application/maps/choropleth_map.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,48 @@

from typing import Literal, overload
import folium
import geopandas as gpd
import numpy as np
import plotly.graph_objects as go
import matplotlib.figure
import matplotlib.axes

from prompt2map.types import Map

@overload
def choropleth_map(data: gpd.GeoDataFrame, value_column: str, title: str, provider: Literal["folium"]) -> folium.Map: ...

@overload
def choropleth_map(data: gpd.GeoDataFrame, value_column: str, title: str, provider: Literal["plotly"]) -> go.Figure: ...

from prompt2map.interfaces.core.map import Map
@overload
def choropleth_map(data: gpd.GeoDataFrame, value_column: str, title: str, provider: Literal["matplotlib"]) -> matplotlib.axes.Axes | matplotlib.figure.Figure: ...

class ChoroplethMap(Map):
def __init__(self, data: gpd.GeoDataFrame, value_column: str, cmap='viridis', legend_title='Legend', title='Choropleth Map', height=500, width=500) -> None:
self.data = data
self.value_column = value_column

self.fig = self.data.explore(value_column, cmap=cmap)

def show(self) -> None:
self.fig.show_in_browser()

def _repr_html_(self):
return self.fig._repr_html_()
def choropleth_map(data: gpd.GeoDataFrame, value_column: str, title: str, provider: Literal["folium", "plotly", "matplotlib"], cmap='viridis') -> Map:
if provider == "folium":
fig = data.explore(value_column, cmap=cmap, title=title)
elif provider == "plotly":
import plotly.express as px
geom_all = data.geometry.union_all()
minx, miny, maxx, maxy = geom_all.bounds
max_bound = max(abs(maxx-minx), abs(maxy-miny)) * 111
zoom = 13 - np.log(max_bound)
fig = px.choropleth_mapbox(data,
geojson=data.to_geo_dict(),
locations=data.index,
color="N_INDIVIDUOS",
mapbox_style="carto-positron",
center={"lat": geom_all.centroid.y, "lon": geom_all.centroid.x},
zoom=zoom,
opacity=0.8,
title=title
)
fig.update_geos(fitbounds="locations")
fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})
return fig
elif provider == "matplotlib":
fig = data.plot(column=value_column, legend=True, title=title)
return fig
else:
raise ValueError("Invalid provider")
return fig
5 changes: 4 additions & 1 deletion prompt2map/application/prompt2sql/sql_query_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def replace_literals(self, query: str) -> str:
self.logger.info(f"Query literals: {query_literals}")
for (table_name, column_name), literal in query_literals.items():
col_type = self.db.get_column_type(table_name, column_name)
if col_type == "text":
if col_type is None:
self.logger.warning(f"Column {column_name} in table {table_name} does not exist.")
continue
elif col_type.lower() in ["text", "varchar"]:
text_embedding = self.embedding.get_embedding(str(literal)).tolist()
# most_similar_literal = self.db.get_most_similar_levenshtein(table_name, column_name, str(literal))
most_similar_literal = self.db.get_most_similar_cosine(table_name, column_name, text_embedding, "emb_openai_small")
Expand Down
38 changes: 14 additions & 24 deletions prompt2map/application/prompt2sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,34 @@ def is_read_only_query(query: str) -> bool:
return expression is None


def to_geospatial_query(query: str, geospatial_columns: dict[str, str]) -> str:
def to_geospatial_query(query: str, geom_table: str, geom_col: str, agg_function: str, ) -> str:
parsed_query = parse_one(query)
table_alias = bidict({t.alias:t.name for t in parsed_query.find_all(exp.Table)})
table_names = [t.name for t in parsed_query.find_all(exp.Table)]

if parsed_query is None:
raise ValueError(f"Query {query} could not be parsed.")

for col in geospatial_columns:
if col not in table_names:
raise ValueError(f"Geotable {col} not found in query.")

if geom_table not in table_names:
raise ValueError(f"Geotable {geom_table} not found in query.")

select = parsed_query.find(exp.Select)
group_by = parsed_query.find(exp.Group)

if select is None:
raise ValueError(f"Query {query} does not contain a SELECT clause.")

# if select contains comuna, add geom
# geospatial_exp = [e for e in select.expressions if type(e) == exp.Column and table_alias[e.table] in geospatial_columns.keys()]


if len(set(table_names).intersection(set(geospatial_columns.keys()))) > 0 and group_by is None:
new_column = exp.Column(this=exp.Identifier(this="geom"))
# if geospatial_exp[0].table != "":
# new_column = exp.Column(this=exp.Identifier(this=geospatial_columns[table_alias[geospatial_exp[0].table]]), table=exp.Identifier(this=geospatial_exp[0].table))
if group_by is None:
new_column = exp.Column(this=exp.Identifier(this=geom_col))
select.expressions.append(new_column)
elif group_by is not None:
group_by_tables = [table_alias[e.table] for e in group_by.expressions if type(e) == exp.Column and e.table in table_alias and table_alias[e.table] not in geospatial_columns.keys()]
for geotable, geocol in geospatial_columns.items():
if geotable not in group_by_tables:
# aggregation = exp.Anonymous(func=exp.Identifier(this="ST_Union"), args=[exp.Column(this=exp.Identifier(this=geocol))])
# select.expressions.append(aggregation)

if geotable in table_alias.values() and table_alias.inv[geotable] != "":
geocol = f"{table_alias.inv[geotable]}.{geocol}"
pass
select.expressions.append(f"ST_Union({geocol}) AS geom")
else:
group_by_tables = [table_alias[e.table] for e in group_by.expressions
if type(e) == exp.Column
and e.table in table_alias
and table_alias[e.table] != geom_table]
if geom_table not in group_by_tables:
if geom_table in table_alias.values() and table_alias.inv[geom_table] != "":
geom_col = f"{table_alias.inv[geom_table]}.{geom_col}"
select.expressions.append(f"{agg_function}({geom_col}) AS {geom_col}")

return parsed_query.sql()
16 changes: 10 additions & 6 deletions prompt2map/application/retrievers/sql_geo_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
class SQLGeoRetriever(GeoRetriever):
def __init__(self, db: GeoDatabase, prompt2sql: Optional[Prompt2SQL] = None,
sql_query_processor: Optional[SQLQueryProcessor] = None,
test_db: Optional[GeoDatabase] = None, db_schema: Optional[str] = None) -> None:
test_db: Optional[GeoDatabase] = None, db_schema: Optional[str] = None,
agg_function_sql: str = "ST_Union_Agg") -> None:
self.logger = logging.getLogger(self.__class__.__name__)
self.db = db
self.test_db = test_db
Expand All @@ -24,6 +25,7 @@ def __init__(self, db: GeoDatabase, prompt2sql: Optional[Prompt2SQL] = None,
self.db_schema = db_schema if db_schema else self.db.get_schema()
prompt2sql = LLMPrompt2SQL(self.openai_provider, self.db_schema)
self.prompt2sql = prompt2sql
self.agg_function_sql = agg_function_sql



Expand All @@ -34,26 +36,28 @@ def retrieve(self, query: str) -> gpd.GeoDataFrame:

# validate read only
if not is_read_only_query(sql_query):
raise ValueError(f"Query {sql_query} is not a read-only query.")
self.logger.info(f"Query {sql_query} is a read-only query.")
raise ValueError(f"Query is not a read-only query.")
self.logger.info(f"Query is a read-only query.")

# replace literals
if self.sql_query_processor:
sql_query = self.sql_query_processor.replace_literals(sql_query)
self.logger.info(f"Replaced literals in query. New query:\n{sql_query}")
self.logger.info(f"Replaced literals in query. New query:\n{sql_query}")

# add spatial columns
sql_query = to_geospatial_query(sql_query, {"comuna": "geom"})
geotable_name, geocolumn_name = self.db.get_geo_column()
sql_query = to_geospatial_query(sql_query, geotable_name, geocolumn_name, self.agg_function_sql)
self.logger.info(f"Added spatial columns to query. New query:\n{sql_query}")


# run in test database
if self.test_db:
self.test_db.get_geodata(sql_query)
self.logger.info(f"Query {sql_query} ran in test database.")
self.logger.info(f"Query {sql_query} ran in test database.")

# run in production environment
data = self.db.get_geodata(sql_query)
self.logger.info(f"Query {sql_query} ran in production database.")

self.sql_query = sql_query
return data
6 changes: 3 additions & 3 deletions prompt2map/interfaces/sql/geo_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import geopandas as gpd

class GeoDatabase(Protocol):
def run_query(self, query: str) -> list[dict]:
...

def get_schema(self) -> str:
...

Expand All @@ -22,4 +19,7 @@ def get_most_similar_levenshtein(self, table: str, column: str, text: str) -> st
...

def get_column_type(self, table_name: str, column_name: str) -> Optional[str]:
...

def get_geo_column(self) -> tuple[str, str]:
...
108 changes: 108 additions & 0 deletions prompt2map/providers/geoduckdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@

import logging
from typing import Any
import duckdb
import geopandas as gpd
import pandas as pd
from prompt2map.interfaces.sql.geo_database import GeoDatabase

class GeoDuckDB(GeoDatabase):
def __init__(self, table_name: str, file_path: str, embeddings_path: str, descriptions_path: str) -> None:
self.table_name = table_name
self.file_path = file_path
self.embeddings_path = embeddings_path
self.descriptions_path = descriptions_path

self.embeddings_table_name = "embeddings"
self.descriptions_table_name = "descriptions"
self.logger = logging.getLogger(__name__)
self.connection = duckdb.connect()

self.connection.install_extension("spatial")
self.connection.load_extension("spatial")

self.connection.execute(f"CREATE TABLE {self.table_name} AS SELECT * FROM '{self.file_path}'")
self.connection.execute(f"CREATE TABLE {self.embeddings_table_name} AS SELECT * FROM '{self.embeddings_path}'")
self.connection.execute(f"CREATE TABLE {self.descriptions_table_name} AS SELECT * FROM '{self.descriptions_path}'")

# validate that the main table has a geometry column
metadata = self.get_fields_metadata(self.table_name)

geometry_columns = metadata[metadata["type"] == "GEOMETRY"]["name"].to_list()

if len(geometry_columns) == 0:
raise ValueError(f"No geometry columns found in table {self.table_name}")
elif len(geometry_columns) > 1:
raise ValueError(f"Multiple geometry columns found in table {self.table_name}")

self.geometry_column = geometry_columns[0]
self.crs = gpd.read_parquet(self.file_path).crs

self.embedding_length = self.connection.sql(f"""SELECT len(values)
FROM {self.embeddings_table_name}
LIMIT 1;""").fetchall()[0][0]

self.embedding_type = f"DOUBLE[{self.embedding_length}]"

def get_fields_metadata(self, table_name) -> pd.DataFrame:
return self.connection.sql(
f"SELECT * FROM pragma_table_info('{table_name}')"
).df()


def get_schema(self) -> str:
variables_block = self.connection.sql("""
SELECT string_agg(
CASE
WHEN description IS NULL OR description = '' THEN format('\t{} {}', ti.name, ti.type)
ELSE format('\t{} {}, -- {}', ti.name, ti.type, d.description)
END, '\n')""" + f"""
FROM pragma_table_info('{self.table_name}') ti
LEFT JOIN {self.descriptions_table_name} d ON d.column = ti.name
""").fetchall()[0][0]

create_table_statement = f"""CREATE TABLE {self.table_name} (\n{variables_block}\n);"""
return create_table_statement

def get_geodata(self, query: str) -> gpd.GeoDataFrame:
result = self.connection.sql(query)
df = self.connection.sql(f"""SELECT
* EXCLUDE ({self.geometry_column}),
ST_AsText({self.geometry_column}) AS wkt_geom
FROM result""").df()

geometry = gpd.GeoSeries.from_wkt(df['wkt_geom'], crs=self.crs)
gdf = gpd.GeoDataFrame(df.drop(columns=["wkt_geom"]), geometry=geometry)
return gdf

def get_literals(self, table: str, column: str) -> list[Any]:
return [value for (value), in self.connection.sql(f"SELECT DISTINCT {column} FROM {table}").fetchall()]

def get_most_similar_cosine(self, table: str, column: str, text_embedding: list[float], embedding_suffix: str) -> str:
self.logger.info(f"table: {table}, column: {column}")
query = f"""
SELECT t.{column}
FROM {table} t
LEFT JOIN {self.embeddings_table_name} e ON t.{column} = e.text
ORDER BY array_distance(e.values::{self.embedding_type}, {text_embedding}::{self.embedding_type})
LIMIT 1;"""
self.logger.info(f"Executing query: {query}")
result = self.connection.sql(query).fetchall()
if len(result) == 0:
raise ValueError(f"No similar value found in {table}.{column}")
return result[0][0]


def get_most_similar_levenshtein(self, table: str, column: str, text: str) -> str:
raise NotImplementedError

def get_column_type(self, table_name: str, column_name: str) -> str | None:
df = self.get_fields_metadata(table_name)
df = df[df.name == column_name]
if len(df) == 0:
return None
return df.iloc[0]["type"]

def get_geo_column(self) -> tuple[str, str]:
return self.table_name, self.geometry_column

2 changes: 1 addition & 1 deletion prompt2map/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_embedding(self, text: str) -> np.ndarray:


def send_batch_embedding(self, requests, input_file_name: str) -> str:
# Write the requests to a jsonl file7
# Write the requests to a jsonl file
with jsonlines.open(input_file_name, mode='w') as writer:
writer.write_all(requests)

Expand Down
Loading

0 comments on commit 3a8d1b9

Please sign in to comment.