Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepFM 1n1d example #360

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 130 additions & 31 deletions RecommenderSystems/deepfm/deepfm_train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import oneflow.nn as nn
from petastorm.reader import make_batch_reader

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
)


def get_args(print_args=True):
Expand All @@ -24,42 +26,68 @@ def str_list(x):

parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument(
"--num_train_samples", type=int, required=True, help="the number of train samples"
"--num_train_samples",
type=int,
required=True,
help="the number of train samples",
)
parser.add_argument(
"--num_val_samples", type=int, required=True, help="the number of validation samples"
"--num_val_samples",
type=int,
required=True,
help="the number of validation samples",
)
parser.add_argument(
"--num_test_samples", type=int, required=True, help="the number of test samples"
)

parser.add_argument("--model_load_dir", type=str, default=None, help="model loading directory")
parser.add_argument("--model_save_dir", type=str, default=None, help="model saving directory")
parser.add_argument(
"--save_initial_model", action="store_true", help="save initial model parameters or not"
"--model_load_dir", type=str, default=None, help="model loading directory"
)
parser.add_argument(
"--model_save_dir", type=str, default=None, help="model saving directory"
)
parser.add_argument(
"--save_initial_model",
action="store_true",
help="save initial model parameters or not",
)
parser.add_argument(
"--save_model_after_each_eval",
action="store_true",
help="save model after each eval or not",
)

parser.add_argument("--embedding_vec_size", type=int, default=16, help="embedding vector size")
parser.add_argument(
"--dnn", type=int_list, default="1000,1000,1000,1000,1000", help="dnn hidden units number"
"--embedding_vec_size", type=int, default=16, help="embedding vector size"
)
parser.add_argument(
"--dnn",
type=int_list,
default="1000,1000,1000,1000,1000",
help="dnn hidden units number",
)
parser.add_argument(
"--net_dropout", type=float, default=0.2, help="net dropout rate"
)
parser.add_argument(
"--disable_fusedmlp", action="store_true", help="disable fused MLP or not"
)
parser.add_argument("--net_dropout", type=float, default=0.2, help="net dropout rate")
parser.add_argument("--disable_fusedmlp", action="store_true", help="disable fused MLP or not")

parser.add_argument("--lr_factor", type=float, default=0.1)
parser.add_argument("--min_lr", type=float, default=1.0e-6)
parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate")
parser.add_argument(
"--learning_rate", type=float, default=0.001, help="learning rate"
)

parser.add_argument(
"--batch_size", type=int, default=10000, help="training/evaluation batch size"
)
parser.add_argument(
"--train_batches", type=int, default=75000, help="the maximum number of training batches"
"--train_batches",
type=int,
default=75000,
help="the maximum number of training batches",
)
parser.add_argument("--loss_print_interval", type=int, default=100, help="")

Expand All @@ -83,7 +111,10 @@ def str_list(x):
required=True,
)
parser.add_argument(
"--persistent_path", type=str, required=True, help="path for persistent kv store"
"--persistent_path",
type=str,
required=True,
help="path for persistent kv store",
)
parser.add_argument(
"--store_type",
Expand All @@ -99,15 +130,31 @@ def str_list(x):
)

parser.add_argument(
"--amp", action="store_true", help="enable Automatic Mixed Precision(AMP) training or not"
"--amp",
action="store_true",
help="enable Automatic Mixed Precision(AMP) training or not",
)
parser.add_argument(
"--loss_scale_policy", type=str, default="static", help="static or dynamic"
)
parser.add_argument("--loss_scale_policy", type=str, default="static", help="static or dynamic")

parser.add_argument(
"--disable_early_stop", action="store_true", help="enable early stop or not"
)
parser.add_argument("--save_best_model", action="store_true", help="save best model or not")

parser.add_argument(
"--save_best_model", action="store_true", help="save best model or not"
)
parser.add_argument(
"--save_graph_for_serving",
action="store_true",
help="Save Graph and OneEmbedding for serving. ",
)
parser.add_argument(
"--model_serving_path",
type=str,
required=True,
help="Graph object path for model serving",
)
args = parser.parse_args()

if print_args and flow.env.get_rank() == 0:
Expand Down Expand Up @@ -191,7 +238,9 @@ def get_batches(self, reader, batch_size=None):
pos = batch_size - len(tail[0])
tail = list(
[
np.concatenate((tail[i], rglist[i][0 : (batch_size - len(tail[i]))]))
np.concatenate(
(tail[i], rglist[i][0 : (batch_size - len(tail[i]))])
)
for i in range(self.num_fields)
]
)
Expand All @@ -205,7 +254,9 @@ def get_batches(self, reader, batch_size=None):
continue
while (pos + batch_size) <= len(rglist[0]):
label = rglist[0][pos : pos + batch_size]
features = [rglist[j][pos : pos + batch_size] for j in range(1, self.num_fields)]
features = [
rglist[j][pos : pos + batch_size] for j in range(1, self.num_fields)
]
pos += batch_size
yield label, np.stack(features, axis=-1)
if pos != len(rglist[0]):
Expand Down Expand Up @@ -263,7 +314,9 @@ def __init__(

if store_type == "device_mem":
store_options = flow.one_embedding.make_device_mem_store_options(
persistent_path=persistent_path, capacity=vocab_size, size_factor=size_factor,
persistent_path=persistent_path,
capacity=vocab_size,
size_factor=size_factor,
)
elif store_type == "cached_host_mem":
assert cache_memory_budget_mb > 0
Expand Down Expand Up @@ -325,7 +378,9 @@ def __init__(
use_relu = [True] * len(hidden_units) + [not skip_final_activation]
hidden_units = [in_features] + hidden_units + [out_features]
for idx in range(len(hidden_units) - 1):
denses.append(nn.Linear(hidden_units[idx], hidden_units[idx + 1], bias=True))
denses.append(
nn.Linear(hidden_units[idx], hidden_units[idx + 1], bias=True)
)
if use_relu[idx]:
denses.append(nn.ReLU())
if dropout_rates[idx] > 0:
Expand Down Expand Up @@ -425,7 +480,13 @@ def build(self, features):

class DeepFMTrainGraph(flow.nn.Graph):
def __init__(
self, deepfm_module, loss, optimizer, grad_scaler=None, amp=False, lr_scheduler=None,
self,
deepfm_module,
loss,
optimizer,
grad_scaler=None,
amp=False,
lr_scheduler=None,
):
super(DeepFMTrainGraph, self).__init__()
self.module = deepfm_module
Expand All @@ -449,7 +510,9 @@ def make_lr_scheduler(args, optimizer):
batches_per_epoch = math.ceil(args.num_train_samples / args.batch_size)
milestones = [
batches_per_epoch * (i + 1)
for i in range(math.floor(math.log(args.min_lr / args.learning_rate, args.lr_factor)))
for i in range(
math.floor(math.log(args.min_lr / args.learning_rate, args.lr_factor))
)
]
multistep_lr = flow.optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=milestones, gamma=args.lr_factor,
Expand All @@ -466,7 +529,9 @@ def get_metrics(logs):
return monitor_value


def early_stop(epoch, monitor_value, best_metric, stopping_steps, patience=2, min_delta=1e-6):
def early_stop(
epoch, monitor_value, best_metric, stopping_steps, patience=2, min_delta=1e-6
):
rank = flow.env.get_rank()
stop_training = False
save_best = False
Expand Down Expand Up @@ -525,7 +590,10 @@ def save_model(subdir):
grad_scaler = flow.amp.StaticGradScaler(1024)
else:
grad_scaler = flow.amp.GradScaler(
init_scale=1073741824, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000,
init_scale=1073741824,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
)

eval_graph = DeepFMValGraph(deepfm_module, args.amp)
Expand All @@ -541,7 +609,9 @@ def save_model(subdir):
stop_training = False

cached_eval_batches = prefetch_eval_batches(
f"{args.data_dir}/val", args.batch_size, math.ceil(args.num_val_samples / args.batch_size)
f"{args.data_dir}/val",
args.batch_size,
math.ceil(args.num_val_samples / args.batch_size),
)

deepfm_module.train()
Expand Down Expand Up @@ -604,14 +674,24 @@ def save_model(subdir):
print("================ Test Evaluation ================")
eval(args, eval_graph, tag="test", cur_step=step, epoch=epoch)

if args.save_graph_for_serving:
del eval_graph
recompiled_eval_graph = compile_eval_graph(args, deepfm_module, tag="test")
flow.save(recompiled_eval_graph, args.model_serving_path)
flow.save_one_embedding_info(recompiled_eval_graph, args.model_serving_path)


def np_to_global(np):
t = flow.from_numpy(np)
return t.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0))
return t.to_global(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单卡情况下,设置sbp为broadcast

placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.broadcast
)


def batch_to_global(np_label, np_features, is_train=True):
labels = np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1)
labels = (
np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1)
)
features = np_to_global(np_features)
return labels, features

Expand Down Expand Up @@ -653,11 +733,15 @@ def eval(args, eval_graph, tag="val", cur_step=0, epoch=0, cached_eval_batches=N
preds.append(pred.to_local())

labels = (
np_to_global(np.concatenate(labels, axis=0)).to_global(sbp=flow.sbp.broadcast()).to_local()
np_to_global(np.concatenate(labels, axis=0))
.to_global(sbp=flow.sbp.broadcast())
.to_local()
)
preds = (
flow.cat(preds, dim=0)
.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0))
.to_global(
placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)
)
.to_global(sbp=flow.sbp.broadcast())
.to_local()
)
Expand All @@ -669,7 +753,9 @@ def eval(args, eval_graph, tag="val", cur_step=0, epoch=0, cached_eval_batches=N

metrics_start_time = time.time()
auc = flow.roc_auc_score(labels, preds).numpy()[0]
logloss = flow._C.binary_cross_entropy_loss(preds, labels, weight=None, reduction="mean")
logloss = flow._C.binary_cross_entropy_loss(
preds, labels, weight=None, reduction="mean"
)
metrics_time = time.time() - metrics_start_time

if rank == 0:
Expand All @@ -687,6 +773,19 @@ def eval(args, eval_graph, tag="val", cur_step=0, epoch=0, cached_eval_batches=N
return auc, logloss


def compile_eval_graph(args, deepfm_module, tag="val"):
eval_graph = DeepFMValGraph(deepfm_module, args.amp)
eval_graph.module.eval()
with make_criteo_dataloader(
f"{args.data_dir}/{tag}", args.batch_size, shuffle=False
) as loader:
label, features = batch_to_global(*next(loader), is_train=False)
# Cause we want to infer to GPU, so here set is_train as True to place input Tensor in CUDA Device
features = features.to("cuda")
pred = eval_graph(features)
return eval_graph


if __name__ == "__main__":
os.system(sys.executable + " -m oneflow --doctor")
flow.boxing.nccl.enable_all_to_all(True)
Expand Down
25 changes: 25 additions & 0 deletions RecommenderSystems/deepfm/train_deepfm_1d.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
DATA_DIR=/path/to/deepfm_parquet
PERSISTENT_PATH=/path/to/persistent
MODEL_SAVE_DIR=/path/to/model/save/dir
MODEL_SERVING_PATH=/path/to/model_serving/save/dir

python3 deepfm_train_eval.py \
--data_dir $DATA_DIR \
--persistent_path $PERSISTENT_PATH \
--table_size_array "649,9364,14746,490,476707,11618,4142,1373,7275,13,169,407,1376,1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572" \
--store_type 'cached_host_mem' \
--cache_memory_budget_mb 1024 \
--batch_size 10000 \
--train_batches 75000 \
--loss_print_interval 100 \
--dnn "1000,1000,1000,1000,1000" \
--net_dropout 0.2 \
--learning_rate 0.001 \
--embedding_vec_size 16 \
--num_train_samples 36672493 \
--num_val_samples 4584062 \
--num_test_samples 4584062 \
--model_save_dir $MODEL_SAVE_DIR \
--save_best_model \
--save_graph_for_serving \
--model_serving_path $MODEL_SERVING_PATH