Skip to content

Commit 620a1fc

Browse files
zou3519facebook-github-bot
authored andcommitted
OpInfos for: normal, bernoulli, multinomial (pytorch#66358)
Summary: Pull Request resolved: pytorch#66358 Test Plan: - run tests Reviewed By: mruberry Differential Revision: D31551695 Pulled By: zou3519 fbshipit-source-id: cf1b43118a0414a1af9ece9ae8c0598b2701aa0a
1 parent 4829dce commit 620a1fc

File tree

3 files changed

+140
-1
lines changed

3 files changed

+140
-1
lines changed

test/test_fx.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3161,7 +3161,6 @@ def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
31613161
raise RuntimeError(f'Did not match any schemas for op {op.name}!')
31623162

31633163

3164-
31653164
class TestFXAPIBackwardCompatibility(JitTestCase):
31663165
def setUp(self):
31673166
self.maxDiff = None

test/test_fx_experimental.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,6 +1526,9 @@ def test_normalize_operator_exhaustive(self, device, dtype, op):
15261526
'new_empty',
15271527
'new_zeros',
15281528
'new_full',
1529+
'normal',
1530+
'multinomial',
1531+
'bernoulli',
15291532
"__getitem__",
15301533
"__radd__",
15311534
"__rsub__",

torch/testing/_internal/common_methods_invocations.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2484,6 +2484,73 @@ def get_val(dtype):
24842484

24852485
return tuple(samples)
24862486

2487+
def sample_inputs_multinomial(self, device, dtype, requires_grad, **kwargs):
2488+
cases = [
2489+
([3], 3, dict()),
2490+
([10], 3, dict()),
2491+
([3, 10], 3, dict()),
2492+
([3], 3, dict(replacement=False)),
2493+
([3], 3, dict(replacement=True)),
2494+
([3, 4], 4, dict(replacement=True)),
2495+
([3, 4], 4, dict(replacement=False)),
2496+
]
2497+
2498+
samples = []
2499+
for shape, num_samples, kwargs in cases:
2500+
t = make_tensor(shape, device, dtype,
2501+
low=0, high=None,
2502+
requires_grad=requires_grad)
2503+
samples.append(SampleInput(t, args=(num_samples,), kwargs=kwargs))
2504+
return tuple(samples)
2505+
2506+
def sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs):
2507+
def get_value_or_make_tensor(value_or_shape):
2508+
if isinstance(value_or_shape, list):
2509+
return make_tensor(value_or_shape, device, dtype,
2510+
low=0, high=None,
2511+
requires_grad=requires_grad)
2512+
return value_or_shape
2513+
2514+
samples = []
2515+
for value_or_mean_shape, value_or_std_shape, kwargs in cases:
2516+
mean = get_value_or_make_tensor(value_or_mean_shape)
2517+
std = get_value_or_make_tensor(value_or_std_shape)
2518+
samples.append(SampleInput(mean, args=(std,), kwargs=kwargs))
2519+
return tuple(samples)
2520+
2521+
def sample_inputs_normal_tensor_first(self, device, dtype, requires_grad, **kwargs):
2522+
# value_or_size, value_or_size, kwargs
2523+
cases = [
2524+
([], [], {}),
2525+
([3], [3], {}),
2526+
([3, 4, 2], [3, 4, 2], {}),
2527+
([2, 3], 1.1, {}),
2528+
]
2529+
2530+
return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs)
2531+
2532+
def sample_inputs_normal_tensor_second(self, device, dtype, requires_grad, **kwargs):
2533+
cases = [
2534+
([3, 4], 0.3, {}),
2535+
]
2536+
return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs)
2537+
2538+
def sample_inputs_bernoulli(self, device, dtype, requires_grad, **kwargs):
2539+
shapes = [
2540+
[3],
2541+
[],
2542+
[0, 3],
2543+
[2, 3, 4],
2544+
]
2545+
2546+
samples = []
2547+
for shape in shapes:
2548+
t = make_tensor(shape, device, dtype,
2549+
low=0, high=1,
2550+
requires_grad=requires_grad)
2551+
samples.append(SampleInput(t))
2552+
return tuple(samples)
2553+
24872554
def sample_inputs_logcumsumexp(self, device, dtype, requires_grad):
24882555
inputs = (
24892556
((S, S, S), 0),
@@ -13316,6 +13383,76 @@ def ref_pairwise_distance(input1, input2):
1331613383
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
1331713384
),
1331813385
supports_autograd=False),
13386+
OpInfo('multinomial',
13387+
op=lambda inp, *args, **kwargs:
13388+
wrapper_set_seed(torch.multinomial, inp, *args, **kwargs),
13389+
method_variant=lambda inp, *args, **kwargs:
13390+
wrapper_set_seed(torch.Tensor.multinomial, inp, *args, **kwargs),
13391+
dtypes=floating_types(),
13392+
dtypesIfCUDA=floating_types_and(torch.half),
13393+
supports_out=True,
13394+
sample_inputs_func=sample_inputs_multinomial,
13395+
skips=(
13396+
# AssertionError: JIT Test does not execute any logic
13397+
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13398+
# UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13399+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning')),
13400+
supports_autograd=False),
13401+
OpInfo('normal',
13402+
op=lambda inp, *args, **kwargs:
13403+
wrapper_set_seed(torch.normal, inp, *args, **kwargs),
13404+
# The inplace variant (Tensor.normal_) is different from torch.normal
13405+
inplace_variant=None,
13406+
dtypes=floating_types_and(torch.bfloat16, torch.half),
13407+
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
13408+
supports_out=True,
13409+
sample_inputs_func=sample_inputs_normal_tensor_first,
13410+
skips=(
13411+
# AssertionError: JIT Test does not execute any logic
13412+
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13413+
# UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13414+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),),
13415+
supports_autograd=False),
13416+
OpInfo('normal',
13417+
# This has its own variant b/c OpInfos assume the first arg is a Tensor but it is not here
13418+
variant_test_name='number_mean',
13419+
op=lambda std, mean, *args, **kwargs:
13420+
wrapper_set_seed(torch.normal, mean, std, *args, **kwargs),
13421+
# The inplace variant (Tensor.normal_) is different from torch.normal
13422+
inplace_variant=None,
13423+
dtypes=floating_types_and(torch.bfloat16, torch.half),
13424+
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
13425+
supports_out=True,
13426+
sample_inputs_func=sample_inputs_normal_tensor_second,
13427+
skips=(
13428+
# AssertionError: JIT Test does not execute any logic
13429+
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13430+
# Seems like a bug:
13431+
# The size of tensor a (0) must match the size of tensor b (4) at non-singleton dimension 1
13432+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
13433+
# UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13434+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),),
13435+
supports_autograd=False),
13436+
OpInfo('bernoulli',
13437+
op=lambda inp, *args, **kwargs:
13438+
wrapper_set_seed(torch.bernoulli, inp, *args, **kwargs),
13439+
# The inplace variant (Tensor.bernoulli_) is different from torch.bernoulli
13440+
inplace_variant=None,
13441+
method_variant=lambda inp, *args, **kwargs:
13442+
wrapper_set_seed(torch.Tensor.bernoulli, inp, *args, **kwargs),
13443+
dtypes=floating_types_and(torch.bfloat16),
13444+
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
13445+
supports_out=True,
13446+
sample_inputs_func=sample_inputs_bernoulli,
13447+
skips=(
13448+
# AssertionError: JIT Test does not execute any logic
13449+
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13450+
# Expected RuntimeError when doing an unsafe cast from a result of
13451+
# dtype torch.float32 into an out= with dtype torch.lon
13452+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
13453+
# UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13454+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning')),
13455+
supports_autograd=False),
1331913456
OpInfo('scatter_add',
1332013457
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
1332113458
sample_inputs_func=sample_inputs_scatter_add,

0 commit comments

Comments
 (0)