From 95f4ca459269257cb0704000f0dec19e1b354f04 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 8 Jun 2026 07:06:09 +0000 Subject: [PATCH] fix prefill_params when prefill num_reqs > 1024 --- lightllm/common/basemodel/triton_kernel/gen_prefill_params.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py index 8f9172b552..f066766b15 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py @@ -19,7 +19,7 @@ def _gen_cumsum_pad0_kernel( for start_index in range(0, size, BLOCK): current_offs = start_index + offs - in_data = tl.load(b_q_seq_len + offs, mask=current_offs < size, other=0) + in_data = tl.load(b_q_seq_len + current_offs, mask=current_offs < size, other=0) in_data = tl.cumsum(in_data) + start_value start_value = tl.max(in_data, 0) tl.store(b1_cu_q_seq_len + current_offs + 1, in_data, mask=current_offs < size) @@ -30,7 +30,7 @@ def _gen_cumsum_pad0_kernel( start_value = tl.cast(0, tl.int64) for start_index in range(0, size, BLOCK): current_offs = start_index + offs - in_data = tl.load(b_kv_seq_len + offs * b_kv_seq_len_stride_0, mask=current_offs < size, other=0) + in_data = tl.load(b_kv_seq_len + current_offs * b_kv_seq_len_stride_0, mask=current_offs < size, other=0) in_data = tl.cumsum(in_data) + start_value start_value = tl.max(in_data, 0) tl.store(b1_cu_kv_seq_len + current_offs + 1, in_data, mask=current_offs < size)