Skip to content

Commit 0ba3c76

Browse files
committed
updated support for gpt-4, pixtral, gemini and momlo
1 parent 51647db commit 0ba3c76

File tree

3 files changed

+138
-68
lines changed

3 files changed

+138
-68
lines changed

models/model_loader.py

+23-20
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
import torch
55
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
66
from transformers import MllamaForConditionalGeneration
7+
from vllm.sampling_params import SamplingParams
78
from transformers import AutoModelForCausalLM
9+
import google.generativeai as genai
810
from vllm import LLM
9-
from vllm.sampling_params import SamplingParams
11+
12+
from dotenv import load_dotenv
13+
14+
# Load environment variables from .env file
15+
load_dotenv()
1016

1117
from logger import get_logger
1218

@@ -21,8 +27,8 @@ def detect_device():
2127
"""
2228
if torch.cuda.is_available():
2329
return 'cuda'
24-
# elif torch.backends.mps.is_available():
25-
# return 'mps'
30+
elif torch.backends.mps.is_available():
31+
return 'mps'
2632
else:
2733
return 'cpu'
2834

@@ -51,21 +57,13 @@ def load_model(model_choice):
5157

5258
elif model_choice == 'gemini':
5359
# Load Gemini model
54-
import genai
55-
genai.api_key = os.environ.get('GENAI_API_KEY')
56-
model = genai.GenerativeModel(model_name="gemini-1.5-pro")
57-
processor = None
58-
_model_cache[model_choice] = (model, processor)
59-
logger.info("Gemini model loaded and cached.")
60-
return _model_cache[model_choice]
60+
api_key = os.getenv("GOOGLE_API_KEY")
61+
if not api_key:
62+
raise ValueError("GOOGLE_API_KEY not found in .env file")
63+
genai.configure(api_key=api_key)
64+
model = genai.GenerativeModel('gemini-1.5-flash-002') # Use the appropriate model name
65+
return model, None
6166

62-
elif model_choice == 'gpt4':
63-
# Load OpenAI GPT-4 model
64-
import openai
65-
openai.api_key = os.environ.get('OPENAI_API_KEY')
66-
_model_cache[model_choice] = (None, None)
67-
logger.info("GPT-4 model ready and cached.")
68-
return _model_cache[model_choice]
6967

7068
elif model_choice == 'llama-vision':
7169
# Load Llama-Vision model
@@ -85,21 +83,26 @@ def load_model(model_choice):
8583

8684
elif model_choice == "pixtral":
8785
device = detect_device()
88-
model = LLM(model="mistralai/Pixtral-12B-2409", tokenizer_mode="mistral")
86+
model = LLM(model="mistralai/Pixtral-12B-2409",
87+
tokenizer_mode="mistral",
88+
gpu_memory_utilization=0.8, # Increase GPU memory utilization
89+
max_model_len=8192, # Decrease max model length
90+
dtype="float16", # Use half precision to save memory
91+
trust_remote_code=True)
8992
sampling_params = SamplingParams(max_tokens=1024)
9093
_model_cache[model_choice] = (model, sampling_params, device)
9194
return _model_cache[model_choice]
9295

9396
elif model_choice == "molmo":
9497
device = detect_device()
9598
processor = AutoProcessor.from_pretrained(
96-
'allenai/Molmo-7B-D-0924',
99+
'allenai/MolmoE-1B-0924',
97100
trust_remote_code=True,
98101
torch_dtype='auto',
99102
device_map='auto'
100103
)
101104
model = AutoModelForCausalLM.from_pretrained(
102-
'allenai/Molmo-7B-D-0924',
105+
'allenai/MolmoE-1B-0924',
103106
trust_remote_code=True,
104107
torch_dtype='auto',
105108
device_map='auto'

models/responder.py

+112-47
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,24 @@
22

33
from models.model_loader import load_model
44
from transformers import GenerationConfig
5+
import google.generativeai as genai
6+
from dotenv import load_dotenv
57
from logger import get_logger
8+
from openai import OpenAI
69
from PIL import Image
710
import torch
11+
import base64
812
import os
13+
import io
14+
915

1016
logger = get_logger(__name__)
1117

18+
# Function to encode the image
19+
def encode_image(image_path):
20+
with open(image_path, "rb") as image_file:
21+
return base64.b64encode(image_file.read()).decode('utf-8')
22+
1223
def generate_response(images, query, session_id, resized_height=280, resized_width=280, model_choice='qwen'):
1324
"""
1425
Generates a response using the selected model based on the query and images.
@@ -56,18 +67,83 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
5667
)
5768
logger.info("Response generated using Qwen model.")
5869
return output_text[0]
70+
5971
elif model_choice == 'gemini':
60-
from models.gemini_responder import generate_gemini_response
61-
model, processor = load_model('gemini')
62-
response = generate_gemini_response(images, query, model, processor)
63-
logger.info("Response generated using Gemini model.")
64-
return response
72+
73+
model, _ = load_model('gemini')
74+
75+
try:
76+
content = []
77+
content.append(query) # Add the text query first
78+
79+
for img_path in images:
80+
full_path = os.path.join('static', img_path)
81+
if os.path.exists(full_path):
82+
try:
83+
img = Image.open(full_path)
84+
content.append(img)
85+
except Exception as e:
86+
logger.error(f"Error opening image {full_path}: {e}")
87+
else:
88+
logger.warning(f"Image file not found: {full_path}")
89+
90+
if len(content) == 1: # Only text, no images
91+
return "No images could be loaded for analysis."
92+
93+
response = model.generate_content(content)
94+
95+
if response.text:
96+
generated_text = response.text
97+
logger.info("Response generated using Gemini model.")
98+
return generated_text
99+
else:
100+
return "The Gemini model did not generate any text response."
101+
102+
except Exception as e:
103+
logger.error(f"Error in Gemini processing: {str(e)}", exc_info=True)
104+
return f"An error occurred while processing the images: {str(e)}"
105+
65106
elif model_choice == 'gpt4':
66-
from models.gpt4_responder import generate_gpt4_response
67-
model, _ = load_model('gpt4')
68-
response = generate_gpt4_response(images, query, model)
69-
logger.info("Response generated using GPT-4 model.")
70-
return response
107+
api_key = os.getenv("OPENAI_API_KEY")
108+
client = OpenAI(api_key=api_key)
109+
110+
try:
111+
content = [{"type": "text", "text": query}]
112+
113+
for img_path in images:
114+
full_path = os.path.join('static', img_path)
115+
if os.path.exists(full_path):
116+
base64_image = encode_image(full_path)
117+
content.append({
118+
"type": "image_url",
119+
"image_url": {
120+
"url": f"data:image/jpeg;base64,{base64_image}"
121+
}
122+
})
123+
else:
124+
logger.warning(f"Image file not found: {full_path}")
125+
126+
if len(content) == 1: # Only text, no images
127+
return "No images could be loaded for analysis."
128+
129+
response = client.chat.completions.create(
130+
model="gpt-4o", # Make sure to use the correct model name
131+
messages=[
132+
{
133+
"role": "user",
134+
"content": content
135+
}
136+
],
137+
max_tokens=1024
138+
)
139+
140+
generated_text = response.choices[0].message.content
141+
logger.info("Response generated using GPT-4 model.")
142+
return generated_text
143+
144+
except Exception as e:
145+
logger.error(f"Error in GPT-4 processing: {str(e)}", exc_info=True)
146+
return f"An error occurred while processing the images: {str(e)}"
71147

72148
elif model_choice == 'llama-vision':
73149
# Load model, processor, and device
@@ -98,20 +174,22 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
98174

99175
model, sampling_params, device = load_model('pixtral')
100176

101-
image_urls = []
102-
for img in images:
103-
# Convert PIL Image to base64
104-
buffered = io.BytesIO()
105-
img.save(buffered, format="PNG")
106-
img_str = base64.b64encode(buffered.getvalue()).decode()
107-
image_urls.append(f"data:image/png;base64,{img_str}")
108177

178+
def image_to_data_url(image_path):
179+
180+
image_path = os.path.join('static', image_path)
181+
182+
with open(image_path, "rb") as image_file:
183+
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
184+
ext = os.path.splitext(image_path)[1][1:] # Get the file extension
185+
return f"data:image/{ext};base64,{encoded_string}"
186+
109187
messages = [
110188
{
111189
"role": "user",
112190
"content": [
113191
{"type": "text", "text": query},
114-
*[{"type": "image_url", "image_url": {"url": url}} for url in image_urls]
192+
*[{"type": "image_url", "image_url": {"url": image_to_data_url(img_path)}} for i, img_path in enumerate(images) if i<1]
115193
]
116194
},
117195
]
@@ -120,10 +198,10 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
120198
return outputs[0].outputs[0].text
121199

122200
elif model_choice == "molmo":
123-
124201
model, processor, device = load_model('molmo')
202+
model = model.half() # Convert model to half precision
125203
pil_images = []
126-
for img_path in images:
204+
for img_path in images[:1]: # Process only the first image for now
127205
full_path = os.path.join('static', img_path)
128206
if os.path.exists(full_path):
129207
try:
@@ -138,53 +216,40 @@ def generate_response(images, query, session_id, resized_height=280, resized_wid
138216
return "No images could be loaded for analysis."
139217

140218
try:
141-
# Log the types and shapes of the images
142-
logger.info(f"Number of images: {len(pil_images)}")
143-
logger.info(f"Image types: {[type(img) for img in pil_images]}")
144-
logger.info(f"Image sizes: {[img.size for img in pil_images]}")
145-
146219
# Process the images and text
147220
inputs = processor.process(
148221
images=pil_images,
149222
text=query
150223
)
151224

152-
# Log the keys and shapes of the inputs
153-
logger.info(f"Input keys: {inputs.keys()}")
154-
for k, v in inputs.items():
155-
if isinstance(v, torch.Tensor):
156-
logger.info(f"Input '{k}' shape: {v.shape}, dtype: {v.dtype}, device: {v.device}")
157-
else:
158-
logger.info(f"Input '{k}' type: {type(v)}")
159-
160225
# Move inputs to the correct device and make a batch of size 1
161-
inputs = {k: v.to(model.device).unsqueeze(0) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
162-
163-
# Log the updated shapes after moving to device
164-
for k, v in inputs.items():
165-
if isinstance(v, torch.Tensor):
166-
logger.info(f"Updated input '{k}' shape: {v.shape}, dtype: {v.dtype}, device: {v.device}")
226+
# Convert float tensors to half precision, but keep integer tensors as they are
227+
inputs = {k: (v.to(device).unsqueeze(0).half() if v.dtype in [torch.float32, torch.float64] else
228+
v.to(device).unsqueeze(0))
229+
if isinstance(v, torch.Tensor) else v
230+
for k, v in inputs.items()}
167231

168232
# Generate output
169-
output = model.generate_from_batch(
170-
inputs,
171-
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
172-
tokenizer=processor.tokenizer
173-
)
233+
with torch.no_grad(): # Disable gradient calculation
234+
output = model.generate_from_batch(
235+
inputs,
236+
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
237+
tokenizer=processor.tokenizer
238+
)
174239

175240
# Only get generated tokens; decode them to text
176241
generated_tokens = output[0, inputs['input_ids'].size(1):]
177242
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
178243

244+
return generated_text
245+
179246
except Exception as e:
180247
logger.error(f"Error in Molmo processing: {str(e)}", exc_info=True)
181248
return f"An error occurred while processing the images: {str(e)}"
182249
finally:
183250
# Close the opened images to free up resources
184251
for img in pil_images:
185-
img.close()
186-
187-
return generated_text
252+
img.close()
188253
else:
189254
logger.error(f"Invalid model choice: {model_choice}")
190255
return "Invalid model selected."

requirements.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ docx2pdf
88
qwen-vl-utils
99
vllm>=0.6.1.post1
1010
mistral_common>=1.4.1
11-
einops
11+
einops
12+
mistral_common[opencv]
13+
mistral_common

0 commit comments

Comments
 (0)