def main():
features, labels = input_fn.train_input_fn(tfrecord_path, batch_size=bs, shuffle_buffer_size=sbs)()
model = trainer.model_fn(features, labels, tf.estimator.ModeKeys.TRAIN)
train_op = model.train_op
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(num_epochs):
sess.run(train_op)