Skip to content

Commit

Permalink
bf16 from Jonghyun
Browse files Browse the repository at this point in the history
  • Loading branch information
pzhanggit committed Mar 7, 2024
1 parent 75df65e commit f0f38a2
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion hydragnn/train/train_validate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@ def train(
tr.start("forward")
with record_function("forward"):
# Perform forward pass and backward pass under autocast
with autocast(enabled=use_tensor_cores):
with autocast(enabled=use_tensor_cores, dtype=torch.bfloat16):
#with autocast(enabled=use_tensor_cores, dtype=torch.float16):
#with autocast(enabled=use_tensor_cores, dtype=torch.float32):
data = data.to(get_device())
pred = model(data)
loss, tasks_loss = model.module.loss(pred, data.y, head_index)
Expand Down

0 comments on commit f0f38a2

Please sign in to comment.