Skip to content

add BINDER model

add BINDER model #1956

Triggered via pull request August 22, 2023 11:07
Status Failure
Total duration 24m 58s
Artifacts

ci.yml

on: pull_request
Fit to window
Zoom out
Zoom in

Annotations

6 errors
test: flair/__init__.py#L1
mypy-status mypy exited with status 1.
test: flair/models/binder_model.py#L1
Black format check --- /home/runner/work/flair/flair/flair/models/binder_model.py 2023-08-22 11:07:45.415345+00:00 +++ /home/runner/work/flair/flair/flair/models/binder_model.py 2023-08-22 11:14:02.362458+00:00 @@ -29,11 +29,11 @@ log_probs = log_probs[batch_indices, start_positions, end_positions] else: log_probs = masked_log_softmax(scores, mask) batch_indices = list(range(batch_size)) log_probs = log_probs[batch_indices, positions] - return - log_probs.mean() + return -log_probs.mean() def masked_log_softmax(vector: torch.Tensor, mask: torch.BoolTensor, dim: int = -1) -> torch.Tensor: if mask is not None: while mask.dim() < vector.dim(): @@ -52,10 +52,11 @@ return 1e-13 elif dtype == torch.half: return 1e-4 else: raise TypeError("Does not support dtype " + str(dtype)) + class BinderModel(flair.nn.Classifier[Sentence]): """This model implements the BINDER architecture for token classification using contrastive learning and a bi-encoder. Paper: https://openreview.net/forum?id=9EAQVEINuum """ @@ -98,17 +99,18 @@ self.token_start_linear = torch.nn.Linear(token_embedding_size, linear_size) self.token_end_linear = torch.nn.Linear(token_embedding_size, linear_size) self.max_span_width = max_span_width if use_span_width_embeddings: - self.token_span_linear = torch.nn.Linear( - self.token_encoder.embedding_length + linear_size, linear_size + self.token_span_linear = torch.nn.Linear(self.token_encoder.embedding_length + linear_size, linear_size) + assert ( + self.token_encoder.model.config.max_position_embeddings + == self.label_encoder.model.config.max_position_embeddings + ), ( + "The maximum position embeddings for the token encoder and label encoder must be the same when using " + "span width embeddings." ) - assert (self.token_encoder.model.config.max_position_embeddings == - self.label_encoder.model.config.max_position_embeddings), \ - "The maximum position embeddings for the token encoder and label encoder must be the same when using " \ - "span width embeddings." span_width = self.token_encoder.model.config.max_position_embeddings self.width_embeddings = torch.nn.Embedding(span_width, linear_size, padding_idx=0) else: self.token_span_linear = torch.nn.Linear(self.token_encoder.embedding_length, linear_size) self.width_embeddings = None @@ -142,14 +144,11 @@ assert start_scores.shape == end_scores.shape, "Start and end scores must have the same shape." assert span_scores.shape[2] == span_scores.shape[3], "Span scores must be square." # Get spans - start_mask, end_mask, span_mask = self._get_masks( - start_scores.size(), - lengths - ) + start_mask, end_mask, span_mask = self._get_masks(start_scores.size(), lengths) # Calculate loss total_loss, num_spans = self._calculate_loss( data_points, start_scores, end_scores, span_scores, start_mask, end_mask, span_mask ) @@ -197,11 +196,13 @@ batch_size, org_seq_length, _ = token_hidden_states.size() seq_length = org_seq_length * 2 + 1 num_types, _ = label_hidden_states.size() hidden_size = self.token_encoder.embedding_length // 2 - token_hidden_states = token_hidden_states.view(batch_size, org_seq_length, 2, hidden_size).view(batch_size, 2 * org_seq_length, hidden_size) + token_hidden_states = token_hidden_states.view(batch_size, org_seq_length, 2, hidden_size).view( + batch_size, 2 * org_seq_length, hidden_size + ) cls_hidden_state = torch.stack([sentence.get_embedding() for sentence in sentences]) # Shape: batch_size x seq_length * 2 (start + end hidden state per token) + 1
test: flair/models/binder_model.py#L341
ruff pytest_ruff.RuffError: flair/models/binder_model.py:1:1: I001 [*] Import block is un-sorted or un-formatted | 1 | / from typing import List, Tuple 2 | | import logging 3 | | 4 | | import numpy as np 5 | | import torch 6 | | import torch.nn.functional as F 7 | | 8 | | import flair 9 | | from flair.data import Sentence, Dictionary, DT, Span, Union, Optional 10 | | from flair.embeddings import TokenEmbeddings, DocumentEmbeddings 11 | | from flair.training_utils import store_embeddings 12 | | 13 | | log = logging.getLogger("flair") | |_^ I001 | = help: Organize imports flair/models/binder_model.py:59:5: D205 1 blank line required between summary line and description | 58 | class BinderModel(flair.nn.Classifier[Sentence]): 59 | """This model implements the BINDER architecture for token classification using contrastive learning and a bi-encoder. | _____^ 60 | | Paper: https://openreview.net/forum?id=9EAQVEINuum 61 | | """ | |_______^ D205 62 | 63 | def __init__( | = help: Insert single blank line flair/models/binder_model.py:59:5: D415 [*] First line should end with a period, question mark, or exclamation point | 58 | class BinderModel(flair.nn.Classifier[Sentence]): 59 | """This model implements the BINDER architecture for token classification using contrastive learning and a bi-encoder. | _____^ 60 | | Paper: https://openreview.net/forum?id=9EAQVEINuum 61 | | """ | |_______^ D415 62 | 63 | def __init__( | = help: Add closing punctuation flair/models/binder_model.py:84:52: E712 [*] Comparison to `True` should be `cond is True` or `if cond:` | 82 | raise RuntimeError("The token encoder must use first_last subtoken pooling when using BINDER model.") 83 | 84 | if not token_encoder.document_embedding == True: | ^^^^ E712 85 | raise RuntimeError("The token encoder must include the CLS token when using BINDER model.") | = help: Replace with `cond is True` flair/models/binder_model.py:108:17: ISC002 Implicitly concatenated string literals over multiple lines | 106 | assert (self.token_encoder.model.config.max_position_embeddings == 107 | self.label_encoder.model.config.max_position_embeddings), \ 108 | "The maximum position embeddings for the token encoder and label encoder must be the same when using " \ | _________________^ 109 | | "span width embeddings." | |________________________________________^ ISC002 110 | span_width = self.token_encoder.model.config.max_position_embeddings 111 | self.width_embeddings = torch.nn.Embedding(span_width, linear_size, padding_idx=0) | flair/models/binder_model.py:417:66: Q000 [*] Single quotes found but double quotes preferred | 415 | def _remove_overlaps(predictions): 416 | # Sort the predictions based on the start values 417 | sorted_predictions = sorted(predictions, key=lambda x: x['start']) | ^^^^^^^ Q000 418 | 419 | # Initialize a list to store non-overlapping predictions | = help: Replace single quotes with double quotes flair/models/binder_model.py:428:31: Q000 [*] Single quotes found but double quotes preferred | 426 | # Check for overlap with the last prediction in the non-overlapping list 427 | last_prediction = non_overlapping[-1] 428 | if prediction['start'] > last_prediction['end']: | ^^^^^^^ Q000 429 | non_overlapping.append(prediction) 430 | else: | = help: Replace single quotes with double quotes flair/models/binder_model.py:428:58: Q000 [*] Single quotes found but double quotes preferred | 426 | # Check for overlap with the last prediction in the non-overlapping list 427 | last_predic
test: flair/models/binder_model.py#L1
flair/models/binder_model.py 19: error: Incompatible default for argument "mask" (default has type "None", argument has type "FloatTensor") [assignment] 19: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True 19: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase 23: error: Incompatible types in assignment (expression has type "Tensor", variable has type "FloatTensor") [assignment] 24: error: Incompatible types in assignment (expression has type "Tensor", variable has type "FloatTensor") [assignment] 25: error: Argument 2 to "masked_log_softmax" has incompatible type "FloatTensor"; expected "BoolTensor" [arg-type] 31: error: Argument 2 to "masked_log_softmax" has incompatible type "FloatTensor"; expected "BoolTensor" [arg-type] 40: error: Incompatible types in assignment (expression has type "Tensor", variable has type "BoolTensor") [assignment] 106: error: Item "Tensor" of "Union[Tensor, Module]" has no attribute "config" [union-attr] 106: error: Item "Tensor" of "Union[Any, Tensor, Module]" has no attribute "max_position_embeddings" [union-attr] 107: error: Item "Tensor" of "Union[Tensor, Module]" has no attribute "config" [union-attr] 107: error: Item "Tensor" of "Union[Any, Tensor, Module]" has no attribute "max_position_embeddings" [union-attr] 110: error: Item "Tensor" of "Union[Tensor, Module]" has no attribute "config" [union-attr] 110: error: Item "Tensor" of "Union[Any, Tensor, Module]" has no attribute "max_position_embeddings" [union-attr] 111: error: Argument 1 to "Embedding" has incompatible type "Union[Any, Tensor, Module]"; expected "int" [arg-type] 114: error: Incompatible types in assignment (expression has type "None", variable has type "Embedding") [assignment] 135: error: "DT" has no attribute "get_spans" [attr-defined] 138: error: List item 0 has incompatible type "List[DT]"; expected "DT" [list-item] 230: error: Module has no attribute "LongTensor" [attr-defined] 364: error: Argument 1 to "enumerate" has incompatible type "Union[List[DT], DT]"; expected "Iterable[DT]" [arg-type] 365: error: Argument 1 to "remove_labels" of "DataPoint" has incompatible type "Optional[str]"; expected "str" [arg-type] 374: error: "DT" has no attribute "tokens" [attr-defined] 405: error: Argument 1 to "store_embeddings" has incompatible type "Union[List[DT], DT]"; expected "Union[List[DT], Dataset[Any]]" [arg-type]
test: flair/nn/model.py#L341
ruff pytest_ruff.RuffError: flair/nn/model.py:663:12: E721 Do not compare types, use `isinstance()` | 661 | @multi_label_threshold.setter 662 | def multi_label_threshold(self, x): # setter method 663 | if type(x) is dict: | ^^^^^^^^^^^^^^^ E721 664 | if "default" in x: 665 | self._multi_label_threshold = x |
test
Process completed with exit code 1.