Skip to content

Commit a87696a

Browse files
committed
Allow discriminated unions of more advanced types
This extends `z.discriminatedUnion` to support not just objects but intersections, unions, and discriminated unions of objects, and nested combinations thereof, so long as the discriminator is always present on the final object value.
1 parent b13975e commit a87696a

File tree

4 files changed

+434
-52
lines changed

4 files changed

+434
-52
lines changed

deno/lib/__tests__/discriminated-unions.test.ts

+72
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,75 @@ test("readonly array of options", () => {
319319
z.discriminatedUnion("type", options).parse({ type: "x", val: 1 })
320320
).toEqual({ type: "x", val: 1 });
321321
});
322+
323+
test("valid - unions and intersections of objects", () => {
324+
const union = z.discriminatedUnion("type", [
325+
z.object({ type: z.literal("a"), a: z.string() }),
326+
z.object({ type: z.literal("b") }).and(z.object({ b: z.string() })),
327+
z
328+
.object({ type: z.literal("c"), c: z.string() })
329+
.or(z.object({ type: z.literal("c"), c: z.number() }))
330+
.or(z.object({ type: z.literal("d"), d: z.string() })),
331+
z
332+
.object({ type: z.literal("e"), e: z.string() })
333+
.and(z.object({ foo: z.string() }).or(z.object({ bar: z.string() }))),
334+
z
335+
.object({ type: z.literal("f"), f: z.string() })
336+
.or(z.object({ type: z.literal("f") }).and(z.object({ f: z.number() }))),
337+
z.discriminatedUnion("foo", [
338+
z.object({ type: z.literal("g"), foo: z.literal("bar") }),
339+
z.object({ type: z.literal("h"), foo: z.literal("baz") }),
340+
]),
341+
z
342+
.object({ type: z.literal("i").or(z.literal("j")) })
343+
.and(
344+
z.object({
345+
type: z.literal("i").or(z.literal("j")).and(z.literal("i")),
346+
})
347+
)
348+
.and(z.object({ type: z.literal("i") }))
349+
.and(z.object({ foo: z.string() })),
350+
]);
351+
352+
expect(union.parse({ type: "a", a: "123" })).toEqual({ type: "a", a: "123" });
353+
expect(union.parse({ type: "b", b: "123" })).toEqual({ type: "b", b: "123" });
354+
expect(union.parse({ type: "c", c: "123" })).toEqual({ type: "c", c: "123" });
355+
expect(union.parse({ type: "c", c: 123 })).toEqual({ type: "c", c: 123 });
356+
expect(union.parse({ type: "d", d: "123" })).toEqual({ type: "d", d: "123" });
357+
expect(() => {
358+
union.parse({ type: "d", c: "123" });
359+
}).toThrow();
360+
expect(union.parse({ type: "e", e: "123", foo: "456" })).toEqual({
361+
type: "e",
362+
e: "123",
363+
foo: "456",
364+
});
365+
expect(union.parse({ type: "e", e: "123", bar: "456" })).toEqual({
366+
type: "e",
367+
e: "123",
368+
bar: "456",
369+
});
370+
expect(() => {
371+
union.parse({ type: "e", e: "123" });
372+
}).toThrow();
373+
expect(union.parse({ type: "f", f: "123" })).toEqual({ type: "f", f: "123" });
374+
expect(union.parse({ type: "f", f: 123 })).toEqual({ type: "f", f: 123 });
375+
expect(union.parse({ type: "g", foo: "bar" })).toEqual({
376+
type: "g",
377+
foo: "bar",
378+
});
379+
expect(union.parse({ type: "h", foo: "baz" })).toEqual({
380+
type: "h",
381+
foo: "baz",
382+
});
383+
expect(() => {
384+
union.parse({ type: "h", foo: "bar" });
385+
}).toThrow();
386+
expect(union.parse({ type: "i", foo: "123" })).toEqual({
387+
type: "i",
388+
foo: "123",
389+
});
390+
expect(() => {
391+
union.parse({ type: "j", foo: "123" });
392+
}).toThrow();
393+
});

deno/lib/types.ts

+145-26
Original file line numberDiff line numberDiff line change
@@ -3080,7 +3080,9 @@ export type AnyZodObject = ZodObject<any, any, any>;
30803080
////////// //////////
30813081
////////////////////////////////////////
30823082
////////////////////////////////////////
3083-
export type ZodUnionOptions = Readonly<[ZodTypeAny, ...ZodTypeAny[]]>;
3083+
export type ZodUnionOptions<T extends ZodTypeAny = ZodTypeAny> = Readonly<
3084+
[T, ...T[]]
3085+
>;
30843086
export interface ZodUnionDef<
30853087
T extends ZodUnionOptions = Readonly<
30863088
[ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]]
@@ -3222,46 +3224,161 @@ export class ZodUnion<T extends ZodUnionOptions> extends ZodType<
32223224
/////////////////////////////////////////////////////
32233225
/////////////////////////////////////////////////////
32243226

3225-
const getDiscriminator = <T extends ZodTypeAny>(type: T): Primitive[] => {
3227+
type AnyZodDiscriminatedUnionOption =
3228+
| SomeZodObject
3229+
| ZodIntersection<
3230+
AnyZodDiscriminatedUnionOption,
3231+
AnyZodDiscriminatedUnionOption
3232+
>
3233+
| ZodUnion<ZodUnionOptions<AnyZodDiscriminatedUnionOption>>
3234+
| ZodDiscriminatedUnion<any, readonly AnyZodDiscriminatedUnionOption[]>;
3235+
3236+
export type ZodDiscriminatedUnionOption<Discriminator extends string> =
3237+
| ZodObject<
3238+
{ [key in Discriminator]: ZodTypeAny } & ZodRawShape,
3239+
UnknownKeysParam,
3240+
ZodTypeAny
3241+
>
3242+
| ZodIntersection<
3243+
ZodDiscriminatedUnionOption<Discriminator>,
3244+
AnyZodDiscriminatedUnionOption
3245+
>
3246+
| ZodIntersection<
3247+
AnyZodDiscriminatedUnionOption,
3248+
ZodDiscriminatedUnionOption<Discriminator>
3249+
>
3250+
| ZodUnion<ZodUnionOptions<ZodDiscriminatedUnionOption<Discriminator>>>
3251+
| ZodDiscriminatedUnion<
3252+
any,
3253+
readonly ZodDiscriminatedUnionOption<Discriminator>[]
3254+
>;
3255+
3256+
const getDiscriminatorValues = (
3257+
discriminator: string,
3258+
type: AnyZodDiscriminatedUnionOption,
3259+
values: Set<Primitive>
3260+
) => {
3261+
if (type instanceof ZodObject) {
3262+
getDiscriminator(type.shape[discriminator], values);
3263+
} else if (type instanceof ZodIntersection) {
3264+
const leftHasDiscriminator = hasDiscriminator(
3265+
discriminator,
3266+
type._def.left
3267+
);
3268+
const rightHasDiscriminator = hasDiscriminator(
3269+
discriminator,
3270+
type._def.right
3271+
);
3272+
3273+
if (leftHasDiscriminator && rightHasDiscriminator) {
3274+
const leftValues = new Set<Primitive>();
3275+
const rightValues = new Set<Primitive>();
3276+
3277+
getDiscriminatorValues(discriminator, type._def.left, leftValues);
3278+
getDiscriminatorValues(discriminator, type._def.right, rightValues);
3279+
3280+
for (const value of leftValues) {
3281+
if (rightValues.has(value)) {
3282+
values.add(value);
3283+
}
3284+
}
3285+
} else if (leftHasDiscriminator) {
3286+
getDiscriminatorValues(discriminator, type._def.left, values);
3287+
} else if (rightHasDiscriminator) {
3288+
getDiscriminatorValues(discriminator, type._def.right, values);
3289+
}
3290+
} else if (
3291+
type instanceof ZodUnion ||
3292+
type instanceof ZodDiscriminatedUnion
3293+
) {
3294+
for (const optionType of type.options) {
3295+
getDiscriminatorValues(discriminator, optionType, values);
3296+
}
3297+
}
3298+
};
3299+
3300+
const hasDiscriminator = (
3301+
discriminator: string,
3302+
type: AnyZodDiscriminatedUnionOption
3303+
): boolean => {
3304+
if (type instanceof ZodObject) {
3305+
return discriminator in type.shape;
3306+
} else if (type instanceof ZodIntersection) {
3307+
return (
3308+
hasDiscriminator(discriminator, type._def.left) ||
3309+
hasDiscriminator(discriminator, type._def.right)
3310+
);
3311+
} else if (
3312+
type instanceof ZodUnion ||
3313+
type instanceof ZodDiscriminatedUnion
3314+
) {
3315+
return type.options.some((optionType) =>
3316+
hasDiscriminator(discriminator, optionType)
3317+
);
3318+
} else {
3319+
return false;
3320+
}
3321+
};
3322+
3323+
const getDiscriminator = <T extends ZodTypeAny>(
3324+
type: T,
3325+
values: Set<Primitive>
3326+
) => {
32263327
if (type instanceof ZodLazy) {
3227-
return getDiscriminator(type.schema);
3328+
getDiscriminator(type.schema, values);
32283329
} else if (type instanceof ZodEffects) {
3229-
return getDiscriminator(type.innerType());
3330+
getDiscriminator(type.innerType(), values);
32303331
} else if (type instanceof ZodLiteral) {
3231-
return [type.value];
3332+
values.add(type.value);
32323333
} else if (type instanceof ZodEnum) {
3233-
return type.options;
3334+
for (const value of type.options) {
3335+
values.add(value);
3336+
}
32343337
} else if (type instanceof ZodNativeEnum) {
32353338
// eslint-disable-next-line ban/ban
3236-
return util.objectValues(type.enum as any);
3339+
for (const value of util.objectValues(type.enum as any)) {
3340+
values.add(value);
3341+
}
32373342
} else if (type instanceof ZodDefault) {
3238-
return getDiscriminator(type._def.innerType);
3343+
getDiscriminator(type._def.innerType, values);
32393344
} else if (type instanceof ZodUndefined) {
3240-
return [undefined];
3345+
values.add(undefined);
32413346
} else if (type instanceof ZodNull) {
3242-
return [null];
3347+
values.add(null);
32433348
} else if (type instanceof ZodOptional) {
3244-
return [undefined, ...getDiscriminator(type.unwrap())];
3349+
values.add(undefined);
3350+
getDiscriminator(type.unwrap(), values);
32453351
} else if (type instanceof ZodNullable) {
3246-
return [null, ...getDiscriminator(type.unwrap())];
3352+
values.add(null);
3353+
getDiscriminator(type.unwrap(), values);
32473354
} else if (type instanceof ZodBranded) {
3248-
return getDiscriminator(type.unwrap());
3355+
getDiscriminator(type.unwrap(), values);
32493356
} else if (type instanceof ZodReadonly) {
3250-
return getDiscriminator(type.unwrap());
3357+
getDiscriminator(type.unwrap(), values);
32513358
} else if (type instanceof ZodCatch) {
3252-
return getDiscriminator(type._def.innerType);
3253-
} else {
3254-
return [];
3359+
getDiscriminator(type._def.innerType, values);
3360+
} else if (type instanceof ZodIntersection) {
3361+
const leftValues = new Set<Primitive>();
3362+
const rightValues = new Set<Primitive>();
3363+
3364+
getDiscriminator(type._def.left, leftValues);
3365+
getDiscriminator(type._def.right, rightValues);
3366+
3367+
for (const value of leftValues) {
3368+
if (rightValues.has(value)) {
3369+
values.add(value);
3370+
}
3371+
}
3372+
} else if (
3373+
type instanceof ZodUnion ||
3374+
type instanceof ZodDiscriminatedUnion
3375+
) {
3376+
for (const optionType of type.options) {
3377+
getDiscriminator(optionType, values);
3378+
}
32553379
}
32563380
};
32573381

3258-
export type ZodDiscriminatedUnionOption<Discriminator extends string> =
3259-
ZodObject<
3260-
{ [key in Discriminator]: ZodTypeAny } & ZodRawShape,
3261-
UnknownKeysParam,
3262-
ZodTypeAny
3263-
>;
3264-
32653382
export interface ZodDiscriminatedUnionDef<
32663383
Discriminator extends string,
32673384
Options extends readonly ZodDiscriminatedUnionOption<string>[] = ZodDiscriminatedUnionOption<string>[]
@@ -3357,9 +3474,10 @@ export class ZodDiscriminatedUnion<
33573474
const optionsMap: Map<Primitive, Types[number]> = new Map();
33583475

33593476
// try {
3477+
const discriminatorValues = new Set<Primitive>();
33603478
for (const type of options) {
3361-
const discriminatorValues = getDiscriminator(type.shape[discriminator]);
3362-
if (!discriminatorValues.length) {
3479+
getDiscriminatorValues(discriminator, type, discriminatorValues);
3480+
if (discriminatorValues.size < 1) {
33633481
throw new Error(
33643482
`A discriminator value for key \`${discriminator}\` could not be extracted from all schema options`
33653483
);
@@ -3375,6 +3493,7 @@ export class ZodDiscriminatedUnion<
33753493

33763494
optionsMap.set(value, type);
33773495
}
3496+
discriminatorValues.clear();
33783497
}
33793498

33803499
return new ZodDiscriminatedUnion<

src/__tests__/discriminated-unions.test.ts

+72
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,75 @@ test("readonly array of options", () => {
318318
z.discriminatedUnion("type", options).parse({ type: "x", val: 1 })
319319
).toEqual({ type: "x", val: 1 });
320320
});
321+
322+
test("valid - unions and intersections of objects", () => {
323+
const union = z.discriminatedUnion("type", [
324+
z.object({ type: z.literal("a"), a: z.string() }),
325+
z.object({ type: z.literal("b") }).and(z.object({ b: z.string() })),
326+
z
327+
.object({ type: z.literal("c"), c: z.string() })
328+
.or(z.object({ type: z.literal("c"), c: z.number() }))
329+
.or(z.object({ type: z.literal("d"), d: z.string() })),
330+
z
331+
.object({ type: z.literal("e"), e: z.string() })
332+
.and(z.object({ foo: z.string() }).or(z.object({ bar: z.string() }))),
333+
z
334+
.object({ type: z.literal("f"), f: z.string() })
335+
.or(z.object({ type: z.literal("f") }).and(z.object({ f: z.number() }))),
336+
z.discriminatedUnion("foo", [
337+
z.object({ type: z.literal("g"), foo: z.literal("bar") }),
338+
z.object({ type: z.literal("h"), foo: z.literal("baz") }),
339+
]),
340+
z
341+
.object({ type: z.literal("i").or(z.literal("j")) })
342+
.and(
343+
z.object({
344+
type: z.literal("i").or(z.literal("j")).and(z.literal("i")),
345+
})
346+
)
347+
.and(z.object({ type: z.literal("i") }))
348+
.and(z.object({ foo: z.string() })),
349+
]);
350+
351+
expect(union.parse({ type: "a", a: "123" })).toEqual({ type: "a", a: "123" });
352+
expect(union.parse({ type: "b", b: "123" })).toEqual({ type: "b", b: "123" });
353+
expect(union.parse({ type: "c", c: "123" })).toEqual({ type: "c", c: "123" });
354+
expect(union.parse({ type: "c", c: 123 })).toEqual({ type: "c", c: 123 });
355+
expect(union.parse({ type: "d", d: "123" })).toEqual({ type: "d", d: "123" });
356+
expect(() => {
357+
union.parse({ type: "d", c: "123" });
358+
}).toThrow();
359+
expect(union.parse({ type: "e", e: "123", foo: "456" })).toEqual({
360+
type: "e",
361+
e: "123",
362+
foo: "456",
363+
});
364+
expect(union.parse({ type: "e", e: "123", bar: "456" })).toEqual({
365+
type: "e",
366+
e: "123",
367+
bar: "456",
368+
});
369+
expect(() => {
370+
union.parse({ type: "e", e: "123" });
371+
}).toThrow();
372+
expect(union.parse({ type: "f", f: "123" })).toEqual({ type: "f", f: "123" });
373+
expect(union.parse({ type: "f", f: 123 })).toEqual({ type: "f", f: 123 });
374+
expect(union.parse({ type: "g", foo: "bar" })).toEqual({
375+
type: "g",
376+
foo: "bar",
377+
});
378+
expect(union.parse({ type: "h", foo: "baz" })).toEqual({
379+
type: "h",
380+
foo: "baz",
381+
});
382+
expect(() => {
383+
union.parse({ type: "h", foo: "bar" });
384+
}).toThrow();
385+
expect(union.parse({ type: "i", foo: "123" })).toEqual({
386+
type: "i",
387+
foo: "123",
388+
});
389+
expect(() => {
390+
union.parse({ type: "j", foo: "123" });
391+
}).toThrow();
392+
});

0 commit comments

Comments
 (0)