Skip to content

Commit 3b50fa3

Browse files
committed
fixing pre commit
1 parent e8c0598 commit 3b50fa3

17 files changed

+174
-54
lines changed

fairseq/checkpoint_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222

2323
from fairseq.data import data_utils
2424
from fairseq.dataclass.configs import CheckpointConfig
25-
from fairseq.dataclass.utils import (convert_namespace_to_omegaconf,
26-
overwrite_args_by_name)
25+
from fairseq.dataclass.utils import (
26+
convert_namespace_to_omegaconf,
27+
overwrite_args_by_name,
28+
)
2729
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
2830
from fairseq.file_io import PathManager
2931
from fairseq.models import FairseqDecoder, FairseqEncoder

fairseq/data/dictionary.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
import torch
1111

12-
# from fairseq import utils
13-
# from fairseq.data import data_utils
14-
# from fairseq.file_chunker_utils import Chunker, find_offsets
15-
# from fairseq.file_io import PathManager
16-
# from fairseq.tokenizer import tokenize_line
12+
from fairseq import utils
13+
from fairseq.data import data_utils
14+
from fairseq.file_chunker_utils import Chunker, find_offsets
15+
from fairseq.file_io import PathManager
16+
from fairseq.tokenizer import tokenize_line
1717

1818

1919
class Dictionary:

fairseq/dataclass/configs.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@
1111
import torch
1212
from omegaconf import II, MISSING
1313

14-
from fairseq.dataclass.constants import (DATASET_IMPL_CHOICES,
15-
DDP_BACKEND_CHOICES,
16-
DDP_COMM_HOOK_CHOICES,
17-
GENERATION_CONSTRAINTS_CHOICES,
18-
GENERATION_DECODING_FORMAT_CHOICES,
19-
LOG_FORMAT_CHOICES,
20-
PIPELINE_CHECKPOINT_CHOICES,
21-
PRINT_ALIGNMENT_CHOICES,
22-
ZERO_SHARDING_CHOICES)
14+
from fairseq.dataclass.constants import (
15+
DATASET_IMPL_CHOICES,
16+
DDP_BACKEND_CHOICES,
17+
DDP_COMM_HOOK_CHOICES,
18+
GENERATION_CONSTRAINTS_CHOICES,
19+
GENERATION_DECODING_FORMAT_CHOICES,
20+
LOG_FORMAT_CHOICES,
21+
PIPELINE_CHECKPOINT_CHOICES,
22+
PRINT_ALIGNMENT_CHOICES,
23+
ZERO_SHARDING_CHOICES,
24+
)
2325

2426

2527
@dataclass

fairseq/dataclass/utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,7 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
345345

346346
no_dc = True
347347
if hasattr(args, "arch"):
348-
from fairseq.models import (ARCH_MODEL_NAME_REGISTRY,
349-
ARCH_MODEL_REGISTRY)
348+
from fairseq.models import ARCH_MODEL_NAME_REGISTRY, ARCH_MODEL_REGISTRY
350349

351350
if args.arch in ARCH_MODEL_REGISTRY:
352351
m_cls = ARCH_MODEL_REGISTRY[args.arch]

fairseq/distributed/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .fully_sharded_data_parallel import (FullyShardedDataParallel,
7-
fsdp_enable_wrap, fsdp_wrap)
6+
from .fully_sharded_data_parallel import (
7+
FullyShardedDataParallel,
8+
fsdp_enable_wrap,
9+
fsdp_wrap,
10+
)
811

912
__all__ = [
1013
"fsdp_enable_wrap",

fairseq/distributed/fully_sharded_data_parallel.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from fairseq.distributed import utils as dist_utils
1313

1414
try:
15-
from fairscale.nn.data_parallel import \
16-
FullyShardedDataParallel as FSDP # type: ignore
15+
from fairscale.nn.data_parallel import (
16+
FullyShardedDataParallel as FSDP, # type: ignore
17+
)
1718

1819
has_FSDP = True
1920
except ImportError:

fairseq/file_chunker_utils.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import os
7+
import typing as tp
8+
9+
10+
def _safe_readline(fd) -> str:
11+
pos = fd.tell()
12+
while True:
13+
try:
14+
return fd.readline()
15+
except UnicodeDecodeError:
16+
pos -= 1
17+
fd.seek(pos) # search where this character begins
18+
19+
20+
def find_offsets(filename: str, num_chunks: int) -> tp.List[int]:
21+
"""
22+
given a file and a number of chuncks, find the offsets in the file
23+
to be able to chunk around full lines.
24+
"""
25+
with open(filename, "r", encoding="utf-8") as f:
26+
size = os.fstat(f.fileno()).st_size
27+
chunk_size = size // num_chunks
28+
offsets = [0 for _ in range(num_chunks + 1)]
29+
for i in range(1, num_chunks):
30+
f.seek(chunk_size * i)
31+
_safe_readline(f)
32+
offsets[i] = f.tell()
33+
offsets[-1] = size
34+
return offsets
35+
36+
37+
class ChunkLineIterator:
38+
"""
39+
Iterator to properly iterate over lines of a file chunck.
40+
"""
41+
42+
def __init__(self, fd, start_offset: int, end_offset: int):
43+
self._fd = fd
44+
self._start_offset = start_offset
45+
self._end_offset = end_offset
46+
47+
def __iter__(self) -> tp.Iterable[str]:
48+
self._fd.seek(self._start_offset)
49+
# next(f) breaks f.tell(), hence readline() must be used
50+
line = _safe_readline(self._fd)
51+
while line:
52+
pos = self._fd.tell()
53+
# f.tell() does not always give the byte position in the file
54+
# sometimes it skips to a very large number
55+
# it is unlikely that through a normal read we go from
56+
# end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely
57+
# that the procedure breaks by the undeterministic behavior of
58+
# f.tell()
59+
if (
60+
self._end_offset > 0
61+
and pos > self._end_offset
62+
and pos < self._end_offset + 2**32
63+
):
64+
break
65+
yield line
66+
line = self._fd.readline()
67+
68+
69+
class Chunker:
70+
"""
71+
contextmanager to read a chunck of a file line by line.
72+
"""
73+
74+
def __init__(self, path: str, start_offset: int, end_offset: int):
75+
self.path = path
76+
self.start_offset = start_offset
77+
self.end_offset = end_offset
78+
79+
def __enter__(self) -> ChunkLineIterator:
80+
self.fd = open(self.path, "r", encoding="utf-8")
81+
return ChunkLineIterator(self.fd, self.start_offset, self.end_offset)
82+
83+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
84+
self.fd.close()

fairseq/models/fairseq_model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
from fairseq import utils
2020
from fairseq.data import Dictionary
21-
from fairseq.dataclass.utils import (convert_namespace_to_omegaconf,
22-
gen_parser_from_dataclass)
21+
from fairseq.dataclass.utils import (
22+
convert_namespace_to_omegaconf,
23+
gen_parser_from_dataclass,
24+
)
2325
from fairseq.models import FairseqDecoder, FairseqEncoder
2426

2527
logger = logging.getLogger(__name__)

fairseq/models/hubert/hubert.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717
from fairseq.data.dictionary import Dictionary
1818
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
1919
from fairseq.models import BaseFairseqModel, register_model
20-
from fairseq.models.wav2vec.wav2vec2 import (EXTRACTOR_MODE_CHOICES,
21-
LAYER_TYPE_CHOICES,
22-
MASKING_DISTRIBUTION_CHOICES,
23-
ConvFeatureExtractionModel,
24-
TransformerEncoder)
20+
from fairseq.models.wav2vec.wav2vec2 import (
21+
EXTRACTOR_MODE_CHOICES,
22+
LAYER_TYPE_CHOICES,
23+
MASKING_DISTRIBUTION_CHOICES,
24+
ConvFeatureExtractionModel,
25+
TransformerEncoder,
26+
)
2527
from fairseq.modules import GradMultiply, LayerNorm
26-
from fairseq.tasks.hubert_pretraining import (HubertPretrainingConfig,
27-
HubertPretrainingTask)
28+
from fairseq.tasks.hubert_pretraining import (
29+
HubertPretrainingConfig,
30+
HubertPretrainingTask,
31+
)
2832

2933
logger = logging.getLogger(__name__)
3034

fairseq/models/wav2vec/wav2vec2.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@
1616
from fairseq.data.data_utils import compute_mask_indices
1717
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
1818
from fairseq.distributed import fsdp_wrap
19-
from fairseq.distributed.fully_sharded_data_parallel import \
20-
FullyShardedDataParallel
19+
from fairseq.distributed.fully_sharded_data_parallel import FullyShardedDataParallel
2120
from fairseq.models import BaseFairseqModel, register_model
22-
from fairseq.modules import (Fp32GroupNorm, Fp32LayerNorm, GradMultiply,
23-
GumbelVectorQuantizer, LayerNorm,
24-
MultiheadAttention, RelPositionalEncoding,
25-
SamePad, TransposeLast)
21+
from fairseq.modules import (
22+
Fp32GroupNorm,
23+
Fp32LayerNorm,
24+
GradMultiply,
25+
GumbelVectorQuantizer,
26+
LayerNorm,
27+
MultiheadAttention,
28+
RelPositionalEncoding,
29+
SamePad,
30+
TransposeLast,
31+
)
2632
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
2733
from fairseq.modules.conformer_layer import ConformerWav2Vec2EncoderLayer
2834
from fairseq.modules.transformer_sentence_encoder import init_bert_params

fairseq/modules/__init__.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .espnet_multihead_attention import (ESPNETMultiHeadedAttention,
7-
RelPositionMultiHeadedAttention,
8-
RotaryPositionMultiHeadedAttention)
6+
from .espnet_multihead_attention import (
7+
ESPNETMultiHeadedAttention,
8+
RelPositionMultiHeadedAttention,
9+
RotaryPositionMultiHeadedAttention,
10+
)
911
from .fp32_group_norm import Fp32GroupNorm
1012
from .grad_multiply import GradMultiply
1113
from .gumbel_vector_quantizer import GumbelVectorQuantizer

fairseq/modules/conformer_layer.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88

99
import torch
1010

11-
from fairseq.modules import (ESPNETMultiHeadedAttention, LayerNorm,
12-
MultiheadAttention,
13-
RelPositionMultiHeadedAttention,
14-
RotaryPositionMultiHeadedAttention)
11+
from fairseq.modules import (
12+
ESPNETMultiHeadedAttention,
13+
LayerNorm,
14+
MultiheadAttention,
15+
RelPositionMultiHeadedAttention,
16+
RotaryPositionMultiHeadedAttention,
17+
)
1518
from fairseq.utils import get_activation_fn
1619

1720

fairseq/modules/espnet_multihead_attention.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from torch import nn
1313

1414
from fairseq.modules.rotary_positional_embedding import (
15-
RotaryPositionalEmbedding, apply_rotary_pos_emb)
15+
RotaryPositionalEmbedding,
16+
apply_rotary_pos_emb,
17+
)
1618

1719

1820
class ESPNETMultiHeadedAttention(nn.Module):

fairseq/modules/multihead_attention.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
_xformers_available = False
2121

2222
from fairseq import utils
23-
from fairseq.models.fairseq_incremental_decoder import \
24-
FairseqIncrementalDecoder
23+
from fairseq.models.fairseq_incremental_decoder import FairseqIncrementalDecoder
2524
from fairseq.modules.fairseq_dropout import FairseqDropout
2625
from fairseq.modules.quant_noise import quant_noise
2726

fairseq/search.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
import torch.nn as nn
1111
from torch import Tensor
1212

13-
from fairseq.token_generation_constraints import (ConstraintState,
14-
OrderedConstraintState,
15-
UnorderedConstraintState)
13+
from fairseq.token_generation_constraints import (
14+
ConstraintState,
15+
OrderedConstraintState,
16+
UnorderedConstraintState,
17+
)
1618

1719

1820
class Search(nn.Module):

fairseq/tasks/fairseq_task.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
from omegaconf import DictConfig
1414

1515
from fairseq import search, tokenizer, utils
16-
from fairseq.data import (Dictionary, FairseqDataset, data_utils, encoders,
17-
iterators)
16+
from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators
1817
from fairseq.dataclass import FairseqDataclass
1918
from fairseq.dataclass.utils import gen_parser_from_dataclass
2019
from fairseq.logging import metrics
@@ -412,8 +411,10 @@ def build_generator(
412411
compute_alignment=getattr(args, "print_alignment", False),
413412
)
414413

415-
from fairseq.sequence_generator import (SequenceGenerator,
416-
SequenceGeneratorWithAlignment)
414+
from fairseq.sequence_generator import (
415+
SequenceGenerator,
416+
SequenceGeneratorWithAlignment,
417+
)
417418

418419
# Choose search strategy. Defaults to Beam Search.
419420
sampling = getattr(args, "sampling", False)

setup.cfg

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[flake8]
2+
ignore = H102,H103,W503,H238,E203,H301,H306,E231
3+
max-line-length = 130
4+
[pycodestyle]
5+
ignore = H102,H103,W503,H238,E203,H301,H306,E231
6+
max-line-length = 130
7+
[isort]
8+
profile = black

0 commit comments

Comments
 (0)