Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fetch is a class method! #138

Merged
merged 6 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/create_arrow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mttl.models.containers.selectors import ArrowSelectorConfig
from mttl.models.expert_model import MultiExpertModel, MultiExpertModelConfig
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.library.library_transforms import ArrowConfig, ArrowTransform
from mttl.models.library.library_transforms import ArrowTransform, ArrowTransformConfig


@click.command()
Expand All @@ -30,7 +30,7 @@ def make_arrow(experts, push_to_hub):
library.add_expert_from_ckpt(path)

# compute arrow prototypes and store them in the library
arrow_config = ArrowConfig()
arrow_config = ArrowTransformConfig()
transform = ArrowTransform(arrow_config)
transform.transform(library, persist=True)

Expand Down
1 change: 0 additions & 1 deletion mttl/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

# warning if logger is not initialized
logger = logging.getLogger("mttl")
logger.setLevel(logging.WARNING)
logging.getLogger("datasets.arrow_dataset").setLevel(logging.CRITICAL + 1)


Expand Down
13 changes: 7 additions & 6 deletions mttl/models/containers/selectors/arrow_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ def compute_arrow_embeddings(
add_base_proto=False,
recompute_prototypes=False,
) -> str:
from mttl.models.library.library_transforms import ArrowConfig, ArrowTransform
from mttl.models.library.library_transforms import (
ArrowTransform,
ArrowTransformConfig,
)

cfg = ArrowConfig(
cfg = ArrowTransformConfig(
name=selector_data_id,
ab_only=ab_only,
tie_params=tie_params or "default",
Expand Down Expand Up @@ -48,8 +51,6 @@ class ArrowSelector(PerTokenSelector):
@artifacts_cache
def load_from_library(cls, config):
"""Fetches prototypes from the library."""
from mttl.models.library.library_transforms import ArrowConfig, ArrowTransform
from mttl.models.library.library_transforms import ArrowTransform

return ArrowTransform(ArrowConfig(name=config.selector_data_id)).fetch(
config.library_id
)
return ArrowTransform.fetch(config.library_id, config.selector_data_id)
12 changes: 5 additions & 7 deletions mttl/models/containers/selectors/average_activation_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,9 @@ class AverageActivationSelector(PerTokenSelector):
@artifacts_cache
def load_from_library(cls, config):
"""Fetches prototypes from the library."""
from mttl.models.library.library_transforms import (
HiddenStateComputer,
HiddenStateComputerConfig,
)
from mttl.models.library.library_transforms import HiddenStateComputer

return HiddenStateComputer(
HiddenStateComputerConfig(name=config.selector_data_id)
).fetch(config.library_id)
return HiddenStateComputer.fetch(
config.library_id,
config.selector_data_id,
)
13 changes: 4 additions & 9 deletions mttl/models/containers/selectors/phatgoose_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def compute_phatgoose_embeddings(
) -> str:
"""Computes Phatgoose embeddings for the given library."""
from mttl.models.library.library_transforms import (
PhatgooseConfig,
PhatgooseTransform,
PhatgooseTransformConfig,
)

cfg = PhatgooseConfig(
cfg = PhatgooseTransformConfig(
n_steps=n_steps_pg,
learning_rate=learning_rate_pg,
name=selector_data_id,
Expand Down Expand Up @@ -70,14 +70,9 @@ def __init__(self, config, **kwargs) -> None:
@artifacts_cache
def load_from_library(cls, config):
"""Fetches prototypes from the library."""
from mttl.models.library.library_transforms import (
PhatgooseConfig,
PhatgooseTransform,
)
from mttl.models.library.library_transforms import PhatgooseTransform

return PhatgooseTransform(PhatgooseConfig(name=config.selector_data_id)).fetch(
config.library_id
)
return PhatgooseTransform.fetch(config.library_id, config.selector_data_id)


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion mttl/models/expert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def from_pretrained_library(
if isinstance(selector_config, LoadableSelectorConfig):
selector_config.library_id = repo_id

elif isinstance(selector_config, dict):
elif isinstance(selector_config, MultiSelectorConfig):
for modifier_name, cfg in selector_config.items():
# inject the library id if it is None
if (
Expand Down
Loading