Skip to content

Use torch.accelerator API in Siamese Network example #1337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions siamese_network/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,37 @@
# Siamese Network Example
Siamese network for image similarity estimation.
The network is composed of two identical networks, one for each input.
The output of each network is concatenated and passed to a linear layer.
The output of the linear layer passed through a sigmoid function.
[FaceNet](https://arxiv.org/pdf/1503.03832.pdf) is a variant of the Siamese network.
This implementation varies from FaceNet as we use the `ResNet-18` model from
[Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) as our feature extractor.
In addition, we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick.

```bash
pip install -r requirements.txt
python main.py
# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2
```
Optionally, you can add the following arguments to customize your execution.

```bash
--batch-size input batch size for training (default: 64)
--test-batch-size input batch size for testing (default: 1000)
--epochs number of epochs to train (default: 14)
--lr learning rate (default: 1.0)
--gamma learning rate step gamma (default: 0.7)
--accel use accelerator
--dry-run quickly check a single pass
--seed random seed (default: 1)
--log-interval how many batches to wait before logging training status
--save-model Saving the current Model
```

To execute in an GPU, add the --accel argument to the command. For example:

```bash
python main.py --accel
```

This command will execute the example on the detected GPU.
27 changes: 14 additions & 13 deletions siamese_network/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def group_examples(self):
"""

# get the targets from MNIST dataset
np_arr = np.array(self.dataset.targets.clone())
np_arr = np.array(self.dataset.targets.clone(), dtype=None, copy=None)

# group examples based on class
self.grouped_examples = {}
Expand Down Expand Up @@ -247,10 +247,8 @@ def main():
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--no-mps', action='store_true', default=False,
help='disables macOS GPU training')
parser.add_argument('--accel', action='store_true',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to #1338 (comment). @eromomon : I did not see this PR before it got merged. But you've changed the default in this example. Previously acceleration was always used (if available) and there was a flag to disable it. Now default is to not use acceleration and there is a flag to enable it. I personally like new behavior better, but we need to change the CI script to reflect that and that was not done:

uv run main.py --epochs 1 --dry-run || error "siamese network example failed"

help='use accelerator')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
Expand All @@ -260,22 +258,25 @@ def main():
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()

use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()

torch.manual_seed(args.seed)

if use_cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
if args.accel and not torch.accelerator.is_available():
print("ERROR: accelerator is not available, try running on CPU")
sys.exit(1)
if not args.accel and torch.accelerator.is_available():
print("WARNING: accelerator is available, run with --accel to enable it")

if args.accel:
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")

print(f"Using device: {device}")

train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
if device=="cuda":
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
Expand Down
2 changes: 1 addition & 1 deletion siamese_network/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch
torchvision==0.20.0
torchvision