Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
expose top-p, top-k, and temperature to users; refactor batch processing
Browse files Browse the repository at this point in the history
mrwyattii committed Nov 7, 2023
1 parent 96edd72 commit f44e5fe
Showing 3 changed files with 114 additions and 145 deletions.
4 changes: 2 additions & 2 deletions mii/batching/generation/samplers.py
Original file line number Diff line number Diff line change
@@ -44,8 +44,8 @@ def __call__(
logits = logits.float()
sampler = Categorical(logits=logits)
next_tokens = sampler.sample()
logprobs = sampler.log_prob(next_tokens)
return next_tokens, logprobs
#logprobs = sampler.log_prob(next_tokens)
return next_tokens #, logprobs


class GreedySampler(BaseGenerationSampler):
147 changes: 55 additions & 92 deletions mii/batching/postprocess.py
Original file line number Diff line number Diff line change
@@ -2,98 +2,61 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import itertools
from collections import defaultdict
from typing import Any, Dict
from typing import TYPE_CHECKING, Any, Dict

import torch

from .generation.logit_processors import (
TopKLogitProcessor,
TopPLogitProcessor,
TemperatureLogitProcessor,
NucleusSamplingLogitProcessor,
)
from .generation.samplers import LogitsSampler, GreedySampler
from .generation.stop_criterion import (
EosGenerationStopCriterion,
NewLineDelimitedStopCriterion,
)

LOGITS_PROCESSORS = {
"TopK": TopKLogitProcessor,
"TopP": TopPLogitProcessor,
"Temperature": TemperatureLogitProcessor,
"NucleusSampling": NucleusSamplingLogitProcessor,
}

SAMPLERS = {"Logits": LogitsSampler, "Greedy": GreedySampler}

STOP_CRITERIA = {
"EosGeneration": EosGenerationStopCriterion,
"NewLineDelimited": NewLineDelimitedStopCriterion,
}

DEFAULT_LOGITS_PROCESSOR = {"name": "TopP", "args": {"top_p": 0.9}}
DEFAULT_SAMPLER = {"name": "Logits"}
DEFAULT_STOP_CRITERION = {"name": "EosGeneration"}


def _create_postprocessor(config: Dict[str,
Any],
classes: Dict[str,
Any],
default_args: Dict[str,
Any] = {}):
assert "name" in config

name = config["name"]
if name not in classes:
raise ValueError(f"Unknown postprocessor {name}")
args = config["args"] if "args" in config else {}
args.update(default_args)
return classes[name](**args)


def _run_batch_postprocess(input_tensor,
requests,
get_processor_fn,
get_result_fn=lambda x: x):
processor_map = {
get_processor_fn(r).get_key(): get_processor_fn(r)
for r in requests
}
processor_indices = defaultdict(list)

for i, r in enumerate(requests):
key = get_processor_fn(r).get_key()
processor_indices[key].append(i)

indice_list = []
outputs_list = []
for key, indices in processor_map.items():
processor = processor_map[key]
indices = processor_indices[key]
input_filtered = input_tensor[indices]
output_filtered = get_result_fn(processor(input_filtered))
indice_list.append(indices)
outputs_list.append(output_filtered)

indice = list(itertools.chain.from_iterable(indice_list))
outputs = torch.cat(outputs_list, dim=0)
return outputs[torch.argsort(torch.tensor(indice))]


def run_batch_logit_processor(input_tensor, requests):
return _run_batch_postprocess(input_tensor, requests, lambda r: r.logit_processor)


def run_batch_sampler(input_tensor, requests):
return _run_batch_postprocess(input_tensor,
requests,
lambda r: r.sampler,
lambda x: x[0])


def run_batch_stop_criterion(input_tensor, requests):
return _run_batch_postprocess(input_tensor, requests, lambda r: r.stop_criterion)
if TYPE_CHECKING:
from mii.batching.ragged_batching import RaggedRequestBatch


def run_batch_processing(input_tensor: torch.Tensor,
requests: "RaggedRequestBatch",
processor_fns: Dict[str,
Any]) -> torch.Tensor:
idx_list = []
output_list = []
for key, process_fn in processor_fns.items():
idx = [i for i, r in enumerate(requests) if key in r.post_processing]
if not idx:
continue
filtered_input = input_tensor[idx]
idx_list.extend(idx)
output_list.append(process_fn(filtered_input))
if not output_list:
return input_tensor
output = torch.cat(output_list, dim=0)
return output[torch.argsort(torch.tensor(idx_list))]


def run_batch_logit_processing(input_logits: torch.Tensor,
requests: "RaggedRequestBatch",
processor_map: Dict[str,
Any]) -> torch.Tensor:
top_k_fns = {k: v for k, v in processor_map.items() if "TopK" in k}
top_p_fns = {k: v for k, v in processor_map.items() if "TopP" in k}
temp_fns = {k: v for k, v in processor_map.items() if "Temp" in k}

# Apply TopK, TopP, and Temperature in sequence
output_logits = input_logits
for fns in (top_k_fns, top_p_fns, temp_fns):
output_logits = run_batch_processing(output_logits, requests, fns)
return output_logits


def run_batch_sampler(input_logits: torch.Tensor,
requests: "RaggedRequestBatch",
processor_map: Dict[str,
Any]) -> torch.Tensor:
sampler_fns = {k: v for k, v in processor_map.items() if "Sampler" in k}
next_tokens = run_batch_processing(input_logits, requests, sampler_fns)
return next_tokens


def run_batch_stop_criterion(next_tokens: torch.Tensor,
requests: "RaggedRequestBatch",
processor_map: Dict[str,
Any]) -> torch.Tensor:
stop_fns = {k: v for k, v in processor_map.items() if "Stop" in k}
done_tokens = run_batch_processing(next_tokens, requests, stop_fns)
return done_tokens
108 changes: 57 additions & 51 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
@@ -20,20 +20,13 @@
from deepspeed.accelerator import get_accelerator
from deepspeed.utils.timer import SynchronizedWallClockTimer

from mii.batching.generation.logit_processors import BaseLogitProcessor
from mii.batching.generation.samplers import BaseGenerationSampler
from mii.batching.generation.stop_criterion import BaseGenerationStopCriterion
from mii.batching.generation.logit_processors import TopPLogitProcessor, TopKLogitProcessor, TemperatureLogitProcessor
from mii.batching.generation.samplers import LogitsSampler
from mii.batching.generation.stop_criterion import EosGenerationStopCriterion
from mii.batching.postprocess import (
_create_postprocessor,
run_batch_logit_processor,
run_batch_logit_processing,
run_batch_sampler,
run_batch_stop_criterion,
DEFAULT_LOGITS_PROCESSOR,
DEFAULT_SAMPLER,
DEFAULT_STOP_CRITERION,
LOGITS_PROCESSORS,
SAMPLERS,
STOP_CRITERIA,
)
from mii.batching.utils import sync_debug, profiler
from mii.constants import GenerationFinishReason, ZMQ_RECV_TIMEOUT
@@ -130,9 +123,7 @@ class RaggedRequest:
max_length: int
max_new_tokens: int
last_in_prompt: bool
logit_processor: BaseLogitProcessor
sampler: BaseGenerationSampler
stop_criterion: BaseGenerationStopCriterion
post_processing: List[object]
stream: bool = False
ignore_eos: bool = False

@@ -308,7 +299,9 @@ def __init__(self, inference_engine, tokenizer, model_config):
self.scheduled_seq_num = 0
self.scheduled_req_blocks = 0

self.logit_processor = run_batch_logit_processor
# TODO: we will need to prune self._post_processors for long running deployments
self._post_processors = {}
self.logit_processor = run_batch_logit_processing
self.sampler = run_batch_sampler
self.stop_criterion = run_batch_stop_criterion

@@ -432,9 +425,15 @@ def _process_logits(
running_requests: RaggedRequestBatch) -> Tuple[torch.Tensor,
torch.Tensor]:
next_token_logits = next_token_logits[:, :self.vocab_size]
next_token_logits = self.logit_processor(next_token_logits, running_requests)
next_tokens = self.sampler(next_token_logits, running_requests)
done_tokens = self.stop_criterion(next_tokens, running_requests)
next_token_logits = self.logit_processor(next_token_logits,
running_requests,
self._post_processors)
next_tokens = self.sampler(next_token_logits,
running_requests,
self._post_processors)
done_tokens = self.stop_criterion(next_tokens,
running_requests,
self._post_processors)
next_tokens = next_tokens.to(torch.device("cpu"), non_blocking=False)
return next_tokens, done_tokens

@@ -536,51 +535,62 @@ def make_request(self,
uid: int,
input_tokens: torch.Tensor,
kwargs: Dict) -> List[RaggedRequest]:
prompt_length = len(input_tokens)
max_length = kwargs.pop("max_length", self.max_length)
max_new_tokens = kwargs.pop("max_new_tokens", max_length - len(input_tokens))
assert max_length > prompt_length, "prompt_length must be less than max_length"
max_new_tokens = kwargs.pop("max_new_tokens", max_length - prompt_length)
stream = kwargs.pop("stream", False)
ignore_eos = kwargs.pop("ignore_eos", False)
# TODO: Add back this check
# if self.policy.get_length(uid) + len(token_ids) >= max_length:
# raise ValueError(f"Session {uid} has reached max length {max_length}.")

postprocess_config = kwargs.pop("postprocess_config", {})
accepted_keys = ("logit_processor", "sampler", "stop_criterion")
for key in postprocess_config.keys():
if key not in accepted_keys:
raise ValueError(
f"Unknown postprocess_config keyword {key}. Accepted keywords are {accepted_keys}"
)
logit_processor = _create_postprocessor(
postprocess_config.get("logit_processor",
DEFAULT_LOGITS_PROCESSOR),
LOGITS_PROCESSORS,
)
sampler = _create_postprocessor(
postprocess_config.get("sampler",
DEFAULT_SAMPLER),
SAMPLERS)
stop_criterion = _create_postprocessor(
postprocess_config.get("stop_criterion",
DEFAULT_STOP_CRITERION),
STOP_CRITERIA,
{"tokenizer": self.tokenizer},
)
post_processing = []

top_p = kwargs.pop("top_p", 0.9)
top_p_name = f"TopP_{top_p}"
if top_p_name not in self._post_processors:
self._post_processors[top_p_name] = TopPLogitProcessor(top_p=top_p)
post_processing.append(top_p_name)

top_k = kwargs.pop("top_k", None)
if top_k is not None:
top_k_name = f"TopK_{top_k}"
if top_k_name not in self._post_processors:
self._post_processors[top_k_name] = TopKLogitProcessor(top_k=top_k)
post_processing.append(top_k_name)

temp = kwargs.pop("temperature", None)
if temp is not None:
temp_name = f"Temp_{temp}"
if temp_name not in self._post_processors:
self._post_processors[temp_name] = TemperatureLogitProcessor(
temperature=temp)
post_processing.append(temp_name)

sampler_name = "Sampler"
if sampler_name not in self._post_processors:
self._post_processors[sampler_name] = LogitsSampler()
post_processing.append(sampler_name)

stop_name = "Stop"
if stop_name not in self._post_processors:
self._post_processors[stop_name] = EosGenerationStopCriterion(
tokenizer=self.tokenizer)
post_processing.append(stop_name)

assert kwargs == {}, f"Unknown keyword arguments {kwargs}"

return [
RaggedRequest(
uid=uid,
input_tokens=input_tokens,
prompt_length=len(input_tokens),
prompt_length=prompt_length,
seq_length=0,
max_length=max_length,
max_new_tokens=max_new_tokens,
last_in_prompt=True,
logit_processor=logit_processor,
sampler=sampler,
stop_criterion=stop_criterion,
post_processing=post_processing,
stream=stream,
ignore_eos=ignore_eos,
)
@@ -631,9 +641,7 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch:
max_length=None,
max_new_tokens=None,
last_in_prompt=None,
logit_processor=None,
sampler=None,
stop_criterion=None,
post_processing=None,
stream=None,
))

@@ -778,9 +786,7 @@ def destroy_session(self,
max_length=None,
max_new_tokens=None,
last_in_prompt=None,
logit_processor=None,
sampler=None,
stop_criterion=None,
post_processing=None,
stream=None,
))
self.uids.remove(uid)

0 comments on commit f44e5fe

Please sign in to comment.