Skip to content

Commit de72d0a

Browse files
committed
Allow discriminated unions of more advanced types
This extends `z.discriminatedUnion` to support not just objects but intersections, unions, and discriminated unions of objects, and nested combinations thereof, so long as the discriminator is always present on the final object value.
1 parent b13975e commit de72d0a

File tree

4 files changed

+212
-52
lines changed

4 files changed

+212
-52
lines changed

deno/lib/__tests__/discriminated-unions.test.ts

+27
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,30 @@ test("readonly array of options", () => {
319319
z.discriminatedUnion("type", options).parse({ type: "x", val: 1 })
320320
).toEqual({ type: "x", val: 1 });
321321
});
322+
323+
test("valid - unions and intersections of objects", () => {
324+
expect(
325+
z
326+
.discriminatedUnion("type", [
327+
z.object({ type: z.literal("a"), a: z.string() }),
328+
z.object({ type: z.literal("b") }).and(z.object({ b: z.string() })),
329+
z
330+
.object({ type: z.literal("c"), c: z.string() })
331+
.or(z.object({ type: z.literal("c"), c: z.number() }))
332+
.or(z.object({ type: z.literal("d"), d: z.string() })),
333+
z
334+
.object({ type: z.literal("e"), e: z.string() })
335+
.and(z.object({ foo: z.string() }).or(z.object({ bar: z.string() }))),
336+
z
337+
.object({ type: z.literal("f"), f: z.string() })
338+
.or(
339+
z.object({ type: z.literal("f") }).and(z.object({ f: z.number() }))
340+
),
341+
z.discriminatedUnion("foo", [
342+
z.object({ type: z.literal("g"), foo: z.literal("bar") }),
343+
z.object({ type: z.literal("h"), foo: z.literal("baz") }),
344+
]),
345+
])
346+
.parse({ type: "f", f: "123" })
347+
).toEqual({ type: "f", f: "123" });
348+
});

deno/lib/types.ts

+79-26
Original file line numberDiff line numberDiff line change
@@ -3080,7 +3080,9 @@ export type AnyZodObject = ZodObject<any, any, any>;
30803080
////////// //////////
30813081
////////////////////////////////////////
30823082
////////////////////////////////////////
3083-
export type ZodUnionOptions = Readonly<[ZodTypeAny, ...ZodTypeAny[]]>;
3083+
export type ZodUnionOptions<T extends ZodTypeAny = ZodTypeAny> = Readonly<
3084+
[T, ...T[]]
3085+
>;
30843086
export interface ZodUnionDef<
30853087
T extends ZodUnionOptions = Readonly<
30863088
[ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]]
@@ -3222,46 +3224,95 @@ export class ZodUnion<T extends ZodUnionOptions> extends ZodType<
32223224
/////////////////////////////////////////////////////
32233225
/////////////////////////////////////////////////////
32243226

3225-
const getDiscriminator = <T extends ZodTypeAny>(type: T): Primitive[] => {
3227+
type AnyZodDiscriminatedUnionOption =
3228+
| SomeZodObject
3229+
| ZodIntersection<
3230+
AnyZodDiscriminatedUnionOption,
3231+
AnyZodDiscriminatedUnionOption
3232+
>
3233+
| ZodUnion<ZodUnionOptions<AnyZodDiscriminatedUnionOption>>
3234+
| ZodDiscriminatedUnion<any, readonly AnyZodDiscriminatedUnionOption[]>;
3235+
3236+
export type ZodDiscriminatedUnionOption<Discriminator extends string> =
3237+
| ZodObject<
3238+
{ [key in Discriminator]: ZodTypeAny } & ZodRawShape,
3239+
UnknownKeysParam,
3240+
ZodTypeAny
3241+
>
3242+
| ZodIntersection<
3243+
ZodDiscriminatedUnionOption<Discriminator>,
3244+
AnyZodDiscriminatedUnionOption
3245+
>
3246+
| ZodIntersection<
3247+
AnyZodDiscriminatedUnionOption,
3248+
ZodDiscriminatedUnionOption<Discriminator>
3249+
>
3250+
| ZodUnion<ZodUnionOptions<ZodDiscriminatedUnionOption<Discriminator>>>
3251+
| ZodDiscriminatedUnion<
3252+
any,
3253+
readonly ZodDiscriminatedUnionOption<Discriminator>[]
3254+
>;
3255+
3256+
const getDiscriminatorValues = (
3257+
discriminator: string,
3258+
type: AnyZodDiscriminatedUnionOption,
3259+
values: Set<Primitive>
3260+
) => {
3261+
if (type instanceof ZodObject) {
3262+
getDiscriminator(type.shape[discriminator], values);
3263+
} else if (type instanceof ZodIntersection) {
3264+
getDiscriminatorValues(discriminator, type._def.left, values);
3265+
getDiscriminatorValues(discriminator, type._def.right, values);
3266+
} else if (
3267+
type instanceof ZodUnion ||
3268+
type instanceof ZodDiscriminatedUnion
3269+
) {
3270+
for (const optionType of type.options) {
3271+
getDiscriminatorValues(discriminator, optionType, values);
3272+
}
3273+
}
3274+
};
3275+
3276+
const getDiscriminator = <T extends ZodTypeAny>(
3277+
type: T,
3278+
values: Set<Primitive>
3279+
) => {
32263280
if (type instanceof ZodLazy) {
3227-
return getDiscriminator(type.schema);
3281+
getDiscriminator(type.schema, values);
32283282
} else if (type instanceof ZodEffects) {
3229-
return getDiscriminator(type.innerType());
3283+
getDiscriminator(type.innerType(), values);
32303284
} else if (type instanceof ZodLiteral) {
3231-
return [type.value];
3285+
values.add(type.value);
32323286
} else if (type instanceof ZodEnum) {
3233-
return type.options;
3287+
for (const value of type.options) {
3288+
values.add(value);
3289+
}
32343290
} else if (type instanceof ZodNativeEnum) {
32353291
// eslint-disable-next-line ban/ban
3236-
return util.objectValues(type.enum as any);
3292+
for (const value of util.objectValues(type.enum as any)) {
3293+
values.add(value);
3294+
}
32373295
} else if (type instanceof ZodDefault) {
3238-
return getDiscriminator(type._def.innerType);
3296+
getDiscriminator(type._def.innerType, values);
32393297
} else if (type instanceof ZodUndefined) {
3240-
return [undefined];
3298+
values.add(undefined);
32413299
} else if (type instanceof ZodNull) {
3242-
return [null];
3300+
values.add(null);
32433301
} else if (type instanceof ZodOptional) {
3244-
return [undefined, ...getDiscriminator(type.unwrap())];
3302+
values.add(undefined);
3303+
getDiscriminator(type.unwrap(), values);
32453304
} else if (type instanceof ZodNullable) {
3246-
return [null, ...getDiscriminator(type.unwrap())];
3305+
values.add(null);
3306+
getDiscriminator(type.unwrap(), values);
32473307
} else if (type instanceof ZodBranded) {
3248-
return getDiscriminator(type.unwrap());
3308+
getDiscriminator(type.unwrap(), values);
32493309
} else if (type instanceof ZodReadonly) {
3250-
return getDiscriminator(type.unwrap());
3310+
getDiscriminator(type.unwrap(), values);
32513311
} else if (type instanceof ZodCatch) {
3252-
return getDiscriminator(type._def.innerType);
3253-
} else {
3254-
return [];
3312+
getDiscriminator(type._def.innerType, values);
32553313
}
32563314
};
32573315

3258-
export type ZodDiscriminatedUnionOption<Discriminator extends string> =
3259-
ZodObject<
3260-
{ [key in Discriminator]: ZodTypeAny } & ZodRawShape,
3261-
UnknownKeysParam,
3262-
ZodTypeAny
3263-
>;
3264-
32653316
export interface ZodDiscriminatedUnionDef<
32663317
Discriminator extends string,
32673318
Options extends readonly ZodDiscriminatedUnionOption<string>[] = ZodDiscriminatedUnionOption<string>[]
@@ -3357,9 +3408,10 @@ export class ZodDiscriminatedUnion<
33573408
const optionsMap: Map<Primitive, Types[number]> = new Map();
33583409

33593410
// try {
3411+
const discriminatorValues = new Set<Primitive>();
33603412
for (const type of options) {
3361-
const discriminatorValues = getDiscriminator(type.shape[discriminator]);
3362-
if (!discriminatorValues.length) {
3413+
getDiscriminatorValues(discriminator, type, discriminatorValues);
3414+
if (discriminatorValues.size < 1) {
33633415
throw new Error(
33643416
`A discriminator value for key \`${discriminator}\` could not be extracted from all schema options`
33653417
);
@@ -3375,6 +3427,7 @@ export class ZodDiscriminatedUnion<
33753427

33763428
optionsMap.set(value, type);
33773429
}
3430+
discriminatorValues.clear();
33783431
}
33793432

33803433
return new ZodDiscriminatedUnion<

src/__tests__/discriminated-unions.test.ts

+27
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,30 @@ test("readonly array of options", () => {
318318
z.discriminatedUnion("type", options).parse({ type: "x", val: 1 })
319319
).toEqual({ type: "x", val: 1 });
320320
});
321+
322+
test("valid - unions and intersections of objects", () => {
323+
expect(
324+
z
325+
.discriminatedUnion("type", [
326+
z.object({ type: z.literal("a"), a: z.string() }),
327+
z.object({ type: z.literal("b") }).and(z.object({ b: z.string() })),
328+
z
329+
.object({ type: z.literal("c"), c: z.string() })
330+
.or(z.object({ type: z.literal("c"), c: z.number() }))
331+
.or(z.object({ type: z.literal("d"), d: z.string() })),
332+
z
333+
.object({ type: z.literal("e"), e: z.string() })
334+
.and(z.object({ foo: z.string() }).or(z.object({ bar: z.string() }))),
335+
z
336+
.object({ type: z.literal("f"), f: z.string() })
337+
.or(
338+
z.object({ type: z.literal("f") }).and(z.object({ f: z.number() }))
339+
),
340+
z.discriminatedUnion("foo", [
341+
z.object({ type: z.literal("g"), foo: z.literal("bar") }),
342+
z.object({ type: z.literal("h"), foo: z.literal("baz") }),
343+
]),
344+
])
345+
.parse({ type: "f", f: "123" })
346+
).toEqual({ type: "f", f: "123" });
347+
});

src/types.ts

+79-26
Original file line numberDiff line numberDiff line change
@@ -3075,7 +3075,9 @@ export type AnyZodObject = ZodObject<any, any, any>;
30753075
////////// //////////
30763076
////////////////////////////////////////
30773077
////////////////////////////////////////
3078-
export type ZodUnionOptions = Readonly<[ZodTypeAny, ...ZodTypeAny[]]>;
3078+
export type ZodUnionOptions<T extends ZodTypeAny = ZodTypeAny> = Readonly<
3079+
[T, ...T[]]
3080+
>;
30793081
export interface ZodUnionDef<
30803082
T extends ZodUnionOptions = Readonly<
30813083
[ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]]
@@ -3217,46 +3219,95 @@ export class ZodUnion<T extends ZodUnionOptions> extends ZodType<
32173219
/////////////////////////////////////////////////////
32183220
/////////////////////////////////////////////////////
32193221

3220-
const getDiscriminator = <T extends ZodTypeAny>(type: T): Primitive[] => {
3222+
type AnyZodDiscriminatedUnionOption =
3223+
| SomeZodObject
3224+
| ZodIntersection<
3225+
AnyZodDiscriminatedUnionOption,
3226+
AnyZodDiscriminatedUnionOption
3227+
>
3228+
| ZodUnion<ZodUnionOptions<AnyZodDiscriminatedUnionOption>>
3229+
| ZodDiscriminatedUnion<any, readonly AnyZodDiscriminatedUnionOption[]>;
3230+
3231+
export type ZodDiscriminatedUnionOption<Discriminator extends string> =
3232+
| ZodObject<
3233+
{ [key in Discriminator]: ZodTypeAny } & ZodRawShape,
3234+
UnknownKeysParam,
3235+
ZodTypeAny
3236+
>
3237+
| ZodIntersection<
3238+
ZodDiscriminatedUnionOption<Discriminator>,
3239+
AnyZodDiscriminatedUnionOption
3240+
>
3241+
| ZodIntersection<
3242+
AnyZodDiscriminatedUnionOption,
3243+
ZodDiscriminatedUnionOption<Discriminator>
3244+
>
3245+
| ZodUnion<ZodUnionOptions<ZodDiscriminatedUnionOption<Discriminator>>>
3246+
| ZodDiscriminatedUnion<
3247+
any,
3248+
readonly ZodDiscriminatedUnionOption<Discriminator>[]
3249+
>;
3250+
3251+
const getDiscriminatorValues = (
3252+
discriminator: string,
3253+
type: AnyZodDiscriminatedUnionOption,
3254+
values: Set<Primitive>
3255+
) => {
3256+
if (type instanceof ZodObject) {
3257+
getDiscriminator(type.shape[discriminator], values);
3258+
} else if (type instanceof ZodIntersection) {
3259+
getDiscriminatorValues(discriminator, type._def.left, values);
3260+
getDiscriminatorValues(discriminator, type._def.right, values);
3261+
} else if (
3262+
type instanceof ZodUnion ||
3263+
type instanceof ZodDiscriminatedUnion
3264+
) {
3265+
for (const optionType of type.options) {
3266+
getDiscriminatorValues(discriminator, optionType, values);
3267+
}
3268+
}
3269+
};
3270+
3271+
const getDiscriminator = <T extends ZodTypeAny>(
3272+
type: T,
3273+
values: Set<Primitive>
3274+
) => {
32213275
if (type instanceof ZodLazy) {
3222-
return getDiscriminator(type.schema);
3276+
getDiscriminator(type.schema, values);
32233277
} else if (type instanceof ZodEffects) {
3224-
return getDiscriminator(type.innerType());
3278+
getDiscriminator(type.innerType(), values);
32253279
} else if (type instanceof ZodLiteral) {
3226-
return [type.value];
3280+
values.add(type.value);
32273281
} else if (type instanceof ZodEnum) {
3228-
return type.options;
3282+
for (const value of type.options) {
3283+
values.add(value);
3284+
}
32293285
} else if (type instanceof ZodNativeEnum) {
32303286
// eslint-disable-next-line ban/ban
3231-
return util.objectValues(type.enum as any);
3287+
for (const value of util.objectValues(type.enum as any)) {
3288+
values.add(value);
3289+
}
32323290
} else if (type instanceof ZodDefault) {
3233-
return getDiscriminator(type._def.innerType);
3291+
getDiscriminator(type._def.innerType, values);
32343292
} else if (type instanceof ZodUndefined) {
3235-
return [undefined];
3293+
values.add(undefined);
32363294
} else if (type instanceof ZodNull) {
3237-
return [null];
3295+
values.add(null);
32383296
} else if (type instanceof ZodOptional) {
3239-
return [undefined, ...getDiscriminator(type.unwrap())];
3297+
values.add(undefined);
3298+
getDiscriminator(type.unwrap(), values);
32403299
} else if (type instanceof ZodNullable) {
3241-
return [null, ...getDiscriminator(type.unwrap())];
3300+
values.add(null);
3301+
getDiscriminator(type.unwrap(), values);
32423302
} else if (type instanceof ZodBranded) {
3243-
return getDiscriminator(type.unwrap());
3303+
getDiscriminator(type.unwrap(), values);
32443304
} else if (type instanceof ZodReadonly) {
3245-
return getDiscriminator(type.unwrap());
3305+
getDiscriminator(type.unwrap(), values);
32463306
} else if (type instanceof ZodCatch) {
3247-
return getDiscriminator(type._def.innerType);
3248-
} else {
3249-
return [];
3307+
getDiscriminator(type._def.innerType, values);
32503308
}
32513309
};
32523310

3253-
export type ZodDiscriminatedUnionOption<Discriminator extends string> =
3254-
ZodObject<
3255-
{ [key in Discriminator]: ZodTypeAny } & ZodRawShape,
3256-
UnknownKeysParam,
3257-
ZodTypeAny
3258-
>;
3259-
32603311
export interface ZodDiscriminatedUnionDef<
32613312
Discriminator extends string,
32623313
Options extends readonly ZodDiscriminatedUnionOption<string>[] = ZodDiscriminatedUnionOption<string>[]
@@ -3352,9 +3403,10 @@ export class ZodDiscriminatedUnion<
33523403
const optionsMap: Map<Primitive, Types[number]> = new Map();
33533404

33543405
// try {
3406+
const discriminatorValues = new Set<Primitive>();
33553407
for (const type of options) {
3356-
const discriminatorValues = getDiscriminator(type.shape[discriminator]);
3357-
if (!discriminatorValues.length) {
3408+
getDiscriminatorValues(discriminator, type, discriminatorValues);
3409+
if (discriminatorValues.size < 1) {
33583410
throw new Error(
33593411
`A discriminator value for key \`${discriminator}\` could not be extracted from all schema options`
33603412
);
@@ -3370,6 +3422,7 @@ export class ZodDiscriminatedUnion<
33703422

33713423
optionsMap.set(value, type);
33723424
}
3425+
discriminatorValues.clear();
33733426
}
33743427

33753428
return new ZodDiscriminatedUnion<

0 commit comments

Comments
 (0)