-
Notifications
You must be signed in to change notification settings - Fork 865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] return hidden states #3364
base: main
Are you sure you want to change the base?
Conversation
fdbd188
to
7c73a30
Compare
This is good to see. But could change our documents to demonstrate the usage and add unit tests to your feature? |
test/srt/test_srt_engine.py
Outdated
@@ -184,6 +184,28 @@ def test_7_engine_offline_throughput(self): | |||
result = throughput_test(server_args=server_args, bench_args=bench_args) | |||
self.assertGreater(result["total_throughput"], 3000) | |||
|
|||
def test_8_return_hidden_states(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quite a strange name.
Also, should assert that the hidden state is all close with hugging face.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, all the names in this file are strange. Maybe we should remove the numbers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cant quite get the huggingface tensor to be similar. It only works when there is less than or equal to 1 decode. When I tried with 2 decodes, the test fails. The values seem wildly different, so maybe its not about numerics. Will debug further
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.25it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 2.25it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:06<00:00, 3.36it/s]
[128000, 15724, 374] [264, 1938, 315]
tensor([[[ 0.9727, 0.1924, 0.6133, ..., -0.8438, -0.2852, 0.0645],
[ 0.6797, 3.2656, 1.9453, ..., -5.3750, -4.0625, 1.9922],
[ 0.7812, 1.3750, 1.3047, ..., -5.0625, -3.6094, -0.2715],
[ 2.8906, 1.3516, 1.9609, ..., -4.0312, -3.0312, -0.3809],
[-0.2949, 4.1250, 0.2578, ..., -3.0469, -5.5938, 0.9219]]],
device='cuda:0', dtype=torch.bfloat16)
===
tensor([[ 1.0078, 0.1855, 0.6328, ..., -0.8750, -0.2812, 0.0601],
[ 0.6953, 3.2500, 1.9375, ..., -5.3750, -4.0312, 1.9453],
[ 0.7656, 1.4062, 1.3203, ..., -5.0938, -3.6250, -0.2715],
[ 0.1807, -0.6289, 0.6953, ..., 0.5977, -0.5742, -0.0889],
[-0.1406, -0.0605, -0.3848, ..., 0.4863, -0.4062, 0.0684]],
dtype=torch.bfloat16)
F
======================================================================
FAIL: test_8_engine_return_hidden_states (test_srt_engine.TestSRTEngine.test_8_engine_return_hidden_states)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/ubuntu/sglang/test/srt/test_srt_engine.py", line 238, in test_8_engine_return_hidden_states
self.assertTrue(
AssertionError: False is not true
----------------------------------------------------------------------
Ran 1 test in 18.811s
Thanks. I will try to get some one familiar with hidden state to help. |
https://github.com/sgl-project/sglang/blob/main/test/srt/models/test_generation_models.py You can check this, may gonna help. |
Motivation
This PR intends to add the
return_hidden_states
argument to ServerArgs which makes the results contain the last layer hidden states inoutput["meta_info"]["hidden_states"]
.These hidden states are useful for example for verifying computations. (e.g. https://arxiv.org/abs/2501.16007)
Modifications
return_hidden_states
to ServerArgscapture_hidden_mode
to accomodatereturn_hidden_states
return_hidden_states
andhidden_states
to necessary dataclassesScript used to test changes
Checklist