From a57020a470711053d3cc35a77cc44ffe310408c9 Mon Sep 17 00:00:00 2001 From: Alan Akbik Date: Wed, 9 Aug 2023 16:31:26 +0200 Subject: [PATCH] fix Ruff issues --- flair/data.py | 14 +++++++++----- flair/datasets/base.py | 2 +- flair/datasets/sequence_labeling.py | 12 ++++++------ flair/models/tars_model.py | 2 +- flair/tokenization.py | 6 +++--- flair/trainers/language_model_trainer.py | 4 ++-- flair/trainers/trainer.py | 6 +++--- flair/training_utils.py | 4 ++-- flair/visual/training_curves.py | 4 ++-- tests/model_test_utils.py | 2 +- tests/test_multitask.py | 2 +- 11 files changed, 31 insertions(+), 27 deletions(-) diff --git a/flair/data.py b/flair/data.py index 182b12836..4c8f4ace7 100644 --- a/flair/data.py +++ b/flair/data.py @@ -3,10 +3,10 @@ import re import typing from abc import ABC, abstractmethod -from collections import Counter, defaultdict, namedtuple +from collections import Counter, defaultdict from operator import itemgetter from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union, cast +from typing import Dict, Iterable, List, NamedTuple, Optional, Union, cast import torch from deprecated import deprecated @@ -39,7 +39,11 @@ def _len_dataset(dataset: Optional[Dataset]) -> int: return len(loader) -BoundingBox = namedtuple("BoundingBox", ["left", "top", "right", "bottom"]) +class BoundingBox(NamedTuple): + left: str + top: int + right: int + bottom: int class Dictionary: @@ -727,7 +731,7 @@ def __init__( if isinstance(use_tokenizer, Tokenizer): tokenizer = use_tokenizer - elif type(use_tokenizer) == bool: + elif isinstance(use_tokenizer, bool): tokenizer = SegtokTokenizer() if use_tokenizer else SpaceTokenizer() else: @@ -809,7 +813,7 @@ def _add_token(self, token: Union[Token, str]): if isinstance(token, Token): assert token.sentence is None - if type(token) is str: + if isinstance(token, str): token = Token(token) token = cast(Token, token) diff --git a/flair/datasets/base.py b/flair/datasets/base.py index 98f625e14..f5550b5bc 100644 --- a/flair/datasets/base.py +++ b/flair/datasets/base.py @@ -229,7 +229,7 @@ def __getitem__(self, index: int = 0) -> Sentence: def find_train_dev_test_files(data_folder, dev_file, test_file, train_file, autofind_splits=True): - if type(data_folder) == str: + if isinstance(data_folder, str): data_folder: Path = Path(data_folder) if train_file is not None: diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 8c34b991d..d214873b5 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -2976,7 +2976,7 @@ def __init__( base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) # if only one language is given - if type(languages) == str: + if isinstance(languages, str): languages = [languages] # column format @@ -3249,7 +3249,7 @@ def __init__( in_memory : bool, optional Specify that the dataset should be loaded in memory, which speeds up the training process but takes increases the RAM usage significantly. """ - if type(languages) == str: + if isinstance(languages, str): languages = [languages] base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) @@ -3710,7 +3710,7 @@ def __init__( ] # if only one language is given - if type(languages) == str: + if isinstance(languages, str): languages = [languages] base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) @@ -3802,7 +3802,7 @@ def __init__( base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) # if only one language is given - if type(languages) == str: + if isinstance(languages, str): languages = [languages] # column format @@ -4748,10 +4748,10 @@ def __init__( """ supported_domains = ["WN", "FIC", "ADG"] - if type(domains) == str and domains == "all": + if isinstance(domains, str) and domains == "all": domains = supported_domains - if type(domains) == str: + if isinstance(domains, str): domains = [domains] base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) diff --git a/flair/models/tars_model.py b/flair/models/tars_model.py index de431c468..6bee5aee1 100644 --- a/flair/models/tars_model.py +++ b/flair/models/tars_model.py @@ -195,7 +195,7 @@ def add_and_switch_to_new_task( # make label dictionary if no Dictionary object is passed if isinstance(label_dictionary, Dictionary): label_dictionary = label_dictionary.get_items() - if type(label_dictionary) == str: + if isinstance(label_dictionary, str): label_dictionary = [label_dictionary] # prepare dictionary of tags (without B- I- prefixes and without UNK) diff --git a/flair/tokenization.py b/flair/tokenization.py index fe2ac33eb..ab4c0d239 100644 --- a/flair/tokenization.py +++ b/flair/tokenization.py @@ -256,9 +256,9 @@ def combined_rule_prefixes() -> List[str]: r"/", # want to split at every slash r"(?<=[0-9])[+\-\*^](?=[0-9-])", rf"(?<=[{char_classes.ALPHA_LOWER}])\.(?=[{char_classes.ALPHA_UPPER}])", - r"(?<=[{a}]),(?=[{a}])".format(a=char_classes.ALPHA), - r'(?<=[{a}])[?";:=,.]*(?:{h})(?=[{a}])'.format(a=char_classes.ALPHA, h=char_classes.HYPHENS), - r"(?<=[{a}0-9])[:<>=/](?=[{a}])".format(a=char_classes.ALPHA), + rf"(?<=[{char_classes.ALPHA}]),(?=[{char_classes.ALPHA}])", + rf'(?<=[{char_classes.ALPHA}])[?";:=,.]*(?:{char_classes.HYPHENS})(?=[{char_classes.ALPHA}])', + rf"(?<=[{char_classes.ALPHA}0-9])[:<>=/](?=[{char_classes.ALPHA}])", ] ) diff --git a/flair/trainers/language_model_trainer.py b/flair/trainers/language_model_trainer.py index 97915f07f..a596a7979 100644 --- a/flair/trainers/language_model_trainer.py +++ b/flair/trainers/language_model_trainer.py @@ -56,7 +56,7 @@ def __len__(self) -> int: def __getitem__(self, index=0) -> torch.Tensor: """Tokenizes a text file on character basis.""" - if type(self.files[index]) is str: + if isinstance(self.files[index], str): self.files[index] = Path(self.files[index]) assert self.files[index].exists() @@ -444,7 +444,7 @@ def load_checkpoint( corpus: TextCorpus, optimizer: Type[Optimizer] = SGD, ): - if type(checkpoint_file) is str: + if isinstance(checkpoint_file, str): checkpoint_file = Path(checkpoint_file) checkpoint = LanguageModel.load_checkpoint(checkpoint_file) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 8e25fb067..3cf96f559 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -139,7 +139,7 @@ def train( # evaluation and monitoring main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), monitor_test: bool = False, - monitor_train_sample: Union[float, int] = 0.0, + monitor_train_sample: float = 0.0, use_final_model_for_eval: bool = False, gold_label_dictionary_for_eval: Optional[Dictionary] = None, exclude_labels: List[str] = [], @@ -211,7 +211,7 @@ def fine_tune( # evaluation and monitoring main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), monitor_test: bool = False, - monitor_train_sample: Union[float, int] = 0.0, + monitor_train_sample: float = 0.0, use_final_model_for_eval: bool = True, gold_label_dictionary_for_eval: Optional[Dictionary] = None, exclude_labels: List[str] = [], @@ -302,7 +302,7 @@ def train_custom( # evaluation and monitoring main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), monitor_test: bool = False, - monitor_train_sample: Union[float, int] = 0.0, + monitor_train_sample: float = 0.0, use_final_model_for_eval: bool = False, gold_label_dictionary_for_eval: Optional[Dictionary] = None, exclude_labels: List[str] = [], diff --git a/flair/training_utils.py b/flair/training_utils.py index f36f95a91..e465f86c1 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -71,7 +71,7 @@ def to_tsv(self): @staticmethod def tsv_header(prefix=None): if prefix: - return "{0}_MEAN_SQUARED_ERROR\t{0}_MEAN_ABSOLUTE_ERROR\t{0}_PEARSON\t{0}_SPEARMAN".format(prefix) + return f"{prefix}_MEAN_SQUARED_ERROR\t{prefix}_MEAN_ABSOLUTE_ERROR\t{prefix}_PEARSON\t{prefix}_SPEARMAN" return "MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN" @@ -99,7 +99,7 @@ class EvaluationMetric(Enum): class WeightExtractor: def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) -> None: - if type(directory) is str: + if isinstance(directory, str): directory = Path(directory) self.weights_file = init_output_file(directory, "weights.txt") self.weights_dict: Dict[str, Dict[int, List[float]]] = defaultdict(lambda: defaultdict(list)) diff --git a/flair/visual/training_curves.py b/flair/visual/training_curves.py index f9a3c224b..1fd856b66 100644 --- a/flair/visual/training_curves.py +++ b/flair/visual/training_curves.py @@ -67,7 +67,7 @@ def _extract_evaluation_data(file_name: Union[str, Path], score: str = "F1") -> @staticmethod def _extract_weight_data(file_name: Union[str, Path]) -> dict: - if type(file_name) is str: + if isinstance(file_name, str): file_name = Path(file_name) weights: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) @@ -86,7 +86,7 @@ def _extract_weight_data(file_name: Union[str, Path]) -> dict: @staticmethod def _extract_learning_rate(file_name: Union[str, Path]): - if type(file_name) is str: + if isinstance(file_name, str): file_name = Path(file_name) lrs = [] diff --git a/tests/model_test_utils.py b/tests/model_test_utils.py index 3b936d862..10aab0831 100644 --- a/tests/model_test_utils.py +++ b/tests/model_test_utils.py @@ -204,7 +204,7 @@ def test_train_load_use_model_multi_label( print(label) assert label.value is not None assert 0.0 <= label.score <= 1.0 - assert type(label.score) is float + assert isinstance(label.score, float) del trainer, model, multi_class_corpus loaded_model = self.model_cls.load(results_base_path / "final-model.pt") diff --git a/tests/test_multitask.py b/tests/test_multitask.py index 02dc42c1b..7a43e45a8 100644 --- a/tests/test_multitask.py +++ b/tests/test_multitask.py @@ -63,5 +63,5 @@ def test_train_load_use_classifier(results_base_path, tasks_base_path): for label in sentence.labels: assert label.value is not None assert 0.0 <= label.score <= 1.0 - assert type(label.score) is float + assert isinstance(label.score, float) del loaded_model