-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsql2gpt.py
68 lines (52 loc) · 2.12 KB
/
sql2gpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import Union
import fire
from sqlalchemy import create_engine, MetaData, Table, inspect
import config
from util import get_db_type, is_uri
def get_schemas(database_url):
print("Connecting to database...")
engine = create_engine(database_url)
meta = MetaData()
print("Inspecting database...")
inspector = inspect(engine)
schemas = []
for table_name in inspector.get_table_names():
table = Table(table_name, meta, autoload_with=engine)
schema_string = f"{table_name}("
columns = []
for column in table.columns:
columns.append(f"{column.name} {column.type}")
schema_string += ", ".join(columns) + ")"
schemas.append(schema_string)
print("Done")
print('-'*72)
return schemas
class SQL2GPT:
def __init__(self):
self.config = config.load()
self.prompt = "I have a {} SQL database. I am going to give you the schema in the following format: table_name(column_name column_type, column_name column_type, ...), followed by END. After END, you will be given instructions on how to use the schema information. Here is the schema:\n\n{}\nEND"
def get_uri(self, i) -> Union[str, None]:
database_url = i
if not is_uri(database_url):
database_url = self.config.get(database_url, None)
if not database_url:
print(
"Database not found in URL. Add with `sql2gpt add <name> <database_uri>`.")
return None
return database_url
def print_schema(self, database_url):
database_url = self.get_uri(database_url)
schemas = get_schemas(database_url)
print("\n".join(schemas))
def get_prompt(self, database_url):
database_url = self.get_uri(database_url)
db_type = get_db_type(database_url)
schemas = get_schemas(database_url)
print(self.prompt.format(db_type, "\n".join(schemas)))
def add(self, name: str, database_uri: str):
config.add(name, database_uri)
print("Database added successfully.")
def main():
fire.Fire(SQL2GPT)
if __name__ == "__main__":
main()