From fddbac60802bd8a102493fe08d799621f42aeaaa Mon Sep 17 00:00:00 2001 From: Emily Soth Date: Mon, 5 May 2025 14:33:53 -0700 Subject: [PATCH] rewrite timeout to work as a decorator; ModelSpec to_json method --- src/natcap/invest/cli.py | 2 +- src/natcap/invest/spec_utils.py | 124 ++++++++++++++------------- src/natcap/invest/ui_server.py | 2 +- tests/test_model_specs.py | 2 +- tests/test_translation.py | 146 ++++++++++++++++---------------- tests/test_validation.py | 10 ++- 6 files changed, 146 insertions(+), 140 deletions(-) diff --git a/src/natcap/invest/cli.py b/src/natcap/invest/cli.py index b4551a145..bb6eba24c 100644 --- a/src/natcap/invest/cli.py +++ b/src/natcap/invest/cli.py @@ -438,7 +438,7 @@ def main(user_args=None): importlib.import_module(name=target_model)) spec = model_module.MODEL_SPEC - message = spec_utils.serialize_args_spec(spec) + message = spec.to_json() sys.stdout.write(message) parser.exit(0) diff --git a/src/natcap/invest/spec_utils.py b/src/natcap/invest/spec_utils.py index 1eeb5344c..bbcae004c 100644 --- a/src/natcap/invest/spec_utils.py +++ b/src/natcap/invest/spec_utils.py @@ -31,7 +31,7 @@ LOGGER = logging.getLogger(__name__) # accessing a file could take a long time if it's in a file streaming service # to prevent the UI from hanging due to slow validation, # set a timeout for these functions. -def timeout(func, *args, timeout=5, **kwargs): +def timeout(func, timeout=5): """Stop a function after a given amount of time. Args: @@ -51,26 +51,28 @@ def timeout(func, *args, timeout=5, **kwargs): # the target function puts the return value from `func` into shared memory message_queue = queue.Queue() - def wrapper_func(): - message_queue.put(func(*args, **kwargs)) + def wrapper(*args, **kwargs): + def put_fn(): + message_queue.put(func(*args, **kwargs)) + thread = threading.Thread(target=put_fn) + LOGGER.debug(f'Starting file checking thread with timeout={timeout}') + thread.start() + thread.join(timeout=timeout) + if thread.is_alive(): + # first arg to `check_csv`, `check_raster`, `check_vector` is the path + warnings.warn( + f'Validation of file {args[0]} timed out. If this file ' + 'is stored in a file streaming service, it may be taking a long ' + 'time to download. Try storing it locally instead.') + return None - thread = threading.Thread(target=wrapper_func) - LOGGER.debug(f'Starting file checking thread with timeout={timeout}') - thread.start() - thread.join(timeout=timeout) - if thread.is_alive(): - # first arg to `check_csv`, `check_raster`, `check_vector` is the path - warnings.warn( - f'Validation of file {args[0]} timed out. If this file ' - 'is stored in a file streaming service, it may be taking a long ' - 'time to download. Try storing it locally instead.') - return None + else: + LOGGER.debug('File checking thread completed.') + # get any warning messages returned from the thread + a = message_queue.get() + return a - else: - LOGGER.debug('File checking thread completed.') - # get any warning messages returned from the thread - a = message_queue.get() - return a + return wrapper def check_headers(expected_headers, actual_headers, header_type='header'): """Validate that expected headers are in a list of actual headers. @@ -879,7 +881,6 @@ class UISpec: hidden: list = None dropdown_functions: dict = dataclasses.field(default_factory=dict) - @dataclasses.dataclass class ModelSpec: model_id: str @@ -898,6 +899,49 @@ class ModelSpec: def get_input(self, key): return self.inputs_dict[key] + def to_json(self): + """Serialize an MODEL_SPEC dict to a JSON string. + + Args: + spec (dict): An invest model's MODEL_SPEC. + + Raises: + TypeError if any object type within the spec is not handled by + json.dumps or by the fallback serializer. + + Returns: + JSON String + """ + + def fallback_serializer(obj): + """Serialize objects that are otherwise not JSON serializeable.""" + if isinstance(obj, pint.Unit): + return format_unit(obj) + # Sets are present in 'geometries' attributes of some args + # We don't need to worry about deserializing back to a set/array + # so casting to string is okay. + elif isinstance(obj, set): + return str(obj) + elif isinstance(obj, types.FunctionType): + return str(obj) + elif dataclasses.is_dataclass(obj): + as_dict = dataclasses.asdict(obj) + if hasattr(obj, 'type'): + as_dict['type'] = obj.type + return as_dict + elif isinstance(obj, IterableWithDotAccess): + return obj.to_json() + raise TypeError(f'fallback serializer is missing for {type(obj)}') + + spec_dict = self.__dict__.copy() + # rename 'inputs' to 'args' to stay consistent with the old api + spec_dict.pop('inputs') + spec_dict.pop('inputs_dict') + spec_dict.pop('outputs_dict') + spec_dict['args'] = self.inputs_dict + spec_dict['outputs'] = self.outputs_dict + return json.dumps(spec_dict, default=fallback_serializer, ensure_ascii=False) + def build_model_spec(model_spec): inputs = [ @@ -1370,46 +1414,6 @@ def format_unit(unit): formatted_unit = formatted_unit.replace('currency', gettext('currency units')) return formatted_unit - -def serialize_args_spec(spec): - """Serialize an MODEL_SPEC dict to a JSON string. - - Args: - spec (dict): An invest model's MODEL_SPEC. - - Raises: - TypeError if any object type within the spec is not handled by - json.dumps or by the fallback serializer. - - Returns: - JSON String - """ - - def fallback_serializer(obj): - """Serialize objects that are otherwise not JSON serializeable.""" - if isinstance(obj, pint.Unit): - return format_unit(obj) - # Sets are present in 'geometries' attributes of some args - # We don't need to worry about deserializing back to a set/array - # so casting to string is okay. - elif isinstance(obj, set): - return str(obj) - elif isinstance(obj, types.FunctionType): - return str(obj) - elif dataclasses.is_dataclass(obj): - as_dict = dataclasses.asdict(obj) - if hasattr(obj, 'type'): - as_dict['type'] = obj.type - return as_dict - elif isinstance(obj, IterableWithDotAccess): - return obj.to_json() - raise TypeError(f'fallback serializer is missing for {type(obj)}') - - spec_dict = json.loads(json.dumps(spec, default=fallback_serializer, ensure_ascii=False)) - spec_dict['args'] = spec_dict.pop('inputs') - return json.dumps(spec_dict, ensure_ascii=False) - - # accepted geometries for a vector will be displayed in this order GEOMETRY_ORDER = [ 'POINT', diff --git a/src/natcap/invest/ui_server.py b/src/natcap/invest/ui_server.py index 3c28920e0..259fab66b 100644 --- a/src/natcap/invest/ui_server.py +++ b/src/natcap/invest/ui_server.py @@ -68,7 +68,7 @@ def get_invest_getspec(): importlib.reload(natcap.invest.validation) model_module = importlib.reload( importlib.import_module(name=target_module)) - return spec_utils.serialize_args_spec(model_module.MODEL_SPEC) + return model_module.MODEL_SPEC.to_json() @app.route(f'/{PREFIX}/dynamic_dropdowns', methods=['POST']) diff --git a/tests/test_model_specs.py b/tests/test_model_specs.py index b996b0708..9438dc76d 100644 --- a/tests/test_model_specs.py +++ b/tests/test_model_specs.py @@ -509,7 +509,7 @@ class ValidateModelSpecs(unittest.TestCase): for pyname in model_id_to_pyname.values(): model = importlib.import_module(pyname) - spec_utils.serialize_args_spec(model.MODEL_SPEC) + model.MODEL_SPEC.to_json() class SpecUtilsTests(unittest.TestCase): diff --git a/tests/test_translation.py b/tests/test_translation.py index 84ffc4679..f145a2365 100644 --- a/tests/test_translation.py +++ b/tests/test_translation.py @@ -76,16 +76,16 @@ class TranslationTests(unittest.TestCase): self.locales_patcher.stop() shutil.rmtree(self.workspace_dir) - # def test_invest_list(self): - # """Translation: test that CLI list output is translated.""" - # from natcap.invest import cli - # with patch('sys.stdout', new=io.StringIO()) as out: - # with self.assertRaises(SystemExit): - # cli.main(['--language', TEST_LANG, 'list']) - # result = out.getvalue() - # self.assertIn(TEST_MESSAGES['Available models:'], result) - # self.assertIn( - # TEST_MESSAGES['Carbon Storage and Sequestration'], result) + def test_invest_list(self): + """Translation: test that CLI list output is translated.""" + from natcap.invest import cli + with patch('sys.stdout', new=io.StringIO()) as out: + with self.assertRaises(SystemExit): + cli.main(['--language', TEST_LANG, 'list']) + result = out.getvalue() + self.assertIn(TEST_MESSAGES['Available models:'], result) + self.assertIn( + TEST_MESSAGES['Carbon Storage and Sequestration'], result) def test_invest_getspec(self): """Translation: test that CLI getspec output is translated.""" @@ -96,72 +96,72 @@ class TranslationTests(unittest.TestCase): result = out.getvalue() self.assertIn(TEST_MESSAGES['baseline LULC'], result) - # def test_invest_validate(self): - # """Translation: test that CLI validate output is translated.""" - # datastack = { # write datastack to a JSON file - # 'model_id': 'carbon', - # 'invest_version': '0.0', - # 'args': {} - # } - # datastack_path = os.path.join(self.workspace_dir, 'datastack.json') - # with open(datastack_path, 'w') as file: - # json.dump(datastack, file) + def test_invest_validate(self): + """Translation: test that CLI validate output is translated.""" + datastack = { # write datastack to a JSON file + 'model_id': 'carbon', + 'invest_version': '0.0', + 'args': {} + } + datastack_path = os.path.join(self.workspace_dir, 'datastack.json') + with open(datastack_path, 'w') as file: + json.dump(datastack, file) - # from natcap.invest import cli - # with patch('sys.stdout', new=io.StringIO()) as out: - # with self.assertRaises(SystemExit): - # cli.main( - # ['--language', TEST_LANG, 'validate', datastack_path]) + from natcap.invest import cli + with patch('sys.stdout', new=io.StringIO()) as out: + with self.assertRaises(SystemExit): + cli.main( + ['--language', TEST_LANG, 'validate', datastack_path]) - # result = out.getvalue() - # self.assertIn(TEST_MESSAGES[missing_key_msg], result) + result = out.getvalue() + self.assertIn(TEST_MESSAGES[missing_key_msg], result) - # def test_server_get_invest_models(self): - # """Translation: test that /models endpoint is translated.""" - # from natcap.invest import ui_server - # test_client = ui_server.app.test_client() - # response = test_client.get( - # 'api/models', query_string={'language': TEST_LANG}) - # result = json.loads(response.get_data(as_text=True)) - # self.assertIn( - # TEST_MESSAGES['Carbon Storage and Sequestration'], - # [val['model_title'] for val in result.values()]) + def test_server_get_invest_models(self): + """Translation: test that /models endpoint is translated.""" + from natcap.invest import ui_server + test_client = ui_server.app.test_client() + response = test_client.get( + 'api/models', query_string={'language': TEST_LANG}) + result = json.loads(response.get_data(as_text=True)) + self.assertIn( + TEST_MESSAGES['Carbon Storage and Sequestration'], + [val['model_title'] for val in result.values()]) - # def test_server_get_invest_getspec(self): - # """Translation: test that /getspec endpoint is translated.""" - # from natcap.invest import ui_server - # test_client = ui_server.app.test_client() - # response = test_client.post( - # 'api/getspec', json='carbon', query_string={'language': TEST_LANG}) - # spec = json.loads(response.get_data(as_text=True)) - # self.assertEqual( - # spec['inputs']['lulc_bas_path']['name'], - # TEST_MESSAGES['baseline LULC']) + def test_server_get_invest_getspec(self): + """Translation: test that /getspec endpoint is translated.""" + from natcap.invest import ui_server + test_client = ui_server.app.test_client() + response = test_client.post( + 'api/getspec', json='carbon', query_string={'language': TEST_LANG}) + spec = json.loads(response.get_data(as_text=True)) + self.assertEqual( + spec['args']['lulc_bas_path']['name'], + TEST_MESSAGES['baseline LULC']) - # def test_server_get_invest_validate(self): - # """Translation: test that /validate endpoint is translated.""" - # from natcap.invest import ui_server - # from natcap.invest import carbon - # test_client = ui_server.app.test_client() - # payload = { - # 'model_id': carbon.MODEL_SPEC.model_id, - # 'args': json.dumps({}) - # } - # response = test_client.post( - # 'api/validate', json=payload, - # query_string={'language': TEST_LANG}) - # results = json.loads(response.get_data(as_text=True)) - # messages = [item[1] for item in results] - # self.assertIn(TEST_MESSAGES[missing_key_msg], messages) + def test_server_get_invest_validate(self): + """Translation: test that /validate endpoint is translated.""" + from natcap.invest import ui_server + from natcap.invest import carbon + test_client = ui_server.app.test_client() + payload = { + 'model_id': carbon.MODEL_SPEC.model_id, + 'args': json.dumps({}) + } + response = test_client.post( + 'api/validate', json=payload, + query_string={'language': TEST_LANG}) + results = json.loads(response.get_data(as_text=True)) + messages = [item[1] for item in results] + self.assertIn(TEST_MESSAGES[missing_key_msg], messages) - # def test_translate_formatted_string(self): - # """Translation: test that f-string can be translated.""" - # from natcap.invest import carbon, validation, set_locale - # set_locale(TEST_LANG) - # importlib.reload(validation) - # importlib.reload(carbon) - # args = {'n_workers': 'not a number'} - # validation_messages = carbon.validate(args) - # self.assertIn( - # TEST_MESSAGES[not_a_number_msg].format(value=args['n_workers']), - # str(validation_messages)) + def test_translate_formatted_string(self): + """Translation: test that f-string can be translated.""" + from natcap.invest import carbon, validation, set_locale + set_locale(TEST_LANG) + importlib.reload(validation) + importlib.reload(carbon) + args = {'n_workers': 'not a number'} + validation_messages = carbon.validate(args) + self.assertIn( + TEST_MESSAGES[not_a_number_msg].format(value=args['n_workers']), + str(validation_messages)) diff --git a/tests/test_validation.py b/tests/test_validation.py index 703f09899..792f86f73 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -295,18 +295,20 @@ class ValidatorTest(unittest.TestCase): from natcap.invest import validation # both args and the kwarg should be passed to the function + @spec_utils.timeout def func(arg1, arg2, kwarg=None): self.assertEqual(kwarg, 'kwarg') time.sleep(1) # this will raise an error if the timeout is exceeded # timeout defaults to 5 seconds so this should pass - spec_utils.timeout(func, 'arg1', 'arg2', kwarg='kwarg') + func('arg1', 'arg2', kwarg='kwarg') def test_timeout_fail(self): from natcap.invest import validation # both args and the kwarg should be passed to the function + @spec_utils.timeout def func(arg): time.sleep(6) @@ -315,7 +317,7 @@ class ValidatorTest(unittest.TestCase): with warnings.catch_warnings(record=True) as ws: # cause all warnings to always be triggered warnings.simplefilter("always") - spec_utils.timeout(func, 'arg') + func('arg') self.assertTrue(len(ws) == 1) self.assertTrue('timed out' in str(ws[0].message)) @@ -924,13 +926,13 @@ class CSVValidation(unittest.TestCase): # define a side effect for the mock that will sleep # for longer than the allowed timeout + @spec_utils.timeout def delay(*args, **kwargs): time.sleep(7) return [] # replace the validation.check_csv with the mock function, and try to validate - with unittest.mock.patch('natcap.invest.spec_utils.CSVInput.validate', - staticmethod(functools.partial(spec_utils.timeout, delay))): + with unittest.mock.patch('natcap.invest.spec_utils.CSVInput.validate', delay): with warnings.catch_warnings(record=True) as ws: # cause all warnings to always be triggered warnings.simplefilter("always")