rewrite timeout to work as a decorator; ModelSpec to_json method

This commit is contained in:
Emily Soth 2025-05-05 14:33:53 -07:00
parent 830c1fbf42
commit fddbac6080
6 changed files with 146 additions and 140 deletions

View File

@ -438,7 +438,7 @@ def main(user_args=None):
importlib.import_module(name=target_model))
spec = model_module.MODEL_SPEC
message = spec_utils.serialize_args_spec(spec)
message = spec.to_json()
sys.stdout.write(message)
parser.exit(0)

View File

@ -31,7 +31,7 @@ LOGGER = logging.getLogger(__name__)
# accessing a file could take a long time if it's in a file streaming service
# to prevent the UI from hanging due to slow validation,
# set a timeout for these functions.
def timeout(func, *args, timeout=5, **kwargs):
def timeout(func, timeout=5):
"""Stop a function after a given amount of time.
Args:
@ -51,26 +51,28 @@ def timeout(func, *args, timeout=5, **kwargs):
# the target function puts the return value from `func` into shared memory
message_queue = queue.Queue()
def wrapper_func():
message_queue.put(func(*args, **kwargs))
def wrapper(*args, **kwargs):
def put_fn():
message_queue.put(func(*args, **kwargs))
thread = threading.Thread(target=put_fn)
LOGGER.debug(f'Starting file checking thread with timeout={timeout}')
thread.start()
thread.join(timeout=timeout)
if thread.is_alive():
# first arg to `check_csv`, `check_raster`, `check_vector` is the path
warnings.warn(
f'Validation of file {args[0]} timed out. If this file '
'is stored in a file streaming service, it may be taking a long '
'time to download. Try storing it locally instead.')
return None
thread = threading.Thread(target=wrapper_func)
LOGGER.debug(f'Starting file checking thread with timeout={timeout}')
thread.start()
thread.join(timeout=timeout)
if thread.is_alive():
# first arg to `check_csv`, `check_raster`, `check_vector` is the path
warnings.warn(
f'Validation of file {args[0]} timed out. If this file '
'is stored in a file streaming service, it may be taking a long '
'time to download. Try storing it locally instead.')
return None
else:
LOGGER.debug('File checking thread completed.')
# get any warning messages returned from the thread
a = message_queue.get()
return a
else:
LOGGER.debug('File checking thread completed.')
# get any warning messages returned from the thread
a = message_queue.get()
return a
return wrapper
def check_headers(expected_headers, actual_headers, header_type='header'):
"""Validate that expected headers are in a list of actual headers.
@ -879,7 +881,6 @@ class UISpec:
hidden: list = None
dropdown_functions: dict = dataclasses.field(default_factory=dict)
@dataclasses.dataclass
class ModelSpec:
model_id: str
@ -898,6 +899,49 @@ class ModelSpec:
def get_input(self, key):
return self.inputs_dict[key]
def to_json(self):
"""Serialize an MODEL_SPEC dict to a JSON string.
Args:
spec (dict): An invest model's MODEL_SPEC.
Raises:
TypeError if any object type within the spec is not handled by
json.dumps or by the fallback serializer.
Returns:
JSON String
"""
def fallback_serializer(obj):
"""Serialize objects that are otherwise not JSON serializeable."""
if isinstance(obj, pint.Unit):
return format_unit(obj)
# Sets are present in 'geometries' attributes of some args
# We don't need to worry about deserializing back to a set/array
# so casting to string is okay.
elif isinstance(obj, set):
return str(obj)
elif isinstance(obj, types.FunctionType):
return str(obj)
elif dataclasses.is_dataclass(obj):
as_dict = dataclasses.asdict(obj)
if hasattr(obj, 'type'):
as_dict['type'] = obj.type
return as_dict
elif isinstance(obj, IterableWithDotAccess):
return obj.to_json()
raise TypeError(f'fallback serializer is missing for {type(obj)}')
spec_dict = self.__dict__.copy()
# rename 'inputs' to 'args' to stay consistent with the old api
spec_dict.pop('inputs')
spec_dict.pop('inputs_dict')
spec_dict.pop('outputs_dict')
spec_dict['args'] = self.inputs_dict
spec_dict['outputs'] = self.outputs_dict
return json.dumps(spec_dict, default=fallback_serializer, ensure_ascii=False)
def build_model_spec(model_spec):
inputs = [
@ -1370,46 +1414,6 @@ def format_unit(unit):
formatted_unit = formatted_unit.replace('currency', gettext('currency units'))
return formatted_unit
def serialize_args_spec(spec):
"""Serialize an MODEL_SPEC dict to a JSON string.
Args:
spec (dict): An invest model's MODEL_SPEC.
Raises:
TypeError if any object type within the spec is not handled by
json.dumps or by the fallback serializer.
Returns:
JSON String
"""
def fallback_serializer(obj):
"""Serialize objects that are otherwise not JSON serializeable."""
if isinstance(obj, pint.Unit):
return format_unit(obj)
# Sets are present in 'geometries' attributes of some args
# We don't need to worry about deserializing back to a set/array
# so casting to string is okay.
elif isinstance(obj, set):
return str(obj)
elif isinstance(obj, types.FunctionType):
return str(obj)
elif dataclasses.is_dataclass(obj):
as_dict = dataclasses.asdict(obj)
if hasattr(obj, 'type'):
as_dict['type'] = obj.type
return as_dict
elif isinstance(obj, IterableWithDotAccess):
return obj.to_json()
raise TypeError(f'fallback serializer is missing for {type(obj)}')
spec_dict = json.loads(json.dumps(spec, default=fallback_serializer, ensure_ascii=False))
spec_dict['args'] = spec_dict.pop('inputs')
return json.dumps(spec_dict, ensure_ascii=False)
# accepted geometries for a vector will be displayed in this order
GEOMETRY_ORDER = [
'POINT',

View File

@ -68,7 +68,7 @@ def get_invest_getspec():
importlib.reload(natcap.invest.validation)
model_module = importlib.reload(
importlib.import_module(name=target_module))
return spec_utils.serialize_args_spec(model_module.MODEL_SPEC)
return model_module.MODEL_SPEC.to_json()
@app.route(f'/{PREFIX}/dynamic_dropdowns', methods=['POST'])

View File

@ -509,7 +509,7 @@ class ValidateModelSpecs(unittest.TestCase):
for pyname in model_id_to_pyname.values():
model = importlib.import_module(pyname)
spec_utils.serialize_args_spec(model.MODEL_SPEC)
model.MODEL_SPEC.to_json()
class SpecUtilsTests(unittest.TestCase):

View File

@ -76,16 +76,16 @@ class TranslationTests(unittest.TestCase):
self.locales_patcher.stop()
shutil.rmtree(self.workspace_dir)
# def test_invest_list(self):
# """Translation: test that CLI list output is translated."""
# from natcap.invest import cli
# with patch('sys.stdout', new=io.StringIO()) as out:
# with self.assertRaises(SystemExit):
# cli.main(['--language', TEST_LANG, 'list'])
# result = out.getvalue()
# self.assertIn(TEST_MESSAGES['Available models:'], result)
# self.assertIn(
# TEST_MESSAGES['Carbon Storage and Sequestration'], result)
def test_invest_list(self):
"""Translation: test that CLI list output is translated."""
from natcap.invest import cli
with patch('sys.stdout', new=io.StringIO()) as out:
with self.assertRaises(SystemExit):
cli.main(['--language', TEST_LANG, 'list'])
result = out.getvalue()
self.assertIn(TEST_MESSAGES['Available models:'], result)
self.assertIn(
TEST_MESSAGES['Carbon Storage and Sequestration'], result)
def test_invest_getspec(self):
"""Translation: test that CLI getspec output is translated."""
@ -96,72 +96,72 @@ class TranslationTests(unittest.TestCase):
result = out.getvalue()
self.assertIn(TEST_MESSAGES['baseline LULC'], result)
# def test_invest_validate(self):
# """Translation: test that CLI validate output is translated."""
# datastack = { # write datastack to a JSON file
# 'model_id': 'carbon',
# 'invest_version': '0.0',
# 'args': {}
# }
# datastack_path = os.path.join(self.workspace_dir, 'datastack.json')
# with open(datastack_path, 'w') as file:
# json.dump(datastack, file)
def test_invest_validate(self):
"""Translation: test that CLI validate output is translated."""
datastack = { # write datastack to a JSON file
'model_id': 'carbon',
'invest_version': '0.0',
'args': {}
}
datastack_path = os.path.join(self.workspace_dir, 'datastack.json')
with open(datastack_path, 'w') as file:
json.dump(datastack, file)
# from natcap.invest import cli
# with patch('sys.stdout', new=io.StringIO()) as out:
# with self.assertRaises(SystemExit):
# cli.main(
# ['--language', TEST_LANG, 'validate', datastack_path])
from natcap.invest import cli
with patch('sys.stdout', new=io.StringIO()) as out:
with self.assertRaises(SystemExit):
cli.main(
['--language', TEST_LANG, 'validate', datastack_path])
# result = out.getvalue()
# self.assertIn(TEST_MESSAGES[missing_key_msg], result)
result = out.getvalue()
self.assertIn(TEST_MESSAGES[missing_key_msg], result)
# def test_server_get_invest_models(self):
# """Translation: test that /models endpoint is translated."""
# from natcap.invest import ui_server
# test_client = ui_server.app.test_client()
# response = test_client.get(
# 'api/models', query_string={'language': TEST_LANG})
# result = json.loads(response.get_data(as_text=True))
# self.assertIn(
# TEST_MESSAGES['Carbon Storage and Sequestration'],
# [val['model_title'] for val in result.values()])
def test_server_get_invest_models(self):
"""Translation: test that /models endpoint is translated."""
from natcap.invest import ui_server
test_client = ui_server.app.test_client()
response = test_client.get(
'api/models', query_string={'language': TEST_LANG})
result = json.loads(response.get_data(as_text=True))
self.assertIn(
TEST_MESSAGES['Carbon Storage and Sequestration'],
[val['model_title'] for val in result.values()])
# def test_server_get_invest_getspec(self):
# """Translation: test that /getspec endpoint is translated."""
# from natcap.invest import ui_server
# test_client = ui_server.app.test_client()
# response = test_client.post(
# 'api/getspec', json='carbon', query_string={'language': TEST_LANG})
# spec = json.loads(response.get_data(as_text=True))
# self.assertEqual(
# spec['inputs']['lulc_bas_path']['name'],
# TEST_MESSAGES['baseline LULC'])
def test_server_get_invest_getspec(self):
"""Translation: test that /getspec endpoint is translated."""
from natcap.invest import ui_server
test_client = ui_server.app.test_client()
response = test_client.post(
'api/getspec', json='carbon', query_string={'language': TEST_LANG})
spec = json.loads(response.get_data(as_text=True))
self.assertEqual(
spec['args']['lulc_bas_path']['name'],
TEST_MESSAGES['baseline LULC'])
# def test_server_get_invest_validate(self):
# """Translation: test that /validate endpoint is translated."""
# from natcap.invest import ui_server
# from natcap.invest import carbon
# test_client = ui_server.app.test_client()
# payload = {
# 'model_id': carbon.MODEL_SPEC.model_id,
# 'args': json.dumps({})
# }
# response = test_client.post(
# 'api/validate', json=payload,
# query_string={'language': TEST_LANG})
# results = json.loads(response.get_data(as_text=True))
# messages = [item[1] for item in results]
# self.assertIn(TEST_MESSAGES[missing_key_msg], messages)
def test_server_get_invest_validate(self):
"""Translation: test that /validate endpoint is translated."""
from natcap.invest import ui_server
from natcap.invest import carbon
test_client = ui_server.app.test_client()
payload = {
'model_id': carbon.MODEL_SPEC.model_id,
'args': json.dumps({})
}
response = test_client.post(
'api/validate', json=payload,
query_string={'language': TEST_LANG})
results = json.loads(response.get_data(as_text=True))
messages = [item[1] for item in results]
self.assertIn(TEST_MESSAGES[missing_key_msg], messages)
# def test_translate_formatted_string(self):
# """Translation: test that f-string can be translated."""
# from natcap.invest import carbon, validation, set_locale
# set_locale(TEST_LANG)
# importlib.reload(validation)
# importlib.reload(carbon)
# args = {'n_workers': 'not a number'}
# validation_messages = carbon.validate(args)
# self.assertIn(
# TEST_MESSAGES[not_a_number_msg].format(value=args['n_workers']),
# str(validation_messages))
def test_translate_formatted_string(self):
"""Translation: test that f-string can be translated."""
from natcap.invest import carbon, validation, set_locale
set_locale(TEST_LANG)
importlib.reload(validation)
importlib.reload(carbon)
args = {'n_workers': 'not a number'}
validation_messages = carbon.validate(args)
self.assertIn(
TEST_MESSAGES[not_a_number_msg].format(value=args['n_workers']),
str(validation_messages))

View File

@ -295,18 +295,20 @@ class ValidatorTest(unittest.TestCase):
from natcap.invest import validation
# both args and the kwarg should be passed to the function
@spec_utils.timeout
def func(arg1, arg2, kwarg=None):
self.assertEqual(kwarg, 'kwarg')
time.sleep(1)
# this will raise an error if the timeout is exceeded
# timeout defaults to 5 seconds so this should pass
spec_utils.timeout(func, 'arg1', 'arg2', kwarg='kwarg')
func('arg1', 'arg2', kwarg='kwarg')
def test_timeout_fail(self):
from natcap.invest import validation
# both args and the kwarg should be passed to the function
@spec_utils.timeout
def func(arg):
time.sleep(6)
@ -315,7 +317,7 @@ class ValidatorTest(unittest.TestCase):
with warnings.catch_warnings(record=True) as ws:
# cause all warnings to always be triggered
warnings.simplefilter("always")
spec_utils.timeout(func, 'arg')
func('arg')
self.assertTrue(len(ws) == 1)
self.assertTrue('timed out' in str(ws[0].message))
@ -924,13 +926,13 @@ class CSVValidation(unittest.TestCase):
# define a side effect for the mock that will sleep
# for longer than the allowed timeout
@spec_utils.timeout
def delay(*args, **kwargs):
time.sleep(7)
return []
# replace the validation.check_csv with the mock function, and try to validate
with unittest.mock.patch('natcap.invest.spec_utils.CSVInput.validate',
staticmethod(functools.partial(spec_utils.timeout, delay))):
with unittest.mock.patch('natcap.invest.spec_utils.CSVInput.validate', delay):
with warnings.catch_warnings(record=True) as ws:
# cause all warnings to always be triggered
warnings.simplefilter("always")