Skip to content

Use iterators for graph vertices #36558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions internal/dag/seq.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package dag

import (
"iter"
"slices"
)

type VertexSeq[T Vertex] iter.Seq[T]

func (seq VertexSeq[T]) Collect() []T {
return slices.Collect(iter.Seq[T](seq))
}

func (seq VertexSeq[T]) AsGeneric() VertexSeq[Vertex] {
return func(yield func(Vertex) bool) {
for v := range seq {
if !yield(v) {
return
}
}
}
}

// Vertices returns an iterator over all the vertices in the graph.
func (g *Graph) VerticesSeq() VertexSeq[Vertex] {
return func(yield func(v Vertex) bool) {
for _, v := range g.vertices {
v, ok := v.(Vertex)
if !ok {
continue
}
if !yield(v) {
return
}
}
}
}

// SelectSeq filters a sequence to include only elements that can be type-asserted to type U.
// It returns a new sequence containing only the matching elements.
// The yield function can return false to stop iteration early.
func SelectSeq[U Vertex](seq VertexSeq[Vertex]) VertexSeq[U] {
return func(yield func(U) bool) {
for v := range seq {
// if the item is not of the type we're looking for, skip it
u, ok := any(v).(U)
if !ok {
continue
}
if !yield(u) {
return
}
}
}
}

// ExcludeSeq filters a sequence to exclude elements that can be type-asserted to type U.
// It returns a new sequence containing only the non-matching elements.
// The yield function can return false to stop iteration early.
func ExcludeSeq[U Vertex](seq VertexSeq[Vertex]) VertexSeq[Vertex] {
return func(yield func(Vertex) bool) {
for v := range seq {
if _, ok := any(v).(U); ok {
continue
}
if !yield(v) {
return
}
}
}
}
85 changes: 85 additions & 0 deletions internal/dag/seq_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package dag

import (
"testing"
)

// Mock implementation of SeqVertex for testing
type MockVertex struct {
id int
}

func (v MockVertex) ZeroValue() any {
return MockVertex{}
}

type MockVertex2 struct {
id int
}

func TestSelectSeq(t *testing.T) {
v1 := MockVertex{id: 1}
v11 := MockVertex{id: 11}
v2 := MockVertex2{id: 2}
vertices := Set{v1: v1, v11: v11, v2: v2}

graph := &Graph{vertices: vertices}
seq := SelectSeq[MockVertex](graph.VerticesSeq())
t.Run("Select objects of given type", func(t *testing.T) {
count := len(seq.Collect())
if count != 2 {
t.Errorf("Expected 2, got %d", count)
}
})

t.Run("Returns empty when looking for incompatible types", func(t *testing.T) {
seq := SelectSeq[MockVertex2](seq.AsGeneric())
count := len(seq.Collect())
if count != 0 {
t.Errorf("Expected empty, got %d", count)
}
})

t.Run("Select objects of given interface", func(t *testing.T) {
seq := SelectSeq[interface{ ZeroValue() any }](graph.VerticesSeq())
count := len(seq.Collect())
if count != 2 {
t.Errorf("Expected 1, got %d", count)
}
})
}

func TestExcludeSeq(t *testing.T) {
v1 := MockVertex{id: 1}
v11 := MockVertex{id: 11}
v2 := MockVertex2{id: 2}
vertices := Set{v1: v1, v11: v11, v2: v2}

graph := &Graph{vertices: vertices}
seq := ExcludeSeq[MockVertex](graph.VerticesSeq())
t.Run("Exclude objects of given type", func(t *testing.T) {
count := len(seq.Collect())
if count != 1 {
t.Errorf("Expected 1, got %d", count)
}
})

t.Run("Returns empty when looking for incompatible types", func(t *testing.T) {
seq := ExcludeSeq[MockVertex2](seq)
count := len(seq.Collect())
if count != 0 {
t.Errorf("Expected empty, got %d", count)
}
})

t.Run("Exclude objects of given interface", func(t *testing.T) {
seq := ExcludeSeq[interface{ ZeroValue() any }](graph.VerticesSeq())
count := len(seq.Collect())
if count != 1 {
t.Errorf("Expected 1, got %d", count)
}
})
}
13 changes: 6 additions & 7 deletions internal/moduletest/graph/test_graph_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/hashicorp/terraform/internal/addrs"
"github.com/hashicorp/terraform/internal/backend/backendrun"
"github.com/hashicorp/terraform/internal/dag"
"github.com/hashicorp/terraform/internal/moduletest"
"github.com/hashicorp/terraform/internal/terraform"
"github.com/hashicorp/terraform/internal/tfdiags"
Expand Down Expand Up @@ -62,13 +63,11 @@ func (b *TestGraphBuilder) Steps() []terraform.GraphTransformer {
}

func validateRunConfigs(g *terraform.Graph) error {
for _, v := range g.Vertices() {
if node, ok := v.(*NodeTestRun); ok {
diags := node.run.Config.Validate(node.run.ModuleConfig)
node.run.Diagnostics = node.run.Diagnostics.Append(diags)
if diags.HasErrors() {
node.run.Status = moduletest.Error
}
for node := range dag.SelectSeq[*NodeTestRun](g.VerticesSeq()) {
diags := node.run.Config.Validate(node.run.ModuleConfig)
node.run.Diagnostics = node.run.Diagnostics.Append(diags)
if diags.HasErrors() {
node.run.Status = moduletest.Error
}
}
return nil
Expand Down
6 changes: 1 addition & 5 deletions internal/moduletest/graph/transform_close_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@ func (t *CloseTestGraphTransformer) Transform(g *terraform.Graph) error {
closeRoot := &nodeCloseTest{}
g.Add(closeRoot)

for _, v := range g.Vertices() {
if v == closeRoot {
continue
}

for v := range dag.ExcludeSeq[*nodeCloseTest](g.VerticesSeq()) {
// since this is closing the graph, make it depend on everything in
// the graph that does not have a parent. Such nodes are the real roots
// of the graph, and since they are now siblings of the closing root node,
Expand Down
2 changes: 1 addition & 1 deletion internal/moduletest/graph/transform_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (e *EvalContextTransformer) Transform(graph *terraform.Graph) error {
}

graph.Add(node)
for _, v := range graph.Vertices() {
for v := range graph.VerticesSeq() {
if v == node {
continue
}
Expand Down
9 changes: 2 additions & 7 deletions internal/moduletest/graph/transform_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ func (t *TestProvidersTransformer) Transform(g *terraform.Graph) error {
// a root provider node that will add the providers to the context
rootProviderNode := t.createRootNode(g, runProviderMap)

for _, v := range g.Vertices() {
node, ok := v.(*NodeTestRun)
if !ok {
continue
}

for node := range dag.SelectSeq[*NodeTestRun](g.VerticesSeq()) {
// Get the providers that the test run depends on
configKey := node.run.GetModuleConfigID()
if _, ok := configsProviderMap[configKey]; !ok {
Expand All @@ -36,7 +31,7 @@ func (t *TestProvidersTransformer) Transform(g *terraform.Graph) error {
runProviderMap[node] = configsProviderMap[configKey]

// Add an edge from the test run node to the root provider node
g.Connect(dag.BasicEdge(v, rootProviderNode))
g.Connect(dag.BasicEdge(node, rootProviderNode))
}

return nil
Expand Down
8 changes: 2 additions & 6 deletions internal/moduletest/graph/transform_state_cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ type TestStateCleanupTransformer struct {
func (t *TestStateCleanupTransformer) Transform(g *terraform.Graph) error {
cleanupMap := make(map[string]*NodeStateCleanup)

for _, v := range g.Vertices() {
node, ok := v.(*NodeTestRun)
if !ok {
continue
}
for node := range dag.SelectSeq[*NodeTestRun](g.VerticesSeq()) {
key := node.run.GetStateKey()
if _, exists := cleanupMap[key]; !exists {
cleanupMap[key] = &NodeStateCleanup{stateKey: key, opts: t.opts}
Expand All @@ -40,7 +36,7 @@ func (t *TestStateCleanupTransformer) Transform(g *terraform.Graph) error {
// existing CLI output.
rootCleanupNode := t.addRootCleanupNode(g)

for _, v := range g.Vertices() {
for v := range g.VerticesSeq() {
switch node := v.(type) {
case *NodeTestRun:
// All the runs that share the same state, must share the same cleanup node,
Expand Down