Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
694c317
Add decilm modelling code
danielkorzekwa Nov 3, 2025
991659f
Add decilm modelling code.
danielkorzekwa Nov 3, 2025
8489cee
Add transformers codebase
danielkorzekwa Nov 3, 2025
f0afefe
Add transformers code
danielkorzekwa Nov 3, 2025
b3ed5bc
Add decilm modelling code
danielkorzekwa Nov 3, 2025
a700da5
Add decilm modelling code
danielkorzekwa Nov 3, 2025
b59b679
Correct licence headers
danielkorzekwa Nov 4, 2025
1abdf3e
Correct licence headers
danielkorzekwa Nov 4, 2025
66609b1
Add decilm code
danielkorzekwa Nov 4, 2025
7da0a8a
Add decilm code
danielkorzekwa Nov 4, 2025
6e09a81
Add decilm code
danielkorzekwa Nov 4, 2025
2e3f5da
Add decilm code
danielkorzekwa Nov 4, 2025
418890e
Add decilm code
danielkorzekwa Nov 4, 2025
01f4fc1
Make llama3 converter self-contained (no deps on internal Nvidia code)
danielkorzekwa Nov 4, 2025
c57eed4
Add common module
danielkorzekwa Nov 4, 2025
3dc37b3
module refactoring
danielkorzekwa Nov 4, 2025
10ffdfe
refactoring
danielkorzekwa Nov 5, 2025
27a4456
add shared_checkpointing_utils
danielkorzekwa Nov 5, 2025
b0e22b7
Add json tools
danielkorzekwa Nov 5, 2025
52e7827
add logger
danielkorzekwa Nov 5, 2025
f5c1c87
import refactoring
danielkorzekwa Nov 5, 2025
0aa6320
add post_init_sparse module
danielkorzekwa Nov 5, 2025
35d0dbc
Add post_init_sparse
danielkorzekwa Nov 5, 2025
e39a1ad
merginy hydra.py and hydra_utils.py
danielkorzekwa Nov 5, 2025
1bd0c67
Add integrationt test for attention pruning
danielkorzekwa Nov 5, 2025
0ecd52b
add score_pruning_activations
danielkorzekwa Nov 5, 2025
278c6b7
import refactoring
danielkorzekwa Nov 5, 2025
7a0af16
add dist_utils
danielkorzekwa Nov 5, 2025
0f0cbbd
Add validate_model
danielkorzekwa Nov 5, 2025
cb5cf25
Add activation scoring hooks for pruning
danielkorzekwa Nov 5, 2025
6f82a67
make validate_model self-contained
danielkorzekwa Nov 6, 2025
a87fb79
updage validatete_pipeline to use DeciLMForCausalLM from modelopt
danielkorzekwa Nov 6, 2025
b227521
fix imports
danielkorzekwa Nov 6, 2025
ca7ab3f
add sewing_kit
danielkorzekwa Nov 6, 2025
a7a4adc
add sewing_kit
danielkorzekwa Nov 6, 2025
ad84c26
fix imports
danielkorzekwa Nov 6, 2025
3d7e8a2
fix imports
danielkorzekwa Nov 6, 2025
3d755b2
add pruning_ckpts
danielkorzekwa Nov 6, 2025
845d453
add pruning_ckpts
danielkorzekwa Nov 6, 2025
4fd921b
import refactoring
danielkorzekwa Nov 6, 2025
3641847
refactor imports
danielkorzekwa Nov 6, 2025
8d6333b
import refactoring
danielkorzekwa Nov 6, 2025
b6b7ca9
Merge branch 'feature/compress' into dkorzekwa/pruning_ckpts_1
danielkorzekwa Nov 25, 2025
7ab69e6
Delete not needed mistral tokenizer
danielkorzekwa Nov 25, 2025
2217a2a
Improve doc strings
danielkorzekwa Nov 25, 2025
a281ff7
Delete empty module
danielkorzekwa Nov 25, 2025
c1fb32c
Add doc string
danielkorzekwa Nov 25, 2025
5203169
Add doc string + remove references to 'lustre'
danielkorzekwa Nov 25, 2025
e2eee60
Add typeguard to compress dependencies in setup.py
danielkorzekwa Nov 26, 2025
6e26074
Improve dpcs
danielkorzekwa Nov 26, 2025
42da180
fix imports
danielkorzekwa Nov 26, 2025
6200962
fix import ordering
danielkorzekwa Nov 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import hydra
import torch
from omegaconf import DictConfig
from modelopt.torch._compress.utils.parsing import format_global_config

from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers
from modelopt.torch._compress.tools.logger import mprint
from modelopt.torch._compress.tools.runtime import BaseRuntime, NativeDdpRuntime
from modelopt.torch._compress.tools.validate_model import validate_model
from modelopt.torch._compress.utils.dist_utils import is_distributed
from modelopt.torch._compress.utils.parsing import format_global_config


def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/_compress/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
import build_library_and_stats
import mip_and_realize_models
import pruning_ckpts
import modelopt.torch._compress.activation_scoring.score_pruning_activations as score_pruning_activations
import scoring
from omegaconf import DictConfig
from modelopt.torch._compress.tools.runtime import IRuntime

import modelopt.torch._compress.activation_scoring.score_pruning_activations as score_pruning_activations
from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir
from modelopt.torch._compress.tools.runtime import IRuntime


def compress(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@

import build_library_and_stats
import mip_and_realize_models
import pruning_ckpts
import scoring
import torch
from torch import nn

import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts
from modelopt.torch._compress.activation_scoring import score_pruning_activations
from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import (
convert_llama3_to_decilm,
Expand Down
351 changes: 351 additions & 0 deletions modelopt/torch/_compress/pruning/pruning_ckpts.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fyi, All this is natively supported in DynamicModule if we implement the DeciLM model as a DynamicModule. That way child_init.py could also be greatly simplified

We just need to set hyperparameter active value to the pruned value and assign a ranking then exporting it will export sorted + pruned module. If ranking order is not assigned, then it will just truncate.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will create an issue (once gitlab is up again:)

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions modelopt/torch/_compress/sewing_kit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# mypy: ignore-errors

from .core import (
Needle,
CantResolveNodeDependenciesException,
ConstantTarget,
ExternalTarget,
FunctionTarget,
InputsLoopFoundException,
KnotException,
LoopFoundException,
InputsLoopFoundException,
ModuleTarget,
MultipleExternalNodesException,
Needle,
OnlyInternalNodesException,
OutputsLoopFoundException,
ExternalTarget,
ModuleTarget,
ConstantTarget,
FunctionTarget,
RemoteTarget,
StitchedModule,
StitchedModuleException,
CantResolveNodeDependenciesException,
StitchedModuleOutput,
)
from .passage import always_false_predicate, always_true_predicate, InputArgs
from .passage import InputArgs, always_false_predicate, always_true_predicate
14 changes: 8 additions & 6 deletions modelopt/torch/_compress/sewing_kit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# limitations under the License.

# mypy: ignore-errors

from __future__ import annotations

from abc import ABC
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, Iterable, Literal, Optional, Sequence, Union

from typing_extensions import override

try:
Expand All @@ -30,19 +33,18 @@
import torch.distributed
import torch.nn as nn

from .utils import distributed_isend_obj, distributed_recv_obj, dynamo_skip
from .passage import (
Passage,
InputArgs,
OutputValue,
Predicate,
always_false_predicate,
Passage,
PassageInputAdapter,
PassageOutputAdapter,
PassageInputOverrides,
PassageOutputAdapter,
PassageOutputOverrides,
Predicate,
always_false_predicate,
)

from .utils import distributed_isend_obj, distributed_recv_obj, dynamo_skip

InputAdapter = Callable[[InputArgs], InputArgs]
OutputAdapter = Callable[..., OutputValue]
Expand Down
11 changes: 6 additions & 5 deletions modelopt/torch/_compress/sewing_kit/passage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from .core import (
Passage,
PassageOutput,
InputArgs,
OutputValue,
Predicate,
Passage,
PassageInputAdapter,
PassageOutputAdapter,
PassageInputOverrides,
PassageOutput,
PassageOutputAdapter,
PassageOutputOverrides,
always_true_predicate,
Predicate,
always_false_predicate,
always_true_predicate,
)
13 changes: 6 additions & 7 deletions modelopt/torch/_compress/sewing_kit/passage/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@

# mypy: ignore-errors
from __future__ import annotations
import sys

from collections.abc import Sequence, Callable

import sys
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import Any, ContextManager, Iterable, Mapping, Optional, Union

Expand All @@ -27,19 +26,19 @@
except ImportError:
from typing_extensions import Self

import torch.nn as nn
from typing_extensions import override

import torch.nn as nn
from ..common import logger
from ..utils import (
ActivityContext,
has_fake_tensor,
dynamo_skip,
fake_tensors,
has_fake_tensor,
is_submodule_of,
is_submodule_or_same,
real_tensors,
dynamo_skip,
)
from ..common import logger


@dataclass
Expand Down
14 changes: 7 additions & 7 deletions modelopt/torch/_compress/sewing_kit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

import inspect
from collections.abc import Sequence, Mapping
from collections.abc import Mapping, Sequence
from contextlib import contextmanager
from typing import (
Any,
Expand All @@ -31,17 +31,17 @@
cast,
overload,
)
from typing_extensions import override

import torch
import torch.distributed
import torch._dynamo
import torch._C
from torch import Tensor
import torch.utils._pytree as pytree
import torch._dynamo
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
import torch.utils._pytree as pytree
from torch import Tensor
from torch._subclasses import FakeTensor, FakeTensorMode

from typing_extensions import override

Fn = TypeVar("Fn", bound=Callable)

Expand Down
Loading