Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 105 additions & 36 deletions megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,26 @@ def initialize_all_tensors(self) -> None:
self.token_to_position_in_request = torch.empty_like(self.token_to_input_ids)
self.token_to_local_position_within_kv_block = torch.empty_like(self.token_to_input_ids)

# Static tensor addresses of active slices to enable fast inference kernels.
self.active_request_metadata: Dict[str, Tensor] = {}
for label, _, on_gpu in self.request_metadata_types:
if on_gpu:
tensor = torch.empty_like(self.request_metadata[label])
else:
tensor = torch.empty_like(
self.request_metadata[label], device="cpu", pin_memory=True
)
self.active_request_metadata[label] = tensor

self.active_request_ids = torch.empty_like(self.request_ids, dtype=torch.int64)
self.active_request_query_lengths = torch.empty_like(self.request_query_lengths)
self.active_request_output_lengths = torch.empty_like(self.request_output_lengths)
self.active_request_kv_length_offsets = torch.empty_like(self.request_kv_length_offsets)
self.active_request_to_kv_block_ids = torch.empty_like(self.request_to_kv_block_ids)

self.active_sequence_lengths = torch.empty_like(self.request_query_lengths)
self.active_request_last_token_idxs = torch.empty_like(self.request_query_lengths)

# NOTE: Need to build this outside the UVM / TMS context to avoid IMA.
if self.is_hybrid_model:
self.mamba_metadata = MambaMetadata(
Expand Down Expand Up @@ -938,20 +958,69 @@ def cu_kv_lengths(self) -> Tuple[Tensor, Tensor, int]:
self.active_attn_metadata["mha_metadata"].state_data["max_seqlen_k"],
)

def get_active_sequence_lengths(self) -> Tensor:
"""Total sequence length (query + key) for active requests."""
lengths = self.request_kv_length_offsets + self.request_query_lengths
lengths = lengths[self.paused_request_count : self.total_request_count]
return lengths

def get_max_sequence_lengths(self) -> Tensor:
"""Maximum sequence length for active requests."""
return self.request_output_lengths[self.paused_request_count : self.total_request_count]

def get_active_request_count(self):
"""Returns the current number of active requests."""
return self.total_request_count - self.paused_request_count

def build_active_slices(self, batch_size: int):
"""Build the active slices of specific tensors. This is run on every forward step.

If the context is reordered to active -> paused -> finished, this can be graphed.
"""
padded_slice = slice(self.paused_request_count, self.paused_request_count + batch_size)

# Request metadata all needs to be sliced.
for label, _, _ in self.request_metadata_types:
self.active_request_metadata[label][:batch_size].copy_(
self.request_metadata[label][padded_slice], non_blocking=True
)

# The following tensor slices are used in various kernels.
self.active_request_ids[:batch_size].copy_(self.request_ids[padded_slice])
self.active_request_query_lengths[:batch_size].copy_(
self.request_query_lengths[padded_slice]
)
self.active_request_output_lengths[:batch_size].copy_(
self.request_output_lengths[padded_slice]
)
self.active_request_kv_length_offsets[:batch_size].copy_(
self.request_kv_length_offsets[padded_slice]
)
self.active_request_to_kv_block_ids[:batch_size].copy_(
self.request_to_kv_block_ids[padded_slice]
)

self.active_request_output_lengths[:batch_size].copy_(
self.request_output_lengths[padded_slice]
)
self.active_request_kv_length_offsets[:batch_size].copy_(
self.request_kv_length_offsets[padded_slice]
)
self.active_request_to_kv_block_ids[:batch_size].copy_(
self.request_to_kv_block_ids[padded_slice]
)

self.active_sequence_lengths[:batch_size].copy_(
(self.active_request_query_lengths + self.active_request_kv_length_offsets)[:batch_size]
)
graph_scratch_space = torch.cumsum(self.active_request_query_lengths[:batch_size], dim=0)
self.active_request_last_token_idxs[:batch_size].copy_(graph_scratch_space - 1)

def pad_active_slices(self):
"""Pad the active slices of specific tensors."""
# Some tensors need to be padded at the token level.
padding_token_slice = slice(self.active_token_count, self.padded_active_token_count)

self.token_to_block_idx[padding_token_slice] = self.kv_block_allocator.dummy_block_idx
self.token_to_local_position_within_kv_block[padding_token_slice] = 0
self.token_to_position_in_request[padding_token_slice] = 0

# Other tensors need to be padded at the request level.
padding_request_slice = slice(
self.total_request_count - self.paused_request_count,
self.padded_active_request_count,
)

def append_key_value_cache(self, layer_number: int, key: Tensor, value: Tensor) -> None:
"""Append to KV cache.

Expand Down Expand Up @@ -1566,33 +1635,13 @@ def initialize_attention_state(
prefill_req_count=padded_prefill_req_count,
decode_req_count=padded_decode_req_count,
)
self.padded_active_token_count = self.padded_batch_dimensions.token_count
self.padded_active_request_count = self.padded_batch_dimensions.req_count
self.padding_slice = slice(self.active_token_count, self.padded_active_token_count)

# Update token position indexes.
self.token_to_block_idx[self.active_token_count : self.padded_active_token_count] = (
self.kv_block_allocator.dummy_block_idx
)
self.token_to_local_position_within_kv_block[
self.active_token_count : self.padded_active_token_count
] = 0
self.token_to_position_in_request[
self.active_token_count : self.padded_active_token_count
] = 0

self.active_attn_metadata = (
self.graph_attn_metadata # type: ignore[assignment]
if self.using_cuda_graph_this_step()
else self.non_graph_attn_metadata # type: ignore[assignment]
)

# Update cu_query_seq_lengths, max_seqlen_q.
active_slice = slice(self.paused_request_count, self.total_request_count)
query_lengths_view = self.request_query_lengths[active_slice]
request_kv_length_offsets_view = self.request_kv_length_offsets[active_slice]
request_to_kv_block_ids_view = self.request_to_kv_block_ids[active_slice]

attn_dimensions = batch_dimensions
if self.using_cuda_graph_this_step():
# Treat some decode requests as prefill requests to fit the cuda graph batch dimension.
Expand All @@ -1606,17 +1655,38 @@ def initialize_attention_state(
decode_req_count=adjusted_decode_req_count,
)

self.padded_active_token_count = self.padded_batch_dimensions.token_count
self.padded_active_request_count = self.padded_batch_dimensions.req_count
self.padding_slice = slice(self.active_token_count, self.padded_active_token_count)

import os, sys
_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
_dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a")
def _dbg(msg):
_dbg_f.write(f"[ATTN] {msg}\n"); _dbg_f.flush()
print(f"[rank{_rank}] [ATTN] {msg}", flush=True, file=sys.stderr)

_dbg(f"build_active_slices start (padded_req={self.padded_active_request_count}, padded_tok={self.padded_active_token_count}, paused={self.paused_request_count}, total={self.total_request_count})")
self.build_active_slices(self.padded_active_request_count)
_dbg("build_active_slices done")
self.pad_active_slices()
_dbg("pad_active_slices done")

batch_size = self.total_request_count - self.paused_request_count
assert self.active_attn_metadata is not None
_dbg(f"mha_metadata.update start (batch_size={batch_size})")
self.active_attn_metadata["mha_metadata"].update(
request_query_lengths=query_lengths_view,
request_kv_length_offsets=request_kv_length_offsets_view,
request_to_kv_block_ids=request_to_kv_block_ids_view,
request_query_lengths=self.active_request_query_lengths[:batch_size],
request_kv_length_offsets=self.active_request_kv_length_offsets[:batch_size],
request_to_kv_block_ids=self.active_request_to_kv_block_ids[:batch_size],
batch_dimensions=attn_dimensions,
padded_batch_dimensions=self.padded_batch_dimensions,
num_speculative_tokens=self.num_speculative_tokens,
)
_dbg("mha_metadata.update done")

if self.is_hybrid_model:
active_slice = slice(self.paused_request_count, self.total_request_count)
active_mamba_indices_view = self.mamba_metadata.request_to_mamba_state_idx[active_slice]
token_to_request_idx_view = self.token_to_request_idx[: self.active_token_count]
cu_seqlens = self.active_attn_metadata["mha_metadata"].state_data[
Expand Down Expand Up @@ -1774,9 +1844,8 @@ def last_token_logits(self, logits: Tensor) -> Tensor:
f"logits.size(1) ({tuple(logits.shape)}) != "
f"padded_active_token_count ({self.padded_active_token_count})."
)
logits_2d = logits.squeeze(0)
last_token_idxs = torch.cumsum(query_lengths, dim=0) - 1
return logits_2d[last_token_idxs, :]
active_request_count = self.total_request_count - self.paused_request_count
return logits.squeeze(0)[self.active_request_last_token_idxs[:active_request_count], :]

def _compute_prefix_match(
self, req: DynamicInferenceRequest, prefill_chunk_length: int
Expand Down
25 changes: 25 additions & 0 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,16 @@ def create_cuda_graphs(self, reset_context: bool = True):
reset_context (bool): Whether to reset the context after building cuda graphs.
"""

import sys
_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
_dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a")
def _dbg(msg):
_dbg_f.write(f"[CG] {msg}\n"); _dbg_f.flush()
print(f"[rank{_rank}] [CG] {msg}", flush=True, file=sys.stderr)

_dbg(f"create_cuda_graphs start (impl={self.cuda_graph_impl})")
if self.cuda_graph_impl != "local":
_dbg("skipping (not local)")
return

if (
Expand Down Expand Up @@ -393,12 +402,15 @@ def create_cuda_graphs(self, reset_context: bool = True):
)

tbar = enumerate(context.cuda_graph_batch_dimensions_list)
_dbg(f"warmup loop start ({len(context.cuda_graph_batch_dimensions_list)} graphs)")
if HAVE_TQDM:
tbar = tqdm(tbar, total=len(context.cuda_graph_batch_dimensions_list))
for tbar_idx, cuda_graph_batch_dimension in tbar:
_dbg(f"warmup iter {tbar_idx}: context_init start ({cuda_graph_batch_dimension})")
input_ids, position_ids = self.controller._dynamic_step_context_init(
construct_graph_dimensions=cuda_graph_batch_dimension
)
_dbg(f"warmup iter {tbar_idx}: context_init done")
# Progress.
tbar_str = f"cuda graph warmup - {cuda_graph_batch_dimension}"
if HAVE_TQDM:
Expand Down Expand Up @@ -1630,12 +1642,23 @@ async def async_forward(self) -> Tuple[Dict, Dict, float]:
step_time (float): How long this step took.
"""

import sys
_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
_dbg_f = open(f"/tmp/tde_debug_rank{_rank}.log", "a")
def _dbg(msg):
_dbg_f.write(f"[ENGINE] {msg}\n"); _dbg_f.flush()
print(f"[rank{_rank}] [ENGINE] {msg}", flush=True, file=sys.stderr)

_dbg("async_forward enter")

# If suspended, no stepping.
if self.state in (EngineState.SUSPENDED, EngineState.SUSPENDING):
raise EngineSuspendedError(self.context.step_count)

# schedule requests
_dbg("schedule_waiting_requests start")
self.schedule_waiting_requests()
_dbg(f"schedule_waiting_requests done (total={self.context.total_request_count}, paused={self.context.paused_request_count}, tokens={self.context.active_token_count})")

# Saving pre-step state, for printing output below.
is_decode_only = self.context.is_decode_only()
Expand All @@ -1654,7 +1677,9 @@ async def async_forward(self) -> Tuple[Dict, Dict, float]:
self.is_decode_only = is_decode_only

self.step_start_event.record()
_dbg("async_generate_output_tokens_dynamic_batch start")
result = await self.controller.async_generate_output_tokens_dynamic_batch()
_dbg("async_generate_output_tokens_dynamic_batch done")
self.step_end_event.record()
self.step_end_event.synchronize()
step_time = self.step_start_event.elapsed_time(self.step_end_event) / 1e3
Expand Down
Loading