Skip to content

Commit

Permalink
Typing fixes for new version of mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
avylove committed Mar 20, 2023
1 parent fe2bf24 commit 5d6686b
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 148 deletions.
6 changes: 2 additions & 4 deletions lisa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,13 @@ def load_environments(
class EnvironmentHookSpec:
@hookspec
def get_environment_information(self, environment: Environment) -> Dict[str, str]:
...
raise NotImplementedError


class EnvironmentHookImpl:
@hookimpl
def get_environment_information(self, environment: Environment) -> Dict[str, str]:
information: Dict[str, str] = {}
information["name"] = environment.name

information: Dict[str, str] = {"name": environment.name}
if environment.nodes:
node = environment.default_node
try:
Expand Down
2 changes: 1 addition & 1 deletion lisa/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def quick_connect(
class NodeHookSpec:
@hookspec
def get_node_information(self, node: Node) -> Dict[str, str]:
...
raise NotImplementedError


class NodeHookImpl:
Expand Down
13 changes: 5 additions & 8 deletions lisa/sut_orchestrator/aws/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import InitVar, dataclass, field
from dataclasses import dataclass, field
from typing import Dict, List, Optional

from dataclasses_json import dataclass_json
Expand Down Expand Up @@ -67,16 +67,13 @@ class AwsNodeSchema:
data_disk_size: int = 32
disk_type: str = ""

# for marketplace image, which need to accept terms
_marketplace: InitVar[Optional[AwsVmMarketplaceSchema]] = None
def __post_init__(self) -> None:
# Caching for marketplace image
self._marketplace: Optional[AwsVmMarketplaceSchema] = None

@property
def marketplace(self) -> AwsVmMarketplaceSchema:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_marketplace"):
self._marketplace: Optional[AwsVmMarketplaceSchema] = None

if not self._marketplace:
if self._marketplace is None:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"
Expand Down
216 changes: 96 additions & 120 deletions lisa/sut_orchestrator/azure/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import re
import sys
from dataclasses import InitVar, dataclass, field
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from pathlib import Path
from threading import Lock
Expand Down Expand Up @@ -189,13 +189,12 @@ class AzureNodeSchema:
# image.
is_linux: Optional[bool] = None

_marketplace: InitVar[Optional[AzureVmMarketplaceSchema]] = None
def __post_init__(self) -> None:
# Caching
self._marketplace: Optional[AzureVmMarketplaceSchema] = None
self._shared_gallery: Optional[SharedImageGallerySchema] = None
self._vhd: Optional[VhdSchema] = None

_shared_gallery: InitVar[Optional[SharedImageGallerySchema]] = None

_vhd: InitVar[Optional[VhdSchema]] = None

def __post_init__(self, *args: Any, **kwargs: Any) -> None:
# trim whitespace of values.
strip_strs(
self,
Expand All @@ -218,109 +217,96 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:

@property
def marketplace(self) -> Optional[AzureVmMarketplaceSchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_marketplace"):
self._marketplace: Optional[AzureVmMarketplaceSchema] = None
marketplace: Optional[AzureVmMarketplaceSchema] = self._marketplace
if not marketplace:
if isinstance(self.marketplace_raw, dict):
if self._marketplace is not None:
return self._marketplace

if isinstance(self.marketplace_raw, dict):
# Users decide the cases of image names,
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
self.marketplace_raw = {
k: v.lower() for k, v in self.marketplace_raw.items()
}
self._marketplace = schema.load_by_type(
AzureVmMarketplaceSchema, self.marketplace_raw
)
# Validated marketplace_raw and filter out any unwanted content
self.marketplace_raw = self._marketplace.to_dict() # type: ignore

elif self.marketplace_raw:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"

self.marketplace_raw = self.marketplace_raw.strip()

if self.marketplace_raw:
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.marketplace_raw = dict(
(k, v.lower()) for k, v in self.marketplace_raw.items()
)
marketplace = schema.load_by_type(
AzureVmMarketplaceSchema, self.marketplace_raw
)
# this step makes marketplace_raw is validated, and
# filter out any unwanted content.
self.marketplace_raw = marketplace.to_dict() # type: ignore
elif self.marketplace_raw:
assert isinstance(
self.marketplace_raw, str
), f"actual: {type(self.marketplace_raw)}"

self.marketplace_raw = self.marketplace_raw.strip()

if self.marketplace_raw:
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
marketplace_strings = re.split(
r"[:\s]+", self.marketplace_raw.lower()
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
marketplace_strings = re.split(r"[:\s]+", self.marketplace_raw.lower())

if len(marketplace_strings) != 4:
raise LisaException(
"Invalid value for the provided marketplace "
f"parameter: '{self.marketplace_raw}'."
"The marketplace parameter should be in the format: "
"'<Publisher> <Offer> <Sku> <Version>' "
"or '<Publisher>:<Offer>:<Sku>:<Version>'"
)
self._marketplace = AzureVmMarketplaceSchema(*marketplace_strings)
# marketplace_raw is used
self.marketplace_raw = (
self._marketplace.to_dict() # type: ignore [attr-defined]
)

if len(marketplace_strings) == 4:
marketplace = AzureVmMarketplaceSchema(*marketplace_strings)
# marketplace_raw is used
self.marketplace_raw = marketplace.to_dict() # type: ignore
else:
raise LisaException(
f"Invalid value for the provided marketplace "
f"parameter: '{self.marketplace_raw}'."
f"The marketplace parameter should be in the format: "
f"'<Publisher> <Offer> <Sku> <Version>' "
f"or '<Publisher>:<Offer>:<Sku>:<Version>'"
)
self._marketplace = marketplace
return marketplace
return self._marketplace

@marketplace.setter
def marketplace(self, value: Optional[AzureVmMarketplaceSchema]) -> None:
self._marketplace = value
if value is None:
self.marketplace_raw = None
else:
self.marketplace_raw = value.to_dict() # type: ignore
# dataclass_json doesn't use a protocol return type, so to_dict() is unknown
self.marketplace_raw = (
None if value is None else value.to_dict() # type: ignore [attr-defined]
)

@property
def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_shared_gallery"):
self._shared_gallery: Optional[SharedImageGallerySchema] = None
shared_gallery: Optional[SharedImageGallerySchema] = self._shared_gallery
if shared_gallery:
return shared_gallery
if self._shared_gallery is not None:
return self._shared_gallery

if isinstance(self.shared_gallery_raw, dict):
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
self.shared_gallery_raw = dict(
(k, v.lower()) for k, v in self.shared_gallery_raw.items()
)
shared_gallery = schema.load_by_type(
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
self.shared_gallery_raw = {
k: v.lower() for k, v in self.shared_gallery_raw.items()
}

self._shared_gallery = schema.load_by_type(
SharedImageGallerySchema, self.shared_gallery_raw
)
if not shared_gallery.subscription_id:
shared_gallery.subscription_id = self.subscription_id
# this step makes shared_gallery_raw is validated, and
# filter out any unwanted content.
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
if not self._shared_gallery.subscription_id:
self._shared_gallery.subscription_id = self.subscription_id
# Validated shared_gallery_raw and filter out any unwanted content
self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore

elif self.shared_gallery_raw:
assert isinstance(
self.shared_gallery_raw, str
), f"actual: {type(self.shared_gallery_raw)}"
# Users decide the cases of image names,
# the inconsistent cases cause the mismatched error in notifiers.
# The lower() normalizes the image names,
# it has no impact on deployment.
# inconsistent cases cause a mismatch error in notifiers.
# lower() normalizes the image names, it has no impact on deployment
shared_gallery_strings = re.split(
r"[/]+", self.shared_gallery_raw.strip().lower()
)
if len(shared_gallery_strings) == 5:
shared_gallery = SharedImageGallerySchema(*shared_gallery_strings)
# shared_gallery_raw is used
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
self._shared_gallery = SharedImageGallerySchema(*shared_gallery_strings)
elif len(shared_gallery_strings) == 3:
shared_gallery = SharedImageGallerySchema(
self._shared_gallery = SharedImageGallerySchema(
self.subscription_id, None, *shared_gallery_strings
)
# shared_gallery_raw is used
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore
else:
raise LisaException(
f"Invalid value for the provided shared gallery "
Expand All @@ -330,51 +316,43 @@ def shared_gallery(self) -> Optional[SharedImageGallerySchema]:
f"<image_definition>/<image_version>' or '<image_gallery>/"
f"<image_definition>/<image_version>'"
)
self._shared_gallery = shared_gallery
return shared_gallery
self.shared_gallery_raw = self._shared_gallery.to_dict() # type: ignore

return self._shared_gallery

@shared_gallery.setter
def shared_gallery(self, value: Optional[SharedImageGallerySchema]) -> None:
self._shared_gallery = value
if value is None:
self.shared_gallery_raw = None
else:
self.shared_gallery_raw = value.to_dict() # type: ignore
# dataclass_json doesn't use a protocol return type, so to_dict() is unknown
self.shared_gallery_raw = (
None if value is None else value.to_dict() # type: ignore [attr-defined]
)

@property
def vhd(self) -> Optional[VhdSchema]:
# this is a safe guard and prevent mypy error on typing
if not hasattr(self, "_vhd"):
self._vhd: Optional[VhdSchema] = None
vhd: Optional[VhdSchema] = self._vhd
if vhd:
return vhd
if self._vhd is not None:
return self._vhd

if isinstance(self.vhd_raw, dict):
vhd = schema.load_by_type(VhdSchema, self.vhd_raw)
add_secret(vhd.vhd_path, PATTERN_URL)
if vhd.vmgs_path:
add_secret(vhd.vmgs_path, PATTERN_URL)
# this step makes vhd_raw is validated, and
# filter out any unwanted content.
self.vhd_raw = vhd.to_dict() # type: ignore
self._vhd = schema.load_by_type(VhdSchema, self.vhd_raw)
add_secret(self._vhd.vhd_path, PATTERN_URL)
if self._vhd.vmgs_path:
add_secret(self._vhd.vmgs_path, PATTERN_URL)
# Validated vhd_raw and filter out any unwanted content
self.vhd_raw = self._vhd.to_dict() # type: ignore

elif self.vhd_raw is not None:
assert isinstance(self.vhd_raw, str), f"actual: {type(self.vhd_raw)}"
vhd = VhdSchema(self.vhd_raw)
add_secret(vhd.vhd_path, PATTERN_URL)
self.vhd_raw = vhd.to_dict() # type: ignore
self._vhd = vhd
if vhd:
return vhd
else:
return None
self._vhd = VhdSchema(self.vhd_raw)
add_secret(self._vhd.vhd_path, PATTERN_URL)
self.vhd_raw = self._vhd.to_dict() # type: ignore

return self._vhd

@vhd.setter
def vhd(self, value: Optional[VhdSchema]) -> None:
self._vhd = value
if value is None:
self.vhd_raw = None
else:
self.vhd_raw = self._vhd.to_dict() # type: ignore
self.vhd_raw = None if value is None else self._vhd.to_dict() # type: ignore

def get_image_name(self) -> str:
result = ""
Expand All @@ -385,7 +363,7 @@ def get_image_name(self) -> str:
self.shared_gallery_raw, dict
), f"actual type: {type(self.shared_gallery_raw)}"
if self.shared_gallery.resource_group_name:
result = "/".join([x for x in self.shared_gallery_raw.values()])
result = "/".join(self.shared_gallery_raw.values())
else:
result = (
f"{self.shared_gallery.image_gallery}/"
Expand All @@ -396,7 +374,7 @@ def get_image_name(self) -> str:
assert isinstance(
self.marketplace_raw, dict
), f"actual type: {type(self.marketplace_raw)}"
result = " ".join([x for x in self.marketplace_raw.values()])
result = " ".join(self.marketplace_raw.values())
return result


Expand All @@ -420,9 +398,7 @@ def from_node_runbook(cls, runbook: AzureNodeSchema) -> "AzureNodeArmParameter":
parameters["vhd_raw"] = parameters["vhd"]
del parameters["vhd"]

arm_parameters = AzureNodeArmParameter(**parameters)

return arm_parameters
return AzureNodeArmParameter(**parameters)


class DataDiskCreateOption:
Expand Down
2 changes: 1 addition & 1 deletion lisa/sut_orchestrator/azure/platform_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,7 +2162,7 @@ def _get_vhd_os_disk_size(self, blob_url: str) -> int:
assert properties.size, f"fail to get blob size of {blob_url}"
# Azure requires only megabyte alignment of vhds, round size up
# for cases where the size is megabyte aligned
return math.ceil(properties.size / 1024 / 1024 / 1024)
return int(math.ceil(properties.size / 1024 / 1024 / 1024))

def _get_sig_info(
self, shared_image: SharedImageGallerySchema
Expand Down
Loading

0 comments on commit 5d6686b

Please sign in to comment.