Skip to content

Commit

Permalink
device type fix
Browse files Browse the repository at this point in the history
Signed-off-by: Nithin Rao Koluguri <nithinraok>
  • Loading branch information
Nithin Rao Koluguri committed Sep 17, 2024
1 parent 973b729 commit 87e6981
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions nemo/collections/audio/modules/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ def forward(self, input: torch.Tensor, activity: torch.Tensor) -> torch.Tensor:
"""
B, num_inputs, F, T = input.shape
num_outputs = activity.size(1)
device = input.device.type

if activity.size(0) != B:
raise ValueError(f'Batch dimension mismatch: activity {activity.shape} vs input {input.shape}')
Expand All @@ -678,7 +679,7 @@ def forward(self, input: torch.Tensor, activity: torch.Tensor) -> torch.Tensor:
if num_outputs == 1:
raise ValueError(f'Expecting multiple outputs, got {num_outputs}')

with torch.amp.autocast(self.device.type, enabled=False):
with torch.amp.autocast(device, enabled=False):
input = input.to(dtype=self.dtype)

assert input.is_complex(), f'Expecting complex input, got {input.dtype}'
Expand Down Expand Up @@ -1039,8 +1040,9 @@ def forward(
shape (B, C, F, T).
"""
io_dtype = input.dtype
device = input.device.type

with torch.amp.autocast(self.device.type, enabled=False):
with torch.amp.autocast(device, enabled=False):
output = input.to(dtype=self.dtype)

if not output.is_complex():
Expand Down

0 comments on commit 87e6981

Please sign in to comment.