Skip to content

Commit

Permalink
fix Ruff issues
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Aug 9, 2023
1 parent 1590089 commit a57020a
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 27 deletions.
14 changes: 9 additions & 5 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import re
import typing
from abc import ABC, abstractmethod
from collections import Counter, defaultdict, namedtuple
from collections import Counter, defaultdict
from operator import itemgetter
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union, cast
from typing import Dict, Iterable, List, NamedTuple, Optional, Union, cast

import torch
from deprecated import deprecated
Expand Down Expand Up @@ -39,7 +39,11 @@ def _len_dataset(dataset: Optional[Dataset]) -> int:
return len(loader)


BoundingBox = namedtuple("BoundingBox", ["left", "top", "right", "bottom"])
class BoundingBox(NamedTuple):
left: str
top: int
right: int
bottom: int


class Dictionary:
Expand Down Expand Up @@ -727,7 +731,7 @@ def __init__(
if isinstance(use_tokenizer, Tokenizer):
tokenizer = use_tokenizer

elif type(use_tokenizer) == bool:
elif isinstance(use_tokenizer, bool):
tokenizer = SegtokTokenizer() if use_tokenizer else SpaceTokenizer()

else:
Expand Down Expand Up @@ -809,7 +813,7 @@ def _add_token(self, token: Union[Token, str]):
if isinstance(token, Token):
assert token.sentence is None

if type(token) is str:
if isinstance(token, str):
token = Token(token)
token = cast(Token, token)

Expand Down
2 changes: 1 addition & 1 deletion flair/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __getitem__(self, index: int = 0) -> Sentence:


def find_train_dev_test_files(data_folder, dev_file, test_file, train_file, autofind_splits=True):
if type(data_folder) == str:
if isinstance(data_folder, str):
data_folder: Path = Path(data_folder)

if train_file is not None:
Expand Down
12 changes: 6 additions & 6 deletions flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2976,7 +2976,7 @@ def __init__(
base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)

# if only one language is given
if type(languages) == str:
if isinstance(languages, str):
languages = [languages]

# column format
Expand Down Expand Up @@ -3249,7 +3249,7 @@ def __init__(
in_memory : bool, optional
Specify that the dataset should be loaded in memory, which speeds up the training process but takes increases the RAM usage significantly.
"""
if type(languages) == str:
if isinstance(languages, str):
languages = [languages]

base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)
Expand Down Expand Up @@ -3710,7 +3710,7 @@ def __init__(
]

# if only one language is given
if type(languages) == str:
if isinstance(languages, str):
languages = [languages]

base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)
Expand Down Expand Up @@ -3802,7 +3802,7 @@ def __init__(
base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)

# if only one language is given
if type(languages) == str:
if isinstance(languages, str):
languages = [languages]

# column format
Expand Down Expand Up @@ -4748,10 +4748,10 @@ def __init__(
"""
supported_domains = ["WN", "FIC", "ADG"]

if type(domains) == str and domains == "all":
if isinstance(domains, str) and domains == "all":
domains = supported_domains

if type(domains) == str:
if isinstance(domains, str):
domains = [domains]

base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)
Expand Down
2 changes: 1 addition & 1 deletion flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def add_and_switch_to_new_task(
# make label dictionary if no Dictionary object is passed
if isinstance(label_dictionary, Dictionary):
label_dictionary = label_dictionary.get_items()
if type(label_dictionary) == str:
if isinstance(label_dictionary, str):
label_dictionary = [label_dictionary]

# prepare dictionary of tags (without B- I- prefixes and without UNK)
Expand Down
6 changes: 3 additions & 3 deletions flair/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ def combined_rule_prefixes() -> List[str]:
r"/", # want to split at every slash
r"(?<=[0-9])[+\-\*^](?=[0-9-])",
rf"(?<=[{char_classes.ALPHA_LOWER}])\.(?=[{char_classes.ALPHA_UPPER}])",
r"(?<=[{a}]),(?=[{a}])".format(a=char_classes.ALPHA),
r'(?<=[{a}])[?";:=,.]*(?:{h})(?=[{a}])'.format(a=char_classes.ALPHA, h=char_classes.HYPHENS),
r"(?<=[{a}0-9])[:<>=/](?=[{a}])".format(a=char_classes.ALPHA),
rf"(?<=[{char_classes.ALPHA}]),(?=[{char_classes.ALPHA}])",
rf'(?<=[{char_classes.ALPHA}])[?";:=,.]*(?:{char_classes.HYPHENS})(?=[{char_classes.ALPHA}])',
rf"(?<=[{char_classes.ALPHA}0-9])[:<>=/](?=[{char_classes.ALPHA}])",
]
)

Expand Down
4 changes: 2 additions & 2 deletions flair/trainers/language_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __len__(self) -> int:

def __getitem__(self, index=0) -> torch.Tensor:
"""Tokenizes a text file on character basis."""
if type(self.files[index]) is str:
if isinstance(self.files[index], str):
self.files[index] = Path(self.files[index])
assert self.files[index].exists()

Expand Down Expand Up @@ -444,7 +444,7 @@ def load_checkpoint(
corpus: TextCorpus,
optimizer: Type[Optimizer] = SGD,
):
if type(checkpoint_file) is str:
if isinstance(checkpoint_file, str):
checkpoint_file = Path(checkpoint_file)

checkpoint = LanguageModel.load_checkpoint(checkpoint_file)
Expand Down
6 changes: 3 additions & 3 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def train(
# evaluation and monitoring
main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
monitor_test: bool = False,
monitor_train_sample: Union[float, int] = 0.0,
monitor_train_sample: float = 0.0,
use_final_model_for_eval: bool = False,
gold_label_dictionary_for_eval: Optional[Dictionary] = None,
exclude_labels: List[str] = [],
Expand Down Expand Up @@ -211,7 +211,7 @@ def fine_tune(
# evaluation and monitoring
main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
monitor_test: bool = False,
monitor_train_sample: Union[float, int] = 0.0,
monitor_train_sample: float = 0.0,
use_final_model_for_eval: bool = True,
gold_label_dictionary_for_eval: Optional[Dictionary] = None,
exclude_labels: List[str] = [],
Expand Down Expand Up @@ -302,7 +302,7 @@ def train_custom(
# evaluation and monitoring
main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
monitor_test: bool = False,
monitor_train_sample: Union[float, int] = 0.0,
monitor_train_sample: float = 0.0,
use_final_model_for_eval: bool = False,
gold_label_dictionary_for_eval: Optional[Dictionary] = None,
exclude_labels: List[str] = [],
Expand Down
4 changes: 2 additions & 2 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def to_tsv(self):
@staticmethod
def tsv_header(prefix=None):
if prefix:
return "{0}_MEAN_SQUARED_ERROR\t{0}_MEAN_ABSOLUTE_ERROR\t{0}_PEARSON\t{0}_SPEARMAN".format(prefix)
return f"{prefix}_MEAN_SQUARED_ERROR\t{prefix}_MEAN_ABSOLUTE_ERROR\t{prefix}_PEARSON\t{prefix}_SPEARMAN"

return "MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN"

Expand Down Expand Up @@ -99,7 +99,7 @@ class EvaluationMetric(Enum):

class WeightExtractor:
def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) -> None:
if type(directory) is str:
if isinstance(directory, str):
directory = Path(directory)
self.weights_file = init_output_file(directory, "weights.txt")
self.weights_dict: Dict[str, Dict[int, List[float]]] = defaultdict(lambda: defaultdict(list))
Expand Down
4 changes: 2 additions & 2 deletions flair/visual/training_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _extract_evaluation_data(file_name: Union[str, Path], score: str = "F1") ->

@staticmethod
def _extract_weight_data(file_name: Union[str, Path]) -> dict:
if type(file_name) is str:
if isinstance(file_name, str):
file_name = Path(file_name)

weights: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
Expand All @@ -86,7 +86,7 @@ def _extract_weight_data(file_name: Union[str, Path]) -> dict:

@staticmethod
def _extract_learning_rate(file_name: Union[str, Path]):
if type(file_name) is str:
if isinstance(file_name, str):
file_name = Path(file_name)

lrs = []
Expand Down
2 changes: 1 addition & 1 deletion tests/model_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_train_load_use_model_multi_label(
print(label)
assert label.value is not None
assert 0.0 <= label.score <= 1.0
assert type(label.score) is float
assert isinstance(label.score, float)

del trainer, model, multi_class_corpus
loaded_model = self.model_cls.load(results_base_path / "final-model.pt")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ def test_train_load_use_classifier(results_base_path, tasks_base_path):
for label in sentence.labels:
assert label.value is not None
assert 0.0 <= label.score <= 1.0
assert type(label.score) is float
assert isinstance(label.score, float)
del loaded_model

0 comments on commit a57020a

Please sign in to comment.