Skip to content

Commit 283204f

Browse files
committed
lint
1 parent aefb19d commit 283204f

File tree

6 files changed

+10
-14
lines changed

6 files changed

+10
-14
lines changed

torchtitan/hf_datasets/flux_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch.distributed.checkpoint.stateful import Stateful
1919

2020
from torch.utils.data import IterableDataset
21+
2122
from torchtitan.components.dataloader import ParallelAwareDataloader
2223

2324
from torchtitan.components.tokenizer import BaseTokenizer

torchtitan/models/flux/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from torchtitan.components.loss import build_mse_loss
78
from torchtitan.components.lr_scheduler import build_lr_schedulers
89
from torchtitan.components.optimizer import build_optimizers
9-
from torchtitan.protocols.train_spec import TrainSpec
1010

1111
from torchtitan.datasets.flux_dataset import build_flux_dataloader
12-
from torchtitan.components.loss import build_mse_loss
12+
from torchtitan.protocols.train_spec import TrainSpec
1313
from .infra.parallelize import parallelize_flux
1414
from .model.args import FluxModelArgs
1515
from .model.autoencoder import AutoEncoderParams

torchtitan/models/flux/model/hf_embedder.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,15 @@ def __init__(self, version: str, random_init=False, **hf_kwargs):
4040
version, **hf_kwargs
4141
)
4242

43-
4443
self.hf_module = self.hf_module.eval().requires_grad_(False)
4544
# This is to make sure the encoders works with FSDP
4645
self.make_parameters_contiguous()
4746

4847
def make_parameters_contiguous(self):
4948
"""Make all non-contiguous parameters contiguous to avoid FSDP issues."""
50-
strided_count = 0
5149
for name, param in self.hf_module.named_parameters():
5250
if not param.is_contiguous():
53-
strided_count += 1
5451
param.data = param.data.contiguous()
55-
5652

5753
def forward(self, batch_tokens: Tensor) -> Tensor:
5854
"""

torchtitan/models/flux/tokenizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from typing import List
1212

1313
import torch
14+
from transformers import CLIPTokenizer, T5Tokenizer
15+
1416
from torchtitan.components.tokenizer import BaseTokenizer, HuggingFaceTokenizer
1517
from torchtitan.config import JobConfig
16-
from transformers import CLIPTokenizer, T5Tokenizer
1718

1819

1920
class FluxTestTokenizer(BaseTokenizer):

torchtitan/models/flux/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111

1212
from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP
1313
from torchtitan.distributed import utils as dist_utils
14-
from torchtitan.tools.logging import init_logger, logger
15-
from torchtitan.train import Trainer
1614

1715
from torchtitan.models.flux.infra.parallelize import parallelize_encoders
1816
from torchtitan.models.flux.model.autoencoder import load_ae
@@ -23,6 +21,8 @@
2321
preprocess_data,
2422
unpack_latents,
2523
)
24+
from torchtitan.tools.logging import init_logger, logger
25+
from torchtitan.train import Trainer
2626

2727

2828
class FluxTrainer(Trainer):

torchtitan/models/flux/validate.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@
1717
from torchtitan.components.tokenizer import BaseTokenizer
1818
from torchtitan.components.validate import Validator
1919
from torchtitan.config import JobConfig
20+
from torchtitan.datasets.flux_dataset import build_flux_validation_dataloader
2021
from torchtitan.distributed import ParallelDims, utils as dist_utils
21-
from torchtitan.datasets.flux_dataset import (
22-
build_flux_validation_dataloader,
23-
)
24-
25-
from torchtitan.models.flux.tokenizer import build_flux_tokenizer
2622
from torchtitan.models.flux.model.autoencoder import AutoEncoder
2723
from torchtitan.models.flux.model.hf_embedder import FluxEmbedder
2824
from torchtitan.models.flux.sampling import generate_image, save_image
25+
26+
from torchtitan.models.flux.tokenizer import build_flux_tokenizer
2927
from torchtitan.models.flux.utils import (
3028
create_position_encoding_for_latents,
3129
pack_latents,

0 commit comments

Comments
 (0)