From 7429dceae8d201af0e2a441ed01eb1d194df88f5 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Tue, 19 Jul 2022 20:31:44 +0800 Subject: [PATCH 1/3] add 1n1d train script --- .../deepfm/deepfm_train_eval.py | 161 ++++++++++++++---- RecommenderSystems/deepfm/train_deepfm_1d.sh | 25 +++ 2 files changed, 155 insertions(+), 31 deletions(-) create mode 100644 RecommenderSystems/deepfm/train_deepfm_1d.sh diff --git a/RecommenderSystems/deepfm/deepfm_train_eval.py b/RecommenderSystems/deepfm/deepfm_train_eval.py index 5c396accb..ad4f9899c 100644 --- a/RecommenderSystems/deepfm/deepfm_train_eval.py +++ b/RecommenderSystems/deepfm/deepfm_train_eval.py @@ -10,7 +10,9 @@ import oneflow.nn as nn from petastorm.reader import make_batch_reader -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) +) def get_args(print_args=True): @@ -24,19 +26,31 @@ def str_list(x): parser.add_argument("--data_dir", type=str, required=True) parser.add_argument( - "--num_train_samples", type=int, required=True, help="the number of train samples" + "--num_train_samples", + type=int, + required=True, + help="the number of train samples", ) parser.add_argument( - "--num_val_samples", type=int, required=True, help="the number of validation samples" + "--num_val_samples", + type=int, + required=True, + help="the number of validation samples", ) parser.add_argument( "--num_test_samples", type=int, required=True, help="the number of test samples" ) - parser.add_argument("--model_load_dir", type=str, default=None, help="model loading directory") - parser.add_argument("--model_save_dir", type=str, default=None, help="model saving directory") parser.add_argument( - "--save_initial_model", action="store_true", help="save initial model parameters or not" + "--model_load_dir", type=str, default=None, help="model loading directory" + ) + parser.add_argument( + "--model_save_dir", type=str, default=None, help="model saving directory" + ) + parser.add_argument( + "--save_initial_model", + action="store_true", + help="save initial model parameters or not", ) parser.add_argument( "--save_model_after_each_eval", @@ -44,22 +58,36 @@ def str_list(x): help="save model after each eval or not", ) - parser.add_argument("--embedding_vec_size", type=int, default=16, help="embedding vector size") parser.add_argument( - "--dnn", type=int_list, default="1000,1000,1000,1000,1000", help="dnn hidden units number" + "--embedding_vec_size", type=int, default=16, help="embedding vector size" + ) + parser.add_argument( + "--dnn", + type=int_list, + default="1000,1000,1000,1000,1000", + help="dnn hidden units number", + ) + parser.add_argument( + "--net_dropout", type=float, default=0.2, help="net dropout rate" + ) + parser.add_argument( + "--disable_fusedmlp", action="store_true", help="disable fused MLP or not" ) - parser.add_argument("--net_dropout", type=float, default=0.2, help="net dropout rate") - parser.add_argument("--disable_fusedmlp", action="store_true", help="disable fused MLP or not") parser.add_argument("--lr_factor", type=float, default=0.1) parser.add_argument("--min_lr", type=float, default=1.0e-6) - parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate") + parser.add_argument( + "--learning_rate", type=float, default=0.001, help="learning rate" + ) parser.add_argument( "--batch_size", type=int, default=10000, help="training/evaluation batch size" ) parser.add_argument( - "--train_batches", type=int, default=75000, help="the maximum number of training batches" + "--train_batches", + type=int, + default=75000, + help="the maximum number of training batches", ) parser.add_argument("--loss_print_interval", type=int, default=100, help="") @@ -83,7 +111,10 @@ def str_list(x): required=True, ) parser.add_argument( - "--persistent_path", type=str, required=True, help="path for persistent kv store" + "--persistent_path", + type=str, + required=True, + help="path for persistent kv store", ) parser.add_argument( "--store_type", @@ -99,15 +130,31 @@ def str_list(x): ) parser.add_argument( - "--amp", action="store_true", help="enable Automatic Mixed Precision(AMP) training or not" + "--amp", + action="store_true", + help="enable Automatic Mixed Precision(AMP) training or not", + ) + parser.add_argument( + "--loss_scale_policy", type=str, default="static", help="static or dynamic" ) - parser.add_argument("--loss_scale_policy", type=str, default="static", help="static or dynamic") parser.add_argument( "--disable_early_stop", action="store_true", help="enable early stop or not" ) - parser.add_argument("--save_best_model", action="store_true", help="save best model or not") - + parser.add_argument( + "--save_best_model", action="store_true", help="save best model or not" + ) + parser.add_argument( + "--save_graph_for_serving", + action="store_true", + help="Save Graph and OneEmbedding for serving. ", + ) + parser.add_argument( + "--model_serving_path", + type=str, + required=True, + help="Graph object path for model serving", + ) args = parser.parse_args() if print_args and flow.env.get_rank() == 0: @@ -191,7 +238,9 @@ def get_batches(self, reader, batch_size=None): pos = batch_size - len(tail[0]) tail = list( [ - np.concatenate((tail[i], rglist[i][0 : (batch_size - len(tail[i]))])) + np.concatenate( + (tail[i], rglist[i][0 : (batch_size - len(tail[i]))]) + ) for i in range(self.num_fields) ] ) @@ -205,7 +254,9 @@ def get_batches(self, reader, batch_size=None): continue while (pos + batch_size) <= len(rglist[0]): label = rglist[0][pos : pos + batch_size] - features = [rglist[j][pos : pos + batch_size] for j in range(1, self.num_fields)] + features = [ + rglist[j][pos : pos + batch_size] for j in range(1, self.num_fields) + ] pos += batch_size yield label, np.stack(features, axis=-1) if pos != len(rglist[0]): @@ -263,7 +314,9 @@ def __init__( if store_type == "device_mem": store_options = flow.one_embedding.make_device_mem_store_options( - persistent_path=persistent_path, capacity=vocab_size, size_factor=size_factor, + persistent_path=persistent_path, + capacity=vocab_size, + size_factor=size_factor, ) elif store_type == "cached_host_mem": assert cache_memory_budget_mb > 0 @@ -325,7 +378,9 @@ def __init__( use_relu = [True] * len(hidden_units) + [not skip_final_activation] hidden_units = [in_features] + hidden_units + [out_features] for idx in range(len(hidden_units) - 1): - denses.append(nn.Linear(hidden_units[idx], hidden_units[idx + 1], bias=True)) + denses.append( + nn.Linear(hidden_units[idx], hidden_units[idx + 1], bias=True) + ) if use_relu[idx]: denses.append(nn.ReLU()) if dropout_rates[idx] > 0: @@ -425,7 +480,13 @@ def build(self, features): class DeepFMTrainGraph(flow.nn.Graph): def __init__( - self, deepfm_module, loss, optimizer, grad_scaler=None, amp=False, lr_scheduler=None, + self, + deepfm_module, + loss, + optimizer, + grad_scaler=None, + amp=False, + lr_scheduler=None, ): super(DeepFMTrainGraph, self).__init__() self.module = deepfm_module @@ -449,7 +510,9 @@ def make_lr_scheduler(args, optimizer): batches_per_epoch = math.ceil(args.num_train_samples / args.batch_size) milestones = [ batches_per_epoch * (i + 1) - for i in range(math.floor(math.log(args.min_lr / args.learning_rate, args.lr_factor))) + for i in range( + math.floor(math.log(args.min_lr / args.learning_rate, args.lr_factor)) + ) ] multistep_lr = flow.optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=milestones, gamma=args.lr_factor, @@ -466,7 +529,9 @@ def get_metrics(logs): return monitor_value -def early_stop(epoch, monitor_value, best_metric, stopping_steps, patience=2, min_delta=1e-6): +def early_stop( + epoch, monitor_value, best_metric, stopping_steps, patience=2, min_delta=1e-6 +): rank = flow.env.get_rank() stop_training = False save_best = False @@ -525,7 +590,10 @@ def save_model(subdir): grad_scaler = flow.amp.StaticGradScaler(1024) else: grad_scaler = flow.amp.GradScaler( - init_scale=1073741824, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, + init_scale=1073741824, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, ) eval_graph = DeepFMValGraph(deepfm_module, args.amp) @@ -541,7 +609,9 @@ def save_model(subdir): stop_training = False cached_eval_batches = prefetch_eval_batches( - f"{args.data_dir}/val", args.batch_size, math.ceil(args.num_val_samples / args.batch_size) + f"{args.data_dir}/val", + args.batch_size, + math.ceil(args.num_val_samples / args.batch_size), ) deepfm_module.train() @@ -604,14 +674,24 @@ def save_model(subdir): print("================ Test Evaluation ================") eval(args, eval_graph, tag="test", cur_step=step, epoch=epoch) + if args.save_graph_for_serving: + del eval_graph + recompiled_eval_graph = compile_eval_graph(args, deepfm_module, tag="test") + flow.save(recompiled_eval_graph, args.model_serving_path) + flow.save_one_embedding_info(recompiled_eval_graph, args.model_serving_path) + def np_to_global(np): t = flow.from_numpy(np) - return t.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) + return t.to_global( + placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.broadcast + ) def batch_to_global(np_label, np_features, is_train=True): - labels = np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1) + labels = ( + np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1) + ) features = np_to_global(np_features) return labels, features @@ -653,11 +733,15 @@ def eval(args, eval_graph, tag="val", cur_step=0, epoch=0, cached_eval_batches=N preds.append(pred.to_local()) labels = ( - np_to_global(np.concatenate(labels, axis=0)).to_global(sbp=flow.sbp.broadcast()).to_local() + np_to_global(np.concatenate(labels, axis=0)) + .to_global(sbp=flow.sbp.broadcast()) + .to_local() ) preds = ( flow.cat(preds, dim=0) - .to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) + .to_global( + placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0) + ) .to_global(sbp=flow.sbp.broadcast()) .to_local() ) @@ -669,7 +753,9 @@ def eval(args, eval_graph, tag="val", cur_step=0, epoch=0, cached_eval_batches=N metrics_start_time = time.time() auc = flow.roc_auc_score(labels, preds).numpy()[0] - logloss = flow._C.binary_cross_entropy_loss(preds, labels, weight=None, reduction="mean") + logloss = flow._C.binary_cross_entropy_loss( + preds, labels, weight=None, reduction="mean" + ) metrics_time = time.time() - metrics_start_time if rank == 0: @@ -687,6 +773,19 @@ def eval(args, eval_graph, tag="val", cur_step=0, epoch=0, cached_eval_batches=N return auc, logloss +def compile_eval_graph(args, deepfm_module, tag="val"): + eval_graph = DeepFMValGraph(deepfm_module, args.amp) + eval_graph.module.eval() + with make_criteo_dataloader( + f"{args.data_dir}/{tag}", args.batch_size, shuffle=False + ) as loader: + label, features = batch_to_global(*next(loader), is_train=False) + # Cause we want to infer to GPU, so here set is_train as True to place input Tensor in CUDA Device + features = features.to("cuda") + pred = eval_graph(features) + return eval_graph + + if __name__ == "__main__": os.system(sys.executable + " -m oneflow --doctor") flow.boxing.nccl.enable_all_to_all(True) diff --git a/RecommenderSystems/deepfm/train_deepfm_1d.sh b/RecommenderSystems/deepfm/train_deepfm_1d.sh new file mode 100644 index 000000000..cfa9f5b14 --- /dev/null +++ b/RecommenderSystems/deepfm/train_deepfm_1d.sh @@ -0,0 +1,25 @@ +DATA_DIR=/path/to/deepfm_parquet +PERSISTENT_PATH=/path/to/persistent +MODEL_SAVE_DIR=/path/to/model/save/dir +MODEL_SERVING_PATH=/path/to/model_serving/save/dir + +python3 deepfm_train_eval.py \ + --data_dir $DATA_DIR \ + --persistent_path $PERSISTENT_PATH \ + --table_size_array "649,9364,14746,490,476707,11618,4142,1373,7275,13,169,407,1376,1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,5461306,10,5652,2173,4,7046547,18,15,286181,105,142572" \ + --store_type 'cached_host_mem' \ + --cache_memory_budget_mb 1024 \ + --batch_size 10000 \ + --train_batches 75000 \ + --loss_print_interval 100 \ + --dnn "1000,1000,1000,1000,1000" \ + --net_dropout 0.2 \ + --learning_rate 0.001 \ + --embedding_vec_size 16 \ + --num_train_samples 36672493 \ + --num_val_samples 4584062 \ + --num_test_samples 4584062 \ + --model_save_dir $MODEL_SAVE_DIR \ + --save_best_model \ + --save_graph_for_serving \ + --model_serving_path $MODEL_SERVING_PATH From 00395e5ea4c25c5183edb807f3d4663a10e0e2a1 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Wed, 20 Jul 2022 09:28:46 +0800 Subject: [PATCH 2/3] format each line size as 100 --- .../deepfm/deepfm_train_eval.py | 139 ++++-------------- 1 file changed, 32 insertions(+), 107 deletions(-) diff --git a/RecommenderSystems/deepfm/deepfm_train_eval.py b/RecommenderSystems/deepfm/deepfm_train_eval.py index ad4f9899c..505b0783d 100644 --- a/RecommenderSystems/deepfm/deepfm_train_eval.py +++ b/RecommenderSystems/deepfm/deepfm_train_eval.py @@ -10,9 +10,7 @@ import oneflow.nn as nn from petastorm.reader import make_batch_reader -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) -) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) def get_args(print_args=True): @@ -26,31 +24,19 @@ def str_list(x): parser.add_argument("--data_dir", type=str, required=True) parser.add_argument( - "--num_train_samples", - type=int, - required=True, - help="the number of train samples", + "--num_train_samples", type=int, required=True, help="the number of train samples", ) parser.add_argument( - "--num_val_samples", - type=int, - required=True, - help="the number of validation samples", + "--num_val_samples", type=int, required=True, help="the number of validation samples", ) parser.add_argument( "--num_test_samples", type=int, required=True, help="the number of test samples" ) + parser.add_argument("--model_load_dir", type=str, default=None, help="model loading directory") + parser.add_argument("--model_save_dir", type=str, default=None, help="model saving directory") parser.add_argument( - "--model_load_dir", type=str, default=None, help="model loading directory" - ) - parser.add_argument( - "--model_save_dir", type=str, default=None, help="model saving directory" - ) - parser.add_argument( - "--save_initial_model", - action="store_true", - help="save initial model parameters or not", + "--save_initial_model", action="store_true", help="save initial model parameters or not", ) parser.add_argument( "--save_model_after_each_eval", @@ -58,36 +44,22 @@ def str_list(x): help="save model after each eval or not", ) + parser.add_argument("--embedding_vec_size", type=int, default=16, help="embedding vector size") parser.add_argument( - "--embedding_vec_size", type=int, default=16, help="embedding vector size" - ) - parser.add_argument( - "--dnn", - type=int_list, - default="1000,1000,1000,1000,1000", - help="dnn hidden units number", - ) - parser.add_argument( - "--net_dropout", type=float, default=0.2, help="net dropout rate" - ) - parser.add_argument( - "--disable_fusedmlp", action="store_true", help="disable fused MLP or not" + "--dnn", type=int_list, default="1000,1000,1000,1000,1000", help="dnn hidden units number", ) + parser.add_argument("--net_dropout", type=float, default=0.2, help="net dropout rate") + parser.add_argument("--disable_fusedmlp", action="store_true", help="disable fused MLP or not") parser.add_argument("--lr_factor", type=float, default=0.1) parser.add_argument("--min_lr", type=float, default=1.0e-6) - parser.add_argument( - "--learning_rate", type=float, default=0.001, help="learning rate" - ) + parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate") parser.add_argument( "--batch_size", type=int, default=10000, help="training/evaluation batch size" ) parser.add_argument( - "--train_batches", - type=int, - default=75000, - help="the maximum number of training batches", + "--train_batches", type=int, default=75000, help="the maximum number of training batches", ) parser.add_argument("--loss_print_interval", type=int, default=100, help="") @@ -111,10 +83,7 @@ def str_list(x): required=True, ) parser.add_argument( - "--persistent_path", - type=str, - required=True, - help="path for persistent kv store", + "--persistent_path", type=str, required=True, help="path for persistent kv store", ) parser.add_argument( "--store_type", @@ -130,30 +99,21 @@ def str_list(x): ) parser.add_argument( - "--amp", - action="store_true", - help="enable Automatic Mixed Precision(AMP) training or not", - ) - parser.add_argument( - "--loss_scale_policy", type=str, default="static", help="static or dynamic" + "--amp", action="store_true", help="enable Automatic Mixed Precision(AMP) training or not", ) + parser.add_argument("--loss_scale_policy", type=str, default="static", help="static or dynamic") parser.add_argument( "--disable_early_stop", action="store_true", help="enable early stop or not" ) - parser.add_argument( - "--save_best_model", action="store_true", help="save best model or not" - ) + parser.add_argument("--save_best_model", action="store_true", help="save best model or not") parser.add_argument( "--save_graph_for_serving", action="store_true", help="Save Graph and OneEmbedding for serving. ", ) parser.add_argument( - "--model_serving_path", - type=str, - required=True, - help="Graph object path for model serving", + "--model_serving_path", type=str, required=True, help="Graph object path for model serving", ) args = parser.parse_args() @@ -238,9 +198,7 @@ def get_batches(self, reader, batch_size=None): pos = batch_size - len(tail[0]) tail = list( [ - np.concatenate( - (tail[i], rglist[i][0 : (batch_size - len(tail[i]))]) - ) + np.concatenate((tail[i], rglist[i][0 : (batch_size - len(tail[i]))])) for i in range(self.num_fields) ] ) @@ -254,9 +212,7 @@ def get_batches(self, reader, batch_size=None): continue while (pos + batch_size) <= len(rglist[0]): label = rglist[0][pos : pos + batch_size] - features = [ - rglist[j][pos : pos + batch_size] for j in range(1, self.num_fields) - ] + features = [rglist[j][pos : pos + batch_size] for j in range(1, self.num_fields)] pos += batch_size yield label, np.stack(features, axis=-1) if pos != len(rglist[0]): @@ -314,9 +270,7 @@ def __init__( if store_type == "device_mem": store_options = flow.one_embedding.make_device_mem_store_options( - persistent_path=persistent_path, - capacity=vocab_size, - size_factor=size_factor, + persistent_path=persistent_path, capacity=vocab_size, size_factor=size_factor, ) elif store_type == "cached_host_mem": assert cache_memory_budget_mb > 0 @@ -378,9 +332,7 @@ def __init__( use_relu = [True] * len(hidden_units) + [not skip_final_activation] hidden_units = [in_features] + hidden_units + [out_features] for idx in range(len(hidden_units) - 1): - denses.append( - nn.Linear(hidden_units[idx], hidden_units[idx + 1], bias=True) - ) + denses.append(nn.Linear(hidden_units[idx], hidden_units[idx + 1], bias=True)) if use_relu[idx]: denses.append(nn.ReLU()) if dropout_rates[idx] > 0: @@ -480,13 +432,7 @@ def build(self, features): class DeepFMTrainGraph(flow.nn.Graph): def __init__( - self, - deepfm_module, - loss, - optimizer, - grad_scaler=None, - amp=False, - lr_scheduler=None, + self, deepfm_module, loss, optimizer, grad_scaler=None, amp=False, lr_scheduler=None, ): super(DeepFMTrainGraph, self).__init__() self.module = deepfm_module @@ -510,9 +456,7 @@ def make_lr_scheduler(args, optimizer): batches_per_epoch = math.ceil(args.num_train_samples / args.batch_size) milestones = [ batches_per_epoch * (i + 1) - for i in range( - math.floor(math.log(args.min_lr / args.learning_rate, args.lr_factor)) - ) + for i in range(math.floor(math.log(args.min_lr / args.learning_rate, args.lr_factor))) ] multistep_lr = flow.optim.lr_scheduler.MultiStepLR( optimizer=optimizer, milestones=milestones, gamma=args.lr_factor, @@ -529,9 +473,7 @@ def get_metrics(logs): return monitor_value -def early_stop( - epoch, monitor_value, best_metric, stopping_steps, patience=2, min_delta=1e-6 -): +def early_stop(epoch, monitor_value, best_metric, stopping_steps, patience=2, min_delta=1e-6): rank = flow.env.get_rank() stop_training = False save_best = False @@ -590,10 +532,7 @@ def save_model(subdir): grad_scaler = flow.amp.StaticGradScaler(1024) else: grad_scaler = flow.amp.GradScaler( - init_scale=1073741824, - growth_factor=2.0, - backoff_factor=0.5, - growth_interval=2000, + init_scale=1073741824, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, ) eval_graph = DeepFMValGraph(deepfm_module, args.amp) @@ -609,9 +548,7 @@ def save_model(subdir): stop_training = False cached_eval_batches = prefetch_eval_batches( - f"{args.data_dir}/val", - args.batch_size, - math.ceil(args.num_val_samples / args.batch_size), + f"{args.data_dir}/val", args.batch_size, math.ceil(args.num_val_samples / args.batch_size), ) deepfm_module.train() @@ -683,15 +620,11 @@ def save_model(subdir): def np_to_global(np): t = flow.from_numpy(np) - return t.to_global( - placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.broadcast - ) + return t.to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.broadcast) def batch_to_global(np_label, np_features, is_train=True): - labels = ( - np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1) - ) + labels = np_to_global(np_label.reshape(-1, 1)) if is_train else np_label.reshape(-1, 1) features = np_to_global(np_features) return labels, features @@ -733,15 +666,11 @@ def eval(args, eval_graph, tag="val", cur_step=0, epoch=0, cached_eval_batches=N preds.append(pred.to_local()) labels = ( - np_to_global(np.concatenate(labels, axis=0)) - .to_global(sbp=flow.sbp.broadcast()) - .to_local() + np_to_global(np.concatenate(labels, axis=0)).to_global(sbp=flow.sbp.broadcast()).to_local() ) preds = ( flow.cat(preds, dim=0) - .to_global( - placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0) - ) + .to_global(placement=flow.env.all_device_placement("cpu"), sbp=flow.sbp.split(0)) .to_global(sbp=flow.sbp.broadcast()) .to_local() ) @@ -753,9 +682,7 @@ def eval(args, eval_graph, tag="val", cur_step=0, epoch=0, cached_eval_batches=N metrics_start_time = time.time() auc = flow.roc_auc_score(labels, preds).numpy()[0] - logloss = flow._C.binary_cross_entropy_loss( - preds, labels, weight=None, reduction="mean" - ) + logloss = flow._C.binary_cross_entropy_loss(preds, labels, weight=None, reduction="mean") metrics_time = time.time() - metrics_start_time if rank == 0: @@ -776,9 +703,7 @@ def eval(args, eval_graph, tag="val", cur_step=0, epoch=0, cached_eval_batches=N def compile_eval_graph(args, deepfm_module, tag="val"): eval_graph = DeepFMValGraph(deepfm_module, args.amp) eval_graph.module.eval() - with make_criteo_dataloader( - f"{args.data_dir}/{tag}", args.batch_size, shuffle=False - ) as loader: + with make_criteo_dataloader(f"{args.data_dir}/{tag}", args.batch_size, shuffle=False) as loader: label, features = batch_to_global(*next(loader), is_train=False) # Cause we want to infer to GPU, so here set is_train as True to place input Tensor in CUDA Device features = features.to("cuda") From f26b4e47c465b0e71d85f0a385e482551f8e6dac Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Thu, 21 Jul 2022 18:17:11 +0800 Subject: [PATCH 3/3] change to use eval state dict to save --- RecommenderSystems/deepfm/deepfm_train_eval.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/RecommenderSystems/deepfm/deepfm_train_eval.py b/RecommenderSystems/deepfm/deepfm_train_eval.py index 505b0783d..06d5fa253 100644 --- a/RecommenderSystems/deepfm/deepfm_train_eval.py +++ b/RecommenderSystems/deepfm/deepfm_train_eval.py @@ -614,8 +614,9 @@ def save_model(subdir): if args.save_graph_for_serving: del eval_graph recompiled_eval_graph = compile_eval_graph(args, deepfm_module, tag="test") + eval_state_dict = recompiled_eval_graph.state_dict() flow.save(recompiled_eval_graph, args.model_serving_path) - flow.save_one_embedding_info(recompiled_eval_graph, args.model_serving_path) + flow.save_one_embedding_info(eval_state_dict, args.model_serving_path) def np_to_global(np):