I am trying to train NVAE on a medical chest X-ray dataset (IU-Xray). The dataset contains grayscale images (1 channel), not RGB images. During training, I encountered several issues related to channel handling and batch normalization.
Initially, the model assumed RGB input (3 channels), which caused a runtime error when the grayscale images (1 channel) were passed to the first convolution layer. I fixed this by modifying the model stem to use Cin = 1 for the medical dataset.
After that, I encountered a BatchNorm error (running_mean should contain 256 elements not 32) when running in distributed/debug mode on a single GPU. This issue appears to be caused by the use of distributed training / SyncBatchNorm in a single-GPU setting, which leads to a mismatch between BatchNorm buffer sizes and actual feature channels.
In summary, the main issues were:
NVAE assumes RGB input by default, but medical X-ray images are grayscale.
Converting X-ray images to RGB works technically but introduces redundant channels and is not ideal for medical modeling.
Distributed mode with SyncBatchNorm can cause channel mismatch errors when training on a single GPU.
I am trying to train NVAE on a medical chest X-ray dataset (IU-Xray). The dataset contains grayscale images (1 channel), not RGB images. During training, I encountered several issues related to channel handling and batch normalization.
Initially, the model assumed RGB input (3 channels), which caused a runtime error when the grayscale images (1 channel) were passed to the first convolution layer. I fixed this by modifying the model stem to use Cin = 1 for the medical dataset.
After that, I encountered a BatchNorm error (running_mean should contain 256 elements not 32) when running in distributed/debug mode on a single GPU. This issue appears to be caused by the use of distributed training / SyncBatchNorm in a single-GPU setting, which leads to a mismatch between BatchNorm buffer sizes and actual feature channels.
In summary, the main issues were:
NVAE assumes RGB input by default, but medical X-ray images are grayscale.
Converting X-ray images to RGB works technically but introduces redundant channels and is not ideal for medical modeling.
Distributed mode with SyncBatchNorm can cause channel mismatch errors when training on a single GPU.