Skip to content

Commit f206180

Browse files
842974287facebook-github-bot
authored andcommitted
add validation for weights tensor in sparse feature (pytorch#680)
Summary: Pull Request resolved: pytorch#680 Validate weights tensor if it's provided. Reviewed By: zyan0 Differential Revision: D40031144 fbshipit-source-id: ce6bc715c3ab8b06999c59d8c1a70415e819cd38
1 parent a3a2fe0 commit f206180

File tree

3 files changed

+20
-7
lines changed

3 files changed

+20
-7
lines changed

torchrec/inference/include/torchrec/inference/Validation.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ namespace torchrec {
1616
// Currently validates:
1717
// 1. Whether sum(lengths) == size(values)
1818
// 2. Whether there are negative values in lengths
19-
bool validateSparseFeatures(at::Tensor& values, at::Tensor& lengths);
19+
// 3. If weights is present, whether sum(lengths) == size(weights)
20+
bool validateSparseFeatures(
21+
at::Tensor& values,
22+
at::Tensor& lengths,
23+
c10::optional<at::Tensor> maybeWeights = c10::nullopt);
2024

2125
// Returns whether dense features are valid.
2226
// Currently validates:

torchrec/inference/src/Validation.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,20 @@
1111

1212
namespace torchrec {
1313

14-
bool validateSparseFeatures(at::Tensor& values, at::Tensor& lengths) {
14+
bool validateSparseFeatures(
15+
at::Tensor& values,
16+
at::Tensor& lengths,
17+
c10::optional<at::Tensor> maybeWeights) {
1518
auto flatLengths = lengths.view(-1);
1619

17-
// validate sum of lengths equals number of values
20+
// validate sum of lengths equals number of values/weights
1821
auto lengthsTotal = at::sum(flatLengths).item<int>();
1922
if (lengthsTotal != values.size(0)) {
2023
return false;
2124
}
25+
if (maybeWeights.has_value() && lengthsTotal != maybeWeights->size(0)) {
26+
return false;
27+
}
2228

2329
// Validate no negative values in lengths.
2430
// Use faster path if contiguous.

torchrec/inference/tests/ValidationTest.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,24 @@
1414
TEST(ValidationTest, validateSparseFeatures) {
1515
auto values = at::tensor({1, 2, 3, 4});
1616
auto lengths = at::tensor({1, 1, 1, 1});
17+
auto weights = at::tensor({.1, .2, .3, .4});
1718

1819
// pass 1D
19-
EXPECT_TRUE(torchrec::validateSparseFeatures(values, lengths));
20+
EXPECT_TRUE(torchrec::validateSparseFeatures(values, lengths, weights));
2021

2122
// pass 2D
2223
lengths.reshape({2, 2});
23-
EXPECT_TRUE(torchrec::validateSparseFeatures(values, lengths));
24+
EXPECT_TRUE(torchrec::validateSparseFeatures(values, lengths, weights));
2425

2526
// fail 1D
2627
auto invalidLengths = at::tensor({1, 2, 1, 1});
27-
EXPECT_FALSE(torchrec::validateSparseFeatures(values, invalidLengths));
28+
EXPECT_FALSE(
29+
torchrec::validateSparseFeatures(values, invalidLengths, weights));
2830

2931
// fail 2D
3032
invalidLengths.reshape({2, 2});
31-
EXPECT_FALSE(torchrec::validateSparseFeatures(values, invalidLengths));
33+
EXPECT_FALSE(
34+
torchrec::validateSparseFeatures(values, invalidLengths, weights));
3235
}
3336

3437
TEST(ValidationTest, validateDenseFeatures) {

0 commit comments

Comments
 (0)