diff --git a/common_utility/configLoader.py b/common_utility/configLoader.py index 300d9d3..daf692b 100644 --- a/common_utility/configLoader.py +++ b/common_utility/configLoader.py @@ -16,7 +16,7 @@ class IConfigLoader(object): def load(self, argument_parser: ArgumentParser) -> Namespace: raise NotImplementedError() - def dump(self, argument_parser: ArgumentParser, config: Namespace, file: TextIO = sys.stdout) -> None: + def dump(self, argument_parser: ArgumentParser, file: TextIO = sys.stdout) -> None: raise NotImplementedError() @@ -60,7 +60,8 @@ def load(self, argument_parser: ArgumentParser) -> Namespace: return Namespace(**configuration) - def dump(self, argument_parser: ArgumentParser, config: Namespace, file: TextIO = sys.stdout) -> None: + def dump(self, argument_parser: ArgumentParser, file: TextIO = sys.stdout) -> None: + arguments = argument_parser.parse_known_args()[0] for group in argument_parser._action_groups: section = group.title if group.title else 'DEFAULT' values = {} @@ -68,7 +69,7 @@ def dump(self, argument_parser: ArgumentParser, config: Namespace, file: TextIO for action in group._group_actions: if not action.dest or action.dest == "help": continue - value = getattr(config, action.dest, None) + value = getattr(arguments, action.dest, None) if value is None: continue values[action.dest] = str(value) diff --git a/tests/configLoaderTest.py b/tests/configLoaderTest.py index 8ac7756..17b4972 100644 --- a/tests/configLoaderTest.py +++ b/tests/configLoaderTest.py @@ -1,11 +1,11 @@ import sys import unittest -from argparse import ArgumentParser, Namespace +from argparse import ArgumentParser, _ArgumentGroup from configparser import ConfigParser from io import StringIO from pathlib import Path from unittest import TestCase -from unittest.mock import patch +from unittest.mock import patch, MagicMock from context_logger import setup_logging @@ -222,11 +222,11 @@ def test_dump_when_values_present_then_write_config_sections(self): network_group.add_argument('--port') runtime_group = argument_parser.add_argument_group('runtime') runtime_group.add_argument('--debug') - config = Namespace(host='localhost', port=8080, debug=True) output = StringIO() # When - config_loader.dump(argument_parser, config, output) + with patch.object(sys, 'argv', ['test', '--host', 'localhost', '--port', '8080', '--debug', 'True']): + config_loader.dump(argument_parser, output) # Then parser = ConfigParser(interpolation=None) @@ -242,11 +242,11 @@ def test_dump_when_value_is_none_then_skip_value(self): runtime_group = argument_parser.add_argument_group('runtime') runtime_group.add_argument('--timeout') runtime_group.add_argument('--retries') - config = Namespace(timeout=None, retries=3) output = StringIO() # When - config_loader.dump(argument_parser, config, output) + with patch.object(sys, 'argv', ['test', '--retries', '3']): + config_loader.dump(argument_parser, output) # Then parser = ConfigParser(interpolation=None) @@ -261,11 +261,11 @@ def test_dump_when_all_values_in_group_are_none_then_omit_section(self): argument_parser = ArgumentParser() secret_group = argument_parser.add_argument_group('secret') secret_group.add_argument('--token') - config = Namespace(token=None) output = StringIO() # When - config_loader.dump(argument_parser, config, output) + with patch.object(sys, 'argv', ['test']): + config_loader.dump(argument_parser, output) # Then self.assertNotIn('[secret]', output.getvalue()) @@ -275,18 +275,16 @@ def test_dump_when_group_has_no_title_then_uses_default_section(self): config_loader = ConfigLoader(Path(DEFAULT_CONFIG_FILE)) argument_parser = ArgumentParser(add_help=False) region_action = argument_parser.add_argument('--region') + dummy_group = MagicMock(spec=_ArgumentGroup) + dummy_group.title = None + dummy_group._group_actions = [region_action] - class DummyGroup(object): - def __init__(self): - self.title = None - self._group_actions = [region_action] - - argument_parser._action_groups = [DummyGroup()] - config = Namespace(region='eu-central') + argument_parser._action_groups = [dummy_group] output = StringIO() # When - config_loader.dump(argument_parser, config, output) + with patch.object(sys, 'argv', ['test', '--region', 'eu-central']): + config_loader.dump(argument_parser, output) # Then parser = ConfigParser(interpolation=None)