Skip to content

feat: add multi-platform support with ascend and maca#1335

Open
zhangts20 wants to merge 18 commits into
ModelTC:mainfrom
zhangts20:refactor_platform_dev
Open

feat: add multi-platform support with ascend and maca#1335
zhangts20 wants to merge 18 commits into
ModelTC:mainfrom
zhangts20:refactor_platform_dev

Conversation

@zhangts20

Copy link
Copy Markdown

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a unified platform abstraction layer (lightllm/platform) to support multiple hardware backends (CUDA, Ascend/NPU, MUSA, MACA) and implements a new Paged FlashAttention-3 backend (PagedFa3AttBackend). It refactors memory management, attention states, and multimodal models to use device-agnostic target devices and platform-specific runtimes. Feedback on these changes highlights several critical issues: returning CPU tensors instead of device tensors in ReqManager's paged allocation, using a fixed-size attention mask in PagedFa3PrefillAttState that risks out-of-bounds errors, allocating softmax_lse with an incorrect size of 1 during NPU graph capture, omitting data type conversion in BaseLayerWeight._to_device, and missing pinned memory allocation for CPU reference tensors in TpPartBaseModel which introduces synchronization overhead.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

mask = ~need_new_page_mask
if mask.any():
token_idxs[mask] = b_last_mem_index[mask] + 1
return token_idxs

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In the second branch of _alloc_paged_mem_indices, token_idxs is created and computed on the CPU. However, alloc_mem_indices is expected to return a device tensor (as it does in the first branch and when page_size == 1). Returning a CPU tensor here will cause a device mismatch crash during the model forward pass or when asserting model_input.mem_indexes.device == self.target_device. The returned tensor must be moved to the target device.

Suggested change
return token_idxs
return token_idxs.to(device=self.mem_manager.target_device)

Comment on lines +68 to +69
if self.atten_mask is None:
self.atten_mask = torch.triu(torch.ones([2048, 2048]), diagonal=1).to(dtype=torch.int8, device=self.infer_state.input_ids.device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attention mask self.atten_mask is initialized with a fixed size of [2048, 2048]. If the input sequence length (max_kv_seq_len) exceeds 2048, this will cause out-of-bounds errors or incorrect attention computation. It should be dynamically sized or resized when max_kv_seq_len exceeds the current mask size. Additionally, allocating the tensor directly on the target device avoids unnecessary host-to-device copy overhead.

Suggested change
if self.atten_mask is None:
self.atten_mask = torch.triu(torch.ones([2048, 2048]), diagonal=1).to(dtype=torch.int8, device=self.infer_state.input_ids.device)
if self.atten_mask is None or self.atten_mask.shape[0] < self.infer_state.max_kv_seq_len:
mask_size = max(2048, self.infer_state.max_kv_seq_len)
self.atten_mask = torch.triu(
torch.ones([mask_size, mask_size], dtype=torch.int8, device=self.infer_state.input_ids.device),
diagonal=1,
)

v = v.view(-1, self.backend.page_size, N_KV * HEAD_DIM)

output = torch.empty_like(q)
softmax_lse = torch.empty(1, dtype=torch.float16, device=q.device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

During NPU graph capture, tensors must be pre-allocated with their exact shapes. Allocating softmax_lse with size 1 and passing it to .out will cause out-of-bounds memory writes during graph replay, leading to silent memory corruption or crashes. It should be allocated with the correct shape (e.g., [q.shape[0], q.shape[1], 1]) and dtype (typically torch.float32).

Suggested change
softmax_lse = torch.empty(1, dtype=torch.float16, device=q.device)
softmax_lse = torch.empty((q.shape[0], q.shape[1], 1), dtype=torch.float32, device=q.device)

Comment on lines +45 to +46
def _to_device(self, cpu_tensor: torch.Tensor) -> torch.Tensor:
return cpu_tensor.contiguous().to(device=self.target_device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new _to_device method does not convert the tensor to the target data type (self.data_type_), unlike the original _cuda method. This can lead to weight tensors remaining in float32 (or whatever their source format was), causing runtime type mismatches during forward pass or doubling the memory usage.

Suggested change
def _to_device(self, cpu_tensor: torch.Tensor) -> torch.Tensor:
return cpu_tensor.contiguous().to(device=self.target_device)
def _to_device(self, cpu_tensor: torch.Tensor) -> torch.Tensor:
return cpu_tensor.contiguous().to(device=self.target_device, dtype=self.data_type_)

Comment on lines +25 to +26
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Allocating tensors on CPU and then moving them to the device using .to(get_current_device_id()) is inefficient and causes unnecessary host-to-device copies. It is much better to allocate them directly on the target device using device=self.model.target_device.

Suggested change
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32, device=self.model.target_device),
torch.empty(shared_len, dtype=torch.int32, device=self.model.target_device),

Comment on lines +139 to +140
self.b1_cu_q_seq_len_cpu_ref = torch.zeros(self.graph_max_batch_size, dtype=torch.int32)
self.b_cu_kv_seq_len_cpu_ref = torch.zeros(self.graph_max_batch_size, dtype=torch.int32)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The CPU reference tensors self.b1_cu_q_seq_len_cpu_ref and self.b_cu_kv_seq_len_cpu_ref are used for asynchronous copying and updating attention parameters on NPU streams. To avoid host-device synchronization overhead and enable fast asynchronous transfers, these tensors should be allocated with pin_memory=True.

Suggested change
self.b1_cu_q_seq_len_cpu_ref = torch.zeros(self.graph_max_batch_size, dtype=torch.int32)
self.b_cu_kv_seq_len_cpu_ref = torch.zeros(self.graph_max_batch_size, dtype=torch.int32)
self.b1_cu_q_seq_len_cpu_ref = torch.zeros(self.graph_max_batch_size, dtype=torch.int32, pin_memory=True)
self.b_cu_kv_seq_len_cpu_ref = torch.zeros(self.graph_max_batch_size, dtype=torch.int32, pin_memory=True)

total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size)
paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size)
pages = paged_token_idxs.view(-1, page_size)
mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The boolean mask mask is created on the CPU because p_token_len is on the CPU. Indexing the device tensor pages with a CPU mask causes synchronous host-device transfers and slow CPU-GPU synchronization. The mask should be created directly on the device of pages.

Suggested change
mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1)
mask = torch.arange(page_size, device=pages.device) < p_token_len.to(device=pages.device).unsqueeze(1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant