Skip to content

dggaytan/DCGAN #1321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions dcgan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,34 @@ python download.py -c bedroom
## Usage

```
usage: main.py [-h] --dataset DATASET --dataroot DATAROOT [--workers WORKERS]
[--batchSize BATCHSIZE] [--imageSize IMAGESIZE] [--nz NZ]
[--ngf NGF] [--ndf NDF] [--niter NITER] [--lr LR]
[--beta1 BETA1] [--cuda] [--ngpu NGPU] [--netG NETG]
[--netD NETD] [--mps]
usage: main.py [-h] --dataset DATASET [--dataroot DATAROOT] [--workers WORKERS] [--batchSize BATCHSIZE] [--imageSize IMAGESIZE]
[--nz NZ] [--ngf NGF] [--ndf NDF] [--niter NITER] [--lr LR] [--beta1 BETA1] [--cuda] [--xpu] [--dry-run]
[--ngpu NGPU] [--netG NETG] [--netD NETD] [--outf OUTF] [--manualSeed MANUALSEED] [--classes CLASSES] [--mps]

optional arguments:
options:
-h, --help show this help message and exit
--dataset DATASET cifar10 | lsun | mnist |imagenet | folder | lfw | fake
--dataroot DATAROOT path to dataset
--workers WORKERS number of data loading workers
--batchSize BATCHSIZE input batch size
--imageSize IMAGESIZE the height / width of the input image to network
--batchSize BATCHSIZE
input batch size
--imageSize IMAGESIZE
the height / width of the input image to network
--nz NZ size of the latent z vector
--ngf NGF number of filters in the generator
--ndf NDF number of filters in the discriminator
--ngf NGF
--ndf NDF
--niter NITER number of epochs to train for
--lr LR learning rate, default=0.0002
--beta1 BETA1 beta1 for adam. default=0.5
--cuda enables cuda
--mps enables macOS GPU
--xpu enables XPU training
--dry-run check a single training cycle works
--ngpu NGPU number of GPUs to use
--netG NETG path to netG (to continue training)
--netD NETD path to netD (to continue training)
--outf OUTF folder to output images and model checkpoints
--manualSeed SEED manual seed
--manualSeed MANUALSEED
manual seed
--classes CLASSES comma separated list of classes for the lsun data set
--mps enables macOS GPU training
```
20 changes: 20 additions & 0 deletions dcgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', default=False, help='enables cuda')
parser.add_argument('--xpu', action='store_true', default=False, help='enables XPU training')
parser.add_argument('--dry-run', action='store_true', help='check a single training cycle works')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
Expand Down Expand Up @@ -57,6 +58,9 @@
if torch.backends.mps.is_available() and not opt.mps:
print("WARNING: You have mps device, to enable macOS GPU run with --mps")

if torch.xpu.is_available() and not opt.xpu:
print("WARNING: You have XPU device, to enable XPU training run with --xpu")

if opt.dataroot is None and str(opt.dataset).lower() != 'fake':
raise ValueError("`dataroot` parameter is required for dataset \"%s\"" % opt.dataset)

Expand Down Expand Up @@ -107,19 +111,29 @@
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
use_mps = opt.mps and torch.backends.mps.is_available()
use_xpu = opt.xpu and torch.xpu.is_available()

##xpu support
if opt.cuda:
device = torch.device("cuda:0")
elif use_mps:
device = torch.device("mps")
elif use_xpu:
device = torch.device("xpu")
print("Number of devices: ", torch.xpu.device_count())
else:
device = torch.device("cpu")


ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)


print("DEVICE TO USE: ", device)
print(ngpu)

# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
Expand Down Expand Up @@ -160,6 +174,8 @@ def __init__(self, ngpu):
def forward(self, input):
if input.is_cuda and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
if input.is_xpu and self.ngpu > 1:
output = nn.DataParallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
Expand Down Expand Up @@ -200,6 +216,8 @@ def __init__(self, ngpu):
def forward(self, input):
if input.is_cuda and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
if input.is_xpu and self.ngpu > 1:
output = nn.DataParallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)

Expand Down Expand Up @@ -281,3 +299,5 @@ def forward(self, input):
# do checkpointing
torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))