Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 56 additions & 24 deletions src/azure-cli/azure/cli/command_modules/batch/_command_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,11 +689,29 @@ def get_track1_validations(self, cls):
return filtered_members

def convert_to_track1_type(self, original_type):
# Handle Python 3.14 pipe union syntax at the top level: "A | B | None"
# Only applies when not inside brackets (e.g. not List[A | B])
if original_type is not None and " | " in original_type and "[" not in original_type:
parts = [p.strip() for p in original_type.split(' | ')]
non_none_parts = [p for p in parts if p != 'None']
if non_none_parts:
if len(non_none_parts) > 1:
original_type = next((p for p in non_none_parts if p != 'str'), non_none_parts[0])
else:
original_type = non_none_parts[0]
# Handle Python 3.14 pipe union syntax inside brackets: "List[str | SomeType]"
# Replace inner "str | X" with just "X"
if original_type is not None and " | " in original_type:
original_type = re.sub(r'\bstr\b\s*\|\s*', '', original_type)
if original_type is not None and "ForwardRef" in original_type:
pattern = r"ForwardRef\('_models\.(.*?)'\)"
original_type = re.sub(pattern, r'\1', original_type)
if original_type is not None and "_models." in original_type:
original_type = original_type.replace("_models.", "")
if original_type is not None and "_enums." in original_type:
original_type = original_type.replace("_enums.", "")
if original_type is not None and "azure.batch.models." in original_type:
original_type = original_type.replace("azure.batch.models.", "")
if original_type is not None and "typing.List" in original_type:
original_type = original_type.replace("typing.List", "List")
if original_type is not None and "typing.Dict" in original_type:
Expand All @@ -710,8 +728,8 @@ def convert_to_track1_type(self, original_type):
pattern = r"typing\.Union\[str, (.+?)\]"
original_type = re.sub(pattern, r"\1", original_type)

if original_type is not None and "<class" in original_type:
pattern = r"<class '([\w\.]+)'>"
if original_type is not None and ("<class" in original_type or "<enum" in original_type):
pattern = r"<(?:class|enum) '([\w\.]+)'>"
match = re.search(pattern, original_type)
if match:
original_type = match.group(1)
Expand All @@ -735,31 +753,42 @@ def get_track1_rest_names(self, cls):
rest_names[name] = rest_name
return rest_names

def _resolve_track1_type_hint(self, type_hint):
"""Resolve type hints to the legacy track1 type string format."""
args = get_args(type_hint)

# Optional[T] / Union[..., None] -> select the best non-None candidate.
if type(None) in args:
non_none_args = [arg for arg in args if arg is not type(None)]
preferred_args = [arg for arg in non_none_args if arg != str] or non_none_args
selected = preferred_args[0] if preferred_args else type_hint
return self.convert_to_track1_type(str(selected))

# Union[str, X] -> prefer X for command argument flattening.
if args and str in args:
non_str_args = [arg for arg in args if arg != str]
if non_str_args:
return self.convert_to_track1_type(str(non_str_args[0]))

return self.convert_to_track1_type(str(type_hint))

def get_track1_attribute_map(self, cls):
# pylint: disable=protected-access
member_types = {}
pattern1 = r"^typing\.Union\[str, (.+), NoneType\]$"
pattern2 = r"^typing\.Union\[(.+), NoneType\]$"
pattern3 = r"^typing\.Optional\[(.+)\]$"

rest_names = self.get_track1_rest_names(cls)
for name, typ in cls.__annotations__.items():
if hasattr(typ, '_name') and typ._name is not None and typ._name == 'Optional':
track1_type = self.convert_to_track1_type(str(get_args(typ)[0]))
else:
track1_type = str(typ)

if re.match(pattern1, track1_type):
track1_type = self.convert_to_track1_type(str(get_args(typ)[1]))
elif re.match(pattern2, track1_type):
track1_type = self.convert_to_track1_type(str(get_args(typ)[0]))
elif re.match(pattern3, track1_type):
track1_type = self.convert_to_track1_type(str(get_args(typ)[0]))
else:
track1_type = self.convert_to_track1_type(track1_type)

if rest_names[name] is None:
print("none")
# Use get_type_hints to resolve ForwardRef strings and get resolved type information
globalns = {}
globalns.update(vars(importlib.import_module(cls.__module__)))
# Azure Batch model annotations use aliases like `_models.Foo` and `_enums.Bar`.
# `_models` aliases resolve via azure.batch.models exports; `_enums` points to the generated enums module.
globalns['_models'] = importlib.import_module('azure.batch.models')
globalns['_enums'] = importlib.import_module('azure.batch.models._enums')
hints = get_type_hints(cls, globalns=globalns)

for name, type_hint in hints.items():
track1_type = self._resolve_track1_type_hint(type_hint)
member_types[name] = {'key': rest_names[name], 'type': track1_type}

return member_types
Expand All @@ -770,14 +799,17 @@ def get_optional_state(self, cls):
globalns = {}
# Add the global namespace of the module where the class is defined
globalns.update(vars(importlib.import_module(cls.__module__)))
# azure batch models uses an alias _models which throws off the get_type_hints eval, need this to correct
# Azure Batch model annotations use aliases like `_models.Foo` and `_enums.Bar`.
# `_models` aliases resolve via azure.batch.models exports; `_enums` points to the generated enums module.
globalns['_models'] = importlib.import_module('azure.batch.models')
globalns['_enums'] = importlib.import_module('azure.batch.models._enums')

members = get_type_hints(cls, globalns=globalns)
filtered_members = {}
for name, type_hint in members.items():
is_optional = (type_hint._name == 'Optional' or type_hint._name is None
if hasattr(type_hint, '_name') else False)
# Use get_args() to detect optional types (stable across Python 3.13 and 3.14)
args = get_args(type_hint)
is_optional = type(None) in args
filtered_members[name] = {'required': not is_optional}
return filtered_members

Expand Down
Loading