diff --git a/lrtc_lib/active_learning/core/strategy/discriminative_representation_sampling.py b/lrtc_lib/active_learning/core/strategy/discriminative_representation_sampling.py index f7bd8b5..077101a 100644 --- a/lrtc_lib/active_learning/core/strategy/discriminative_representation_sampling.py +++ b/lrtc_lib/active_learning/core/strategy/discriminative_representation_sampling.py @@ -21,7 +21,7 @@ class DiscriminativeRepresentationSampling(ActiveLearner): - def __init__(self, max_to_consider=10 ** 6): + def __init__(self, max_to_consider=5 * (10 ** 4)): self.max_to_consider = max_to_consider self.sub_batches = 5 diff --git a/lrtc_lib/data/load_dataset.py b/lrtc_lib/data/load_dataset.py index c79307b..47399e0 100644 --- a/lrtc_lib/data/load_dataset.py +++ b/lrtc_lib/data/load_dataset.py @@ -2,12 +2,13 @@ # LICENSE: Apache License 2.0 (Apache-2.0) # http://www.apache.org/licenses/LICENSE-2.0 - +import json import logging from lrtc_lib.data_access import single_dataset_loader from lrtc_lib.data_access.processors.dataset_part import DatasetPart from lrtc_lib.oracle_data_access import gold_labels_loader +import lrtc_lib.oracle_data_access.core.utils as oracle_utils logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s') @@ -26,6 +27,19 @@ def load(dataset: str, force_new: bool = False): logging.info('-' * 60) +def clear_labels(dataset_name): + f = oracle_utils.get_labels_dump_filename(dataset_name+"_train") + with open(f, "r") as fp: + json_dump = json.load(fp) + for uri in json_dump.keys(): + json_dump[uri] = {} + with open(f, "w") as fp: + json.dump(json_dump, fp) + + if __name__ == '__main__': - dataset_name = 'polarity' - load(dataset=dataset_name) \ No newline at end of file + dataset_name = 'hemnet_descriptions' + #load(dataset=dataset_name, force_new=True) + clear_labels(dataset_name) + + diff --git a/lrtc_lib/data_access/data_access_in_memory.py b/lrtc_lib/data_access/data_access_in_memory.py index a7b00ec..53223e5 100644 --- a/lrtc_lib/data_access/data_access_in_memory.py +++ b/lrtc_lib/data_access/data_access_in_memory.py @@ -74,6 +74,9 @@ def set_labels(self, workspace_id: str, texts_and_labels: Sequence[Tuple[str, Ma :param propagate_to_duplicates: if True, also set the same labels for additional URIs that are duplicates of the URIs provided. """ + if not texts_and_labels or not texts_and_labels[0]: + return {} + dataset_name = utils.get_dataset_name(texts_and_labels[0][0]) logic.labels_in_memory[workspace_id][dataset_name] = logic.get_labels(workspace_id, dataset_name) if propagate_to_duplicates: diff --git a/lrtc_lib/data_access/processors/data_processor_factory.py b/lrtc_lib/data_access/processors/data_processor_factory.py index 9d33978..9693bc2 100644 --- a/lrtc_lib/data_access/processors/data_processor_factory.py +++ b/lrtc_lib/data_access/processors/data_processor_factory.py @@ -2,7 +2,7 @@ # LICENSE: Apache License 2.0 (Apache-2.0) # http://www.apache.org/licenses/LICENSE-2.0 - +from data_access.processors.process_hemnet_descriptions import HemnetDescriptionsProcessor from lrtc_lib.data_access.processors.process_ag_new_data import AgNewsProcessor from lrtc_lib.data_access.processors.process_cola_data import ColaProcessor from lrtc_lib.data_access.processors.process_isear_data import IsearProcessor @@ -22,6 +22,8 @@ def get_data_processor(dataset_name: str) -> DataProcessorAPI: dataset_source, dataset_part = parse_dataset_name(dataset_name=dataset_name) if dataset_source == 'trec_50': return TrecProcessor(dataset_part=dataset_part, use_fine_grained_labels=True) + if dataset_source == 'hemnet_descriptions': + return HemnetDescriptionsProcessor(dataset_part=dataset_part) if dataset_source == 'trec': return TrecProcessor(dataset_part=dataset_part, use_fine_grained_labels=False) if dataset_source == 'isear': diff --git a/lrtc_lib/data_access/processors/process_csv_data.py b/lrtc_lib/data_access/processors/process_csv_data.py index 1cb6ee8..aaeeded 100644 --- a/lrtc_lib/data_access/processors/process_csv_data.py +++ b/lrtc_lib/data_access/processors/process_csv_data.py @@ -33,7 +33,8 @@ class CsvProcessor(DataProcessorAPI): """ - def __init__(self, dataset_name: str, dataset_part: DatasetPart, text_col: str = 'text', + def __init__(self, dataset_name: str, dataset_part: DatasetPart, + text_col: str = 'text', label_col: str = 'label', context_col: str = None, doc_id_col: str = None, encoding: str = 'utf-8'): diff --git a/lrtc_lib/data_access/processors/process_hemnet_descriptions.py b/lrtc_lib/data_access/processors/process_hemnet_descriptions.py new file mode 100644 index 0000000..2031694 --- /dev/null +++ b/lrtc_lib/data_access/processors/process_hemnet_descriptions.py @@ -0,0 +1,20 @@ +# (c) Copyright IBM Corporation 2020. + +# LICENSE: Apache License 2.0 (Apache-2.0) +# http://www.apache.org/licenses/LICENSE-2.0 + +import os +import pandas as pd + +from lrtc_lib.data_access.processors.dataset_part import DatasetPart +from lrtc_lib.data_access.processors.process_csv_data import CsvProcessor + + +class HemnetDescriptionsProcessor(CsvProcessor): + + def __init__(self, dataset_part: DatasetPart): + super().__init__(dataset_name='hemnet_descriptions', + text_col='sentence', + doc_id_col='listing_id_sentence_idx', + dataset_part=dataset_part) + diff --git a/lrtc_lib/data_access/single_dataset_loader.py b/lrtc_lib/data_access/single_dataset_loader.py index 4184eda..872741f 100644 --- a/lrtc_lib/data_access/single_dataset_loader.py +++ b/lrtc_lib/data_access/single_dataset_loader.py @@ -58,9 +58,7 @@ def clear_all_saved_files(dataset_name): if __name__ == '__main__': - all_dataset_sources = ['ag_news', 'ag_news_imbalanced_1', 'cola', 'isear', - 'polarity', 'polarity_imbalanced_positive', - 'subjectivity', 'subjectivity_imbalanced_subjective', 'trec', 'wiki_attack'] + all_dataset_sources = ['hemnet_descriptions', 'trec'] for dataset_source in all_dataset_sources: for part in DatasetPart: diff --git a/lrtc_lib/experiment_runners/experiment_runner.py b/lrtc_lib/experiment_runners/experiment_runner.py index 60108a3..a108fc8 100644 --- a/lrtc_lib/experiment_runners/experiment_runner.py +++ b/lrtc_lib/experiment_runners/experiment_runner.py @@ -46,7 +46,7 @@ class ExperimentParams: def compute_batch_scores(config, elements): data_access = get_data_access() unlabeled = data_access.sample_unlabeled_text_elements(config.workspace_id, config.train_dataset_name, - config.category_name, 10 ** 6)["results"] + config.category_name, 5*(10 ** 4))["results"] unlabeled_emb = np.array(orchestrator_api.infer(config.workspace_id, config.category_name, unlabeled)["embeddings"]) batch_emb = np.array(orchestrator_api.infer(config.workspace_id, config.category_name, elements)["embeddings"]) @@ -122,12 +122,14 @@ def train_first_model(self, config: ExperimentParams): if orchestrator_api.workspace_exists(config.workspace_id): orchestrator_api.delete_workspace(config.workspace_id) + config.dev_dataset_name = None + orchestrator_api.create_workspace(config.workspace_id, config.train_dataset_name, dev_dataset_name=config.dev_dataset_name) orchestrator_api.create_new_category(config.workspace_id, config.category_name, "No description for you") - dev_text_elements_uris = orchestrator_api.get_all_text_elements_uris(config.dev_dataset_name) - dev_text_elements_and_labels = oracle_data_access_api.get_gold_labels(config.dev_dataset_name, + dev_text_elements_uris = orchestrator_api.get_all_text_elements_uris(config.test_dataset_name) + dev_text_elements_and_labels = oracle_data_access_api.get_gold_labels(config.test_dataset_name, dev_text_elements_uris) if dev_text_elements_and_labels is not None: orchestrator_api.set_labels(config.workspace_id, dev_text_elements_and_labels) @@ -141,7 +143,10 @@ def train_first_model(self, config: ExperimentParams): # train first model logging.info(f'Starting first model training (model: {config.model.name})\tworkspace: {config.workspace_id}') - new_model_id = orchestrator_api.train(config.workspace_id, config.category_name, config.model, train_params=config.train_params) + new_model_id = orchestrator_api.train(config.workspace_id, + config.category_name, + config.model, + train_params=config.train_params) if new_model_id is None: raise Exception(f'a new model was not trained\tworkspace: {config.workspace_id}') @@ -192,8 +197,10 @@ def get_suggested_elements_and_gold_labels(self, config, al): f'for dataset: {config.train_dataset_name} and category: {config.category_name}.\t' f'runtime: {end - start}\tworkspace: {config.workspace_id}') uris_for_labeling = [elem.uri for elem in suggested_text_elements_for_labeling] - uris_and_gold_labels = oracle_data_access_api.get_gold_labels(config.train_dataset_name, uris_for_labeling, - config.category_name) + uris_and_gold_labels = oracle_data_access_api.get_gold_labels(config.train_dataset_name, + uris_for_labeling, + config.category_name, + self.data_access) return suggested_text_elements_for_labeling, uris_and_gold_labels def evaluate(self, config: ExperimentParams, al, iteration, eval_dataset, diff --git a/lrtc_lib/experiment_runners/experiment_runner_imbalanced_practical.py b/lrtc_lib/experiment_runners/experiment_runner_imbalanced_practical.py index 8d06ee2..bfb33f4 100644 --- a/lrtc_lib/experiment_runners/experiment_runner_imbalanced_practical.py +++ b/lrtc_lib/experiment_runners/experiment_runner_imbalanced_practical.py @@ -75,7 +75,9 @@ def set_first_model_positives(self, config, random_seed) -> List[TextElement]: sampled_uris = [t.uri for t in sampled_unlabeled_text_elements] sampled_uris_and_gold_labels = dict( - oracle_data_access_api.get_gold_labels(config.train_dataset_name, sampled_uris)) + oracle_data_access_api.get_gold_labels(config.train_dataset_name, sampled_uris, + category_name=config.category_name, + data_access=self.data_access)) sampled_uris_and_label = \ [(x.uri, {config.category_name: sampled_uris_and_gold_labels[x.uri][config.category_name]}) for x in sampled_unlabeled_text_elements] @@ -143,17 +145,23 @@ def count_query_matches_in_elements(self, train_dataset_name, category_name, ele # define experiments parameters experiment_name = 'query_NB' - active_learning_iterations_num = 1 + active_learning_iterations_num = 3 num_experiment_repeats = 1 # for full list of datasets and categories available run: python -m lrtc_lib.data_access.loaded_datasets_info - datasets_categories_and_queries = {'trec': {'LOC': ['Where|countr.*|cit.*']}} - classification_models = [ModelTypes.NB] + datasets_categories_and_queries = { + #'trec': {'LOC': ['Where|countr.*|cit.*']} + 'hemnet_descriptions': {'needs_renovation': ['renoveringsbehov|drömhem']} + } + classification_models = [ + ModelTypes.HFBERT + #ModelTypes.NB + ] train_params = {ModelTypes.HFBERT: {"metric": "f1"}, ModelTypes.NB: {}} - active_learning_strategies = [ActiveLearningStrategies.RANDOM, ActiveLearningStrategies.HARD_MINING] + active_learning_strategies = [ActiveLearningStrategies.DAL] - experiments_runner = ExperimentRunnerImbalancedPractical(first_model_labeled_from_query_num=100, - first_model_negatives_num=100, - active_learning_suggestions_num=50, + experiments_runner = ExperimentRunnerImbalancedPractical(first_model_labeled_from_query_num=150, + first_model_negatives_num=150, + active_learning_suggestions_num=20, queries_per_dataset=datasets_categories_and_queries) results_file_path, results_file_path_aggregated = res_handler.get_results_files_paths( diff --git a/lrtc_lib/oracle_data_access/core/utils.py b/lrtc_lib/oracle_data_access/core/utils.py index 8e3b32a..c92c77f 100644 --- a/lrtc_lib/oracle_data_access/core/utils.py +++ b/lrtc_lib/oracle_data_access/core/utils.py @@ -4,13 +4,16 @@ # http://www.apache.org/licenses/LICENSE-2.0 import ujson as json +import json as json2 import os import random -from typing import Mapping +from typing import Mapping, List +from data_access.data_access_api import DataAccessApi from lrtc_lib.definitions import ROOT_DIR from lrtc_lib.definitions import PROJECT_PROPERTIES from lrtc_lib.data_access.core.data_structs import nested_default_dict, Label +from orchestrator.orchestrator_api import LABEL_POSITIVE, LABEL_NEGATIVE gold_labels_per_dataset: (str, nested_default_dict()) = None # (dataset, URIs -> categories -> Label) @@ -30,6 +33,15 @@ def get_gold_labels(dataset_name: str, category_name: str = None) -> Mapping[st categories. :return: # URIs -> categories -> Label """ + uri_categories_and_labels_map = _read_gold_labels(dataset_name)[1] + + if category_name is not None: + data_view_func = PROJECT_PROPERTIES["data_view_func"] + uri_categories_and_labels_map = data_view_func(category_name, uri_categories_and_labels_map) + return uri_categories_and_labels_map + + +def _read_gold_labels(dataset_name): global gold_labels_per_dataset if gold_labels_per_dataset is None or gold_labels_per_dataset[0] != dataset_name: # not in memory @@ -43,11 +55,56 @@ def get_gold_labels(dataset_name: str, category_name: str = None) -> Mapping[st else: # or create an empty in-memory gold_labels_per_dataset = (dataset_name, nested_default_dict()) - uri_categories_and_labels_map = gold_labels_per_dataset[1] - if category_name is not None: - data_view_func = PROJECT_PROPERTIES["data_view_func"] - uri_categories_and_labels_map = data_view_func(category_name, uri_categories_and_labels_map) - return uri_categories_and_labels_map + return gold_labels_per_dataset + + +def create_gold_labels_online(dataset_name: str, + category_name: str, + text_element_uris: List[str], + data_access: DataAccessApi) -> None: + """ + { + "trec_dev-0-0": { + "ABBR": { + "labels": [ + "false" + ], + "metadata": {} + }, + "DESC": { + "labels": [ + "true" + ], + "metadata": {} + }, + + """ + assert category_name + + doc_uris = [uri[:-2] for uri in text_element_uris] + docs = data_access.get_documents(dataset_name, doc_uris) + + i = 0 + + for doc in docs: + for text_element in doc.text_elements: + if text_element.uri in text_element_uris: + i += 1 + print("-" * 30, f"{i}/{len(text_element_uris)}", "-" * 30) + print(text_element.text) + label_input = input() + label = LABEL_POSITIVE if label_input else LABEL_NEGATIVE + with open(get_labels_dump_filename(dataset_name), "r") as json_file: + text_and_gold_labels_encoded = json.load(json_file) + text_and_gold_labels_encoded[text_element.uri] = text_and_gold_labels_encoded.get( + text_element.uri, {}) + text_and_gold_labels_encoded[text_element.uri][category_name] = { + "labels": [label], + "metadata": {} + } + print("-->", label) + with open(get_labels_dump_filename(dataset_name), "w") as json_file: + json2.dump(text_and_gold_labels_encoded, json_file) def sample(dataset_name: str, category_name: str, sample_size: int, random_seed: int, restrict_label: str = None): diff --git a/lrtc_lib/oracle_data_access/oracle_data_access_api.py b/lrtc_lib/oracle_data_access/oracle_data_access_api.py index 6ec766a..3deddd1 100644 --- a/lrtc_lib/oracle_data_access/oracle_data_access_api.py +++ b/lrtc_lib/oracle_data_access/oracle_data_access_api.py @@ -9,8 +9,10 @@ from typing import Sequence, List, Mapping, Tuple, Set import lrtc_lib.oracle_data_access.core.utils as oracle_utils +from data_access.data_access_api import DataAccessApi from lrtc_lib.data_access.core.data_structs import Label from lrtc_lib.orchestrator.orchestrator_api import LABEL_POSITIVE, LABEL_NEGATIVE +from oracle_data_access.core.utils import create_gold_labels_online def add_gold_labels(dataset_name: str, text_and_gold_labels: List[Tuple[str, Mapping[str, Label]]]): @@ -34,7 +36,8 @@ def add_gold_labels(dataset_name: str, text_and_gold_labels: List[Tuple[str, Map f.write(gold_labels_encoded) -def get_gold_labels(dataset_name: str, text_element_uris: Sequence[str], category_name: str = None) -> \ +def get_gold_labels(dataset_name: str, text_element_uris: Sequence[str], category_name: str = None, + data_access: DataAccessApi = None) -> \ List[Tuple[str, Mapping[str, Label]]]: """ Return the gold labels information for the given TextElements uris, keeping the same order, for the given dataset. @@ -48,8 +51,19 @@ def get_gold_labels(dataset_name: str, text_element_uris: Sequence[str], categor same order as the order of the TextElement uris given as input. """ - gold_labels = oracle_utils.get_gold_labels(dataset_name, category_name) - return [(uri, gold_labels[uri]) for uri in text_element_uris if gold_labels[uri]] + gold_labels = oracle_utils.get_gold_labels(dataset_name, None) + uris_missing = [uri for uri in text_element_uris + if uri not in gold_labels or not gold_labels[uri]] + + if uris_missing: + create_gold_labels_online(dataset_name, category_name, uris_missing, data_access) + + gold_labels = oracle_utils.get_gold_labels(dataset_name, None) + gold_labels_dataset = [(uri, gold_labels[uri]) + for uri in text_element_uris + if uri in gold_labels and gold_labels[uri]] + + return gold_labels_dataset def sample(dataset_name: str, category_name: str, sample_size: int, random_seed: int): diff --git a/lrtc_lib/orchestrator/orchestrator_api.py b/lrtc_lib/orchestrator/orchestrator_api.py index d38a230..4318fb3 100644 --- a/lrtc_lib/orchestrator/orchestrator_api.py +++ b/lrtc_lib/orchestrator/orchestrator_api.py @@ -6,6 +6,7 @@ import glob import logging import os +import random import traceback from collections import Counter from enum import Enum @@ -338,7 +339,7 @@ def train(workspace_id: str, category_name: str, model_type: ModelType, train_pa dataset_name = workspace.dataset_name (train_data, train_counts), (dev_data, dev_counts) = train_and_dev_sets_selector.get_train_and_dev_sets( workspace_id=workspace_id, train_dataset_name=dataset_name, category_name=category_name, - dev_dataset_name=workspace.dev_dataset_name) + dev_dataset_name=workspace.test_dataset_name or 'hemnet_descriptions_test') logging.info(f"training a new model with {train_counts}") # label_counts != train_counts as train_counts may refer to negative and weak negative labels separately @@ -443,6 +444,11 @@ def infer(workspace_id: str, category_name: str, texts_to_infer: Sequence[TextEl train_and_infer = PROJECT_PROPERTIES["train_and_infer_factory"].get_train_and_infer(model.model_type) list_of_dicts = [{"text": element.text} for element in texts_to_infer] + + #n_sample = 10000 + #if len(list_of_dicts) > n_sample: + # list_of_dicts = random.sample(list_of_dicts, n_sample) + infer_results = train_and_infer.infer(model_id=model.model_id, items_to_infer=list_of_dicts, infer_params=infer_params, use_cache=use_cache) diff --git a/lrtc_lib/requirements.txt b/lrtc_lib/requirements.txt index b7fa4ef..7cd583b 100644 --- a/lrtc_lib/requirements.txt +++ b/lrtc_lib/requirements.txt @@ -4,8 +4,8 @@ numpy==1.16.4 dataclasses==0.6 scikit-learn==0.21.3 jsonpickle==1.3 -tensorflow>=2.1.0,<2.2.0 -transformers==2.5.1 +tensorflow>=2.2.0,<2.3.0 +transformers==3.5.1 mip>=1.7.3 filelock==3.0.12 wrapt_timeout_decorator==1.3.1 @@ -13,3 +13,4 @@ nose==1.3.7 ujson==3.1.0 seaborn==0.11.0 h5py==2.10.0 +torch==1.7.1 \ No newline at end of file diff --git a/lrtc_lib/train_and_infer_service/train_and_infer_hf.py b/lrtc_lib/train_and_infer_service/train_and_infer_hf.py index 2afad14..a7b03dd 100644 --- a/lrtc_lib/train_and_infer_service/train_and_infer_hf.py +++ b/lrtc_lib/train_and_infer_service/train_and_infer_hf.py @@ -2,7 +2,7 @@ # LICENSE: Apache License 2.0 (Apache-2.0) # http://www.apache.org/licenses/LICENSE-2.0 - +import datetime import logging import os import pickle @@ -17,8 +17,8 @@ from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.keras.callbacks import EarlyStopping, ModelCheckpoint from tensorflow.python.keras.engine import data_adapter -from transformers import BertTokenizer, TFBertForSequenceClassification, InputFeatures -from tensorflow.python.keras.mixed_precision.experimental import (loss_scale_optimizer as lso) +from transformers import AutoTokenizer, TFBertForSequenceClassification, InputFeatures +from tensorflow.keras.mixed_precision import LossScaleOptimizer from tensorflow.keras import backend as K from lrtc_lib.definitions import ROOT_DIR @@ -26,10 +26,10 @@ MODEL_DIR = os.path.join(ROOT_DIR, "output", "models", "transformers") HF_CACHE_DIR = os.path.join(ROOT_DIR, "output", "temp", "hf_cache") - +HF_MODEL_ID = 'KBLab/sentence-bert-swedish-cased' class TrainAndInferHF(TrainAndInferAPI): - def __init__(self, batch_size, infer_batch_size=10, learning_rate=5e-5, debug=False, model_dir=MODEL_DIR, + def __init__(self, batch_size, infer_batch_size=128, learning_rate=5e-5, debug=False, model_dir=MODEL_DIR, infer_with_cls=False): """ :param batch_size: @@ -50,7 +50,7 @@ def __init__(self, batch_size, infer_batch_size=10, learning_rate=5e-5, debug=Fa self.tokenizer = self.get_tokenizer() # Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule self.learning_rate = learning_rate - self.max_length = 100 + self.max_length = 64 self.batch_size = batch_size if infer_batch_size == -1: self.infer_batch_size = batch_size @@ -68,7 +68,7 @@ def __setstate__(self, d): self.__dict__["tokenizer"] = self.get_tokenizer() def get_tokenizer(self): - return BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=HF_CACHE_DIR) + return AutoTokenizer.from_pretrained(HF_MODEL_ID, cache_dir=HF_CACHE_DIR) def process_inputs(self, texts, labels=None, to_dataset=True): """ @@ -106,7 +106,7 @@ def train(self, train_data, dev_data, test_data, train_params: dict) -> str: fl.write("") # init - model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased', cache_dir=HF_CACHE_DIR) + model = TFBertForSequenceClassification.from_pretrained(HF_MODEL_ID, from_pt=True, cache_dir=HF_CACHE_DIR) model.config.output_hidden_states = True optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate, epsilon=1e-06) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) @@ -140,8 +140,11 @@ def train(self, train_data, dev_data, test_data, train_params: dict) -> str: os.makedirs(model_dir) model_checkpoint = ModelCheckpoint(model_dir, save_best_only=True, save_weights_only=True) - early_stopping = EarlyStopping(monitor="val_" + metric_name, patience=epochs) - history = model.fit(x=input, validation_data=dev_input, epochs=epochs, + early_stopping = EarlyStopping(monitor="val_" + metric_name, + patience=epochs) + history = model.fit(x=input, + validation_data=dev_input, + epochs=epochs, callbacks=[early_stopping, model_checkpoint]) with open(params_file, "wb") as fl: pickle.dump(self, fl) @@ -158,11 +161,14 @@ def train(self, train_data, dev_data, test_data, train_params: dict) -> str: @infer_with_cache def infer(self, model_id, items_to_infer, infer_params: dict, use_cache=True): logging.info("Inferring with hf model...") + ts_start = datetime.datetime.now() + items_to_infer = [x["text"] for x in items_to_infer] model = TFBertForSequenceClassification.from_pretrained(self.get_model_dir_by_id(model_id)) if self.debug: items_to_infer = items_to_infer[:self.infer_batch_size] + #items_to_infer = items_to_infer[:10000] input = self.process_inputs(items_to_infer).batch(self.infer_batch_size) if self.infer_with_cls: # get embeddings for CLS token in last hidden layer @@ -181,7 +187,8 @@ def infer(self, model_id, items_to_infer, infer_params: dict, use_cache=True): labels = [int(np.argmax(logit)) for logit in logits] predictions = softmax(logits, axis=1) scores = [float(prediction[label]) for label, prediction in zip(labels, predictions)] - logging.info("Infer hf model done") + duration_seconds = (datetime.datetime.now() - ts_start).total_seconds() + logging.info(f"Infer hf model done, took {duration_seconds} seconds") return {"labels": labels, "scores": scores, "logits": logits.numpy().tolist(), "embeddings": out_emb.numpy().tolist()} @@ -272,7 +279,7 @@ def _process_input_data(x, y, sample_weight, model): def _clip_scale_grads(strategy, tape, optimizer, loss, params): with tape: - if isinstance(optimizer, lso.LossScaleOptimizer): + if isinstance(optimizer, LossScaleOptimizer): loss = optimizer.get_scaled_loss(loss) gradients = tape.gradient(loss, params)