From 0e4b257e8fe78bde7b450d4b7ea225a908e405bc Mon Sep 17 00:00:00 2001 From: Bill Fraser Date: Mon, 1 Apr 2024 13:03:49 -0700 Subject: [PATCH] add type annotations to generator python code (#148) --- .github/workflows/cargo-test.yml | 2 +- generator/rust.py | 67 ++++++----- generator/rust.stoneg.py | 108 ++++++++++------- generator/test.stoneg.py | 199 ++++++++++++++++++++++--------- 4 files changed, 245 insertions(+), 131 deletions(-) diff --git a/.github/workflows/cargo-test.yml b/.github/workflows/cargo-test.yml index 8109008..77df172 100644 --- a/.github/workflows/cargo-test.yml +++ b/.github/workflows/cargo-test.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5.0.0 with: - python-version: 3.8 + python-version: "3.10" - name: Install Python dependencies run: | diff --git a/generator/rust.py b/generator/rust.py index 0ef56ab..51ce791 100644 --- a/generator/rust.py +++ b/generator/rust.py @@ -1,4 +1,6 @@ +from abc import ABC from contextlib import contextmanager +from typing import Optional, Iterator from stone import ir from stone.backend import CodeBackend @@ -31,25 +33,32 @@ EXTRA_DISPLAY_TYPES = ["auth::RateLimitReason"] -class RustHelperBackend(CodeBackend): +def _arg_list(args: list[str]) -> str: + arg_list = '' + for arg in args: + arg_list += (', ' if arg_list != '' else '') + arg + return arg_list + + +class RustHelperBackend(CodeBackend, ABC): """ A superclass for RustGenerator and TestGenerator to contain some common rust-generation methods. """ - def _dent_len(self): + def _dent_len(self) -> int: if self.tabs_for_indents: return 4 * self.cur_indent else: return self.cur_indent - def _arg_list(self, args): - arg_list = '' - for arg in args: - arg_list += (', ' if arg_list != '' else '') + arg - return arg_list - @contextmanager - def emit_rust_function_def(self, name, args=None, return_type=None, access=None): + def emit_rust_function_def( + self, + name: str, + args: Optional[list[str]] = None, + return_type: Optional[str] = None, + access: Optional[str] = None, + ) -> Iterator[None]: """ A Rust function definition context manager. """ @@ -60,7 +69,7 @@ def emit_rust_function_def(self, name, args=None, return_type=None, access=None) else: access += ' ' ret = f' -> {return_type}' if return_type is not None else '' - one_line = f'{access}fn {name}({self._arg_list(args)}){ret} {{' + one_line = f'{access}fn {name}({_arg_list(args)}){ret} {{' if self._dent_len() + len(one_line) < 100: # one-line version self.emit(one_line) @@ -76,35 +85,35 @@ def emit_rust_function_def(self, name, args=None, return_type=None, access=None) yield self.emit('}') - def emit_rust_fn_call(self, func_name, args, end=None): + def emit_rust_fn_call(self, func_name: str, args: list[str], end: Optional[str] = None) -> None: """ Emit a Rust function call. Wraps arguments to multiple lines if it gets too long. If `end` is None, the call ends without any semicolon. """ if end is None: end = '' - one_line = f'{func_name}({self._arg_list(args)}){end}' + one_line = f'{func_name}({_arg_list(args)}){end}' if self._dent_len() + len(one_line) < 100: self.emit(one_line) else: self.emit(func_name + '(') with self.indent(): for i, arg in enumerate(args): - self.emit(arg + (',' if i+1 < len(args) else (')' + end))) + self.emit(arg + (',' if i + 1 < len(args) else (')' + end))) - def is_enum_type(self, typ): + def is_enum_type(self, typ: ir.DataType) -> bool: return isinstance(typ, ir.Union) or \ (isinstance(typ, ir.Struct) and typ.has_enumerated_subtypes()) - def is_nullary_struct(self, typ): + def is_nullary_struct(self, typ: ir.DataType) -> bool: return isinstance(typ, ir.Struct) and not typ.all_fields - def is_closed_union(self, typ): + def is_closed_union(self, typ: ir.DataType) -> bool: return (isinstance(typ, ir.Union) and typ.closed) \ or (isinstance(typ, ir.Struct) and typ.has_enumerated_subtypes() and not typ.is_catch_all()) - def get_enum_variants(self, typ): + def get_enum_variants(self, typ: ir.DataType) -> list[ir.StructField]: if isinstance(typ, ir.Union): return typ.all_fields elif isinstance(typ, ir.Struct) and typ.has_enumerated_subtypes(): @@ -112,49 +121,49 @@ def get_enum_variants(self, typ): else: return [] - def namespace_name(self, ns): + def namespace_name(self, ns: ir.ApiNamespace) -> str: return self.namespace_name_raw(ns.name) - def namespace_name_raw(self, ns_name): + def namespace_name_raw(self, ns_name: str) -> str: name = fmt_underscores(ns_name) if name in RUST_RESERVED_WORDS + RUST_GLOBAL_NAMESPACE: name = 'dbx_' + name return name - def struct_name(self, struct): + def struct_name(self, struct: ir.Struct) -> str: name = fmt_pascal(struct.name) if name in RUST_RESERVED_WORDS + RUST_GLOBAL_NAMESPACE: name += 'Struct' return name - def enum_name(self, union): + def enum_name(self, union: ir.DataType) -> str: name = fmt_pascal(union.name) if name in RUST_RESERVED_WORDS + RUST_GLOBAL_NAMESPACE: name += 'Union' return name - def field_name(self, field): + def field_name(self, field: ir.StructField) -> str: return self.field_name_raw(field.name) - def field_name_raw(self, name): + def field_name_raw(self, name: str) -> str: name = fmt_underscores(name) if name in RUST_RESERVED_WORDS: name += '_field' return name - def enum_variant_name(self, field): + def enum_variant_name(self, field: ir.UnionField) -> str: return self.enum_variant_name_raw(field.name) - def enum_variant_name_raw(self, name): + def enum_variant_name_raw(self, name: str) -> str: name = fmt_pascal(name) if name in RUST_RESERVED_WORDS: name += 'Variant' return name - def route_name(self, route): + def route_name(self, route: ir.ApiRoute) -> str: return self.route_name_raw(route.name, route.version) - def route_name_raw(self, name, version): + def route_name_raw(self, name: str, version: int) -> str: name = fmt_underscores(name) if version > 1: name = f'{name}_v{version}' @@ -162,13 +171,13 @@ def route_name_raw(self, name, version): name = 'do_' + name return name - def alias_name(self, alias): + def alias_name(self, alias: ir.Alias) -> str: name = fmt_pascal(alias.name) if name in RUST_RESERVED_WORDS + RUST_GLOBAL_NAMESPACE: name += 'Alias' return name - def rust_type(self, typ, current_namespace, no_qualify=False, crate='crate'): + def rust_type(self, typ: ir.DataType, current_namespace: str, no_qualify: bool = False, crate: str ='crate') -> str: if isinstance(typ, ir.Nullable): t = self.rust_type(typ.data_type, current_namespace, no_qualify, crate) return f'Option<{t}>' diff --git a/generator/rust.stoneg.py b/generator/rust.stoneg.py index 03c65da..20c4690 100644 --- a/generator/rust.stoneg.py +++ b/generator/rust.stoneg.py @@ -1,5 +1,6 @@ import contextlib from contextlib import contextmanager +from typing import Iterator, Optional, Sequence from rust import RustHelperBackend, EXTRA_DISPLAY_TYPES, REQUIRED_NAMESPACES from stone import ir @@ -9,19 +10,23 @@ DERIVE_TRAITS = ['Debug', 'Clone', 'PartialEq'] -def fmt_shouting_snake(name): +def fmt_shouting_snake(name: str) -> str: return '_'.join([word.upper() for word in split_words(name)]) class RustBackend(RustHelperBackend): - def __init__(self, target_folder_path, args): - super(RustBackend, self).__init__(target_folder_path, args) - self._modules = [] + def __init__(self, target_folder_path: str, args: Optional[Sequence[str]]) -> None: + super().__init__(target_folder_path, args) self.preserve_aliases = True + self._all_types: dict[str, dict[str, ir.UserDefined]] = dict() + self._current_namespace: str = '' + self._error_types: set[Optional[ir.DataType]] = set() + self._modules: list[str] = [] + # File Generators - def generate(self, api): + def generate(self, api: ir.Api) -> None: self._all_types = {ns.name: {typ.name: typ for typ in ns.data_types} for ns in api.namespaces.values()} @@ -44,7 +49,7 @@ def generate(self, api): self._emit_namespace(namespace) self._generate_mod_file() - def _generate_mod_file(self): + def _generate_mod_file(self) -> None: with self.output_to_relative_path('mod.rs'): self._emit_header() self.emit('#![allow(missing_docs)]') @@ -65,7 +70,7 @@ def _generate_mod_file(self): # Type Emitters - def _emit_namespace(self, namespace): + def _emit_namespace(self, namespace: ir.ApiNamespace) -> None: ns = self.namespace_name(namespace) with self.output_to_relative_path(ns + '.rs'): self._current_namespace = namespace.name @@ -97,7 +102,7 @@ def _emit_namespace(self, namespace): self._modules.append(namespace.name) - def _emit_header(self): + def _emit_header(self) -> None: self.emit('// DO NOT EDIT') self.emit('// This file was @generated by Stone') self.emit() @@ -108,7 +113,7 @@ def _emit_header(self): self.emit(')]') self.emit() - def _emit_struct(self, struct): + def _emit_struct(self, struct: ir.Struct) -> None: struct_name = self.struct_name(struct) self._emit_doc(struct.doc) derive_traits = list(DERIVE_TRAITS) @@ -148,7 +153,7 @@ def _emit_struct(self, struct): else: self._impl_from_for_struct(struct, struct.parent_type) - def _emit_polymorphic_struct(self, struct): + def _emit_polymorphic_struct(self, struct: ir.Struct) -> None: enum_name = self.enum_name(struct) self._emit_doc(struct.doc) derive_traits = list(DERIVE_TRAITS) @@ -168,7 +173,7 @@ def _emit_polymorphic_struct(self, struct): self._impl_serde_for_polymorphic_struct(struct) - def _emit_union(self, union): + def _emit_union(self, union: ir.Union) -> None: enum_name = self.enum_name(union) self._emit_doc(union.doc) derive_traits = list(DERIVE_TRAITS) @@ -204,7 +209,13 @@ def _emit_union(self, union): if union.parent_type: self._impl_from_for_union(union, union.parent_type) - def _emit_route(self, ns, fn, auth_trait = None): + def _emit_route(self, ns: str, fn: ir.ApiRoute, auth_trait: Optional[str] = None) -> None: + # work around lazy init messing with mypy + assert fn.attrs is not None + assert fn.arg_data_type is not None + assert fn.result_data_type is not None + assert fn.error_data_type is not None + route_name = self.route_name(fn) host = fn.attrs.get('host', 'api') if host == 'api': @@ -324,11 +335,12 @@ def _emit_route(self, ns, fn, auth_trait = None): raise RuntimeError(f'ERROR: unknown route style: {style}') self.emit() - def _emit_alias(self, alias): + def _emit_alias(self, alias: ir.Alias) -> None: alias_name = self.alias_name(alias) + assert isinstance(alias.data_type, ir.DataType) self.emit(f'pub type {alias_name} = {self._rust_type(alias.data_type)};') - def _emit_other_variant(self): + def _emit_other_variant(self) -> None: self.emit_wrapped_text( 'Catch-all used for unrecognized values returned from the server.' ' Encountering this value typically indicates that this SDK version is' @@ -338,7 +350,7 @@ def _emit_other_variant(self): # Serialization - def _impl_serde_for_struct(self, struct): + def _impl_serde_for_struct(self, struct: ir.Struct) -> None: """ Emit internal_deserialize() and possibly internal_deserialize_opt(). internal_deserialize[_opt] takes a map and deserializes it into the struct. It reads the @@ -458,6 +470,7 @@ def _impl_serde_for_struct(self, struct): self.emit(f's.serialize_field("{field.name}", val)?;') else: fieldval = f'self.{self.field_name(field)}' + ctx: contextlib.AbstractContextManager if field.has_default: if isinstance(field.data_type, ir.String) and not field.default: ctx = self.block(f'if !{fieldval}.is_empty()') @@ -503,7 +516,7 @@ def _impl_serde_for_struct(self, struct): self.emit('s.end()') self.emit() - def _impl_serde_for_polymorphic_struct(self, struct): + def _impl_serde_for_polymorphic_struct(self, struct: ir.Struct) -> None: type_name = self.enum_name(struct) with self._impl_deserialize(type_name): self.emit('// polymorphic struct deserializer') @@ -571,7 +584,7 @@ def _impl_serde_for_polymorphic_struct(self, struct): '"cannot serialize unknown variant"))') self.emit() - def _impl_serde_for_union(self, union): + def _impl_serde_for_union(self, union: ir.Union) -> None: type_name = self.enum_name(union) with self._impl_deserialize(type_name): self.emit('// union deserializer') @@ -722,7 +735,7 @@ def _impl_serde_for_union(self, union): # "extends" for structs means the subtype adds additional fields to the supertype, so we can # convert from the subtype to the supertype - def _impl_from_for_struct(self, struct, parent): + def _impl_from_for_struct(self, struct: ir.Struct, parent: ir.Struct) -> None: subtype = self._rust_type(struct) supertype = self._rust_type(parent) self.emit(f'// struct extends {supertype}') @@ -738,13 +751,14 @@ def _impl_from_for_struct(self, struct, parent): # "extends" for polymorphic structs means it's one of the supertype's variants, so we can # convert from the subtype to the supertype. - def _impl_from_for_polymorphic_struct(self, struct, parent): - subtype = self._rust_type(struct) + def _impl_from_for_polymorphic_struct(self, struct: ir.Struct, parent: ir.Struct) -> None: + thistype = self._rust_type(struct) supertype = self._rust_type(parent) self.emit(f'// struct extends polymorphic struct {supertype}') - with self.block(f'impl From<{subtype}> for {supertype}'): - with self.block(f'fn from(subtype: {subtype}) -> Self'): + with self.block(f'impl From<{thistype}> for {supertype}'): + with self.block(f'fn from(subtype: {thistype}) -> Self'): for subtype in parent.get_enumerated_subtypes(): + assert isinstance(subtype, ir.UnionField) if subtype.data_type != struct: continue variant_name = self.enum_variant_name(subtype) @@ -752,7 +766,7 @@ def _impl_from_for_polymorphic_struct(self, struct, parent): # "extends" for unions means the subtype adds additional variants, so we can convert from the # supertype to the subtype. - def _impl_from_for_union(self, union, parent): + def _impl_from_for_union(self, union: ir.Union, parent: ir.Union) -> None: subtype = self._rust_type(union) supertype = self._rust_type(parent) self.emit(f'// union extends {supertype}') @@ -766,7 +780,7 @@ def _impl_from_for_union(self, union, parent): # Helpers - def _emit_doc(self, doc_string, prefix='///'): + def _emit_doc(self, doc_string: Optional[str], prefix: str = '///') -> None: if doc_string is not None: for idx, chunk in enumerate(doc_string.split('\n\n')): if idx != 0: @@ -776,11 +790,11 @@ def _emit_doc(self, doc_string, prefix='///'): self.process_doc(chunk, docf), prefix=prefix + ' ', width=100) - def _docf(self, tag, val): + def _docf(self, tag: str, val: str) -> str: if tag == 'route': if ':' in val: - val, version = val.split(':') - version = int(version) + val, vstr = val.split(':') + version = int(vstr) else: version = 1 if '.' in val: @@ -795,12 +809,12 @@ def _docf(self, tag, val): elif tag == 'field': if '.' in val: cls_name, field = val.rsplit('.', 1) - assert('.' not in cls_name) # dunno if this is even allowed, but we don't handle it + assert '.' not in cls_name # dunno if this is even allowed, but we don't handle it typ = self._all_types[self._current_namespace][cls_name] type_name = self._rust_type(typ) if self.is_enum_type(typ): if isinstance(typ, ir.Struct) and typ.has_enumerated_subtypes() \ - and field in (field.name for field in typ.fields): + and typ.fields and field in (field.name for field in typ.fields): # This is actually a link to a field in a polymorphic struct, not a enum # variant. Because Rust doesn't have polymorphism, we make the fields be # present on all enum variants, so this is a link to a field in the current @@ -851,7 +865,7 @@ def _docf(self, tag, val): return f'`{val}`' @contextmanager - def _impl_deserialize(self, type_name): + def _impl_deserialize(self, type_name: str) -> Iterator[None]: with self.block(f'impl<\'de> ::serde::de::Deserialize<\'de> for {type_name}'), \ self.emit_rust_function_def( 'deserialize>', @@ -860,7 +874,7 @@ def _impl_deserialize(self, type_name): yield @contextmanager - def _impl_serialize(self, type_name): + def _impl_serialize(self, type_name: str) -> Iterator[None]: with self.block(f'impl ::serde::ser::Serialize for {type_name}'), \ self.emit_rust_function_def( 'serialize', @@ -868,7 +882,7 @@ def _impl_serialize(self, type_name): 'Result'): yield - def _impl_default_for_struct(self, struct): + def _impl_default_for_struct(self, struct: ir.Struct) -> None: struct_name = self.struct_name(struct) with self.block(f'impl Default for {struct_name}'): with self.emit_rust_function_def('default', [], 'Self'): @@ -878,10 +892,12 @@ def _impl_default_for_struct(self, struct): value = self._default_value(field) self.emit(f'{name}: {value},') - def _impl_struct(self, struct): - return self.block(f'impl {self.struct_name(struct)}') + @contextmanager + def _impl_struct(self, struct: ir.Struct) -> Iterator[None]: + with self.block(f'impl {self.struct_name(struct)}'): + yield - def _emit_new_for_struct(self, struct): + def _emit_new_for_struct(self, struct: ir.Struct) -> None: struct_name = self.struct_name(struct) first = True @@ -927,11 +943,11 @@ def _emit_new_for_struct(self, struct): self.emit(f'self.{field_name} = {value};') self.emit('self') - def _default_value(self, field): + def _default_value(self, field: ir.StructField) -> str: if isinstance(field.data_type, ir.Nullable): return 'None' elif ir.is_numeric_type(ir.unwrap_aliases(field.data_type)[0]): - return field.default + return str(field.default) elif isinstance(field.default, ir.TagRef): default_variant = None for variant in field.default.union_data_type.all_fields: @@ -958,9 +974,9 @@ def _default_value(self, field): print(f' in field: {field}') if isinstance(field.data_type, ir.Alias): print(' unwrapped alias:', ir.unwrap_aliases(field.data_type)[0]) - return field.default + return str(field.default) - def _can_derive_eq(self, typ): + def _can_derive_eq(self, typ: ir.DataType) -> bool: if isinstance(typ, ir.Float32) or isinstance(typ, ir.Float64): # These are the only primitive types that don't have strict equality. return False @@ -968,7 +984,7 @@ def _can_derive_eq(self, typ): # Check for various kinds of compound types and check all fields: if hasattr(typ, "data_type"): return self._can_derive_eq(typ.data_type) - if hasattr(typ, "has_enumerated_subtypes") and typ.has_enumerated_subtypes(): + if isinstance(typ, ir.Struct) and typ.has_enumerated_subtypes(): for styp in typ.get_enumerated_subtypes(): if not self._can_derive_eq(styp): return False @@ -982,7 +998,7 @@ def _can_derive_eq(self, typ): # All other primitive types are strict-comparable. return True - def _needs_explicit_default(self, field): + def _needs_explicit_default(self, field: ir.StructField) -> bool: if isinstance(field.data_type, ir.Nullable): # default is always None return False @@ -991,7 +1007,7 @@ def _needs_explicit_default(self, field): elif ir.is_numeric_type(ir.unwrap_aliases(field.data_type)[0]): return field.default != 0 elif isinstance(field.data_type, ir.Boolean): - return field.default + return bool(field.default) elif isinstance(field.data_type, ir.String): return len(field.default) != 0 else: @@ -1000,10 +1016,10 @@ def _needs_explicit_default(self, field): print('its default is', field.default) return True - def _is_error_type(self, typ): + def _is_error_type(self, typ: ir.DataType) -> bool: return typ in self._error_types - def _impl_error(self, typ): + def _impl_error(self, typ: ir.DataType) -> None: type_name = self.enum_name(typ) # N.B.: error types SHOULD always be enums, but there's at least one type used as the error @@ -1028,7 +1044,7 @@ def _impl_error(self, typ): self.emit() self._impl_display(typ) - def _impl_display(self, typ): + def _impl_display(self, typ: ir.DataType) -> None: type_name = self.enum_name(typ) variants = self.get_enum_variants(typ) @@ -1112,5 +1128,5 @@ def _impl_display(self, typ): # Naming Rules - def _rust_type(self, typ, no_qualify=False): + def _rust_type(self, typ: ir.DataType, no_qualify: bool = False) -> str: return self.rust_type(typ, self._current_namespace, no_qualify) diff --git a/generator/test.stoneg.py b/generator/test.stoneg.py index 688e18b..8364fed 100644 --- a/generator/test.stoneg.py +++ b/generator/test.stoneg.py @@ -1,11 +1,16 @@ +from __future__ import annotations + import datetime import os.path import re import string import sys +from _ast import Module +from contextlib import contextmanager +from typing import Any, Dict, Iterator, Optional, Protocol try: - import re._parser as sre_parse + import re._parser as sre_parse # type: ignore except ImportError: # Python < 3.11 import sre_parse @@ -17,7 +22,7 @@ class Permissions(object): @property - def permissions(self): + def permissions(self) -> list[str]: # For generating tests, make sure we include any internal # fields/structs if we're using internal specs. If we're not using # internal specs, this is a no-op, so just do it all the time. Note @@ -26,8 +31,22 @@ def permissions(self): return ['internal'] +#JsonEncodeType = Callable[[Any, object, list[str], **kwargs], str] +class JsonEncodeType(Protocol): + def __call__( + self, + data_type: Any, + obj: object, + permissions: Permissions, + _1: Any = None, + _2: bool = False, + _3: bool = False, + ) -> str: + ... + + class TestBackend(RustHelperBackend): - def __init__(self, target_folder_path, args): + def __init__(self, target_folder_path: str, args: list[str]) -> None: super(TestBackend, self).__init__(target_folder_path, args) # Don't import other generators until here, otherwise stone.cli will @@ -36,16 +55,16 @@ def __init__(self, target_folder_path, args): self.target_path = target_folder_path self.ref_path = os.path.join(target_folder_path, 'reference') self.reference = PythonTypesBackend(self.ref_path, args + ["--package", "reference"]) - self.reference_impls = {} + self.reference_impls: Dict[str, Module] = {} # Make test values for this type. # If it's a union or polymorphic type, make values for all variants. # If the type or any of its variants have optional fields, also make two versions: one with all # fields filled in, and one with all optional fields omitted. This helps catch backwards-compat # issues as well as checking (de)serialization of None. - def make_test_values(self, typ): - vals = [] - if ir.is_struct_type(typ): + def make_test_values(self, typ: ir.Struct | ir.Union) -> list[TestStruct | TestUnion | TestPolymorphicStruct]: + vals: list[TestStruct | TestUnion | TestPolymorphicStruct] = [] + if isinstance(typ, ir.Struct): if typ.has_enumerated_subtypes(): for variant in typ.get_enumerated_subtypes(): vals.append(TestPolymorphicStruct( @@ -58,7 +77,7 @@ def make_test_values(self, typ): vals.append(TestStruct(self, typ, self.reference_impls, no_optional_fields=False)) if typ.all_optional_fields: vals.append(TestStruct(self, typ, self.reference_impls, no_optional_fields=True)) - elif ir.is_union_type(typ): + elif isinstance(typ, ir.Union): for variant in typ.all_fields: vals.append(TestUnion( self, typ, self.reference_impls, variant, no_optional_fields=False)) @@ -66,7 +85,7 @@ def make_test_values(self, typ): raise RuntimeError(f'ERROR: type {typ} is neither struct nor union') return vals - def generate(self, api): + def generate(self, api: ir.Api) -> None: print('Generating Python reference code') self.reference.generate(api) with self.output_to_relative_path('reference/__init__.py'): @@ -76,13 +95,13 @@ def generate(self, api): sys.path.insert(0, self.target_path) sys.path.insert(1, "stone") from stone.backends.python_rsrc.stone_serializers import json_encode - for ns in api.namespaces: - print('\t' + ns) - python_ns = ns - if ns == 'async': + for ns_name in api.namespaces: + print('\t' + ns_name) + python_ns = ns_name + if ns_name == 'async': # hack to work around 'async' being a Python3 keyword python_ns = 'async_' - self.reference_impls[ns] = __import__('reference.'+python_ns).__dict__[python_ns] + self.reference_impls[ns_name] = __import__('reference.'+python_ns).__dict__[python_ns] print('Generating test code') for ns in api.namespaces.values(): @@ -93,6 +112,7 @@ def generate(self, api): self._emit_tests(ns, typ, json_encode) if self.is_closed_union(typ): + assert isinstance(typ, ir.Struct | ir.Union) self._emit_closed_union_test(ns, typ) for route in ns.routes: @@ -103,13 +123,13 @@ def generate(self, api): self.emit('#[path = "../noop_client.rs"]') self.emit('pub mod noop_client;') self.emit() - for ns in api.namespaces: - if ns not in REQUIRED_NAMESPACES: - self.emit(f'#[cfg(feature = "dbx_{ns}")]') - self.emit(f'mod {self.namespace_name_raw(ns)};') + for ns_name in api.namespaces: + if ns_name not in REQUIRED_NAMESPACES: + self.emit(f'#[cfg(feature = "dbx_{ns_name}")]') + self.emit(f'mod {self.namespace_name_raw(ns_name)};') self.emit() - def _emit_header(self): + def _emit_header(self) -> None: self.emit('// DO NOT EDIT') self.emit('// This file was @generated by Stone') self.emit() @@ -125,7 +145,12 @@ def _emit_header(self): self.emit(')]') self.emit() - def _emit_tests(self, ns, typ, json_encode): + def _emit_tests( + self, + ns: ir.ApiNamespace, + typ: ir.UserDefined, + json_encode: JsonEncodeType, + ) -> None: ns_name = self.namespace_name(ns) type_name = self.struct_name(typ) @@ -179,7 +204,7 @@ def _emit_tests(self, ns, typ, json_encode): self.emit('assert!(::serde_json::to_string(&x).is_err());') self.emit() - def _emit_closed_union_test(self, ns, typ): + def _emit_closed_union_test(self, ns: ir.ApiNamespace, typ: ir.Struct | ir.Union) -> None: ns_name = self.namespace_name(ns) type_name = self.struct_name(typ) with self._test_fn("ClosedUnion_" + type_name): @@ -204,11 +229,23 @@ def _emit_closed_union_test(self, ns, typ): self.emit('}') self.emit() - def _emit_route_test(self, ns, route, json_encode, auth_type=None): + def _emit_route_test( + self, + ns: ir.ApiNamespace, + route: ir.ApiRoute, + json_encode: JsonEncodeType, + auth_type: Optional[str] = None, + ) -> None: + assert route.arg_data_type + assert route.result_data_type + assert route.error_data_type + assert route.attrs + arg_typ = self.rust_type(route.arg_data_type, '', crate='dropbox_sdk') if arg_typ == '()': json = "{}" else: + assert isinstance(route.arg_data_type, ir.Union | ir.Struct) arg_value = self.make_test_values(route.arg_data_type)[0] pyname = fmt_py_class(route.arg_data_type.name) json = json_encode( @@ -268,27 +305,36 @@ def _emit_route_test(self, ns, route, json_encode, auth_type=None): self.emit('assert!(matches!(ret, Err(dropbox_sdk::Error::HttpClient(..))));') self.emit() - def _test_fn(self, name): + @contextmanager + def _test_fn(self, name: str) -> Iterator[None]: self.emit('#[test]') - return self.emit_rust_function_def('test_' + name) + with self.emit_rust_function_def('test_' + name): + yield -def _typ_or_void(typ): +def _typ_or_void(typ: ir.DataType) -> ir.DataType: if typ is None: - return ir.Void + return ir.Void() else: return typ class TestField(object): - def __init__(self, name, python_value, test_value, stone_type, option): + def __init__( + self, + name: str, + python_value: Any, + test_value: TestStruct | TestUnion | TestPolymorphicStruct, + stone_type: ir.DataType, + option: bool, + ) -> None: self.name = name self.value = python_value self.test_value = test_value self.typ = stone_type self.option = option - def emit_assert(self, codegen, expression_path): + def emit_assert(self, codegen: RustHelperBackend, expression_path: str) -> None: extra = ('.' + self.name) if self.name else '' if self.option: if self.value is None: @@ -307,6 +353,7 @@ def emit_assert(self, codegen, expression_path): elif ir.is_boolean_type(self.typ): codegen.emit(f'assert_eq!({expression}, {"true" if self.value else "false"});') elif ir.is_timestamp_type(self.typ): + assert isinstance(self.typ, ir.Timestamp) codegen.emit(f'assert_eq!({expression}.as_str(), "{self.value.strftime(self.typ.format)}");') elif ir.is_bytes_type(self.typ): codegen.emit(f'assert_eq!(&{expression}, &[{",".join(str(x) for x in self.value)}]);') @@ -316,24 +363,30 @@ def emit_assert(self, codegen, expression_path): class TestValue(object): - def __init__(self, rust_generator): + def __init__(self, rust_generator: RustHelperBackend) -> None: self.rust_generator = rust_generator - self.fields = [] + self.fields: list[TestField] = [] self.value = None - def emit_asserts(self, codegen, expression_path): + def emit_asserts(self, codegen: RustHelperBackend, expression_path: str) -> None: raise NotImplementedError('you\'re supposed to implement TestValue.emit_asserts') - def is_serializable(self): + def is_serializable(self) -> bool: # Not all types can round-trip back from Rust to JSON. return True - def test_suffix(self): + def test_suffix(self) -> str: return "" class TestStruct(TestValue): - def __init__(self, rust_generator: TestBackend, stone_type: ir.Struct, reference_impls, no_optional_fields): + def __init__( + self, + rust_generator: RustHelperBackend, + stone_type: ir.Struct, + reference_impls: Dict[str, Any], + no_optional_fields: bool, + ) -> None: super(TestStruct, self).__init__(rust_generator) if stone_type.has_enumerated_subtypes(): @@ -370,18 +423,26 @@ def __init__(self, rust_generator: TestBackend, stone_type: ir.Struct, reference raise RuntimeError(f'Error generating value for {stone_type.name}.{field.name}: {e}') self.fields.append(field_value) - def emit_asserts(self, codegen, expression_path): + def emit_asserts(self, codegen: RustHelperBackend, expression_path: str) -> None: for field in self.fields: field.emit_assert(codegen, expression_path) - def test_suffix(self): + def test_suffix(self) -> str: if self._no_optional_fields: return "_OnlyRequiredFields" else: return "" + class TestUnion(TestValue): - def __init__(self, rust_generator, stone_type, reference_impls, variant, no_optional_fields): + def __init__( + self, + rust_generator: RustHelperBackend, + stone_type: ir.Struct | ir.Union, # Struct because TestPolymorphicStruct also uses this + reference_impls: Dict[str, Module], + variant: ir.UnionField, + no_optional_fields: bool, + ) -> None: super(TestUnion, self).__init__(rust_generator) self._stone_type = stone_type self._reference_impls = reference_impls @@ -402,7 +463,7 @@ def __init__(self, rust_generator, stone_type, reference_impls, variant, no_opti self.value = self.get_from_inner_value(variant.name, self._inner_value) - def get_from_inner_value(self, variant_name, generated_field): + def get_from_inner_value(self, variant_name: str, generated_field: TestField) -> Any: pyname = fmt_py_class(self._stone_type.name) try: return self._reference_impls[self._stone_type.namespace.name] \ @@ -410,12 +471,13 @@ def get_from_inner_value(self, variant_name, generated_field): except Exception as e: raise RuntimeError(f'Error generating value for {self._stone_type.name}.{variant_name}: {e}') - def has_other_variants(self): - return len(self._stone_type.all_fields) > 1 or not self._stone_type.closed + def has_other_variants(self) -> bool: + return len(self._stone_type.all_fields) > 1 \ + or (isinstance(self._stone_type, ir.Union) and not self._stone_type.closed) - def emit_asserts(self, codegen, expression_path): + def emit_asserts(self, codegen: RustHelperBackend, expression_path: str) -> None: if expression_path[0] == '(' and expression_path[-1] == ')': - expression_path = expression_path[1:-1] # strip off superfluous parens + expression_path = expression_path[1:-1] # strip off superfluous parens with codegen.block(f'match {expression_path}'): path = f'::dropbox_sdk::{self._rust_namespace_name}::{self._rust_name}::{self._rust_variant_name}' @@ -430,10 +492,10 @@ def emit_asserts(self, codegen, expression_path): if self.has_other_variants(): codegen.emit('_ => panic!("wrong variant")') - def is_serializable(self): + def is_serializable(self) -> bool: return not self._variant.catch_all - def test_suffix(self): + def test_suffix(self) -> str: suf = "_" + self._rust_variant_name if self._no_optional_fields: suf += "_OnlyRequiredFields" @@ -441,16 +503,21 @@ def test_suffix(self): class TestPolymorphicStruct(TestUnion): - def get_from_inner_value(self, variant_name, generated_field): + def get_from_inner_value(self, variant_name: str, generated_field: TestField) -> Any: return generated_field.value - def has_other_variants(self): + def has_other_variants(self) -> bool: return len(self._stone_type.get_enumerated_subtypes()) > 1 \ or self._stone_type.is_catch_all() class TestList(TestValue): - def __init__(self, rust_generator, stone_type, reference_impls): + def __init__( + self, + rust_generator: RustHelperBackend, + stone_type: ir.DataType, + reference_impls: Dict[str, Module], + ) -> None: super(TestList, self).__init__(rust_generator) self._stone_type = stone_type self._reference_impls = reference_impls @@ -462,12 +529,17 @@ def __init__(self, rust_generator, stone_type, reference_impls): self.value = self._inner_value.value - def emit_asserts(self, codegen, expression_path): + def emit_asserts(self, codegen: RustHelperBackend, expression_path: str) -> None: self._inner_value.emit_assert(codegen, expression_path + '[0]') class TestMap(TestValue): - def __init__(self, rust_generator, stone_type, reference_impls): + def __init__( + self, + rust_generator: RustHelperBackend, + stone_type: ir.Map, + reference_impls: Dict[str, Module], + ) -> None: super(TestMap, self).__init__(rust_generator) self._stone_type = stone_type self._reference_impls = reference_impls @@ -477,22 +549,33 @@ def __init__(self, rust_generator, stone_type, reference_impls): reference_impls) self.value = {self._key_value.value: self._val_value.value} - def emit_asserts(self, codegen, expression_path): + def emit_asserts(self, codegen: RustHelperBackend, expression_path: str) -> None: key_str = f'["{self._key_value.value}"]' self._val_value.emit_assert(codegen, expression_path + key_str) # Make a TestField with a specific value. -def test_field_with_value(field_name, value, stone_type, rust_generator, reference_impls): +def test_field_with_value( + field_name: str, + value: Any, + stone_type: ir.DataType, + rust_generator: RustHelperBackend, + reference_impls: Dict[str, Module], +) -> TestField: typ, option = ir.unwrap_nullable(stone_type) inner = None if ir.is_tag_ref(value): + assert isinstance(stone_type, ir.Union) + assert isinstance(value, ir.TagRef) # TagRef means we need to instantiate the named variant of this union, so find the right # field (variant) of the union and change the value to a TestUnion of it + variant = None for f in stone_type.all_fields: + assert isinstance(f, ir.UnionField) if f.name == value.tag_name: variant = f break + assert variant, f"no appropriate variant found for tag name {value.tag_name}" inner = TestUnion(rust_generator, typ, reference_impls, variant, no_optional_fields=True) value = inner.value return TestField( @@ -505,7 +588,13 @@ def test_field_with_value(field_name, value, stone_type, rust_generator, referen # Make a TestField with an arbitrary value that satisfies constraints. If no_optional_fields is True # then optional or nullable fields will be left unset. -def make_test_field(field_name, stone_type, rust_generator, reference_impls, no_optional_fields): +def make_test_field( + field_name: Optional[str], + stone_type: ir.DataType, + rust_generator: RustHelperBackend, + reference_impls: Dict[str, Module], + no_optional_fields: bool = False, +) -> TestField: rust_name = rust_generator.field_name_raw(field_name) if field_name is not None else None typ, option = ir.unwrap_nullable(stone_type) @@ -560,15 +649,15 @@ class Unregex(object): Generate a minimal string that passes a regex and optionally is of a given minimum length. """ - def __init__(self, regex_string, min_len=None): + def __init__(self, regex_string: str, min_len: Optional[int] = None) -> None: self._min_len = min_len self._group_refs = {} self._tokens = sre_parse.parse(regex_string) - def generate(self): + def generate(self) -> str: return self._generate(self._tokens) - def _generate(self, tokens): + def _generate(self, tokens: Any) -> str: result = '' for (opcode, argument) in tokens: opcode = str(opcode).lower()