diff --git a/src/masks/multiblock.py b/src/masks/multiblock.py index c82bf0f6..82d093d6 100644 --- a/src/masks/multiblock.py +++ b/src/masks/multiblock.py @@ -86,8 +86,8 @@ def constrain_mask(mask, tries=0): valid_mask = False while not valid_mask: # -- Sample block top-left corner - top = torch.randint(0, self.height - h, (1,)) - left = torch.randint(0, self.width - w, (1,)) + top = torch.randint(0, self.height - h + 1, (1,)) + left = torch.randint(0, self.width - w + 1, (1,)) mask = torch.zeros((self.height, self.width), dtype=torch.int32) mask[top:top+h, left:left+w] = 1 # -- Constrain mask to a set of acceptable regions