@@ -2484,6 +2484,73 @@ def get_val(dtype):
2484
2484
2485
2485
return tuple(samples)
2486
2486
2487
+ def sample_inputs_multinomial(self, device, dtype, requires_grad, **kwargs):
2488
+ cases = [
2489
+ ([3], 3, dict()),
2490
+ ([10], 3, dict()),
2491
+ ([3, 10], 3, dict()),
2492
+ ([3], 3, dict(replacement=False)),
2493
+ ([3], 3, dict(replacement=True)),
2494
+ ([3, 4], 4, dict(replacement=True)),
2495
+ ([3, 4], 4, dict(replacement=False)),
2496
+ ]
2497
+
2498
+ samples = []
2499
+ for shape, num_samples, kwargs in cases:
2500
+ t = make_tensor(shape, device, dtype,
2501
+ low=0, high=None,
2502
+ requires_grad=requires_grad)
2503
+ samples.append(SampleInput(t, args=(num_samples,), kwargs=kwargs))
2504
+ return tuple(samples)
2505
+
2506
+ def sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs):
2507
+ def get_value_or_make_tensor(value_or_shape):
2508
+ if isinstance(value_or_shape, list):
2509
+ return make_tensor(value_or_shape, device, dtype,
2510
+ low=0, high=None,
2511
+ requires_grad=requires_grad)
2512
+ return value_or_shape
2513
+
2514
+ samples = []
2515
+ for value_or_mean_shape, value_or_std_shape, kwargs in cases:
2516
+ mean = get_value_or_make_tensor(value_or_mean_shape)
2517
+ std = get_value_or_make_tensor(value_or_std_shape)
2518
+ samples.append(SampleInput(mean, args=(std,), kwargs=kwargs))
2519
+ return tuple(samples)
2520
+
2521
+ def sample_inputs_normal_tensor_first(self, device, dtype, requires_grad, **kwargs):
2522
+ # value_or_size, value_or_size, kwargs
2523
+ cases = [
2524
+ ([], [], {}),
2525
+ ([3], [3], {}),
2526
+ ([3, 4, 2], [3, 4, 2], {}),
2527
+ ([2, 3], 1.1, {}),
2528
+ ]
2529
+
2530
+ return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs)
2531
+
2532
+ def sample_inputs_normal_tensor_second(self, device, dtype, requires_grad, **kwargs):
2533
+ cases = [
2534
+ ([3, 4], 0.3, {}),
2535
+ ]
2536
+ return sample_inputs_normal_common(self, device, dtype, requires_grad, cases, **kwargs)
2537
+
2538
+ def sample_inputs_bernoulli(self, device, dtype, requires_grad, **kwargs):
2539
+ shapes = [
2540
+ [3],
2541
+ [],
2542
+ [0, 3],
2543
+ [2, 3, 4],
2544
+ ]
2545
+
2546
+ samples = []
2547
+ for shape in shapes:
2548
+ t = make_tensor(shape, device, dtype,
2549
+ low=0, high=1,
2550
+ requires_grad=requires_grad)
2551
+ samples.append(SampleInput(t))
2552
+ return tuple(samples)
2553
+
2487
2554
def sample_inputs_logcumsumexp(self, device, dtype, requires_grad):
2488
2555
inputs = (
2489
2556
((S, S, S), 0),
@@ -13316,6 +13383,76 @@ def ref_pairwise_distance(input1, input2):
13316
13383
DecorateInfo(unittest.skip("Skipped!"), 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
13317
13384
),
13318
13385
supports_autograd=False),
13386
+ OpInfo('multinomial',
13387
+ op=lambda inp, *args, **kwargs:
13388
+ wrapper_set_seed(torch.multinomial, inp, *args, **kwargs),
13389
+ method_variant=lambda inp, *args, **kwargs:
13390
+ wrapper_set_seed(torch.Tensor.multinomial, inp, *args, **kwargs),
13391
+ dtypes=floating_types(),
13392
+ dtypesIfCUDA=floating_types_and(torch.half),
13393
+ supports_out=True,
13394
+ sample_inputs_func=sample_inputs_multinomial,
13395
+ skips=(
13396
+ # AssertionError: JIT Test does not execute any logic
13397
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13398
+ # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13399
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning')),
13400
+ supports_autograd=False),
13401
+ OpInfo('normal',
13402
+ op=lambda inp, *args, **kwargs:
13403
+ wrapper_set_seed(torch.normal, inp, *args, **kwargs),
13404
+ # The inplace variant (Tensor.normal_) is different from torch.normal
13405
+ inplace_variant=None,
13406
+ dtypes=floating_types_and(torch.bfloat16, torch.half),
13407
+ dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
13408
+ supports_out=True,
13409
+ sample_inputs_func=sample_inputs_normal_tensor_first,
13410
+ skips=(
13411
+ # AssertionError: JIT Test does not execute any logic
13412
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13413
+ # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13414
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),),
13415
+ supports_autograd=False),
13416
+ OpInfo('normal',
13417
+ # This has its own variant b/c OpInfos assume the first arg is a Tensor but it is not here
13418
+ variant_test_name='number_mean',
13419
+ op=lambda std, mean, *args, **kwargs:
13420
+ wrapper_set_seed(torch.normal, mean, std, *args, **kwargs),
13421
+ # The inplace variant (Tensor.normal_) is different from torch.normal
13422
+ inplace_variant=None,
13423
+ dtypes=floating_types_and(torch.bfloat16, torch.half),
13424
+ dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
13425
+ supports_out=True,
13426
+ sample_inputs_func=sample_inputs_normal_tensor_second,
13427
+ skips=(
13428
+ # AssertionError: JIT Test does not execute any logic
13429
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13430
+ # Seems like a bug:
13431
+ # The size of tensor a (0) must match the size of tensor b (4) at non-singleton dimension 1
13432
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
13433
+ # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13434
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning'),),
13435
+ supports_autograd=False),
13436
+ OpInfo('bernoulli',
13437
+ op=lambda inp, *args, **kwargs:
13438
+ wrapper_set_seed(torch.bernoulli, inp, *args, **kwargs),
13439
+ # The inplace variant (Tensor.bernoulli_) is different from torch.bernoulli
13440
+ inplace_variant=None,
13441
+ method_variant=lambda inp, *args, **kwargs:
13442
+ wrapper_set_seed(torch.Tensor.bernoulli, inp, *args, **kwargs),
13443
+ dtypes=floating_types_and(torch.bfloat16),
13444
+ dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.half),
13445
+ supports_out=True,
13446
+ sample_inputs_func=sample_inputs_bernoulli,
13447
+ skips=(
13448
+ # AssertionError: JIT Test does not execute any logic
13449
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
13450
+ # Expected RuntimeError when doing an unsafe cast from a result of
13451
+ # dtype torch.float32 into an out= with dtype torch.lon
13452
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
13453
+ # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
13454
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out_warning')),
13455
+ supports_autograd=False),
13319
13456
OpInfo('scatter_add',
13320
13457
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
13321
13458
sample_inputs_func=sample_inputs_scatter_add,
0 commit comments