2
2
3
3
from models .model_loader import load_model
4
4
from transformers import GenerationConfig
5
+ import google .generativeai as genai
6
+ from dotenv import load_dotenv
5
7
from logger import get_logger
8
+ from openai import OpenAI
6
9
from PIL import Image
7
10
import torch
11
+ import base64
8
12
import os
13
+ import io
14
+
9
15
10
16
logger = get_logger (__name__ )
11
17
18
+ # Function to encode the image
19
+ def encode_image (image_path ):
20
+ with open (image_path , "rb" ) as image_file :
21
+ return base64 .b64encode (image_file .read ()).decode ('utf-8' )
22
+
12
23
def generate_response (images , query , session_id , resized_height = 280 , resized_width = 280 , model_choice = 'qwen' ):
13
24
"""
14
25
Generates a response using the selected model based on the query and images.
@@ -56,18 +67,83 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
56
67
)
57
68
logger .info ("Response generated using Qwen model." )
58
69
return output_text [0 ]
70
+
59
71
elif model_choice == 'gemini' :
60
- from models .gemini_responder import generate_gemini_response
61
- model , processor = load_model ('gemini' )
62
- response = generate_gemini_response (images , query , model , processor )
63
- logger .info ("Response generated using Gemini model." )
64
- return response
72
+
73
+ model , _ = load_model ('gemini' )
74
+
75
+ try :
76
+ content = []
77
+ content .append (query ) # Add the text query first
78
+
79
+ for img_path in images :
80
+ full_path = os .path .join ('static' , img_path )
81
+ if os .path .exists (full_path ):
82
+ try :
83
+ img = Image .open (full_path )
84
+ content .append (img )
85
+ except Exception as e :
86
+ logger .error (f"Error opening image { full_path } : { e } " )
87
+ else :
88
+ logger .warning (f"Image file not found: { full_path } " )
89
+
90
+ if len (content ) == 1 : # Only text, no images
91
+ return "No images could be loaded for analysis."
92
+
93
+ response = model .generate_content (content )
94
+
95
+ if response .text :
96
+ generated_text = response .text
97
+ logger .info ("Response generated using Gemini model." )
98
+ return generated_text
99
+ else :
100
+ return "The Gemini model did not generate any text response."
101
+
102
+ except Exception as e :
103
+ logger .error (f"Error in Gemini processing: { str (e )} " , exc_info = True )
104
+ return f"An error occurred while processing the images: { str (e )} "
105
+
65
106
elif model_choice == 'gpt4' :
66
- from models .gpt4_responder import generate_gpt4_response
67
- model , _ = load_model ('gpt4' )
68
- response = generate_gpt4_response (images , query , model )
69
- logger .info ("Response generated using GPT-4 model." )
70
- return response
107
+ api_key = os .getenv ("OPENAI_API_KEY" )
108
+ client = OpenAI (api_key = api_key )
109
+
110
+ try :
111
+ content = [{"type" : "text" , "text" : query }]
112
+
113
+ for img_path in images :
114
+ full_path = os .path .join ('static' , img_path )
115
+ if os .path .exists (full_path ):
116
+ base64_image = encode_image (full_path )
117
+ content .append ({
118
+ "type" : "image_url" ,
119
+ "image_url" : {
120
+ "url" : f"data:image/jpeg;base64,{ base64_image } "
121
+ }
122
+ })
123
+ else :
124
+ logger .warning (f"Image file not found: { full_path } " )
125
+
126
+ if len (content ) == 1 : # Only text, no images
127
+ return "No images could be loaded for analysis."
128
+
129
+ response = client .chat .completions .create (
130
+ model = "gpt-4o" , # Make sure to use the correct model name
131
+ messages = [
132
+ {
133
+ "role" : "user" ,
134
+ "content" : content
135
+ }
136
+ ],
137
+ max_tokens = 1024
138
+ )
139
+
140
+ generated_text = response .choices [0 ].message .content
141
+ logger .info ("Response generated using GPT-4 model." )
142
+ return generated_text
143
+
144
+ except Exception as e :
145
+ logger .error (f"Error in GPT-4 processing: { str (e )} " , exc_info = True )
146
+ return f"An error occurred while processing the images: { str (e )} "
71
147
72
148
elif model_choice == 'llama-vision' :
73
149
# Load model, processor, and device
@@ -98,20 +174,22 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
98
174
99
175
model , sampling_params , device = load_model ('pixtral' )
100
176
101
- image_urls = []
102
- for img in images :
103
- # Convert PIL Image to base64
104
- buffered = io .BytesIO ()
105
- img .save (buffered , format = "PNG" )
106
- img_str = base64 .b64encode (buffered .getvalue ()).decode ()
107
- image_urls .append (f"data:image/png;base64,{ img_str } " )
108
177
178
+ def image_to_data_url (image_path ):
179
+
180
+ image_path = os .path .join ('static' , image_path )
181
+
182
+ with open (image_path , "rb" ) as image_file :
183
+ encoded_string = base64 .b64encode (image_file .read ()).decode ('utf-8' )
184
+ ext = os .path .splitext (image_path )[1 ][1 :] # Get the file extension
185
+ return f"data:image/{ ext } ;base64,{ encoded_string } "
186
+
109
187
messages = [
110
188
{
111
189
"role" : "user" ,
112
190
"content" : [
113
191
{"type" : "text" , "text" : query },
114
- * [{"type" : "image_url" , "image_url" : {"url" : url }} for url in image_urls ]
192
+ * [{"type" : "image_url" , "image_url" : {"url" : image_to_data_url ( img_path ) }} for i , img_path in enumerate ( images ) if i < 1 ]
115
193
]
116
194
},
117
195
]
@@ -120,10 +198,10 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
120
198
return outputs [0 ].outputs [0 ].text
121
199
122
200
elif model_choice == "molmo" :
123
-
124
201
model , processor , device = load_model ('molmo' )
202
+ model = model .half () # Convert model to half precision
125
203
pil_images = []
126
- for img_path in images :
204
+ for img_path in images [: 1 ]: # Process only the first image for now
127
205
full_path = os .path .join ('static' , img_path )
128
206
if os .path .exists (full_path ):
129
207
try :
@@ -138,53 +216,40 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
138
216
return "No images could be loaded for analysis."
139
217
140
218
try :
141
- # Log the types and shapes of the images
142
- logger .info (f"Number of images: { len (pil_images )} " )
143
- logger .info (f"Image types: { [type (img ) for img in pil_images ]} " )
144
- logger .info (f"Image sizes: { [img .size for img in pil_images ]} " )
145
-
146
219
# Process the images and text
147
220
inputs = processor .process (
148
221
images = pil_images ,
149
222
text = query
150
223
)
151
224
152
- # Log the keys and shapes of the inputs
153
- logger .info (f"Input keys: { inputs .keys ()} " )
154
- for k , v in inputs .items ():
155
- if isinstance (v , torch .Tensor ):
156
- logger .info (f"Input '{ k } ' shape: { v .shape } , dtype: { v .dtype } , device: { v .device } " )
157
- else :
158
- logger .info (f"Input '{ k } ' type: { type (v )} " )
159
-
160
225
# Move inputs to the correct device and make a batch of size 1
161
- inputs = {k : v .to (model .device ).unsqueeze (0 ) if isinstance (v , torch .Tensor ) else v for k , v in inputs .items ()}
162
-
163
- # Log the updated shapes after moving to device
164
- for k , v in inputs .items ():
165
- if isinstance (v , torch .Tensor ):
166
- logger .info (f"Updated input '{ k } ' shape: { v .shape } , dtype: { v .dtype } , device: { v .device } " )
226
+ # Convert float tensors to half precision, but keep integer tensors as they are
227
+ inputs = {k : (v .to (device ).unsqueeze (0 ).half () if v .dtype in [torch .float32 , torch .float64 ] else
228
+ v .to (device ).unsqueeze (0 ))
229
+ if isinstance (v , torch .Tensor ) else v
230
+ for k , v in inputs .items ()}
167
231
168
232
# Generate output
169
- output = model .generate_from_batch (
170
- inputs ,
171
- GenerationConfig (max_new_tokens = 200 , stop_strings = "<|endoftext|>" ),
172
- tokenizer = processor .tokenizer
173
- )
233
+ with torch .no_grad (): # Disable gradient calculation
234
+ output = model .generate_from_batch (
235
+ inputs ,
236
+ GenerationConfig (max_new_tokens = 200 , stop_strings = "<|endoftext|>" ),
237
+ tokenizer = processor .tokenizer
238
+ )
174
239
175
240
# Only get generated tokens; decode them to text
176
241
generated_tokens = output [0 , inputs ['input_ids' ].size (1 ):]
177
242
generated_text = processor .tokenizer .decode (generated_tokens , skip_special_tokens = True )
178
243
244
+ return generated_text
245
+
179
246
except Exception as e :
180
247
logger .error (f"Error in Molmo processing: { str (e )} " , exc_info = True )
181
248
return f"An error occurred while processing the images: { str (e )} "
182
249
finally :
183
250
# Close the opened images to free up resources
184
251
for img in pil_images :
185
- img .close ()
186
-
187
- return generated_text
252
+ img .close ()
188
253
else :
189
254
logger .error (f"Invalid model choice: { model_choice } " )
190
255
return "Invalid model selected."
0 commit comments