diff --git a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py index 8f9172b55..f066766b1 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py @@ -19,7 +19,7 @@ def _gen_cumsum_pad0_kernel( for start_index in range(0, size, BLOCK): current_offs = start_index + offs - in_data = tl.load(b_q_seq_len + offs, mask=current_offs < size, other=0) + in_data = tl.load(b_q_seq_len + current_offs, mask=current_offs < size, other=0) in_data = tl.cumsum(in_data) + start_value start_value = tl.max(in_data, 0) tl.store(b1_cu_q_seq_len + current_offs + 1, in_data, mask=current_offs < size) @@ -30,7 +30,7 @@ def _gen_cumsum_pad0_kernel( start_value = tl.cast(0, tl.int64) for start_index in range(0, size, BLOCK): current_offs = start_index + offs - in_data = tl.load(b_kv_seq_len + offs * b_kv_seq_len_stride_0, mask=current_offs < size, other=0) + in_data = tl.load(b_kv_seq_len + current_offs * b_kv_seq_len_stride_0, mask=current_offs < size, other=0) in_data = tl.cumsum(in_data) + start_value start_value = tl.max(in_data, 0) tl.store(b1_cu_kv_seq_len + current_offs + 1, in_data, mask=current_offs < size)