Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class DiscriminativeRepresentationSampling(ActiveLearner):

def __init__(self, max_to_consider=10 ** 6):
def __init__(self, max_to_consider=5 * (10 ** 4)):
self.max_to_consider = max_to_consider
self.sub_batches = 5

Expand Down
20 changes: 17 additions & 3 deletions lrtc_lib/data/load_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

# LICENSE: Apache License 2.0 (Apache-2.0)
# http://www.apache.org/licenses/LICENSE-2.0

import json
import logging

from lrtc_lib.data_access import single_dataset_loader
from lrtc_lib.data_access.processors.dataset_part import DatasetPart
from lrtc_lib.oracle_data_access import gold_labels_loader
import lrtc_lib.oracle_data_access.core.utils as oracle_utils

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')

Expand All @@ -26,6 +27,19 @@ def load(dataset: str, force_new: bool = False):
logging.info('-' * 60)


def clear_labels(dataset_name):
f = oracle_utils.get_labels_dump_filename(dataset_name+"_train")
with open(f, "r") as fp:
json_dump = json.load(fp)
for uri in json_dump.keys():
json_dump[uri] = {}
with open(f, "w") as fp:
json.dump(json_dump, fp)


if __name__ == '__main__':
dataset_name = 'polarity'
load(dataset=dataset_name)
dataset_name = 'hemnet_descriptions'
#load(dataset=dataset_name, force_new=True)
clear_labels(dataset_name)


3 changes: 3 additions & 0 deletions lrtc_lib/data_access/data_access_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def set_labels(self, workspace_id: str, texts_and_labels: Sequence[Tuple[str, Ma
:param propagate_to_duplicates: if True, also set the same labels for additional URIs that are duplicates
of the URIs provided.
"""
if not texts_and_labels or not texts_and_labels[0]:
return {}

dataset_name = utils.get_dataset_name(texts_and_labels[0][0])
logic.labels_in_memory[workspace_id][dataset_name] = logic.get_labels(workspace_id, dataset_name)
if propagate_to_duplicates:
Expand Down
4 changes: 3 additions & 1 deletion lrtc_lib/data_access/processors/data_processor_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# LICENSE: Apache License 2.0 (Apache-2.0)
# http://www.apache.org/licenses/LICENSE-2.0

from data_access.processors.process_hemnet_descriptions import HemnetDescriptionsProcessor
from lrtc_lib.data_access.processors.process_ag_new_data import AgNewsProcessor
from lrtc_lib.data_access.processors.process_cola_data import ColaProcessor
from lrtc_lib.data_access.processors.process_isear_data import IsearProcessor
Expand All @@ -22,6 +22,8 @@ def get_data_processor(dataset_name: str) -> DataProcessorAPI:
dataset_source, dataset_part = parse_dataset_name(dataset_name=dataset_name)
if dataset_source == 'trec_50':
return TrecProcessor(dataset_part=dataset_part, use_fine_grained_labels=True)
if dataset_source == 'hemnet_descriptions':
return HemnetDescriptionsProcessor(dataset_part=dataset_part)
if dataset_source == 'trec':
return TrecProcessor(dataset_part=dataset_part, use_fine_grained_labels=False)
if dataset_source == 'isear':
Expand Down
3 changes: 2 additions & 1 deletion lrtc_lib/data_access/processors/process_csv_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class CsvProcessor(DataProcessorAPI):

"""

def __init__(self, dataset_name: str, dataset_part: DatasetPart, text_col: str = 'text',
def __init__(self, dataset_name: str, dataset_part: DatasetPart,
text_col: str = 'text',
label_col: str = 'label', context_col: str = None,
doc_id_col: str = None,
encoding: str = 'utf-8'):
Expand Down
20 changes: 20 additions & 0 deletions lrtc_lib/data_access/processors/process_hemnet_descriptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# (c) Copyright IBM Corporation 2020.

# LICENSE: Apache License 2.0 (Apache-2.0)
# http://www.apache.org/licenses/LICENSE-2.0

import os
import pandas as pd

from lrtc_lib.data_access.processors.dataset_part import DatasetPart
from lrtc_lib.data_access.processors.process_csv_data import CsvProcessor


class HemnetDescriptionsProcessor(CsvProcessor):

def __init__(self, dataset_part: DatasetPart):
super().__init__(dataset_name='hemnet_descriptions',
text_col='sentence',
doc_id_col='listing_id_sentence_idx',
dataset_part=dataset_part)

4 changes: 1 addition & 3 deletions lrtc_lib/data_access/single_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def clear_all_saved_files(dataset_name):


if __name__ == '__main__':
all_dataset_sources = ['ag_news', 'ag_news_imbalanced_1', 'cola', 'isear',
'polarity', 'polarity_imbalanced_positive',
'subjectivity', 'subjectivity_imbalanced_subjective', 'trec', 'wiki_attack']
all_dataset_sources = ['hemnet_descriptions', 'trec']

for dataset_source in all_dataset_sources:
for part in DatasetPart:
Expand Down
19 changes: 13 additions & 6 deletions lrtc_lib/experiment_runners/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ExperimentParams:
def compute_batch_scores(config, elements):
data_access = get_data_access()
unlabeled = data_access.sample_unlabeled_text_elements(config.workspace_id, config.train_dataset_name,
config.category_name, 10 ** 6)["results"]
config.category_name, 5*(10 ** 4))["results"]
unlabeled_emb = np.array(orchestrator_api.infer(config.workspace_id, config.category_name, unlabeled)["embeddings"])
batch_emb = np.array(orchestrator_api.infer(config.workspace_id, config.category_name, elements)["embeddings"])

Expand Down Expand Up @@ -122,12 +122,14 @@ def train_first_model(self, config: ExperimentParams):
if orchestrator_api.workspace_exists(config.workspace_id):
orchestrator_api.delete_workspace(config.workspace_id)

config.dev_dataset_name = None

orchestrator_api.create_workspace(config.workspace_id, config.train_dataset_name,
dev_dataset_name=config.dev_dataset_name)
orchestrator_api.create_new_category(config.workspace_id, config.category_name, "No description for you")

dev_text_elements_uris = orchestrator_api.get_all_text_elements_uris(config.dev_dataset_name)
dev_text_elements_and_labels = oracle_data_access_api.get_gold_labels(config.dev_dataset_name,
dev_text_elements_uris = orchestrator_api.get_all_text_elements_uris(config.test_dataset_name)
dev_text_elements_and_labels = oracle_data_access_api.get_gold_labels(config.test_dataset_name,
dev_text_elements_uris)
if dev_text_elements_and_labels is not None:
orchestrator_api.set_labels(config.workspace_id, dev_text_elements_and_labels)
Expand All @@ -141,7 +143,10 @@ def train_first_model(self, config: ExperimentParams):

# train first model
logging.info(f'Starting first model training (model: {config.model.name})\tworkspace: {config.workspace_id}')
new_model_id = orchestrator_api.train(config.workspace_id, config.category_name, config.model, train_params=config.train_params)
new_model_id = orchestrator_api.train(config.workspace_id,
config.category_name,
config.model,
train_params=config.train_params)
if new_model_id is None:
raise Exception(f'a new model was not trained\tworkspace: {config.workspace_id}')

Expand Down Expand Up @@ -192,8 +197,10 @@ def get_suggested_elements_and_gold_labels(self, config, al):
f'for dataset: {config.train_dataset_name} and category: {config.category_name}.\t'
f'runtime: {end - start}\tworkspace: {config.workspace_id}')
uris_for_labeling = [elem.uri for elem in suggested_text_elements_for_labeling]
uris_and_gold_labels = oracle_data_access_api.get_gold_labels(config.train_dataset_name, uris_for_labeling,
config.category_name)
uris_and_gold_labels = oracle_data_access_api.get_gold_labels(config.train_dataset_name,
uris_for_labeling,
config.category_name,
self.data_access)
return suggested_text_elements_for_labeling, uris_and_gold_labels

def evaluate(self, config: ExperimentParams, al, iteration, eval_dataset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def set_first_model_positives(self, config, random_seed) -> List[TextElement]:

sampled_uris = [t.uri for t in sampled_unlabeled_text_elements]
sampled_uris_and_gold_labels = dict(
oracle_data_access_api.get_gold_labels(config.train_dataset_name, sampled_uris))
oracle_data_access_api.get_gold_labels(config.train_dataset_name, sampled_uris,
category_name=config.category_name,
data_access=self.data_access))
sampled_uris_and_label = \
[(x.uri, {config.category_name: sampled_uris_and_gold_labels[x.uri][config.category_name]})
for x in sampled_unlabeled_text_elements]
Expand Down Expand Up @@ -143,17 +145,23 @@ def count_query_matches_in_elements(self, train_dataset_name, category_name, ele

# define experiments parameters
experiment_name = 'query_NB'
active_learning_iterations_num = 1
active_learning_iterations_num = 3
num_experiment_repeats = 1
# for full list of datasets and categories available run: python -m lrtc_lib.data_access.loaded_datasets_info
datasets_categories_and_queries = {'trec': {'LOC': ['Where|countr.*|cit.*']}}
classification_models = [ModelTypes.NB]
datasets_categories_and_queries = {
#'trec': {'LOC': ['Where|countr.*|cit.*']}
'hemnet_descriptions': {'needs_renovation': ['renoveringsbehov|drömhem']}
}
classification_models = [
ModelTypes.HFBERT
#ModelTypes.NB
]
train_params = {ModelTypes.HFBERT: {"metric": "f1"}, ModelTypes.NB: {}}
active_learning_strategies = [ActiveLearningStrategies.RANDOM, ActiveLearningStrategies.HARD_MINING]
active_learning_strategies = [ActiveLearningStrategies.DAL]

experiments_runner = ExperimentRunnerImbalancedPractical(first_model_labeled_from_query_num=100,
first_model_negatives_num=100,
active_learning_suggestions_num=50,
experiments_runner = ExperimentRunnerImbalancedPractical(first_model_labeled_from_query_num=150,
first_model_negatives_num=150,
active_learning_suggestions_num=20,
queries_per_dataset=datasets_categories_and_queries)

results_file_path, results_file_path_aggregated = res_handler.get_results_files_paths(
Expand Down
69 changes: 63 additions & 6 deletions lrtc_lib/oracle_data_access/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
# http://www.apache.org/licenses/LICENSE-2.0

import ujson as json
import json as json2
import os
import random
from typing import Mapping
from typing import Mapping, List

from data_access.data_access_api import DataAccessApi
from lrtc_lib.definitions import ROOT_DIR
from lrtc_lib.definitions import PROJECT_PROPERTIES
from lrtc_lib.data_access.core.data_structs import nested_default_dict, Label
from orchestrator.orchestrator_api import LABEL_POSITIVE, LABEL_NEGATIVE

gold_labels_per_dataset: (str, nested_default_dict()) = None # (dataset, URIs -> categories -> Label)

Expand All @@ -30,6 +33,15 @@ def get_gold_labels(dataset_name: str, category_name: str = None) -> Mapping[st
categories.
:return: # URIs -> categories -> Label
"""
uri_categories_and_labels_map = _read_gold_labels(dataset_name)[1]

if category_name is not None:
data_view_func = PROJECT_PROPERTIES["data_view_func"]
uri_categories_and_labels_map = data_view_func(category_name, uri_categories_and_labels_map)
return uri_categories_and_labels_map


def _read_gold_labels(dataset_name):
global gold_labels_per_dataset

if gold_labels_per_dataset is None or gold_labels_per_dataset[0] != dataset_name: # not in memory
Expand All @@ -43,11 +55,56 @@ def get_gold_labels(dataset_name: str, category_name: str = None) -> Mapping[st
else: # or create an empty in-memory
gold_labels_per_dataset = (dataset_name, nested_default_dict())

uri_categories_and_labels_map = gold_labels_per_dataset[1]
if category_name is not None:
data_view_func = PROJECT_PROPERTIES["data_view_func"]
uri_categories_and_labels_map = data_view_func(category_name, uri_categories_and_labels_map)
return uri_categories_and_labels_map
return gold_labels_per_dataset


def create_gold_labels_online(dataset_name: str,
category_name: str,
text_element_uris: List[str],
data_access: DataAccessApi) -> None:
"""
{
"trec_dev-0-0": {
"ABBR": {
"labels": [
"false"
],
"metadata": {}
},
"DESC": {
"labels": [
"true"
],
"metadata": {}
},

"""
assert category_name

doc_uris = [uri[:-2] for uri in text_element_uris]
docs = data_access.get_documents(dataset_name, doc_uris)

i = 0

for doc in docs:
for text_element in doc.text_elements:
if text_element.uri in text_element_uris:
i += 1
print("-" * 30, f"{i}/{len(text_element_uris)}", "-" * 30)
print(text_element.text)
label_input = input()
label = LABEL_POSITIVE if label_input else LABEL_NEGATIVE
with open(get_labels_dump_filename(dataset_name), "r") as json_file:
text_and_gold_labels_encoded = json.load(json_file)
text_and_gold_labels_encoded[text_element.uri] = text_and_gold_labels_encoded.get(
text_element.uri, {})
text_and_gold_labels_encoded[text_element.uri][category_name] = {
"labels": [label],
"metadata": {}
}
print("-->", label)
with open(get_labels_dump_filename(dataset_name), "w") as json_file:
json2.dump(text_and_gold_labels_encoded, json_file)


def sample(dataset_name: str, category_name: str, sample_size: int, random_seed: int, restrict_label: str = None):
Expand Down
20 changes: 17 additions & 3 deletions lrtc_lib/oracle_data_access/oracle_data_access_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from typing import Sequence, List, Mapping, Tuple, Set

import lrtc_lib.oracle_data_access.core.utils as oracle_utils
from data_access.data_access_api import DataAccessApi
from lrtc_lib.data_access.core.data_structs import Label
from lrtc_lib.orchestrator.orchestrator_api import LABEL_POSITIVE, LABEL_NEGATIVE
from oracle_data_access.core.utils import create_gold_labels_online


def add_gold_labels(dataset_name: str, text_and_gold_labels: List[Tuple[str, Mapping[str, Label]]]):
Expand All @@ -34,7 +36,8 @@ def add_gold_labels(dataset_name: str, text_and_gold_labels: List[Tuple[str, Map
f.write(gold_labels_encoded)


def get_gold_labels(dataset_name: str, text_element_uris: Sequence[str], category_name: str = None) -> \
def get_gold_labels(dataset_name: str, text_element_uris: Sequence[str], category_name: str = None,
data_access: DataAccessApi = None) -> \
List[Tuple[str, Mapping[str, Label]]]:
"""
Return the gold labels information for the given TextElements uris, keeping the same order, for the given dataset.
Expand All @@ -48,8 +51,19 @@ def get_gold_labels(dataset_name: str, text_element_uris: Sequence[str], categor
same order as the order of the TextElement uris given as input.
"""

gold_labels = oracle_utils.get_gold_labels(dataset_name, category_name)
return [(uri, gold_labels[uri]) for uri in text_element_uris if gold_labels[uri]]
gold_labels = oracle_utils.get_gold_labels(dataset_name, None)
uris_missing = [uri for uri in text_element_uris
if uri not in gold_labels or not gold_labels[uri]]

if uris_missing:
create_gold_labels_online(dataset_name, category_name, uris_missing, data_access)

gold_labels = oracle_utils.get_gold_labels(dataset_name, None)
gold_labels_dataset = [(uri, gold_labels[uri])
for uri in text_element_uris
if uri in gold_labels and gold_labels[uri]]

return gold_labels_dataset


def sample(dataset_name: str, category_name: str, sample_size: int, random_seed: int):
Expand Down
8 changes: 7 additions & 1 deletion lrtc_lib/orchestrator/orchestrator_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import glob
import logging
import os
import random
import traceback
from collections import Counter
from enum import Enum
Expand Down Expand Up @@ -338,7 +339,7 @@ def train(workspace_id: str, category_name: str, model_type: ModelType, train_pa
dataset_name = workspace.dataset_name
(train_data, train_counts), (dev_data, dev_counts) = train_and_dev_sets_selector.get_train_and_dev_sets(
workspace_id=workspace_id, train_dataset_name=dataset_name, category_name=category_name,
dev_dataset_name=workspace.dev_dataset_name)
dev_dataset_name=workspace.test_dataset_name or 'hemnet_descriptions_test')
logging.info(f"training a new model with {train_counts}")

# label_counts != train_counts as train_counts may refer to negative and weak negative labels separately
Expand Down Expand Up @@ -443,6 +444,11 @@ def infer(workspace_id: str, category_name: str, texts_to_infer: Sequence[TextEl

train_and_infer = PROJECT_PROPERTIES["train_and_infer_factory"].get_train_and_infer(model.model_type)
list_of_dicts = [{"text": element.text} for element in texts_to_infer]

#n_sample = 10000
#if len(list_of_dicts) > n_sample:
# list_of_dicts = random.sample(list_of_dicts, n_sample)

infer_results = train_and_infer.infer(model_id=model.model_id, items_to_infer=list_of_dicts,
infer_params=infer_params, use_cache=use_cache)

Expand Down
5 changes: 3 additions & 2 deletions lrtc_lib/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ numpy==1.16.4
dataclasses==0.6
scikit-learn==0.21.3
jsonpickle==1.3
tensorflow>=2.1.0,<2.2.0
transformers==2.5.1
tensorflow>=2.2.0,<2.3.0
transformers==3.5.1
mip>=1.7.3
filelock==3.0.12
wrapt_timeout_decorator==1.3.1
nose==1.3.7
ujson==3.1.0
seaborn==0.11.0
h5py==2.10.0
torch==1.7.1
Loading