Skip to content

Commit af6305c

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Patch speech_transformer data loading to load from absolute path (#699)
Summary: Lazy Tensor users would like to run the benchmark from a different directory. Yet the data loading script in speech transformer model always assume the code runs under the `benchmark/` directory. This PR patches the dataloading code of speech transformer model such that it finds the absolute path of input data, and load data from there. It also fixes a bug in `run.py` related to `extra_args`. Pull Request resolved: #699 Reviewed By: wconstab Differential Revision: D33625399 Pulled By: xuzhao9 fbshipit-source-id: 2e9fdc9184bb397601dbb287dd872c281707d5a9
1 parent 9e35edc commit af6305c

File tree

4 files changed

+27
-9
lines changed

4 files changed

+27
-9
lines changed

run.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def _validate_devices(devices: str):
163163
print(f"Unable to find model matching {args.model}.")
164164
exit(-1)
165165
model_args = inspect.signature(Model)
166-
if extra_args and not 'extra_args' in model_args.parameters:
166+
support_extra_args = 'extra_args' in model_args.parameters
167+
if extra_args and not support_extra_args:
167168
print(f"The model {args.model} doesn't accept extra args: {extra_args}")
168169
exit(-1)
169170
print(f"Running {args.test} method from {Model.name} on {args.device} in {args.mode} mode.")
@@ -174,14 +175,23 @@ def _validate_devices(devices: str):
174175
if args.bs:
175176
try:
176177
if args.test == "eval":
177-
m = Model(device=args.device, jit=(args.mode == "jit"), eval_bs=args.bs, extra_args=extra_args)
178+
if support_extra_args:
179+
m = Model(device=args.device, jit=(args.mode == "jit"), eval_bs=args.bs, extra_args=extra_args)
180+
else:
181+
m = Model(device=args.device, jit=(args.mode == "jit"), eval_bs=args.bs)
178182
elif args.test == "train":
179-
m = Model(device=args.device, jit=(args.mode == "jit"), train_bs=args.bs, extra_args=extra_args)
183+
if support_extra_args:
184+
m = Model(device=args.device, jit=(args.mode == "jit"), train_bs=args.bs, extra_args=extra_args)
185+
else:
186+
m = Model(device=args.device, jit=(args.mode == "jit"), eval_bs=args.bs)
180187
except:
181188
print(f"The model {args.model} doesn't support specifying batch size, please remove --bs argument in the commandline.")
182189
exit(1)
183190
else:
184-
m = Model(device=args.device, jit=(args.mode == "jit"), extra_args=extra_args)
191+
if support_extra_args:
192+
m = Model(device=args.device, jit=(args.mode == "jit"), extra_args=extra_args)
193+
else:
194+
m = Model(device=args.device, jit=(args.mode == "jit"))
185195

186196
test = getattr(m, args.test)
187197
model_flops = None

torchbenchmark/models/speech_transformer/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def __init__(self, device=None, jit=False, train_bs=32):
3030
return
3131
self.traincfg = SpeechTransformerTrainConfig(prefetch=True, train_bs=train_bs, num_train_batch=NUM_TRAIN_BATCH)
3232
self.evalcfg = SpeechTransformerEvalConfig(self.traincfg, num_eval_batch=NUM_EVAL_BATCH)
33-
self.traincfg.model.cuda()
34-
self.evalcfg.model.cuda()
33+
self.traincfg.model.to(self.device)
34+
self.evalcfg.model.to(self.device)
3535

3636
def get_module(self):
3737
if self.device == "cpu":
@@ -40,7 +40,7 @@ def get_module(self):
4040
raise NotImplementedError("JIT is not supported by this model")
4141
for data in self.traincfg.tr_loader:
4242
padded_input, input_lengths, padded_target = data
43-
return self.traincfg.model, (padded_input.cuda(), input_lengths.cuda(), padded_target.cuda())
43+
return self.traincfg.model, (padded_input.to(self.device), input_lengths.to(self.device), padded_target.to(self.device))
4444

4545
def train(self, niter=1):
4646
if self.device == "cpu":

torchbenchmark/models/speech_transformer/config.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ class SpeechTransformerTrainConfig:
3737
batch_frames = 15000
3838
maxlen_in = 800
3939
maxlen_out = 150
40-
num_workers = 4
40+
# don't use subprocess in dataloader
41+
# because TorchBench is only running 1 batch
42+
num_workers = 0
43+
# original value
44+
# num_workers = 4
4145
# optimizer
4246
k = 0.2
4347
warmup_steps = 1

torchbenchmark/models/speech_transformer/speech_transformer/data/data.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111
import json
1212

13+
from pathlib import Path
1314
import numpy as np
1415
import torch
1516
import torch.utils.data as data
@@ -144,7 +145,10 @@ def load_inputs_and_targets(batch, LFR_m=1, LFR_n=1):
144145
# load acoustic features and target sequence of token ids
145146
# for b in batch:
146147
# print(b[1]['input'][0]['feat'])
147-
xs = [kaldi_io.read_mat(b[1]['input'][0]['feat']) for b in batch]
148+
# TorchBench: Patch the input data with current file directory
149+
# Current file path: TORCHBENCH_ROOT/torchbenchmark/models/speech_transformer/speech_transformer/data/data.py
150+
TORCHBENCH_ROOT = Path(__file__).parents[5]
151+
xs = [kaldi_io.read_mat(str(TORCHBENCH_ROOT.joinpath(b[1]['input'][0]['feat']).resolve())) for b in batch]
148152
ys = [b[1]['output'][0]['tokenid'].split() for b in batch]
149153

150154
if LFR_m != 1 or LFR_n != 1:

0 commit comments

Comments
 (0)