Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions src/guidellm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def benchmark():
"--backend-kwargs",
"--backend-args", # legacy alias
"backend_kwargs",
callback=cli_tools.parse_json,
default=BenchmarkGenerativeTextArgs.get_default("backend_kwargs"),
help="JSON string of arguments to pass to the backend.",
)
Expand Down Expand Up @@ -205,14 +204,12 @@ def benchmark():
@click.option(
"--processor-args",
default=BenchmarkGenerativeTextArgs.get_default("processor_args"),
callback=cli_tools.parse_json,
help="JSON string of arguments to pass to the processor constructor.",
)
@click.option(
"--data-args",
multiple=True,
default=BenchmarkGenerativeTextArgs.get_default("data_args"),
callback=cli_tools.parse_json,
help="JSON string of arguments to pass to dataset creation.",
)
@click.option(
Expand All @@ -227,7 +224,6 @@ def benchmark():
@click.option(
"--data-column-mapper",
default=BenchmarkGenerativeTextArgs.get_default("data_column_mapper"),
callback=cli_tools.parse_json,
help="JSON string of column mappings to apply to the dataset.",
)
@click.option(
Expand All @@ -245,7 +241,6 @@ def benchmark():
@click.option(
"--dataloader-kwargs",
default=BenchmarkGenerativeTextArgs.get_default("dataloader_kwargs"),
callback=cli_tools.parse_json,
help="JSON string of arguments to pass to the dataloader constructor.",
)
@click.option(
Expand Down Expand Up @@ -303,7 +298,6 @@ def benchmark():
"--warmup-percent", # legacy alias
"warmup",
default=BenchmarkGenerativeTextArgs.get_default("warmup"),
callback=cli_tools.parse_json,
help=(
"Warmup specification: int, float, or dict as string "
"(json or key=value). "
Expand All @@ -318,7 +312,6 @@ def benchmark():
"--cooldown-percent", # legacy alias
"cooldown",
default=BenchmarkGenerativeTextArgs.get_default("cooldown"),
callback=cli_tools.parse_json,
help=(
"Cooldown specification: int, float, or dict as string "
"(json or key=value). "
Expand Down Expand Up @@ -387,7 +380,6 @@ def benchmark():
@click.option(
"--over-saturation",
"over_saturation",
callback=cli_tools.parse_json,
default=None,
help=(
"Enable over-saturation detection. "
Expand Down
45 changes: 45 additions & 0 deletions src/guidellm/benchmark/schemas/generative/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from guidellm.data import DatasetPreprocessor, RequestFormatter
from guidellm.scheduler import StrategyType
from guidellm.schemas import StandardBaseModel
from guidellm.utils import arg_string

__all__ = ["BenchmarkGenerativeTextArgs"]

Expand Down Expand Up @@ -312,6 +313,50 @@ def single_to_list(
else:
raise

@field_validator(
"backend_kwargs",
"processor_args",
"data_args",
"data_column_mapper",
"data_request_formatter",
"dataloader_kwargs",
"warmup",
"cooldown",
"over_saturation",
mode="wrap",
)
@classmethod
def parse_config_str(
cls,
value: Any,
handler: ValidatorFunctionWrapHandler,
) -> Any:
"""
Parse backend/profile from string to instance if necessary.

:param value: Input value for the 'backend' or 'profile' field
:return: Parsed backend/profile instance or original value
"""
if isinstance(value, str):
try:
value_parsed = yaml.safe_load(value)
except yaml.YAMLError:
value_parsed = value
# If no change from YAML parsing, try arg_string parsing
if value_parsed == value:
try:
value_parsed = arg_string.loads(value)
# If arg_string parsing fails, attempt to parse the original string
except arg_string.ArgStringParseError as e:
try:
return handler(value)
except ValidationError as err:
# If validation fails, re-raise from the arg_string error
raise err from e
return handler(value_parsed)
else:
return handler(value)

@field_serializer("backend")
def serialize_backend(self, backend: BackendType | Backend) -> str:
"""Serialize backend to type string."""
Expand Down
58 changes: 20 additions & 38 deletions src/guidellm/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import ValidationError

from guidellm.data.schemas import DataConfig, DataNotSupportedError
from guidellm.utils import arg_string

ConfigT = TypeVar("ConfigT", bound=DataConfig)

Expand Down Expand Up @@ -48,9 +49,7 @@ def _load_config_file(data: Any, config_class: type[ConfigT]) -> ConfigT | None:

if Path(data).is_file() and data_path.suffix.lower() == ".json":
try:
return config_class.model_validate_json(
data_path.read_text()
)
return config_class.model_validate_json(data_path.read_text())
except Exception as err: # noqa: BLE001
error = err

Expand All @@ -60,9 +59,7 @@ def _load_config_file(data: Any, config_class: type[ConfigT]) -> ConfigT | None:
".config",
}:
try:
return config_class.model_validate(
yaml.safe_load(data_path.read_text())
)
return config_class.model_validate(yaml.safe_load(data_path.read_text()))
except Exception as err: # noqa: BLE001
error = err

Expand All @@ -82,39 +79,24 @@ def _load_config_str(data: str, config_class: type[ConfigT]) -> ConfigT | None:
if not isinstance(data, str):
return None

data_str = data.strip()
error = None

if (data_str.startswith("{") and data_str.endswith("}")) or (
data_str.startswith("[") and data_str.endswith("]")
):
try:
return config_class.model_validate_json(data_str)
except Exception as err: # noqa: BLE001
error = err

if data_str.count("=") > 1:
# key=value pairs separated by commas
try:
config_dict = {}
items = data_str.split(",")
for item in items:
key, value = item.split("=")
config_dict[key.strip()] = (
int(value.strip())
if value.strip().isnumeric()
else value.strip()
)

return config_class.model_validate(config_dict)
except Exception as err: # noqa: BLE001
error = err

err_message = (
f"Unsupported string data for {config_class.__name__}, "
f"expected JSON or key-value pairs, got {data}"
)
if error is not None:
err_message += f" with error: {error}"
raise DataNotSupportedError(err_message) from error
raise DataNotSupportedError(err_message)

try:
data_parsed = yaml.safe_load(data)
except yaml.YAMLError:
data_parsed = data

# If no change from YAML parsing, try arg_string parsing
if data_parsed == data:
try:
data_parsed = arg_string.loads(data_parsed)
except arg_string.ArgStringParseError as e:
raise DataNotSupportedError(err_message) from e

try:
return config_class.model_validate(data_parsed)
except ValidationError as err:
raise DataNotSupportedError(err_message) from err
2 changes: 2 additions & 0 deletions src/guidellm/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import arg_string
from .auto_importer import AutoImporterMixin
from .console import Colors, Console, ConsoleUpdateStep, StatusIcons, StatusStyles
from .default_group import DefaultGroupHandler
Expand Down Expand Up @@ -77,6 +78,7 @@
"StatusStyles",
"ThreadSafeSingletonMixin",
"all_defined",
"arg_string",
"camelize_str",
"check_load_processor",
"clean_text",
Expand Down
Loading
Loading