@@ -3080,7 +3080,9 @@ export type AnyZodObject = ZodObject<any, any, any>;
3080
3080
////////// //////////
3081
3081
////////////////////////////////////////
3082
3082
////////////////////////////////////////
3083
- export type ZodUnionOptions = Readonly < [ ZodTypeAny , ...ZodTypeAny [ ] ] > ;
3083
+ export type ZodUnionOptions < T extends ZodTypeAny = ZodTypeAny > = Readonly <
3084
+ [ T , ...T [ ] ]
3085
+ > ;
3084
3086
export interface ZodUnionDef <
3085
3087
T extends ZodUnionOptions = Readonly <
3086
3088
[ ZodTypeAny , ZodTypeAny , ...ZodTypeAny [ ] ]
@@ -3222,46 +3224,161 @@ export class ZodUnion<T extends ZodUnionOptions> extends ZodType<
3222
3224
/////////////////////////////////////////////////////
3223
3225
/////////////////////////////////////////////////////
3224
3226
3225
- const getDiscriminator = < T extends ZodTypeAny > ( type : T ) : Primitive [ ] => {
3227
+ type AnyZodDiscriminatedUnionOption =
3228
+ | SomeZodObject
3229
+ | ZodIntersection <
3230
+ AnyZodDiscriminatedUnionOption ,
3231
+ AnyZodDiscriminatedUnionOption
3232
+ >
3233
+ | ZodUnion < ZodUnionOptions < AnyZodDiscriminatedUnionOption > >
3234
+ | ZodDiscriminatedUnion < any , readonly AnyZodDiscriminatedUnionOption [ ] > ;
3235
+
3236
+ export type ZodDiscriminatedUnionOption < Discriminator extends string > =
3237
+ | ZodObject <
3238
+ { [ key in Discriminator ] : ZodTypeAny } & ZodRawShape ,
3239
+ UnknownKeysParam ,
3240
+ ZodTypeAny
3241
+ >
3242
+ | ZodIntersection <
3243
+ ZodDiscriminatedUnionOption < Discriminator > ,
3244
+ AnyZodDiscriminatedUnionOption
3245
+ >
3246
+ | ZodIntersection <
3247
+ AnyZodDiscriminatedUnionOption ,
3248
+ ZodDiscriminatedUnionOption < Discriminator >
3249
+ >
3250
+ | ZodUnion < ZodUnionOptions < ZodDiscriminatedUnionOption < Discriminator > > >
3251
+ | ZodDiscriminatedUnion <
3252
+ any ,
3253
+ readonly ZodDiscriminatedUnionOption < Discriminator > [ ]
3254
+ > ;
3255
+
3256
+ const getDiscriminatorValues = (
3257
+ discriminator : string ,
3258
+ type : AnyZodDiscriminatedUnionOption ,
3259
+ values : Set < Primitive >
3260
+ ) => {
3261
+ if ( type instanceof ZodObject ) {
3262
+ getDiscriminator ( type . shape [ discriminator ] , values ) ;
3263
+ } else if ( type instanceof ZodIntersection ) {
3264
+ const leftHasDiscriminator = hasDiscriminator (
3265
+ discriminator ,
3266
+ type . _def . left
3267
+ ) ;
3268
+ const rightHasDiscriminator = hasDiscriminator (
3269
+ discriminator ,
3270
+ type . _def . right
3271
+ ) ;
3272
+
3273
+ if ( leftHasDiscriminator && rightHasDiscriminator ) {
3274
+ const leftValues = new Set < Primitive > ( ) ;
3275
+ const rightValues = new Set < Primitive > ( ) ;
3276
+
3277
+ getDiscriminatorValues ( discriminator , type . _def . left , leftValues ) ;
3278
+ getDiscriminatorValues ( discriminator , type . _def . right , rightValues ) ;
3279
+
3280
+ for ( const value of leftValues ) {
3281
+ if ( rightValues . has ( value ) ) {
3282
+ values . add ( value ) ;
3283
+ }
3284
+ }
3285
+ } else if ( leftHasDiscriminator ) {
3286
+ getDiscriminatorValues ( discriminator , type . _def . left , values ) ;
3287
+ } else if ( rightHasDiscriminator ) {
3288
+ getDiscriminatorValues ( discriminator , type . _def . right , values ) ;
3289
+ }
3290
+ } else if (
3291
+ type instanceof ZodUnion ||
3292
+ type instanceof ZodDiscriminatedUnion
3293
+ ) {
3294
+ for ( const optionType of type . options ) {
3295
+ getDiscriminatorValues ( discriminator , optionType , values ) ;
3296
+ }
3297
+ }
3298
+ } ;
3299
+
3300
+ const hasDiscriminator = (
3301
+ discriminator : string ,
3302
+ type : AnyZodDiscriminatedUnionOption
3303
+ ) : boolean => {
3304
+ if ( type instanceof ZodObject ) {
3305
+ return discriminator in type . shape ;
3306
+ } else if ( type instanceof ZodIntersection ) {
3307
+ return (
3308
+ hasDiscriminator ( discriminator , type . _def . left ) ||
3309
+ hasDiscriminator ( discriminator , type . _def . right )
3310
+ ) ;
3311
+ } else if (
3312
+ type instanceof ZodUnion ||
3313
+ type instanceof ZodDiscriminatedUnion
3314
+ ) {
3315
+ return type . options . some ( ( optionType ) =>
3316
+ hasDiscriminator ( discriminator , optionType )
3317
+ ) ;
3318
+ } else {
3319
+ return false ;
3320
+ }
3321
+ } ;
3322
+
3323
+ const getDiscriminator = < T extends ZodTypeAny > (
3324
+ type : T ,
3325
+ values : Set < Primitive >
3326
+ ) => {
3226
3327
if ( type instanceof ZodLazy ) {
3227
- return getDiscriminator ( type . schema ) ;
3328
+ getDiscriminator ( type . schema , values ) ;
3228
3329
} else if ( type instanceof ZodEffects ) {
3229
- return getDiscriminator ( type . innerType ( ) ) ;
3330
+ getDiscriminator ( type . innerType ( ) , values ) ;
3230
3331
} else if ( type instanceof ZodLiteral ) {
3231
- return [ type . value ] ;
3332
+ values . add ( type . value ) ;
3232
3333
} else if ( type instanceof ZodEnum ) {
3233
- return type . options ;
3334
+ for ( const value of type . options ) {
3335
+ values . add ( value ) ;
3336
+ }
3234
3337
} else if ( type instanceof ZodNativeEnum ) {
3235
3338
// eslint-disable-next-line ban/ban
3236
- return util . objectValues ( type . enum as any ) ;
3339
+ for ( const value of util . objectValues ( type . enum as any ) ) {
3340
+ values . add ( value ) ;
3341
+ }
3237
3342
} else if ( type instanceof ZodDefault ) {
3238
- return getDiscriminator ( type . _def . innerType ) ;
3343
+ getDiscriminator ( type . _def . innerType , values ) ;
3239
3344
} else if ( type instanceof ZodUndefined ) {
3240
- return [ undefined ] ;
3345
+ values . add ( undefined ) ;
3241
3346
} else if ( type instanceof ZodNull ) {
3242
- return [ null ] ;
3347
+ values . add ( null ) ;
3243
3348
} else if ( type instanceof ZodOptional ) {
3244
- return [ undefined , ...getDiscriminator ( type . unwrap ( ) ) ] ;
3349
+ values . add ( undefined ) ;
3350
+ getDiscriminator ( type . unwrap ( ) , values ) ;
3245
3351
} else if ( type instanceof ZodNullable ) {
3246
- return [ null , ...getDiscriminator ( type . unwrap ( ) ) ] ;
3352
+ values . add ( null ) ;
3353
+ getDiscriminator ( type . unwrap ( ) , values ) ;
3247
3354
} else if ( type instanceof ZodBranded ) {
3248
- return getDiscriminator ( type . unwrap ( ) ) ;
3355
+ getDiscriminator ( type . unwrap ( ) , values ) ;
3249
3356
} else if ( type instanceof ZodReadonly ) {
3250
- return getDiscriminator ( type . unwrap ( ) ) ;
3357
+ getDiscriminator ( type . unwrap ( ) , values ) ;
3251
3358
} else if ( type instanceof ZodCatch ) {
3252
- return getDiscriminator ( type . _def . innerType ) ;
3253
- } else {
3254
- return [ ] ;
3359
+ getDiscriminator ( type . _def . innerType , values ) ;
3360
+ } else if ( type instanceof ZodIntersection ) {
3361
+ const leftValues = new Set < Primitive > ( ) ;
3362
+ const rightValues = new Set < Primitive > ( ) ;
3363
+
3364
+ getDiscriminator ( type . _def . left , leftValues ) ;
3365
+ getDiscriminator ( type . _def . right , rightValues ) ;
3366
+
3367
+ for ( const value of leftValues ) {
3368
+ if ( rightValues . has ( value ) ) {
3369
+ values . add ( value ) ;
3370
+ }
3371
+ }
3372
+ } else if (
3373
+ type instanceof ZodUnion ||
3374
+ type instanceof ZodDiscriminatedUnion
3375
+ ) {
3376
+ for ( const optionType of type . options ) {
3377
+ getDiscriminator ( optionType , values ) ;
3378
+ }
3255
3379
}
3256
3380
} ;
3257
3381
3258
- export type ZodDiscriminatedUnionOption < Discriminator extends string > =
3259
- ZodObject <
3260
- { [ key in Discriminator ] : ZodTypeAny } & ZodRawShape ,
3261
- UnknownKeysParam ,
3262
- ZodTypeAny
3263
- > ;
3264
-
3265
3382
export interface ZodDiscriminatedUnionDef <
3266
3383
Discriminator extends string ,
3267
3384
Options extends readonly ZodDiscriminatedUnionOption < string > [ ] = ZodDiscriminatedUnionOption < string > [ ]
@@ -3357,9 +3474,10 @@ export class ZodDiscriminatedUnion<
3357
3474
const optionsMap : Map < Primitive , Types [ number ] > = new Map ( ) ;
3358
3475
3359
3476
// try {
3477
+ const discriminatorValues = new Set < Primitive > ( ) ;
3360
3478
for ( const type of options ) {
3361
- const discriminatorValues = getDiscriminator ( type . shape [ discriminator ] ) ;
3362
- if ( ! discriminatorValues . length ) {
3479
+ getDiscriminatorValues ( discriminator , type , discriminatorValues ) ;
3480
+ if ( discriminatorValues . size < 1 ) {
3363
3481
throw new Error (
3364
3482
`A discriminator value for key \`${ discriminator } \` could not be extracted from all schema options`
3365
3483
) ;
@@ -3375,6 +3493,7 @@ export class ZodDiscriminatedUnion<
3375
3493
3376
3494
optionsMap . set ( value , type ) ;
3377
3495
}
3496
+ discriminatorValues . clear ( ) ;
3378
3497
}
3379
3498
3380
3499
return new ZodDiscriminatedUnion <
0 commit comments