Skip to content

Commit 585ad41

Browse files
authored
Revert "Support Vertex Flex API in GeminiModelHandler (#36982)" (#37051)
This reverts commit 72e84ef.
1 parent 2cf0930 commit 585ad41

File tree

2 files changed

+2
-43
lines changed

2 files changed

+2
-43
lines changed

sdks/python/apache_beam/ml/inference/gemini_inference.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from google import genai
2727
from google.genai import errors
28-
from google.genai.types import HttpOptions
2928
from google.genai.types import Part
3029
from PIL.Image import Image
3130

@@ -109,7 +108,6 @@ def __init__(
109108
api_key: Optional[str] = None,
110109
project: Optional[str] = None,
111110
location: Optional[str] = None,
112-
use_vertex_flex_api: Optional[bool] = False,
113111
*,
114112
min_batch_size: Optional[int] = None,
115113
max_batch_size: Optional[int] = None,
@@ -139,7 +137,6 @@ def __init__(
139137
location: the GCP project to use for Vertex AI requests. Setting this
140138
parameter routes requests to Vertex AI. If this paramter is provided,
141139
project must also be provided and api_key should not be set.
142-
use_vertex_flex_api: if true, use the Vertex Flex API.
143140
min_batch_size: optional. the minimum batch size to use when batching
144141
inputs.
145142
max_batch_size: optional. the maximum batch size to use when batching
@@ -172,8 +169,6 @@ def __init__(
172169
self.location = location
173170
self.use_vertex = True
174171

175-
self.use_vertex_flex_api = use_vertex_flex_api
176-
177172
super().__init__(
178173
namespace='GeminiModelHandler',
179174
retry_filter=_retry_on_appropriate_service_error,
@@ -185,19 +180,8 @@ def create_client(self) -> genai.Client:
185180
provided when the GeminiModelHandler class is instantiated.
186181
"""
187182
if self.use_vertex:
188-
if self.use_vertex_flex_api:
189-
return genai.Client(
190-
vertexai=True,
191-
project=self.project,
192-
location=self.location,
193-
http_options=HttpOptions(
194-
api_version="v1",
195-
headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
196-
# Set timeout in the unit of millisecond.
197-
timeout=600000))
198-
else:
199-
return genai.Client(
200-
vertexai=True, project=self.project, location=self.location)
183+
return genai.Client(
184+
vertexai=True, project=self.project, location=self.location)
201185
return genai.Client(api_key=self.api_key)
202186

203187
def request(

sdks/python/apache_beam/ml/inference/gemini_inference_test.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# pytype: skip-file
1818

1919
import unittest
20-
from unittest import mock
2120

2221
try:
2322
from google.genai import errors
@@ -82,29 +81,5 @@ def test_missing_all_params(self):
8281
)
8382

8483

85-
@mock.patch('apache_beam.ml.inference.gemini_inference.genai.Client')
86-
@mock.patch('apache_beam.ml.inference.gemini_inference.HttpOptions')
87-
class TestGeminiModelHandler(unittest.TestCase):
88-
def test_create_client_with_flex_api(
89-
self, mock_http_options, mock_genai_client):
90-
handler = GeminiModelHandler(
91-
model_name="gemini-pro",
92-
request_fn=generate_from_string,
93-
project="test-project",
94-
location="us-central1",
95-
use_vertex_flex_api=True)
96-
handler.create_client()
97-
mock_http_options.assert_called_with(
98-
api_version="v1",
99-
headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
100-
timeout=600000,
101-
)
102-
mock_genai_client.assert_called_with(
103-
vertexai=True,
104-
project="test-project",
105-
location="us-central1",
106-
http_options=mock_http_options.return_value)
107-
108-
10984
if __name__ == '__main__':
11085
unittest.main()

0 commit comments

Comments
 (0)