Skip to content

Commit 546f1ca

Browse files
committed
Remove unused params in attn
Signed-off-by: yizhang-nv <[email protected]>
1 parent 04b1126 commit 546f1ca

File tree

15 files changed

+110
-141
lines changed

15 files changed

+110
-141
lines changed

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ class AttentionOp
127127
public:
128128
// Attention packed mask input (used by context FMHA).
129129
uint32_t const* attention_packed_mask = nullptr;
130-
kernels::KVBlockArray::DataType* host_block_offsets = nullptr;
131130
int32_t batch_size = 0;
132131
float2 const* mrope_rotary_cos_sin = nullptr;
133132

@@ -182,7 +181,6 @@ class AttentionOp
182181
ss << "context_buf_sf: " << this->context_buf_sf << std::endl;
183182
ss << "key_value_cache: " << (half*) this->key_value_cache << std::endl;
184183
ss << "block_offsets: " << this->block_offsets << std::endl;
185-
ss << "host_block_offsets: " << this->host_block_offsets << std::endl;
186184
ss << "host_primary_pool_pointer: " << this->host_primary_pool_pointer << std::endl;
187185
ss << "host_secondary_pool_pointer: " << this->host_secondary_pool_pointer << std::endl;
188186
ss << "batch_size: " << this->batch_size << std::endl;

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,19 @@ void initBindings(nb::module_& m)
4242
nb::arg("output_sf") = std::nullopt, nb::arg("workspace_") = std::nullopt, nb::arg("sequence_length"),
4343
nb::arg("host_past_key_value_lengths"), nb::arg("host_total_kv_lens"), nb::arg("context_lengths"),
4444
nb::arg("host_context_lengths"), nb::arg("host_request_types"),
45-
nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_block_offsets") = std::nullopt,
46-
nb::arg("host_kv_cache_pool_pointers") = std::nullopt, nb::arg("host_kv_cache_pool_mapping") = std::nullopt,
47-
nb::arg("cache_indirection") = std::nullopt, nb::arg("kv_scale_orig_quant") = std::nullopt,
48-
nb::arg("kv_scale_quant_orig") = std::nullopt, nb::arg("out_scale") = std::nullopt,
49-
nb::arg("rotary_inv_freq") = std::nullopt, nb::arg("rotary_cos_sin") = std::nullopt,
50-
nb::arg("latent_cache") = std::nullopt, nb::arg("q_pe") = std::nullopt,
51-
nb::arg("block_ids_per_seq") = std::nullopt, nb::arg("attention_sinks") = std::nullopt, nb::arg("is_fused_qkv"),
52-
nb::arg("update_kv_cache"), nb::arg("predicted_tokens_per_seq"), nb::arg("layer_idx"), nb::arg("num_heads"),
53-
nb::arg("num_kv_heads"), nb::arg("head_size"), nb::arg("tokens_per_block") = std::nullopt,
54-
nb::arg("max_num_requests"), nb::arg("max_context_length"), nb::arg("attention_window_size"),
55-
nb::arg("sink_token_length"), nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"),
56-
nb::arg("q_scaling"), nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"),
57-
nb::arg("rotary_embedding_base"), nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"),
45+
nb::arg("kv_cache_block_offsets") = std::nullopt, nb::arg("host_kv_cache_pool_pointers") = std::nullopt,
46+
nb::arg("host_kv_cache_pool_mapping") = std::nullopt, nb::arg("cache_indirection") = std::nullopt,
47+
nb::arg("kv_scale_orig_quant") = std::nullopt, nb::arg("kv_scale_quant_orig") = std::nullopt,
48+
nb::arg("out_scale") = std::nullopt, nb::arg("rotary_inv_freq") = std::nullopt,
49+
nb::arg("rotary_cos_sin") = std::nullopt, nb::arg("latent_cache") = std::nullopt,
50+
nb::arg("q_pe") = std::nullopt, nb::arg("block_ids_per_seq") = std::nullopt,
51+
nb::arg("attention_sinks") = std::nullopt, nb::arg("is_fused_qkv"), nb::arg("update_kv_cache"),
52+
nb::arg("predicted_tokens_per_seq"), nb::arg("layer_idx"), nb::arg("num_heads"), nb::arg("num_kv_heads"),
53+
nb::arg("head_size"), nb::arg("tokens_per_block") = std::nullopt, nb::arg("max_num_requests"),
54+
nb::arg("max_context_length"), nb::arg("attention_window_size"), nb::arg("sink_token_length"),
55+
nb::arg("beam_width"), nb::arg("mask_type"), nb::arg("quant_mode"), nb::arg("q_scaling"),
56+
nb::arg("position_embedding_type"), nb::arg("rotary_embedding_dim"), nb::arg("rotary_embedding_base"),
57+
nb::arg("rotary_embedding_scale_type"), nb::arg("rotary_embedding_scales"),
5858
nb::arg("rotary_embedding_max_position_info"), nb::arg("use_paged_context_fmha"),
5959
nb::arg("attention_input_type") = std::nullopt, nb::arg("is_mla_enable"),
6060
nb::arg("chunked_prefill_buffer_batch_size") = std::nullopt, nb::arg("q_lora_rank") = std::nullopt,

cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
858858

859859
int max_blocks_per_sequence = 0;
860860
kernels::KVBlockArray::DataType* block_offsets = nullptr;
861-
kernels::KVBlockArray::DataType* host_block_offsets = nullptr;
862861
void* host_primary_pool_pointer = nullptr;
863862
void* host_secondary_pool_pointer = nullptr;
864863
if (useKVCache() && mPagedKVCache)
@@ -882,10 +881,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
882881
= reinterpret_cast<kernels::KVBlockArray::DataType*>(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)])
883882
+ poolOffset + seqOffset;
884883

885-
host_block_offsets
886-
= reinterpret_cast<kernels::KVBlockArray::DataType*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS)])
887-
+ poolOffset + seqOffset;
888-
889884
auto const* const typed_host_pool_pointers
890885
= static_cast<char* const*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_POINTERS)]);
891886

@@ -1046,7 +1041,6 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
10461041
common_enqueue_params.max_past_kv_length = max_context_kv_len;
10471042
EnqueueContextParams<T> enqueue_params{common_enqueue_params};
10481043
enqueue_params.attention_packed_mask = attention_packed_mask;
1049-
enqueue_params.host_block_offsets = host_block_offsets;
10501044
enqueue_params.batch_size = batch_size;
10511045
enqueue_params.mrope_rotary_cos_sin = mrope_rotary_cos_sin;
10521046
enqueue_params.total_kv_len = enqueue_params.num_tokens;

cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ namespace tensorrt_llm::plugins
5555
// all elements must be identical.
5656
// 8. past_key_value_pool [batch_size, 2, local_num_kv_heads, max_seq_len, head_size] or
5757
// block_offsets [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
58-
// 8.1 host_block_offsets [batch_size, 2, max_blocks_per_seq] if paged kv cache (optional)
59-
// 8.2 host_pool_pointers [2] if paged kv cache (optional)
58+
// 8.1 host_pool_pointers [2] if paged kv cache (optional)
6059
// 9. kv_cache_quantization_scale [1] (optional)
6160
// 10. kv_cache_dequantization_scale [1] (optional)
6261
// 11. attention_output_quantization_scale [1] (on device, optional)

cpp/tensorrt_llm/pybind/thop/bindings.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,19 @@ void initBindings(pybind11::module_& m)
4242
py::arg("output_sf") = std::nullopt, py::arg("workspace_") = std::nullopt, py::arg("sequence_length"),
4343
py::arg("host_past_key_value_lengths"), py::arg("host_total_kv_lens"), py::arg("context_lengths"),
4444
py::arg("host_context_lengths"), py::arg("host_request_types"),
45-
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_block_offsets") = std::nullopt,
46-
py::arg("host_kv_cache_pool_pointers") = std::nullopt, py::arg("host_kv_cache_pool_mapping") = std::nullopt,
47-
py::arg("cache_indirection") = std::nullopt, py::arg("kv_scale_orig_quant") = std::nullopt,
48-
py::arg("kv_scale_quant_orig") = std::nullopt, py::arg("out_scale") = std::nullopt,
49-
py::arg("rotary_inv_freq") = std::nullopt, py::arg("rotary_cos_sin") = std::nullopt,
50-
py::arg("latent_cache") = std::nullopt, py::arg("q_pe") = std::nullopt,
51-
py::arg("block_ids_per_seq") = std::nullopt, py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"),
52-
py::arg("update_kv_cache"), py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"),
53-
py::arg("num_kv_heads"), py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt,
54-
py::arg("max_num_requests"), py::arg("max_context_length"), py::arg("attention_window_size"),
55-
py::arg("sink_token_length"), py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"),
56-
py::arg("q_scaling"), py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"),
57-
py::arg("rotary_embedding_base"), py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"),
45+
py::arg("kv_cache_block_offsets") = std::nullopt, py::arg("host_kv_cache_pool_pointers") = std::nullopt,
46+
py::arg("host_kv_cache_pool_mapping") = std::nullopt, py::arg("cache_indirection") = std::nullopt,
47+
py::arg("kv_scale_orig_quant") = std::nullopt, py::arg("kv_scale_quant_orig") = std::nullopt,
48+
py::arg("out_scale") = std::nullopt, py::arg("rotary_inv_freq") = std::nullopt,
49+
py::arg("rotary_cos_sin") = std::nullopt, py::arg("latent_cache") = std::nullopt,
50+
py::arg("q_pe") = std::nullopt, py::arg("block_ids_per_seq") = std::nullopt,
51+
py::arg("attention_sinks") = std::nullopt, py::arg("is_fused_qkv"), py::arg("update_kv_cache"),
52+
py::arg("predicted_tokens_per_seq"), py::arg("layer_idx"), py::arg("num_heads"), py::arg("num_kv_heads"),
53+
py::arg("head_size"), py::arg("tokens_per_block") = std::nullopt, py::arg("max_num_requests"),
54+
py::arg("max_context_length"), py::arg("attention_window_size"), py::arg("sink_token_length"),
55+
py::arg("beam_width"), py::arg("mask_type"), py::arg("quant_mode"), py::arg("q_scaling"),
56+
py::arg("position_embedding_type"), py::arg("rotary_embedding_dim"), py::arg("rotary_embedding_base"),
57+
py::arg("rotary_embedding_scale_type"), py::arg("rotary_embedding_scales"),
5858
py::arg("rotary_embedding_max_position_info"), py::arg("use_paged_context_fmha"),
5959
py::arg("attention_input_type") = std::nullopt, py::arg("is_mla_enable"),
6060
py::arg("chunked_prefill_buffer_batch_size") = std::nullopt, py::arg("q_lora_rank") = std::nullopt,

0 commit comments

Comments
 (0)