From cbbdf0893ce4f79668a13c0879a7f143bdd0f472 Mon Sep 17 00:00:00 2001 From: steven Date: Fri, 13 Sep 2019 16:41:51 +0800 Subject: [PATCH] all: fine tune exec module code * fine tune exec module code. and else/end opcode should not discard any stack item, need not add unreachable instruction to instr stream to compile * refine code, and add uint test --- disasm/disasm.go | 157 +++++++++++++------------------ exec/internal/compile/compile.go | 29 ------ wasm/index.go | 59 ++++++++++++ wasm/module_test.go | 91 ++++++++++++++++++ wasm/testdata/spec/sigtest.wasm | Bin 0 -> 125 bytes wasm/testdata/spec/sigtest.wat | 23 +++++ 6 files changed, 237 insertions(+), 122 deletions(-) create mode 100644 wasm/testdata/spec/sigtest.wasm create mode 100644 wasm/testdata/spec/sigtest.wat diff --git a/disasm/disasm.go b/disasm/disasm.go index 0d4c3936..49747f94 100644 --- a/disasm/disasm.go +++ b/disasm/disasm.go @@ -107,7 +107,6 @@ func NewDisassembly(fn wasm.Function, module *wasm.Module) (*Disassembly, error) stackDepths.Push(0) blockIndices := &stack.Stack{} // a stack of indices to operators which start new blocks curIndex := 0 - var lastOpReturn bool for _, instr := range instrs { logger.Printf("stack top is %d", stackDepths.Top()) @@ -129,13 +128,28 @@ func NewDisassembly(fn wasm.Function, module *wasm.Module) (*Disassembly, error) instr.Unreachable = !isInstrReachable(blockPolymorphicOps) } + var blockStartIndex uint64 + switch op { + case ops.End, ops.Else: + blockStartIndex = blockIndices.Pop() + if op == ops.Else { + blockIndices.Push(uint64(curIndex)) + } + case ops.Block, ops.Loop, ops.If: + blockIndices.Push(uint64(curIndex)) + } + + if instr.Unreachable { + continue + } + logger.Printf("op: %s, unreachable: %v", opStr.Name, instr.Unreachable) - if !opStr.Polymorphic && !instr.Unreachable { + if !opStr.Polymorphic { top := int(stackDepths.Top()) top -= len(opStr.Args) stackDepths.SetTop(uint64(top)) - if top < -1 { - return nil, ErrStackUnderflow + if top < 0 { + panic("underflow during validation") } if opStr.Returns != wasm.ValueType(wasm.BlockTypeEmpty) { top++ @@ -148,23 +162,13 @@ func NewDisassembly(fn wasm.Function, module *wasm.Module) (*Disassembly, error) case ops.Unreachable: pushPolymorphicOp(blockPolymorphicOps, curIndex) case ops.Drop: - if !instr.Unreachable { - stackDepths.SetTop(stackDepths.Top() - 1) - } + stackDepths.SetTop(stackDepths.Top() - 1) case ops.Select: - if !instr.Unreachable { - stackDepths.SetTop(stackDepths.Top() - 2) - } + stackDepths.SetTop(stackDepths.Top() - 2) case ops.Return: - if !instr.Unreachable { - stackDepths.SetTop(stackDepths.Top() - uint64(len(fn.Sig.ReturnTypes))) - } + stackDepths.SetTop(stackDepths.Top() - uint64(len(fn.Sig.ReturnTypes))) pushPolymorphicOp(blockPolymorphicOps, curIndex) - lastOpReturn = true case ops.End, ops.Else: - // The max depth reached while execing the current block - curDepth := stackDepths.Top() - blockStartIndex := blockIndices.Pop() blockSig := disas.Code[blockStartIndex].Block.Signature instr.Block = &BlockInfo{ Start: false, @@ -175,7 +179,7 @@ func NewDisassembly(fn wasm.Function, module *wasm.Module) (*Disassembly, error) disas.Code[blockStartIndex].Block.EndIndex = curIndex } else { // ops.Else instr.Block.ElseIfIndex = int(blockStartIndex) - disas.Code[blockStartIndex].Block.IfElseIndex = int(blockStartIndex) + disas.Code[blockStartIndex].Block.IfElseIndex = int(curIndex) } // The max depth reached while execing the last block @@ -187,60 +191,28 @@ func NewDisassembly(fn wasm.Function, module *wasm.Module) (*Disassembly, error) prevDepthIndex := stackDepths.Len() - 2 prevDepth := stackDepths.Get(prevDepthIndex) - if op != ops.Else && blockSig != wasm.BlockTypeEmpty && !instr.Unreachable { + if op != ops.Else && blockSig != wasm.BlockTypeEmpty { stackDepths.Set(prevDepthIndex, prevDepth+1) disas.checkMaxDepth(int(stackDepths.Get(prevDepthIndex))) } - if !lastOpReturn { - elemsDiscard := int(curDepth) - int(prevDepth) - if elemsDiscard < -1 { - return nil, ErrStackUnderflow - } - instr.NewStack = &StackInfo{ - StackTopDiff: int64(elemsDiscard), - PreserveTop: blockSig != wasm.BlockTypeEmpty, - } - logger.Printf("discard %d elements, preserve top: %v", elemsDiscard, instr.NewStack.PreserveTop) - } else { - instr.NewStack = &StackInfo{} - } - logger.Printf("setting new stack for %s block (%d)", disas.Code[blockStartIndex].Op.Name, blockStartIndex) - disas.Code[blockStartIndex].NewStack = instr.NewStack - if !instr.Unreachable { - blockPolymorphicOps = blockPolymorphicOps[:len(blockPolymorphicOps)-1] - } + blockPolymorphicOps = blockPolymorphicOps[:len(blockPolymorphicOps)-1] stackDepths.Pop() if op == ops.Else { stackDepths.Push(stackDepths.Top()) - blockIndices.Push(uint64(curIndex)) - if !instr.Unreachable { - blockPolymorphicOps = append(blockPolymorphicOps, []int{}) - } + blockPolymorphicOps = append(blockPolymorphicOps, []int{}) } - case ops.Block, ops.Loop, ops.If: sig := instr.Immediates[0].(wasm.BlockType) logger.Printf("if, depth is %d", stackDepths.Top()) stackDepths.Push(stackDepths.Top()) - // If this new block is unreachable, its - // entire instruction sequence is unreachable - // as well. To make sure that isInstrReachable - // returns the correct value, we don't push a new - // array to blockPolymorphicOps. - if !instr.Unreachable { - // Therefore, only push a new array if this instruction - // is reachable. - blockPolymorphicOps = append(blockPolymorphicOps, []int{}) - } + blockPolymorphicOps = append(blockPolymorphicOps, []int{}) instr.Block = &BlockInfo{ Start: true, Signature: sig, } - - blockIndices.Push(uint64(curIndex)) case ops.Br, ops.BrIf: depth := instr.Immediates[0].(uint32) if int(depth) == blockIndices.Len() { @@ -260,10 +232,13 @@ func NewDisassembly(fn wasm.Function, module *wasm.Module) (*Disassembly, error) // No need to subtract 2 here, we are getting the block // we need to branch to. - index := blockIndices.Get(blockIndices.Len() - 1 - int(depth)) - instr.NewStack = &StackInfo{ - StackTopDiff: int64(elemsDiscard), - PreserveTop: disas.Code[index].Block.Signature != wasm.BlockTypeEmpty, + // No need Discard one. and PreserveTop. + if elemsDiscard > 1 { + index := blockIndices.Get(blockIndices.Len() - 1 - int(depth)) + instr.NewStack = &StackInfo{ + StackTopDiff: int64(elemsDiscard), + PreserveTop: disas.Code[index].Block.Signature != wasm.BlockTypeEmpty, + } } } if op == ops.Br { @@ -271,9 +246,7 @@ func NewDisassembly(fn wasm.Function, module *wasm.Module) (*Disassembly, error) } case ops.BrTable: - if !instr.Unreachable { - stackDepths.SetTop(stackDepths.Top() - 1) - } + stackDepths.SetTop(stackDepths.Top() - 1) targetCount := instr.Immediates[0].(uint32) for i := uint32(0); i < targetCount; i++ { entry := instr.Immediates[i+1].(uint32) @@ -318,42 +291,40 @@ func NewDisassembly(fn wasm.Function, module *wasm.Module) (*Disassembly, error) pushPolymorphicOp(blockPolymorphicOps, curIndex) case ops.Call, ops.CallIndirect: index := instr.Immediates[0].(uint32) - if !instr.Unreachable { - var sig *wasm.FunctionSig - top := int(stackDepths.Top()) - if op == ops.CallIndirect { - if module.Types == nil { - return nil, errors.New("missing types section") - } - sig = &module.Types.Entries[index] - top-- - } else { - sig = module.GetFunction(int(index)).Sig + var sig *wasm.FunctionSig + top := int(stackDepths.Top()) + + switch op { + case ops.CallIndirect: + if module.Types == nil { + return nil, errors.New("missing types section") } - top -= len(sig.ParamTypes) - top += len(sig.ReturnTypes) - stackDepths.SetTop(uint64(top)) - disas.checkMaxDepth(top) - } - case ops.GetLocal, ops.SetLocal, ops.TeeLocal, ops.GetGlobal, ops.SetGlobal: - if !instr.Unreachable { - top := stackDepths.Top() - switch op { - case ops.GetLocal, ops.GetGlobal: - top++ - stackDepths.SetTop(top) - disas.checkMaxDepth(int(top)) - case ops.SetLocal, ops.SetGlobal: - top-- - stackDepths.SetTop(top) - case ops.TeeLocal: - // stack remains unchanged for tee_local + sig = &module.Types.Entries[index] + top-- + default: + sig, err = module.GetFunctionSig(index) + if err != nil { + return nil, err } } - } - if op != ops.Return { - lastOpReturn = false + top -= len(sig.ParamTypes) + top += len(sig.ReturnTypes) + stackDepths.SetTop(uint64(top)) + disas.checkMaxDepth(top) + case ops.GetLocal, ops.SetLocal, ops.TeeLocal, ops.GetGlobal, ops.SetGlobal: + top := stackDepths.Top() + switch op { + case ops.GetLocal, ops.GetGlobal: + top++ + stackDepths.SetTop(top) + disas.checkMaxDepth(int(top)) + case ops.SetLocal, ops.SetGlobal: + top-- + stackDepths.SetTop(top) + case ops.TeeLocal: + // stack remains unchanged for tee_local + } } disas.Code = append(disas.Code, instr) diff --git a/exec/internal/compile/compile.go b/exec/internal/compile/compile.go index 784dc096..b959b7ea 100644 --- a/exec/internal/compile/compile.go +++ b/exec/internal/compile/compile.go @@ -206,29 +206,15 @@ func Compile(disassembly []disasm.Instr) ([]byte, *BytecodeMetadata) { offset: int64(buffer.Len()), ifBlock: false, loopBlock: true, - discard: *instr.NewStack, } continue case ops.Block: curBlockDepth++ blocks[curBlockDepth] = &block{ ifBlock: false, - discard: *instr.NewStack, } continue case ops.Else: - ifInstr := disassembly[instr.Block.ElseIfIndex] // the corresponding `if` instruction for this else - if ifInstr.NewStack != nil && ifInstr.NewStack.StackTopDiff != 0 { - // add code for jumping out of a taken if branch - op := OpDiscard - if ifInstr.NewStack.PreserveTop { - op = OpDiscardPreserveTop - } - - emitMetadata(op, buffer.Len(), instAndInt64Len) - buffer.WriteByte(op) - binary.Write(buffer, binary.LittleEndian, ifInstr.NewStack.StackTopDiff) - } emitMetadata(OpJmp, buffer.Len(), instAndInt64Len) buffer.WriteByte(OpJmp) ifBlockEndOffset := int64(buffer.Len()) @@ -247,21 +233,6 @@ func Compile(disassembly []disasm.Instr) ([]byte, *BytecodeMetadata) { depth := curBlockDepth block := blocks[depth] - if instr.NewStack.StackTopDiff != 0 { - // when exiting a block, discard elements to - // restore stack height. - op := OpDiscard - if instr.NewStack.PreserveTop { - // this is true when the block has a - // signature, and therefore pushes - // a value on to the stack - op = OpDiscardPreserveTop - } - emitMetadata(op, buffer.Len(), instAndInt64Len) - buffer.WriteByte(op) - binary.Write(buffer, binary.LittleEndian, instr.NewStack.StackTopDiff) - } - if !block.loopBlock { // is a normal block block.offset = int64(buffer.Len()) if block.ifBlock { diff --git a/wasm/index.go b/wasm/index.go index 0184e396..6b1ea8ee 100644 --- a/wasm/index.go +++ b/wasm/index.go @@ -6,6 +6,7 @@ package wasm import ( "bytes" + "errors" "fmt" "reflect" ) @@ -82,6 +83,7 @@ func (m *Module) populateFunctions() error { } funcs := make([]uint32, 0, len(m.Function.Types)+len(m.imports.Funcs)) + funcs = append(funcs, m.imports.Funcs...) funcs = append(funcs, m.Function.Types...) m.Function.Types = funcs @@ -98,6 +100,36 @@ func (m *Module) GetFunction(i int) *Function { return &m.FunctionIndexSpace[i] } +func (m *Module) GetFunctionSig(i uint32) (*FunctionSig, error) { + var funcindex uint32 + if m.Import == nil { + if i >= uint32(len(m.Function.Types)) { + return nil, errors.New("fsig out of len") + } + typeindex := m.Function.Types[i] + return &m.Types.Entries[typeindex], nil + } + + for _, importEntry := range m.Import.Entries { + if importEntry.Type.Kind() == ExternalFunction { + if funcindex == i { + typeindex := importEntry.Type.(FuncImport).Type + return &m.Types.Entries[typeindex], nil + } + + funcindex++ + } + } + + i = i - (funcindex - uint32(len(m.imports.Funcs))) + if i >= uint32(len(m.Function.Types)) { + return nil, errors.New("fsig out of len") + } + + typeindex := m.Function.Types[i] + return &m.Types.Entries[typeindex], nil +} + func (m *Module) populateGlobals() error { if m.Global == nil { return nil @@ -118,6 +150,33 @@ func (m *Module) GetGlobal(i int) *GlobalEntry { return &m.GlobalIndexSpace[i] } +func (m *Module) GetGlobalType(i uint32) (*GlobalVar, error) { + var globalindex uint32 + + if m.Import == nil { + if i >= uint32(len(m.Global.Globals)) { + return nil, errors.New("global index out of len") + } + return &m.Global.Globals[i].Type, nil + } + + for _, importEntry := range m.Import.Entries { + if importEntry.Type.Kind() == ExternalGlobal { + if globalindex == i { + v := importEntry.Type.(GlobalVarImport).Type + return &v, nil + } + globalindex++ + } + } + + i = i - (globalindex - uint32(m.imports.Globals)) + if i >= uint32(len(m.Global.Globals)) { + return nil, errors.New("global index out of len") + } + return &m.Global.Globals[i].Type, nil +} + func (m *Module) populateTables() error { if m.Table == nil || len(m.Table.Entries) == 0 || m.Elements == nil || len(m.Elements.Entries) == 0 { return nil diff --git a/wasm/module_test.go b/wasm/module_test.go index 5098c219..dcad6708 100644 --- a/wasm/module_test.go +++ b/wasm/module_test.go @@ -7,6 +7,7 @@ package wasm_test import ( "bytes" "io/ioutil" + "os" "path/filepath" "reflect" "testing" @@ -173,3 +174,93 @@ func TestDuplicateExportError_NoStackOverflow(t *testing.T) { err := wasm.DuplicateExportError("h") _ = err.Error() } + +func TestGetFuntionSig(t *testing.T) { + f, err := os.Open("testdata/spec/sigtest.wasm") + if err != nil { + t.Fatalf("%v", err) + } + defer f.Close() + m, err := wasm.ReadModule(f, nil) + if err != nil { + t.Fatalf("error reading module %v", err) + } + + // check first sig + fsig, err := m.GetFunctionSig(0) + if err != nil { + t.Fatalf("get fsig error") + } + if !(len(fsig.ParamTypes) == 1 && fsig.ParamTypes[0] == wasm.ValueTypeI64) { + t.Fatalf("error param sig, %v", fsig.ParamTypes) + } + if !(len(fsig.ReturnTypes) == 1 && fsig.ReturnTypes[0] == wasm.ValueTypeI64) { + t.Fatalf("error return sig, %v", fsig.ReturnTypes) + } + + // check second sig + fsig, err = m.GetFunctionSig(1) + if err != nil { + t.Fatalf("get fsig error") + } + if !(len(fsig.ParamTypes) == 2 && fsig.ParamTypes[0] == wasm.ValueTypeI32 && fsig.ParamTypes[1] == wasm.ValueTypeI32) { + t.Fatalf("error param sig, %v", fsig.ParamTypes) + } + if !(len(fsig.ReturnTypes) == 1 && fsig.ReturnTypes[0] == wasm.ValueTypeI32) { + t.Fatalf("error return sig, %v", fsig.ReturnTypes) + } + + // check third sig + fsig, err = m.GetFunctionSig(2) + if err != nil { + t.Fatalf("get fsig error") + } + if !(len(fsig.ParamTypes) == 1 && fsig.ParamTypes[0] == wasm.ValueTypeI32) { + t.Fatalf("error param sig, %v", fsig.ParamTypes) + } + if !(len(fsig.ReturnTypes) == 1 && fsig.ReturnTypes[0] == wasm.ValueTypeI32) { + t.Fatalf("error return sig, %v", fsig.ReturnTypes) + } + + // check fourth sig + fsig, err = m.GetFunctionSig(3) + if err != nil { + t.Fatalf("get fsig error") + } + if !(len(fsig.ParamTypes) == 0) { + t.Fatalf("error param sig, %v", fsig.ParamTypes) + } + if !(len(fsig.ReturnTypes) == 1) && fsig.ReturnTypes[0] == wasm.ValueTypeI32 { + t.Fatalf("error return sig, %v", fsig.ReturnTypes) + } + + fsig, err = m.GetFunctionSig(4) + if err == nil { + t.Fatalf("get fsig error") + } + + // check global var sig + gsig, err := m.GetGlobalType(0) + if err != nil { + t.Fatalf("get global type error") + } + + if gsig.Type != wasm.ValueTypeI64 { + t.Fatalf("error global type sig, %v", gsig.Type) + } + + gsig, err = m.GetGlobalType(1) + if err != nil { + t.Fatalf("get global type error") + } + + if gsig.Type != wasm.ValueTypeI32 { + t.Fatalf("error global type sig, %v", gsig.Type) + } + + gsig, err = m.GetGlobalType(2) + if err == nil { + t.Fatalf("get global type error") + } + +} diff --git a/wasm/testdata/spec/sigtest.wasm b/wasm/testdata/spec/sigtest.wasm new file mode 100644 index 0000000000000000000000000000000000000000..16bd3ecb26ff06cb816aaaddc783521332b6c537 GIT binary patch literal 125 zcmWlRJqp7x7(*q0CI)*6g$~^^^&YuNelTsC)E;62UF_@iAU#N(L3)E8!nnT)fLA@R zunel4jT&ZVbywoUafu