Skip to content

Evaluating Trained Models #100

@Markus28

Description

@Markus28

I was wondering how the trained models are intended to be evaluated. I don't believe that the paper states how many samples were used to compute the metrics. The code appears to give some indication but the testing functionality seems broken.
Assuming we train a model via:

python main.py +experiment=planar dataset=planar ++hydra.run.dir=<HEAD>/planar_debug

I would expect that we evaluate it on the test set via:

python main.py +experiment=planar dataset=planar general.test_only=<HEAD>/planar_debug/checkpoints/planar/last.ckpt

Unfortunately, this functionality is broken and gives this stack trace:

Traceback (most recent call last):
  File "/fs/gpfs41/lv11/fileset01/pool/pool-krimmel/DiGress/src/main.py", line 244, in <module>
    main()
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/fs/gpfs41/lv11/fileset01/pool/pool-krimmel/DiGress/src/main.py", line 176, in main
    cfg, _ = get_resume(cfg, model_kwargs)
  File "/fs/gpfs41/lv11/fileset01/pool/pool-krimmel/DiGress/src/main.py", line 31, in get_resume
    model = DiscreteDenoisingDiffusion.load_from_checkpoint(resume, **model_kwargs)
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/pytorch_lightning/core/module.py", line 1520, in load_from_checkpoint
    loaded = _load_from_checkpoint(
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/pytorch_lightning/core/saving.py", line 62, in _load_from_checkpoint
    checkpoint = pl_load(checkpoint_path, map_location=map_location)
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/lightning_fabric/utilities/cloud_io.py", line 51, in _load
    return torch.load(f, map_location=map_location)  # type: ignore[arg-type]
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/pickle.py", line 1212, in load
    dispatch[key[0]](self)
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/pickle.py", line 1717, in load_build
    setstate(state)
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 881, in __setstate__
    self.process_group = _get_default_group()
  File "/fs/pool/pool-krimmel/miniconda3/envs/digress/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 707, in _get_default_group
    raise RuntimeError(
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

I am using torch==2.0.1+cu118 and pytorch-lightning==2.0.4, as specified in the requirements.

So how are we actually supposed to evaluate the model? I think some instructions in the README would be valuable. Thanks for your help!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions