Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 8 additions & 23 deletions dflash/scripts/_prefill_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class PrefillConfig:
keep_ratio: float # 0.015..0.125
drafter_gguf: Optional[Path] # drafter weights (Qwen3-0.6B BF16 GGUF)
drafter_tokenizer_id: str # HF repo ID for drafter vocab
skip_park: bool = False # skip park/unpark on >=32 GB GPUs

@property
def enabled(self) -> bool:
Expand Down Expand Up @@ -93,11 +92,6 @@ def add_cli_flags(ap) -> None:
ap.add_argument("--prefill-drafter-tokenizer", default="Qwen/Qwen3-0.6B",
help="HF repo ID for the drafter tokenizer "
"(default Qwen/Qwen3-0.6B).")
ap.add_argument("--prefill-skip-park", action="store_true", default=False,
help="Skip park/unpark/free-drafter on GPUs with enough VRAM "
"to hold target + draft + scorer simultaneously (e.g. "
"RTX 5090 32 GB). Keeps scorer resident for fast "
"subsequent compressions.")


def config_from_args(args) -> PrefillConfig:
Expand All @@ -115,7 +109,6 @@ def config_from_args(args) -> PrefillConfig:
keep_ratio=args.prefill_keep_ratio,
drafter_gguf=args.prefill_drafter,
drafter_tokenizer_id=args.prefill_drafter_tokenizer,
skip_park=getattr(args, 'prefill_skip_park', False),
)


Expand All @@ -128,17 +121,11 @@ def compress_text_via_daemon(
drafter_tokenizer,
cfg: PrefillConfig,
prompt_text: str,
skip_park: bool = False,
) -> str:
"""Run the daemon's compress + memory dance, return the compressed text.

Caller holds the daemon lock for the full duration. After this returns,
the daemon has its target + draft restored and is ready for ``generate``.

When ``skip_park`` is True (e.g. on 32 GB+ GPUs where all three models
fit in VRAM simultaneously), the park/unpark/free-drafter steps are
skipped. The daemon's compress handler will keep the scorer loaded for
subsequent requests, avoiding the ~2s reload penalty per request.
"""
# 1) drafter-tokenize the prompt
drafter_ids = drafter_tokenizer(prompt_text, return_tensors=None,
Expand All @@ -153,23 +140,21 @@ def compress_text_via_daemon(
for t in drafter_ids:
f.write(struct.pack("<i", int(t)))

if not skip_park:
# 3) park target + draft so drafter has VRAM headroom on a 24 GB card
_send_and_ack(daemon_stdin, r_pipe, "park target\n")
_send_and_ack(daemon_stdin, r_pipe, "park draft\n")
# 3) park target + draft so drafter has VRAM headroom on a 24 GB card
_send_and_ack(daemon_stdin, r_pipe, "park target\n")
_send_and_ack(daemon_stdin, r_pipe, "park draft\n")

# 4) compress: drafter loads, FlashPrefill scoring, emit compressed ids, drafter held
keep_x1000 = int(round(cfg.keep_ratio * 1000))
daemon_stdin.write(
f"compress {path} {keep_x1000} {cfg.drafter_gguf}\n".encode("utf-8"))
f"compress\t{path}\t{keep_x1000}\t{cfg.drafter_gguf}\n".encode("utf-8"))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1: Tab-separated compress output is incompatible with the daemon parser, which only matches compress and parses space-separated args.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At dflash/scripts/_prefill_hook.py, line 150:

<comment>Tab-separated `compress` output is incompatible with the daemon parser, which only matches `compress ` and parses space-separated args.</comment>

<file context>
@@ -153,23 +140,21 @@ def compress_text_via_daemon(
         keep_x1000 = int(round(cfg.keep_ratio * 1000))
         daemon_stdin.write(
-            f"compress {path} {keep_x1000} {cfg.drafter_gguf}\n".encode("utf-8"))
+            f"compress\t{path}\t{keep_x1000}\t{cfg.drafter_gguf}\n".encode("utf-8"))
         daemon_stdin.flush()
         compressed_ids = _drain_until_sentinel(r_pipe)
</file context>
Suggested change
f"compress\t{path}\t{keep_x1000}\t{cfg.drafter_gguf}\n".encode("utf-8"))
f"compress {path} {keep_x1000} {cfg.drafter_gguf}\n".encode("utf-8"))

daemon_stdin.flush()
compressed_ids = _drain_until_sentinel(r_pipe)

if not skip_park:
# 5) free drafter weights + BSA scratch, then restore target + draft
_send_and_ack(daemon_stdin, r_pipe, "free drafter\n")
_send_and_ack(daemon_stdin, r_pipe, "unpark target\n")
_send_and_ack(daemon_stdin, r_pipe, "unpark draft\n")
# 5) free drafter weights + BSA scratch, then restore target + draft
_send_and_ack(daemon_stdin, r_pipe, "free drafter\n")
_send_and_ack(daemon_stdin, r_pipe, "unpark target\n")
_send_and_ack(daemon_stdin, r_pipe, "unpark draft\n")
finally:
try: os.unlink(path)
except Exception: pass
Expand Down
88 changes: 18 additions & 70 deletions dflash/scripts/bench_he.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,35 +178,33 @@
]


def _find_draft_file(root: Path) -> str | None:
def _find_safetensors(root: Path) -> str | None:
if root.is_file():
return str(root) if root.suffix in (".safetensors", ".gguf") else None
return str(root)
if not root.is_dir():
return None
for st in root.rglob("model.safetensors"):
return str(st)
for gguf in root.rglob("*.gguf"):
return str(gguf)
return None


def _resolve_draft() -> str:
env = os.environ.get("DFLASH_DRAFT")
if env:
found = _find_draft_file(Path(env))
found = _find_safetensors(Path(env))
if found:
return found
raise FileNotFoundError(f"DFLASH_DRAFT does not point to a draft file: {env}")
raise FileNotFoundError(f"DFLASH_DRAFT does not point to model.safetensors: {env}")

for candidate in (_LOCAL_DRAFT_FILE, _LOCAL_DRAFT_ROOT):
found = _find_draft_file(candidate)
found = _find_safetensors(candidate)
if found:
return found

raise FileNotFoundError(
"draft model file not found. Expected one of:\n"
"draft model.safetensors not found. Expected one of:\n"
f" - {_LOCAL_DRAFT_FILE}\n"
"Download it as documented in the README, or set DFLASH_DRAFT to an explicit .safetensors/.gguf file or directory."
"Download it as documented in the README, or set DFLASH_DRAFT to an explicit file or directory."
)


Expand Down Expand Up @@ -261,37 +259,17 @@ def run_test_dflash(prompt_path: Path, n_gen: int, fast_rollback: bool,
print("STDERR:", r.stderr[-2000:])
raise RuntimeError(f"test_dflash exited {r.returncode}")

# Parse output. The target layer-split harness prints both prefill and
# decode lines, so avoid the older "first tok/s wins" regexp there.
# Parse output
out = r.stdout
m_prefill = re.search(
r"\[target-split\] prefill tokens=(\d+) time=(\d+(?:\.\d+)?) s speed=(\d+(?:\.\d+)?) tok/s",
out,
)
m_decode_split = re.search(
r"\[target-split-dflash\] decode tokens=(\d+) time=(\d+(?:\.\d+)?) s speed=(\d+(?:\.\d+)?) tok/s",
out,
)
m_decode_default = re.search(
r"\[dflash\] generated \d+ tokens in \d+(?:\.\d+)? s\s+->\s+(\d+(?:\.\d+)?) tok/s",
out,
)
m_tps = re.search(r"(\d+(?:\.\d+)?)\s+tok/s", out)
m_commit = re.search(r"avg commit/step=(\d+(?:\.\d+)?)", out)
m_accept = re.search(r"accepted=(\d+)/(\d+) \((\d+(?:\.\d+)?)%", out)
m_steps = re.search(r"(\d+) draft steps", out)
if not ((m_decode_split or m_decode_default or m_tps) and m_commit and m_accept and m_steps):
if not (m_tps and m_commit and m_accept and m_steps):
print("STDOUT tail:", out[-2000:])
raise RuntimeError("failed to parse output")
if m_decode_split:
tok_s = float(m_decode_split.group(3))
elif m_decode_default:
tok_s = float(m_decode_default.group(1))
else:
tok_s = float(m_tps.group(1))
return {
"tok_s": tok_s,
"prefill_tok_s": float(m_prefill.group(3)) if m_prefill else None,
"tok_s": float(m_tps.group(1)),
"commit_per_step": float(m_commit.group(1)),
"accepted": int(m_accept.group(1)),
"total_draft_pos": int(m_accept.group(2)),
Expand Down Expand Up @@ -322,26 +300,12 @@ def main():
help="Visible CUDA device id for the target backend")
ap.add_argument("--draft-gpu", type=int, default=None,
help="Visible CUDA device id for the draft backend")
ap.add_argument("--target-gpus", default=None,
help="Comma-separated target GPU ids for the layer-split harness")
ap.add_argument("--target-layer-split", default=None,
help="Comma-separated layer split weights matching --target-gpus")
ap.add_argument("--target-split-load-draft", action="store_true",
help="Load the draft alongside the target layer-split harness")
ap.add_argument("--target-split-dflash", action="store_true",
help="Run chain DFlash decode through the target layer-split harness")
ap.add_argument("--max-ctx", type=int, default=None,
help="Forward --max-ctx=N to test_dflash")
ap.add_argument("--prefill-ubatch", type=int, default=None,
help="Set DFLASH27B_PREFILL_UBATCH for target split prefill")
ap.add_argument("--cuda-visible-devices", default=None,
help="Optional CUDA_VISIBLE_DEVICES override for test_dflash")
ap.add_argument("--target-tokenizer",
default=os.environ.get("DFLASH_TOKENIZER", "Qwen/Qwen3.5-27B"),
default=os.environ.get("DFLASH_TOKENIZER", "Qwen/Qwen3.6-27B"),
help="HuggingFace tokenizer repo for the target. Defaults to "
"$DFLASH_TOKENIZER, then Qwen/Qwen3.5-27B. Override for "
"Qwen3.6 or other variants, e.g. "
"--target-tokenizer Qwen/Qwen3.6-27B")
"$DFLASH_TOKENIZER, then Qwen/Qwen3.6-27B.")
args = ap.parse_args()

# Tokenized prompts are cached at TMPDIR/he_prompt_<slug>_NN.bin so
Expand Down Expand Up @@ -375,8 +339,8 @@ def main():
print(f"[bench] skipping tokenize (reusing {_prompt_path(0, tok_slug).parent})")

print(f"\n[bench] mode={args.mode} n_gen={args.n_gen}")
print(f"{'prompt':28s} {'steps':>6s} {'AL':>6s} {'pct%':>6s} {'prefill':>8s} {'decode':>8s}")
print("-" * 72)
print(f"{'prompt':28s} {'steps':>6s} {'AL':>6s} {'pct%':>6s} {'tok/s':>8s}")
print("-" * 62)

extra_args = []
if args.draft_feature_mirror:
Expand All @@ -385,29 +349,17 @@ def main():
extra_args.append(f"--target-gpu={args.target_gpu}")
if args.draft_gpu is not None:
extra_args.append(f"--draft-gpu={args.draft_gpu}")
if args.target_gpus:
extra_args.append(f"--target-gpus={args.target_gpus}")
if args.target_layer_split:
extra_args.append(f"--target-layer-split={args.target_layer_split}")
if args.target_split_load_draft:
extra_args.append("--target-split-load-draft")
if args.target_split_dflash:
extra_args.append("--target-split-dflash")
if args.max_ctx is not None:
extra_args.append(f"--max-ctx={args.max_ctx}")

extra_env = {}
if args.cuda_visible_devices:
extra_env["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices
if args.prefill_ubatch is not None:
extra_env["DFLASH27B_PREFILL_UBATCH"] = str(args.prefill_ubatch)

results = []
for i, (name, _) in enumerate(PROMPTS):
path = _prompt_path(i, tok_slug)
try:
r = run_test_dflash(path, args.n_gen,
fast_rollback=(args.mode == "fast" and not args.target_split_dflash),
fast_rollback=(args.mode == "fast"),
ddtree_budget=args.ddtree_budget,
ddtree_temp=args.ddtree_temp,
ddtree_no_chain_seed=args.ddtree_no_chain_seed,
Expand All @@ -417,10 +369,9 @@ def main():
print(f" [{i:02d}] {name:26s} FAILED: {e}")
continue
results.append((name, r))
prefill_s = f"{r['prefill_tok_s']:8.2f}" if r["prefill_tok_s"] is not None else f"{'n/a':>8s}"
print(
f" {name:26s} {r['steps']:6d} {r['commit_per_step']:6.2f} "
f"{r['pct']:6.1f} {prefill_s} {r['tok_s']:8.2f}"
f"{r['pct']:6.1f} {r['tok_s']:8.2f}"
)

if not results:
Expand All @@ -431,12 +382,9 @@ def main():
mean_al = sum(r["commit_per_step"] for _, r in results) / n
mean_tps = sum(r["tok_s"] for _, r in results) / n
mean_pct = sum(r["pct"] for _, r in results) / n
prefill_vals = [r["prefill_tok_s"] for _, r in results if r["prefill_tok_s"] is not None]
mean_prefill = sum(prefill_vals) / len(prefill_vals) if prefill_vals else None

print("-" * 72)
prefill_s = f"{mean_prefill:8.2f}" if mean_prefill is not None else f"{'n/a':>8s}"
print(f"{'MEAN':28s} {'':6s} {mean_al:6.2f} {mean_pct:6.1f} {prefill_s} {mean_tps:8.2f}")
print("-" * 62)
print(f"{'MEAN':28s} {'':6s} {mean_al:6.2f} {mean_pct:6.1f} {mean_tps:8.2f}")
print()
print(f"commit/step range: {min(r['commit_per_step'] for _,r in results):.2f} - "
f"{max(r['commit_per_step'] for _,r in results):.2f}")
Expand Down
Loading