Skip to content

Commit

Permalink
feat: use dynamic batching param (#6203)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Sep 23, 2024
1 parent 434d09e commit 338ac3f
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 28 deletions.
2 changes: 1 addition & 1 deletion extra-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ mock: test
requests-mock: test
pytest-custom_exit_code: test
black==24.3.0: test
kubernetes>=18.20.0: test
kubernetes>=18.20.0,<31.0.0: test
pytest-kind==22.11.1: test
pytest-lazy-fixture: test
torch: cicd
Expand Down
3 changes: 1 addition & 2 deletions jina/clients/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
os.unsetenv('http_proxy')
os.unsetenv('https_proxy')
self._inputs = None
self._inputs_length = None
self._setup_instrumentation(
name=(
self.args.name
Expand Down Expand Up @@ -144,8 +145,6 @@ def _get_requests(
else:
total_docs = None

self._inputs_length = None

if total_docs:
self._inputs_length = max(1, total_docs / _kwargs['request_size'])

Expand Down
3 changes: 3 additions & 0 deletions jina/serve/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def dynamic_batching(
flush_all: bool = False,
custom_metric: Optional[Callable[['DocumentArray'], Union[float, int]]] = None,
use_custom_metric: bool = False,
use_dynamic_batching: bool = True,
):
"""
`@dynamic_batching` defines the dynamic batching behavior of an Executor.
Expand All @@ -438,6 +439,7 @@ def dynamic_batching(
If this is true, `preferred_batch_size` is used as a trigger mechanism.
:param custom_metric: Potential lambda function to measure the "weight" of each request.
:param use_custom_metric: Determines if we need to use the `custom_metric` to determine preferred_batch_size.
:param use_dynamic_batching: Determines if we should apply dynamic batching for this method.
:return: decorated function
"""

Expand Down Expand Up @@ -486,6 +488,7 @@ def _inject_owner_attrs(self, owner, name):
owner.dynamic_batching[fn_name]['flush_all'] = flush_all
owner.dynamic_batching[fn_name]['use_custom_metric'] = use_custom_metric
owner.dynamic_batching[fn_name]['custom_metric'] = custom_metric
owner.dynamic_batching[fn_name]['use_dynamic_batching'] = use_dynamic_batching
setattr(owner, name, self.fn)

def __set_name__(self, owner, name):
Expand Down
5 changes: 3 additions & 2 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
timeout: int = 10_000,
custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None,
use_custom_metric: bool = False,
**kwargs,
) -> None:
# To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent
self.func = func
Expand Down Expand Up @@ -285,7 +286,8 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
big_doc_in_batch, requests_idxs_in_batch,
self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None
self._preferred_batch_size if not self._flush_all else None,
docs_metrics_in_batch if self._custom_metric is not None else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
Expand Down Expand Up @@ -360,7 +362,6 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option
requests_completed_in_batch,
)


async def close(self):
"""Closes the batch queue by flushing pending requests."""
if not self._is_closed:
Expand Down
9 changes: 5 additions & 4 deletions jina/serve/runtimes/worker/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,11 @@ def _init_batchqueue_dict(self):
)
raise Exception(error_msg)

if key.startswith('/'):
dbatch_endpoints.append((key, dbatch_config))
else:
dbatch_functions.append((key, dbatch_config))
if dbatch_config.get("use_dynamic_batching", True):
if key.startswith('/'):
dbatch_endpoints.append((key, dbatch_config))
else:
dbatch_functions.append((key, dbatch_config))

# Specific endpoint configs take precedence over function configs
for endpoint, dbatch_config in dbatch_endpoints:
Expand Down
49 changes: 46 additions & 3 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,8 +706,8 @@ def foo(self, docs, **kwargs):


@pytest.mark.asyncio
@pytest.mark.parametrize('use_custom_metric', [True])
@pytest.mark.parametrize('flush_all', [True])
@pytest.mark.parametrize('use_custom_metric', [True, False])
@pytest.mark.parametrize('flush_all', [True, False])
async def test_dynamic_batching_custom_metric(use_custom_metric, flush_all):
class DynCustomBatchProcessor(Executor):

Expand All @@ -719,7 +719,9 @@ def foo(self, docs, **kwargs):
for doc in docs:
doc.text = f"{total_len}"

depl = Deployment(uses=DynCustomBatchProcessor, uses_dynamic_batching={'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_custom_metric": use_custom_metric, "flush_all": flush_all}})
depl = Deployment(uses=DynCustomBatchProcessor, uses_dynamic_batching={
'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_custom_metric": use_custom_metric,
"flush_all": flush_all}})
da = DocumentArray([Document(text='aaaaa') for i in range(50)])
with depl:
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
Expand All @@ -733,3 +735,44 @@ def foo(self, docs, **kwargs):
):
res.extend(r)
assert len(res) == 50 # 1 request per input


@pytest.mark.asyncio
@pytest.mark.parametrize('use_dynamic_batching', [True, False])
async def test_use_dynamic_batching(use_dynamic_batching):
class UseDynBatchProcessor(Executor):

@dynamic_batching(preferred_batch_size=10)
@requests(on='/foo')
def foo(self, docs, **kwargs):
print(f'len docs {len(docs)}')
for doc in docs:
doc.text = f"{len(docs)}"

depl = Deployment(uses=UseDynBatchProcessor, uses_dynamic_batching={
'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_dynamic_batching": use_dynamic_batching,
"flush_all": False}})
da = DocumentArray([Document(text='aaaaa') for _ in range(50)])
with depl:
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
res = []
async for r in cl.post(
on='/foo',
inputs=da,
request_size=1,
continue_on_error=True,
results_in_order=True,
):
res.extend(r)
assert len(res) == 50 # 1 request per input
for doc in res:
num_10 = 0
if doc.text == "10":
num_10 += 1
if not use_dynamic_batching:
assert doc.text == "1"

if use_dynamic_batching:
assert num_10 > 0
else:
assert num_10 == 0
19 changes: 10 additions & 9 deletions tests/k8s/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ def __init__(self, kind_cluster: KindCluster, logger: JinaLogger) -> None:
self._loaded_images = set()

def _linkerd_install_cmd(
self, kind_cluster: KindCluster, cmd, tool_name: str
self, kind_cluster: KindCluster, cmd, tool_name: str
) -> None:
self._log.info(f'Installing {tool_name} to Cluster...')
kube_out = subprocess.check_output(
(str(kind_cluster.kubectl_path), 'version'),
env=os.environ,
)
self._log.info(f'kuberbetes versions: {kube_out}')
self._log.info(f'kubernetes versions: {kube_out}')

# since we need to pipe to commands and the linkerd output can bee too long
# there is a risk of deadlock and hanging tests: https://docs.python.org/3/library/subprocess.html#popen-objects
Expand Down Expand Up @@ -86,7 +86,7 @@ def _install_linkerd(self, kind_cluster: KindCluster) -> None:
print(f'linkerd check yields {out.decode() if out else "nothing"}')
except subprocess.CalledProcessError as e:
print(
f'linkerd check failed with error code { e.returncode } and output { e.output }, and stderr { e.stderr }'
f'linkerd check failed with error code {e.returncode} and output {e.output}, and stderr {e.stderr}'
)
raise

Expand Down Expand Up @@ -125,16 +125,17 @@ def install_linkerd_smi(self) -> None:
print(f'linkerd check yields {out.decode() if out else "nothing"}')
except subprocess.CalledProcessError as e:
print(
f'linkerd check failed with error code { e.returncode } and output { e.output }'
f'linkerd check failed with error code {e.returncode} and output {e.output}, and stderr {e.stderr}'
)
raise

def _set_kube_config(self):
self._log.info(f'Setting KUBECONFIG to {self._kube_config_path}')
os.environ['KUBECONFIG'] = self._kube_config_path
load_cluster_config()

def load_docker_images(
self, images: List[str], image_tag_map: Dict[str, str]
self, images: List[str], image_tag_map: Dict[str, str]
) -> None:
for image in images:
full_image_name = image + ':' + image_tag_map[image]
Expand Down Expand Up @@ -213,9 +214,9 @@ def load_cluster_config() -> None:

@pytest.fixture
def docker_images(
request: FixtureRequest,
image_name_tag_map: Dict[str, str],
k8s_cluster: KindClusterWrapper,
request: FixtureRequest,
image_name_tag_map: Dict[str, str],
k8s_cluster: KindClusterWrapper,
) -> List[str]:
image_names: List[str] = request.param
k8s_cluster.load_docker_images(image_names, image_name_tag_map)
Expand All @@ -227,7 +228,7 @@ def docker_images(

@contextlib.contextmanager
def shell_portforward(
kubectl_path, pod_or_service, port1, port2, namespace, waiting: float = 1
kubectl_path, pod_or_service, port1, port2, namespace, waiting: float = 1
):
try:
proc = subprocess.Popen(
Expand Down
1 change: 0 additions & 1 deletion tests/k8s/test_k8s_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from jina.serve.runtimes.servers import BaseServer

from jina import Deployment, Client
from jina.helper import random_port
from tests.k8s.conftest import shell_portforward

cluster.KIND_VERSION = 'v0.11.1'
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/serve/executors/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,15 +614,15 @@ class C(B):
[
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
(
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
(
dict(preferred_batch_size=4),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
],
)
Expand All @@ -641,15 +641,15 @@ def foo(self, docs, **kwargs):
[
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
(
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
(
dict(preferred_batch_size=4),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True),
),
],
)
Expand Down

0 comments on commit 338ac3f

Please sign in to comment.