Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,42 @@
cmake_minimum_required(VERSION 3.10)
project(MainProject VERSION 1.0)

find_package(Git QUIET)

if(GIT_FOUND)
execute_process(
COMMAND ${GIT_EXECUTABLE} describe --tags --long --match "v*"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE GIT_DESCRIBE_OUTPUT
ERROR_QUIET
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(GIT_DESCRIBE_OUTPUT MATCHES "^v([0-9]+\\.[0-9]+\\.[0-9]+)-([0-9]+)-g([0-9a-f]+)$")
set(GIT_VERSION "${CMAKE_MATCH_1}")
set(GIT_DISTANCE "${CMAKE_MATCH_2}")
set(GIT_HASH "${CMAKE_MATCH_3}")
if(GIT_DISTANCE STREQUAL "0")
set(DETECTED_VERSION "${GIT_VERSION}")
else()
set(DETECTED_VERSION "${GIT_VERSION}+git${GIT_HASH}")
endif()
message(STATUS "Version from git tag: ${DETECTED_VERSION}")
endif()
endif()

if(NOT DEFINED DETECTED_VERSION OR DETECTED_VERSION STREQUAL "")
set(DETECTED_VERSION "0.0.0")
message(WARNING "Could not detect version from git tag, using fallback: ${DETECTED_VERSION}")
endif()

# Strip +gitXXXXXXX suffix for CMake project VERSION (must be numeric X.Y.Z)
if(DETECTED_VERSION MATCHES "^([0-9]+\\.[0-9]+\\.[0-9]+)")
set(NUMERIC_VERSION "${CMAKE_MATCH_1}")
else()
set(NUMERIC_VERSION "0.0.0")
endif()

project(MainProject VERSION ${NUMERIC_VERSION})
message(STATUS "Project version: ${PROJECT_VERSION} (full: ${DETECTED_VERSION})")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Expand Down
1 change: 0 additions & 1 deletion VERSION

This file was deleted.

42 changes: 42 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,54 @@ for arg in "$@"; do
BUILD_TYPE="release"
shift
;;
--clean)
BUILD_TYPE="clean"
shift
;;
*)
# Unknown option
;;
esac
done

# Handle clean
if [ "$BUILD_TYPE" = "clean" ]; then
echo "=== Cleaning all build artifacts ==="

# Remove CMake build directory
if [ -d "build" ]; then
rm -rf build
echo "Removed build/"
fi

# Remove compiled .so files in package directory
find flexkv -name "*.so" -type f -delete -print | sed 's/^/Removed /'

# Remove copied libs directory
if [ -d "flexkv/lib" ]; then
rm -rf flexkv/lib
echo "Removed flexkv/lib/"
fi

# Remove Python build artifacts
find . -maxdepth 2 -name "*.egg-info" -type d | while read d; do
rm -rf "$d"
echo "Removed $d"
done
# Only remove top-level dist/ (Python build output), not csrc/dist/ source directory
if [ -d "dist" ]; then
rm -rf dist
echo "Removed dist/"
fi
find . -name "__pycache__" -type d | while read d; do
rm -rf "$d"
echo "Removed $d"
done

echo "=== Clean completed ==="
exit 0
fi

echo "=== Building in ${BUILD_TYPE} mode ==="

# Install submodules
Expand Down
6 changes: 3 additions & 3 deletions flexkv/cache/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m

if cache_config.enable_cpu:
if cache_config.enable_p2p_cpu:
self.cpu_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.CPU, meta=self.redis_meta) #TODO
self.cpu_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.CPU, meta=self.redis_meta, pp_rank=self.model_config.pp_rank, pp_size=self.model_config.pp_size) #TODO
elif self.index_accel:
self.cpu_cache_engine = CacheEngineAccel(DeviceType.CPU,
cache_config.num_cpu_blocks,
Expand All @@ -420,7 +420,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m
self.cache_engines[DeviceType.CPU] = self.cpu_cache_engine
if cache_config.enable_ssd:
if cache_config.enable_p2p_ssd:
self.ssd_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.SSD, meta=self.redis_meta) #TODO
self.ssd_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.SSD, meta=self.redis_meta, pp_rank=self.model_config.pp_rank, pp_size=self.model_config.pp_size) #TODO
elif self.index_accel:
self.ssd_cache_engine = CacheEngineAccel(DeviceType.SSD,
cache_config.num_ssd_blocks,
Expand All @@ -445,7 +445,7 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, redis_m
if cache_config.enable_remote:
if cache_config.enable_kv_sharing:
# Build PCFSCacheEngine from CacheConfig directly (replacing RemotePCFSCacheEngine) TODO
self.remote_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.REMOTE, meta=self.redis_meta)
self.remote_cache_engine = HierarchyLRCacheEngine.from_cache_config(cache_config, self.node_id, DeviceType.REMOTE, meta=self.redis_meta, pp_rank=self.model_config.pp_rank, pp_size=self.model_config.pp_size)
elif self.index_accel:
self.remote_cache_engine = CacheEngineAccel(DeviceType.REMOTE,
cache_config.num_remote_blocks,
Expand Down
34 changes: 23 additions & 11 deletions flexkv/cache/hie_cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(self,
evict_start_threshold: float = 1.0,
hit_reward_seconds: int = 0,
eviction_policy: str = "lru",
meta: Optional[RedisMeta] = None) -> None:
meta: Optional[RedisMeta] = None,
pp_rank: int = 0,
pp_size: int = 1) -> None:
if num_total_blocks <= 0:
raise ValueError(f"Invalid num_total_blocks: {num_total_blocks}")
if tokens_per_block <= 0 or (tokens_per_block & (tokens_per_block - 1)) != 0:
Expand Down Expand Up @@ -90,6 +92,8 @@ def __init__(self,
self.num_total_blocks = num_total_blocks
self.evict_ratio = evict_ratio
self.evict_start_threshold = evict_start_threshold
self.pp_rank = pp_rank
self.pp_size = pp_size

# cumulative statistics: for analyzing distributed KV reuse benefits
self._stats_total_queried_tokens = 0 # total tokens queried
Expand All @@ -102,17 +106,22 @@ def start(self) -> None:
if self._meta is None:
raise ValueError("RedisMeta is not provided; ensure from_cache_config stores it or pass it to start().")
#TODO can we use like this to distinguish the different tree pairs?
# Determine base block key prefix by device type
if self.device_type == DeviceType.REMOTE:
local_ch_block_key = "PCFSB"
remote_ch_block_key = "PCFSB"
base_key = "PCFSB"
elif self.device_type == DeviceType.CPU:
local_ch_block_key = "CPUB"
remote_ch_block_key = "CPUB"
base_key = "CPUB"
elif self.device_type == DeviceType.SSD:
local_ch_block_key = "SSDB"
remote_ch_block_key = "SSDB"
base_key = "SSDB"
else:
raise ValueError(f"Invalid device type: {self.device_type}")

if self.pp_size > 1:
local_ch_block_key = f"{base_key}:pp{self.pp_rank}"
remote_ch_block_key = f"{base_key}:pp{self.pp_rank}"
else:
local_ch_block_key = base_key
remote_ch_block_key = base_key
self.remote_ch = self._meta.get_redis_meta_channel(remote_ch_block_key)
self.local_ch = self._meta.get_redis_meta_channel(local_ch_block_key)
# Load and store mapping of node_id -> file_nodeids from Redis
Expand Down Expand Up @@ -443,7 +452,7 @@ def recycle(self, physical_blocks: np.ndarray) -> None:

#TODO pfcs may not work now
@classmethod
def pcfs_ce_from_cache_config(cls, cache_config: "CacheConfig", node_id: int, meta: Optional[RedisMeta] = None) -> "HierarchyLRCacheEngine":
def pcfs_ce_from_cache_config(cls, cache_config: "CacheConfig", node_id: int, meta: Optional[RedisMeta] = None, pp_rank: int = 0, pp_size: int = 1) -> "HierarchyLRCacheEngine":
"""Create a PCFSCacheEngine from CacheConfig.

This replaces RemotePCFSCacheEngine. It wires both local and remote
Expand Down Expand Up @@ -522,14 +531,16 @@ def pcfs_ce_from_cache_config(cls, cache_config: "CacheConfig", node_id: int, me
local_safety_ttl_ms=int(GLOBAL_CONFIG_FROM_ENV.safety_ttl_ms),
eviction_policy=GLOBAL_CONFIG_FROM_ENV.eviction_policy,
meta=meta,
pp_rank=pp_rank,
pp_size=pp_size,
)

#TODO is this enough for peercpu and peerssd?
@classmethod
def from_cache_config(cls, cache_config: "CacheConfig", node_id: int, device_type: DeviceType, meta: Optional[RedisMeta] = None) -> "HierarchyLRCacheEngine":
def from_cache_config(cls, cache_config: "CacheConfig", node_id: int, device_type: DeviceType, meta: Optional[RedisMeta] = None, pp_rank: int = 0, pp_size: int = 1) -> "HierarchyLRCacheEngine":

if device_type == DeviceType.REMOTE:
return cls.pcfs_ce_from_cache_config(cache_config, node_id, meta)
return cls.pcfs_ce_from_cache_config(cache_config, node_id, meta, pp_rank=pp_rank, pp_size=pp_size)
else:
# select correct blocks configuration based on device_type
if device_type == DeviceType.CPU:
Expand Down Expand Up @@ -563,6 +574,7 @@ def from_cache_config(cls, cache_config: "CacheConfig", node_id: int, device_typ
hit_reward_seconds=int(GLOBAL_CONFIG_FROM_ENV.hit_reward_seconds),
eviction_policy=GLOBAL_CONFIG_FROM_ENV.eviction_policy,
meta=meta,
pp_rank=pp_rank,
pp_size=pp_size,
)
raise ValueError("Invalid device type: {cache_config.device_type}")

56 changes: 40 additions & 16 deletions flexkv/cache/redis_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ def register_node(self) -> Optional[int]:
"local_ip": self.local_ip, # Keep for backward compatibility
"uuid": self.uuid,
"status": "active",
"timestamp": str(int(time.time()))
"timestamp": str(int(time.time())),
"pp_rank": str(getattr(self, 'pp_rank', 0)),
"pp_size": str(getattr(self, 'pp_size', 1)),
})

# Publish node update event
Expand Down Expand Up @@ -504,13 +506,14 @@ def add_node_ids(self, node_ids: Iterable[Union[int, str]]) -> int:
# rpush returns the new length of the list
return int(r.rpush(f"pcfs:{nid}", *values))

def regist_buffer(self, mrs: Iterable[object]) -> int:
def regist_buffer(self, mrs: Iterable[object], pp_rank: int = 0, pp_size: int = 1) -> int:
"""Register RDMA memory regions in Redis.

Each element in mrs can be one of:
- dict with keys {"buffer_ptr": ..., "buffer_size": ...}
- tuple/list (buffer_ptr, buffer_size)
Stored as hash: key = buffer:<node_id>:<buffer_ptr>, field "buffer_size" = <buffer_size>.
Stored as hash: key = buffer:<node_id>[:pp<pp_rank>]:<buffer_ptr>, field "buffer_size" = <buffer_size>.
When pp_size > 1, pp_rank is included in the key for isolation.
Returns the number of regions processed.
"""
nid = self.get_node_id()
Expand All @@ -527,53 +530,68 @@ def regist_buffer(self, mrs: Iterable[object]) -> int:
continue
if ptr is None or size is None:
continue
key = f"buffer:{nid}:{int(ptr)}"
if pp_size > 1:
key = f"buffer:{nid}:pp{pp_rank}:{int(ptr)}"
else:
key = f"buffer:{nid}:{int(ptr)}"
pipe.hset(key, mapping={"buffer_size": int(size)})
processed += 1
if processed:
pipe.execute()
return processed

def unregist_buffer(self, buffer_ptr: Union[int, str]) -> bool:
def unregist_buffer(self, buffer_ptr: Union[int, str], pp_rank: int = 0, pp_size: int = 1) -> bool:
"""Unregister a previously registered RDMA memory region by buffer_ptr.

Looks up key buffer:<node_id>:<buffer_ptr> and deletes it if present.
Looks up key buffer:<node_id>[:pp<pp_rank>]:<buffer_ptr> and deletes it if present.
Returns True if the key existed and was deleted, otherwise False.
"""
nid = self.get_node_id()
key = f"buffer:{nid}:{int(buffer_ptr)}"
if pp_size > 1:
key = f"buffer:{nid}:pp{pp_rank}:{int(buffer_ptr)}"
else:
key = f"buffer:{nid}:{int(buffer_ptr)}"
r = self._client()
exists = bool(r.exists(key))
if exists:
r.delete(key)
return True
return False

def regist_node_meta(self, node_id: int, addr: str, zmq_addr: str, cpu_buffer_ptr: int, ssd_buffer_ptr: int) -> None:
def regist_node_meta(self, node_id: int, addr: str, zmq_addr: str, cpu_buffer_ptr: int, ssd_buffer_ptr: int, pp_rank: int = 0, pp_size: int = 1) -> None:
"""Register node meta information as a Redis hash.

Key: meta:<node_id>
Key: meta:<node_id>[:pp<pp_rank>]
When pp_size > 1, pp_rank is included in the key for PP rank isolation.
Fields: node_id (int), addr (str), cpu_buffer_ptr (int), ssd_buffer_ptr (int)
"""
r = self._client()
key = f"meta:{int(node_id)}"
if pp_size > 1:
key = f"meta:{int(node_id)}:pp{pp_rank}"
else:
key = f"meta:{int(node_id)}"
r.hset(key, mapping={
"node_id": int(node_id),
"addr": str(addr),
"zmq_addr": str(zmq_addr),
"cpu_buffer_ptr": int(cpu_buffer_ptr),
"ssd_buffer_ptr": int(ssd_buffer_ptr),
"pp_rank": int(pp_rank),
"pp_size": int(pp_size),
})

def get_node_meta(self, node_id: int) -> dict:
def get_node_meta(self, node_id: int, pp_rank: int = 0, pp_size: int = 1) -> dict:
"""Get node meta information from Redis.

Reads key meta:<node_id> and returns a dict with fields:
Reads key meta:<node_id>[:pp<pp_rank>] and returns a dict with fields:
node_id (int), addr (str), cpu_buffer_ptr (int), ssd_buffer_ptr (int).
Returns empty dict if the key does not exist.
"""
r = self._client()
key = f"meta:{int(node_id)}"
if pp_size > 1:
key = f"meta:{int(node_id)}:pp{pp_rank}"
else:
key = f"meta:{int(node_id)}"
data = r.hgetall(key)
if not data:
return {}
Expand All @@ -588,10 +606,16 @@ def get_node_meta(self, node_id: int) -> dict:
out["ssd_buffer_ptr"] = int(sb) if sb is not None and sb != "" else 0
return out

def unregist_node_meta(self, node_id: int) -> bool:
"""Unregister node meta by node_id. Returns True if deleted."""
def unregist_node_meta(self, node_id: int, pp_rank: int = 0, pp_size: int = 1) -> bool:
"""Unregister node meta by node_id. Returns True if deleted.

When pp_size > 1, only deletes the key for the specified pp_rank.
"""
r = self._client()
key = f"meta:{int(node_id)}"
if pp_size > 1:
key = f"meta:{int(node_id)}:pp{pp_rank}"
else:
key = f"meta:{int(node_id)}"
return bool(r.delete(key))


Expand Down
16 changes: 16 additions & 0 deletions flexkv/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@
from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType
from flexkv.common.debug import flexkv_logger


@dataclass
class IndexerCacheConfig:
"""Indexer-specific cache configuration, embedded inside CacheConfig."""
# Indexer head layout
head_size: int = 0 # qk_rope_head_dim for DSA/NSA models
num_kv_heads: int = 1 # typically 1 for MLA-style indexer
dtype: torch.dtype = torch.uint8 # indexer storage dtype (fp8 quantized)
page_size: int = 1


@dataclass
class ModelConfig:
num_layers: int = 1
Expand All @@ -22,6 +33,8 @@ class ModelConfig:
# parallel configs
tp_size: int = 1
dp_size: int = 1
pp_size: int = 1
pp_rank: int = 0

@property
def token_size_in_bytes(self) -> int:
Expand All @@ -46,6 +59,9 @@ class CacheConfig:
num_tmp_cpu_blocks: int = 500 # only used when distributed ssd p2p, it controls the number blocks of temp cpu buffer which used for copy data from ssd to cpu


# Indexer configuration
indexer: Optional[IndexerCacheConfig] = None

# mempool capacity configs
num_cpu_blocks: int = 1000000
num_ssd_blocks: int = 10000000
Expand Down
7 changes: 7 additions & 0 deletions flexkv/common/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
class CompletedOp:
graph_id: int
op_id: int
transfer_type: Optional[str] = None
num_blocks: int = 0
num_bytes: int = 0

def is_graph_completed(self) -> bool:
return self.op_id == -1
Expand Down Expand Up @@ -96,6 +99,10 @@ class TransferOp:
remote_node_ids: Optional[np.ndarray] = None
# used for distributed cpu and ssd
src_block_node_ids: Optional[np.ndarray] = None
# pending_count tracks how many workers (main KV + indexer) have not yet completed this op.
# Initialized to 1; incremented before submitting to indexer worker.
# _scheduler_loop decrements it on each worker completion; finalization happens only when it reaches 0.
pending_count: int = 1

def __post_init__(self) -> None:
if self.transfer_type != TransferType.VIRTUAL and \
Expand Down
Loading