Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train.py中import导入错误have error,from dassl.data.datasets import OfficeHome, VisDA17, Office31 #8

Open
dw1360585641 opened this issue Sep 26, 2024 · 0 comments

Comments

@dw1360585641
Copy link

dw1360585641 commented Sep 26, 2024

我发现你在train.py中的 from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from dassl.engine import build_trainer
from dassl.data.datasets import OfficeHome, VisDA17, Office31 最后这一个from dassl.data.datasets import OfficeHome, VisDA17, Office31有问题,你都在本项目中的.datasets中重写了着三个数据集装载,为什么还要导入dassl中的dassl.data.datasets这些,我刚刚花了一晚上时间去排查这个问题,因为dassl包中自带的office31.py并不能直接用,会出现[https://github.com/KaiyangZhou/Dassl.pytorch/issues/52]这个问题:在 “{}” `处找不到文件“'.format(fpath),然后我修改了dassl包中的office31.py:代码如下:```
def _read_data(self, input_domains):
items = []

for domain, dname in enumerate(input_domains):
    domain_dir = os.path.join(self.dataset_dir, dname)
    print("domain_dir:", domain_dir)
    class_names = listdir_nohidden(domain_dir)
    class_names.sort()

    for label, class_name in enumerate(class_names):
        class_path = os.path.join(domain_dir, class_name)
        print("class_path:", class_path)
        imnames = listdir_nohidden(class_path)
        print("imnames:", imnames)

        for imname in imnames:
            # 检查 imname 是否是目录
            impath_dir = os.path.join(class_path, imname)
            if os.path.isdir(impath_dir):
                # 读取子目录下的图像文件
                sub_imnames = listdir_nohidden(impath_dir)
                for sub_imname in sub_imnames:
                    impath = os.path.join(impath_dir, sub_imname)
                    print("impath:", impath)
                    if os.path.isfile(impath):
                        item = Datum(
                            impath=impath,
                            label=label,
                            domain=domain,
                            classname=class_name
                        )
                        print("item:",item)
                        items.append(item)
                    else:
                        print("No such file or directory:", impath)
            else:
                print("Not a directory:", impath_dir)

return items

`这样会出现这个问题:
File "/data/kb/anaconda3/envs/zsh1/lib/python3.7/site-packages/torch/nn/functional.py", line 2247, in _verify_batch_size
raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size))
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 64])
在这里搜了其他人的解决方法,然后我去找到了dasll中的data_manager中的dataloader,发现代码里本身就是drop_last=True,所以并没有解决问题,
然后我又去修改config中的trainer中的offce31.yaml,调整了train_x,_u,test的输出通道为16,重新运行依然是got input size torch.Size([1, 64])这样不能解决,然后我就放弃了,但是这样也花费了数个小时。
综上,我希望您把train.py中的这行代码修改以下。
当然,也许是我自己的问题,毕竟dassl工具包第一次接触,看不懂他是怎么根据cfg去生成不同的trainer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant