Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Matrix is a library for fast, scalable, and easy-to-use LLM-generation engine, f

Matrix runs on top of a [Ray](https://github.com/ray-project/ray) cluster for scalability. Cluster resources are acquired from [Slurm](https://slurm.schedmd.com/documentation.html) or local through [submitit](https://github.com/facebookincubator/submitit). Matrix has following main features:

**Large scale inference** for maintstream opensourced and proprietary LLMs
**Large scale inference** for mainstream opensourced and proprietary LLMs
- Hugging Face LLMs via seamless integration with [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang). Native multi-node inference support.
- Azure OpenAI, SageMaker, Gemini models with Proxy server

Expand Down Expand Up @@ -97,7 +97,7 @@ matrix deploy_applications --applications "[{'model_name': 'meta-llama/Llama-3.1
matrix check_health --app_name 8B
```

- Shudown ray cluster
- Shutdown ray cluster
```bash
matrix stop_cluster
```
Expand Down Expand Up @@ -138,7 +138,7 @@ vLLM Engine [Arguments](https://docs.vllm.ai/en/latest/serving/engine_args.html)
* `name`: the default app_name.
* `model_size`: template to apply when model is from a directory, such as 8B, 70B, 405B etc, templates are from the llm_config.py file.
* `max_ongoing_requests`: the max concurrent requests to each replica.
* `min_replia` and `max_replica`: the num of replicas ranges auto-scaled based on num of Ray workers.
* `min_replica` and `max_replica`: the num of replicas ranges auto-scaled based on num of Ray workers.
* `use_grpc`: enable grpc by adding `{'use_grpc': 'true'}`.

### OpenAI Azure Model
Expand All @@ -152,7 +152,7 @@ matrix deploy_applications --applications "[{'api_version': \"$AZURE_API_VERSION
- Note: no GPU is required, in start_workers, can add `--slurm "{'gpus_per_node': 0}"`

```bash
matrix deploy_applications --applications "[{'app_type': 'gemini', 'name': "gemini", 'api_key': \"$GOOGLE_API_KEY\", 'model_name': 'gemini-2.0-flash'}]"
matrix deploy_applications --applications "[{'app_type': 'gemini', 'name': 'gemini', 'api_key': \"$GOOGLE_API_KEY\", 'model_name': 'gemini-2.0-flash'}]"
```

### Deepseek R1
Expand All @@ -161,7 +161,7 @@ vLLM >=0.8.3 supports DS R1. An alternative backend is sglang.
# install sglang
pip install fair-matrix[sglang_045]

matrix deploy_applications --applications "[{'model_name': 'deepseek-ai/DeepSeek-R1', 'pipeline-parallel-size': 2, 'app_type': sglang_llm, 'name': 'r1'}]"
matrix deploy_applications --applications "[{'model_name': 'deepseek-ai/DeepSeek-R1', 'pipeline-parallel-size': 2, 'app_type': 'sglang_llm', 'name': 'r1'}]"
```
### Llama 4
```bash
Expand Down Expand Up @@ -326,7 +326,7 @@ python -m matrix.data_pipeline.generate.vllm_generate $ray_head:$client_server_p

## Peer-to-peer

Peer-to-peer framework avoids the single orchestration botttleneck and supports diverse synthetic data generaion tasks. More details are in [here](matrix/agents/README.md).
Peer-to-peer framework avoids the single orchestration bottleneck and supports diverse synthetic data generation tasks. More details are in [here](matrix/agents/README.md).

---

Expand Down
11 changes: 11 additions & 0 deletions matrix/app_server/app_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_app_type,
get_yaml_for_deployment,
is_sglang_app,
sort_apps_by_gpu_requirements,
write_yaml_file,
)
from matrix.client.endpoint_cache import EndpointCache
Expand Down Expand Up @@ -102,6 +103,16 @@ def deploy(
f"Invalid action '{action}', expected one of {[a.value for a in Action]}"
)
if action in [Action.ADD, Action.REPLACE]:
# Sort applications by GPU requirements (largest first) to minimize fragmentation
# This helps prevent resource fragmentation by deploying models requiring more GPUs
# first, leaving smaller contiguous blocks for models requiring fewer GPUs.
# See issue #111: https://github.com/facebookresearch/matrix/issues/111
if applications and len(applications) > 1:
logger.info(
"Sorting applications by GPU requirements (largest first) to minimize fragmentation"
)
applications = sort_apps_by_gpu_requirements(applications)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if sorting makes a difference?

Existing deployment could partially occupy a node, which prevent later deployment to get resources without removing old one. Not sure if Ray also does a poor job of compaction for multiple models in the same deployment.


for app in applications or []:
if str(app.get("model_name", "")).startswith("s3://"):
cache_dir = os.environ.get(
Expand Down
94 changes: 94 additions & 0 deletions matrix/app_server/deploy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,100 @@ def update_vllm_app_params(app: Dict[str, Union[str, int]]):
return app


def get_gpu_requirements_per_replica(app: Dict[str, Union[str, int]]) -> int:
"""
Calculate the number of GPUs required per replica for an application.

For LLM models, this is determined by tensor_parallel_size.
For other app types, returns 0 (no GPU requirement) or 1 if specified.

Args:
app: Application configuration dictionary

Returns:
Number of GPUs required per replica
"""
app_type = app.get("app_type", "llm")

# For LLM models, GPU requirement is tensor_parallel_size
if app_type in ["llm", "sglang_llm", "fastgen"]:
# Get tensor_parallel_size from app config (check both formats)
tensor_parallel = app.get("tensor-parallel-size") or app.get("tensor_parallel_size")
if tensor_parallel is None:
# Try to get from model defaults
model_name = str(app.get("model_name", ""))
default_params = llm_model_default_parameters.get(model_name)
if default_params:
tensor_parallel = default_params.get("tensor-parallel-size", 1)
else:
# Default to 1 if we can't determine
tensor_parallel = 1
return int(tensor_parallel)

# For vision models, typically 1 GPU
elif app_type in ["perception_encoder", "optical_flow"]:
return 1

# For other app types (code, container, proxies), no GPU requirement
else:
return 0


def sort_apps_by_gpu_requirements(
applications: List[Dict[str, Union[str, int]]]
) -> List[Dict[str, Union[str, int]]]:
"""
Sort applications by GPU requirements (largest first) to minimize fragmentation.

This helps prevent resource fragmentation by deploying models requiring more GPUs
first, leaving smaller contiguous blocks for models requiring fewer GPUs.
This addresses issue #111: https://github.com/facebookresearch/matrix/issues/111

Args:
applications: List of application configuration dictionaries

Returns:
Sorted list of applications (largest GPU requirement first)
"""
# Calculate GPU requirements for each app
# We try to get accurate tensor_parallel_size from defaults if not specified
apps_with_gpus = []
for app in applications:
app_type = app.get("app_type", "llm")
gpu_req = 0

if app_type in ["llm", "sglang_llm", "fastgen"]:
# Try to get tensor_parallel_size from app config first
tensor_parallel = app.get("tensor-parallel-size") or app.get("tensor_parallel_size")
if tensor_parallel is None:
# Try to get from model defaults without modifying the app
model_name = str(app.get("model_name", ""))
default_params = llm_model_default_parameters.get(model_name)
if default_params:
tensor_parallel = default_params.get("tensor-parallel-size", 1)
else:
tensor_parallel = 1
gpu_req = int(tensor_parallel)
elif app_type in ["perception_encoder", "optical_flow"]:
gpu_req = 1
else:
gpu_req = 0

apps_with_gpus.append((gpu_req, app))

# Sort by GPU requirements (descending), then by min_replica (descending)
# This ensures largest models are deployed first
sorted_apps = sorted(
apps_with_gpus,
key=lambda x: (
-x[0], # GPU requirement (negative for descending)
-x[1].get("min_replica", 1), # min_replica (negative for descending)
),
)

return [app for _, app in sorted_apps]


def is_sglang_app(app):
if "deployments" in app:
return "sglang" in app["deployments"][0]["name"].lower()
Expand Down
168 changes: 168 additions & 0 deletions tests/unit/app_server/test_resource_fragmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Tests for resource fragmentation fix (Issue #111).

This test verifies that applications are sorted by GPU requirements
to minimize resource fragmentation when deploying multiple models
with different GPU requirements.
"""

import pytest

from matrix.app_server.deploy_utils import (
get_gpu_requirements_per_replica,
sort_apps_by_gpu_requirements,
)


def test_get_gpu_requirements_per_replica():
"""Test GPU requirement calculation for different app types."""
# LLM model with explicit tensor_parallel_size
app1 = {
"app_type": "llm",
"model_name": "test-model",
"tensor-parallel-size": 4,
}
assert get_gpu_requirements_per_replica(app1) == 4

# LLM model with tensor_parallel_size (underscore format)
app2 = {
"app_type": "llm",
"model_name": "test-model",
"tensor_parallel_size": 2,
}
assert get_gpu_requirements_per_replica(app2) == 2

# LLM model using default from llm_config
app3 = {
"app_type": "llm",
"model_name": "meta-llama/Llama-3.1-70B-Instruct",
}
# Llama-3.1-70B-Instruct has tensor-parallel-size: 4 in defaults
assert get_gpu_requirements_per_replica(app3) == 4

# Vision model
app4 = {"app_type": "perception_encoder", "model_name": "test-vision"}
assert get_gpu_requirements_per_replica(app4) == 1

# Code execution (no GPU)
app5 = {"app_type": "code", "name": "code"}
assert get_gpu_requirements_per_replica(app5) == 0

# Container (no GPU)
app6 = {"app_type": "container", "name": "container"}
assert get_gpu_requirements_per_replica(app6) == 0


def test_sort_apps_by_gpu_requirements():
"""Test that apps are sorted by GPU requirements (largest first)."""
# Create apps with different GPU requirements
apps = [
{
"app_type": "llm",
"model_name": "model-1gpu",
"tensor-parallel-size": 1,
"name": "model-a",
"min_replica": 14,
},
{
"app_type": "llm",
"model_name": "model-2gpu",
"tensor-parallel-size": 2,
"name": "model-b",
"min_replica": 1,
},
{
"app_type": "llm",
"model_name": "model-4gpu",
"tensor-parallel-size": 4,
"name": "model-c",
"min_replica": 1,
},
{
"app_type": "code",
"name": "code-app",
},
]

sorted_apps = sort_apps_by_gpu_requirements(apps)

# Verify sorting: largest GPU requirement first
assert sorted_apps[0]["name"] == "model-c" # 4 GPUs
assert sorted_apps[1]["name"] == "model-b" # 2 GPUs
assert sorted_apps[2]["name"] == "model-a" # 1 GPU
assert sorted_apps[3]["name"] == "code-app" # 0 GPUs

# Verify all apps are present
assert len(sorted_apps) == len(apps)


def test_sort_apps_same_gpu_requirements():
"""Test sorting when apps have same GPU requirements."""
apps = [
{
"app_type": "llm",
"model_name": "model-a",
"tensor-parallel-size": 2,
"name": "model-a",
"min_replica": 1,
},
{
"app_type": "llm",
"model_name": "model-b",
"tensor-parallel-size": 2,
"name": "model-b",
"min_replica": 5,
},
]

sorted_apps = sort_apps_by_gpu_requirements(apps)

# When GPU requirements are equal, sort by min_replica (descending)
assert sorted_apps[0]["name"] == "model-b" # min_replica: 5
assert sorted_apps[1]["name"] == "model-a" # min_replica: 1


def test_sort_apps_with_defaults():
"""Test sorting when apps use default tensor_parallel_size from model config."""
apps = [
{
"app_type": "llm",
"model_name": "meta-llama/Llama-3.1-8B-Instruct", # Default: 1 GPU
"name": "model-8b",
},
{
"app_type": "llm",
"model_name": "meta-llama/Llama-3.1-70B-Instruct", # Default: 4 GPUs
"name": "model-70b",
},
]

sorted_apps = sort_apps_by_gpu_requirements(apps)

# 70B model (4 GPUs) should come before 8B model (1 GPU)
assert sorted_apps[0]["name"] == "model-70b"
assert sorted_apps[1]["name"] == "model-8b"


def test_sort_single_app():
"""Test that sorting a single app returns it unchanged."""
apps = [
{
"app_type": "llm",
"model_name": "test-model",
"tensor-parallel-size": 2,
"name": "single-app",
}
]

sorted_apps = sort_apps_by_gpu_requirements(apps)

assert len(sorted_apps) == 1
assert sorted_apps[0]["name"] == "single-app"