Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] return hidden states #3364

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

Jackmin801
Copy link

@Jackmin801 Jackmin801 commented Feb 7, 2025

Motivation

This PR intends to add the return_hidden_states argument to ServerArgs which makes the results contain the last layer hidden states in output["meta_info"]["hidden_states"].
These hidden states are useful for example for verifying computations. (e.g. https://arxiv.org/abs/2501.16007)

Modifications

  • Add return_hidden_states to ServerArgs
  • Changed the logic to determine capture_hidden_mode to accomodate return_hidden_states
  • Modify scheduler process_batch_results to save the hidden state to the Req
  • Add return_hidden_states and hidden_states to necessary dataclasses

Script used to test changes

# launch the offline engine
import asyncio
from transformers import AutoTokenizer
import sglang as sgl

def main():
    MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    llm = sgl.Engine(
        model_path=MODEL_NAME,
        skip_tokenizer_init=True,
        disable_cuda_graph=False,
        return_hidden_states=False,
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 10}

    input_ids = tokenizer(prompts).input_ids
    #outputs = llm.generate(input_ids=input_ids, sampling_params=sampling_params)
    outputs = llm.generate(prompts, sampling_params=sampling_params)
    for input_id, output in zip(input_ids, outputs):
        print("===============================")
        print(input_id)
        print(output)
        print()
        if "token_ids" in output:
            print(input_id, output["token_ids"], len(input_id), len(output["token_ids"]))
        else:
            print(output['text'])
        if "hidden_states" in output["meta_info"]:
            print(
                [i.shape for i in output["meta_info"]["hidden_states"]],
                len(output["meta_info"]["hidden_states"]),
            )

if __name__ == "__main__":
    main()

Checklist

  • Format your code according to the Code Formatting with Pre-Commit.
  • Add unit tests as outlined in the Running Unit Tests.
  • Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.

@zhaochenyang20
Copy link
Collaborator

This is good to see. But could change our documents to demonstrate the usage and add unit tests to your feature?

docs/backend/server_arguments.md Outdated Show resolved Hide resolved
@@ -184,6 +184,28 @@ def test_7_engine_offline_throughput(self):
result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertGreater(result["total_throughput"], 3000)

def test_8_return_hidden_states(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quite a strange name.

Also, should assert that the hidden state is all close with hugging face.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, all the names in this file are strange. Maybe we should remove the numbers?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cant quite get the huggingface tensor to be similar. It only works when there is less than or equal to 1 decode. When I tried with 2 decodes, the test fails. The values seem wildly different, so maybe its not about numerics. Will debug further

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  2.25it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  2.25it/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:06<00:00,  3.36it/s]
[128000, 15724, 374] [264, 1938, 315]
tensor([[[ 0.9727,  0.1924,  0.6133,  ..., -0.8438, -0.2852,  0.0645],
         [ 0.6797,  3.2656,  1.9453,  ..., -5.3750, -4.0625,  1.9922],
         [ 0.7812,  1.3750,  1.3047,  ..., -5.0625, -3.6094, -0.2715],
         [ 2.8906,  1.3516,  1.9609,  ..., -4.0312, -3.0312, -0.3809],
         [-0.2949,  4.1250,  0.2578,  ..., -3.0469, -5.5938,  0.9219]]],
       device='cuda:0', dtype=torch.bfloat16)
===
tensor([[ 1.0078,  0.1855,  0.6328,  ..., -0.8750, -0.2812,  0.0601],
        [ 0.6953,  3.2500,  1.9375,  ..., -5.3750, -4.0312,  1.9453],
        [ 0.7656,  1.4062,  1.3203,  ..., -5.0938, -3.6250, -0.2715],
        [ 0.1807, -0.6289,  0.6953,  ...,  0.5977, -0.5742, -0.0889],
        [-0.1406, -0.0605, -0.3848,  ...,  0.4863, -0.4062,  0.0684]],
       dtype=torch.bfloat16)
F
======================================================================
FAIL: test_8_engine_return_hidden_states (test_srt_engine.TestSRTEngine.test_8_engine_return_hidden_states)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/ubuntu/sglang/test/srt/test_srt_engine.py", line 238, in test_8_engine_return_hidden_states
    self.assertTrue(
AssertionError: False is not true

----------------------------------------------------------------------
Ran 1 test in 18.811s

@zhaochenyang20
Copy link
Collaborator

Thanks. I will try to get some one familiar with hidden state to help.

@zhaochenyang20
Copy link
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants