Skip to content

Commit 5329156

Browse files
committed
DGS-22899 Fix support for wrapped Avro unions
1 parent ece423f commit 5329156

File tree

3 files changed

+478
-7
lines changed

3 files changed

+478
-7
lines changed

src/confluent_kafka/schema_registry/common/avro.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import defaultdict
66
from copy import deepcopy
77
from io import BytesIO
8-
from typing import Dict, Optional, Set, Union
8+
from typing import Dict, Optional, Set, Tuple, Union
99

1010
from fastavro import repository, validate
1111
from fastavro.schema import load_schema
@@ -42,6 +42,7 @@
4242
bytes, # 'bytes'
4343
list, # 'array'
4444
dict, # 'map' and 'record'
45+
tuple, # wrapped union type
4546
]
4647
AvroSchema = Union[str, list, dict]
4748

@@ -108,10 +109,13 @@ def transform(
108109
if field_ctx is not None:
109110
field_ctx.field_type = get_type(schema)
110111
if isinstance(schema, list):
111-
subschema = _resolve_union(schema, message)
112+
(subschema, submessage) = _resolve_union(schema, message)
112113
if subschema is None:
113114
return message
114-
return transform(ctx, subschema, message, field_transform)
115+
submessage = transform(ctx, subschema, submessage, field_transform)
116+
if isinstance(message, tuple) and len(message) == 2:
117+
return (message[0], submessage)
118+
return submessage
115119
elif isinstance(schema, dict):
116120
schema_type = schema.get("type")
117121
if schema_type == 'array':
@@ -207,14 +211,23 @@ def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool:
207211
return True
208212

209213

210-
def _resolve_union(schema: AvroSchema, message: AvroMessage) -> Optional[AvroSchema]:
214+
def _resolve_union(schema: AvroSchema, message: AvroMessage) -> Tuple[Optional[AvroSchema], AvroMessage]:
215+
is_wrapped_union = isinstance(message, tuple) and len(message) == 2
216+
is_typed_union = isinstance(message, dict) and '-type' in message
211217
for subschema in schema:
212218
try:
213-
validate(message, subschema)
219+
if is_wrapped_union:
220+
if isinstance(subschema, dict) and subschema["name"] == message[0]:
221+
return (subschema, message[1])
222+
elif is_typed_union:
223+
if isinstance(subschema, dict) and subschema["name"] == message['-type']:
224+
return (subschema, message)
225+
else:
226+
validate(message, subschema)
227+
return (subschema, message)
214228
except: # noqa: E722
215229
continue
216-
return subschema
217-
return None
230+
return (None, message)
218231

219232

220233
def get_inline_tags(schema: AvroSchema) -> Dict[str, Set[str]]:

tests/schema_registry/_async/test_avro_serdes.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,235 @@ async def test_avro_encryption_deterministic():
12751275
assert obj == obj2
12761276

12771277

1278+
async def test_avro_encryption_wrapped_union():
1279+
executor = FieldEncryptionExecutor.register_with_clock(FakeClock())
1280+
1281+
conf = {'url': _BASE_URL}
1282+
client = AsyncSchemaRegistryClient.new_client(conf)
1283+
ser_conf = {'auto.register.schemas': False, 'use.latest.version': True}
1284+
rule_conf = {'secret': 'mysecret'}
1285+
schema = {
1286+
"fields": [
1287+
{
1288+
"name": "id",
1289+
"type": "int"
1290+
},
1291+
{
1292+
"name": "result",
1293+
"type": [
1294+
"null",
1295+
{
1296+
"fields": [
1297+
{
1298+
"name": "code",
1299+
"type": "int"
1300+
},
1301+
{
1302+
"confluent:tags": [
1303+
"PII"
1304+
],
1305+
"name": "secret",
1306+
"type": [
1307+
"null",
1308+
"string"
1309+
]
1310+
}
1311+
],
1312+
"name": "Data",
1313+
"type": "record"
1314+
},
1315+
{
1316+
"fields": [
1317+
{
1318+
"name": "code",
1319+
"type": "int"
1320+
},
1321+
{
1322+
"name": "reason",
1323+
"type": [
1324+
"null",
1325+
"string"
1326+
]
1327+
}
1328+
],
1329+
"name": "Error",
1330+
"type": "record"
1331+
}
1332+
]
1333+
}
1334+
],
1335+
"name": "Result",
1336+
"namespace": "com.acme",
1337+
"type": "record"
1338+
}
1339+
1340+
rule = Rule(
1341+
"test-encrypt",
1342+
"",
1343+
RuleKind.TRANSFORM,
1344+
RuleMode.WRITEREAD,
1345+
"ENCRYPT",
1346+
["PII"],
1347+
RuleParams({
1348+
"encrypt.kek.name": "kek1",
1349+
"encrypt.kms.type": "local-kms",
1350+
"encrypt.kms.key.id": "mykey"
1351+
}),
1352+
None,
1353+
None,
1354+
"ERROR,NONE",
1355+
False
1356+
)
1357+
await client.register_schema(_SUBJECT, Schema(
1358+
json.dumps(schema),
1359+
"AVRO",
1360+
[],
1361+
None,
1362+
RuleSet(None, [rule])
1363+
))
1364+
1365+
obj = {
1366+
'id': 123,
1367+
'result': (
1368+
'com.acme.Data', {
1369+
'code': 456,
1370+
'secret': 'mypii'
1371+
}
1372+
)
1373+
}
1374+
ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf)
1375+
dek_client = executor.executor.client
1376+
ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE)
1377+
obj_bytes = await ser(obj, ser_ctx)
1378+
1379+
# reset encrypted fields
1380+
assert obj['result'][1]['secret'] != 'mypii'
1381+
# remove union wrapper
1382+
obj['result'] = {
1383+
'code': 456,
1384+
'secret': 'mypii'
1385+
}
1386+
1387+
deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf)
1388+
executor.executor.client = dek_client
1389+
obj2 = await deser(obj_bytes, ser_ctx)
1390+
assert obj == obj2
1391+
1392+
1393+
async def test_avro_encryption_typed_union():
1394+
executor = FieldEncryptionExecutor.register_with_clock(FakeClock())
1395+
1396+
conf = {'url': _BASE_URL}
1397+
client = AsyncSchemaRegistryClient.new_client(conf)
1398+
ser_conf = {'auto.register.schemas': False, 'use.latest.version': True}
1399+
rule_conf = {'secret': 'mysecret'}
1400+
schema = {
1401+
"fields": [
1402+
{
1403+
"name": "id",
1404+
"type": "int"
1405+
},
1406+
{
1407+
"name": "result",
1408+
"type": [
1409+
"null",
1410+
{
1411+
"fields": [
1412+
{
1413+
"name": "code",
1414+
"type": "int"
1415+
},
1416+
{
1417+
"confluent:tags": [
1418+
"PII"
1419+
],
1420+
"name": "secret",
1421+
"type": [
1422+
"null",
1423+
"string"
1424+
]
1425+
}
1426+
],
1427+
"name": "Data",
1428+
"type": "record"
1429+
},
1430+
{
1431+
"fields": [
1432+
{
1433+
"name": "code",
1434+
"type": "int"
1435+
},
1436+
{
1437+
"name": "reason",
1438+
"type": [
1439+
"null",
1440+
"string"
1441+
]
1442+
}
1443+
],
1444+
"name": "Error",
1445+
"type": "record"
1446+
}
1447+
]
1448+
}
1449+
],
1450+
"name": "Result",
1451+
"namespace": "com.acme",
1452+
"type": "record"
1453+
}
1454+
1455+
rule = Rule(
1456+
"test-encrypt",
1457+
"",
1458+
RuleKind.TRANSFORM,
1459+
RuleMode.WRITEREAD,
1460+
"ENCRYPT",
1461+
["PII"],
1462+
RuleParams({
1463+
"encrypt.kek.name": "kek1",
1464+
"encrypt.kms.type": "local-kms",
1465+
"encrypt.kms.key.id": "mykey"
1466+
}),
1467+
None,
1468+
None,
1469+
"ERROR,NONE",
1470+
False
1471+
)
1472+
await client.register_schema(_SUBJECT, Schema(
1473+
json.dumps(schema),
1474+
"AVRO",
1475+
[],
1476+
None,
1477+
RuleSet(None, [rule])
1478+
))
1479+
1480+
obj = {
1481+
'id': 123,
1482+
'result': {
1483+
'-type': 'com.acme.Data',
1484+
'code': 456,
1485+
'secret': 'mypii'
1486+
}
1487+
}
1488+
ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf)
1489+
dek_client = executor.executor.client
1490+
ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE)
1491+
obj_bytes = await ser(obj, ser_ctx)
1492+
1493+
# reset encrypted fields
1494+
assert obj['result']['secret'] != 'mypii'
1495+
# remove union wrapper
1496+
obj['result'] = {
1497+
'code': 456,
1498+
'secret': 'mypii'
1499+
}
1500+
1501+
deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf)
1502+
executor.executor.client = dek_client
1503+
obj2 = await deser(obj_bytes, ser_ctx)
1504+
assert obj == obj2
1505+
1506+
12781507
async def test_avro_encryption_cel():
12791508
executor = FieldEncryptionExecutor.register_with_clock(FakeClock())
12801509

0 commit comments

Comments
 (0)