Skip to content

Commit d636d35

Browse files
authored
ClearML integration (#2197)
* clearml * lint * include in loggers * add step * fix torch dep
1 parent 5aae81b commit d636d35

File tree

4 files changed

+174
-3
lines changed

4 files changed

+174
-3
lines changed

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
"GPUtil>=1.4.0",
5656
"protobuf>=3.12.2,<=3.20.3",
5757
"click>=7.1.2,!=8.0.0", # latest version < 8.0 + blocked version with reported bug
58+
"clearml==1.14.4",
5859
]
5960
_nm_deps = [f"{'sparsezoo' if is_release else 'sparsezoo-nightly'}~={version_nm_deps}"]
6061
_deepsparse_deps = [

src/sparseml/pytorch/utils/logger.py

+105
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,21 @@
4545
wandb = None
4646
wandb_err = err
4747

48+
49+
try:
50+
from clearml import Task
51+
52+
clearml_err = None
53+
except Exception as err:
54+
clearml = None
55+
clearml_err = err
56+
4857
from sparseml.utils import ALL_TOKEN, create_dirs
4958

5059

5160
__all__ = [
5261
"BaseLogger",
62+
"ClearMLLogger",
5363
"LambdaLogger",
5464
"PythonLogger",
5565
"TensorBoardLogger",
@@ -628,6 +638,101 @@ def save(
628638
return True
629639

630640

641+
class ClearMLLogger(LambdaLogger):
642+
@staticmethod
643+
def available() -> bool:
644+
"""
645+
:return: True if wandb is available and installed, False, otherwise
646+
"""
647+
return not clearml_err
648+
649+
def __init__(
650+
self,
651+
name: str = "clearml",
652+
enabled: bool = True,
653+
project_name: str = "sparseml",
654+
task_name: str = "",
655+
):
656+
if task_name == "":
657+
now = datetime.now()
658+
task_name = now.strftime("%d-%m-%Y_%H.%M.%S")
659+
660+
self.task = Task.init(project_name=project_name, task_name=task_name)
661+
662+
super().__init__(
663+
lambda_func=self.log_scalar,
664+
name=name,
665+
enabled=enabled,
666+
)
667+
668+
def log_hyperparams(
669+
self,
670+
params: Dict,
671+
level: Optional[int] = None,
672+
) -> bool:
673+
"""
674+
:param params: Each key-value pair in the dictionary is the name of the
675+
hyper parameter and it's corresponding value.
676+
:return: True if logged, False otherwise.
677+
"""
678+
if not self.enabled:
679+
return False
680+
681+
self.task.connect(params)
682+
return True
683+
684+
def log_scalar(
685+
self,
686+
tag: str,
687+
value: float,
688+
step: Optional[int] = None,
689+
wall_time: Optional[float] = None,
690+
level: Optional[int] = None,
691+
) -> bool:
692+
"""
693+
:param tag: identifying tag to log the value with
694+
:param value: value to save
695+
:param step: global step for when the value was taken
696+
:param wall_time: global wall time for when the value was taken,
697+
defaults to time.time()
698+
:param kwargs: additional logging arguments to support Python and custom loggers
699+
:return: True if logged, False otherwise.
700+
"""
701+
logger = self.task.get_logger()
702+
# each series is superimposed on the same plot on title
703+
logger.report_scalar(
704+
title=tag, series=str(level) or tag, value=value, iteration=step
705+
)
706+
return True
707+
708+
def log_scalars(
709+
self,
710+
tag: str,
711+
values: Dict[str, float],
712+
step: Optional[int] = None,
713+
wall_time: Optional[float] = None,
714+
level: Optional[int] = None,
715+
) -> bool:
716+
"""
717+
:param tag: identifying tag to log the values with
718+
:param values: values to save
719+
:param step: global step for when the values were taken
720+
:param wall_time: global wall time for when the values were taken,
721+
defaults to time.time()
722+
:param kwargs: additional logging arguments to support Python and custom loggers
723+
:return: True if logged, False otherwise.
724+
"""
725+
for k, v in values.items():
726+
self.log_scalar(
727+
tag=f"{tag}.{k}",
728+
value=v,
729+
step=step,
730+
wall_time=wall_time,
731+
level=level,
732+
)
733+
return True
734+
735+
631736
class SparsificationGroupLogger(BaseLogger):
632737
"""
633738
Modifier logger that handles outputting values to other supported systems.

tests/sparseml/pytorch/utils/test_logger.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121

2222
from sparseml.pytorch.utils import (
23+
ClearMLLogger,
2324
LambdaLogger,
2425
LoggerManager,
2526
PythonLogger,
@@ -45,6 +46,7 @@
4546
or True
4647
),
4748
*([WANDBLogger()] if WANDBLogger.available() else []),
49+
*([ClearMLLogger()] if ClearMLLogger.available() else []),
4850
SparsificationGroupLogger(
4951
lambda_func=lambda tag, value, values, step, wall_time, level: logging.info(
5052
f"{tag}, {value}, {values}, {step}, {wall_time}, {level}"
@@ -79,12 +81,12 @@ def test_log_scalar(self, logger):
7981

8082
def test_log_scalars(self, logger):
8183
logger.log_scalars("test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0})
82-
logger.log_scalars("test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}, 1)
84+
logger.log_scalars("test-scalars-tag2", {"scalar1": 0.0, "scalar2": 1.0}, 1)
8385
logger.log_scalars(
84-
"test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}, 2, time.time() - 1
86+
"test-scalars-tag3", {"scalar1": 0.0, "scalar2": 1.0}, 2, time.time() - 1
8587
)
8688
logger.log_scalars(
87-
"test-scalars-tag",
89+
"test-scalars-tag4",
8890
{"scalar1": 0.0, "scalar2": 1.0},
8991
2,
9092
time.time() - 1,

tests/sparseml/test_clear_ml.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pathlib import Path
16+
17+
from clearml import Task
18+
from sparseml.transformers import apply
19+
from sparseml.utils import is_package_available
20+
21+
22+
is_torch_available = is_package_available("torch")
23+
if is_torch_available:
24+
import torch
25+
26+
torch_err = None
27+
else:
28+
torch = object
29+
torch_err = ModuleNotFoundError(
30+
"`torch` is not installed, use `pip install torch` to log to Weights and Biases"
31+
)
32+
33+
34+
def test_oneshot_and_finetune(tmp_path: Path):
35+
recipe_str = "tests/sparseml/transformers/finetune/test_alternate_recipe.yaml"
36+
model = "Xenova/llama2.c-stories15M"
37+
device = "cuda:0"
38+
if is_torch_available and not torch.cuda.is_available():
39+
device = "cpu"
40+
dataset = "wikitext"
41+
dataset_config_name = "wikitext-2-raw-v1"
42+
concatenate_data = True
43+
run_stages = True
44+
output_dir = tmp_path
45+
max_steps = 50
46+
splits = {"train": "train[:50%]", "calibration": "train[50%:60%]"}
47+
48+
# clearML will automatically log default capturing entries without
49+
# explicitly calling logger. Logs accessible in https://app.clear.ml/
50+
Task.init(project_name="test", task_name="test_oneshot_and_finetune")
51+
52+
apply(
53+
model=model,
54+
dataset=dataset,
55+
dataset_config_name=dataset_config_name,
56+
run_stages=run_stages,
57+
output_dir=output_dir,
58+
recipe=recipe_str,
59+
max_steps=max_steps,
60+
concatenate_data=concatenate_data,
61+
splits=splits,
62+
oneshot_device=device,
63+
)

0 commit comments

Comments
 (0)