Skip to content

fix: Project model directory #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions label_studio_ml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import attr
import io
import rq

try:
import torch.multiprocessing as mp

try:
mp.set_start_method('spawn')
except RuntimeError:
Expand Down Expand Up @@ -39,6 +41,7 @@
LABEL_STUDIO_ML_BACKEND_V2_DEFAULT = False
AUTO_UPDATE_DEFAULT = False


@attr.s
class ModelWrapper(object):
model = attr.ib()
Expand Down Expand Up @@ -109,8 +112,9 @@ def get_result_from_job_id(self, job_id):
DON'T OVERRIDE THIS FUNCTION! Instead, override _get_result_from_job_id
"""
result = self._get_result_from_job_id(job_id)
assert isinstance(result, dict), f"Job {job_id} was finished unsuccessfully. No result was saved in job folder." \
f"Please clean up failed job folders to remove this error from log."
assert isinstance(result, dict), \
f"Training job {job_id} was finished unsuccessfully. No result was saved in job folder." \
f"Please clean up failed job folders to remove this error from log."
result['job_id'] = job_id
return result

Expand Down Expand Up @@ -195,8 +199,11 @@ def _get_result_from_job_id(self, job_id):
return None
result_file = os.path.join(job_dir, self.JOB_RESULT)
if not os.path.exists(result_file):
logger.warning(f"=> Warning: {job_id} dir doesn't contain result file. "
f"It seems that previous training session ended with error.")
logger.warning(
f"=> Warning: {job_id} dir doesn't contain result file. "
f"It seems that previous training session was never done or ended with error. "
f"It is normal to see it if your model doesn't have fit() implementation at all. "
)
# Return empty dict if training is failed OR None if Error message is needed in case of failed train
IGNORE_FAILED_TRAINING = get_env("IGNORE_FAILED_TRAINING", is_bool=True)
return {} if IGNORE_FAILED_TRAINING else None
Expand Down Expand Up @@ -232,6 +239,7 @@ class RQJobManager(JobManager):
"""

MAX_QUEUE_LEN = 1 # Controls a maximal amount of simultaneous jobs running in queue.

# If exceeded, new jobs are ignored

def __init__(self, redis_host, redis_port, redis_queue):
Expand Down Expand Up @@ -293,7 +301,6 @@ def post_process(self, event, data, job_id, result):


class LabelStudioMLBase(ABC):

TRAIN_EVENTS = (
'ANNOTATION_CREATED',
'ANNOTATION_UPDATED',
Expand All @@ -307,7 +314,7 @@ def __init__(self, label_config=None, train_output=None, **kwargs):
self.parsed_label_config = parse_config(self.label_config) if self.label_config else {}
self.train_output = train_output or {}
self.hostname = kwargs.get('hostname', '') or get_env('HOSTNAME')
self.access_token = kwargs.get('access_token', '') or get_env('ACCESS_TOKEN') or get_env('API_KEY')
self.access_token = kwargs.get('access_token', '') or get_env('ACCESS_TOKEN') or get_env('API_KEY')

@abstractmethod
def predict(self, tasks, **kwargs):
Expand All @@ -328,7 +335,6 @@ def get_local_path(self, url, project_dir=None):


class LabelStudioMLManager(object):

model_class = None
model_dir = None
redis_host = None
Expand All @@ -342,8 +348,8 @@ class LabelStudioMLManager(object):

@classmethod
def initialize(
cls, model_class, model_dir=None, redis_host='localhost', redis_port=6379, redis_queue='default',
**init_kwargs
cls, model_class, model_dir=None, redis_host='localhost', redis_port=6379, redis_queue='default',
**init_kwargs
):
if not issubclass(model_class, LabelStudioMLBase):
raise ValueError('Inference class should be the subclass of ' + LabelStudioMLBase.__class__.__name__)
Expand Down Expand Up @@ -452,15 +458,15 @@ def _key(cls, project):
@classmethod
def has_active_model(cls, project):
if not os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=LABEL_STUDIO_ML_BACKEND_V2_DEFAULT):
# TODO: Deprecated branch since LS 1.5
# TODO: Deprecated branch since LS 1.5
return cls._key(project) in cls._current_model
else:
return cls._current_model is not None

@classmethod
def get(cls, project):
if not os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=LABEL_STUDIO_ML_BACKEND_V2_DEFAULT):
# TODO: Deprecated branch since LS 1.5
# TODO: Deprecated branch since LS 1.5
key = cls._key(project)
logger.debug('Get project ' + str(key))
return cls._current_model.get(key)
Expand Down Expand Up @@ -488,7 +494,7 @@ def create(cls, project=None, label_config=None, train_output=None, version=None

@classmethod
def get_or_create(
cls, project=None, label_config=None, force_reload=False, train_output=None, version=None, **kwargs
cls, project=None, label_config=None, force_reload=False, train_output=None, version=None, **kwargs
):
m = cls.get(project)
# reload new model if model is not loaded into memory OR force_reload=True OR model versions are mismatched
Expand Down Expand Up @@ -529,7 +535,8 @@ def fetch(cls, project=None, label_config=None, force_reload=False, **kwargs):
return cls.get_or_create(project, label_config, force_reload, train_output, version, **kwargs)

model_version = kwargs.get('model_version')
if not cls._current_model or (model_version != cls._current_model.model_version and model_version is not None) or \
if not cls._current_model or (
model_version != cls._current_model.model_version and model_version is not None) or \
os.getenv('AUTO_UPDATE', default=AUTO_UPDATE_DEFAULT):
jm = cls.get_job_manager()
model_version = kwargs.get('model_version')
Expand Down Expand Up @@ -593,7 +600,7 @@ def is_training(cls, project):

@classmethod
def predict(
cls, tasks, project=None, label_config=None, force_reload=False, try_fetch=True, **kwargs
cls, tasks, project=None, label_config=None, force_reload=False, try_fetch=True, **kwargs
):
"""
Make prediction for tasks
Expand Down Expand Up @@ -636,7 +643,7 @@ def create_data_snapshot(cls, data_iter, workdir):

@classmethod
def train_script_wrapper(
cls, project, label_config, train_kwargs, initialization_params=None, tasks=()
cls, project, label_config, train_kwargs, initialization_params=None, tasks=()
):

if initialization_params:
Expand Down Expand Up @@ -750,14 +757,16 @@ def _get_models_from_workdir(cls, project):
@return: List of model versions for current model
"""
V2 = os.getenv('LABEL_STUDIO_ML_BACKEND_V2', default=LABEL_STUDIO_ML_BACKEND_V2_DEFAULT)
project_model_dir = os.path.join(cls.model_dir, project or '')
if not V2:
project_model_dir = os.path.join(cls.model_dir, project or '')
if not os.path.exists(project_model_dir):
return []
else:
project_model_dir = cls.model_dir
# get directories with traing results
# else:
# project_model_dir = cls.model_dir

# get directories with training results
final_models = []

for subdir in map(int, filter(lambda d: d.isdigit(), os.listdir(project_model_dir))):
job_result_file = os.path.join(project_model_dir, str(subdir), 'job_result.json')
# check if there is job result
Expand All @@ -773,7 +782,6 @@ def _get_models_from_workdir(cls, project):
return final_models



def get_all_classes_inherited_LabelStudioMLBase(script_file):
names = set()
abs_path = os.path.abspath(script_file)
Expand Down