Skip to content

Commit

Permalink
Merge pull request #2 from keboola/fix-running-version
Browse files Browse the repository at this point in the history
Fix running version
  • Loading branch information
Nweaver412 authored Jan 16, 2025
2 parents d4a0a5f + 10258cb commit ba9eb1d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 75 deletions.
7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
dataconf==3.2.0
dataconf==3.3.0
freezegun==1.2.2
keboola.component==1.4.4
keboola.utils==1.1.0
lancedb==0.13.0
mock==4.0.3
openai==1.44.1
pandas==2.2.2
openai==1.59.7
pandas==2.2.3
pyarrow==15.0.0
91 changes: 20 additions & 71 deletions src/component.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import csv
import logging
import os
import shutil
import zipfile
import lancedb

import pyarrow as pa
import pandas as pd
Expand All @@ -13,6 +9,7 @@
from configuration import Configuration

from openai import OpenAI

class Component(ComponentBase):
def __init__(self):
super().__init__()
Expand All @@ -21,28 +18,18 @@ def __init__(self):

def run(self):
self.init_configuration()
self.init_client()
self.init_openai_client()

try:
input_table = self._get_input_table()
with open(input_table.full_path, 'r', encoding='utf-8') as input_file:
reader = csv.DictReader(input_file)
if self._configuration.outputFormat == 'lance':
lance_dir, table = self._initialize_lance_output(reader.fieldnames)
self._process_rows_lance(reader, table)
elif self._configuration.outputFormat == 'csv':
self._process_rows_csv(reader)
self._process_rows_csv(reader)
except Exception as e:
raise UserException(f"Error occurred during embedding process: {str(e)}")

def _initialize_lance_output(self, fieldnames):
lance_dir = os.path.join(self.tables_out_path, 'lance_db')
os.makedirs(lance_dir, exist_ok=True)
db = lancedb.connect(lance_dir)
schema = self._get_lance_schema(fieldnames)
table = db.create_table("embeddings", schema=schema, mode="overwrite")
return lance_dir, table

def _process_rows_csv(self, reader):

output_table = self._get_output_table()
with open(output_table.full_path, 'w', encoding='utf-8', newline='') as output_file:
fieldnames = reader.fieldnames + ['embedding']
Expand All @@ -53,78 +40,40 @@ def _process_rows_csv(self, reader):
self.row_count += 1
text = row[self._configuration.embedColumn]
embedding = self.get_embedding(text)
row['embedding'] = embedding
row['embedding'] = embedding if embedding else "[]"
writer.writerow(row)

def _process_rows_lance(self, reader, table, lance_dir):
data = []
self.row_count = 0
for row in reader:
self.row_count += 1
text = row[self._configuration.embedColumn]
embedding = self.get_embedding(text)
lance_row = {**row, 'embedding': embedding}
data.append(lance_row)
if self.row_count % 1000 == 0:
table.add(data)
data = []
if data:
table.add(data)
self._finalize_lance_output(lance_dir)



def init_configuration(self):
self.validate_configuration_parameters(Configuration.get_dataclass_required_parameters())
self._configuration: Configuration = Configuration.load_from_dict(self.configuration.parameters)

def init_client(self):
def init_openai_client(self):
self.client = OpenAI(api_key=self._configuration.pswd_apiKey)

def get_embedding(self, text):
try:
response = self.client.embeddings.create(input=[text], model=self._configuration.model)
return response.data[0].embedding
except Exception as e:
raise UserException(f"Error getting embedding: {str(e)}")


def get_embedding(self, text, model = 'text-embedding-3-small'):
if not text or not isinstance(text, str) or text.strip() == "":
return []
text = text.replace("\n", " ")
return self.client.embeddings.create(input = [text], model=model).data[0].embedding


def _get_input_table(self):
if not self.get_input_tables_definitions():
raise UserException("No input table specified. Please provide one input table in the input mapping!")
if len(self.get_input_tables_definitions()) > 1:
raise UserException("Only one input table is supported")
return self.get_input_tables_definitions()[0]
def _get_output_table(self):

def _get_output_table(self):
destination_config = self.configuration.parameters['destination']
if not (out_table_name := destination_config.get("output_table_name")):
out_table_name = f"app-embed-lancedb.csv"
else:
out_table_name = f"{out_table_name}.csv"

return self.create_out_table_definition(out_table_name)

def _get_lance_schema(self, fieldnames):
schema = pa.schema([
(name, pa.string()) for name in fieldnames
] + [('embedding', pa.list_(pa.float32()))])
return schema

def _finalize_lance_output(self, lance_dir):
print("Zipping the Lance directory")
try:
zip_path = os.path.join(self.files_out_path, 'embeddings_lance.zip')
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(lance_dir):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, lance_dir)
zipf.write(file_path, arcname)
print(f"Successfully zipped Lance directory to {zip_path}")
# Remove the original Lance directory
shutil.rmtree(lance_dir)
except Exception as e:
print(f"Error zipping Lance directory: {e}")
raise


if __name__ == "__main__":
try:
Expand All @@ -135,4 +84,4 @@ def _finalize_lance_output(self, lance_dir):
exit(1)
except Exception as exc:
logging.exception(exc)
exit(2)
exit(2)

0 comments on commit ba9eb1d

Please sign in to comment.