Skip to content

Commit

Permalink
test.0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
Nweaver412 committed Jan 21, 2025
1 parent ba9eb1d commit b38c6a0
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 18 deletions.
19 changes: 17 additions & 2 deletions component_config/configSchema.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"required": [
"embedColumn",
"#apiKey",
"model"
"model",
"chunkingEnabled"
],
"properties": {
"model": {
Expand Down Expand Up @@ -50,6 +51,20 @@
},
"description": "Choose the output format for the embeddings",
"propertyOrder": 3
},
"chunkingEnabled": {
"type": "boolean",
"title": "Enable Chunking",
"default": false,
"description": "Enable chunking of input data for embedding.",
"propertyOrder": 4
},
"chunkSize": {
"type": "integer",
"title": "Chunk Size",
"minimum": 1,
"description": "Number of rows to process in a single chunk when chunking is enabled.",
"propertyOrder": 5
}
}
}
}
54 changes: 38 additions & 16 deletions src/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run(self):
reader = csv.DictReader(input_file)
self._process_rows_csv(reader)
except Exception as e:
raise UserException(f"Error occurred during embedding process: {str(e)}")
raise logging.info(f"Error occurred during embedding process: {str(e)}")

def _process_rows_csv(self, reader):

Expand All @@ -35,13 +35,36 @@ def _process_rows_csv(self, reader):
fieldnames = reader.fieldnames + ['embedding']
writer = csv.DictWriter(output_file, fieldnames=fieldnames)
writer.writeheader()

self.row_count = 0
for row in reader:
self.row_count += 1
text = row[self._configuration.embedColumn]
embedding = self.get_embedding(text)
row['embedding'] = embedding if embedding else "[]"
writer.writerow(row)
if self._configuration.chunkingEnabled:
chunk = []
for row in reader:
self.row_count += 1
text = row[self._configuration.embedColumn]
chunk.append(text)

if len(chunk) == self._configuration.chunkSize:
self._process_chunk(chunk, writer, row)
chunk = []

if chunk:
self._process_chunk(chunk, writer, row)
else:
for row in reader:
self.row_count += 1
text = row[self._configuration.embedColumn]
embedding = self.get_embedding([text])[0]
row['embedding'] = embedding if embedding else "[]"
writer.writerow(row)

def _process_chunk(self, chunk, writer, row_template):
embeddings = self.get_embedding(chunk)

for i, embedding in enumerate(embeddings):
row = row_template.copy()
row['embedding'] = embedding if embedding else "[]"
writer.writerow(row)

def init_configuration(self):
self.validate_configuration_parameters(Configuration.get_dataclass_required_parameters())
Expand All @@ -50,19 +73,19 @@ def init_configuration(self):
def init_openai_client(self):
self.client = OpenAI(api_key=self._configuration.pswd_apiKey)


def get_embedding(self, text, model = 'text-embedding-3-small'):
if not text or not isinstance(text, str) or text.strip() == "":
return []
text = text.replace("\n", " ")
return self.client.embeddings.create(input = [text], model=model).data[0].embedding
def get_embedding(self, texts, model=None):
if not texts or not isinstance(texts, list):
return []
texts = [text.replace("\n", " ") for text in texts if isinstance(text, str) and text.strip()]
model = model or self._configuration.model
response = self.client.embeddings.create(input=texts, model=model)
return [item.embedding for item in response.data]

def _get_input_table(self):
if not self.get_input_tables_definitions():
raise UserException("No input table specified. Please provide one input table in the input mapping!")
if len(self.get_input_tables_definitions()) > 1:
raise UserException("Only one input table is supported")
raise logging.info("Only one input table is supported")
return self.get_input_tables_definitions()[0]

def _get_output_table(self):
Expand All @@ -74,7 +97,6 @@ def _get_output_table(self):

return self.create_out_table_definition(out_table_name)


if __name__ == "__main__":
try:
comp = Component()
Expand Down
2 changes: 2 additions & 0 deletions src/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class Configuration(ConfigurationBase):
pswd_apiKey: str
model: str
destination: Destination
chunkingEnabled: bool = False
chunkSize: int = 1
outputFormat: str = "csv"

def __post_init__(self):
Expand Down

0 comments on commit b38c6a0

Please sign in to comment.