Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] InferReferenceEffects outputs a disjoint set of aliases #30974

Open
wants to merge 5 commits into
base: gh/mvitousek/35/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import {
eachTerminalOperand,
eachTerminalSuccessor,
} from '../HIR/visitors';
import DisjointSet from '../Utils/DisjointSet';
import {assertExhaustive} from '../Utils/utils';
import {
inferTerminalFunctionEffects,
Expand Down Expand Up @@ -105,7 +106,7 @@ const UndefinedValue: InstructionValue = {
export default function inferReferenceEffects(
fn: HIRFunction,
options: {isFunctionExpression: boolean} = {isFunctionExpression: false},
): void {
): DisjointSet<IdentifierId> {
/*
* Initial state contains function params
* TODO: include module declarations here as well
Expand Down Expand Up @@ -225,6 +226,7 @@ export default function inferReferenceEffects(
}
queue(fn.body.entry, initialState);

const finishedStates: Map<BlockId, InferenceState> = new Map();
const functionEffects: Array<FunctionEffect> = fn.effects ?? [];

while (queuedStates.size !== 0) {
Expand All @@ -237,6 +239,7 @@ export default function inferReferenceEffects(

statesByBlock.set(blockId, incomingState);
const state = incomingState.clone();
finishedStates.set(blockId, state);
inferBlock(fn.env, state, block, functionEffects);

for (const nextBlockId of eachTerminalSuccessor(block.terminal)) {
Expand All @@ -250,6 +253,12 @@ export default function inferReferenceEffects(
} else {
raiseFunctionEffectErrors(functionEffects);
}

const summaryState = Array(...finishedStates.values()).reduce(
(acc, state) => acc.merge(state) ?? acc,
);

return summaryState.aliases;
}

type FreezeAction = {values: Set<InstructionValue>; reason: Set<ValueReason>};
Expand All @@ -267,18 +276,26 @@ class InferenceState {
*/
#variables: Map<IdentifierId, Set<InstructionValue>>;

#aliases: DisjointSet<IdentifierId>;

constructor(
env: Environment,
values: Map<InstructionValue, AbstractValue>,
variables: Map<IdentifierId, Set<InstructionValue>>,
aliases: DisjointSet<IdentifierId>,
) {
this.#env = env;
this.#values = values;
this.#variables = variables;
this.#aliases = aliases;
}

get aliases(): DisjointSet<IdentifierId> {
return this.#aliases;
}

static empty(env: Environment): InferenceState {
return new InferenceState(env, new Map(), new Map());
return new InferenceState(env, new Map(), new Map(), new DisjointSet());
}

// (Re)initializes a @param value with its default @param kind.
Expand Down Expand Up @@ -338,6 +355,7 @@ class InferenceState {
suggestions: null,
});
this.#variables.set(place.identifier.id, new Set(values));
this.#aliases.union([place.identifier.id, value.identifier.id]);
place.abstractValue = value.abstractValue;
}

Expand Down Expand Up @@ -549,6 +567,7 @@ class InferenceState {
merge(other: InferenceState): InferenceState | null {
let nextValues: Map<InstructionValue, AbstractValue> | null = null;
let nextVariables: Map<IdentifierId, Set<InstructionValue>> | null = null;
let nextAliases: DisjointSet<IdentifierId> | null = null;

for (const [id, thisValue] of this.#values) {
const otherValue = other.#values.get(id);
Expand Down Expand Up @@ -593,13 +612,21 @@ class InferenceState {
nextVariables.set(id, new Set(otherValues));
}

if (nextVariables === null && nextValues === null) {
if (!this.#aliases.equals(other.#aliases)) {
nextAliases = this.#aliases.copy();
for (const otherAliasSet of other.#aliases.buildSets()) {
nextAliases.union(Array(...otherAliasSet));
}
}

if (nextVariables === null && nextValues === null && nextAliases === null) {
return null;
} else {
return new InferenceState(
this.#env,
nextValues ?? new Map(this.#values),
nextVariables ?? new Map(this.#variables),
nextAliases ?? this.#aliases.copy(),
);
}
}
Expand All @@ -614,6 +641,7 @@ class InferenceState {
this.#env,
new Map(this.#values),
new Map(this.#variables),
this.#aliases.copy(),
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,54 @@ export default class DisjointSet<T> {
return [...sets.values()];
}

copy(): DisjointSet<T> {
const copy = new DisjointSet<T>();
copy.#entries = new Map(this.#entries);
return copy;
}

equals(other: DisjointSet<T>): boolean {
if (this.size !== other.size) {
return false;
}
const rootMap = new Map<T, T>();
for (const thisGroupId of this.#entries.values()) {
const otherGroupId = other.find(thisGroupId);
if (otherGroupId === null || this.find(otherGroupId) !== thisGroupId) {
return false;
}
rootMap.set(thisGroupId, otherGroupId);
}

for (const otherGroupId of other.#entries.values()) {
if (!new Set(rootMap.values()).has(otherGroupId)) {
return false;
}
}

for (const item of this.#entries.keys()) {
const otherRoot = other.find(item);
if (otherRoot === null) {
return false;
}
const thisRoot = this.find(item);
CompilerError.invariant(thisRoot != null, {
reason: 'Expected item to be in set',
loc: null,
});
if (rootMap.get(thisRoot) !== otherRoot) {
return false;
}
}
for (const item of other.#entries.keys()) {
const thisRoot = this.find(item);
if (thisRoot === null) {
return false;
}
}
return true;
}

get size(): number {
return this.#entries.size;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,45 @@ describe('DisjointSet', () => {

identifiers.forEach((_, group) => expect(group).toBe(z));
});

it('`.equals` is false when it should be', () => {
const x = new DisjointSet<TestIdentifier>();
const y = new DisjointSet<TestIdentifier>();

const [a, b] = makeIdentifiers('a', 'b');
x.union([a, b]);
y.union([a]);
y.union([b]);

expect(x.equals(y)).toBe(false);
expect(y.equals(x)).toBe(false);
});

it('`.equals` is true when it should be', () => {
const x = new DisjointSet<TestIdentifier>();
const y = new DisjointSet<TestIdentifier>();

const [a, b, c] = makeIdentifiers('a', 'b', 'c');
x.union([a, b, c]);
y.union([a, b]);
y.union([b, c]);

expect(x.equals(y)).toBe(true);
expect(y.equals(x)).toBe(true);
});

it('`.copy` doesnt mutate the underlying', () => {
const x = new DisjointSet<TestIdentifier>();

const [a, b] = makeIdentifiers('a', 'b');
x.union([a]);
x.union([b]);

const y = x.copy();

y.union([a, b]);

expect(x.find(a) !== x.find(b)).toBe(true);
expect(y.find(a) === y.find(b)).toBe(true);
});
});