rewrite timeout to work as a decorator; ModelSpec to_json method
This commit is contained in:
parent
830c1fbf42
commit
fddbac6080
|
@ -438,7 +438,7 @@ def main(user_args=None):
|
||||||
importlib.import_module(name=target_model))
|
importlib.import_module(name=target_model))
|
||||||
spec = model_module.MODEL_SPEC
|
spec = model_module.MODEL_SPEC
|
||||||
|
|
||||||
message = spec_utils.serialize_args_spec(spec)
|
message = spec.to_json()
|
||||||
sys.stdout.write(message)
|
sys.stdout.write(message)
|
||||||
parser.exit(0)
|
parser.exit(0)
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ LOGGER = logging.getLogger(__name__)
|
||||||
# accessing a file could take a long time if it's in a file streaming service
|
# accessing a file could take a long time if it's in a file streaming service
|
||||||
# to prevent the UI from hanging due to slow validation,
|
# to prevent the UI from hanging due to slow validation,
|
||||||
# set a timeout for these functions.
|
# set a timeout for these functions.
|
||||||
def timeout(func, *args, timeout=5, **kwargs):
|
def timeout(func, timeout=5):
|
||||||
"""Stop a function after a given amount of time.
|
"""Stop a function after a given amount of time.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -51,26 +51,28 @@ def timeout(func, *args, timeout=5, **kwargs):
|
||||||
# the target function puts the return value from `func` into shared memory
|
# the target function puts the return value from `func` into shared memory
|
||||||
message_queue = queue.Queue()
|
message_queue = queue.Queue()
|
||||||
|
|
||||||
def wrapper_func():
|
def wrapper(*args, **kwargs):
|
||||||
message_queue.put(func(*args, **kwargs))
|
def put_fn():
|
||||||
|
message_queue.put(func(*args, **kwargs))
|
||||||
|
thread = threading.Thread(target=put_fn)
|
||||||
|
LOGGER.debug(f'Starting file checking thread with timeout={timeout}')
|
||||||
|
thread.start()
|
||||||
|
thread.join(timeout=timeout)
|
||||||
|
if thread.is_alive():
|
||||||
|
# first arg to `check_csv`, `check_raster`, `check_vector` is the path
|
||||||
|
warnings.warn(
|
||||||
|
f'Validation of file {args[0]} timed out. If this file '
|
||||||
|
'is stored in a file streaming service, it may be taking a long '
|
||||||
|
'time to download. Try storing it locally instead.')
|
||||||
|
return None
|
||||||
|
|
||||||
thread = threading.Thread(target=wrapper_func)
|
else:
|
||||||
LOGGER.debug(f'Starting file checking thread with timeout={timeout}')
|
LOGGER.debug('File checking thread completed.')
|
||||||
thread.start()
|
# get any warning messages returned from the thread
|
||||||
thread.join(timeout=timeout)
|
a = message_queue.get()
|
||||||
if thread.is_alive():
|
return a
|
||||||
# first arg to `check_csv`, `check_raster`, `check_vector` is the path
|
|
||||||
warnings.warn(
|
|
||||||
f'Validation of file {args[0]} timed out. If this file '
|
|
||||||
'is stored in a file streaming service, it may be taking a long '
|
|
||||||
'time to download. Try storing it locally instead.')
|
|
||||||
return None
|
|
||||||
|
|
||||||
else:
|
return wrapper
|
||||||
LOGGER.debug('File checking thread completed.')
|
|
||||||
# get any warning messages returned from the thread
|
|
||||||
a = message_queue.get()
|
|
||||||
return a
|
|
||||||
|
|
||||||
def check_headers(expected_headers, actual_headers, header_type='header'):
|
def check_headers(expected_headers, actual_headers, header_type='header'):
|
||||||
"""Validate that expected headers are in a list of actual headers.
|
"""Validate that expected headers are in a list of actual headers.
|
||||||
|
@ -879,7 +881,6 @@ class UISpec:
|
||||||
hidden: list = None
|
hidden: list = None
|
||||||
dropdown_functions: dict = dataclasses.field(default_factory=dict)
|
dropdown_functions: dict = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ModelSpec:
|
class ModelSpec:
|
||||||
model_id: str
|
model_id: str
|
||||||
|
@ -898,6 +899,49 @@ class ModelSpec:
|
||||||
def get_input(self, key):
|
def get_input(self, key):
|
||||||
return self.inputs_dict[key]
|
return self.inputs_dict[key]
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
"""Serialize an MODEL_SPEC dict to a JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec (dict): An invest model's MODEL_SPEC.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError if any object type within the spec is not handled by
|
||||||
|
json.dumps or by the fallback serializer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON String
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fallback_serializer(obj):
|
||||||
|
"""Serialize objects that are otherwise not JSON serializeable."""
|
||||||
|
if isinstance(obj, pint.Unit):
|
||||||
|
return format_unit(obj)
|
||||||
|
# Sets are present in 'geometries' attributes of some args
|
||||||
|
# We don't need to worry about deserializing back to a set/array
|
||||||
|
# so casting to string is okay.
|
||||||
|
elif isinstance(obj, set):
|
||||||
|
return str(obj)
|
||||||
|
elif isinstance(obj, types.FunctionType):
|
||||||
|
return str(obj)
|
||||||
|
elif dataclasses.is_dataclass(obj):
|
||||||
|
as_dict = dataclasses.asdict(obj)
|
||||||
|
if hasattr(obj, 'type'):
|
||||||
|
as_dict['type'] = obj.type
|
||||||
|
return as_dict
|
||||||
|
elif isinstance(obj, IterableWithDotAccess):
|
||||||
|
return obj.to_json()
|
||||||
|
raise TypeError(f'fallback serializer is missing for {type(obj)}')
|
||||||
|
|
||||||
|
spec_dict = self.__dict__.copy()
|
||||||
|
# rename 'inputs' to 'args' to stay consistent with the old api
|
||||||
|
spec_dict.pop('inputs')
|
||||||
|
spec_dict.pop('inputs_dict')
|
||||||
|
spec_dict.pop('outputs_dict')
|
||||||
|
spec_dict['args'] = self.inputs_dict
|
||||||
|
spec_dict['outputs'] = self.outputs_dict
|
||||||
|
return json.dumps(spec_dict, default=fallback_serializer, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def build_model_spec(model_spec):
|
def build_model_spec(model_spec):
|
||||||
inputs = [
|
inputs = [
|
||||||
|
@ -1370,46 +1414,6 @@ def format_unit(unit):
|
||||||
formatted_unit = formatted_unit.replace('currency', gettext('currency units'))
|
formatted_unit = formatted_unit.replace('currency', gettext('currency units'))
|
||||||
return formatted_unit
|
return formatted_unit
|
||||||
|
|
||||||
|
|
||||||
def serialize_args_spec(spec):
|
|
||||||
"""Serialize an MODEL_SPEC dict to a JSON string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec (dict): An invest model's MODEL_SPEC.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError if any object type within the spec is not handled by
|
|
||||||
json.dumps or by the fallback serializer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON String
|
|
||||||
"""
|
|
||||||
|
|
||||||
def fallback_serializer(obj):
|
|
||||||
"""Serialize objects that are otherwise not JSON serializeable."""
|
|
||||||
if isinstance(obj, pint.Unit):
|
|
||||||
return format_unit(obj)
|
|
||||||
# Sets are present in 'geometries' attributes of some args
|
|
||||||
# We don't need to worry about deserializing back to a set/array
|
|
||||||
# so casting to string is okay.
|
|
||||||
elif isinstance(obj, set):
|
|
||||||
return str(obj)
|
|
||||||
elif isinstance(obj, types.FunctionType):
|
|
||||||
return str(obj)
|
|
||||||
elif dataclasses.is_dataclass(obj):
|
|
||||||
as_dict = dataclasses.asdict(obj)
|
|
||||||
if hasattr(obj, 'type'):
|
|
||||||
as_dict['type'] = obj.type
|
|
||||||
return as_dict
|
|
||||||
elif isinstance(obj, IterableWithDotAccess):
|
|
||||||
return obj.to_json()
|
|
||||||
raise TypeError(f'fallback serializer is missing for {type(obj)}')
|
|
||||||
|
|
||||||
spec_dict = json.loads(json.dumps(spec, default=fallback_serializer, ensure_ascii=False))
|
|
||||||
spec_dict['args'] = spec_dict.pop('inputs')
|
|
||||||
return json.dumps(spec_dict, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
# accepted geometries for a vector will be displayed in this order
|
# accepted geometries for a vector will be displayed in this order
|
||||||
GEOMETRY_ORDER = [
|
GEOMETRY_ORDER = [
|
||||||
'POINT',
|
'POINT',
|
||||||
|
|
|
@ -68,7 +68,7 @@ def get_invest_getspec():
|
||||||
importlib.reload(natcap.invest.validation)
|
importlib.reload(natcap.invest.validation)
|
||||||
model_module = importlib.reload(
|
model_module = importlib.reload(
|
||||||
importlib.import_module(name=target_module))
|
importlib.import_module(name=target_module))
|
||||||
return spec_utils.serialize_args_spec(model_module.MODEL_SPEC)
|
return model_module.MODEL_SPEC.to_json()
|
||||||
|
|
||||||
|
|
||||||
@app.route(f'/{PREFIX}/dynamic_dropdowns', methods=['POST'])
|
@app.route(f'/{PREFIX}/dynamic_dropdowns', methods=['POST'])
|
||||||
|
|
|
@ -509,7 +509,7 @@ class ValidateModelSpecs(unittest.TestCase):
|
||||||
|
|
||||||
for pyname in model_id_to_pyname.values():
|
for pyname in model_id_to_pyname.values():
|
||||||
model = importlib.import_module(pyname)
|
model = importlib.import_module(pyname)
|
||||||
spec_utils.serialize_args_spec(model.MODEL_SPEC)
|
model.MODEL_SPEC.to_json()
|
||||||
|
|
||||||
|
|
||||||
class SpecUtilsTests(unittest.TestCase):
|
class SpecUtilsTests(unittest.TestCase):
|
||||||
|
|
|
@ -76,16 +76,16 @@ class TranslationTests(unittest.TestCase):
|
||||||
self.locales_patcher.stop()
|
self.locales_patcher.stop()
|
||||||
shutil.rmtree(self.workspace_dir)
|
shutil.rmtree(self.workspace_dir)
|
||||||
|
|
||||||
# def test_invest_list(self):
|
def test_invest_list(self):
|
||||||
# """Translation: test that CLI list output is translated."""
|
"""Translation: test that CLI list output is translated."""
|
||||||
# from natcap.invest import cli
|
from natcap.invest import cli
|
||||||
# with patch('sys.stdout', new=io.StringIO()) as out:
|
with patch('sys.stdout', new=io.StringIO()) as out:
|
||||||
# with self.assertRaises(SystemExit):
|
with self.assertRaises(SystemExit):
|
||||||
# cli.main(['--language', TEST_LANG, 'list'])
|
cli.main(['--language', TEST_LANG, 'list'])
|
||||||
# result = out.getvalue()
|
result = out.getvalue()
|
||||||
# self.assertIn(TEST_MESSAGES['Available models:'], result)
|
self.assertIn(TEST_MESSAGES['Available models:'], result)
|
||||||
# self.assertIn(
|
self.assertIn(
|
||||||
# TEST_MESSAGES['Carbon Storage and Sequestration'], result)
|
TEST_MESSAGES['Carbon Storage and Sequestration'], result)
|
||||||
|
|
||||||
def test_invest_getspec(self):
|
def test_invest_getspec(self):
|
||||||
"""Translation: test that CLI getspec output is translated."""
|
"""Translation: test that CLI getspec output is translated."""
|
||||||
|
@ -96,72 +96,72 @@ class TranslationTests(unittest.TestCase):
|
||||||
result = out.getvalue()
|
result = out.getvalue()
|
||||||
self.assertIn(TEST_MESSAGES['baseline LULC'], result)
|
self.assertIn(TEST_MESSAGES['baseline LULC'], result)
|
||||||
|
|
||||||
# def test_invest_validate(self):
|
def test_invest_validate(self):
|
||||||
# """Translation: test that CLI validate output is translated."""
|
"""Translation: test that CLI validate output is translated."""
|
||||||
# datastack = { # write datastack to a JSON file
|
datastack = { # write datastack to a JSON file
|
||||||
# 'model_id': 'carbon',
|
'model_id': 'carbon',
|
||||||
# 'invest_version': '0.0',
|
'invest_version': '0.0',
|
||||||
# 'args': {}
|
'args': {}
|
||||||
# }
|
}
|
||||||
# datastack_path = os.path.join(self.workspace_dir, 'datastack.json')
|
datastack_path = os.path.join(self.workspace_dir, 'datastack.json')
|
||||||
# with open(datastack_path, 'w') as file:
|
with open(datastack_path, 'w') as file:
|
||||||
# json.dump(datastack, file)
|
json.dump(datastack, file)
|
||||||
|
|
||||||
# from natcap.invest import cli
|
from natcap.invest import cli
|
||||||
# with patch('sys.stdout', new=io.StringIO()) as out:
|
with patch('sys.stdout', new=io.StringIO()) as out:
|
||||||
# with self.assertRaises(SystemExit):
|
with self.assertRaises(SystemExit):
|
||||||
# cli.main(
|
cli.main(
|
||||||
# ['--language', TEST_LANG, 'validate', datastack_path])
|
['--language', TEST_LANG, 'validate', datastack_path])
|
||||||
|
|
||||||
# result = out.getvalue()
|
result = out.getvalue()
|
||||||
# self.assertIn(TEST_MESSAGES[missing_key_msg], result)
|
self.assertIn(TEST_MESSAGES[missing_key_msg], result)
|
||||||
|
|
||||||
# def test_server_get_invest_models(self):
|
def test_server_get_invest_models(self):
|
||||||
# """Translation: test that /models endpoint is translated."""
|
"""Translation: test that /models endpoint is translated."""
|
||||||
# from natcap.invest import ui_server
|
from natcap.invest import ui_server
|
||||||
# test_client = ui_server.app.test_client()
|
test_client = ui_server.app.test_client()
|
||||||
# response = test_client.get(
|
response = test_client.get(
|
||||||
# 'api/models', query_string={'language': TEST_LANG})
|
'api/models', query_string={'language': TEST_LANG})
|
||||||
# result = json.loads(response.get_data(as_text=True))
|
result = json.loads(response.get_data(as_text=True))
|
||||||
# self.assertIn(
|
self.assertIn(
|
||||||
# TEST_MESSAGES['Carbon Storage and Sequestration'],
|
TEST_MESSAGES['Carbon Storage and Sequestration'],
|
||||||
# [val['model_title'] for val in result.values()])
|
[val['model_title'] for val in result.values()])
|
||||||
|
|
||||||
# def test_server_get_invest_getspec(self):
|
def test_server_get_invest_getspec(self):
|
||||||
# """Translation: test that /getspec endpoint is translated."""
|
"""Translation: test that /getspec endpoint is translated."""
|
||||||
# from natcap.invest import ui_server
|
from natcap.invest import ui_server
|
||||||
# test_client = ui_server.app.test_client()
|
test_client = ui_server.app.test_client()
|
||||||
# response = test_client.post(
|
response = test_client.post(
|
||||||
# 'api/getspec', json='carbon', query_string={'language': TEST_LANG})
|
'api/getspec', json='carbon', query_string={'language': TEST_LANG})
|
||||||
# spec = json.loads(response.get_data(as_text=True))
|
spec = json.loads(response.get_data(as_text=True))
|
||||||
# self.assertEqual(
|
self.assertEqual(
|
||||||
# spec['inputs']['lulc_bas_path']['name'],
|
spec['args']['lulc_bas_path']['name'],
|
||||||
# TEST_MESSAGES['baseline LULC'])
|
TEST_MESSAGES['baseline LULC'])
|
||||||
|
|
||||||
# def test_server_get_invest_validate(self):
|
def test_server_get_invest_validate(self):
|
||||||
# """Translation: test that /validate endpoint is translated."""
|
"""Translation: test that /validate endpoint is translated."""
|
||||||
# from natcap.invest import ui_server
|
from natcap.invest import ui_server
|
||||||
# from natcap.invest import carbon
|
from natcap.invest import carbon
|
||||||
# test_client = ui_server.app.test_client()
|
test_client = ui_server.app.test_client()
|
||||||
# payload = {
|
payload = {
|
||||||
# 'model_id': carbon.MODEL_SPEC.model_id,
|
'model_id': carbon.MODEL_SPEC.model_id,
|
||||||
# 'args': json.dumps({})
|
'args': json.dumps({})
|
||||||
# }
|
}
|
||||||
# response = test_client.post(
|
response = test_client.post(
|
||||||
# 'api/validate', json=payload,
|
'api/validate', json=payload,
|
||||||
# query_string={'language': TEST_LANG})
|
query_string={'language': TEST_LANG})
|
||||||
# results = json.loads(response.get_data(as_text=True))
|
results = json.loads(response.get_data(as_text=True))
|
||||||
# messages = [item[1] for item in results]
|
messages = [item[1] for item in results]
|
||||||
# self.assertIn(TEST_MESSAGES[missing_key_msg], messages)
|
self.assertIn(TEST_MESSAGES[missing_key_msg], messages)
|
||||||
|
|
||||||
# def test_translate_formatted_string(self):
|
def test_translate_formatted_string(self):
|
||||||
# """Translation: test that f-string can be translated."""
|
"""Translation: test that f-string can be translated."""
|
||||||
# from natcap.invest import carbon, validation, set_locale
|
from natcap.invest import carbon, validation, set_locale
|
||||||
# set_locale(TEST_LANG)
|
set_locale(TEST_LANG)
|
||||||
# importlib.reload(validation)
|
importlib.reload(validation)
|
||||||
# importlib.reload(carbon)
|
importlib.reload(carbon)
|
||||||
# args = {'n_workers': 'not a number'}
|
args = {'n_workers': 'not a number'}
|
||||||
# validation_messages = carbon.validate(args)
|
validation_messages = carbon.validate(args)
|
||||||
# self.assertIn(
|
self.assertIn(
|
||||||
# TEST_MESSAGES[not_a_number_msg].format(value=args['n_workers']),
|
TEST_MESSAGES[not_a_number_msg].format(value=args['n_workers']),
|
||||||
# str(validation_messages))
|
str(validation_messages))
|
||||||
|
|
|
@ -295,18 +295,20 @@ class ValidatorTest(unittest.TestCase):
|
||||||
from natcap.invest import validation
|
from natcap.invest import validation
|
||||||
|
|
||||||
# both args and the kwarg should be passed to the function
|
# both args and the kwarg should be passed to the function
|
||||||
|
@spec_utils.timeout
|
||||||
def func(arg1, arg2, kwarg=None):
|
def func(arg1, arg2, kwarg=None):
|
||||||
self.assertEqual(kwarg, 'kwarg')
|
self.assertEqual(kwarg, 'kwarg')
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
# this will raise an error if the timeout is exceeded
|
# this will raise an error if the timeout is exceeded
|
||||||
# timeout defaults to 5 seconds so this should pass
|
# timeout defaults to 5 seconds so this should pass
|
||||||
spec_utils.timeout(func, 'arg1', 'arg2', kwarg='kwarg')
|
func('arg1', 'arg2', kwarg='kwarg')
|
||||||
|
|
||||||
def test_timeout_fail(self):
|
def test_timeout_fail(self):
|
||||||
from natcap.invest import validation
|
from natcap.invest import validation
|
||||||
|
|
||||||
# both args and the kwarg should be passed to the function
|
# both args and the kwarg should be passed to the function
|
||||||
|
@spec_utils.timeout
|
||||||
def func(arg):
|
def func(arg):
|
||||||
time.sleep(6)
|
time.sleep(6)
|
||||||
|
|
||||||
|
@ -315,7 +317,7 @@ class ValidatorTest(unittest.TestCase):
|
||||||
with warnings.catch_warnings(record=True) as ws:
|
with warnings.catch_warnings(record=True) as ws:
|
||||||
# cause all warnings to always be triggered
|
# cause all warnings to always be triggered
|
||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
spec_utils.timeout(func, 'arg')
|
func('arg')
|
||||||
self.assertTrue(len(ws) == 1)
|
self.assertTrue(len(ws) == 1)
|
||||||
self.assertTrue('timed out' in str(ws[0].message))
|
self.assertTrue('timed out' in str(ws[0].message))
|
||||||
|
|
||||||
|
@ -924,13 +926,13 @@ class CSVValidation(unittest.TestCase):
|
||||||
|
|
||||||
# define a side effect for the mock that will sleep
|
# define a side effect for the mock that will sleep
|
||||||
# for longer than the allowed timeout
|
# for longer than the allowed timeout
|
||||||
|
@spec_utils.timeout
|
||||||
def delay(*args, **kwargs):
|
def delay(*args, **kwargs):
|
||||||
time.sleep(7)
|
time.sleep(7)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# replace the validation.check_csv with the mock function, and try to validate
|
# replace the validation.check_csv with the mock function, and try to validate
|
||||||
with unittest.mock.patch('natcap.invest.spec_utils.CSVInput.validate',
|
with unittest.mock.patch('natcap.invest.spec_utils.CSVInput.validate', delay):
|
||||||
staticmethod(functools.partial(spec_utils.timeout, delay))):
|
|
||||||
with warnings.catch_warnings(record=True) as ws:
|
with warnings.catch_warnings(record=True) as ws:
|
||||||
# cause all warnings to always be triggered
|
# cause all warnings to always be triggered
|
||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
|
|
Loading…
Reference in New Issue