From 8f936f048621311442f4ec9030b8e2a5f192af31 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Thu, 16 Jan 2020 17:08:17 +0800 Subject: [PATCH] fix saved model path (#1718) --- go.mod | 4 ++- go.sum | 2 ++ pkg/sql/alisa_submitter.go | 43 +++++++++++++++++++++++++-------- pkg/sql/alisa_submitter_test.go | 4 +-- pkg/sql/codegen/pai/codegen.go | 11 ++++----- 5 files changed, 45 insertions(+), 19 deletions(-) diff --git a/go.mod b/go.mod index 54251bebe7..7497722d1a 100644 --- a/go.mod +++ b/go.mod @@ -12,10 +12,11 @@ require ( github.com/fortytw2/leaktest v1.3.0 github.com/go-delve/delve v1.3.2 // indirect - github.com/go-openapi/spec v0.19.4 // indirect + github.com/go-openapi/spec v0.19.5 // indirect github.com/go-sql-driver/mysql v1.4.1 github.com/golang/protobuf v1.3.2 github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect + github.com/kr/pty v1.1.5 // indirect github.com/mattn/go-colorable v0.1.4 // indirect github.com/mattn/go-isatty v0.0.11 // indirect github.com/mattn/go-runewidth v0.0.7 // indirect @@ -31,6 +32,7 @@ require ( github.com/sirupsen/logrus v1.4.2 github.com/soniakeys/quant v1.0.0 // indirect github.com/spf13/cobra v0.0.5 // indirect + github.com/stretchr/objx v0.2.0 // indirect github.com/stretchr/testify v1.4.0 go.starlark.net v0.0.0-20191218235703-9fcb808a6221 // indirect golang.org/x/arch v0.0.0-20191126211547-368ea8f32fff // indirect diff --git a/go.sum b/go.sum index beb1fe9ebb..0661d4fee3 100644 --- a/go.sum +++ b/go.sum @@ -95,6 +95,8 @@ github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwoh github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc= github.com/go-openapi/spec v0.19.4 h1:ixzUSnHTd6hCemgtAJgluaTSGYpLNpJY4mA2DIkdOAo= github.com/go-openapi/spec v0.19.4/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo= +github.com/go-openapi/spec v0.19.5 h1:Xm0Ao53uqnk9QE/LlYV5DEU09UAgpliA85QoT9LzqPw= +github.com/go-openapi/spec v0.19.5/go.mod h1:Hm2Jr4jv8G1ciIAo+frC/Ft+rR2kQDh8JHKHb3gWUSk= github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I= github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY= diff --git a/pkg/sql/alisa_submitter.go b/pkg/sql/alisa_submitter.go index 83819fa2c1..0bb0b21e56 100644 --- a/pkg/sql/alisa_submitter.go +++ b/pkg/sql/alisa_submitter.go @@ -37,11 +37,11 @@ type alisaSubmitter struct { } func (s *alisaSubmitter) submitAlisaTask(code, resourceName string) error { - _, dSName, err := database.ParseURL(s.Session.DbConnStr) + _, dsName, err := database.ParseURL(s.Session.DbConnStr) if err != nil { return err } - cfg, e := goalisa.ParseDSN(dSName) + cfg, e := goalisa.ParseDSN(dsName) if e != nil { return e } @@ -59,6 +59,22 @@ func (s *alisaSubmitter) submitAlisaTask(code, resourceName string) error { return e } +func (s *alisaSubmitter) getModelPath(modelName string) (string, error) { + _, dsName, err := database.ParseURL(s.Session.DbConnStr) + if err != nil { + return "", err + } + cfg, err := goalisa.ParseDSN(dsName) + if err != nil { + return "", err + } + userID := s.Session.UserId + if userID == "" { + userID = "unkown" + } + return strings.Join([]string{cfg.Project, userID, modelName}, "/"), nil +} + func (s *alisaSubmitter) ExecuteTrain(ts *ir.TrainStmt) (e error) { ts.TmpTrainTable, ts.TmpValidateTable, e = createTempTrainAndValTable(ts.Select, ts.ValidationSelect, s.Session.DbConnStr) if e != nil { @@ -71,12 +87,17 @@ func (s *alisaSubmitter) ExecuteTrain(ts *ir.TrainStmt) (e error) { return e } - paiCmd, e := getPAIcmd(cc, ts.Into, ts.TmpTrainTable, ts.TmpValidateTable, "") + modelPath, e := s.getModelPath(ts.Into) if e != nil { return e } - code, e := pai.TFTrainAndSave(ts, s.Session, ts.Into) + paiCmd, e := getPAIcmd(cc, ts.Into, modelPath, ts.TmpTrainTable, ts.TmpValidateTable, "") + if e != nil { + return e + } + + code, e := pai.TFTrainAndSave(ts, s.Session, modelPath) if e != nil { return e } @@ -121,13 +142,15 @@ func (s *alisaSubmitter) ExecutePredict(ps *ir.PredictStmt) error { if e != nil { return e } - - paiCmd, e := getPAIcmd(cc, ps.Using, ps.TmpPredictTable, "", ps.ResultTable) + modelPath, e := s.getModelPath(ps.Using) if e != nil { return e } - - code, e := pai.TFLoadAndPredict(ps, s.Session, ps.Using) + paiCmd, e := getPAIcmd(cc, ps.Using, modelPath, ps.TmpPredictTable, "", ps.ResultTable) + if e != nil { + return e + } + code, e := pai.TFLoadAndPredict(ps, s.Session, modelPath) if e != nil { return e } @@ -198,14 +221,14 @@ func odpsTables(table string) (string, error) { return fmt.Sprintf("odps://%s/tables/%s", parts[0], parts[1]), nil } -func getPAIcmd(cc *pai.ClusterConfig, modelName, trainTable, valTable, resTable string) (string, error) { +func getPAIcmd(cc *pai.ClusterConfig, modelName, ossModelPath, trainTable, valTable, resTable string) (string, error) { jobName := strings.Replace(strings.Join([]string{"sqlflow", modelName}, "_"), ".", "_", 0) cfString, err := json.Marshal(cc) if err != nil { return "", err } cfQuote := strconv.Quote(string(cfString)) - ckpDir, err := pai.FormatCkptDir(modelName) + ckpDir, err := pai.FormatCkptDir(ossModelPath) if err != nil { return "", err } diff --git a/pkg/sql/alisa_submitter_test.go b/pkg/sql/alisa_submitter_test.go index ddbf654fb2..498864e401 100644 --- a/pkg/sql/alisa_submitter_test.go +++ b/pkg/sql/alisa_submitter_test.go @@ -50,9 +50,9 @@ func TestGetPAICmd(t *testing.T) { } os.Setenv("SQLFLOW_OSS_CHECKPOINT_DIR", "oss://bucket/?role_arn=xxx&host=xxx") defer os.Unsetenv("SQLFLOW_OSS_CHECKPOINT_DIR") - paiCmd, err := getPAIcmd(cc, "my_model", "testdb.test", "", "testdb.result") + paiCmd, err := getPAIcmd(cc, "my_model", "project/12345/my_model", "testdb.test", "", "testdb.result") a.NoError(err) - ckpDir, err := pai.FormatCkptDir("my_model") + ckpDir, err := pai.FormatCkptDir("project/12345/my_model") a.NoError(err) expected := fmt.Sprintf("pai -name tensorflow1120 -DjobName=sqlflow_my_model -Dtags=dnn -Dscript=file://@@task.tar.gz -DentryFile=entry.py -Dtables=odps://testdb/tables/test -Doutputs=odps://testdb/tables/result -DcheckpointDir=\"%s\"", ckpDir) a.Equal(expected, paiCmd) diff --git a/pkg/sql/codegen/pai/codegen.go b/pkg/sql/codegen/pai/codegen.go index 92c9ad5871..2dada147b3 100644 --- a/pkg/sql/codegen/pai/codegen.go +++ b/pkg/sql/codegen/pai/codegen.go @@ -66,8 +66,7 @@ func FormatCkptDir(modelName string) (string, error) { } ossDir := strings.Join([]string{strings.TrimRight(ossURIParts[0], "/"), modelName}, "/") // Form URI like: oss://bucket/your/path/modelname/?args=... - ossCkptDir = strings.Join([]string{ossDir + "/", ossURIParts[1]}, "?") - return ossCkptDir, nil + return strings.Join([]string{ossDir + "/", ossURIParts[1]}, "?"), nil } // wrapper generates a Python program for submit TensorFlow tasks to PAI. @@ -228,7 +227,7 @@ func Train(ir *ir.TrainStmt, session *pb.Session, modelName, cwd string) (string } // TFTrainAndSave generates PAI-TF train program. -func TFTrainAndSave(ir *ir.TrainStmt, session *pb.Session, modelName string) (string, error) { +func TFTrainAndSave(ir *ir.TrainStmt, session *pb.Session, modelPath string) (string, error) { code, err := tensorflow.Train(ir, session) if err != nil { return "", err @@ -236,7 +235,7 @@ func TFTrainAndSave(ir *ir.TrainStmt, session *pb.Session, modelName string) (st // append code snippet to save model var tpl = template.Must(template.New("SaveModel").Parse(tfSaveModelTmplText)) - ckptDir, err := FormatCkptDir(ir.Into) + ckptDir, err := FormatCkptDir(modelPath) if err != nil { return "", err } @@ -332,9 +331,9 @@ func Predict(ir *ir.PredictStmt, session *pb.Session, modelName, cwd string) (st } // TFLoadAndPredict generates PAI-TF prediction program. -func TFLoadAndPredict(ir *ir.PredictStmt, session *pb.Session, modelName string) (string, error) { +func TFLoadAndPredict(ir *ir.PredictStmt, session *pb.Session, modelPath string) (string, error) { var tpl = template.Must(template.New("Predict").Parse(tfPredictTmplText)) - ossModelDir, err := FormatCkptDir(modelName) + ossModelDir, err := FormatCkptDir(modelPath) if err != nil { return "", err }