Skip to content

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Oct 2, 2025

target #16440
rel #16117

Initial version of automatic memory offloading to host memory using an extended logic for minimizing the prompt reprocessing. The host-memory prompt cache acts as "extra slots" with which we can calculate prefix similarity and decide to hot-swap them into the llama_context if it would reduce the processing.

Still WIP, but probably should be useable already.

Note: mtmd workarounds are starting to cause some headaches. For example server_tokens is not copyable which complicates the cache logic and makes the prompt caching feature incompatible with mtmd.

Server refactor

  • Replace server_slot members with a single server_task
  • Remove server_slot.n_predict
  • Remove prompt truncation logic (obsolete and not useful anymore)
  • slot.task is now const ptr to reflect that the task parameters should not change when it is passed to the slot

TODOs

  • Set memory limit for the host-memory cache from CLI
  • Clean-up implementation
  • Test with agentic workflows
  • Multi-slot tests
  • Fix progress report

@ggerganov ggerganov force-pushed the gg/prompt-cache-ext branch 2 times, most recently from 0787f03 to 5c0cec4 Compare October 3, 2025 18:49
@tommarques56

This comment was marked as spam.

@ggerganov ggerganov force-pushed the gg/prompt-cache-ext branch from 5c0cec4 to 1440ec5 Compare October 7, 2025 07:40
@ggerganov ggerganov changed the base branch from master to gg/server-checkpoints-improve October 7, 2025 07:41
@github-actions github-actions bot added the python python script changes label Oct 7, 2025
@ggerganov ggerganov mentioned this pull request Oct 7, 2025
3 tasks
@ggerganov ggerganov force-pushed the gg/prompt-cache-ext branch from 9de8392 to cf7dd4b Compare October 7, 2025 15:09
@ggerganov
Copy link
Member Author

Looking for some feedback of how this new logic performs in different use cases. I've been testing it with the llama.vscode agent and it significantly improves the experience since we can now use a single server slot without trashing the prompt cache.

The current implementation should work with any model (dense, MoE, SWA, SSM, etc.). I think the default settings should be good for most use cases, though we'll probably add some options to adjust cache limits if needed.

Pay attention to these new messages in the logs:

image

Interested in testing agentic use cases, such as Claude Code and similar, where we have a single large context with various auxilary calls (keyword extraction, summarization, etc.) interleaved. The expectation is that prompt reprocessing should be significantly reduces in such cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples python python script changes server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants