Skip to content

Commit

Permalink
reorg code, resolve client import bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii committed Nov 8, 2023
1 parent 1286f45 commit e9f3af3
Show file tree
Hide file tree
Showing 14 changed files with 170 additions and 129 deletions.
11 changes: 2 additions & 9 deletions mii/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
try:
import grpc
from .pipeline import pipeline
from .server import serve
from .client import client
except ImportError as e:
print("Warning: DeepSpeed-FastGen could not be imported:")
print(e)
pass
import grpc
from .api import client, serve, pipeline

from .legacy import MIIServer, MIIClient, mii_query_handle, deploy, terminate, DeploymentType, TaskType, aml_output_path, MIIConfig, ModelConfig, get_supported_models

Expand Down
149 changes: 149 additions & 0 deletions mii/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from typing import Optional, Any, Dict, Tuple, Union

import mii
from mii.backend import MIIClient #, MIIServer
from mii.batching import MIIPipeline, MIIAsyncPipeline
from mii.config import get_mii_config, ModelConfig, MIIConfig
from mii.constants import DeploymentType
from mii.errors import UnknownArgument
from mii.modeling.models import load_model
from mii.score import create_score_file
from mii.modeling.tokenizers import load_tokenizer
from mii.utils import import_score_file


def _parse_kwargs_to_model_config(model_name_or_path: str = "",
model_config: Optional[Dict[str,
Any]] = None,
**kwargs) -> Tuple[ModelConfig,
Dict[str,
Any]]:
if model_config is None:
model_config = {}

assert isinstance(model_config, dict), "model_config must be a dict"

# If model_name_or_path is set in model config, make sure it matches the kwarg
if model_name_or_path:
if "model_name_or_path" in model_config:
assert model_config.get("model_name_or_path") == model_name_or_path, "model_name_or_path in model_config must match model_name_or_path"
model_config["model_name_or_path"] = model_name_or_path

# Fill model_config dict with relevant kwargs, store remaining kwargs in a new dict
remaining_kwargs = {}
for key, val in kwargs.items():
if key in ModelConfig.__dict__["__fields__"]:
if key in model_config:
assert model_config.get(key) == val, f"{key} in model_config must match {key}"
model_config[key] = val
else:
remaining_kwargs[key] = val

# Create the ModelConfig object and return it with remaining kwargs
model_config = ModelConfig(**model_config)
return model_config, remaining_kwargs


def _parse_kwargs_to_mii_config(model_name_or_path: str = "",
model_config: Optional[Dict[str,
Any]] = None,
mii_config: Optional[Dict[str,
Any]] = None,
**kwargs) -> MIIConfig:
# Parse all model_config kwargs
model_config, remaining_kwargs = _parse_kwargs_to_model_config(model_name_or_path=model_name_or_path, model_config=model_config, **kwargs)

if mii_config is None:
mii_config = {}

assert isinstance(mii_config, dict), "mii_config must be a dict"

# Verify that any model_config kwargs match any existing model_config in the mii_config
if "model_config" in mii_config:
assert mii_config.get("model_config") == model_config, "mii_config['model_config'] must match model_config"
else:
mii_config["model_config"] = model_config

# Fill mii_config dict with relevant kwargs, raise error on unknown kwargs
for key, val in remaining_kwargs.items():
if key in MIIConfig.__dict__["__fields__"]:
if key in mii_config:
assert mii_config.get(key) == val, f"{key} in mii_config must match {key}"
mii_config[key] = val
else:
raise UnknownArgument(f"Keyword argument {key} not recognized")

# Return the MIIConfig object
mii_config = MIIConfig(**mii_config)
return mii_config


def client(model_or_deployment_name: str) -> MIIClient:
mii_config = get_mii_config(model_or_deployment_name)

return MIIClient(mii_config)


def serve(model_name_or_path: str = "",
model_config: Optional[Dict[str,
Any]] = None,
mii_config: Optional[Dict[str,
Any]] = None,
**kwargs) -> Union[None,
MIIClient]:
mii_config = _parse_kwargs_to_mii_config(model_name_or_path=model_name_or_path,
model_config=model_config,
mii_config=mii_config,
**kwargs)

#MIIServer(mii_config)
create_score_file(mii_config)

if mii_config.deployment_type == DeploymentType.LOCAL:
import_score_file(mii_config.deployment_name, DeploymentType.LOCAL).init()
return MIIClient(mii_config=mii_config)
if mii_config.deployment_type == DeploymentType.AML:
acr_name = mii.aml_related.utils.get_acr_name()
mii.aml_related.utils.generate_aml_scripts(
acr_name=acr_name,
deployment_name=mii_config.deployment_name,
model_name=mii_config.model_config.model,
task_name=mii_config.model_config.task,
replica_num=mii_config.model_config.replica_num,
instance_type=mii_config.instance_type,
version=mii_config.version,
)
print(
f"AML deployment assets at {mii.aml_related.utils.aml_output_path(mii_config.deployment_name)}"
)
print("Please run 'deploy.sh' to bring your deployment online")


def pipeline(model_name_or_path: str = "",
model_config: Optional[Dict[str,
Any]] = None,
**kwargs) -> MIIPipeline:
model_config, remaining_kwargs = _parse_kwargs_to_model_config(model_name_or_path=model_name_or_path, model_config=model_config, **kwargs)
if remaining_kwargs:
raise UnknownArgument(
f"Keyword argument(s) {remaining_kwargs.keys()} not recognized")

inference_engine = load_model(model_config)
tokenizer = load_tokenizer(model_config)
inference_pipeline = MIIPipeline(inference_engine=inference_engine,
tokenizer=tokenizer,
model_config=model_config)
return inference_pipeline


def async_pipeline(model_config: ModelConfig) -> MIIAsyncPipeline:
inference_engine = load_model(model_config)
tokenizer = load_tokenizer(model_config)
inference_pipeline = MIIAsyncPipeline(inference_engine=inference_engine,
tokenizer=tokenizer,
model_config=model_config)
return inference_pipeline
6 changes: 6 additions & 0 deletions mii/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from .client import MIIClient
from .server import MIIServer
10 changes: 2 additions & 8 deletions mii/client.py → mii/backend/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import requests
from typing import Dict, Any, Callable

from mii.config import get_mii_config, MIIConfig
from mii.config import MIIConfig
from mii.constants import GRPC_MAX_MSG_SIZE, TaskType
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
from mii.task_methods import TASK_METHODS_DICT
from mii.grpc_related.task_methods import TASK_METHODS_DICT


def create_channel(host, port):
Expand Down Expand Up @@ -121,9 +121,3 @@ def destroy_session(self, session_id):
self.task == TaskType.TEXT_GENERATION
), f"Session deletion only available for task '{TaskType.TEXT_GENERATION}'."
self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id))


def client(model_or_deployment_name: str) -> MIIClient:
mii_config = get_mii_config(model_or_deployment_name)

return MIIClient(mii_config)
61 changes: 1 addition & 60 deletions mii/server.py → mii/backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,72 +9,13 @@
import tempfile
import time
from collections import defaultdict
from typing import Optional, Any, Dict, Union, List
from typing import List

from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.config_utils import DeepSpeedConfigModel

import mii
from mii.client import MIIClient
from mii.config import ModelConfig, MIIConfig, ReplicaConfig
from mii.constants import DeploymentType
from mii.logging import logger
from mii.score import create_score_file
from mii.utils import import_score_file


def serve(model_name_or_path: str = "",
model_config: Optional[Dict[str,
Any]] = None,
mii_config: Optional[Dict[str,
Any]] = None,
**kwargs) -> Union[None,
MIIClient]:
if model_config is None:
model_config = {}
if mii_config is None:
mii_config = {}
if model_name_or_path:
if "model_name_or_path" in model_config:
assert model_config.get("model_name_or_path") == model_name_or_path, "model_name_or_path in model_config must match model_name_or_path"
model_config["model_name_or_path"] = model_name_or_path
for key, val in kwargs.items():
if key in ModelConfig.__dict__["__fields__"]:
if key in model_config:
assert model_config.get(key) == val, f"{key} in model_config must match {key}"
model_config[key] = val
elif key in MIIConfig.__dict__["__fields__"]:
if key in mii_config:
assert mii_config.get(key) == val, f"{key} in mii_config must match {key}"
mii_config[key] = val
else:
raise ValueError(f"Invalid keyword argument {key}")
if "model_config" in mii_config:
assert mii_config.get("model_config") == model_config, "model_config in mii_config must match model_config"
mii_config["model_config"] = model_config
mii_config = MIIConfig(**mii_config)

#MIIServer(mii_config)
create_score_file(mii_config)

if mii_config.deployment_type == DeploymentType.LOCAL:
import_score_file(mii_config.deployment_name, DeploymentType.LOCAL).init()
return MIIClient(mii_config=mii_config)
if mii_config.deployment_type == DeploymentType.AML:
acr_name = mii.aml_related.utils.get_acr_name()
mii.aml_related.utils.generate_aml_scripts(
acr_name=acr_name,
deployment_name=mii_config.deployment_name,
model_name=mii_config.model_config.model,
task_name=mii_config.model_config.task,
replica_num=mii_config.model_config.replica_num,
instance_type=mii_config.instance_type,
version=mii_config.version,
)
print(
f"AML deployment assets at {mii.aml_related.utils.aml_output_path(mii_config.deployment_name)}"
)
print("Please run 'deploy.sh' to bring your deployment online")


def config_to_b64_str(config: DeepSpeedConfigModel) -> str:
Expand Down
2 changes: 1 addition & 1 deletion mii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

from mii.constants import DeploymentType, TaskType, ModelProvider
from mii.errors import DeploymentNotFoundError
from mii.modeling.tokenizers import MIITokenizerWrapper
from mii.pydantic_v1 import Field, root_validator
from mii.tokenizers import MIITokenizerWrapper
from mii.utils import generate_deployment_name, get_default_task, import_score_file


Expand Down
4 changes: 4 additions & 0 deletions mii/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@

class DeploymentNotFoundError(Exception):
pass


class UnknownArgument(Exception):
pass
4 changes: 2 additions & 2 deletions mii/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
STREAM_RESPONSE_QUEUE_TIMEOUT,
TaskType,
)
from mii.task_methods import TASK_METHODS_DICT
from mii.client import create_channel
from mii.grpc_related.task_methods import TASK_METHODS_DICT
from mii.backend.client import create_channel
from mii.utils import unpack_proto_query_kwargs

from mii.constants import GenerationFinishReason
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion mii/launch/multi_gpu_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mii.config import ModelConfig
from mii.grpc_related.modelresponse_server import serve_inference, serve_load_balancing
from mii.grpc_related.restful_gateway import RestfulGatewayThread
from mii.pipeline import async_pipeline
from mii.api import async_pipeline


def b64_encoded_config(config_str: str) -> ModelConfig:
Expand Down
File renamed without changes.
File renamed without changes.
46 changes: 0 additions & 46 deletions mii/pipeline.py

This file was deleted.

4 changes: 2 additions & 2 deletions mii/score/score_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ def init():
start_server = False

if start_server:
mii.server.MIIServer(mii_config)
mii.backend.MIIServer(mii_config)

global model
model = None

# In AML deployments both the GRPC client and server are used in the same process
if mii.utils.is_aml():
model = mii.client.MIIClient(mii_config=mii_config)
model = mii.backend.MIIClient(mii_config=mii_config)


def run(request):
Expand Down

0 comments on commit e9f3af3

Please sign in to comment.