fix variable name conflicts and get tests passing

This commit is contained in:
Emily Soth 2025-05-15 12:30:35 -07:00
parent d663d379d5
commit 0407999347
14 changed files with 32 additions and 56 deletions

View File

@ -1321,6 +1321,9 @@ def build_model_spec(model_spec):
for argkey, argspec in model_spec['args'].items()] for argkey, argspec in model_spec['args'].items()]
outputs = [ outputs = [
build_output_spec(argkey, argspec) for argkey, argspec in model_spec['outputs'].items()] build_output_spec(argkey, argspec) for argkey, argspec in model_spec['outputs'].items()]
different_projections_ok = False
if 'args_with_spatial_overlap' in model_spec:
different_projections_ok = model_spec['args_with_spatial_overlap'].get('different_projections_ok', False)
return ModelSpec( return ModelSpec(
model_id=model_spec['model_id'], model_id=model_spec['model_id'],
model_title=model_spec['model_title'], model_title=model_spec['model_title'],
@ -1330,7 +1333,7 @@ def build_model_spec(model_spec):
outputs=outputs, outputs=outputs,
input_field_order=model_spec['ui_spec']['order'], input_field_order=model_spec['ui_spec']['order'],
validate_spatial_overlap=True, validate_spatial_overlap=True,
different_projections_ok=model_spec['args_with_spatial_overlap'].get('different_projections_ok', False)) different_projections_ok=different_projections_ok)
def build_input_spec(argkey, arg): def build_input_spec(argkey, arg):

View File

@ -249,7 +249,7 @@ def _format_bbox_list(file_list, bbox_list):
file_list, bbox_list)]) file_list, bbox_list)])
def validate(args, spec): def validate(args, model_spec):
"""Validate an args dict against a model spec. """Validate an args dict against a model spec.
Validates an arguments dictionary according to the rules laid out in Validates an arguments dictionary according to the rules laid out in
@ -257,7 +257,7 @@ def validate(args, spec):
Args: Args:
args (dict): The InVEST model args dict to validate. args (dict): The InVEST model args dict to validate.
spec (dict): The InVEST model spec dict to validate against. model_spec (dict): The InVEST model spec dict to validate against.
Returns: Returns:
A list of tuples where the first element of the tuple is an iterable of A list of tuples where the first element of the tuple is an iterable of
@ -272,9 +272,9 @@ def validate(args, spec):
missing_keys = set() missing_keys = set()
required_keys_with_no_value = set() required_keys_with_no_value = set()
expression_values = { expression_values = {
input_spec.id: args.get(input_spec.id, False) for input_spec in spec.inputs} input_spec.id: args.get(input_spec.id, False) for input_spec in model_spec.inputs}
keys_with_falsey_values = set() keys_with_falsey_values = set()
for parameter_spec in spec.inputs: for parameter_spec in model_spec.inputs:
key = parameter_spec.id key = parameter_spec.id
required = parameter_spec.required required = parameter_spec.required
@ -312,7 +312,7 @@ def validate(args, spec):
# we don't need to try to validate them # we don't need to try to validate them
try: try:
# Using deepcopy to make sure we don't modify the original spec # Using deepcopy to make sure we don't modify the original spec
parameter_spec = copy.deepcopy(spec.get_input(key)) parameter_spec = copy.deepcopy(model_spec.get_input(key))
except KeyError: except KeyError:
LOGGER.debug(f'Provided key {key} does not exist in MODEL_SPEC') LOGGER.debug(f'Provided key {key} does not exist in MODEL_SPEC')
continue continue
@ -341,10 +341,10 @@ def validate(args, spec):
validation_warnings.append(([key], get_message('UNEXPECTED_ERROR'))) validation_warnings.append(([key], get_message('UNEXPECTED_ERROR')))
# Phase 3: Check spatial overlap if applicable # Phase 3: Check spatial overlap if applicable
if spec.validate_spatial_overlap: if model_spec.validate_spatial_overlap:
spatial_keys = set() spatial_keys = set()
i in spec.inputs: for i in model_spec.inputs:
if isinstance(i, spec.SingleBandRasterInput) or isinstance(i, spec.VectorInput): if i.type in['raster', 'vector']:
spatial_keys.add(i.id) spatial_keys.add(i.id)
# Only test for spatial overlap once all the sufficient spatial keys # Only test for spatial overlap once all the sufficient spatial keys
@ -361,7 +361,7 @@ def validate(args, spec):
checked_keys.append(key) checked_keys.append(key)
spatial_overlap_error = check_spatial_overlap( spatial_overlap_error = check_spatial_overlap(
spatial_files, spec.different_projections_ok) spatial_files, model_spec.different_projections_ok)
if spatial_overlap_error: if spatial_overlap_error:
validation_warnings.append( validation_warnings.append(
(checked_keys, spatial_overlap_error)) (checked_keys, spatial_overlap_error))
@ -459,12 +459,12 @@ def invest_validator(validate_func):
return _wrapped_validate_func return _wrapped_validate_func
def args_enabled(args, spec): def args_enabled(args, model_spec):
"""Get enabled/disabled status of arg fields given their values and spec. """Get enabled/disabled status of arg fields given their values and spec.
Args: Args:
args (dict): Dict mapping arg keys to user-provided values args (dict): Dict mapping arg keys to user-provided values
spec (dict): MODEL_SPEC dictionary model_spec (dict): MODEL_SPEC dictionary
Returns: Returns:
Dictionary mapping each arg key to a boolean value - True if the Dictionary mapping each arg key to a boolean value - True if the
@ -472,8 +472,8 @@ def args_enabled(args, spec):
""" """
enabled = {} enabled = {}
expression_values = { expression_values = {
arg_spec.id: args.get(arg_spec.id, False) for arg_spec in spec.inputs} arg_spec.id: args.get(arg_spec.id, False) for arg_spec in model_spec.inputs}
for arg_spec in spec.inputs: for arg_spec in model_spec.inputs:
if isinstance(arg_spec.allowed, str): if isinstance(arg_spec.allowed, str):
enabled[arg_spec.id] = bool(_evaluate_expression( enabled[arg_spec.id] = bool(_evaluate_expression(
arg_spec.allowed, expression_values)) arg_spec.allowed, expression_values))

View File

@ -26,6 +26,5 @@ MODEL_SPEC = spec.ModelSpec(inputs=[
model_id='archive_extraction_model', model_id='archive_extraction_model',
model_title='', model_title='',
userguide='', userguide='',
input_field_order=[], input_field_order=[]
args_with_spatial_overlap={}
) )

View File

@ -9,6 +9,5 @@ MODEL_SPEC = spec.ModelSpec(
model_id='duplicate_filepaths_model', model_id='duplicate_filepaths_model',
model_title='', model_title='',
userguide='', userguide='',
input_field_order=[], input_field_order=[]
args_with_spatial_overlap={}
) )

View File

@ -9,6 +9,5 @@ MODEL_SPEC = spec.ModelSpec(inputs=[
model_id='nonspatial_model', model_id='nonspatial_model',
model_title='', model_title='',
userguide='', userguide='',
input_field_order=[], input_field_order=[]
args_with_spatial_overlap={}
) )

View File

@ -6,6 +6,5 @@ MODEL_SPEC = spec.ModelSpec(inputs=[
model_id='raster_model', model_id='raster_model',
model_title='', model_title='',
userguide='', userguide='',
input_field_order=[], input_field_order=[]
args_with_spatial_overlap={}
) )

View File

@ -13,6 +13,5 @@ MODEL_SPEC = spec.ModelSpec(inputs=[
model_id='simple_model', model_id='simple_model',
model_title='', model_title='',
userguide='', userguide='',
input_field_order=[], input_field_order=[]
args_with_spatial_overlap={}
) )

View File

@ -7,6 +7,5 @@ MODEL_SPEC = SimpleNamespace(inputs=[
model_id='ui_parameters_model', model_id='ui_parameters_model',
model_title='', model_title='',
userguide='', userguide='',
input_field_order=[], input_field_order=[]
args_with_spatial_overlap={}
) )

View File

@ -7,6 +7,5 @@ MODEL_SPEC = spec.ModelSpec(inputs=[
model_id='vector_model', model_id='vector_model',
model_title='', model_title='',
userguide='', userguide='',
input_field_order=[], input_field_order=[]
args_with_spatial_overlap={}
) )

View File

@ -132,11 +132,6 @@ class ValidateModelSpecs(unittest.TestCase):
("Required key(s) missing from MODEL_SPEC: " ("Required key(s) missing from MODEL_SPEC: "
f"{set(required_keys).difference(set(dir(model.MODEL_SPEC)))}")) f"{set(required_keys).difference(set(dir(model.MODEL_SPEC)))}"))
if model.MODEL_SPEC.args_with_spatial_overlap:
self.assertTrue(
set(model.MODEL_SPEC.args_with_spatial_overlap).issubset(
{'spatial_keys', 'different_projections_ok'}))
self.assertIsInstance(model.MODEL_SPEC.input_field_order, list) self.assertIsInstance(model.MODEL_SPEC.input_field_order, list)
found_keys = set() found_keys = set()
for group in model.MODEL_SPEC.input_field_order: for group in model.MODEL_SPEC.input_field_order:

View File

@ -362,7 +362,6 @@ class TestMetadataFromSpec(unittest.TestCase):
aliases=[], aliases=[],
input_field_order=[], input_field_order=[],
inputs={}, inputs={},
args_with_spatial_overlap={},
outputs=output_spec outputs=output_spec
) )
) )

View File

@ -45,7 +45,8 @@ class EndpointFunctionTests(unittest.TestCase):
self.assertEqual( self.assertEqual(
set(spec), set(spec),
{'model_id', 'model_title', 'userguide', 'aliases', {'model_id', 'model_title', 'userguide', 'aliases',
'input_field_order', 'args_with_spatial_overlap', 'args', 'outputs'}) 'input_field_order', 'different_projections_ok',
'validate_spatial_overlap', 'args', 'outputs'})
def test_get_invest_validate(self): def test_get_invest_validate(self):
"""UI server: get_invest_validate endpoint.""" """UI server: get_invest_validate endpoint."""

View File

@ -75,8 +75,7 @@ class UsageLoggingTests(unittest.TestCase):
spec.VectorInput(id='blank_vector_path', geometry_types={}, fields={}) spec.VectorInput(id='blank_vector_path', geometry_types={}, fields={})
], ],
outputs={}, outputs={},
input_field_order=[], input_field_order=[])
args_with_spatial_overlap=None)
output_logfile = os.path.join(self.workspace_dir, 'logfile.txt') output_logfile = os.path.join(self.workspace_dir, 'logfile.txt')
with utils.log_to_file(output_logfile): with utils.log_to_file(output_logfile):

View File

@ -40,12 +40,10 @@ from natcap.invest.spec import (
gdal.UseExceptions() gdal.UseExceptions()
def model_spec_with_defaults(model_id='', model_title='', userguide='', aliases=None, def model_spec_with_defaults(model_id='', model_title='', userguide='', aliases=None,
inputs={}, outputs={}, input_field_order=[], inputs={}, outputs={}, input_field_order=[]):
args_with_spatial_overlap=[]):
return ModelSpec(model_id=model_id, model_title=model_title, userguide=userguide, return ModelSpec(model_id=model_id, model_title=model_title, userguide=userguide,
aliases=aliases, inputs=inputs, outputs=outputs, aliases=aliases, inputs=inputs, outputs=outputs,
input_field_order=input_field_order, input_field_order=input_field_order)
args_with_spatial_overlap=args_with_spatial_overlap)
def number_input_spec_with_defaults(id='', units=u.none, expression='', **kwargs): def number_input_spec_with_defaults(id='', units=u.none, expression='', **kwargs):
return NumberInput(id=id, units=units, expression=expression, **kwargs) return NumberInput(id=id, units=units, expression=expression, **kwargs)
@ -2050,11 +2048,7 @@ class TestValidationFromSpec(unittest.TestCase):
fields={}, fields={},
geometry_types={'POINT'} geometry_types={'POINT'}
) )
], ]
args_with_spatial_overlap={
'spatial_keys': ['raster_a', 'raster_b', 'vector_a'],
'different_projections_ok': True
}
) )
driver = gdal.GetDriverByName('GTiff') driver = gdal.GetDriverByName('GTiff')
@ -2117,11 +2111,7 @@ class TestValidationFromSpec(unittest.TestCase):
data_type=float, data_type=float,
units=u.none units=u.none
) )
], ]
args_with_spatial_overlap={
'spatial_keys': ['raster_a', 'raster_b'],
'different_projections_ok': True
}
) )
driver = gdal.GetDriverByName('GTiff') driver = gdal.GetDriverByName('GTiff')
@ -2170,11 +2160,7 @@ class TestValidationFromSpec(unittest.TestCase):
fields=[], fields=[],
geometry_types={'POINT'} geometry_types={'POINT'}
) )
], ]
args_with_spatial_overlap={
'spatial_keys': ['raster_a', 'raster_b', 'vector_a'],
'different_projections_ok': True
}
) )
driver = gdal.GetDriverByName('GTiff') driver = gdal.GetDriverByName('GTiff')