Skip to content

Commit

Permalink
make style
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Apr 5, 2023
1 parent abcacc1 commit ab6d1fc
Show file tree
Hide file tree
Showing 16 changed files with 2,318 additions and 2,236 deletions.
18 changes: 5 additions & 13 deletions app/app.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os

import gradio as gr
import requests

from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
from transformers.pipelines.audio_utils import ffmpeg_read


title = "Whisper JAX: The Fastest Whisper API Available ⚡️"

Expand Down Expand Up @@ -56,7 +54,7 @@
"""
article = article.replace("{URL}", API_URL)

language_names = sorted(list(TO_LANGUAGE_CODE.keys()))
language_names = sorted(TO_LANGUAGE_CODE.keys())


def query(files_payload, json_payload):
Expand All @@ -65,18 +63,12 @@ def query(files_payload, json_payload):


def inference(input, language, task, return_timestamps):
json_payload = {
"task": task,
"return_timestamps": return_timestamps
}
json_payload = {"task": task, "return_timestamps": return_timestamps}

if language:
json_payload["language"] = f"<|{TO_LANGUAGE_CODE[language]}|>"

data = query(
{"inputs": {"array": input[1], "sampling_rate": input[0]}},
json_payload
)
data = query({"inputs": {"array": input[1], "sampling_rate": input[0]}}, json_payload)

text = data[0]["text"]

Expand All @@ -94,7 +86,7 @@ def inference(input, language, task, return_timestamps):
gr.inputs.Audio(source="upload", label="Input"),
gr.inputs.Dropdown(language_names, label="Language", default=None),
gr.inputs.Dropdown(["transcribe", "translate"], label="Task", default="transcribe"),
gr.inputs.Checkbox(default=False, label="Return timestamps")
gr.inputs.Checkbox(default=False, label="Return timestamps"),
],
outputs=[
gr.outputs.Textbox(label="Transcription"),
Expand Down
39 changes: 27 additions & 12 deletions app/fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import jax.numpy as jnp
import numpy as np
from fastapi import FastAPI, Request, HTTPException
from fastapi import FastAPI, HTTPException, Request
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE

from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE

checkpoint = "openai/whisper-tiny"

Expand All @@ -13,43 +13,58 @@

app = FastAPI()


@app.get("/")
def read_root():
return {"Hello": "World"}


language_codes = [f"<|{lang_id}|>" for lang_id in TO_LANGUAGE_CODE.values()]


def check_inputs(inputs, language, task, return_timestamps):
# required pre-processing to handle different input types efficiently over requests
if isinstance(inputs, dict):
if not ("sampling_rate" in inputs and "array" in inputs):
raise HTTPException(status_code=404, detail=("When passing a dictionary as inputs, the dict needs to contain an "
'"array" key containing the numpy array representing the audio, and a "sampling_rate" key, '
"containing the sampling_rate associated with that array"))
raise HTTPException(
status_code=404,
detail=(
"When passing a dictionary as inputs, the dict needs to contain an "
'"array" key containing the numpy array representing the audio, and a "sampling_rate" key, '
"containing the sampling_rate associated with that array"
),
)

if isinstance(inputs["array"], str):
inputs["array"] = np.fromstring(inputs["array"], dtype=np.int16)

if not isinstance(inputs["array"], np.ndarray):
raise HTTPException(status_code=404, detail=f"We expect a numpy ndarray as input, got `{type(inputs)}`")


if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")

if language is not None:
if not isinstance("language", str) or language not in language_codes:
raise HTTPException(status_code=404, detail=(f"language argument should be in ")) #TODO(SG): handle language as string
raise HTTPException(
status_code=404, detail=("language argument should be in ")
) # TODO(SG): handle language as string

if task is not None:
if not isinstance("task", str) or task not in ["transcribe", "translate"]:
raise HTTPException(status_code=404, detail=(f"task argument should be either"
f'"transcribe" or "translate", got {task}.'))
raise HTTPException(
status_code=404, detail=(f"task argument should be either" f'"transcribe" or "translate", got {task}.')
)

if return_timestamps is not None:
if not isinstance(return_timestamps, bool):
raise HTTPException(status_code=404,
detail=(f"return_timestamps should be a boolean value of either 'True' or 'False', got {return_timestamps}"))
raise HTTPException(
status_code=404,
detail=(
f"return_timestamps should be a boolean value of either 'True' or 'False', got {return_timestamps}"
),
)


@app.post("/generate/")
async def generate(request: Request):
Expand Down
30 changes: 18 additions & 12 deletions benchmarks/run_pjit.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
import argparse
import time
from datasets import load_dataset, concatenate_datasets
from flax.core.frozen_dict import freeze
import jax.numpy as jnp

import datasets
import jax
from jax.sharding import PartitionSpec as P
import jax.numpy as jnp
from datasets import concatenate_datasets, load_dataset
from flax.core.frozen_dict import freeze
from jax.experimental.compilation_cache import compilation_cache as cc
from transformers import WhisperProcessor, WhisperConfig
from jax.sharding import PartitionSpec as P
from transformers import WhisperConfig, WhisperProcessor

from whisper_jax import FlaxWhisperForConditionalGeneration, InferenceState, PjitPartitioner

from whisper_jax import FlaxWhisperForConditionalGeneration, PjitPartitioner, InferenceState

import datasets
datasets.logging.set_verbosity(datasets.logging.CRITICAL)

cc.initialize_cache("./jax_cache")
jax.config.update("jax_array", True)


def parse_args():
parser = argparse.ArgumentParser(description="Benchmark Whisper large-v2")
parser.add_argument(
"--model_parallel_submesh",
type=int,
nargs='+',
nargs="+",
default=(2, 2, 1, 1),
help="Model parallel submesh.",
)
args = parser.parse_args()
return args


BATCH_SIZES = [4, 8, 16, 32]
NUM_BATCHES = 100
NUM_TOKENS = 25
Expand All @@ -45,9 +49,10 @@ def parse_args():
("kv", None),
("length", None),
("num_mel", None),
("channels", None)
("channels", None),
]


def main():
args = parse_args()
print(args.model_parallel_submesh)
Expand Down Expand Up @@ -134,14 +139,15 @@ def generate(params, input_features):

# warm-up step
batch = next(iter(eval_dataloader))
pred_ids = p_generate(freeze(params), batch["input_features"])
p_generate(freeze(params), batch["input_features"])

start = time.time()
for batch in eval_dataloader:
pred_ids = p_generate(freeze(params), batch["input_features"])
p_generate(freeze(params), batch["input_features"])
runtime = time.time() - start

print(f"{batch_size}: {runtime:.06}")


if __name__ == "__main__":
main()
main()
22 changes: 15 additions & 7 deletions benchmarks/run_pjit_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import time

import jax
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset
from flax.core.frozen_dict import freeze
import jax.numpy as jnp
import jax
from jax.sharding import PartitionSpec as P
from jax.experimental.compilation_cache import compilation_cache as cc
from jax.sharding import PartitionSpec as P
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import WhisperProcessor

from whisper_jax import FlaxWhisperForConditionalGeneration, PjitPartitioner, InferenceState
from whisper_jax import FlaxWhisperForConditionalGeneration, InferenceState, PjitPartitioner

from tqdm import tqdm

cc.initialize_cache("./jax_cache")
jax.config.update("jax_array", True)
Expand All @@ -36,7 +36,7 @@
("kv", None),
("length", None),
("num_mel", None),
("channels", None)
("channels", None),
]

model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
Expand All @@ -45,6 +45,7 @@
dtype=jnp.bfloat16,
)


def init_fn():
input_shape = (1, 80, 3000)

Expand All @@ -68,6 +69,7 @@ def init_fn():
)
return init_params


# Axis names metadata
param_axes = jax.eval_shape(init_fn)["params_axes"]

Expand All @@ -90,10 +92,12 @@ def init_fn():

p_shard_params = partitioner.partition(model.to_bf16, (params_spec,), params_spec)


def generate(params, input_features):
output_ids = model.generate(input_features, params=params, max_length=NUM_TOKENS).sequences
return output_ids


p_generate = partitioner.partition(
generate,
in_axis_resources=(params_spec, P("data")),
Expand All @@ -109,18 +113,22 @@ def generate(params, input_features):
# processors/tokenizers are the same for all models, so just load from tiny and preprocess once
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")


def preprocess(batch):
batch["input_features"] = processor(
batch["audio"]["array"], sampling_rate=16000, return_tensors="np"
).input_features[0]
return batch


librispeech = load_dataset("speechcolab/gigaspeech", "l", split="train", streaming=STREAMING, use_auth_token=True)
librispeech_features = librispeech.features.keys()

librispeech_processed = librispeech.map(preprocess, remove_columns=librispeech_features)

eval_dataloader = DataLoader(librispeech_processed, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, persistent_workers=True)
eval_dataloader = DataLoader(
librispeech_processed, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, persistent_workers=True
)

all_load_times = 0
all_runtimes = 0
Expand Down
18 changes: 12 additions & 6 deletions benchmarks/run_pjit_save_transcriptions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset
from flax.core.frozen_dict import freeze
import jax.numpy as jnp
import jax
from jax.sharding import PartitionSpec as P
from jax.experimental.compilation_cache import compilation_cache as cc
from jax.sharding import PartitionSpec as P
from tqdm import tqdm
from transformers import WhisperProcessor

from whisper_jax import FlaxWhisperForConditionalGeneration, PjitPartitioner, InferenceState
from whisper_jax import FlaxWhisperForConditionalGeneration, InferenceState, PjitPartitioner

from tqdm import tqdm

cc.initialize_cache("./jax_cache")
jax.config.update("jax_array", True)
Expand All @@ -33,18 +33,20 @@
("kv", None),
("length", None),
("num_mel", None),
("channels", None)
("channels", None),
]

# processors/tokenizers are the same for all models, so just load from tiny and preprocess once
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")


def preprocess(batch):
batch["input_features"] = processor(
batch["audio"]["array"], sampling_rate=16000, return_tensors="np"
).input_features[0]
return batch


librispeech = load_dataset("librispeech_asr", "all", streaming=True)
librispeech_features = list(next(iter(librispeech.values())).features.keys())

Expand All @@ -54,6 +56,7 @@ def preprocess(batch):
dtype=jnp.bfloat16,
)


def init_fn():
input_shape = (1, 80, 3000)

Expand All @@ -77,6 +80,7 @@ def init_fn():
)
return init_params


# Axis names metadata
param_axes = jax.eval_shape(init_fn)["params_axes"]

Expand All @@ -99,10 +103,12 @@ def init_fn():

p_shard_params = partitioner.partition(model.to_bf16, (params_spec,), params_spec)


def generate(params, input_features):
output_ids = model.generate(input_features, params=params, max_length=NUM_TOKENS).sequences
return output_ids


p_generate = partitioner.partition(
generate,
in_axis_resources=(params_spec, P("data")),
Expand Down
Loading

0 comments on commit ab6d1fc

Please sign in to comment.