From 2fc644c672d7e1b782c3f3635b0eb3df744737db Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 17 Jun 2024 14:35:38 +0000 Subject: [PATCH 1/9] add dockerfile --- Dockerfile-rocm | 130 ++++++++++++++++++++++++++ backends/python/Makefile-flash-att-v2 | 21 +++++ 2 files changed, 151 insertions(+) create mode 100644 Dockerfile-rocm create mode 100644 backends/python/Makefile-flash-att-v2 diff --git a/Dockerfile-rocm b/Dockerfile-rocm new file mode 100644 index 00000000..f9e56631 --- /dev/null +++ b/Dockerfile-rocm @@ -0,0 +1,130 @@ +FROM rocm/dev-ubuntu-22.04:6.0.2 AS base-builder + +ENV SCCACHE=0.5.4 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache +ENV PATH="/root/.cargo/bin:${PATH}" + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + libssl-dev \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +# Donwload and configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y +RUN cargo install cargo-chef --locked + +FROM base-builder AS planner + +WORKDIR /usr/src + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN cargo chef prepare --recipe-path recipe.json + +FROM base-builder AS builder + +ARG CUDA_COMPUTE_CAP=80 +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG ACTIONS_CACHE_URL +ARG ACTIONS_RUNTIME_TOKEN +ARG SCCACHE_GHA_ENABLED + +WORKDIR /usr/src + +COPY --from=planner /usr/src/recipe.json recipe.json + +RUN cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + unzip \ + && rm -rf /var/lib/apt/lists/* + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY proto proto + +FROM builder as http-builder + +RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s + +FROM builder as grpc-builder + +RUN cargo build --release --bin text-embeddings-router -F python -F grpc --no-default-features && sccache -s + +FROM rocm/dev-ubuntu-22.04:6.0.2 as base + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + git \ + python3-dev \ + rocthrust-dev \ + hipsparse-dev \ + hipblas-dev \ + hipblaslt-dev \ + rocblas-dev \ + hiprand-dev \ + rocrand-dev \ + && rm -rf /var/lib/apt/lists/* + + +# Keep in sync with `server/pyproject.toml +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTORCH_VERSION='2.3.0' +ARG ROCM_VERSION='6.0.2' +ARG PYTHON_VERSION='3.10.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +RUN curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + mamba init && \ + rm ~/mambaforge.sh + +# Install flash-attention, torch dependencies +RUN pip install numpy einops ninja --no-cache-dir + +RUN pip install torch --index-url https://download.pytorch.org/whl/rocm6.0 + +ARG DEFAULT_USE_FLASH_ATTENTION=True +COPY backends/python/Makefile-flash-att-v2 Makefile-flash-att-v2 +RUN make -f Makefile-flash-att-v2 install-flash-attention-v2-rocm + +ENV HUGGINGFACE_HUB_CACHE=/data \ + PORT=80 \ + USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION + +FROM base as grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM base + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] diff --git a/backends/python/Makefile-flash-att-v2 b/backends/python/Makefile-flash-att-v2 new file mode 100644 index 00000000..ba90a74d --- /dev/null +++ b/backends/python/Makefile-flash-att-v2 @@ -0,0 +1,21 @@ +flash_att_v2_commit_cuda := v2.5.9.post1 +flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 + +build-flash-attention-v2-cuda: + pip install -U packaging wheel + pip install flash-attn==$(flash_att_v2_commit_cuda) + +install-flash-attention-v2-cuda: build-flash-attention-v2-cuda + echo "Flash v2 installed" + +build-flash-attention-v2-rocm: + if [ ! -d 'flash-attention-v2' ]; then \ + pip install -U packaging ninja --no-cache-dir && \ + git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \ + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \ + git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \ + fi + +install-flash-attention-v2-rocm: build-flash-attention-v2-rocm + cd flash-attention-v2 && \ + GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install From 37d29316243298970ba2d3e889f0757dc2d74e7a Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:05:09 +0000 Subject: [PATCH 2/9] working cls pooling --- backends/python/server/pyproject.toml | 7 +- backends/python/server/requirements.txt | 12 --- .../server/text_embeddings_server/cli.py | 3 +- .../text_embeddings_server/models/__init__.py | 18 ++-- .../models/flash_bert.py | 59 ++++-------- .../server/text_embeddings_server/server.py | 3 +- .../utils/flash_attn.py | 92 ------------------- router/src/lib.rs | 1 + 8 files changed, 38 insertions(+), 157 deletions(-) delete mode 100644 backends/python/server/text_embeddings_server/utils/flash_attn.py diff --git a/backends/python/server/pyproject.toml b/backends/python/server/pyproject.toml index 96fcaf9e..839ff27a 100644 --- a/backends/python/server/pyproject.toml +++ b/backends/python/server/pyproject.toml @@ -20,7 +20,7 @@ loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" -torch = { version = "^2.0.1" } +torch = { version = "==2.3.1" } [tool.poetry.extras] @@ -33,6 +33,11 @@ name = "pytorch-gpu-src" url = "https://download.pytorch.org/whl/cu118" priority = "explicit" +[[tool.poetry.source]] +name = "pytorch-gpu-src-rocm" +url = "https://download.pytorch.org/whl/rocm6.0" +priority = "explicit" + [tool.pytest.ini_options] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] diff --git a/backends/python/server/requirements.txt b/backends/python/server/requirements.txt index 89ca314d..2d089e41 100644 --- a/backends/python/server/requirements.txt +++ b/backends/python/server/requirements.txt @@ -4,20 +4,13 @@ charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" idna==3.4 ; python_version >= "3.9" and python_version < "3.13" -jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13" -mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" -networkx==3.1 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" @@ -27,15 +20,10 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" -pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" -sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" -torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/python/server/text_embeddings_server/cli.py b/backends/python/server/text_embeddings_server/cli.py index 4c627515..70e60d80 100644 --- a/backends/python/server/text_embeddings_server/cli.py +++ b/backends/python/server/text_embeddings_server/cli.py @@ -23,6 +23,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, + pooling_mode: Optional[str] = None, ): # Remove default handler logger.remove() @@ -47,7 +48,7 @@ def serve( # Downgrade enum into str for easier management later on dtype = None if dtype is None else dtype.value - server.serve(model_path, dtype, uds_path) + server.serve(model_path, dtype, uds_path, pooling_mode) if __name__ == "__main__": diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 47867187..7f480b33 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -15,17 +15,19 @@ torch.set_grad_enabled(False) FLASH_ATTENTION = True -try: - from text_embeddings_server.models.flash_bert import FlashBert -except ImportError as e: - logger.warning(f"Could not import Flash Attention enabled models: {e}") - FLASH_ATTENTION = False +# try: +from text_embeddings_server.models.flash_bert import FlashBert +# except ImportError as e: +# logger.warning(f"Could not import Flash Attention enabled models: {e}") +# FLASH_ATTENTION = False if FLASH_ATTENTION: __all__.append(FlashBert) -def get_model(model_path: Path, dtype: Optional[str]): +class + +def get_model(model_path: Path, dtype: Optional[str], pooling_mode: str): if dtype == "float32": dtype = torch.float32 elif dtype == "float16": @@ -52,8 +54,8 @@ def get_model(model_path: Path, dtype: Optional[str]): and dtype in [torch.float16, torch.bfloat16] and FLASH_ATTENTION ): - return FlashBert(model_path, device, dtype) + return FlashBert(model_path, device, dtype, pooling_mode) else: - return DefaultModel(model_path, device, dtype) + return DefaultModel(model_path, device, dtype, pooling_mode) raise NotImplementedError diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 50b8d70d..67176f2c 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -8,46 +8,15 @@ from transformers.models.bert import BertConfig from opentelemetry import trace -# Flash attention imports -import dropout_layer_norm - from text_embeddings_server.models import Model from text_embeddings_server.models.types import FlashBatch, Embedding -from text_embeddings_server.utils.flash_attn import attention +from text_embeddings_server.layers.attention import attention +from text_embeddings_server.layers.layernorm import FastLayerNorm +from loguru import logger tracer = trace.get_tracer(__name__) -class FastLayerNorm: - def __init__(self, prefix, handle, device, dtype, config: BertConfig): - self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) - self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) - self.variance_epsilon = config.layer_norm_eps - - def forward(self, hidden_states, residual=None): - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - False, - ) - if res is None: - res = hidden_states - - return normed_hidden_states, res - - class BertEmbeddings: def __init__(self, prefix, handle, device, dtype, config: BertConfig): self.word_embeddings_weight = ( @@ -217,7 +186,7 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids) encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s) - return encoder_outputs[cu_seqlens[:-1]] + return encoder_outputs class FlashBert(Model): @@ -236,6 +205,7 @@ def batch_type(self) -> Type[FlashBatch]: @tracer.start_as_current_span("embed") def embed(self, batch: FlashBatch) -> List[Embedding]: + logger.info(f"batch.input_ids {batch.input_ids}") embedding = self.model.forward( input_ids=batch.input_ids, token_type_ids=batch.token_type_ids, @@ -243,11 +213,16 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: cu_seqlens=batch.cu_seqlens, max_s=batch.max_s, ) - cpu_results = embedding.view(-1).tolist() - return [ - Embedding( - values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] - ) - for i in range(len(batch)) - ] + if True: + embedding = embedding[batch.cu_seqlens[:-1]] + logger.info(f"embedding {embedding.shape}") + cpu_results = embedding.view(-1).tolist() + + return [ + Embedding( + values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] + ) + for i in range(len(batch)) + ] + elif diff --git a/backends/python/server/text_embeddings_server/server.py b/backends/python/server/text_embeddings_server/server.py index d0a43ace..2c99cf79 100644 --- a/backends/python/server/text_embeddings_server/server.py +++ b/backends/python/server/text_embeddings_server/server.py @@ -37,6 +37,7 @@ def serve( model_path: Path, dtype: Optional[str], uds_path: Path, + pooling_mode: Optional[str], ): async def serve_inner( model_path: Path, @@ -45,7 +46,7 @@ async def serve_inner( unix_socket = f"unix://{uds_path}" try: - model = get_model(model_path, dtype) + model = get_model(model_path, dtype, pooling_mode) except Exception: logger.exception("Error when initializing model") raise diff --git a/backends/python/server/text_embeddings_server/utils/flash_attn.py b/backends/python/server/text_embeddings_server/utils/flash_attn.py deleted file mode 100644 index 1d325351..00000000 --- a/backends/python/server/text_embeddings_server/utils/flash_attn.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import torch - -from loguru import logger - -if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": - raise ImportError("`USE_FLASH_ATTENTION` is false.") - -if not torch.cuda.is_available(): - raise ImportError("CUDA is not available") - -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 -is_sm8x = major == 8 and minor >= 0 -is_sm90 = major == 9 and minor == 0 - -HAS_FLASH_ATTN = False -HAS_FLASH_ATTN_V2 = False -try: - try: - import flash_attn_2_cuda - except ImportError: - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" - ) - if not (is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - HAS_FLASH_ATTN_V2 = True -except ImportError as e: - try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - - if not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - logger.warning(f"Unable to use Flash Attention V2: {e}") - HAS_FLASH_ATTN = True - - -def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): - if HAS_FLASH_ATTN_V2: - return flash_attn_2_cuda.varlen_fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - is_causal, - -1, - -1, - False, - None, - ) - - if HAS_FLASH_ATTN: - return flash_attn_cuda.fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - is_causal, - False, - 0, - None, - ) - - raise NotImplementedError("flash attention is not installed") diff --git a/router/src/lib.rs b/router/src/lib.rs index 3801af8a..e387b3cb 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -198,6 +198,7 @@ pub async fn run( backend_model_type, uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()), otlp_endpoint.clone(), + pooling.to_string(), ) .context("Could not create backend")?; backend From 8584b6d4480794625eb62e1e2e3af80538a26f2a Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:07:45 +0000 Subject: [PATCH 3/9] add layers --- .../text_embeddings_server/layers/__init__.py | 0 .../layers/attention/__init__.py | 11 +++ .../layers/attention/cuda.py | 92 +++++++++++++++++++ .../layers/attention/rocm.py | 45 +++++++++ .../layers/layernorm.py | 54 +++++++++++ .../utils/import_utils.py | 12 +++ 6 files changed, 214 insertions(+) create mode 100644 backends/python/server/text_embeddings_server/layers/__init__.py create mode 100644 backends/python/server/text_embeddings_server/layers/attention/__init__.py create mode 100644 backends/python/server/text_embeddings_server/layers/attention/cuda.py create mode 100644 backends/python/server/text_embeddings_server/layers/attention/rocm.py create mode 100644 backends/python/server/text_embeddings_server/layers/layernorm.py create mode 100644 backends/python/server/text_embeddings_server/utils/import_utils.py diff --git a/backends/python/server/text_embeddings_server/layers/__init__.py b/backends/python/server/text_embeddings_server/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backends/python/server/text_embeddings_server/layers/attention/__init__.py b/backends/python/server/text_embeddings_server/layers/attention/__init__.py new file mode 100644 index 00000000..9cce5d34 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/attention/__init__.py @@ -0,0 +1,11 @@ +from text_embeddings_server.utils.import_utils import SYSTEM +import os + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + raise ImportError("`USE_FLASH_ATTENTION` is false.") +if SYSTEM == "cuda": + from .cuda import attention +elif SYSTEM == "rocm": + from .rocm import attention +else: + raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/backends/python/server/text_embeddings_server/layers/attention/cuda.py b/backends/python/server/text_embeddings_server/layers/attention/cuda.py new file mode 100644 index 00000000..1d325351 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/attention/cuda.py @@ -0,0 +1,92 @@ +import os +import torch + +from loguru import logger + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + raise ImportError("`USE_FLASH_ATTENTION` is false.") + +if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 +is_sm8x = major == 8 and minor >= 0 +is_sm90 = major == 9 and minor == 0 + +HAS_FLASH_ATTN = False +HAS_FLASH_ATTN_V2 = False +try: + try: + import flash_attn_2_cuda + except ImportError: + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" + ) + if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2 = True +except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True + + +def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): + if HAS_FLASH_ATTN_V2: + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + is_causal, + -1, + -1, + False, + None, + ) + + if HAS_FLASH_ATTN: + return flash_attn_cuda.fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + is_causal, + False, + 0, + None, + ) + + raise NotImplementedError("flash attention is not installed") diff --git a/backends/python/server/text_embeddings_server/layers/attention/rocm.py b/backends/python/server/text_embeddings_server/layers/attention/rocm.py new file mode 100644 index 00000000..365e5451 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/attention/rocm.py @@ -0,0 +1,45 @@ +import os +import torch +from text_embeddings_server.utils.import_utils import SYSTEM +from loguru import logger + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 + +if SYSTEM == "rocm": + try: + import flash_attn_2_cuda + + logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + except ImportError as e: + if major >= 8 or is_sm75: + architecture_suffix = f"-{SYSTEM}" + raise ImportError(f"Flash Attention V2 is not installed. {e}") + else: + for idx in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(idx) + if "MI210" not in name and "MI250" not in name and "MI300" not in name: + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + raise ImportError( + f"AMD GPU with ROCm capability {major} {minor} is not supported" + ) from e + +def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + is_causal, + False, + None, + ) \ No newline at end of file diff --git a/backends/python/server/text_embeddings_server/layers/layernorm.py b/backends/python/server/text_embeddings_server/layers/layernorm.py new file mode 100644 index 00000000..abd9e676 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/layernorm.py @@ -0,0 +1,54 @@ +import torch +from text_embeddings_server.utils.import_utils import SYSTEM + +from transformers.models.bert import BertConfig + +if SYSTEM == "cuda": + import dropout_layer_norm + + class FastLayerNorm: + def __init__(self, prefix, handle, device, dtype, config: BertConfig): + self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) + self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) + self.variance_epsilon = config.layer_norm_eps + + def forward(self, hidden_states, residual=None): + normed_hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +elif SYSTEM == "rocm": + class FastLayerNorm: + def __init__(self, prefix, handle, device, dtype, config: BertConfig): + self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) + self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) + self.variance_epsilon = config.layer_norm_eps + + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = torch.nn.functional.layer_norm(hidden_states, self.weight.shape, self.weight, self.bias, eps=self.variance_epsilon) + + return hidden_states, residual +else: + raise ValueError("System not recognized") \ No newline at end of file diff --git a/backends/python/server/text_embeddings_server/utils/import_utils.py b/backends/python/server/text_embeddings_server/utils/import_utils.py new file mode 100644 index 00000000..83394eaa --- /dev/null +++ b/backends/python/server/text_embeddings_server/utils/import_utils.py @@ -0,0 +1,12 @@ +import torch +from loguru import logger + +SYSTEM = None +if torch.version.hip is not None: + SYSTEM = "rocm" +elif torch.version.cuda is not None and torch.cuda.is_available(): + SYSTEM = "cuda" +else: + SYSTEM = "cpu" + +logger.info(f"Python backend: detected system {SYSTEM}") From 2a2993a38655132fe4c9367a40b5273f783adc90 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 08:48:26 +0000 Subject: [PATCH 4/9] support mean pooling in python backend --- .../text_embeddings_server/layers/pooling.py | 22 +++++++++++++++++++ .../text_embeddings_server/models/__init__.py | 2 -- .../models/default_model.py | 4 +++- .../models/flash_bert.py | 22 ++++++++++++++----- backends/python/src/lib.rs | 6 +++-- backends/python/src/management.rs | 4 ++++ backends/src/lib.rs | 4 ++++ router/src/lib.rs | 17 +++++++++----- 8 files changed, 64 insertions(+), 17 deletions(-) create mode 100644 backends/python/server/text_embeddings_server/layers/pooling.py diff --git a/backends/python/server/text_embeddings_server/layers/pooling.py b/backends/python/server/text_embeddings_server/layers/pooling.py new file mode 100644 index 00000000..1bccbc57 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/pooling.py @@ -0,0 +1,22 @@ +import torch +from flash_attn.bert_padding import pad_input + +from loguru import logger + +def mean_pooling(embedding, cu_seqlens, max_s): + # Ideally, rust would pass `indices` to the FlashBatch. + seqlens = cu_seqlens[1:].clone() + seqlens[0] = cu_seqlens[1] + seqlens[1:] -= cu_seqlens[1:-1] + batch_size = len(seqlens) + + # Example: indices = [0, 1, 2, 3, 7, 8, 9, 10, 11, 12, 13] + mask = torch.zeros(batch_size, max_s, dtype=torch.int32, device=cu_seqlens.device) + mask[torch.arange(max_s) < seqlens[:, None].cpu()] = 1 + indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() + + embedding_padded = pad_input(embedding, indices, batch_size, max_s) + + sum_embeddings = torch.sum(embedding_padded, 1) + + return sum_embeddings / seqlens[:, None] \ No newline at end of file diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 7f480b33..c606efc9 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -25,8 +25,6 @@ __all__.append(FlashBert) -class - def get_model(model_path: Path, dtype: Optional[str], pooling_mode: str): if dtype == "float32": dtype = torch.float32 diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index dc39fdc8..17ad4589 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -8,14 +8,16 @@ from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding +from typing import Optional tracer = trace.get_tracer(__name__) class DefaultModel(Model): - def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype, pooling_mode: Optional[str]): model = AutoModel.from_pretrained(model_path).to(dtype).to(device) self.hidden_size = model.config.hidden_size + self.pooling_mode = pooling_mode self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 67176f2c..60be0002 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -12,7 +12,8 @@ from text_embeddings_server.models.types import FlashBatch, Embedding from text_embeddings_server.layers.attention import attention from text_embeddings_server.layers.layernorm import FastLayerNorm -from loguru import logger +from text_embeddings_server.layers.pooling import mean_pooling +from typing import Optional tracer = trace.get_tracer(__name__) @@ -190,12 +191,13 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): class FlashBert(Model): - def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype, pooling_mode: Optional[str]): config = BertConfig.from_pretrained(model_path) with safe_open(model_path / "model.safetensors", framework="pt") as f: model = FlashBertModel(f, device, dtype, config) self.hidden_size = config.hidden_size + self.pooling_mode = pooling_mode super(FlashBert, self).__init__(model=model, dtype=dtype, device=device) @@ -205,7 +207,6 @@ def batch_type(self) -> Type[FlashBatch]: @tracer.start_as_current_span("embed") def embed(self, batch: FlashBatch) -> List[Embedding]: - logger.info(f"batch.input_ids {batch.input_ids}") embedding = self.model.forward( input_ids=batch.input_ids, token_type_ids=batch.token_type_ids, @@ -214,9 +215,8 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: max_s=batch.max_s, ) - if True: + if self.pooling_mode == "cls": embedding = embedding[batch.cu_seqlens[:-1]] - logger.info(f"embedding {embedding.shape}") cpu_results = embedding.view(-1).tolist() return [ @@ -225,4 +225,14 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: ) for i in range(len(batch)) ] - elif + elif self.pooling_mode == "mean": + res = mean_pooling(embedding, batch.cu_seqlens, batch.max_s) + return [ + Embedding( + values=res[i] + ) + for i in range(len(batch)) + ] + + else: + raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") \ No newline at end of file diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 195f1d37..ef33b7d2 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -23,6 +23,7 @@ impl PythonBackend { uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result { match model_type { ModelType::Classifier => { @@ -31,8 +32,8 @@ impl PythonBackend { )) } ModelType::Embedding(pool) => { - if pool != Pool::Cls { - return Err(BackendError::Start(format!("{pool:?} is not supported"))); + if pool != Pool::Cls && pool != Pool::Mean { + return Err(BackendError::Start(format!("{pool:?} is not supported in the TEI Python backend. Please open an issue."))); } pool } @@ -44,6 +45,7 @@ impl PythonBackend { &uds_path, otlp_endpoint, otlp_service_name, + pooling_mode, )?; let tokio_runtime = tokio::runtime::Builder::new_current_thread() .enable_all() diff --git a/backends/python/src/management.rs b/backends/python/src/management.rs index 911c6984..2044a3e0 100644 --- a/backends/python/src/management.rs +++ b/backends/python/src/management.rs @@ -22,6 +22,7 @@ impl BackendProcess { uds_path: &str, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result { // Get UDS path let uds = Path::new(uds_path); @@ -52,6 +53,9 @@ impl BackendProcess { python_server_args.push("--otlp-service-name".to_owned()); python_server_args.push(otlp_service_name); + python_server_args.push("--pooling-mode".to_owned()); + python_server_args.push(pooling_mode); + // Copy current process env let envs: Vec<(OsString, OsString)> = env::vars_os().collect(); diff --git a/backends/src/lib.rs b/backends/src/lib.rs index d332b4a7..db27cddc 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -39,6 +39,7 @@ impl Backend { uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result { let (backend_sender, backend_receiver) = mpsc::unbounded_channel(); @@ -49,6 +50,7 @@ impl Backend { uds_path, otlp_endpoint, otlp_service_name, + pooling_mode, )?; let padded_model = backend.is_padded(); let max_batch_size = backend.max_batch_size(); @@ -138,6 +140,7 @@ fn init_backend( uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result, BackendError> { if cfg!(feature = "candle") { #[cfg(feature = "candle")] @@ -158,6 +161,7 @@ fn init_backend( uds_path, otlp_endpoint, otlp_service_name, + pooling_mode, ) }) .join() diff --git a/router/src/lib.rs b/router/src/lib.rs index 2f2fec29..03f8fc41 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -105,7 +105,7 @@ pub async fn run( serde_json::from_str(&config).context("Failed to parse `config.json`")?; // Set model type from config - let backend_model_type = get_backend_model_type(&config, &model_root, pooling)?; + let backend_model_type = get_backend_model_type(&config, &model_root, &pooling)?; // Info model type let model_type = match &backend_model_type { @@ -191,6 +191,11 @@ pub async fn run( } }); + let pooling_str = match pooling { + Some(pool) => pool.to_string(), + None => "none".to_string(), + }; + // Create backend tracing::info!("Starting model backend"); let backend = text_embeddings_backend::Backend::new( @@ -200,7 +205,7 @@ pub async fn run( uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()), otlp_endpoint.clone(), otlp_service_name.clone(), - pooling.to_string(), + pooling_str, ) .context("Could not create backend")?; backend @@ -307,10 +312,10 @@ pub async fn run( fn get_backend_model_type( config: &ModelConfig, model_root: &Path, - pooling: Option, + pooling: &Option, ) -> Result { for arch in &config.architectures { - if Some(text_embeddings_backend::Pool::Splade) == pooling && arch.ends_with("MaskedLM") { + if Some(text_embeddings_backend::Pool::Splade) == *pooling && arch.ends_with("MaskedLM") { return Ok(text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, )); @@ -324,7 +329,7 @@ fn get_backend_model_type( } } - if Some(text_embeddings_backend::Pool::Splade) == pooling { + if Some(text_embeddings_backend::Pool::Splade) == *pooling { return Err(anyhow!( "Splade pooling is not supported: model is not a ForMaskedLM model" )); @@ -332,7 +337,7 @@ fn get_backend_model_type( // Set pooling let pool = match pooling { - Some(pool) => pool, + Some(pool) => pool.clone(), None => { // Load pooling config let config_path = model_root.join("1_Pooling/config.json"); From 36b3a72dadd7973059e796141dbc1525e8b5c909 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 09:06:46 +0000 Subject: [PATCH 5/9] fix dockerfile and install --- Dockerfile-rocm | 5 +++++ backends/python/server/pyproject.toml | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/Dockerfile-rocm b/Dockerfile-rocm index f9e56631..152fa0a0 100644 --- a/Dockerfile-rocm +++ b/Dockerfile-rocm @@ -111,6 +111,11 @@ ARG DEFAULT_USE_FLASH_ATTENTION=True COPY backends/python/Makefile-flash-att-v2 Makefile-flash-att-v2 RUN make -f Makefile-flash-att-v2 install-flash-attention-v2-rocm +# Install python backend +COPY backends/python/server /tei_backends/python/server +COPY backends/proto tei_backends/proto +RUN make -C /tei_backends/python/server install + ENV HUGGINGFACE_HUB_CACHE=/data \ PORT=80 \ USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION diff --git a/backends/python/server/pyproject.toml b/backends/python/server/pyproject.toml index 839ff27a..8fbc0008 100644 --- a/backends/python/server/pyproject.toml +++ b/backends/python/server/pyproject.toml @@ -15,12 +15,13 @@ grpcio-status = "^1.51.1" grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" -safetensors = "^0.3.2" +safetensors = "^0.4.0" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" torch = { version = "==2.3.1" } +transformers = { version = "^4.39.0"} [tool.poetry.extras] From a8c02db493575a2cb95e7aeb20a586e71389f1c5 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 11:32:50 +0000 Subject: [PATCH 6/9] add tests --- .gitignore | 1 + .../layers/attention/__init__.py | 5 +- .../models/flash_bert.py | 1 - router/src/lib.rs | 14 +-- tests/__init__.py | 0 tests/assets/default_bert.pt | 0 tests/assets/flash_bert.pt | 0 tests/conftest.py | 113 ++++++++++++++++++ tests/pytest.ini | 2 + tests/test_default_model.py | 28 +++++ tests/test_flash_bert.py | 28 +++++ 11 files changed, 183 insertions(+), 9 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/assets/default_bert.pt create mode 100644 tests/assets/flash_bert.pt create mode 100644 tests/conftest.py create mode 100644 tests/pytest.ini create mode 100644 tests/test_default_model.py create mode 100644 tests/test_flash_bert.py diff --git a/.gitignore b/.gitignore index ee44a963..6862c2f1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea target +__pycache__/ diff --git a/backends/python/server/text_embeddings_server/layers/attention/__init__.py b/backends/python/server/text_embeddings_server/layers/attention/__init__.py index 9cce5d34..42aac2bd 100644 --- a/backends/python/server/text_embeddings_server/layers/attention/__init__.py +++ b/backends/python/server/text_embeddings_server/layers/attention/__init__.py @@ -2,7 +2,10 @@ import os if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": - raise ImportError("`USE_FLASH_ATTENTION` is false.") + class Attention: + def __getattr__(self, name): + raise RuntimeError(f"TEI is used with USE_FLASH_ATTENTION=false, accessing `attention` is prohibited") + attention = Attention() if SYSTEM == "cuda": from .cuda import attention elif SYSTEM == "rocm": diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 60be0002..6ebb70d4 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -233,6 +233,5 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: ) for i in range(len(batch)) ] - else: raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") \ No newline at end of file diff --git a/router/src/lib.rs b/router/src/lib.rs index 03f8fc41..14f1dfb3 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -105,7 +105,7 @@ pub async fn run( serde_json::from_str(&config).context("Failed to parse `config.json`")?; // Set model type from config - let backend_model_type = get_backend_model_type(&config, &model_root, &pooling)?; + let (backend_model_type, inferred_pooling) = get_backend_model_type(&config, &model_root, &pooling)?; // Info model type let model_type = match &backend_model_type { @@ -191,7 +191,7 @@ pub async fn run( } }); - let pooling_str = match pooling { + let pooling_str = match inferred_pooling { Some(pool) => pool.to_string(), None => "none".to_string(), }; @@ -313,19 +313,19 @@ fn get_backend_model_type( config: &ModelConfig, model_root: &Path, pooling: &Option, -) -> Result { +) -> Result<(text_embeddings_backend::ModelType, Option)> { for arch in &config.architectures { if Some(text_embeddings_backend::Pool::Splade) == *pooling && arch.ends_with("MaskedLM") { - return Ok(text_embeddings_backend::ModelType::Embedding( + return Ok((text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, - )); + ), Some(text_embeddings_backend::Pool::Splade))); } else if arch.ends_with("Classification") { if pooling.is_some() { tracing::warn!( "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg." ); } - return Ok(text_embeddings_backend::ModelType::Classifier); + return Ok((text_embeddings_backend::ModelType::Classifier, None)); } } @@ -353,7 +353,7 @@ fn get_backend_model_type( } } }; - Ok(text_embeddings_backend::ModelType::Embedding(pool)) + Ok((text_embeddings_backend::ModelType::Embedding(pool.clone()), Some(pool))) } #[derive(Debug, Deserialize)] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/assets/default_bert.pt b/tests/assets/default_bert.pt new file mode 100644 index 00000000..e69de29b diff --git a/tests/assets/flash_bert.pt b/tests/assets/flash_bert.pt new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..6d8ed997 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,113 @@ +import pytest +import asyncio +import contextlib +import random +import os +import tempfile +import subprocess +import shutil +import sys +from typing import Optional +from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError +import requests +import time +from requests.exceptions import ConnectionError as RequestsConnectionError + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + +class ProcessLauncherHandle: + def __init__(self, process, port: int): + self.port = port + self.process = process + + def _inner_health(self) -> bool: + return self.process.poll() is None + + def health(self, timeout: int = 60): + assert timeout > 0 + for _ in range(timeout): + if not self._inner_health(): + raise RuntimeError("Launcher crashed") + + try: + url = f"http://0.0.0.0:{self.port}/health" + headers = {"Content-Type": "application/json"} + + response = requests.post(url, headers=headers) + return + except (ClientConnectorError, ClientOSError, ServerDisconnectedError, RequestsConnectionError) as e: + print("Connecting") + time.sleep(1) + raise RuntimeError("Health check failed") + +@pytest.fixture(scope="module") +def launcher(event_loop): + @contextlib.contextmanager + def local_launcher( + model_id: str, + trust_remote_code: bool = False, + use_flash_attention: bool = True, + dtype: Optional[str] = None, + revision: Optional[str] = None, + pooling: Optional[str] = None, + ): + port = random.randint(8000, 10_000) + shard_uds_path = ( + f"/tmp/tei-tests-{model_id.split('/')[-1]}-server" + ) + + args = [ + "text-embeddings-router", + "--model-id", + model_id, + "--port", + str(port), + "--uds-path", + shard_uds_path, + ] + + env = os.environ + + if dtype is not None: + args.append("--dtype") + args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) + if trust_remote_code: + args.append("--trust-remote-code") + if pooling: + args.append("--pooling") + args.append(str(max_input_length)) + + env["LOG_LEVEL"] = "debug" + + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + + with tempfile.TemporaryFile("w+") as tmp: + # We'll output stdout/stderr to a temporary file. Using a pipe + # cause the process to block until stdout is read. + print("call subprocess.Popen, with args", args) + with subprocess.Popen( + args, + stdout=tmp, + stderr=subprocess.STDOUT, + env=env, + ) as process: + yield ProcessLauncherHandle(process, port) + + process.terminate() + process.wait(60) + + tmp.seek(0) + shutil.copyfileobj(tmp, sys.stderr) + + if not use_flash_attention: + del env["USE_FLASH_ATTENTION"] + + return local_launcher \ No newline at end of file diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..2f4c80e3 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/tests/test_default_model.py b/tests/test_default_model.py new file mode 100644 index 00000000..f8ab25fa --- /dev/null +++ b/tests/test_default_model.py @@ -0,0 +1,28 @@ +import pytest +import requests +import json +import torch + +@pytest.fixture(scope="module") +def default_model_handle(launcher): + with launcher("sentence-transformers/all-MiniLM-L6-v2", use_flash_attention=False) as handle: + yield handle + +@pytest.fixture(scope="module") +async def default_model(default_model_handle): + default_model_handle.health(300) + return default_model_handle + +@pytest.mark.asyncio +@pytest.mark.private +async def test_single_query(default_model): + url = f"http://0.0.0.0:{default_model.port}/embed" + data = {"inputs": "What is Deep Learning?"} + headers = {"Content-Type": "application/json"} + + response = requests.post(url, json=data, headers=headers) + + embedding = torch.Tensor(json.loads(response.text)) + # reference_embedding = torch.load("assets/default_model.pt") + + # assert torch.allclose(embedding, reference_embedding) \ No newline at end of file diff --git a/tests/test_flash_bert.py b/tests/test_flash_bert.py new file mode 100644 index 00000000..38df22e3 --- /dev/null +++ b/tests/test_flash_bert.py @@ -0,0 +1,28 @@ +import pytest +import requests +import json +import torch + +@pytest.fixture(scope="module") +def default_model_handle(launcher): + with launcher("sentence-transformers/all-MiniLM-L6-v2", use_flash_attention=True) as handle: + yield handle + +@pytest.fixture(scope="module") +async def default_model(default_model_handle): + default_model_handle.health(300) + return default_model_handle + +@pytest.mark.asyncio +@pytest.mark.private +async def test_single_query(default_model): + url = f"http://0.0.0.0:{default_model.port}/embed" + data = {"inputs": "What is Deep Learning?"} + headers = {"Content-Type": "application/json"} + + response = requests.post(url, json=data, headers=headers) + + embedding = torch.Tensor(json.loads(response.text)) + # reference_embedding = torch.load("assets/default_model.pt") + + # assert torch.allclose(embedding, reference_embedding) \ No newline at end of file From 35cc5b8c538c9be536ec451d4d379b086d845913 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:02:54 +0000 Subject: [PATCH 7/9] tests instructions --- tests/README.md | 11 +++++++++++ tests/requirements.txt | 3 +++ 2 files changed, 14 insertions(+) create mode 100644 tests/README.md create mode 100644 tests/requirements.txt diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..e5492ef9 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,11 @@ +## Testing + +To run the tests, install from within docker with `--entrypoint "/bin/bash"` the requirements +``` +pip install -r requirements.txt +``` + +and mounting a volume for the tests, they can be run from within the container with +``` +pytest tests/ -s -vvvvv +``` \ No newline at end of file diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 00000000..b1ee0f58 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,3 @@ +pytest +pytest-asyncio +aiohttp \ No newline at end of file From 309d25560bb41acf9d3c960cd21012bf2009cfea Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:10:19 +0000 Subject: [PATCH 8/9] add rocm image builder --- .github/workflows/build_rocm.yaml | 134 ++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 .github/workflows/build_rocm.yaml diff --git a/.github/workflows/build_rocm.yaml b/.github/workflows/build_rocm.yaml new file mode 100644 index 00000000..8a9fde49 --- /dev/null +++ b/.github/workflows/build_rocm.yaml @@ -0,0 +1,134 @@ + name: Build and push AMD ROCm docker image to registry + + on: + workflow_dispatch: + push: + branches: + - 'main' + tags: + - 'v*' + pull_request: + paths: + - ".github/workflows/build.yaml" +# - "integration-tests/**" + - "backends/**" + - "core/**" + - "router/**" + - "Cargo.lock" + - "rust-toolchain.toml" + - "Dockerfile" + branches: + - 'main' + + jobs: + build-and-push-image: + concurrency: + group: ${{ github.workflow }}-${{ github.job }}-rocm-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci] + permissions: + contents: write + packages: write + # This is used to complete the identity challenge + # with sigstore/fulcio when running outside of PRs. + id-token: write + security-events: write + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Initialize Docker Buildx + uses: docker/setup-buildx-action@v2.0.0 + with: + install: true + - name: Configure sccache + uses: actions/github-script@v6 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: Inject slug/short variables + uses: rlespinasse/github-slug-action@v4.4.1 + - name: Tailscale + uses: huggingface/tailscale-action@v1 + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry + uses: docker/login-action@v2.1.0 + with: + username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} + password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} + registry: registry.internal.huggingface.tech + - name: Extract metadata (tags, labels) for Docker + id: meta-rocm + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=rocm-{{version}} + type=semver,pattern=rocm-{{major}}.{{minor}} + type=raw,value=rocm-latest + type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }} + + - name: Build and push Docker image + id: build-and-push-rocm + uses: docker/build-push-action@v4 + with: + context: . + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-rocm.outputs.tags }} + labels: ${{ steps.meta-rocm.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max + cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max + + - name: Extract metadata (tags, labels) for Docker + id: meta-rocm-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=rocm-{{version}}-grpc + type=semver,pattern=rocm-{{major}}.{{minor}}-grpc + type=raw,value=rocm-latest-grpc + type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}-grpc + + - name: Build and push Docker image + id: build-and-push-rocm-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-rocm-grpc.outputs.tags }} + labels: ${{ steps.meta-rocm-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max From ae3da109356c830fd5b66bd31a88487ffc002521 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:48:26 +0000 Subject: [PATCH 9/9] add reference tensors --- tests/README.md | 23 +++++++++++ tests/assets/default_bert.pt | 0 tests/assets/flash_bert.pt | 0 ...ence-transformers-all-MiniLM-L6-v2_inp1.pt | Bin 0 -> 3024 bytes ...sformers-all-MiniLM-L6-v2_inp1_no_flash.pt | Bin 0 -> 3069 bytes ...ence-transformers-all-MiniLM-L6-v2_inp3.pt | Bin 0 -> 6096 bytes ...sformers-all-MiniLM-L6-v2_inp3_no_flash.pt | Bin 0 -> 6141 bytes tests/collect.py | 37 ++++++++++++++++++ tests/test_default_model.py | 4 +- tests/test_flash_bert.py | 4 +- 10 files changed, 64 insertions(+), 4 deletions(-) delete mode 100644 tests/assets/default_bert.pt delete mode 100644 tests/assets/flash_bert.pt create mode 100644 tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt create mode 100644 tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt create mode 100644 tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt create mode 100644 tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt create mode 100644 tests/collect.py diff --git a/tests/README.md b/tests/README.md index e5492ef9..c4ff5d0b 100644 --- a/tests/README.md +++ b/tests/README.md @@ -8,4 +8,27 @@ pip install -r requirements.txt and mounting a volume for the tests, they can be run from within the container with ``` pytest tests/ -s -vvvvv +``` + +## Reference outputs + +For example, collecting the reference on an RTX 4090 on Candle backend: +``` +docker run --rm -it --gpus all --net host --entrypoint "/bin/bash" -v $(pwd):/tei ghcr.io/huggingface/text-embeddings-inference:89-1.2.3 +``` +and +``` +text-embeddings-router --model-id sentence-transformers/all-MiniLM-L6-v2 +``` + +and then +``` +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1 --flash +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3 --flash +``` + +Restart server with `USE_FLASH_ATTENTION=0`, and +``` +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1 +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3 ``` \ No newline at end of file diff --git a/tests/assets/default_bert.pt b/tests/assets/default_bert.pt deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/assets/flash_bert.pt b/tests/assets/flash_bert.pt deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt new file mode 100644 index 0000000000000000000000000000000000000000..aaf95a92819bf701ce3a6e37c1c3870523216604 GIT binary patch literal 3024 zcmbVO2~<-_77aTD_XPzT&~9Y_34}e8dWac`N>D-Ck%1PBAz(BT5(46SupOioS3tA} z+=mf2Y-yE|Uk}9v7j#_OwUKVo)*b`~msSt1Qy**(>Tu35mGe`T{8x2vz4xl#jZm1G z3k25If=`;Cz*CT_PByBORcevZpiE9p)EQ>04XGlfRx28>N!Emq7lj9i(tYAJ$tm7m z2}+~VGi8=mE^`;e{Ld}ZRAtl|R5Lu|(~KHzsxaQ5o}Q-BCd4z`RGlF{-AAYwxVlDL zF@CAcQkkUIo3u|6%C$PBF{=Gfq24swgxQGoX3=J>SEZ!s&Bte&81)vB3W36e4VUUI z*|&oAR?Ac>omQ(>88y1(RN+KJg4&=?7_Ctmg?ekZNqQTjUN}i_>lvZ2vd&sz7P*lv z$UhOCI?G=W(j{k2pKDa>cv!iNbirGnx>c(&8nx;+qHx#Vl9A+}&WZX69J)l5kE)kg zF0<=|WnbN9T{{G(r3-mhiQxGJ&ObMME_TmD|Mi@P1mXJiZeR&yo9@^6q0$9 z9I?FUG)-La6t|~|xv#S_qzC;7oYq)X##t$Gj`0$Cksl3mT@G{SD-;T*9qh|ThhF$x>;Mq#Ahn*q~1C~NzUMLy^?||Lyd>nf{o}`B;$n!7m(sARR zQQc5OW4rCHaPv7!D#;6QDF}qa#ojchZX5R9Ud-E(i$!x1L}VcYNMQ3JJhO2n9@yFkzp7t}kB@G_mG^BAUn_USf0s_672104 zX7A@X+dT&Lg%hFAKV_W0&YJweO$iZ)&vL1u{pq<4e+uSunG9w0f8h!5>*>9k&9yxL3jr7;|GNahblK-Ztq^gFQom`ARyV-ymtj083iv_c?q$qX%}Y3W5c( zmLzew1^p5=*r@-RKKSV%jkBsJQSL{P%}4jd2$USS&-L;VW0gZS?SCc%H7^~(qn91r znsA?-eDD)#P0k{ZollUbuUCU@(-atEx)39eeGhvlMN+m0!jB)8m&4on-Gf^_!fANz zB56oLB(d96O9u2SM6+-gnpEjYB4T!sfSB&cavF0NUV(t!3-QTsF7VqK8(ceXLXe^N zYMklyH(-3xX~-9-m2`&<2U~;UN)Su-UTurZFPB6#ykS4({hym-g`*Q^LP~WLuBoqr zsG84!#VP1mO9?(dF%lqY9nQBuPgp#NXmk-55pIOAky^syPXs;7sbhEuj?Ex^zM=QQ z-gw*TEou676Z2&NE_SeBIlqUVkAFfd8*MN*C!16|scGCe2RKr8n6NX1d6UX}~z%|B3BK&h{g_rF)qFibN>yI7S~)?K5^twYF>r~=sJbpy}Ob^*qhF7usA_Epruu&OVmiOXF{ z6}Js`kNa0!e1qKoPTkhdmNpM6hL%woKyu6e!SLaigeg*Xo&du{kTzp6+4lA!N3S%K z8_v7%de|6P^0FLOhp5S+Q)vZjs6UK4Smf$ zw2NhSF(^4^AABUr4g}xzr3We(;J7_9wvXX(edhy?-QUFf@_e$fsFbq2fjvidK;){u zz?J@$^Lw00UI_RuH_WfY}2qs{=vW0ia^)W!l2+2MGxZFj6|DHP$3QJ3mck4QtiQXevyc_-AKjWn= hYZ*4}J)6f1U_r8N`)~ergu;TCdIAgf;qBXP{{yb>Qq2GW literal 0 HcmV?d00001 diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt new file mode 100644 index 0000000000000000000000000000000000000000..d986e33279c76e2005fdec89b65e617f1787e992 GIT binary patch literal 3069 zcmbtW2~<;88h&9XZi5Rh6?C*^5rSdIy!(TuEK?(hxFL!$U?9;n5E5{OVil_*id)Cl zQqgfoae=Dk-QTtDsL(oAMQbY#l*(w28%mwp`@r#tj;H4^_ne!TI`lA$r0Q_?a%)Ml7v8iPR=tvBjNM9W45$+G;^dSkk; z+L)$JF=))X!BaIBjZgY?L!|PrT*9A?T&rYDT1K+YN1bWW8_ZI5hBhfvZu(~M^6$c(Al4DHl#eX>Pra`22ZIa*B8I8z6o zSe3m)-csvPMeK#*jD2gcLb>4f1uJQmMVlt#M=JgQJX&u?)1bFl4B8f=P}|#%kre0K ziSnFN`-qBJ9W0Mjc52HqJ7+zrbq?6?gJPmv+{jOxokQo2E2qX)ztMHUeW>7kg*qPV zhmK8Ec(ZaHS?lV8(fv=;%oC6B_T@oB!@Nvh+K&Jh*&mYlYH+opnsl5S59KpQquZWy z^y4)a_-Kn10=gd|2TpsUZ^9?!wla|3vFwKEu}^96XQxP6T^?MY8-atvZh>sqYJ4(F zO*G3?WY@l0dSB5GXY{C~hnoLUVVr!Hd^M~Y&Tb5bhhDz)+}VwIdhcd3DY};EU2f63 z-`CLRVI#4~V<#*+dF-Ar{8VK?iT?s!hJbNWG4 zbXSo#zvz^a_?}z=-Z51;DLNUFRJpLKYn||0q89-{-!gVAxj0hJp?>P|CMFW+q&ImbOdrJ_82i7Nvu;V<9Sh1H}=v7URmQSL`ZOVjyO)I8x=PMx2su6P) z9F0q|!Kx1Vv~;rz_WrSwI_5_r6B_rp-l-xdDiSA$X5yyh-}AGoy`lN{6jTnYp~lI> zNyNz{tgTH4qvvPzl#85%2dtq_4ZTA@s^MYD{E@^UDaEEapFn}-CkU!4#-&4TsQ8YU za%8!U;%`JW3b$vZP_zFw(tTPPLHl1w{)jAL&eOl+=tu97hhBSWg3yEZoEQP`9{!rw zOz6YcPqL%s?+<|5S)Fldi2^#5+Y&aZVWS@3ss4c`?cYz;DL<0tz#~`@y%)2~V)01- zdZ8dpj@Rd&pmn;L_+07=?7a{lP)AsdNyN%L!g592yA;BshvT8qE*pz?mBEExqv(_H zQgSW7jOpqC%qi#;+E}fk?-b4BpSi^m76&o+_~@`3oxjM3RO_~qK^~ouX&P@WyaYFc zaxmB09d6V(;?oVo6?x&y@q^G!plY_kzVd;{*2063CdJhz!~%!!1?Dq)Hb_S6C+?$S z{CyMb@o%bW;8W6wq34dmx6k^*##DnM_WCy5UKIu`-ne<-d2(Sv4PiMDL_dZ5eMp{% zT@J$PIR!ipGMEZ@D8hH?*R&kDyD@7Z6}ioOVVuwYMbT-m)d?$>KEbHW``7{9B6d5|p4as&^9H_jOT zAAxhsQfT=@g8pYY_%~jnrA{88TWsKax~=9#PCLsxAk!*knq@jVOCFE@g6s*N4pAQ; zr&a&V!=ZVz>GZ3|sJI8S?kXJZwgC^E-VSkV7xNACi%81aYrx_{M4hd5Edu6CeE0E6 zq~?cwVrC3IF9cvt)Ja<4Fdk~gT*Jy*%r%Vwr)e4!G z$&M#Gkm+{F_I1l}i?jW{({-0=rF$4LDL;jdbFbjwFWrE7n-(P|5|0fxM6K{_AB3GX zpiRH0Xy%qPEo`A>u{z$ZeiN**o(Y->pC7z5U<%w?IEH6u0@(JDLE(gX#O2tp!tu>_ ziL&DkWL|(Rs)OL=rX}Zp)uMOBHM%dZ5ZKvA2PpQ@g&Vy1^)dILcTYEv`}IY3HVMjM z6R|ie0a-pmOZq!QL(gth)a_))aFX7~3z(M#mTO|YESt<f{ry` z2t}T`@~ zR7$u6d0=o*NMJ}vV6cBk;E*9fe!(GrayESBfg!SY)6I@YVw;M-tbZZB4ABl0@X zs|p}(fMQv1+pbN$(B11uuVzuz2I<@OBK>cD#j5P=yk7vgTFYKS*}nMp-|N5@A#vPH h$JQI{Ik8*j)^%w8Ew;t6DqFEHWCpP1#PzLf{{Z-0Ya{>w literal 0 HcmV?d00001 diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt new file mode 100644 index 0000000000000000000000000000000000000000..bea6dca1e82781e1446019cb8d5c37a1876e9a7b GIT binary patch literal 6096 zcmbVQ2{=~U*EeJ?rH~=fASFX7!r4orq@cm&3o3Wh`R(Q)GEXM)-yZVQs9)Pq6jhv_~Efm|>21%chR{ z#YMp3_p(r#afY0@H3}W#HWBak@pN`f0bCsZnGUkPNM3cu!kfe?c--JI#Fy>B^sDZq zkM<05wd65XpQ?##n`>#~hP;Xd>C5EC+%Itb@;;L=z%h`lEO1 zHEy0p2yv`WfsJa{vCPdI4Ch5bi_%lRvqhH#hG!6&M}heLL_4v#@S6Kva|te9Z>P@| z)<9L-4v2gm$<8tr7L-a`IG;(xd&%oa?_0{`nbc{t+ntOpiv96mcQUSP%tm&m@@XG5 zFlEwAdZ4F?pDf*lT|*XNM(8XkKR$)uzC)VS#CifdFMs)h5|zC_iIn&)X7K{Z=yDxk zTrk^f0yl7f`~{M?`wI1ae2TCf(cpn@bn?Ped}&2CO>3xtt#8{gpjL#cW%a^`H#XCe z!5aASN*(P#zy_JnL`wWx%VVxh!K(HM#>rhS(OC~}cPvJ;F%2};)Q-4+_Cni_%V67^ z-Sl$s7!vBalRkG+wOGHAgVk*_NvZq^A*b-oyBSzaJUqP@9h9L5;yK8;nanYc_^#!C zLjFi@XCYZ2_kv^wt>l#@cBArqZE|bhDcTjGNJlN50vDo6ImQdur6onH+J?aSjq;e* zU=C}?OA^aU32I&JkD8Ak(jfIx8nd#AEL&WKz1N;XB{v7W-O)75DhP|?St z_>3J@2wcZqcbP*{b~ca!@)-hFT4iBKWQ}r3#oRuOOIg&|brU+Dti?(Fbik)o2E8X* znV-<#ij$=dLN8kh92PYUnJsk7eKTJuS&1bB?(rSFvju*rey>6*@XzW7Z+7{^Bz%d# zeLV}ibT#1fz>VfX7xORye6D=dXXTm4!3h6M=a6#j2zx@)S;^~0_o$R%cRvklRVT3fS{0S z>UumD8#n$!-^ZMztUjoG&-JP-b= zhZsKX?M2EZjH;m~Z6IG9(&>KA7pN3ac7L&S4U(9=XpLTNFz{V+#4m@5+ues`S8WUf`wkadv zMD!pude_6}y2*D3Q+=nP#jDhe^Nnlu6N`0C=aYln3x?lDJ6K@N!J5xcLbhbe&tJjyL zu_C5Glt66?*47RH>(hJ5fSerqrMEcFb3X|pYp2rD##Ma2c09%}zD#>OOsJOBIjRw_ z4$AVIXj1L~$bxyuxC77piE#JZGT0>HcL-JwfCCYM0 zYu?zRsg=5^*(iT%=^9BTowtMfx>{`Pa>K}iTSR-CDLC9VL(lv(!0gIr8?G1l;P(Wl z(wz8Ev=krDnPxOqq;lf$ICC+0O|V7lOAo2ZZc$*`B8FoOxEH<5L9f~t;ulZnKk1Kv z!wS#IzAhPbYuQ0rUb*ibmBe8GA^7^v0o`2!Xxnxxc5mmNT5N@prRmUleh$4hSJvX! zm!)R6H>u&tQxB;9CE3bGy`dP?TNL~iQei@nJeGEA;(q1TV8uD%NT(K>V|WavAIPTx z(@IFL{%1NPCXUL7E~9#%63zXWx>T_I31`Gl`iJ?Lf@oM3y@9g$WVqN#U^;=hn-7rn zR)e@vPXdIT)A`yBoaOW~s;a*drYQ}vF#701W3#i!PK(Ec)d0B_AcN^TeJhzZF-_Ef z6b!qCH=c(6`1{LzF)X8~WeUM}-4R4)Yq6jgcu=7WwS~>(Zkr)6TVOm^_8*Xr>1Nl# z{?cX?cy)SJMp!M8IXd#NW9V!Yux|J$QF*MrkY?4T(bv18fbP#l0ndt+MIfc83AYN9 zF(zp}vU)&On|z*Ijp!S)R5%m0zT^f4J7+-2udbvq_F2XHuLJ3TFRiGt#}%Un=97bd zGll)+LB$c`{UsHRp6sG^mTlB0J`;Y~Vg!{lCIPE;d|W>Yx2E6cueyGvW2420e%uZ| z_Spd%-Sd%S9N}-_t4Yn9Uy)l4z;JOX*6;DRJVfRjc0K=v)mIIulx35HnF{damug;n z@LA?p6rgN?~#AqTKNK?}0KJf}UXg_t~iBZPJk z{?G<@@~w0bve&5KU=Nq1ZNa%q8b3DHa~G^Mkm(InZ;qu5JO8ZcDFi5nQPw;7zK4s5 zN9#MD=>Vn%NTEnT95GCI$5~h|LG~V=75Vc13IpgwgIxZ;{dxRNt%S_A*2E7pr{RL< zhpF{x4b&by30a*8xyRc1nZ&2LnN&HY5LW-Ncf18Kev$D*WNc4@z>Uz)Sic|x;$PE? zr{B@YCn})-Xg*Y%XQri^?gynSoWa z7Se-<=Yit6XjJW7#?C$nd*#FcuQ?h+R&{V1!w=wO&WW^?bn^uV#y~R2!`7+uVae$v z)c&eMZVo#}?PA^WUZXfvY+H-Jj@&NT!-^@YB*Q8kzgn?=2H`c&-;>ojU zjPJfyo;a)opI&od-I<>1b=FAN~TeKecC#@;poMsEm9+ycC0(-SodUjDBR+b&0HuDOo z5UYExDb~#}A!4!5iC1J_i~X6hIKk*O&p7?j|K(iWAo6)5X$UGI-)!~JzW)reXwvSA zjzSkqX{tBRZ|cLXKlGWOb+m>^>INap6E^G40lB3|!SO)}<@a3{FvF<=N{_zPgd^Tz zDC8q!PBb6rHSmYuBfjIPppQ=b?J!GE8<|Gvf;~6r(U2a(xP!iZPm(E@#ev!L3_sE0 z?~|@w>fo2N3VaegK=y$JZkXOfDh6a>Thuf1VAB?&BjH0{RY~){om%%n#nCMZKtF5xuWFJE1+brPg&0J=8ynr>${Wm z(i&<$)JYMNPI%yB6E`e&62pr_djr##MP2PYuIf5Rd{wqUsDd99&z(fFy5(r`+_Ri* z(*k69s<3UkPaNISs82!@Eos)L4KHt!1Q}Nl?|+Nx>dV5TR5hIAa0$~*^EkzGqWEFG z_>cZiUP%Fj=|j4SA7vbnsPVa+wc%w5SrtrJUBR#yafIz7ng+pi<@FjY^|->tg$=Jv zOYwrsXJw&Z<{(TmoMa~BaD&xw64+c%z>67&=-vlLC(?txwDwbpth2(|1#Ll_ zOkZNS$eb1oPoeEe;p9QIG;~RKkhU%3iAioXrk6_-(NDk7n)mL+{eC}W^+O*>p1@4E zWXiZAw;`5l5uQ)$Z%turZkQWK52*UX6RS32xjl)D z`0R{PMHlGjoK|Q)T8On4FX$m-aiV)HADaeWfRNRO_-6h-bQaH`I$IK%){@A|8?oFB z#bjWg!_iJRny{HrCA&aSoajOKoL8hBV?WXLx#ln}XetW5j`_o8y#dhbvzcRcf{PE7 z!{AeWsof?mQaPcT?b$(ncO6HD1-_QA1MhF)IITkm8UMV-@L}k@SrmlHlcVU1DI)#Z zmqB*@`{Zeyh~`gCo>s6?BOA}zSgHTZ=;_BWY@DospnjjP{?Yb<*(-pts@ z)MUa~lW}8A1(f}ETsbh3ArvkLJ5d4bpOe?Wjv2=b5B|3PKXcfBM)>P|(D-lO_1WO` zXCnMLy!>Z~zY4+VZy|2~WQgC#j{o%k*EEm&n|}wzpF6@I;`(QRzh=nlZvk%nWPtx& zI}Ub|l6rsm9W1W@X9S5qUj5JgqcF;5Yn9UPceYO$fQtkV{QbS~*uhRh81;%suphzx HciaB}S|>Pe literal 0 HcmV?d00001 diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt new file mode 100644 index 0000000000000000000000000000000000000000..7bf5187911d6098a4ae44bd9cea3f5b5d8d0ff2b GIT binary patch literal 6141 zcmbtY3s_ER*M2*n6ggK2Nuole!+uthgmNfKN;;@idUbrG59uTnQAs0_)8vqdln(n@ zX;jlV53x!f+c>zK7NLhaPQ!V)gj?){lX&*y~Sd~g@M6=3l7S{eFwGtvJM6vh?>qDu+ogmtuRSgvl(IVP~Jv$S~Q3 zv9gjdIY)buz0AT`sjXr1?590DDKh+CsO}9ph8tAPqDRD{VU2st3$Y8ts?HxI6_Jv|IB44Y^LU15mj&d zNV`2WspEkA)LmsX*8Ow~-|osL(b6OF!{b`&T-Sqhv<&&$A;%RxqKDf)U{r_$yUh&!ze6do0MM=qg`zW zDU5i-F)Tn6I;hC08fsEA;F8@2mWwSkol~@&xNZl&EsZ80>V}deQ3>u#Pr*@{19AMl z6tpZljCrDQWkmxFNpj~+DVoWzZDZwlP7{WZi^=?0bkbe`x~-=-z6R?={dBHoQVLhTC6 zA;;hqx)zJ*!|k%@bn&?cSRQW>#TMXBY#1Q{`TxYQX?0B*o z8y`2(E!%_|_+aL%PzdPQM>n`mCcQa(>6UN}%e%S+Y|IytVuiC9SiKF-Z+ZwqZXwrY zsql;mJW^Iuu0i<~uc7Zx-}js#?RPJbpYmRi6F%$tsV(1Na-%MJ9(kSyuNXo@t>%DM zdl9#WAI4R_m#1m-$AV&v8miY@Kv0G}nKnwE2EGnN#uxpxvzXrev5gdZRN@$?^Z06- z16Ef(<=v&Hpr*?AblUwWw00T=Rh2`)>HZU<*VsxNBNNHv-fEKZb~h}xS`G)7M&s_# z3#`@-bd&yBvMl`qh)FPQKa-OUqPnimX1R2|+YKbd}h4XA^c0 zYkL!M;jRr#*HtukTM1nl+6asnY^*y4e&dq4lC}fH{&3S5y_e^3jFW@Fe2L*7 z)3G+Whhw^pyv zoC#!9%280A|N2XP9Y5w$Fyq`00d_kSf63b5prb6J*QuJucqGMZA@v`o4tUaH_rCdM3#b+uJWtHdLjF=?+ ztQZe_&EzrhmwS|7y$U7`xqwrrTqcXoY=hjS3q;-ZG~QS-1AiE!h&S`)xUjD6#NP5C zg{%x1pwIv+1!|=0O+18j?}QsGF4CSRZ?y4VfWyZ4z;ET2koLTsJZjku4QcgoZqbAC zmdhgYzF#`-tsMm2{c}k5g(K9xN*c4n&VdBy(O1VV@q^8`qoMItnzF)#GAuOxiWcmz z-bSC+tU&njh#nLU~27$Cq5&dxA8Jl-608J??_|Ovyo_$3m-*^Emk0@dBE6AiJ zYcZ#<7=s5Hz-LA2iV9yhFzKI-ZFwv3NwWePi+>{}n-#Fmtq|gVoK3^uZh_ePxiHv8 z2g}RMvG1Z0w_m%3Ub=XOSSI{xp{2coYD!h&NUWt8yiWd`f(!g)1#(| zZmD(PBhE&~O?h#Q3T6C4ud*^c+(B`;vzA5f@b5tAKhesH;M1ZC1>GqyAU~cQ)x1Kp zoU$n60k39B%nTd@sZ9SU%J>3@sB0Mi#skyj?-3Rof{&V`bL1sp_)9GkW9YA{LwPgR zJ+v||44))T=U9Eq%c5n#A#*hhlbMH1FElYj3QqfPCvw&XT+AE`fp3`OHILtJH6GZ0 za;{4W#qTocOy361duIi?T$Br;mmIL^Y!H3NTVq&D2bZX`8#eCGf`Rf*^nru2<(jHC z^9Co)FVAk+WrL-OqSt`sz)?!43>jOEphu9W$Tuvmn7fH1A2MZf+P1xGGos;eP6IPxxz zB;jA)zq5;u!SZ%}YPXUIe3R`rFObWzYhl5SiBKE;kXV*XhOMgMuy(afh1B>gJU#6e zbao}9z-zZ*8PN_oP71o&J7$C<-aZ`%tDEI2=G`o$8cyF(+l*M4ReS9su|4B`)EnPOKOpQ0P5U~A$RyJ=uz8`EFYY|shL;P`zX|# zXf55t9}fPOzJJ%v9d}xf=fxX{1b@J_#{lANE^^a?x>5EHqDgEDta-Q*HKeNng)u2aYxT=IYckV{!C;$8;IcH zoB4^9Q`rHXEo-r*O-s-OH+HcM*4~_si@$kEuDM$jjx}EmttmDTnWGP82S!6x%nRBv zxDc<7-U19CugboQiu6X}T?ydLsD*-#A*@Rg4~@9Z*;|gp?=qDEOQuqbe!Kau<|lA+ zX#`CgvyCUE7s!X6PTuCiAqetUKmlv2ZVzV`5P(Tz6oAF$oofeDrbC{aSA%?MF^SRB z#(?AVvHinw%JhYdPaF|u1gzG0OniWNoVN?jM+H2#`|EDmLw{rxpbB~9qu9L8Vd(90Rd<^@`xsWaTz5Ke8 zDNyXB2FqMrA^p@&^qrwW`kCd^bs3(Rl_E=S{ig>+nQzYK2YYLNQH5h(0Cc1{=h zx9w!(rbeP0{(`Fu`AE<}7T!y`=;H6TfPG;pgltOYkGp-Q7ufrz`%X5@f4CF0cZGt^ z=2v7**Cp^BG>_`;wSYX^512Vf0pvEySynvqB2HTkF>SIMaMBCW*|V91Uo%0)5NYVm zQwPNb=Xu+ocY~0h%r^F3P>U;%P0#_Cb^18^OeIJz72+z(Y&e>Bm;1b56}~TD2BG&$ z!P#>KShvoGJ)&}cT~i!ZHnoLI z44Wngfwk0Lmh*6nyJ>^LF@pUMo9ZFd9+@XVe!Z+E(>urXPKup&LD`4ZuxrFDqFbs$ zty`HlU6)xqfa%B=!{YIv^+~~Dy9xIm}fe+#^VD%QBNz_N^VySc>FX)@`j@N3A(2qO5 z6LQBZH;xJ3CF9e~>HWks%Ct$!D-?lwoIEt1PDYhgp^Tg&sSDXi4HkKk-DLyNG4nj* zyuU z_y{m z-_y`@Pa&(l5KH7=(n! z)`Rt3juxwlq5ZQ5T|IFqW&E(dZUITt=b~OpDeRY#L7VpjAwxTvV?IGCaTzd=(>IT_ z$+`Jel0z3DJG^f%uLM8OiZVlnV6f;Fg7(ap2Ci)?1G80vAMB{nW_1d>81io z-JcazCNPA;aY!2U6L{qvadOJfx-{&D=j=CFSc_2)sM+bK7{#-@-zeKw8wUPdBeL2|6%NzVV zz$h8RKSIg<`_uoPb-oBmB#IjH`vrENFfA8l5BvRhVJz-oFDJ}}nE~WP!v5d){tw(5 BMB)Gd literal 0 HcmV?d00001 diff --git a/tests/collect.py b/tests/collect.py new file mode 100644 index 00000000..313c0871 --- /dev/null +++ b/tests/collect.py @@ -0,0 +1,37 @@ + +import requests +import torch +import argparse +import json +import os + +parser = argparse.ArgumentParser(description='Assets collection') +parser.add_argument('--model-id', help='Model id', required=True) +parser.add_argument('--n_inp', help='Number of inputs', required=True, type=int) +parser.add_argument('--flash', action='store_true') + +args = parser.parse_args() + +url = f"http://0.0.0.0:80/embed" + +INPUTS = [ + "What is Deep Learning?", + "Today I am in Paris and I would like to", + "Paris weather is", + "Great job" +] + +data = {"inputs": INPUTS[:args.n_inp]} +headers = {"Content-Type": "application/json"} + +response = requests.post(url, json=data, headers=headers) + +embedding = torch.Tensor(json.loads(response.text)) + +postfix = "" +if not args.flash: + postfix = "_no_flash" + +save_path = f"./assets/{args.model_id.replace('/', '-')}_inp{args.n_inp}{postfix}.pt" +print(f"Saving embedding of shape {embedding.shape} to {save_path}") +torch.save(embedding, save_path) \ No newline at end of file diff --git a/tests/test_default_model.py b/tests/test_default_model.py index f8ab25fa..595fe6bf 100644 --- a/tests/test_default_model.py +++ b/tests/test_default_model.py @@ -23,6 +23,6 @@ async def test_single_query(default_model): response = requests.post(url, json=data, headers=headers) embedding = torch.Tensor(json.loads(response.text)) - # reference_embedding = torch.load("assets/default_model.pt") + reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt") - # assert torch.allclose(embedding, reference_embedding) \ No newline at end of file + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) \ No newline at end of file diff --git a/tests/test_flash_bert.py b/tests/test_flash_bert.py index 38df22e3..3c3fde1c 100644 --- a/tests/test_flash_bert.py +++ b/tests/test_flash_bert.py @@ -23,6 +23,6 @@ async def test_single_query(default_model): response = requests.post(url, json=data, headers=headers) embedding = torch.Tensor(json.loads(response.text)) - # reference_embedding = torch.load("assets/default_model.pt") + reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt") - # assert torch.allclose(embedding, reference_embedding) \ No newline at end of file + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) \ No newline at end of file