Skip to content

Commit

Permalink
Merge pull request #138 from microsoft/fetch-refac
Browse files Browse the repository at this point in the history
fetch is a class method!
  • Loading branch information
sordonia authored Nov 12, 2024
2 parents a99f41e + 184c839 commit 9e7f882
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 144 deletions.
4 changes: 2 additions & 2 deletions examples/create_arrow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mttl.models.containers.selectors import ArrowSelectorConfig
from mttl.models.expert_model import MultiExpertModel, MultiExpertModelConfig
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.library.library_transforms import ArrowConfig, ArrowTransform
from mttl.models.library.library_transforms import ArrowTransform, ArrowTransformConfig


@click.command()
Expand All @@ -30,7 +30,7 @@ def make_arrow(experts, push_to_hub):
library.add_expert_from_ckpt(path)

# compute arrow prototypes and store them in the library
arrow_config = ArrowConfig()
arrow_config = ArrowTransformConfig()
transform = ArrowTransform(arrow_config)
transform.transform(library, persist=True)

Expand Down
1 change: 0 additions & 1 deletion mttl/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

# warning if logger is not initialized
logger = logging.getLogger("mttl")
logger.setLevel(logging.WARNING)
logging.getLogger("datasets.arrow_dataset").setLevel(logging.CRITICAL + 1)


Expand Down
13 changes: 7 additions & 6 deletions mttl/models/containers/selectors/arrow_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ def compute_arrow_embeddings(
add_base_proto=False,
recompute_prototypes=False,
) -> str:
from mttl.models.library.library_transforms import ArrowConfig, ArrowTransform
from mttl.models.library.library_transforms import (
ArrowTransform,
ArrowTransformConfig,
)

cfg = ArrowConfig(
cfg = ArrowTransformConfig(
name=selector_data_id,
ab_only=ab_only,
tie_params=tie_params or "default",
Expand Down Expand Up @@ -48,8 +51,6 @@ class ArrowSelector(PerTokenSelector):
@artifacts_cache
def load_from_library(cls, config):
"""Fetches prototypes from the library."""
from mttl.models.library.library_transforms import ArrowConfig, ArrowTransform
from mttl.models.library.library_transforms import ArrowTransform

return ArrowTransform(ArrowConfig(name=config.selector_data_id)).fetch(
config.library_id
)
return ArrowTransform.fetch(config.library_id, config.selector_data_id)
12 changes: 5 additions & 7 deletions mttl/models/containers/selectors/average_activation_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,9 @@ class AverageActivationSelector(PerTokenSelector):
@artifacts_cache
def load_from_library(cls, config):
"""Fetches prototypes from the library."""
from mttl.models.library.library_transforms import (
HiddenStateComputer,
HiddenStateComputerConfig,
)
from mttl.models.library.library_transforms import HiddenStateComputer

return HiddenStateComputer(
HiddenStateComputerConfig(name=config.selector_data_id)
).fetch(config.library_id)
return HiddenStateComputer.fetch(
config.library_id,
config.selector_data_id,
)
13 changes: 4 additions & 9 deletions mttl/models/containers/selectors/phatgoose_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def compute_phatgoose_embeddings(
) -> str:
"""Computes Phatgoose embeddings for the given library."""
from mttl.models.library.library_transforms import (
PhatgooseConfig,
PhatgooseTransform,
PhatgooseTransformConfig,
)

cfg = PhatgooseConfig(
cfg = PhatgooseTransformConfig(
n_steps=n_steps_pg,
learning_rate=learning_rate_pg,
name=selector_data_id,
Expand Down Expand Up @@ -70,14 +70,9 @@ def __init__(self, config, **kwargs) -> None:
@artifacts_cache
def load_from_library(cls, config):
"""Fetches prototypes from the library."""
from mttl.models.library.library_transforms import (
PhatgooseConfig,
PhatgooseTransform,
)
from mttl.models.library.library_transforms import PhatgooseTransform

return PhatgooseTransform(PhatgooseConfig(name=config.selector_data_id)).fetch(
config.library_id
)
return PhatgooseTransform.fetch(config.library_id, config.selector_data_id)


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion mttl/models/expert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def from_pretrained_library(
if isinstance(selector_config, LoadableSelectorConfig):
selector_config.library_id = repo_id

elif isinstance(selector_config, dict):
elif isinstance(selector_config, MultiSelectorConfig):
for modifier_name, cfg in selector_config.items():
# inject the library id if it is None
if (
Expand Down
Loading

0 comments on commit 9e7f882

Please sign in to comment.