Skip to content

Commit

Permalink
Merge branch 'main' into update-perplexity-ci-install
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam authored Nov 27, 2024
2 parents c7eca71 + 9e1e121 commit 86eec50
Show file tree
Hide file tree
Showing 16 changed files with 669 additions and 127 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,45 @@ concurrency:
cancel-in-progress: true

jobs:
test_punet:
name: "Integration Tests - punet"
runs-on: nodai-amdgpu-mi250-x86-64
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
steps:
- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: 3.11

- name: "Checkout Code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- name: Cache Pip Packages
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}

- name: Install pip deps
run: |
python -m pip install --no-compile --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
# Update to the latest iree packages.
pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler iree-base-runtime --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
- name: Run punet tests
run: |
pytest -v sharktank/ -m model_punet
test:
name: "Unit Tests and Type Checking"
strategy:
Expand Down
41 changes: 29 additions & 12 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,28 @@ def pytest_addoption(parser):
# --outtype=f32 \
# t5-v1_1-small
parser.addoption(
"--google-t5-v1-1-small-fp32-model-path",
"--google-t5-v1-1-small-f32-model-path",
type=Path,
default="/data/t5/small/google__t5-v1_1-small_fp32.gguf",
help="Google T5 v1.1 small fp32 model path",
default="/data/t5/small/google__t5-v1_1-small_f32.gguf",
help="Google T5 v1.1 small float32 model path",
)
parser.addoption(
"--google-t5-v1-1-xxl-fp32-model-path",
"--google-t5-v1-1-small-bf16-model-path",
type=Path,
default="/data/t5/xxl/google__t5-v1_1-xxl_fp32.gguf",
help="Google T5 v1.1 XXL fp32 model path",
default="/data/t5/small/google__t5-v1_1-small_bf16.gguf",
help="Google T5 v1.1 small bfloat16 model path",
)
parser.addoption(
"--google-t5-v1-1-xxl-f32-model-path",
type=Path,
default="/data/t5/xxl/google__t5-v1_1-xxl_f32.gguf",
help="Google T5 v1.1 XXL float32 model path",
)
parser.addoption(
"--google-t5-v1-1-xxl-bf16-model-path",
type=Path,
default="/data/t5/xxl/google__t5-v1_1-xxl_bf16.gguf",
help="Google T5 v1.1 XXL bfloat16 model path",
)

parser.addoption(
Expand Down Expand Up @@ -288,15 +300,20 @@ def get_model_artifacts(request: FixtureRequest):
model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model"
)
model_path["google__t5_v1_1_small_fp32_model_path"] = set_fixture_from_cli_option(
model_path["google__t5_v1_1_small_f32_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-small-f32-model-path",
"google__t5_v1_1_small_f32_model",
)
model_path["google__t5_v1_1_small_bf16_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-small-fp32-model-path",
"google__t5_v1_1_small_fp32_model",
"--google-t5-v1-1-small-bf16-model-path",
"google__t5_v1_1_small_bf16_model",
)
model_path["google__t5_v1_1_xxl_fp32_model_path"] = set_fixture_from_cli_option(
model_path["google__t5_v1_1_xxl_f32_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-xxl-fp32-model-path",
"google__t5_v1_1_xxl_fp32_model",
"--google-t5-v1-1-xxl-f32-model-path",
"google__t5_v1_1_xxl_f32_model",
)
return model_path

Expand Down
7 changes: 4 additions & 3 deletions sharktank/integration/models/punet/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,13 @@ def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir):
def sdxl_int8_base_files():
from huggingface_hub import hf_hub_download

REPO_ID = "amd-shark/sdxl-quant-models"
REVISION = "942e771bf0c2657a8b33380103d04747a75dfa4a"
REPO_ID = "amd-shark/sdxl-quant-int8"
SUBFOLDER = "mi300_all_sym_8_step14_fp32"
REVISION = "efda8afb35fd72c1769e02370b320b1011622958"

def download(filename):
return hf_hub_download(
repo_id=REPO_ID, subfolder="unet/int8", filename=filename, revision=REVISION
repo_id=REPO_ID, subfolder=SUBFOLDER, filename=filename, revision=REVISION
)

return {
Expand Down
2 changes: 1 addition & 1 deletion sharktank/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
iree-turbine

# Runtime deps.
gguf==0.6.0
gguf==0.10.0
numpy<2.0

# Needed for newer gguf versions (TODO: remove when gguf package includes this)
Expand Down
17 changes: 15 additions & 2 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
== properties["t5.attention.layer_norm_rms_epsilon"]
)

all_kwargs = {"vocab_size": None, "feed_forward_proj": None}

gguf_to_config_names_map = {
"t5.context_length": ["context_length"],
"t5.embedding_length": ["d_model"],
Expand All @@ -236,18 +238,29 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
"t5.attention.key_length": ["d_kv"],
"t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"],
"t5.attention.relative_buckets_count": ["relative_attention_num_buckets"],
"t5.decoder_start_token_id": ["decoder_start_token_id"],
"tokenizer.ggml.eos_token_id": ["eos_token_id"],
"tokenizer.ggml.padding_token_id": ["pad_token_id"],
}
all_kwargs = {"vocab_size": None, "feed_forward_proj": None}
all_kwargs.update(
{
config_name: properties[gguf_name]
for gguf_name, config_names in gguf_to_config_names_map.items()
for config_name in config_names
}
)

gguf_to_optional_config_names_map = {
"t5.decoder_start_token_id": ["decoder_start_token_id"],
}
all_kwargs.update(
{
config_name: properties[gguf_name]
for gguf_name, config_names in gguf_to_optional_config_names_map.items()
for config_name in config_names
if gguf_name in properties
}
)

if "tokenizer.ggml.tokens" in properties:
all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"])
all_kwargs.update(kwargs)
Expand Down
21 changes: 11 additions & 10 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ class LinearLayer(ThetaLayer):
x = x * premul_input
matmul(x, weight.T) + bias
fake_quant exists to allow export without adding dequant ops.
when fake_quant is True, the op will in quant dequant fashion.
When false, it will keep quantized types.
fake quant only exists in order to allow for q_input to act as qdq.
when fake quant is false, q_input will quantize normally.
```
"""

Expand All @@ -43,7 +42,7 @@ def __init__(
*,
weight_name: str = "weight",
bias_name: str = "bias",
fake_quant: bool = True,
fake_quant: bool = False,
):
super().__init__(theta)
self._simulate_native_quant = True
Expand Down Expand Up @@ -74,21 +73,23 @@ def forward(self, x):
x = q_input.quantize(x)
if self.fake_quant:
x = x.unpack().dequant()
elif qdq_input is not None and self.fake_quant:

elif qdq_input is not None:
x = qdq_input.quantize(x).unpack().dequant()

y = ops.linear(x, weight, bias)

# Unconditionally dequantize.
if isinstance(y, QuantizedTensor) and not self.fake_quant:
if isinstance(y, QuantizedTensor):
y = y.unpack().dequant()
# Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32.
# We can truncate to fp16 in iree, so we do a cast here
# to account for this in the IR. This is may not be the right
# level to do this, but for now its here.
if not self.fake_quant and y.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.float16)
return y
if qdq_output is not None and self.fake_quant:
if not isinstance(y, QuantizedTensor):
if y.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.float16)
return y
if qdq_output is not None:
y = qdq_output.quantize(y).unpack().dequant()
return y
3 changes: 2 additions & 1 deletion sharktank/sharktank/layers/token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
from typing import Optional

from .. import ops
from .base import Theta, ThetaLayer
Expand All @@ -16,7 +17,7 @@ def __init__(
theta: Theta,
*,
weight_name: str = "weight",
dtype: torch.dtype = torch.float32,
dtype: Optional[torch.dtype] = torch.float32,
):
super().__init__(theta)
self.weight = self.theta_tensor(weight_name)
Expand Down
20 changes: 17 additions & 3 deletions sharktank/sharktank/models/t5/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Union
import functools
from typing import Optional, Union
from pathlib import Path
import torch
from copy import copy

from .t5 import T5Config, T5Encoder
from ...types import Dataset
from ...transforms.dataset import set_float_dtype
from iree.turbine.aot import FxProgramsBuilder, export

__all__ = [
Expand Down Expand Up @@ -91,7 +94,18 @@ def prune_decoder_parameters(dataset: Dataset):
pass


def export_encoder_iree_parameters(model_path: str, output_path: str):
dataset = Dataset.load(model_path)
def export_encoder_iree_parameters(
model_path_or_dataset: str | Dataset,
output_path: str,
dtype: Optional[torch.dtype] = None,
):
if isinstance(model_path_or_dataset, Dataset):
dataset = copy(model_path_or_dataset)
else:
dataset = Dataset.load(model_path_or_dataset)
if dtype:
dataset.root_theta = dataset.root_theta.transform(
functools.partial(set_float_dtype, dtype=dtype)
)
prune_decoder_parameters(dataset)
dataset.save(output_path)
8 changes: 6 additions & 2 deletions sharktank/sharktank/models/t5/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,9 @@ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None):
self.add_module(
"final_layer_norm",
RMSNormLayer(
theta(f"{theta_prefix}.output_norm"), epsilon=config.layer_norm_epsilon
theta(f"{theta_prefix}.output_norm"),
epsilon=config.layer_norm_epsilon,
dtype=config.activation_dtype,
),
)

Expand Down Expand Up @@ -1046,7 +1048,9 @@ def __init__(self, theta: Theta, config: T5Config):
super().__init__()
self.add_module(
"token_embedding",
TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype),
TokenEmbeddingLayer(
theta("token_embd"), dtype=theta("token_embd").tensor("weight").dtype
),
)

encoder_config = copy.deepcopy(config)
Expand Down
1 change: 1 addition & 0 deletions sharktank/sharktank/transforms/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .sharding import *
from .dataset import *
19 changes: 19 additions & 0 deletions sharktank/sharktank/transforms/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch

from ...types.tensors import InferenceTensor, PrimitiveTensor, DefaultPrimitiveTensor
from ... import ops


def set_float_dtype(tensor: InferenceTensor, dtype: torch.dtype) -> InferenceTensor:
if isinstance(tensor, PrimitiveTensor) and tensor.dtype.is_floating_point:
return DefaultPrimitiveTensor(
name=tensor.name, data=ops.to(tensor, dtype=dtype)
)

return tensor
9 changes: 9 additions & 0 deletions sharktank/sharktank/types/gguf_interop/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def _wrap_tensor(
name=name, data=_externalize_tensor(name, data, logical_shape)
)

if type_name == "BF16":
assert data.dtype == np.uint8
return DefaultPrimitiveTensor(
name=name,
data=_externalize_tensor(name, data.view(np.int16), logical_shape).view(
dtype=torch.bfloat16
),
)

quantized_type = _quantized_types.get(type_name)
if quantized_type is not None:
return quantized_type(
Expand Down
Loading

0 comments on commit 86eec50

Please sign in to comment.