Skip to content

Training Error Occured : Type Error #118

@KyungdaePark

Description

@KyungdaePark

Hello, I just started training, with re10k datasets, and this error occured :

================================================================================

❯ python3 -m src.main +experiment=re10k data_loader.train.batch_size=1 
Saving outputs to /n/pixelsplat2/outputs/2025-08-22/16-07-13.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Using cache found in /home/pkd/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /home/pkd/.cache/torch/hub/facebookresearch_dino_main
[2025-08-22 16:07:14,084][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(

[2025-08-22 16:07:14,084][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
[2025-08-22 16:07:14,328][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)

Loading model from: /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                 | Params | Mode
---------------------------------------------------------
0 | encoder | EncoderEpipolar      | 118 M  | train
1 | decoder | DecoderSplattingCUDA | 0      | train
2 | losses  | ModuleList           | 0      | train
---------------------------------------------------------
118 M     Trainable params
0         Non-trainable params
118 M     Total params
475.918   Total estimated model params size (MB)
493       Modules in train mode
59        Modules in eval mode
[2025-08-22 16:07:15,289][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.

validation step 0; scene = ['306e2b7785657539']; context = [[48, 73]]
[2025-08-22 16:07:15,920][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
[2025-08-22 16:07:15,921][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(

[2025-08-22 16:07:15,921][py.warnings][WARNING] - /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)

Loading model from: /home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth
Error executing job with overrides: ['+experiment=re10k', 'data_loader.train.batch_size=1']
Traceback (most recent call last):
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 436, in wrapped_fn_impl
    param_fn(*args, **kwargs)
  File "<@beartype(src.visualization.layout.vcat) at 0x74e449816560>", line 55, in vcat
beartype.roar.BeartypeCallHintParamViolation: Function src.visualization.layout.vcat() parameter images="tensor([[[0.3804, 0.4118, 0.4392,  ..., 0.9961, 0.9961, 0.9961],
         [0.3882, 0.4196...')" violates type hint typing.Iterable[jaxtyping.Float[Tensor, 'channel _ _']], as <protocol "torch.Tensor"> index 0 item this array has 2 dimensions, not the 3 expected by the type hint.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 753, in _get_problem_arg
    fn(*args, **kwargs)
  File "<@beartype(src.visualization.layout.check_single_arg) at 0x74e503b75480>", line 49, in check_single_arg
beartype.roar.BeartypeCallHintParamViolation: Function src.visualization.layout.check_single_arg() parameter images="tensor([[[0.3804, 0.4118, 0.4392,  ..., 0.9961, 0.9961, 0.9961],
         [0.3882, 0.4196...')" violates type hint typing.Iterable[jaxtyping.Float[Tensor, 'channel _ _']], as <protocol "torch.Tensor"> index 0 item this array has 2 dimensions, not the 3 expected by the type hint.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 441, in wrapped_fn_impl
    argmsg = _get_problem_arg(
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 756, in _get_problem_arg
    raise TypeCheckError(
jaxtyping.TypeCheckError:
The problem arose whilst typechecking parameter 'images'.
Actual value: (f32[3,256,256](torch), f32[3,256,256](torch))
Expected type: typing.Iterable[Float[Tensor, 'channel _ _']].

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/n/pixelsplat2/src/main.py", line 128, in train
    trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 561, in fit
    call._call_and_handle_interrupt(
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
    results = self._run_stage()
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1054, in _run_stage
    self._run_sanity_check()
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1083, in _run_sanity_check
    val_loop.run()
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 145, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 437, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 328, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 412, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/lightning_utilities/core/rank_zero.py", line 41, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 549, in wrapped_fn
    return wrapped_fn_impl(args, kwargs, bound, memos)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 473, in wrapped_fn_impl
    out = fn(*args, **kwargs)
  File "/n/pixelsplat2/src/model/model_wrapper.py", line 258, in validation_step
    add_label(vcat(*batch["context"]["image"][0]), "Context"),
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 549, in wrapped_fn
    return wrapped_fn_impl(args, kwargs, bound, memos)
  File "/home/pkd/anaconda3/envs/pixelsplat/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 470, in wrapped_fn_impl
    raise TypeCheckError(msg) from e
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of src.visualization.layout.vcat.
The problem arose whilst typechecking parameter 'images'.
Actual value: (f32[3,256,256](torch), f32[3,256,256](torch))
Expected type: typing.Iterable[Float[Tensor, 'channel _ _']].
----------------------
Called with parameters: {
  'images': (f32[3,256,256](torch), f32[3,256,256](torch)),
  'align': 'start',
  'gap': 8,
  'gap_color': 1
}
Parameter annotations: (*images: Iterable[Float[Tensor, 'channel _ _']], align: Literal['start', 'center', 'end', 'left', 'right'] = 'start', gap: int = 8, gap_color: Union[int, float, Iterable[int], Iterable[float], Float[Tensor, '#channel'], Float[Tensor, '']] = 1) -> Any.


Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

================================================================================

How can I fix it ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions