Skip to content

Commit bca25d9

Browse files
Jeongmin Leefacebook-github-bot
Jeongmin Lee
authored andcommitted
[itemwise-dropout][1/x][low-level module] Implement Itemwise Sparse Feature Dropout in Dper3 (pytorch#59322)
Summary: Pull Request resolved: pytorch#59322 Implement sparse feature dropout (with replacement) that can drop out individual items in each sparse feature. For example, the existing sparse feature dropout with replacement drops out whole feature (e.g., a list of page ids) when the feature is selected for drop out. This itemwise dropout assigns probability and drops out to individual items in sparse features. Test Plan: ``` buck test mode/dev caffe2/torch/fb/sparsenn:test ``` https://www.internalfb.com/intern/testinfra/testrun/281475166777899/ ``` buck test mode/dev //dper3/dper3/modules/tests:sparse_itemwise_dropout_with_replacement_test ``` https://www.internalfb.com/intern/testinfra/testrun/6473924504443423 ``` buck test mode/opt caffe2/caffe2/python:layers_test ``` https://www.internalfb.com/intern/testinfra/testrun/2533274848456607 ``` buck test mode/opt caffe2/caffe2/python/operator_test:sparse_itemwise_dropout_with_replacement_op_test ``` https://www.internalfb.com/intern/testinfra/testrun/8725724318782701 Reviewed By: Wakeupbuddy Differential Revision: D27867213 fbshipit-source-id: 8e173c7b3294abbc8bf8a3b04f723cb170446b96
1 parent 68df4d4 commit bca25d9

5 files changed

+365
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#include "caffe2/operators/sparse_itemwise_dropout_with_replacement_op.h"
2+
3+
#include <algorithm>
4+
#include <iterator>
5+
6+
namespace caffe2 {
7+
8+
template <>
9+
bool SparseItemwiseDropoutWithReplacementOp<CPUContext>::RunOnDevice() {
10+
auto& X = Input(0);
11+
CAFFE_ENFORCE_EQ(X.ndim(), 1, "Input tensor should be 1-D");
12+
const int64_t* Xdata = X.data<int64_t>();
13+
auto& Lengths = Input(1);
14+
CAFFE_ENFORCE_EQ(Lengths.ndim(), 1, "Lengths tensor should be 1-D");
15+
auto* OutputLengths = Output(1, Lengths.size(), at::dtype<int32_t>());
16+
int32_t const* input_lengths_data = Lengths.template data<int32_t>();
17+
int32_t* output_lengths_data =
18+
OutputLengths->template mutable_data<int32_t>();
19+
// Check that input lengths add up to the length of input data
20+
int total_input_length = 0;
21+
for (int i = 0; i < Lengths.numel(); ++i) {
22+
total_input_length += input_lengths_data[i];
23+
}
24+
CAFFE_ENFORCE_EQ(
25+
total_input_length,
26+
X.numel(),
27+
"Inconsistent input data. Number of elements should match total length.");
28+
29+
at::bernoulli_distribution<double> dist(1. - ratio_);
30+
auto* gen = context_.RandGenerator();
31+
const float _BARNUM = 0.5;
32+
vector<bool> selected(total_input_length, false);
33+
for (int i = 0; i < total_input_length; ++i) {
34+
if (dist(gen) > _BARNUM) {
35+
selected[i] = true;
36+
}
37+
}
38+
39+
for (int i = 0; i < Lengths.numel(); ++i) {
40+
output_lengths_data[i] = input_lengths_data[i];
41+
}
42+
43+
auto* Y = Output(0, {total_input_length}, at::dtype<int64_t>());
44+
int64_t* Ydata = Y->template mutable_data<int64_t>();
45+
46+
for (int i = 0; i < total_input_length; ++i) {
47+
if (selected[i]) {
48+
// Copy logical elements from input to output
49+
Ydata[i] = Xdata[i];
50+
} else {
51+
Ydata[i] = replacement_value_;
52+
}
53+
}
54+
return true;
55+
}
56+
57+
REGISTER_CPU_OPERATOR(
58+
SparseItemwiseDropoutWithReplacement,
59+
SparseItemwiseDropoutWithReplacementOp<CPUContext>);
60+
61+
OPERATOR_SCHEMA(SparseItemwiseDropoutWithReplacement)
62+
.NumInputs(2)
63+
.SameNumberOfOutput()
64+
.SetDoc(R"DOC(
65+
66+
`SparseItemwiseDropoutWithReplacement` takes a 1-d input tensor and a lengths tensor.
67+
Values in the Lengths tensor represent how many input elements consitute each
68+
example in a given batch. The each input value in the tensor of an example can be
69+
replaced with the replacement value with probability given by the `ratio`
70+
argument.
71+
72+
<details>
73+
74+
<summary> <b>Example</b> </summary>
75+
76+
**Code**
77+
78+
```
79+
workspace.ResetWorkspace()
80+
81+
op = core.CreateOperator(
82+
"SparseItemwiseDropoutWithReplacement",
83+
["X", "Lengths"],
84+
["Y", "OutputLengths"],
85+
ratio=0.5,
86+
replacement_value=-1
87+
)
88+
89+
workspace.FeedBlob("X", np.array([1, 2, 3, 4, 5]).astype(np.int64))
90+
workspace.FeedBlob("Lengths", np.array([2, 3]).astype(np.int32))
91+
print("X:", workspace.FetchBlob("X"))
92+
print("Lengths:", workspace.FetchBlob("Lengths"))
93+
workspace.RunOperatorOnce(op)
94+
print("Y:", workspace.FetchBlob("Y"))
95+
print("OutputLengths:", workspace.FetchBlob("OutputLengths"))
96+
```
97+
98+
**Result**
99+
100+
```
101+
X: [1, 2, 3, 4, 5]
102+
Lengths: [2, 3]
103+
Y: [1, 2, -1]
104+
OutputLengths: [2, 1]
105+
```
106+
107+
</details>
108+
109+
)DOC")
110+
.Arg(
111+
"ratio",
112+
"*(type: float; default: 0.0)* Probability of an element to be replaced.")
113+
.Arg(
114+
"replacement_value",
115+
"*(type: int64_t; default: 0)* Value elements are replaced with.")
116+
.Input(0, "X", "*(type: Tensor`<int64_t>`)* Input data tensor.")
117+
.Input(
118+
1,
119+
"Lengths",
120+
"*(type: Tensor`<int32_t>`)* Lengths tensor for input.")
121+
.Output(0, "Y", "*(type: Tensor`<int64_t>`)* Output tensor.")
122+
.Output(1, "OutputLengths", "*(type: Tensor`<int32_t>`)* Output tensor.");
123+
124+
NO_GRADIENT(SparseItemwiseDropoutWithReplacement);
125+
} // namespace caffe2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#ifndef CAFFE2_OPERATORS_SPARSE_ITEMWISE_DROPOUT_WITH_REPLACEMENT_OP_H_
2+
#define CAFFE2_OPERATORS_SPARSE_ITEMWISE_DROPOUT_WITH_REPLACEMENT_OP_H_
3+
4+
#include "caffe2/core/context.h"
5+
#include "caffe2/core/logging.h"
6+
#include "caffe2/core/operator.h"
7+
#include "caffe2/utils/math.h"
8+
9+
namespace caffe2 {
10+
11+
template <class Context>
12+
class SparseItemwiseDropoutWithReplacementOp final : public Operator<Context> {
13+
public:
14+
USE_OPERATOR_CONTEXT_FUNCTIONS;
15+
template <class... Args>
16+
explicit SparseItemwiseDropoutWithReplacementOp(Args&&... args)
17+
: Operator<Context>(std::forward<Args>(args)...),
18+
ratio_(this->template GetSingleArgument<float>("ratio", 0.0)),
19+
replacement_value_(
20+
this->template GetSingleArgument<int64_t>("replacement_value", 0)) {
21+
// It is allowed to drop all or drop none.
22+
CAFFE_ENFORCE_GE(ratio_, 0.0, "Ratio should be a valid probability");
23+
CAFFE_ENFORCE_LE(ratio_, 1.0, "Ratio should be a valid probability");
24+
}
25+
26+
bool RunOnDevice() override;
27+
28+
private:
29+
float ratio_;
30+
int64_t replacement_value_;
31+
};
32+
33+
} // namespace caffe2
34+
35+
#endif // CAFFE2_OPERATORS_SPARSE_ITEMWISE_DROPOUT_WITH_REPLACEMENT_OP_H_
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
2+
3+
4+
5+
6+
from caffe2.python import schema
7+
from caffe2.python.layers.layers import (
8+
IdList,
9+
ModelLayer,
10+
)
11+
12+
# Model layer for implementing probabilistic replacement of individual elements in
13+
# IdLists. Takes probabilities for train, eval and predict nets as input, as
14+
# well as the replacement value when dropout happens. For features we may have
15+
# available to us in train net but not in predict net, we'd set dropout
16+
# probability for predict net to be 1.0 and set the feature to the replacement
17+
# value given here. This way, the value is tied to the particular model and not
18+
# to any specific logic in feature processing in serving.
19+
20+
# Consider the following example where X is the values in the IdList and Lengths
21+
# is the number of values corresponding to each example.
22+
# X: [1, 2, 3, 4, 5]
23+
# Lengths: [2, 3]
24+
# This IdList contains 2 IdList features of lengths 2, 3. Let's assume we used a
25+
# ratio of 0.5 and ended up dropping out 2nd item in 2nd IdList feature, and used a
26+
# replacement value of -1. We will end up with the following IdList.
27+
28+
# Y: [1, 2, 3, -1, 5]
29+
# OutputLengths: [2, 3]
30+
# where the 2nd item in 2nd IdList feature [4] was replaced with [-1].
31+
32+
class SparseItemwiseDropoutWithReplacement(ModelLayer):
33+
def __init__(
34+
self,
35+
model,
36+
input_record,
37+
dropout_prob_train,
38+
dropout_prob_eval,
39+
dropout_prob_predict,
40+
replacement_value,
41+
name='sparse_itemwise_dropout',
42+
**kwargs):
43+
44+
super(SparseItemwiseDropoutWithReplacement, self).__init__(model, name, input_record, **kwargs)
45+
assert schema.equal_schemas(input_record, IdList), "Incorrect input type"
46+
47+
self.dropout_prob_train = float(dropout_prob_train)
48+
self.dropout_prob_eval = float(dropout_prob_eval)
49+
self.dropout_prob_predict = float(dropout_prob_predict)
50+
self.replacement_value = int(replacement_value)
51+
assert (self.dropout_prob_train >= 0 and
52+
self.dropout_prob_train <= 1.0), \
53+
"Expected 0 <= dropout_prob_train <= 1, but got %s" \
54+
% self.dropout_prob_train
55+
assert (self.dropout_prob_eval >= 0 and
56+
self.dropout_prob_eval <= 1.0), \
57+
"Expected 0 <= dropout_prob_eval <= 1, but got %s" \
58+
% dropout_prob_eval
59+
assert (self.dropout_prob_predict >= 0 and
60+
self.dropout_prob_predict <= 1.0), \
61+
"Expected 0 <= dropout_prob_predict <= 1, but got %s" \
62+
% dropout_prob_predict
63+
assert(self.dropout_prob_train > 0 or
64+
self.dropout_prob_eval > 0 or
65+
self.dropout_prob_predict > 0), \
66+
"Ratios all set to 0.0 for train, eval and predict"
67+
68+
self.output_schema = schema.NewRecord(model.net, IdList)
69+
if input_record.lengths.metadata:
70+
self.output_schema.lengths.set_metadata(
71+
input_record.lengths.metadata)
72+
if input_record.items.metadata:
73+
self.output_schema.items.set_metadata(
74+
input_record.items.metadata)
75+
76+
def _add_ops(self, net, ratio):
77+
input_values_blob = self.input_record.items()
78+
input_lengths_blob = self.input_record.lengths()
79+
80+
output_lengths_blob = self.output_schema.lengths()
81+
output_values_blob = self.output_schema.items()
82+
83+
net.SparseItemwiseDropoutWithReplacement(
84+
[
85+
input_values_blob,
86+
input_lengths_blob
87+
],
88+
[
89+
output_values_blob,
90+
output_lengths_blob
91+
],
92+
ratio=ratio,
93+
replacement_value=self.replacement_value
94+
)
95+
96+
def add_train_ops(self, net):
97+
self._add_ops(net, self.dropout_prob_train)
98+
99+
def add_eval_ops(self, net):
100+
self._add_ops(net, self.dropout_prob_eval)
101+
102+
def add_ops(self, net):
103+
self._add_ops(net, self.dropout_prob_predict)

caffe2/python/layers_test.py

+34
Original file line numberDiff line numberDiff line change
@@ -2480,3 +2480,37 @@ def testSparseLookupWithAttentionWeightOnIdScoreList(self):
24802480

24812481
predict_net = self.get_predict_net()
24822482
self.assertNetContainOps(predict_net, [sparse_lookup_op_spec])
2483+
2484+
def testSparseItemwiseDropoutWithReplacement(self):
2485+
input_record = schema.NewRecord(self.model.net, IdList)
2486+
self.model.output_schema = schema.Struct()
2487+
2488+
lengths_blob = input_record.field_blobs()[0]
2489+
values_blob = input_record.field_blobs()[1]
2490+
lengths = np.array([1] * 10).astype(np.int32)
2491+
values = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int64)
2492+
workspace.FeedBlob(lengths_blob, lengths)
2493+
workspace.FeedBlob(values_blob, values)
2494+
2495+
out = self.model.SparseItemwiseDropoutWithReplacement(
2496+
input_record, 0.0, 0.5, 1.0, -1, output_names_or_num=1)
2497+
self.assertEqual(schema.List(schema.Scalar(np.int64,)), out)
2498+
2499+
train_init_net, train_net = self.get_training_nets()
2500+
eval_net = self.get_eval_net()
2501+
predict_net = self.get_predict_net()
2502+
2503+
workspace.RunNetOnce(train_init_net)
2504+
workspace.RunNetOnce(train_net)
2505+
out_values = workspace.FetchBlob(out.items())
2506+
out_lengths = workspace.FetchBlob(out.lengths())
2507+
self.assertBlobsEqual(out_values, values)
2508+
self.assertBlobsEqual(out_lengths, lengths)
2509+
2510+
workspace.RunNetOnce(eval_net)
2511+
2512+
workspace.RunNetOnce(predict_net)
2513+
predict_values = workspace.FetchBlob("values_auto_0")
2514+
predict_lengths = workspace.FetchBlob("lengths_auto_0")
2515+
self.assertBlobsEqual(predict_values, np.array([-1] * 10).astype(np.int64))
2516+
self.assertBlobsEqual(predict_lengths, lengths)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
2+
3+
4+
5+
6+
from caffe2.python import core
7+
from hypothesis import given
8+
import caffe2.python.hypothesis_test_util as hu
9+
import numpy as np
10+
11+
12+
class SparseItemwiseDropoutWithReplacementTest(hu.HypothesisTestCase):
13+
@given(**hu.gcs_cpu_only)
14+
def test_no_dropout(self, gc, dc):
15+
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int64)
16+
Lengths = np.array([2, 2, 2, 2, 2]).astype(np.int32)
17+
replacement_value = -1
18+
self.ws.create_blob("X").feed(X)
19+
self.ws.create_blob("Lengths").feed(Lengths)
20+
sparse_dropout_op = core.CreateOperator(
21+
"SparseItemwiseDropoutWithReplacement", ["X", "Lengths"], ["Y", "LY"],
22+
ratio=0.0, replacement_value=replacement_value)
23+
self.ws.run(sparse_dropout_op)
24+
Y = self.ws.blobs["Y"].fetch()
25+
OutputLengths = self.ws.blobs["LY"].fetch()
26+
self.assertListEqual(X.tolist(), Y.tolist(),
27+
"Values should stay unchanged")
28+
self.assertListEqual(Lengths.tolist(), OutputLengths.tolist(),
29+
"Lengths should stay unchanged.")
30+
31+
@given(**hu.gcs_cpu_only)
32+
def test_all_dropout(self, gc, dc):
33+
X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).astype(np.int64)
34+
Lengths = np.array([2, 2, 2, 2, 2]).astype(np.int32)
35+
replacement_value = -1
36+
self.ws.create_blob("X").feed(X)
37+
self.ws.create_blob("Lengths").feed(Lengths)
38+
sparse_dropout_op = core.CreateOperator(
39+
"SparseItemwiseDropoutWithReplacement", ["X", "Lengths"], ["Y", "LY"],
40+
ratio=1.0, replacement_value=replacement_value)
41+
self.ws.run(sparse_dropout_op)
42+
y = self.ws.blobs["Y"].fetch()
43+
lengths = self.ws.blobs["LY"].fetch()
44+
for elem in y:
45+
self.assertEqual(elem, replacement_value, "Expected all \
46+
negative elements when dropout ratio is 1.")
47+
for length in lengths:
48+
self.assertEqual(length, 2)
49+
self.assertEqual(sum(lengths), len(y))
50+
51+
@given(**hu.gcs_cpu_only)
52+
def test_all_dropout_empty_input(self, gc, dc):
53+
X = np.array([]).astype(np.int64)
54+
Lengths = np.array([0]).astype(np.int32)
55+
replacement_value = -1
56+
self.ws.create_blob("X").feed(X)
57+
self.ws.create_blob("Lengths").feed(Lengths)
58+
sparse_dropout_op = core.CreateOperator(
59+
"SparseItemwiseDropoutWithReplacement", ["X", "Lengths"], ["Y", "LY"],
60+
ratio=1.0, replacement_value=replacement_value)
61+
self.ws.run(sparse_dropout_op)
62+
y = self.ws.blobs["Y"].fetch()
63+
lengths = self.ws.blobs["LY"].fetch()
64+
self.assertEqual(len(y), 0, "Expected no dropout value")
65+
self.assertEqual(len(lengths), 1, "Expected single element \
66+
in lengths array")
67+
self.assertEqual(lengths[0], 0, "Expected 0 as sole length")
68+
self.assertEqual(sum(lengths), len(y))

0 commit comments

Comments
 (0)