rewrite timeout to work as a decorator; ModelSpec to_json method
This commit is contained in:
parent
830c1fbf42
commit
fddbac6080
|
@ -438,7 +438,7 @@ def main(user_args=None):
|
|||
importlib.import_module(name=target_model))
|
||||
spec = model_module.MODEL_SPEC
|
||||
|
||||
message = spec_utils.serialize_args_spec(spec)
|
||||
message = spec.to_json()
|
||||
sys.stdout.write(message)
|
||||
parser.exit(0)
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ LOGGER = logging.getLogger(__name__)
|
|||
# accessing a file could take a long time if it's in a file streaming service
|
||||
# to prevent the UI from hanging due to slow validation,
|
||||
# set a timeout for these functions.
|
||||
def timeout(func, *args, timeout=5, **kwargs):
|
||||
def timeout(func, timeout=5):
|
||||
"""Stop a function after a given amount of time.
|
||||
|
||||
Args:
|
||||
|
@ -51,26 +51,28 @@ def timeout(func, *args, timeout=5, **kwargs):
|
|||
# the target function puts the return value from `func` into shared memory
|
||||
message_queue = queue.Queue()
|
||||
|
||||
def wrapper_func():
|
||||
message_queue.put(func(*args, **kwargs))
|
||||
def wrapper(*args, **kwargs):
|
||||
def put_fn():
|
||||
message_queue.put(func(*args, **kwargs))
|
||||
thread = threading.Thread(target=put_fn)
|
||||
LOGGER.debug(f'Starting file checking thread with timeout={timeout}')
|
||||
thread.start()
|
||||
thread.join(timeout=timeout)
|
||||
if thread.is_alive():
|
||||
# first arg to `check_csv`, `check_raster`, `check_vector` is the path
|
||||
warnings.warn(
|
||||
f'Validation of file {args[0]} timed out. If this file '
|
||||
'is stored in a file streaming service, it may be taking a long '
|
||||
'time to download. Try storing it locally instead.')
|
||||
return None
|
||||
|
||||
thread = threading.Thread(target=wrapper_func)
|
||||
LOGGER.debug(f'Starting file checking thread with timeout={timeout}')
|
||||
thread.start()
|
||||
thread.join(timeout=timeout)
|
||||
if thread.is_alive():
|
||||
# first arg to `check_csv`, `check_raster`, `check_vector` is the path
|
||||
warnings.warn(
|
||||
f'Validation of file {args[0]} timed out. If this file '
|
||||
'is stored in a file streaming service, it may be taking a long '
|
||||
'time to download. Try storing it locally instead.')
|
||||
return None
|
||||
else:
|
||||
LOGGER.debug('File checking thread completed.')
|
||||
# get any warning messages returned from the thread
|
||||
a = message_queue.get()
|
||||
return a
|
||||
|
||||
else:
|
||||
LOGGER.debug('File checking thread completed.')
|
||||
# get any warning messages returned from the thread
|
||||
a = message_queue.get()
|
||||
return a
|
||||
return wrapper
|
||||
|
||||
def check_headers(expected_headers, actual_headers, header_type='header'):
|
||||
"""Validate that expected headers are in a list of actual headers.
|
||||
|
@ -879,7 +881,6 @@ class UISpec:
|
|||
hidden: list = None
|
||||
dropdown_functions: dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelSpec:
|
||||
model_id: str
|
||||
|
@ -898,6 +899,49 @@ class ModelSpec:
|
|||
def get_input(self, key):
|
||||
return self.inputs_dict[key]
|
||||
|
||||
def to_json(self):
|
||||
"""Serialize an MODEL_SPEC dict to a JSON string.
|
||||
|
||||
Args:
|
||||
spec (dict): An invest model's MODEL_SPEC.
|
||||
|
||||
Raises:
|
||||
TypeError if any object type within the spec is not handled by
|
||||
json.dumps or by the fallback serializer.
|
||||
|
||||
Returns:
|
||||
JSON String
|
||||
"""
|
||||
|
||||
def fallback_serializer(obj):
|
||||
"""Serialize objects that are otherwise not JSON serializeable."""
|
||||
if isinstance(obj, pint.Unit):
|
||||
return format_unit(obj)
|
||||
# Sets are present in 'geometries' attributes of some args
|
||||
# We don't need to worry about deserializing back to a set/array
|
||||
# so casting to string is okay.
|
||||
elif isinstance(obj, set):
|
||||
return str(obj)
|
||||
elif isinstance(obj, types.FunctionType):
|
||||
return str(obj)
|
||||
elif dataclasses.is_dataclass(obj):
|
||||
as_dict = dataclasses.asdict(obj)
|
||||
if hasattr(obj, 'type'):
|
||||
as_dict['type'] = obj.type
|
||||
return as_dict
|
||||
elif isinstance(obj, IterableWithDotAccess):
|
||||
return obj.to_json()
|
||||
raise TypeError(f'fallback serializer is missing for {type(obj)}')
|
||||
|
||||
spec_dict = self.__dict__.copy()
|
||||
# rename 'inputs' to 'args' to stay consistent with the old api
|
||||
spec_dict.pop('inputs')
|
||||
spec_dict.pop('inputs_dict')
|
||||
spec_dict.pop('outputs_dict')
|
||||
spec_dict['args'] = self.inputs_dict
|
||||
spec_dict['outputs'] = self.outputs_dict
|
||||
return json.dumps(spec_dict, default=fallback_serializer, ensure_ascii=False)
|
||||
|
||||
|
||||
def build_model_spec(model_spec):
|
||||
inputs = [
|
||||
|
@ -1370,46 +1414,6 @@ def format_unit(unit):
|
|||
formatted_unit = formatted_unit.replace('currency', gettext('currency units'))
|
||||
return formatted_unit
|
||||
|
||||
|
||||
def serialize_args_spec(spec):
|
||||
"""Serialize an MODEL_SPEC dict to a JSON string.
|
||||
|
||||
Args:
|
||||
spec (dict): An invest model's MODEL_SPEC.
|
||||
|
||||
Raises:
|
||||
TypeError if any object type within the spec is not handled by
|
||||
json.dumps or by the fallback serializer.
|
||||
|
||||
Returns:
|
||||
JSON String
|
||||
"""
|
||||
|
||||
def fallback_serializer(obj):
|
||||
"""Serialize objects that are otherwise not JSON serializeable."""
|
||||
if isinstance(obj, pint.Unit):
|
||||
return format_unit(obj)
|
||||
# Sets are present in 'geometries' attributes of some args
|
||||
# We don't need to worry about deserializing back to a set/array
|
||||
# so casting to string is okay.
|
||||
elif isinstance(obj, set):
|
||||
return str(obj)
|
||||
elif isinstance(obj, types.FunctionType):
|
||||
return str(obj)
|
||||
elif dataclasses.is_dataclass(obj):
|
||||
as_dict = dataclasses.asdict(obj)
|
||||
if hasattr(obj, 'type'):
|
||||
as_dict['type'] = obj.type
|
||||
return as_dict
|
||||
elif isinstance(obj, IterableWithDotAccess):
|
||||
return obj.to_json()
|
||||
raise TypeError(f'fallback serializer is missing for {type(obj)}')
|
||||
|
||||
spec_dict = json.loads(json.dumps(spec, default=fallback_serializer, ensure_ascii=False))
|
||||
spec_dict['args'] = spec_dict.pop('inputs')
|
||||
return json.dumps(spec_dict, ensure_ascii=False)
|
||||
|
||||
|
||||
# accepted geometries for a vector will be displayed in this order
|
||||
GEOMETRY_ORDER = [
|
||||
'POINT',
|
||||
|
|
|
@ -68,7 +68,7 @@ def get_invest_getspec():
|
|||
importlib.reload(natcap.invest.validation)
|
||||
model_module = importlib.reload(
|
||||
importlib.import_module(name=target_module))
|
||||
return spec_utils.serialize_args_spec(model_module.MODEL_SPEC)
|
||||
return model_module.MODEL_SPEC.to_json()
|
||||
|
||||
|
||||
@app.route(f'/{PREFIX}/dynamic_dropdowns', methods=['POST'])
|
||||
|
|
|
@ -509,7 +509,7 @@ class ValidateModelSpecs(unittest.TestCase):
|
|||
|
||||
for pyname in model_id_to_pyname.values():
|
||||
model = importlib.import_module(pyname)
|
||||
spec_utils.serialize_args_spec(model.MODEL_SPEC)
|
||||
model.MODEL_SPEC.to_json()
|
||||
|
||||
|
||||
class SpecUtilsTests(unittest.TestCase):
|
||||
|
|
|
@ -76,16 +76,16 @@ class TranslationTests(unittest.TestCase):
|
|||
self.locales_patcher.stop()
|
||||
shutil.rmtree(self.workspace_dir)
|
||||
|
||||
# def test_invest_list(self):
|
||||
# """Translation: test that CLI list output is translated."""
|
||||
# from natcap.invest import cli
|
||||
# with patch('sys.stdout', new=io.StringIO()) as out:
|
||||
# with self.assertRaises(SystemExit):
|
||||
# cli.main(['--language', TEST_LANG, 'list'])
|
||||
# result = out.getvalue()
|
||||
# self.assertIn(TEST_MESSAGES['Available models:'], result)
|
||||
# self.assertIn(
|
||||
# TEST_MESSAGES['Carbon Storage and Sequestration'], result)
|
||||
def test_invest_list(self):
|
||||
"""Translation: test that CLI list output is translated."""
|
||||
from natcap.invest import cli
|
||||
with patch('sys.stdout', new=io.StringIO()) as out:
|
||||
with self.assertRaises(SystemExit):
|
||||
cli.main(['--language', TEST_LANG, 'list'])
|
||||
result = out.getvalue()
|
||||
self.assertIn(TEST_MESSAGES['Available models:'], result)
|
||||
self.assertIn(
|
||||
TEST_MESSAGES['Carbon Storage and Sequestration'], result)
|
||||
|
||||
def test_invest_getspec(self):
|
||||
"""Translation: test that CLI getspec output is translated."""
|
||||
|
@ -96,72 +96,72 @@ class TranslationTests(unittest.TestCase):
|
|||
result = out.getvalue()
|
||||
self.assertIn(TEST_MESSAGES['baseline LULC'], result)
|
||||
|
||||
# def test_invest_validate(self):
|
||||
# """Translation: test that CLI validate output is translated."""
|
||||
# datastack = { # write datastack to a JSON file
|
||||
# 'model_id': 'carbon',
|
||||
# 'invest_version': '0.0',
|
||||
# 'args': {}
|
||||
# }
|
||||
# datastack_path = os.path.join(self.workspace_dir, 'datastack.json')
|
||||
# with open(datastack_path, 'w') as file:
|
||||
# json.dump(datastack, file)
|
||||
def test_invest_validate(self):
|
||||
"""Translation: test that CLI validate output is translated."""
|
||||
datastack = { # write datastack to a JSON file
|
||||
'model_id': 'carbon',
|
||||
'invest_version': '0.0',
|
||||
'args': {}
|
||||
}
|
||||
datastack_path = os.path.join(self.workspace_dir, 'datastack.json')
|
||||
with open(datastack_path, 'w') as file:
|
||||
json.dump(datastack, file)
|
||||
|
||||
# from natcap.invest import cli
|
||||
# with patch('sys.stdout', new=io.StringIO()) as out:
|
||||
# with self.assertRaises(SystemExit):
|
||||
# cli.main(
|
||||
# ['--language', TEST_LANG, 'validate', datastack_path])
|
||||
from natcap.invest import cli
|
||||
with patch('sys.stdout', new=io.StringIO()) as out:
|
||||
with self.assertRaises(SystemExit):
|
||||
cli.main(
|
||||
['--language', TEST_LANG, 'validate', datastack_path])
|
||||
|
||||
# result = out.getvalue()
|
||||
# self.assertIn(TEST_MESSAGES[missing_key_msg], result)
|
||||
result = out.getvalue()
|
||||
self.assertIn(TEST_MESSAGES[missing_key_msg], result)
|
||||
|
||||
# def test_server_get_invest_models(self):
|
||||
# """Translation: test that /models endpoint is translated."""
|
||||
# from natcap.invest import ui_server
|
||||
# test_client = ui_server.app.test_client()
|
||||
# response = test_client.get(
|
||||
# 'api/models', query_string={'language': TEST_LANG})
|
||||
# result = json.loads(response.get_data(as_text=True))
|
||||
# self.assertIn(
|
||||
# TEST_MESSAGES['Carbon Storage and Sequestration'],
|
||||
# [val['model_title'] for val in result.values()])
|
||||
def test_server_get_invest_models(self):
|
||||
"""Translation: test that /models endpoint is translated."""
|
||||
from natcap.invest import ui_server
|
||||
test_client = ui_server.app.test_client()
|
||||
response = test_client.get(
|
||||
'api/models', query_string={'language': TEST_LANG})
|
||||
result = json.loads(response.get_data(as_text=True))
|
||||
self.assertIn(
|
||||
TEST_MESSAGES['Carbon Storage and Sequestration'],
|
||||
[val['model_title'] for val in result.values()])
|
||||
|
||||
# def test_server_get_invest_getspec(self):
|
||||
# """Translation: test that /getspec endpoint is translated."""
|
||||
# from natcap.invest import ui_server
|
||||
# test_client = ui_server.app.test_client()
|
||||
# response = test_client.post(
|
||||
# 'api/getspec', json='carbon', query_string={'language': TEST_LANG})
|
||||
# spec = json.loads(response.get_data(as_text=True))
|
||||
# self.assertEqual(
|
||||
# spec['inputs']['lulc_bas_path']['name'],
|
||||
# TEST_MESSAGES['baseline LULC'])
|
||||
def test_server_get_invest_getspec(self):
|
||||
"""Translation: test that /getspec endpoint is translated."""
|
||||
from natcap.invest import ui_server
|
||||
test_client = ui_server.app.test_client()
|
||||
response = test_client.post(
|
||||
'api/getspec', json='carbon', query_string={'language': TEST_LANG})
|
||||
spec = json.loads(response.get_data(as_text=True))
|
||||
self.assertEqual(
|
||||
spec['args']['lulc_bas_path']['name'],
|
||||
TEST_MESSAGES['baseline LULC'])
|
||||
|
||||
# def test_server_get_invest_validate(self):
|
||||
# """Translation: test that /validate endpoint is translated."""
|
||||
# from natcap.invest import ui_server
|
||||
# from natcap.invest import carbon
|
||||
# test_client = ui_server.app.test_client()
|
||||
# payload = {
|
||||
# 'model_id': carbon.MODEL_SPEC.model_id,
|
||||
# 'args': json.dumps({})
|
||||
# }
|
||||
# response = test_client.post(
|
||||
# 'api/validate', json=payload,
|
||||
# query_string={'language': TEST_LANG})
|
||||
# results = json.loads(response.get_data(as_text=True))
|
||||
# messages = [item[1] for item in results]
|
||||
# self.assertIn(TEST_MESSAGES[missing_key_msg], messages)
|
||||
def test_server_get_invest_validate(self):
|
||||
"""Translation: test that /validate endpoint is translated."""
|
||||
from natcap.invest import ui_server
|
||||
from natcap.invest import carbon
|
||||
test_client = ui_server.app.test_client()
|
||||
payload = {
|
||||
'model_id': carbon.MODEL_SPEC.model_id,
|
||||
'args': json.dumps({})
|
||||
}
|
||||
response = test_client.post(
|
||||
'api/validate', json=payload,
|
||||
query_string={'language': TEST_LANG})
|
||||
results = json.loads(response.get_data(as_text=True))
|
||||
messages = [item[1] for item in results]
|
||||
self.assertIn(TEST_MESSAGES[missing_key_msg], messages)
|
||||
|
||||
# def test_translate_formatted_string(self):
|
||||
# """Translation: test that f-string can be translated."""
|
||||
# from natcap.invest import carbon, validation, set_locale
|
||||
# set_locale(TEST_LANG)
|
||||
# importlib.reload(validation)
|
||||
# importlib.reload(carbon)
|
||||
# args = {'n_workers': 'not a number'}
|
||||
# validation_messages = carbon.validate(args)
|
||||
# self.assertIn(
|
||||
# TEST_MESSAGES[not_a_number_msg].format(value=args['n_workers']),
|
||||
# str(validation_messages))
|
||||
def test_translate_formatted_string(self):
|
||||
"""Translation: test that f-string can be translated."""
|
||||
from natcap.invest import carbon, validation, set_locale
|
||||
set_locale(TEST_LANG)
|
||||
importlib.reload(validation)
|
||||
importlib.reload(carbon)
|
||||
args = {'n_workers': 'not a number'}
|
||||
validation_messages = carbon.validate(args)
|
||||
self.assertIn(
|
||||
TEST_MESSAGES[not_a_number_msg].format(value=args['n_workers']),
|
||||
str(validation_messages))
|
||||
|
|
|
@ -295,18 +295,20 @@ class ValidatorTest(unittest.TestCase):
|
|||
from natcap.invest import validation
|
||||
|
||||
# both args and the kwarg should be passed to the function
|
||||
@spec_utils.timeout
|
||||
def func(arg1, arg2, kwarg=None):
|
||||
self.assertEqual(kwarg, 'kwarg')
|
||||
time.sleep(1)
|
||||
|
||||
# this will raise an error if the timeout is exceeded
|
||||
# timeout defaults to 5 seconds so this should pass
|
||||
spec_utils.timeout(func, 'arg1', 'arg2', kwarg='kwarg')
|
||||
func('arg1', 'arg2', kwarg='kwarg')
|
||||
|
||||
def test_timeout_fail(self):
|
||||
from natcap.invest import validation
|
||||
|
||||
# both args and the kwarg should be passed to the function
|
||||
@spec_utils.timeout
|
||||
def func(arg):
|
||||
time.sleep(6)
|
||||
|
||||
|
@ -315,7 +317,7 @@ class ValidatorTest(unittest.TestCase):
|
|||
with warnings.catch_warnings(record=True) as ws:
|
||||
# cause all warnings to always be triggered
|
||||
warnings.simplefilter("always")
|
||||
spec_utils.timeout(func, 'arg')
|
||||
func('arg')
|
||||
self.assertTrue(len(ws) == 1)
|
||||
self.assertTrue('timed out' in str(ws[0].message))
|
||||
|
||||
|
@ -924,13 +926,13 @@ class CSVValidation(unittest.TestCase):
|
|||
|
||||
# define a side effect for the mock that will sleep
|
||||
# for longer than the allowed timeout
|
||||
@spec_utils.timeout
|
||||
def delay(*args, **kwargs):
|
||||
time.sleep(7)
|
||||
return []
|
||||
|
||||
# replace the validation.check_csv with the mock function, and try to validate
|
||||
with unittest.mock.patch('natcap.invest.spec_utils.CSVInput.validate',
|
||||
staticmethod(functools.partial(spec_utils.timeout, delay))):
|
||||
with unittest.mock.patch('natcap.invest.spec_utils.CSVInput.validate', delay):
|
||||
with warnings.catch_warnings(record=True) as ws:
|
||||
# cause all warnings to always be triggered
|
||||
warnings.simplefilter("always")
|
||||
|
|
Loading…
Reference in New Issue