work in progress: updating tests

This commit is contained in:
Emily Soth 2025-04-28 12:07:59 -07:00
parent 57c151f92e
commit 73b42aecc5
48 changed files with 799 additions and 855 deletions

View File

@ -536,9 +536,8 @@ def execute(args):
'Checking that watersheds have entries for every `ws_id` in the '
'valuation table.')
# Open/read in valuation parameters from CSV file
valuation_df = validation.get_validated_dataframe(
args['valuation_table_path'],
MODEL_SPEC.inputs.valuation_table_path)
valuation_df = MODEL_SPEC.inputs.get(
'valuation_table_path').get_validated_dataframe(args['valuation_table_path'])
watershed_vector = gdal.OpenEx(
args['watersheds_path'], gdal.OF_VECTOR)
watershed_layer = watershed_vector.GetLayer()
@ -660,15 +659,16 @@ def execute(args):
'lulc': pygeoprocessing.get_raster_info(clipped_lulc_path)['nodata'][0]}
# Open/read in the csv file into a dictionary and add to arguments
bio_df = validation.get_validated_dataframe(args['biophysical_table_path'],
MODEL_SPEC.inputs.biophysical_table_path)
bio_df = MODEL_SPEC.inputs.get(
'biophysical_table_path').get_validated_dataframe(args['biophysical_table_path'])
bio_lucodes = set(bio_df.index.values)
bio_lucodes.add(nodata_dict['lulc'])
LOGGER.debug(f'bio_lucodes: {bio_lucodes}')
if 'demand_table_path' in args and args['demand_table_path'] != '':
demand_df = validation.get_validated_dataframe(
args['demand_table_path'], MODEL_SPEC.inputs.demand_table_path)
demand_df = MODEL_SPEC.inputs.get('demand_table_path').get_validated_dataframe(
args['demand_table_path'])
demand_reclassify_dict = dict(
[(lucode, row['demand']) for lucode, row in demand_df.iterrows()])
demand_lucodes = set(demand_df.index.values)
@ -1324,5 +1324,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
return validation.validate(args, MODEL_SPEC)

View File

@ -322,8 +322,8 @@ def execute(args):
"Baseline LULC Year is earlier than the Alternate LULC Year."
)
carbon_pool_df = validation.get_validated_dataframe(
args['carbon_pools_path'], MODEL_SPEC.inputs.carbon_pools_path)
carbon_pool_df = MODEL_SPEC.inputs.get(
'carbon_pools_path').get_validated_dataframe(args['carbon_pools_path'])
try:
n_workers = int(args['n_workers'])
@ -694,5 +694,4 @@ def validate(args, limit_to=None):
the error message in the second part of the tuple. This should
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
return validation.validate(args, MODEL_SPEC)

View File

@ -586,10 +586,9 @@ def execute(args):
task_graph, n_workers, intermediate_dir, output_dir, suffix = (
_set_up_workspace(args))
snapshots = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
MODEL_SPEC.inputs.landcover_snapshot_csv
)['raster_path'].to_dict()
snapshots = MODEL_SPEC.inputs.get(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
# Phase 1: alignment and preparation of inputs
baseline_lulc_year = min(snapshots.keys())
@ -609,9 +608,9 @@ def execute(args):
# We're assuming that the LULC initial variables and the carbon pool
# transient table are combined into a single lookup table.
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
MODEL_SPEC.inputs.biophysical_table_path)
biophysical_df = MODEL_SPEC.inputs.get(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
# LULC Classnames are critical to the transition mapping, so they must be
# unique. This check is here in ``execute`` because it's possible that
@ -979,10 +978,9 @@ def execute(args):
prices = None
if args.get('do_economic_analysis', False): # Do if truthy
if args.get('use_price_table', False):
prices = validation.get_validated_dataframe(
args['price_table_path'],
MODEL_SPEC.inputs.price_table_path
)['price'].to_dict()
prices = MODEL_SPEC.inputs.get(
'price_table_path').get_validated_dataframe(
args['price_table_path'])['price'].to_dict()
else:
inflation_rate = float(args['inflation_rate']) * 0.01
annual_price = float(args['price'])
@ -1964,9 +1962,9 @@ def _read_transition_matrix(transition_csv_path, biophysical_df):
landcover transition, and the second contains accumulation rates for
the pool for the landcover transition.
"""
table = validation.get_validated_dataframe(
transition_csv_path, MODEL_SPEC.inputs.landcover_transitions_table
).reset_index()
table = MODEL_SPEC.inputs.get(
'landcover_transitions_table').get_validated_dataframe(
transition_csv_path).reset_index()
lulc_class_to_lucode = {}
max_lucode = biophysical_df.index.max()
@ -2180,15 +2178,15 @@ def validate(args, limit_to=None):
A list of tuples where tuple[0] is an iterable of keys that the error
message applies to and tuple[1] is the string validation warning.
"""
validation_warnings = validation.validate(args, MODEL_SPEC.inputs)
validation_warnings = validation.validate(args, MODEL_SPEC)
sufficient_keys = validation.get_sufficient_keys(args)
invalid_keys = validation.get_invalid_keys(validation_warnings)
if ("landcover_snapshot_csv" not in invalid_keys and
"landcover_snapshot_csv" in sufficient_keys):
snapshots = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
MODEL_SPEC.inputs.landcover_snapshot_csv
snapshots = MODEL_SPEC.inputs.get(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv']
)['raster_path'].to_dict()
snapshot_years = set(snapshots.keys())
@ -2208,13 +2206,13 @@ def validate(args, limit_to=None):
# check for invalid options in the translation table
if ("landcover_transitions_table" not in invalid_keys and
"landcover_transitions_table" in sufficient_keys):
transitions_spec = MODEL_SPEC.inputs.landcover_transitions_table
transitions_spec = MODEL_SPEC.inputs.get('landcover_transitions_table')
transition_options = list(
getattr(transitions_spec.columns, '[LULC CODE]').options.keys())
transitions_spec.columns.get('[LULC CODE]').options.keys())
# lowercase options since utils call will lowercase table values
transition_options = [x.lower() for x in transition_options]
transitions_df = validation.get_validated_dataframe(
args['landcover_transitions_table'], transitions_spec)
transitions_df = transitions_spec.get_validated_dataframe(
args['landcover_transitions_table'])
transitions_mask = ~transitions_df.isin(transition_options) & ~transitions_df.isna()
if transitions_mask.any(axis=None):
transition_numpy_mask = transitions_mask.values

View File

@ -185,10 +185,9 @@ def execute(args):
os.path.join(args['workspace_dir'], 'taskgraph_cache'),
n_workers, reporting_interval=5.0)
snapshots_dict = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
MODEL_SPEC.inputs.landcover_snapshot_csv
)['raster_path'].to_dict()
snapshots_dict = MODEL_SPEC.inputs.get(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
# Align the raster stack for analyzing the various transitions.
min_pixel_size = float('inf')
@ -218,9 +217,9 @@ def execute(args):
target_path_list=aligned_snapshot_paths,
task_name='Align input landcover rasters')
landcover_df = validation.get_validated_dataframe(
args['lulc_lookup_table_path'],
MODEL_SPEC.inputs.lulc_lookup_table_path)
landcover_df = MODEL_SPEC.inputs.get(
'lulc_lookup_table_path').get_validated_dataframe(
args['lulc_lookup_table_path'])
target_transition_table = os.path.join(
output_dir, TRANSITION_TABLE.format(suffix=suffix))
@ -389,8 +388,8 @@ def _create_biophysical_table(landcover_df, target_biophysical_table_path):
``None``
"""
target_column_names = [
colname.lower() for colname in
coastal_blue_carbon.MODEL_SPEC.inputs.biophysical_table_path.columns.__dict__]
spec.id.lower() for spec in
coastal_blue_carbon.MODEL_SPEC.inputs.get('biophysical_table_path').columns]
with open(target_biophysical_table_path, 'w') as bio_table:
bio_table.write(f"{','.join(target_column_names)}\n")
@ -424,4 +423,4 @@ def validate(args, limit_to=None):
A list of tuples where tuple[0] is an iterable of keys that the error
message applies to and tuple[1] is the string validation warning.
"""
return validation.validate(args, MODEL_SPEC.inputs)
return validation.validate(args, MODEL_SPEC)

View File

@ -2342,8 +2342,8 @@ def _schedule_habitat_tasks(
list of pickle file path strings
"""
habitat_dataframe = validation.get_validated_dataframe(
habitat_table_path, MODEL_SPEC.inputs.habitat_table_path
habitat_dataframe = MODEL_SPEC.inputs.get(
'habitat_table_path').get_validated_dataframe(habitat_table_path
).rename(columns={'protection distance (m)': 'distance'})
habitat_task_list = []
@ -2872,10 +2872,9 @@ def assemble_results_and_calculate_exposure(
with open(pickle_path, 'rb') as file:
final_values_dict[var_name] = pickle.load(file)
habitat_df = validation.get_validated_dataframe(
habitat_protection_path, **MODEL_SPEC['outputs']['intermediate'][
'contents']['habitats']['contents']['habitat_protection.csv']
).rename(columns={'r_hab': 'R_hab'})
habitat_df = MODEL_SPEC.outputs.get('intermediate').contents.get(
'habitats').contents.get('habitat_protection.csv').get_validated_dataframe(
habitat_protection_path).rename(columns={'r_hab': 'R_hab'})
output_layer.StartTransaction()
for feature in output_layer:
shore_id = feature.GetField(SHORE_ID_FIELD)
@ -3507,8 +3506,7 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
sufficient_keys = validation.get_sufficient_keys(args)
@ -3519,8 +3517,8 @@ def validate(args, limit_to=None):
'slr_field' in sufficient_keys):
fieldnames = validation.load_fields_from_vector(
args['slr_vector_path'])
error_msg = validation.check_option_string(args['slr_field'],
fieldnames)
error_msg = spec_utils.OptionStringInputSpec(
options=fieldnames).validate(args['slr_field'])
if error_msg:
validation_warnings.append((['slr_field'], error_msg))

View File

@ -615,9 +615,10 @@ def execute(args):
None.
"""
crop_to_landcover_df = validation.get_validated_dataframe(
args['landcover_to_crop_table_path'],
MODEL_SPEC.inputs.landcover_to_crop_table_path)
crop_to_landcover_df = MODEL_SPEC.inputs.get(
'landcover_to_crop_table_path').get_validated_dataframe(
args['landcover_to_crop_table_path'])
bad_crop_name_list = []
for crop_name in crop_to_landcover_df.index:
crop_climate_bin_raster_path = os.path.join(
@ -696,11 +697,11 @@ def execute(args):
climate_percentile_yield_table_path = os.path.join(
args['model_data_path'],
_CLIMATE_PERCENTILE_TABLE_PATTERN % crop_name)
crop_climate_percentile_df = validation.get_validated_dataframe(
climate_percentile_yield_table_path,
MODEL_SPEC.inputs.model_data_path['contents'][
'climate_percentile_yield_tables']['contents'][
'[CROP]_percentile_yield_table.csv'])
crop_climate_percentile_df = MODEL_SPEC.inputs.get(
'model_data_path').contents.get(
'climate_percentile_yield_tables').contents.get(
'[CROP]_percentile_yield_table.csv').get_validated_dataframe(
climate_percentile_yield_table_path)
yield_percentile_headers = [
x for x in crop_climate_percentile_df.columns if x != 'climate_bin']
@ -853,9 +854,11 @@ def execute(args):
# both 'crop_nutrient.csv' and 'crop' are known data/header values for
# this model data.
nutrient_df = validation.get_validated_dataframe(
os.path.join(args['model_data_path'], 'crop_nutrient.csv'),
MODEL_SPEC.inputs.model_data_path['contents']['crop_nutrient.csv'])
nutrient_df = MODEL_SPEC.inputs.get(
'model_data_path').contents.get(
'crop_nutrient.csv').get_validated_dataframe(
os.path.join(args['model_data_path'], 'crop_nutrient.csv'))
result_table_path = os.path.join(
output_dir, 'result_table%s.csv' % file_suffix)
@ -1246,4 +1249,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
args, MODEL_SPEC, MODEL_SPEC.args_with_spatial_overlap)

View File

@ -504,13 +504,13 @@ def execute(args):
LOGGER.info(
"Checking if the landcover raster is missing lucodes")
crop_to_landcover_df = validation.get_validated_dataframe(
args['landcover_to_crop_table_path'],
MODEL_SPEC.inputs.landcover_to_crop_table_path)
crop_to_landcover_df = MODEL_SPEC.inputs.get(
'landcover_to_crop_table_path').get_validated_dataframe(
args['landcover_to_crop_table_path'])
crop_to_fertilization_rate_df = validation.get_validated_dataframe(
args['fertilization_rate_table_path'],
MODEL_SPEC.inputs.fertilization_rate_table_path)
crop_to_fertilization_rate_df = MODEL_SPEC.inputs.get(
'fertilization_rate_table_path').get_validated_dataframe(
args['fertilization_rate_table_path'])
crop_lucodes = list(crop_to_landcover_df[_EXPECTED_LUCODE_TABLE_HEADER])
@ -585,12 +585,11 @@ def execute(args):
task_name='crop_climate_bin')
dependent_task_list.append(crop_climate_bin_task)
crop_regression_df = validation.get_validated_dataframe(
os.path.join(args['model_data_path'],
_REGRESSION_TABLE_PATTERN % crop_name),
MODEL_SPEC.inputs.model_data_path['contents'][
'climate_regression_yield_tables']['contents'][
'[CROP]_regression_yield_table.csv'])
crop_regression_df = MODEL_SPEC.inputs.get('model_data_path').contents.get(
'climate_regression_yield_tables').contents.get(
'[CROP]_regression_yield_table.csv').get_validated_dataframe(
os.path.join(args['model_data_path'],
_REGRESSION_TABLE_PATTERN % crop_name))
for _, row in crop_regression_df.iterrows():
for header in _EXPECTED_REGRESSION_TABLE_HEADERS:
if numpy.isnan(row[header]):
@ -809,9 +808,9 @@ def execute(args):
# both 'crop_nutrient.csv' and 'crop' are known data/header values for
# this model data.
nutrient_df = validation.get_validated_dataframe(
os.path.join(args['model_data_path'], 'crop_nutrient.csv'),
MODEL_SPEC.inputs.model_data_path['contents']['crop_nutrient.csv'])
nutrient_df = MODEL_SPEC.inputs.get('model_data_path').contents.get(
'crop_nutrient.csv').get_validated_dataframe(
os.path.join(args['model_data_path'], 'crop_nutrient.csv'))
LOGGER.info("Generating report table")
crop_names = list(crop_to_landcover_df.index)
@ -1174,5 +1173,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
return validation.validate(args, MODEL_SPEC)

View File

@ -37,6 +37,7 @@ from osgeo import gdal
from . import utils
from . import validation
from . import models
from . import spec_utils
try:
from . import __version__
@ -198,10 +199,11 @@ def build_datastack_archive(args, model_id, datastack_path):
# For tracking existing files so we don't copy files in twice
files_found = {}
LOGGER.debug(f'Keys: {sorted(args.keys())}')
args_spec = module.MODEL_SPEC['args']
spatial_types = {'raster', 'vector'}
file_based_types = spatial_types.union({'csv', 'file', 'directory'})
spatial_types = {spec_utils.SingleBandRasterInputSpec, spec_utils.VectorInputSpec,
spec_utils.RasterOrVectorInputSpec}
file_based_types = spatial_types.union({
spec_utils.CSVInputSpec, spec_utils.FileInputSpec, spec_utils.DirectoryInputSpec})
rewritten_args = {}
for key in args:
# Allow the model to override specific arguments in datastack archive
@ -234,11 +236,11 @@ def build_datastack_archive(args, model_id, datastack_path):
LOGGER.info(f'Starting to archive arg "{key}": {args[key]}')
# Possible that a user might pass an args key that doesn't belong to
# this model. Skip if so.
if key not in args_spec:
if key not in module.MODEL_SPEC.inputs:
LOGGER.info(f'Skipping arg {key}; not in model MODEL_SPEC')
input_type = args_spec[key]['type']
if input_type in file_based_types:
input_spec = module.MODEL_SPEC.inputs.get(key)
if input_spec.__class__ in file_based_types:
if args[key] in {None, ''}:
LOGGER.info(
f'Skipping key {key}, value is empty and cannot point to '
@ -258,22 +260,16 @@ def build_datastack_archive(args, model_id, datastack_path):
rewritten_args[key] = files_found[source_path]
continue
if input_type == 'csv':
if input_spec.__class__ is spec_utils.CSVInputSpec:
# check the CSV for columns that may be spatial.
# But also, the columns specification might not be listed, so don't
# require that 'columns' exists in the MODEL_SPEC.
spatial_columns = []
if 'columns' in args_spec[key]:
for col_name, col_definition in (
args_spec[key]['columns'].items()):
# Type attribute may be a string (one type) or set
# (multiple types allowed), so always convert to a set for
# easier comparison.
col_types = col_definition['type']
if isinstance(col_types, str):
col_types = set([col_types])
if col_types.intersection(spatial_types):
spatial_columns.append(col_name)
if input_spec.columns:
for col_spec in input_spec.columns:
if col_spec.__class__ in spatial_types:
spatial_columns.append(col_spec.id)
LOGGER.debug(f'Detected spatial columns: {spatial_columns}')
target_csv_path = os.path.join(
@ -286,8 +282,7 @@ def build_datastack_archive(args, model_id, datastack_path):
contained_files_dir = os.path.join(
data_dir, f'{key}_csv_data')
dataframe = validation.get_validated_dataframe(
source_path, **args_spec[key])
dataframe = input_spec.get_validated_dataframe(source_path)
csv_source_dir = os.path.abspath(os.path.dirname(source_path))
for spatial_column_name in spatial_columns:
# Iterate through the spatial columns, identify the set of
@ -347,7 +342,7 @@ def build_datastack_archive(args, model_id, datastack_path):
target_arg_value = target_csv_path
files_found[source_path] = target_arg_value
elif input_type == 'file':
elif input_spec.__class__ is spec_utils.FileInputSpec:
target_filepath = os.path.join(
data_dir, f'{key}_file')
shutil.copyfile(source_path, target_filepath)
@ -356,7 +351,7 @@ def build_datastack_archive(args, model_id, datastack_path):
target_arg_value = target_filepath
files_found[source_path] = target_arg_value
elif input_type == 'directory':
elif input_spec.__class__ is spec_utils.DirectoryInputSpec:
# copy the whole folder
target_directory = os.path.join(data_dir, f'{key}_directory')
os.makedirs(target_directory)
@ -376,22 +371,17 @@ def build_datastack_archive(args, model_id, datastack_path):
target_arg_value = target_directory
files_found[source_path] = target_arg_value
elif input_type in spatial_types:
elif input_spec.__class__ in spatial_types:
# Create a directory with a readable name, something like
# "aoi_path_vector" or "lulc_cur_path_raster".
spatial_dir = os.path.join(data_dir, f'{key}_{input_type}')
spatial_dir = os.path.join(data_dir, f'{key}_{input_spec.__class__}')
target_arg_value = utils.copy_spatial_files(
source_path, spatial_dir)
files_found[source_path] = target_arg_value
elif input_type == 'other':
# Note that no models currently use this to the best of my
# knowledge, so better to raise a NotImplementedError
raise NotImplementedError(
'The "other" MODEL_SPEC input type is not supported')
else:
LOGGER.debug(
f"Type {input_type} is not filesystem-based; "
f"Type {input_spec.__class__} is not filesystem-based; "
"recording value directly")
# not a filesystem-based type
# Record the value directly

View File

@ -830,5 +830,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
return validation.validate(args, MODEL_SPEC)

View File

@ -425,9 +425,9 @@ def execute(args):
# Map non-forest landcover codes to carbon biomasses
LOGGER.info('Calculating direct mapped carbon stocks')
carbon_maps = []
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
MODEL_SPEC.inputs.biophysical_table_path)
biophysical_df = MODEL_SPEC.inputs.get(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
pool_list = [('c_above', True)]
if args['pools_to_calculate'] == 'all':
pool_list.extend([
@ -644,8 +644,8 @@ def _calculate_lulc_carbon_map(
"""
# classify forest pixels from lulc
biophysical_df = validation.get_validated_dataframe(
biophysical_table_path, MODEL_SPEC.inputs.biophysical_table_path)
biophysical_df = MODEL_SPEC.inputs.get(
'biophysical_table_path').get_validated_dataframe(biophysical_table_path)
lucode_to_per_cell_carbon = {}
@ -704,8 +704,9 @@ def _map_distance_from_tropical_forest_edge(
"""
# Build a list of forest lucodes
biophysical_df = validation.get_validated_dataframe(
biophysical_table_path, MODEL_SPEC.inputs.biophysical_table_path)
biophysical_df = MODEL_SPEC.inputs.get(
'biophysical_table_path').get_validated_dataframe(
biophysical_table_path)
forest_codes = biophysical_df[biophysical_df['is_tropical_forest']].index.values
# Make a raster where 1 is non-forest landcover types and 0 is forest
@ -1033,8 +1034,7 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = set([])
for affected_keys, error_msg in validation_warnings:

View File

@ -456,12 +456,12 @@ def execute(args):
LOGGER.info("Checking Threat and Sensitivity tables for compliance")
# Get CSVs as dictionaries and ensure the key is a string for threats.
threat_df = validation.get_validated_dataframe(
args['threats_table_path'], MODEL_SPEC.inputs.threats_table_path
).fillna('')
sensitivity_df = validation.get_validated_dataframe(
args['sensitivity_table_path'],
MODEL_SPEC.inputs.sensitivity_table_path)
threat_df = MODEL_SPEC.inputs.get(
'threats_table_path').get_validated_dataframe(
args['threats_table_path']).fillna('')
sensitivity_df = MODEL_SPEC.inputs.get(
'sensitivity_table_path').get_validated_dataframe(
args['sensitivity_table_path'])
half_saturation_constant = float(args['half_saturation_constant'])
@ -1174,8 +1174,7 @@ def validate(args, limit_to=None):
the error message in the second part of the tuple. This should
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
@ -1183,12 +1182,12 @@ def validate(args, limit_to=None):
"sensitivity_table_path" not in invalid_keys and
"threat_raster_folder" not in invalid_keys):
# Get CSVs as dictionaries and ensure the key is a string for threats.
threat_df = validation.get_validated_dataframe(
args['threats_table_path'],
MODEL_SPEC.inputs.threats_table_path).fillna('')
sensitivity_df = validation.get_validated_dataframe(
args['sensitivity_table_path'],
MODEL_SPEC.inputs.sensitivity_table_path)
threat_df = MODEL_SPEC.inputs.get(
'threats_table_path').get_validated_dataframe(
args['threats_table_path']).fillna('')
sensitivity_df = MODEL_SPEC.inputs.get(
'sensitivity_table_path').get_validated_dataframe(
args['sensitivity_table_path'])
# check that the threat names in the threats table match with the
# threats columns in the sensitivity table.

View File

@ -448,8 +448,8 @@ MODEL_SPEC = spec_utils.build_model_spec({
}
})
_VALID_RISK_EQS = set(MODEL_SPEC.inputs.risk_eq.options.keys())
_VALID_DECAY_TYPES = set(MODEL_SPEC.inputs.decay_eq.options.keys())
_VALID_RISK_EQS = set(MODEL_SPEC.inputs.get('risk_eq').options.keys())
_VALID_DECAY_TYPES = set(MODEL_SPEC.inputs.get('decay_eq').options.keys())
def execute(args):
@ -1792,8 +1792,8 @@ def _parse_info_table(info_table_path):
info_table_path = os.path.abspath(info_table_path)
try:
table = validation.get_validated_dataframe(
info_table_path, MODEL_SPEC.inputs.info_table_path)
table = MODEL_SPEC.inputs.get(
'info_table_path').get_validated_dataframe(info_table_path)
except ValueError as err:
if 'Index has duplicate keys' in str(err):
raise ValueError("Habitat and stressor names may not overlap.")
@ -2475,4 +2475,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(args, MODEL_SPEC.inputs)
return validation.validate(args, MODEL_SPEC)

View File

@ -87,23 +87,47 @@ MODEL_SPEC = spec_utils.build_model_spec({
"index_col": "lucode",
"columns": {
"lucode": spec_utils.LULC_TABLE_COLUMN,
"load_[NUTRIENT]": { # nitrogen or phosphorus nutrient loads
"load_n": {
"type": "number",
"units": u.kilogram/u.hectare/u.year,
"required": "calc_n",
"about": gettext(
"The nutrient loading for this land use class.")},
"eff_[NUTRIENT]": { # nutrient retention capacities
"The nitrogen loading for this land use class.")},
"load_p": {
"type": "number",
"units": u.kilogram/u.hectare/u.year,
"required": "calc_p",
"about": gettext(
"The phosphorus loading for this land use class.")},
"eff_n": {
"type": "ratio",
"required": "calc_n",
"about": gettext(
"Maximum nutrient retention efficiency. This is the "
"maximum proportion of the nutrient that is retained "
"Maximum nitrogen retention efficiency. This is the "
"maximum proportion of the nitrogen that is retained "
"on this LULC class.")},
"crit_len_[NUTRIENT]": { # nutrient critical lengths
"eff_p": {
"type": "ratio",
"required": "calc_p",
"about": gettext(
"Maximum phosphorus retention efficiency. This is the "
"maximum proportion of the phosphorus that is retained "
"on this LULC class.")},
"crit_len_n": {
"type": "number",
"units": u.meter,
"required": "calc_n",
"about": gettext(
"The distance after which it is assumed that this "
"LULC type retains the nutrient at its maximum "
"LULC type retains nitrogen at its maximum "
"capacity.")},
"crit_len_p": {
"type": "number",
"units": u.meter,
"required": "calc_p",
"about": gettext(
"The distance after which it is assumed that this "
"LULC type retains phosphorus at its maximum "
"capacity.")},
"proportion_subsurface_n": {
"type": "ratio",
@ -117,13 +141,11 @@ MODEL_SPEC = spec_utils.build_model_spec({
},
"about": gettext(
"A table mapping each LULC class to its biophysical "
"properties related to nutrient load and retention. Replace "
"'[NUTRIENT]' in the column names with 'n' or 'p' for "
"nitrogen or phosphorus respectively. Nitrogen data must be "
"provided if Calculate Nitrogen is selected. Phosphorus data "
"must be provided if Calculate Phosphorus is selected. All "
"LULC codes in the LULC raster must have corresponding "
"entries in this table."),
"properties related to nutrient load and retention. Nitrogen "
"data must be provided if Calculate Nitrogen is selected. "
"Phosphorus data must be provided if Calculate Phosphorus is "
"selected. All LULC codes in the LULC raster must have "
"corresponding entries in this table."),
"name": gettext("biophysical table")
},
"calc_p": {
@ -603,9 +625,9 @@ def execute(args):
if args['calc_' + nutrient_id]:
nutrients_to_process.append(nutrient_id)
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
MODEL_SPEC.inputs.biophysical_table_path)
biophysical_df = MODEL_SPEC.inputs.get(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
# Ensure that if user doesn't explicitly assign a value,
# runoff_proxy_av = None
@ -1268,7 +1290,7 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
spec_copy = copy.deepcopy(MODEL_SPEC.inputs)
spec_copy = copy.deepcopy(MODEL_SPEC)
# Check required fields given the state of ``calc_n`` and ``calc_p``
nutrients_selected = []
for nutrient_letter in ('n', 'p'):
@ -1277,16 +1299,14 @@ def validate(args, limit_to=None):
for param in ['load', 'eff', 'crit_len']:
for nutrient in nutrients_selected:
spec_copy.biophysical_table_path.columns[f'{param}_{nutrient}'] = (
spec_copy.biophysical_table_path.columns[f'{param}_[NUTRIENT]'])
spec_copy.biophysical_table_path.columns.[f'{param}_{nutrient}']['required'] = True
spec_copy.biophysical_table_path.columns.pop(f'{param}_[NUTRIENT]')
spec_copy.inputs.get('biophysical_table_path').columns.get(
f'{param}_{nutrient}').required = True
if 'n' in nutrients_selected:
spec_copy.biophysical_table_path.columns.proportion_subsurface_n.required = True
spec_copy.inputs.get('biophysical_table_path').columns.get(
'proportion_subsurface_n').required = True
validation_warnings = validation.validate(
args, spec_copy, MODEL_SPEC.args_with_spatial_overlap)
validation_warnings = validation.validate(args, spec_copy)
if not nutrients_selected:
validation_warnings.append(

View File

@ -1222,8 +1222,8 @@ def _parse_scenario_variables(args):
else:
farm_vector_path = None
guild_df = validation.get_validated_dataframe(
guild_table_path, MODEL_SPEC.inputs.guild_table_path)
guild_df = MODEL_SPEC.inputs.get(
'guild_table_path').get_validated_dataframe(guild_table_path)
LOGGER.info('Checking to make sure guild table has all expected headers')
for header in _EXPECTED_GUILD_HEADERS:
@ -1234,9 +1234,9 @@ def _parse_scenario_variables(args):
f"'{header}' but was unable to find one. Here are all the "
f"headers from {guild_table_path}: {', '.join(guild_df.columns)}")
landcover_biophysical_df = validation.get_validated_dataframe(
landcover_biophysical_table_path,
MODEL_SPEC.inputs.landcover_biophysical_table_path)
landcover_biophysical_df = MODEL_SPEC.inputs.get(
'landcover_biophysical_table_path').get_validated_dataframe(
landcover_biophysical_table_path)
biophysical_table_headers = landcover_biophysical_df.columns
for header in _EXPECTED_BIOPHYSICAL_HEADERS:
matches = re.findall(header, " ".join(biophysical_table_headers))
@ -1500,4 +1500,4 @@ def validate(args, limit_to=None):
# Deliberately not validating the interrelationship of the columns between
# the biophysical table and the guilds table as the model itself already
# does extensive checking for this.
return validation.validate(args, MODEL_SPEC.inputs)
return validation.validate(args, MODEL_SPEC)

View File

@ -610,9 +610,9 @@ def execute(args):
# Compute the regression
coefficient_json_path = os.path.join(
intermediate_dir, 'predictor_estimates.json')
predictor_df = validation.get_validated_dataframe(
args['predictor_table_path'],
MODEL_SPEC.inputs.predictor_table_path)
predictor_df = MODEL_SPEC.inputs.get(
'predictor_table_path').get_validated_dataframe(
args['predictor_table_path'])
predictor_id_list = predictor_df.index
compute_regression_task = task_graph.add_task(
func=_compute_and_summarize_regression,
@ -997,8 +997,8 @@ def _schedule_predictor_data_processing(
'line_intersect_length': _line_intersect_length,
}
predictor_df = validation.get_validated_dataframe(
predictor_table_path, MODEL_SPEC.inputs.predictor_table_path)
predictor_df = MODEL_SPEC.inputs.get(
'predictor_table_path').get_validated_dataframe(predictor_table_path)
predictor_task_list = []
predictor_json_list = [] # tracks predictor files to add to gpkg
@ -1766,8 +1766,8 @@ def _validate_same_id_lengths(table_path):
string message if IDs are too long
"""
predictor_df = validation.get_validated_dataframe(
table_path, MODEL_SPEC.inputs.predictor_table_path)
predictor_df = MODEL_SPEC.inputs.get(
'predictor_table_path').get_validated_dataframe(table_path)
too_long = set()
for p_id in predictor_df.index:
if len(p_id) > 10:
@ -1795,12 +1795,13 @@ def _validate_same_ids_and_types(
string message if any of the fields in 'id' and 'type' don't match
between tables.
"""
predictor_df = validation.get_validated_dataframe(
predictor_table_path, MODEL_SPEC.inputs.predictor_table_path)
predictor_df = MODEL_SPEC.inputs.get(
'predictor_table_path').get_validated_dataframe(
predictor_table_path)
scenario_predictor_df = validation.get_validated_dataframe(
scenario_predictor_table_path,
MODEL_SPEC.inputs.scenario_predictor_table_path)
scenario_predictor_df = MODEL_SPEC.inputs.get(
'scenario_predictor_table_path').get_validated_dataframe(
scenario_predictor_table_path)
predictor_pairs = set([
(p_id, row['type']) for p_id, row in predictor_df.iterrows()])
@ -1825,9 +1826,9 @@ def _validate_same_projection(base_vector_path, table_path):
"""
# This will load the table as a list of paths which we can iterate through
# without bothering the rest of the table structure
data_paths = validation.get_validated_dataframe(
table_path, MODEL_SPEC.inputs.predictor_table_path
)['path'].tolist()
data_paths = MODEL_SPEC.inputs.get(
'predictor_table_path').get_validated_dataframe(
table_path)['path'].tolist()
base_vector = gdal.OpenEx(base_vector_path, gdal.OF_VECTOR)
base_layer = base_vector.GetLayer()
@ -1868,8 +1869,8 @@ def _validate_predictor_types(table_path):
string message if any value in the ``type`` column does not match a
valid type, ignoring leading/trailing whitespace.
"""
df = validation.get_validated_dataframe(
table_path, MODEL_SPEC.inputs.predictor_table_path)
df = MODEL_SPEC.inputs.get(
'predictor_table_path').get_validated_dataframe(table_path)
# ignore leading/trailing whitespace because it will be removed
# when the type values are used
valid_types = set({'raster_mean', 'raster_sum', 'point_count',
@ -1928,7 +1929,7 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_messages = validation.validate(args, MODEL_SPEC.inputs)
validation_messages = validation.validate(args, MODEL_SPEC)
sufficient_valid_keys = (validation.get_sufficient_keys(args) -
validation.get_invalid_keys(validation_messages))

View File

@ -540,7 +540,7 @@ def validate(args, limit_to=None):
the error message in the second part of the tuple. This should
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(args, MODEL_SPEC.inputs)
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
sufficient_keys = validation.get_sufficient_keys(args)

View File

@ -917,7 +917,7 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(args, MODEL_SPEC.inputs)
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
if ('convert_nearest_to_edge' not in invalid_keys and

View File

@ -284,7 +284,7 @@ def execute(args):
'b': float(args['b_coef']),
}
if (args['valuation_function'] not in
MODEL_SPEC['args']['valuation_function']['options']):
MODEL_SPEC.inputs.get('valuation_function').options):
raise ValueError('Valuation function type %s not recognized' %
args['valuation_function'])
max_valuation_radius = float(args['max_valuation_radius'])
@ -1121,5 +1121,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
return validation.validate(args, MODEL_SPEC)

View File

@ -563,9 +563,9 @@ def execute(args):
"""
file_suffix = utils.make_suffix_string(args, 'results_suffix')
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
MODEL_SPEC.inputs.biophysical_table_path)
biophysical_df = MODEL_SPEC.inputs.get(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
# Test to see if c or p values are outside of 0..1
for key in ['usle_c', 'usle_p']:
@ -1600,5 +1600,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
return validation.validate(args, MODEL_SPEC)

View File

@ -626,20 +626,19 @@ def execute(args):
# fail early on a missing required rain events table
if (not args['user_defined_local_recharge'] and
not args['user_defined_climate_zones']):
rain_events_df = validation.get_validated_dataframe(
args['rain_events_table_path'],
MODEL_SPEC.inputs.rain_events_table_path)
rain_events_df = MODEL_SPEC.inputs.get(
'rain_events_table_path').get_validated_dataframe(
args['rain_events_table_path'])
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
MODEL_SPEC.inputs.biophysical_table_path)
biophysical_df = MODEL_SPEC.inputs.get(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
if args['monthly_alpha']:
# parse out the alpha lookup table of the form (month_id: alpha_val)
alpha_month_map = validation.get_validated_dataframe(
args['monthly_alpha_path'],
MODEL_SPEC.inputs.monthly_alpha_path
)['alpha'].to_dict()
alpha_month_map = MODEL_SPEC.inputs.get(
'monthly_alpha_path').get_validated_dataframe(
args['monthly_alpha_path'])['alpha'].to_dict()
else:
# make all 12 entries equal to args['alpha_m']
alpha_m = float(fractions.Fraction(args['alpha_m']))
@ -816,9 +815,9 @@ def execute(args):
'table_name': 'Climate Zone'}
for month_id in range(N_MONTHS):
if args['user_defined_climate_zones']:
cz_rain_events_df = validation.get_validated_dataframe(
args['climate_zone_table_path'],
MODEL_SPEC.inputs.climate_zone_table_path)
cz_rain_events_df = MODEL_SPEC.inputs.get(
'climate_zone_table_path').get_validated_dataframe(
args['climate_zone_table_path'])
climate_zone_rain_events_month = (
cz_rain_events_df[MONTH_ID_TO_LABEL[month_id]].to_dict())
n_events_task = task_graph.add_task(
@ -1486,5 +1485,4 @@ def validate(args, limit_to=None):
the error message in the second part of the tuple. This should
be an empty list if validation succeeds.
"""
return validation.validate(args, MODEL_SPEC.inputs,
MODEL_SPEC.args_with_spatial_overlap)
return validation.validate(args, MODEL_SPEC)

View File

@ -28,7 +28,8 @@ from .unit_registry import u
LOGGER = logging.getLogger(__name__)
def get_validated_dataframe(key, model_spec, args):
return MODEL_SPEC.inputs.get(key).get_validated_dataframe(args[key])
# accessing a file could take a long time if it's in a file streaming service
# to prevent the UI from hanging due to slow validation,
@ -176,7 +177,6 @@ def _check_projection(srs, projected, projection_units):
class IterableWithDotAccess():
def __init__(self, *args):
print(args)
self.args = args
self.inputs_dict = {i.id: i for i in args}
self.iter_index = 0
@ -190,6 +190,9 @@ class IterableWithDotAccess():
def get(self, key):
return self.inputs_dict[key]
def to_json(self):
return self.inputs_dict
# def __next__(self):
# print('next')
# if self.iter_index < len(self.args):
@ -203,6 +206,9 @@ class IterableWithDotAccess():
class ModelInputs(IterableWithDotAccess):
pass
class ModelOutputs(IterableWithDotAccess):
pass
class Rows(IterableWithDotAccess):
pass
@ -230,8 +236,9 @@ class InputSpec:
@dataclasses.dataclass(kw_only=True)
class OutputSpec:
about: str
created_if: bool | str
id: str = ''
about: str = ''
created_if: bool | str = True
@dataclasses.dataclass(kw_only=True)
class FileInputSpec(InputSpec):
@ -289,7 +296,7 @@ class SingleBandRasterInputSpec(FileInputSpec):
A string error message if an error was found. ``None`` otherwise.
"""
file_warning = super().validate(filepath)
file_warning = FileInputSpec.validate(self, filepath)
if file_warning:
return file_warning
@ -504,12 +511,11 @@ class CSVInputSpec(FileInputSpec):
available_cols -= set(matching_cols)
for col in matching_cols:
try:
print(df[col])
df[col] = col_spec.format_column(df[col], csv_path)
except Exception as err:
raise ValueError(
f'Value(s) in the "{col}" column could not be interpreted '
f'as {type(col_spec)}s. Original error: {err}')
f'as {type(col_spec).__name__}s. Original error: {err}')
if type(col_spec) in {SingleBandRasterInputSpec,
VectorInputSpec, RasterOrVectorInputSpec}:
@ -548,7 +554,7 @@ class CSVInputSpec(FileInputSpec):
@dataclasses.dataclass(kw_only=True)
class DirectoryInputSpec(InputSpec):
contents: Contents | None = None
permissions: str = 'rx'
permissions: str = ''
must_exist: bool = True
# @timeout
@ -822,21 +828,21 @@ class SingleBandRasterOutputSpec(OutputSpec):
@dataclasses.dataclass(kw_only=True)
class VectorOutputSpec(OutputSpec):
geometries: set
fields: types.SimpleNamespace
projected: bool
projection_units: pint.Unit
fields: Fields
projected: bool | None = None
projection_units: pint.Unit | None = None
@dataclasses.dataclass(kw_only=True)
class CSVOutputSpec(OutputSpec):
columns: types.SimpleNamespace
rows: types.SimpleNamespace
columns: Columns
rows: Rows
index_col: str
@dataclasses.dataclass(kw_only=True)
class DirectoryOutputSpec(OutputSpec):
contents: types.SimpleNamespace
permissions: str
must_exist: bool
contents: Contents | None = None
permissions: str = ''
must_exist: bool = True
@dataclasses.dataclass(kw_only=True)
class FileOutputSpec(OutputSpec):
@ -870,7 +876,7 @@ class OptionStringOutputSpec(OutputSpec):
@dataclasses.dataclass(kw_only=True)
class UISpec:
order: list
hidden: list
hidden: list = None
dropdown_functions: dict = dataclasses.field(default_factory=dict)
@ -887,18 +893,16 @@ class ModelSpec:
def build_model_spec(model_spec):
x = [build_input_spec(argkey, argspec) for argkey, argspec in model_spec['args'].items()]
print(x)
print(*x)
inputs = ModelInputs(*x)
for i in inputs:
print(i.id)
outputs = types.SimpleNamespace({
argkey: build_output_spec(argspec, argkey) for argkey, argspec in model_spec['outputs'].items()
})
input_specs = [
build_input_spec(argkey, argspec)
for argkey, argspec in model_spec['args'].items()]
output_specs = [
build_output_spec(argkey, argspec) for argkey, argspec in model_spec['outputs'].items()]
inputs = ModelInputs(*input_specs)
outputs = ModelOutputs(*output_specs)
ui_spec = UISpec(
order=model_spec['ui_spec']['order'],
hidden=model_spec['ui_spec']['hidden'],
hidden=model_spec['ui_spec'].get('hidden', None),
dropdown_functions=model_spec['ui_spec'].get('dropdown_functions', None))
return ModelSpec(
model_id=model_spec['model_id'],
@ -989,7 +993,7 @@ def build_input_spec(argkey, arg):
return DirectoryInputSpec(
contents=Contents(*[
build_input_spec(k, v) for k, v in arg['contents'].items()]),
permissions=arg.get('permissions', None),
permissions=arg.get('permissions', 'rx'),
must_exist=arg.get('must_exist', None),
**base_attrs)
@ -1010,8 +1014,9 @@ def build_input_spec(argkey, arg):
raise ValueError
def build_output_spec(spec, key=None):
def build_output_spec(key, spec):
base_attrs = {
'id': key,
'about': spec.get('about', None),
'created_if': spec.get('created_if', None)
}
@ -1052,7 +1057,7 @@ def build_output_spec(spec, key=None):
elif t == 'raster':
return SingleBandRasterOutputSpec(
**base_attrs,
band=build_output_spec(spec['bands'][1]),
band=build_output_spec(1, spec['bands'][1]),
projected=None,
projection_units=None)
@ -1060,27 +1065,23 @@ def build_output_spec(spec, key=None):
return VectorOutputSpec(
**base_attrs,
geometries=spec['geometries'],
fields=types.SimpleNamespace({
key: build_output_spec(field_spec) for key, field_spec in spec['fields'].items()
}),
fields=Fields(*[
build_output_spec(key, field_spec) for key, field_spec in spec['fields'].items()]),
projected=None,
projection_units=None)
elif t == 'csv':
columns = types.SimpleNamespace()
for col_name, col_spec in spec['columns'].items():
setattr(columns, col_name, build_output_spec(col_spec))
return CSVOutputSpec(
**base_attrs,
columns=columns,
columns=Columns(*[
build_output_spec(key, col_spec) for key, col_spec in spec['columns'].items()]),
rows=None,
index_col=spec.get('index_col', None))
elif t == 'directory':
return DirectoryOutputSpec(
contents=types.SimpleNamespace({
k: build_output_spec(v, k) for k, v in spec['contents'].items()
}),
contents=Contents(*[
build_output_spec(k, v) for k, v in spec['contents'].items()]),
permissions=None,
must_exist=None,
**base_attrs)
@ -1390,9 +1391,15 @@ def serialize_args_spec(spec):
return str(obj)
elif isinstance(obj, types.FunctionType):
return str(obj)
elif dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)
elif isinstance(obj, IterableWithDotAccess):
return obj.to_json()
raise TypeError(f'fallback serializer is missing for {type(obj)}')
return json.dumps(spec, default=fallback_serializer)
x = json.dumps(spec, default=fallback_serializer)
print(x)
return x
# accepted geometries for a vector will be displayed in this order
@ -1534,50 +1541,40 @@ def format_type_string(arg_type):
# some types need a more user-friendly name
# all types are listed here so that they can be marked up for translation
type_names = {
'boolean': gettext('true/false'),
'csv': gettext('CSV'),
'directory': gettext('directory'),
'file': gettext('file'),
'freestyle_string': gettext('text'),
'integer': gettext('integer'),
'number': gettext('number'),
'option_string': gettext('option'),
'percent': gettext('percent'),
'raster': gettext('raster'),
'ratio': gettext('ratio'),
'vector': gettext('vector')
BooleanInputSpec: gettext('true/false'),
CSVInputSpec: gettext('CSV'),
DirectoryInputSpec: gettext('directory'),
FileInputSpec: gettext('file'),
StringInputSpec: gettext('text'),
IntegerInputSpec: gettext('integer'),
NumberInputSpec: gettext('number'),
OptionStringInputSpec: gettext('option'),
PercentInputSpec: gettext('percent'),
SingleBandRasterInputSpec: gettext('raster'),
RatioInputSpec: gettext('ratio'),
VectorInputSpec: gettext('vector'),
RasterOrVectorInputSpec: gettext('raster or vector')
}
def format_single_type(arg_type):
"""Represent a type as a link to the corresponding Input Types section.
Args:
arg_type (str): the type to format.
Returns:
formatted string that links to a description of the input type
"""
# Represent the type as a string. Some need a more user-friendly name.
# we can only use standard docutils features here, so no :ref:
# this syntax works to link to a section in a different page, but it
# isn't universally supported and depends on knowing the built page name.
if arg_type == 'freestyle_string':
section_name = 'text'
elif arg_type == 'option_string':
section_name = 'option'
elif arg_type == 'boolean':
section_name = 'truefalse'
elif arg_type == 'csv':
section_name = 'csv'
else:
section_name = arg_type
return f'`{type_names[arg_type]} <{INPUT_TYPES_HTML_FILE}#{section_name}>`__'
if isinstance(arg_type, set):
return ' or '.join(format_single_type(t) for t in sorted(arg_type))
else:
return format_single_type(arg_type)
type_sections = { # names of section headers to link to in the RST
BooleanInputSpec: 'truefalse',
CSVInputSpec: 'csv',
DirectoryInputSpec: 'directory',
FileInputSpec: 'file',
StringInputSpec: 'text',
IntegerInputSpec: 'integer',
NumberInputSpec: 'number',
OptionStringInputSpec: 'option',
PercentInputSpec: 'percent',
SingleBandRasterInputSpec: 'raster',
RatioInputSpec: 'ratio',
VectorInputSpec: 'vector',
RasterOrVectorInputSpec: 'raster'
}
if arg_type is RasterOrVectorInputSpec:
return (
f'`{type_names[SingleBandRasterInputSpec]} <{INPUT_TYPES_HTML_FILE}#{type_sections[SingleBandRasterInputSpec]}>`__ or '
f'`{type_names[VectorInputSpec]} <{INPUT_TYPES_HTML_FILE}#{type_sections[VectorInputSpec]}>`__')
return f'`{type_names[arg_type]} <{INPUT_TYPES_HTML_FILE}#{type_sections[arg_type]}>`__'
def describe_arg_from_spec(name, spec):
@ -1603,15 +1600,15 @@ def describe_arg_from_spec(name, spec):
lines that are indented, that describe details of the arg such as
vector fields and geometries, option_string options, etc.
"""
type_string = format_type_string(spec['type'])
type_string = format_type_string(spec.__class__)
in_parentheses = [type_string]
# For numbers and rasters that have units, display the units
units = None
if spec['type'] == 'number':
units = spec['units']
elif spec['type'] == 'raster' and spec['bands'][1]['type'] == 'number':
units = spec['bands'][1]['units']
if spec.__class__ is NumberInputSpec:
units = spec.units
elif spec.__class__ is SingleBandRasterInputSpec and spec.band.__class__ is NumberInputSpec:
units = spec.band.units
if units:
units_string = format_unit(units)
if units_string:
@ -1619,19 +1616,18 @@ def describe_arg_from_spec(name, spec):
translated_units = gettext("units")
in_parentheses.append(f'{translated_units}: **{units_string}**')
if spec['type'] == 'vector':
in_parentheses.append(format_geometries_string(spec["geometries"]))
if spec.__class__ is VectorInputSpec:
in_parentheses.append(format_geometries_string(spec.geometries))
# Represent the required state as a string, defaulting to required
# It doesn't make sense to include this for boolean checkboxes
if spec['type'] != 'boolean':
# get() returns None if the key doesn't exist in the dictionary
required_string = format_required_string(spec.get('required'))
if spec.__class__ is not BooleanInputSpec:
required_string = format_required_string(spec.required)
in_parentheses.append(f'*{required_string}*')
# Nested args may not have an about section
if 'about' in spec:
sanitized_about_string = spec["about"].replace("_", "\\_")
if spec.about:
sanitized_about_string = spec.about.replace("_", "\\_")
about_string = f': {sanitized_about_string}'
else:
about_string = ''
@ -1640,19 +1636,19 @@ def describe_arg_from_spec(name, spec):
# Add details for the types that have them
indented_block = []
if spec['type'] == 'option_string':
if spec.__class__ is OptionStringInputSpec:
# may be either a dict or set. if it's empty, the options are
# dynamically generated. don't try to document them.
if spec['options']:
if isinstance(spec['options'], dict):
if spec.options:
if isinstance(spec.options, dict):
indented_block.append(gettext('Options:'))
indented_block += format_options_string_from_dict(spec['options'])
indented_block += format_options_string_from_dict(spec.options)
else:
formatted_options = format_options_string_from_list(spec['options'])
formatted_options = format_options_string_from_list(spec.options)
indented_block.append(gettext('Options:') + f' {formatted_options}')
elif spec['type'] == 'csv':
if 'columns' not in spec and 'rows' not in spec:
elif spec.__class__ is CSVInputSpec:
if not spec.columns and not spec.rows:
first_line += gettext(
' Please see the sample data table for details on the format.')
@ -1676,22 +1672,25 @@ def describe_arg_from_name(module_name, *arg_keys):
module = importlib.import_module(module_name)
# start with the spec for all args
# narrow down to the nested spec indicated by the sequence of arg keys
spec = module.MODEL_SPEC['args']
spec = module.MODEL_SPEC.inputs
for i, key in enumerate(arg_keys):
# convert raster band numbers to ints
if arg_keys[i - 1] == 'bands':
key = int(key)
try:
spec = spec[key]
except KeyError:
keys_so_far = '.'.join(arg_keys[:i + 1])
raise ValueError(
f"Could not find the key '{keys_so_far}' in the "
f"{module_name} model's MODEL_SPEC")
if key in {'bands', 'fields', 'contents', 'columns', 'rows'}:
spec = getattr(spec, key)
else:
try:
spec = spec.get(key)
except KeyError:
keys_so_far = '.'.join(arg_keys[:i + 1])
raise ValueError(
f"Could not find the key '{keys_so_far}' in the "
f"{module_name} model's MODEL_SPEC")
# format spec into an RST formatted description string
if 'name' in spec:
arg_name = capitalize(spec['name'])
if spec.name:
arg_name = capitalize(spec.name)
else:
arg_name = arg_keys[-1]
@ -1721,31 +1720,26 @@ def write_metadata_file(datasource_path, spec, lineage_statement, keywords_list)
words = resource.get_keywords()
resource.set_keywords(set(words + keywords_list))
if 'about' in spec:
resource.set_description(spec['about'])
attr_spec = None
if 'columns' in spec:
attr_spec = spec['columns']
if 'fields' in spec:
attr_spec = spec['fields']
if attr_spec:
for key, value in attr_spec.items():
about = value['about'] if 'about' in value else ''
units = format_unit(value['units']) if 'units' in value else ''
if spec.about:
resource.set_description(spec.about)
attr_specs = None
if hasattr(spec, 'columns') and spec.columns:
attr_specs = spec.columns
if hasattr(spec, 'fields') and spec.fields:
attr_specs = spec.fields
if attr_specs:
for nested_spec in attr_specs:
units = format_unit(nested_spec.units) if hasattr(nested_spec, 'units') else ''
try:
resource.set_field_description(
key, description=about, units=units)
nested_spec.id, description=nested_spec.about, units=units)
except KeyError as error:
# fields that are in the spec but missing
# from model results because they are conditional.
LOGGER.debug(error)
if 'bands' in spec:
for idx, value in spec['bands'].items():
try:
units = format_unit(spec['bands'][idx]['units'])
except KeyError:
units = ''
resource.set_band_description(idx, units=units)
if hasattr(spec, 'band'):
units = format_unit(spec.band.units)
resource.set_band_description(1, units=units)
resource.write()
@ -1768,18 +1762,19 @@ def generate_metadata(model_module, args_dict):
lineage_statement = (
f'Created by {model_module.__name__}.execute(\n{formatted_args})\n'
f'Version {natcap.invest.__version__}')
keywords = [model_module.MODEL_SPEC['model_id'], 'InVEST']
keywords = [model_module.MODEL_SPEC.model_id, 'InVEST']
def _walk_spec(output_spec, workspace):
for filename, spec_data in output_spec.items():
if 'type' in spec_data and spec_data['type'] == 'directory':
if 'taskgraph.db' in spec_data['contents']:
for spec_data in output_spec:
print(spec_data)
if spec_data.__class__ is DirectoryOutputSpec:
if 'taskgraph.db' in [s.id for s in spec_data.contents]:
continue
_walk_spec(
spec_data['contents'],
os.path.join(workspace, filename))
spec_data.contents,
os.path.join(workspace, spec_data.id))
else:
pre, post = os.path.splitext(filename)
pre, post = os.path.splitext(spec_data.id)
full_path = os.path.join(workspace, f'{pre}{file_suffix}{post}')
if os.path.exists(full_path):
try:
@ -1789,4 +1784,4 @@ def generate_metadata(model_module, args_dict):
# Some unsupported file formats, e.g. html
LOGGER.debug(error)
_walk_spec(model_module.MODEL_SPEC['outputs'], args_dict['workspace_dir'])
_walk_spec(model_module.MODEL_SPEC.outputs, args_dict['workspace_dir'])

View File

@ -498,9 +498,9 @@ def execute(args):
# Build a lookup dictionary mapping each LULC code to its row
# sort by the LULC codes upfront because we use the sorted list in multiple
# places. it's more efficient to do this once.
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table'], MODEL_SPEC.inputs.biophysical_table
).sort_index()
biophysical_df = MODEL_SPEC.inputs.get(
'biophysical_table').get_validated_dataframe(
args['biophysical_table']).sort_index()
sorted_lucodes = biophysical_df.index.to_list()
# convert the nested dictionary in to a 2D array where rows are LULC codes
@ -1171,7 +1171,7 @@ def raster_average(raster_path, radius, kernel_path, out_path):
target_nodata=FLOAT_NODATA)
@ validation.invest_validator
@validation.invest_validator
def validate(args, limit_to=None):
"""Validate args to ensure they conform to `execute`'s contract.
@ -1189,8 +1189,7 @@ def validate(args, limit_to=None):
the error message in the second part of the tuple. This should
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(args, MODEL_SPEC.inputs,
MODEL_SPEC.args_with_spatial_overlap)
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
if 'soil_group_path' not in invalid_keys:
# check that soil group raster has integer type

View File

@ -1,4 +1,5 @@
"""Urban Cooling Model."""
import copy
import logging
import math
import os
@ -475,9 +476,9 @@ def execute(args):
intermediate_dir = os.path.join(
args['workspace_dir'], 'intermediate')
utils.make_directories([args['workspace_dir'], intermediate_dir])
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
MODEL_SPEC.inputs.biophysical_table_path)
biophysical_df = MODEL_SPEC.inputs.get(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
# cast to float and calculate relative weights
# Use default weights for shade, albedo, eti if the user didn't provide
@ -1160,9 +1161,9 @@ def calculate_energy_savings(
for field in target_building_layer.schema]
type_field_index = fieldnames.index('type')
energy_consumption_df = validation.get_validated_dataframe(
energy_consumption_table_path,
MODEL_SPEC.inputs.energy_consumption_table_path)
energy_consumption_df = MODEL_SPEC.inputs.get(
'energy_consumption_table_path').get_validated_dataframe(
energy_consumption_table_path)
target_building_layer.StartTransaction()
last_time = time.time()
@ -1536,26 +1537,24 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(
args, MODEL_SPEC.args, MODEL_SPEC.args_with_spatial_overlap)
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
if ('biophysical_table_path' not in invalid_keys and
'cc_method' not in invalid_keys):
spec = copy.deepcopy(MODEL_SPEC.inputs.get('biophysical_table_path'))
if args['cc_method'] == 'factors':
extra_biophysical_keys = ['shade', 'albedo']
spec.columns.get('shade').required = True
spec.columns.get('albedo').required = True
else:
# args['cc_method'] must be 'intensity'.
# If args['cc_method'] isn't one of these two allowed values
# ('intensity' or 'factors'), it'll be caught by
# validation.validate due to the allowed values stated in
# MODEL_SPEC.
extra_biophysical_keys = ['building_intensity']
spec.columns.get('building_intensity').required = True
error_msg = validation.check_csv(
args['biophysical_table_path'],
header_patterns=extra_biophysical_keys,
axis=1)
error_msg = spec.validate(args['biophysical_table_path'])
if error_msg:
validation_warnings.append((['biophysical_table_path'], error_msg))

View File

@ -308,9 +308,9 @@ def execute(args):
task_name='align raster stack')
# Load CN table
cn_df = validation.get_validated_dataframe(
args['curve_number_table_path'],
MODEL_SPEC.inputs.curve_number_table_path)
cn_df = MODEL_SPEC.inputs.get(
'curve_number_table_path').get_validated_dataframe(
args['curve_number_table_path'])
# make cn_table into a 2d array where first dim is lucode, second is
# 0..3 to correspond to CN_A..CN_D
@ -637,10 +637,9 @@ def _calculate_damage_to_infrastructure_in_aoi(
infrastructure_vector = gdal.OpenEx(structures_vector_path, gdal.OF_VECTOR)
infrastructure_layer = infrastructure_vector.GetLayer()
damage_type_map = validation.get_validated_dataframe(
structures_damage_table,
MODEL_SPEC.inputs.infrastructure_damage_loss_table_path
)['damage'].to_dict()
damage_type_map = MODEL_SPEC.inputs.get(
'infrastructure_damage_loss_table_path').get_validated_dataframe(
structures_damage_table)['damage'].to_dict()
infrastructure_layer_defn = infrastructure_layer.GetLayerDefn()
type_index = -1
@ -930,8 +929,7 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
validation_warnings = validation.validate(args, MODEL_SPEC)
sufficient_keys = validation.get_sufficient_keys(args)
invalid_keys = validation.get_invalid_keys(validation_warnings)
@ -939,9 +937,9 @@ def validate(args, limit_to=None):
if ("curve_number_table_path" not in invalid_keys and
"curve_number_table_path" in sufficient_keys):
# Load CN table. Resulting DF has index and CN_X columns only.
cn_df = validation.get_validated_dataframe(
args['curve_number_table_path'],
MODEL_SPEC.inputs.curve_number_table_path)
cn_df = MODEL_SPEC.inputs.get(
'curve_number_table_path').get_validated_dataframe(
args['curve_number_table_path'])
# Check for NaN values.
nan_mask = cn_df.isna()
if nan_mask.any(axis=None):

View File

@ -945,9 +945,8 @@ def execute(args):
aoi_reprojection_task, lulc_mask_task]
)
attr_table = validation.get_validated_dataframe(
args['lulc_attribute_table'],
MODEL_SPEC.inputs.lulc_attribute_table)
attr_table = MODEL_SPEC.inputs.get(
'lulc_attribute_table').get_validated_dataframe(args['lulc_attribute_table'])
kernel_paths = {} # search_radius, kernel path
kernel_tasks = {} # search_radius, kernel task
@ -965,15 +964,15 @@ def execute(args):
lucode_to_search_radii = list(
urban_nature_attrs[['search_radius_m']].itertuples(name=None))
elif args['search_radius_mode'] == RADIUS_OPT_POP_GROUP:
pop_group_table = validation.get_validated_dataframe(
args['population_group_radii_table'],
MODEL_SPEC.inputs.population_group_radii_table)
pop_group_table = MODEL_SPEC.inputs.get(
'population_group_radii_table').get_validated_dataframe(
args['population_group_radii_table'])
search_radii = set(pop_group_table['search_radius_m'].unique())
# Build a dict of {pop_group: search_radius_m}
search_radii_by_pop_group = pop_group_table['search_radius_m'].to_dict()
else:
valid_options = ', '.join(
MODEL_SPEC['args']['search_radius_mode']['options'].keys())
MODEL_SPEC.inputs.get('search_radius_mode').options.keys())
raise ValueError(
"Invalid search radius mode provided: "
f"{args['search_radius_mode']}; must be one of {valid_options}")
@ -1845,8 +1844,8 @@ def _reclassify_urban_nature_area(
Returns:
``None``
"""
lulc_attribute_df = validation.get_validated_dataframe(
lulc_attribute_table, MODEL_SPEC.inputs.lulc_attribute_table)
lulc_attribute_df = MODEL_SPEC.inputs.get(
'lulc_attribute_table').get_validated_dataframe(lulc_attribute_table)
squared_pixel_area = abs(
numpy.multiply(*_square_off_pixels(lulc_raster_path)))
@ -1878,9 +1877,9 @@ def _reclassify_urban_nature_area(
target_datatype=gdal.GDT_Float32,
target_nodata=FLOAT32_NODATA,
error_details={
'raster_name': MODEL_SPEC['args']['lulc_raster_path']['name'],
'raster_name': MODEL_SPEC.inputs.get('lulc_raster_path').name,
'column_name': 'urban_nature',
'table_name': MODEL_SPEC['args']['lulc_attribute_table']['name'],
'table_name': MODEL_SPEC.inputs.get('lulc_attribute_table').name
}
)
@ -2602,5 +2601,4 @@ def _mask_raster(source_raster_path, mask_raster_path, target_raster_path):
def validate(args, limit_to=None):
return validation.validate(
args, MODEL_SPEC.inputs, MODEL_SPEC.args_with_spatial_overlap)
return validation.validate(args, MODEL_SPEC)

View File

@ -18,6 +18,7 @@ import pygeoprocessing
import requests
from . import utils
from . import spec_utils
ENCODING = sys.getfilesystemencoding()
LOGGER = logging.getLogger(__name__)
@ -65,12 +66,12 @@ def log_run(model_pyname, args):
log_exit_thread.start()
def _calculate_args_bounding_box(args, args_spec):
def _calculate_args_bounding_box(args, model_spec):
"""Calculate the bounding boxes of any GIS types found in `args_dict`.
Args:
args (dict): a string key and any value pair dictionary.
args_spec (dict): the model MODEL_SPEC describing args
model_spec (dict): the model's MODEL_SPEC
Returns:
bb_intersection, bb_union tuple that's either the lat/lng bounding
@ -119,10 +120,11 @@ def _calculate_args_bounding_box(args, args_spec):
# should already have been validated so the path is either valid or
# blank.
spatial_info = None
if args_spec['args'][key]['type'] == 'raster' and value.strip() != '':
if (isinstance(model_spec.inputs.get(key),
spec_utils.SingleBandRasterInputSpec) and value.strip() != ''):
spatial_info = pygeoprocessing.get_raster_info(value)
elif (args_spec['args'][key]['type'] == 'vector'
and value.strip() != ''):
elif (isinstance(model_spec.inputs.get(key),
spec_utils.SingleBandRasterInputSpec) and value.strip() != ''):
spatial_info = pygeoprocessing.get_vector_info(value)
if spatial_info:
@ -156,7 +158,7 @@ def _calculate_args_bounding_box(args, args_spec):
LOGGER.exception(
f'Error when transforming coordinates: {transform_error}')
else:
LOGGER.debug(f'Arg {key} of type {args_spec["args"][key]["type"]} '
LOGGER.debug(f'Arg {key} of type {model_spec.inputs.get(key).__class__} '
'excluded from bounding box calculation')
return bb_intersection, bb_union
@ -216,11 +218,11 @@ def _log_model(pyname, model_args, invest_interface, session_id=None):
md5.update(json.dumps(data).encode('utf-8'))
return md5.hexdigest()
args_spec = importlib.import_module(pyname).MODEL_SPEC
model_spec = importlib.import_module(pyname).MODEL_SPEC
try:
bounding_box_intersection, bounding_box_union = (
_calculate_args_bounding_box(model_args, args_spec))
_calculate_args_bounding_box(model_args, model_spec))
log_start_url = requests.get(_ENDPOINTS_INDEX_URL).json()['START']
requests.post(log_start_url, data={
'model_name': pyname,

View File

@ -238,24 +238,15 @@ def _format_bbox_list(file_list, bbox_list):
file_list, bbox_list)])
def validate(args, spec, spatial_overlap_opts=None):
def validate(args, spec):
"""Validate an args dict against a model spec.
Validates an arguments dictionary according to the rules laid out in
``spec``. If ``spatial_overlap_opts`` is also provided, valid spatial
inputs will be checked for spatial overlap.
``spec``.
Args:
args (dict): The InVEST model args dict to validate.
spec (dict): The InVEST model spec dict to validate against.
spatial_overlap_opts=None (dict): A dict. If provided, the key
``"spatial_keys"`` is required to be a list of keys that may be present
in the args dict and (if provided in args) will be checked for
overlap with all other keys in this list. If the key
``"reference_key"`` is also present in this dict, the bounding
boxes of each of the files represented by
``spatial_overlap_opts["spatial_keys"]`` will be transformed to the
SRS of the dataset at this key.
Returns:
A list of tuples where the first element of the tuple is an iterable of
@ -339,8 +330,8 @@ def validate(args, spec, spatial_overlap_opts=None):
validation_warnings.append(([key], MESSAGES['UNEXPECTED_ERROR']))
# Phase 3: Check spatial overlap if applicable
if spatial_overlap_opts:
spatial_keys = set(spatial_overlap_opts['spatial_keys'])
if spec.args_with_spatial_overlap:
spatial_keys = set(spec.args_with_spatial_overlap['spatial_keys'])
# Only test for spatial overlap once all the sufficient spatial keys
# are otherwise valid. And then only when there are at least 2.
@ -357,7 +348,7 @@ def validate(args, spec, spatial_overlap_opts=None):
try:
different_projections_ok = (
spatial_overlap_opts['different_projections_ok'])
spec.args_with_spatial_overlap['different_projections_ok'])
except KeyError:
different_projections_ok = False
@ -422,62 +413,43 @@ def invest_validator(validate_func):
# which gets imported into itself here and fails.
# Since this decorator might not be needed in the future,
# just ignore failed imports; assume they have no MODEL_SPEC.
try:
model_module = importlib.import_module(validate_func.__module__)
except Exception:
LOGGER.warning(
'Unable to import module %s: assuming no MODEL_SPEC.',
validate_func.__module__)
model_module = None
# If the module has an MODEL_SPEC defined, validate against that.
if hasattr(model_module, 'MODEL_SPEC'):
LOGGER.debug('Using MODEL_SPEC for validation')
args_spec = model_module.MODEL_SPEC.inputs
if limit_to is None:
LOGGER.info('Starting whole-model validation with MODEL_SPEC')
warnings_ = validate_func(args)
else:
LOGGER.info('Starting single-input validation with MODEL_SPEC')
args_key_spec = args_spec[limit_to]
args_value = args[limit_to]
error_msg = None
# We're only validating a single input. This is not officially
# supported in the validation function, but we can make it work
# within this decorator.
try:
if args_key_spec['required'] is True:
if args_value in ('', None):
error_msg = "Value is required"
except KeyError:
# If required is not defined in the args_spec, we default
# to False. If 'required' is an expression, we can't
# validate that outside of whole-model validation.
pass
# If the input is not required and does not have a value, no
# need to validate it.
if args_value not in ('', None):
input_type = args_key_spec['type']
if isinstance(input_type, set):
input_type = frozenset(input_type)
validator_func = _VALIDATION_FUNCS[input_type]
error_msg = validator_func(args_value, **args_key_spec)
if error_msg is None:
warnings_ = []
else:
warnings_ = [([limit_to], error_msg)]
else: # args_spec is not defined for this function.
LOGGER.warning('MODEL_SPEC not defined for this model')
model_module = importlib.import_module(validate_func.__module__)
if model_module.__name__ == 'test_validation':
warnings_ = validate_func(args, limit_to)
return warnings_
LOGGER.debug('Validation warnings: %s',
pprint.pformat(warnings_))
args_spec = model_module.MODEL_SPEC.inputs
if limit_to is None:
LOGGER.info('Starting whole-model validation with MODEL_SPEC')
warnings_ = validate_func(args)
else:
LOGGER.info('Starting single-input validation with MODEL_SPEC')
args_key_spec = args_spec.get(limit_to)
args_value = args[limit_to]
error_msg = None
# We're only validating a single input. This is not officially
# supported in the validation function, but we can make it work
# within this decorator.
# If 'required' is an expression, we can't
# validate that outside of whole-model validation.
if args_key_spec.required is True:
if args_value in ('', None):
error_msg = "Value is required"
# If the input is not required and does not have a value, no
# need to validate it.
if args_value not in ('', None):
error_msg = args_key_spec.validate(args_value)
if error_msg is None:
warnings_ = []
else:
warnings_ = [([limit_to], error_msg)]
LOGGER.debug(f'Validation warnings: {pprint.pformat(warnings_)}')
return warnings_
return _wrapped_validate_func

View File

@ -315,6 +315,7 @@ MODEL_SPEC = spec_utils.build_model_spec({
"about": gettext("Value of the machine parameter.")
}
},
"index_col": "name",
"about": gettext("Table of parameters for the wave energy machine in use."),
"name": gettext("machine parameter table")
},
@ -365,6 +366,7 @@ MODEL_SPEC = spec_utils.build_model_spec({
"about": gettext("Value of the machine parameter.")
}
},
"index_col": "name",
"required": "valuation_container",
"allowed": "valuation_container",
"about": gettext(
@ -754,21 +756,15 @@ def execute(args):
LOGGER.debug('Machine Performance Rows : %s', machine_perf_dict['periods'])
LOGGER.debug('Machine Performance Cols : %s', machine_perf_dict['heights'])
machine_param_dict = validation.get_validated_dataframe(
args['machine_param_path'],
index_col='name',
columns={
'name': {'type': 'option_string'},
'value': {'type': 'number'}
},
)['value'].to_dict()
machine_param_dict = MODEL_SPEC.inputs.get(
'machine_param_path').get_validated_dataframe(
args['machine_param_path'])['value'].to_dict()
# Check if required column fields are entered in the land grid csv file
if 'land_gridPts_path' in args:
# Create a grid_land_df dataframe for later use in valuation
grid_land_df = validation.get_validated_dataframe(
args['land_gridPts_path'],
MODEL_SPEC.inputs.land_gridPts_path)
grid_land_df = MODEL_SPEC.inputs.get(
'land_gridPts_path').get_validated_dataframe(args['land_gridPts_path'])
missing_grid_land_fields = []
for field in ['id', 'type', 'lat', 'long', 'location']:
if field not in grid_land_df.columns:
@ -780,14 +776,9 @@ def execute(args):
'Connection Points File: %s' % missing_grid_land_fields)
if 'valuation_container' in args and args['valuation_container']:
machine_econ_dict = validation.get_validated_dataframe(
args['machine_econ_path'],
index_col='name',
columns={
'name': {'type': 'option_string'},
'value': {'type': 'number'}
}
)['value'].to_dict()
machine_econ_dict = MODEL_SPEC.inputs.get(
'machine_econ_path').get_validated_dataframe(
args['machine_econ_path'])['value'].to_dict()
# Build up a dictionary of possible analysis areas where the key
# is the analysis area selected and the value is a dictionary
@ -2383,4 +2374,4 @@ def validate(args, limit_to=None):
validation warning.
"""
return validation.validate(args, MODEL_SPEC.inputs)
return validation.validate(args, MODEL_SPEC)

View File

@ -718,15 +718,13 @@ def execute(args):
number_of_turbines = int(args['number_of_turbines'])
# Read the biophysical turbine parameters into a dictionary
turbine_dict = validation.get_validated_dataframe(
args['turbine_parameters_path'],
MODEL_SPEC.inputs.turbine_parameters_path
).iloc[0].to_dict()
turbine_dict = MODEL_SPEC.inputs.get(
'turbine_parameters_path').get_validated_dataframe(
args['turbine_parameters_path']).iloc[0].to_dict()
# Read the biophysical global parameters into a dictionary
global_params_dict = validation.get_validated_dataframe(
args['global_wind_parameters_path'],
MODEL_SPEC.inputs.global_wind_parameters_path
).iloc[0].to_dict()
global_params_dict = MODEL_SPEC.inputs.get(
'global_wind_parameters_path').get_validated_dataframe(
args['global_wind_parameters_path']).iloc[0].to_dict()
# Combine the turbine and global parameters into one dictionary
parameters_dict = global_params_dict.copy()
@ -744,9 +742,9 @@ def execute(args):
# If Price Table provided use that for price of energy, validate inputs
time = parameters_dict['time_period']
if args['price_table']:
wind_price_df = validation.get_validated_dataframe(
args['wind_schedule'], MODEL_SPEC.inputs.wind_schedule
).sort_index() # sort by year
wind_price_df = MODEL_SPEC.inputs.get(
'wind_schedule').get_validated_dataframe(
args['wind_schedule']).sort_index() # sort by year
year_count = len(wind_price_df)
if year_count != time + 1:
@ -1115,8 +1113,8 @@ def execute(args):
LOGGER.info('Grid Points Provided. Reading in the grid points')
# Read the grid points csv, and convert it to land and grid dictionary
grid_land_df = validation.get_validated_dataframe(
args['grid_points_path'], MODEL_SPEC.inputs.grid_points_path)
grid_land_df = MODEL_SPEC.inputs.get(
'grid_points_path').get_validated_dataframe(args['grid_points_path'])
# Convert the dataframes to dictionaries, using 'ID' (the index) as key
grid_dict = grid_land_df[grid_land_df['type'] == 'grid'].to_dict('index')
@ -1936,8 +1934,8 @@ def _compute_density_harvested_fields(
# Read the wind energy data into a dictionary
LOGGER.info('Reading in Wind Data into a dictionary')
wind_point_df = validation.get_validated_dataframe(
wind_data_path, MODEL_SPEC.inputs.wind_data_path)
wind_point_df = MODEL_SPEC.inputs.get(
'wind_data_path').get_validated_dataframe(wind_data_path)
wind_point_df.columns = wind_point_df.columns.str.upper()
# Calculate scale value at new hub height given reference values.
# See equation 3 in users guide
@ -2666,8 +2664,7 @@ def validate(args, limit_to=None):
A list of tuples where tuple[0] is an iterable of keys that the error
message applies to and tuple[1] is the str validation warning.
"""
validation_warnings = validation.validate(args, MODEL_SPEC.inputs,
MODEL_SPEC.args_with_spatial_overlap)
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
sufficient_keys = validation.get_sufficient_keys(args)
valid_sufficient_keys = sufficient_keys - invalid_keys
@ -2676,15 +2673,9 @@ def validate(args, limit_to=None):
'global_wind_parameters_path' in valid_sufficient_keys):
year_count = utils.read_csv_to_dataframe(
args['wind_schedule']).shape[0]
time = validation.get_validated_dataframe(
args['global_wind_parameters_path'],
index_col='0',
columns={
'0': {'type': 'freestyle_string'},
'1': {'type': 'number'}
},
read_csv_kwargs={'header': None}
)['1']['time_period']
time = MODEL_SPEC.inputs.get(
'global_wind_parameters_path').get_validated_dataframe(
args['global_wind_parameters_path']).iloc[0]['time_period']
if year_count != time + 1:
validation_warnings.append((
['wind_schedule'],

View File

@ -211,9 +211,9 @@ class TestPreprocessor(unittest.TestCase):
lulc_csv.write('0,mangrove,True\n')
lulc_csv.write('1,parking lot,False\n')
landcover_df = validation.get_validated_dataframe(
landcover_table_path,
preprocessor.MODEL_SPEC.inputs.lulc_lookup_table_path)
landcover_df = preprocessor.MODEL_SPEC.inputs.get(
'lulc_lookup_table_path').get_validated_dataframe(landcover_table_path)
target_table_path = os.path.join(self.workspace_dir,
'transition_table.csv')
@ -227,9 +227,8 @@ class TestPreprocessor(unittest.TestCase):
str(context.exception))
# Re-load the landcover table
landcover_df = validation.get_validated_dataframe(
landcover_table_path,
preprocessor.MODEL_SPEC.inputs.lulc_lookup_table_path)
landcover_df = preprocessor.MODEL_SPEC.inputs.get(
'lulc_lookup_table_path').get_validated_dataframe(landcover_table_path)
preprocessor._create_transition_table(
landcover_df, [filename_a, filename_b], target_table_path)
@ -641,10 +640,9 @@ class TestCBC2(unittest.TestCase):
args = TestCBC2._create_model_args(self.workspace_dir)
args['workspace_dir'] = os.path.join(self.workspace_dir, 'workspace')
prior_snapshots = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
coastal_blue_carbon.MODEL_SPEC.inputs.landcover_snapshot_csv
)['raster_path'].to_dict()
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.inputs.get(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
baseline_year = min(prior_snapshots.keys())
baseline_raster = prior_snapshots[baseline_year]
with open(args['landcover_snapshot_csv'], 'w') as snapshot_csv:
@ -819,10 +817,9 @@ class TestCBC2(unittest.TestCase):
args = TestCBC2._create_model_args(self.workspace_dir)
args['workspace_dir'] = os.path.join(self.workspace_dir, 'workspace')
prior_snapshots = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
coastal_blue_carbon.MODEL_SPEC.inputs.landcover_snapshot_csv
)['raster_path'].to_dict()
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.inputs.get(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
baseline_year = min(prior_snapshots.keys())
baseline_raster = prior_snapshots[baseline_year]
with open(args['landcover_snapshot_csv'], 'w') as snapshot_csv:
@ -879,10 +876,9 @@ class TestCBC2(unittest.TestCase):
# Now work through the extra validation warnings.
# test validation: invalid analysis year
prior_snapshots = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
coastal_blue_carbon.MODEL_SPEC.inputs.landcover_snapshot_csv
)['raster_path'].to_dict()
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.inputs.get(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
baseline_year = min(prior_snapshots)
# analysis year must be >= the last transition year.
args['analysis_year'] = baseline_year

View File

@ -328,6 +328,7 @@ class DatastackArchiveTests(unittest.TestCase):
"""Datastack: test archive extraction."""
from natcap.invest import datastack
from natcap.invest import utils
from natcap.invest import spec_utils
from natcap.invest import validation
params = {
@ -407,13 +408,14 @@ class DatastackArchiveTests(unittest.TestCase):
self.assertTrue(
filecmp.cmp(archive_params[key], params[key], shallow=False))
spatial_csv_dict = validation.get_validated_dataframe(
archive_params['spatial_table'],
spatial_csv_dict = spec_utils.CSVInputSpec(
index_col='id',
columns={
'id': {'type': 'integer'},
'path': {'type': 'file'}
}).to_dict(orient='index')
columns=spec_utils.Columns(
spec_utils.IntegerInputSpec(id='id'),
spec_utils.FileInputSpec(id='path'))
).get_validated_dataframe(
archive_params['spatial_table']
).to_dict(orient='index')
spatial_csv_dir = os.path.dirname(archive_params['spatial_table'])
numpy.testing.assert_allclose(
pygeoprocessing.raster_to_numpy_array(

View File

@ -1,25 +1,27 @@
MODEL_SPEC = {
'args': {
'blank': {'type': 'freestyle_string'},
'a': {'type': 'integer'},
'b': {'type': 'freestyle_string'},
'c': {'type': 'freestyle_string'},
'foo': {'type': 'file'},
'bar': {'type': 'file'},
'data_dir': {'type': 'directory'},
'raster': {'type': 'raster'},
'vector': {'type': 'vector'},
'simple_table': {'type': 'csv'},
'spatial_table': {
'type': 'csv',
'columns': {
'ID': {'type': 'integer'},
'path': {
'type': {'raster', 'vector'},
'geometries': {'POINT', 'POLYGON'},
'bands': {1: {'type': 'number'}}
}
}
}
}
}
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
spec_utils.StringInputSpec(id='blank'),
spec_utils.IntegerInputSpec(id='a'),
spec_utils.StringInputSpec(id='b'),
spec_utils.StringInputSpec(id='c'),
spec_utils.FileInputSpec(id='foo'),
spec_utils.FileInputSpec(id='bar'),
spec_utils.DirectoryInputSpec(id='data_dir', contents={}),
spec_utils.SingleBandRasterInputSpec(id='raster', band=spec_utils.InputSpec()),
spec_utils.VectorInputSpec(id='vector', fields={}, geometries={}),
spec_utils.CSVInputSpec(id='simple_table'),
spec_utils.CSVInputSpec(
id='spatial_table',
columns=spec_utils.Columns(
spec_utils.IntegerInputSpec(id='ID'),
spec_utils.RasterOrVectorInputSpec(
id='path',
fields={},
geometries={'POINT', 'POLYGON'},
band=spec_utils.NumberInputSpec()
)
)
)
))

View File

@ -1,6 +1,7 @@
MODEL_SPEC = {
'args': {
'foo': {'type': 'file'},
'bar': {'type': 'file'},
}
}
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
spec_utils.FileInputSpec(id='foo'),
spec_utils.FileInputSpec(id='bar'),
))

View File

@ -1,6 +1,9 @@
MODEL_SPEC = {
'args': {
'some_file': {'type': 'file'},
'data_dir': {'type': 'directory'},
}
}
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
spec_utils.FileInputSpec(id='some_file'),
spec_utils.DirectoryInputSpec(
id='data_dir',
contents=spec_utils.Contents())
))

View File

@ -1,5 +1,6 @@
MODEL_SPEC = {
'args': {
'raster': {'type': 'raster'},
}
}
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
spec_utils.SingleBandRasterInputSpec(id='raster', band=spec_utils.InputSpec())
))

View File

@ -1,9 +1,13 @@
MODEL_SPEC = {
'args': {
'a': {'type': 'integer'},
'b': {'type': 'freestyle_string'},
'c': {'type': 'freestyle_string'},
'd': {'type': 'freestyle_string'},
'workspace_dir': {'type': 'directory'},
}
}
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
spec_utils.IntegerInputSpec(id='a'),
spec_utils.StringInputSpec(id='b'),
spec_utils.StringInputSpec(id='c'),
spec_utils.StringInputSpec(id='d'),
spec_utils.DirectoryInputSpec(
id='workspace_dir',
contents=spec_utils.Contents()
)
))

View File

@ -1,6 +1,7 @@
MODEL_SPEC = {
'args': {
'foo': {'type': 'freestyle_string'},
'bar': {'type': 'freestyle_string'},
}
}
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
spec_utils.StringInputSpec(id='foo'),
spec_utils.StringInputSpec(id='bar')
))

View File

@ -1,5 +1,7 @@
MODEL_SPEC = {
'args': {
'vector': {'type': 'vector'},
}
}
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
spec_utils.VectorInputSpec(
id='vector', fields={}, geometries={})
))

View File

@ -26,8 +26,8 @@ REGRESSION_DATA = os.path.join(
'delineateit')
# Skipping all compiled model tests temporarily for feature/plugins
pytestmark = pytest.mark.skip(
reason="Temporarily ignoring compiled models for feature/plugins")
# pytestmark = pytest.mark.skip(
# reason="Temporarily ignoring compiled models for feature/plugins")
@contextlib.contextmanager
def capture_logging(logger, level=logging.NOTSET):

View File

@ -18,8 +18,8 @@ REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'ndr')
# Skipping all compiled model tests temporarily for feature/plugins
pytestmark = pytest.mark.skip(
reason="Temporarily ignoring compiled models for feature/plugins")
# pytestmark = pytest.mark.skip(
# reason="Temporarily ignoring compiled models for feature/plugins")
class NDRTests(unittest.TestCase):
"""Regression tests for InVEST SDR model."""

View File

@ -44,8 +44,8 @@ SAMPLE_DATA = os.path.join(REGRESSION_DATA, 'input')
LOGGER = logging.getLogger('test_recreation')
# Skipping all compiled model tests temporarily for feature/plugins
pytestmark = pytest.mark.skip(
reason="Temporarily ignoring compiled models for feature/plugins")
# pytestmark = pytest.mark.skip(
# reason="Temporarily ignoring compiled models for feature/plugins")
def _timeout(max_timeout):
"""Timeout decorator, parameter in seconds."""
@ -710,9 +710,9 @@ class TestRecClientServer(unittest.TestCase):
out_regression_vector_path = os.path.join(
args['workspace_dir'], f'regression_data_{suffix}.gpkg')
predictor_df = validation.get_validated_dataframe(
os.path.join(SAMPLE_DATA, 'predictors_all.csv'),
**recmodel_client.MODEL_SPEC['args']['predictor_table_path'])
predictor_df = recmodel_client.MODEL_SPEC.inputs.get(
'predictor_table_path').get_validated_dataframe(
os.path.join(SAMPLE_DATA, 'predictors_all.csv'))
field_list = list(predictor_df.index) + ['pr_TUD', 'pr_PUD', 'avg_pr_UD']
# For convenience, assert the sums of the columns instead of all
@ -1284,10 +1284,9 @@ class RecreationClientRegressionTests(unittest.TestCase):
predictor_table_path = os.path.join(SAMPLE_DATA, 'predictors.csv')
# make outputs to be overwritten
predictor_dict = validation.get_validated_dataframe(
predictor_table_path,
**recmodel_client.MODEL_SPEC['args']['predictor_table_path']
).to_dict(orient='index')
predictor_dict = recmodel_client.MODEL_SPEC.inputs.get(
'predictor_table_path').get_validated_dataframe(
predictor_table_path).to_dict(orient='index')
predictor_list = predictor_dict.keys()
tmp_working_dir = tempfile.mkdtemp(dir=self.workspace_dir)
empty_json_list = [

View File

@ -21,8 +21,8 @@ _SRS.ImportFromEPSG(32731) # WGS84 / UTM zone 31s
WKT = _SRS.ExportToWkt()
# Skipping all compiled model tests temporarily for feature/plugins
pytestmark = pytest.mark.skip(
reason="Temporarily ignoring compiled models for feature/plugins")
# pytestmark = pytest.mark.skip(
# reason="Temporarily ignoring compiled models for feature/plugins")
class ScenicQualityTests(unittest.TestCase):
"""Tests for the InVEST Scenic Quality model."""

View File

@ -16,8 +16,8 @@ REGRESSION_DATA = os.path.join(
SAMPLE_DATA = os.path.join(REGRESSION_DATA, 'input')
# Skipping all compiled model tests temporarily for feature/plugins
pytestmark = pytest.mark.skip(
reason="Temporarily ignoring compiled models for feature/plugins")
# pytestmark = pytest.mark.skip(
# reason="Temporarily ignoring compiled models for feature/plugins")
def assert_expected_results_in_vector(expected_results, vector_path):
"""Assert one feature vector maps to expected_results key/value pairs."""
@ -375,7 +375,7 @@ class SDRTests(unittest.TestCase):
with self.assertRaises(ValueError) as context:
sdr.execute(args)
self.assertIn(
'could not be interpreted as ratios', str(context.exception))
'could not be interpreted as RatioInput', str(context.exception))
def test_lucode_not_a_number(self):
"""SDR test expected exception for invalid data in lucode column."""
@ -396,7 +396,7 @@ class SDRTests(unittest.TestCase):
with self.assertRaises(ValueError) as context:
sdr.execute(args)
self.assertIn(
'could not be interpreted as integers', str(context.exception))
'could not be interpreted as IntegerInput', str(context.exception))
def test_missing_lulc_value(self):
"""SDR test for ValueError when LULC value not found in table."""

View File

@ -17,9 +17,9 @@ REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data',
'seasonal_water_yield')
# Skipping all compiled model tests temporarily for feature/plugins
pytestmark = pytest.mark.skip(
reason="Temporarily ignoring compiled models for feature/plugins")
# # Skipping all compiled model tests temporarily for feature/plugins
# pytestmark = pytest.mark.skip(
# reason="Temporarily ignoring compiled models for feature/plugins")
def make_simple_shp(base_shp_path, origin):
"""Make a 100x100 ogr rectangular geometry shapefile.
@ -879,7 +879,7 @@ class SeasonalWaterYieldRegressionTests(unittest.TestCase):
with self.assertRaises(ValueError) as context:
seasonal_water_yield.execute(args)
self.assertIn(
'could not be interpreted as numbers', str(context.exception))
'could not be interpreted as NumberInput', str(context.exception))
def test_monthly_alpha_regression(self):
"""SWY monthly alpha values regression test on sample data.

View File

@ -41,82 +41,75 @@ class TestDescribeArgFromSpec(unittest.TestCase):
"""Test building RST for various invest args specifications."""
def test_number_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "number",
"units": u.meter**3/u.month,
"expression": "value >= 0"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.NumberInputSpec(
name="Bar",
about="Description",
units=u.meter**3/u.month,
expression="value >= 0"
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = ([
'**Bar** (`number <input_types.html#number>`__, '
'units: **m³/month**, *required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_ratio_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "ratio"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.RatioInputSpec(
name="Bar",
about="Description"
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = (['**Bar** (`ratio <input_types.html#ratio>`__, '
'*required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_percent_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "percent",
"required": False
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.PercentInputSpec(
name="Bar",
about="Description",
required=False
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = (['**Bar** (`percent <input_types.html#percent>`__, '
'*optional*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_code_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "integer",
"required": True
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
def test_integer_spec(self):
spec = spec_utils.IntegerInputSpec(
name="Bar",
about="Description",
required=True
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = (['**Bar** (`integer <input_types.html#integer>`__, '
'*required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_boolean_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "boolean"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.BooleanInputSpec(
name="Bar",
about="Description"
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = (['**Bar** (`true/false <input_types.html#truefalse>'
'`__): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_freestyle_string_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "freestyle_string"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.StringInputSpec(
name="Bar",
about="Description"
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = (['**Bar** (`text <input_types.html#text>`__, '
'*required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_option_string_spec_dictionary(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "option_string",
"options": {
spec = spec_utils.OptionStringInputSpec(
name="Bar",
about="Description",
options={
"option_a": {
"display_name": "A"
},
@ -128,10 +121,10 @@ class TestDescribeArgFromSpec(unittest.TestCase):
"description": "do something else"
}
}
}
)
# expect that option case is ignored
# otherwise, c would sort before A
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = ([
'**Bar** (`option <input_types.html#option>`__, *required*): Description',
'\tOptions:',
@ -142,13 +135,12 @@ class TestDescribeArgFromSpec(unittest.TestCase):
self.assertEqual(repr(out), repr(expected_rst))
def test_option_string_spec_list(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "option_string",
"options": ["option_a", "Option_b"]
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.OptionStringInputSpec(
name="Bar",
about="Description",
options=["option_a", "Option_b"]
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = ([
'**Bar** (`option <input_types.html#option>`__, *required*): Description',
'\tOptions: option_a, Option_b'
@ -156,77 +148,69 @@ class TestDescribeArgFromSpec(unittest.TestCase):
self.assertEqual(repr(out), repr(expected_rst))
def test_raster_spec(self):
spec = {
"type": "raster",
"bands": {1: {"type": "integer"}},
"about": "Description",
"name": "Bar"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.SingleBandRasterInputSpec(
band=spec_utils.IntegerInputSpec(),
about="Description",
name="Bar"
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = ([
'**Bar** (`raster <input_types.html#raster>`__, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
spec = {
"type": "raster",
"bands": {1: {
"type": "number",
"units": u.millimeter/u.year
}},
"about": "Description",
"name": "Bar"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.SingleBandRasterInputSpec(
band=spec_utils.NumberInputSpec(units=u.millimeter/u.year),
about="Description",
name="Bar"
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = ([
'**Bar** (`raster <input_types.html#raster>`__, units: **mm/year**, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_vector_spec(self):
spec = {
"type": "vector",
"fields": {},
"geometries": {"LINESTRING"},
"about": "Description",
"name": "Bar"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.VectorInputSpec(
fields={},
geometries={"LINESTRING"},
about="Description",
name="Bar"
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = ([
'**Bar** (`vector <input_types.html#vector>`__, linestring, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
spec = {
"type": "vector",
"fields": {
"id": {
"type": "integer",
"about": "Unique identifier for each feature"
},
"precipitation": {
"type": "number",
"units": u.millimeter/u.year,
"about": "Average annual precipitation over the area"
}
},
"geometries": {"POLYGON", "MULTIPOLYGON"},
"about": "Description",
"name": "Bar"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.VectorInputSpec(
fields=spec_utils.Fields(
spec_utils.IntegerInputSpec(
id="id",
about="Unique identifier for each feature"
),
spec_utils.NumberInputSpec(
id="precipitation",
units=u.millimeter/u.year,
about="Average annual precipitation over the area"
)
),
geometries={"POLYGON", "MULTIPOLYGON"},
about="Description",
name="Bar"
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = ([
'**Bar** (`vector <input_types.html#vector>`__, polygon/multipolygon, *required*): Description',
])
self.assertEqual(repr(out), repr(expected_rst))
def test_csv_spec(self):
spec = {
"type": "csv",
"about": "Description.",
"name": "Bar"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.CSVInputSpec(
about="Description.",
name="Bar"
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = ([
'**Bar** (`CSV <input_types.html#csv>`__, *required*): Description. '
'Please see the sample data table for details on the format.'
@ -235,15 +219,17 @@ class TestDescribeArgFromSpec(unittest.TestCase):
# Test every type that can be nested in a CSV column:
# number, ratio, percent, code,
spec = {
"type": "csv",
"about": "Description",
"name": "Bar",
"columns": {
"b": {"type": "ratio", "about": "description"}
}
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.CSVInputSpec(
about="Description",
name="Bar",
columns=spec_utils.Columns(
spec_utils.RatioInputSpec(
id="b",
about="description"
)
)
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = ([
'**Bar** (`CSV <input_types.html#csv>`__, *required*): Description'
])
@ -251,28 +237,26 @@ class TestDescribeArgFromSpec(unittest.TestCase):
def test_directory_spec(self):
self.maxDiff = None
spec = {
"type": "directory",
"about": "Description",
"name": "Bar",
"contents": {}
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.DirectoryInputSpec(
about="Description",
name="Bar",
contents={}
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = ([
'**Bar** (`directory <input_types.html#directory>`__, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_multi_type_spec(self):
spec = {
"type": {"raster", "vector"},
"about": "Description",
"name": "Bar",
"bands": {1: {"type": "integer"}},
"geometries": {"POLYGON"},
"fields": {}
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
spec = spec_utils.RasterOrVectorInputSpec(
about="Description",
name="Bar",
band=spec_utils.IntegerInputSpec(),
geometries={"POLYGON"},
fields={}
)
out = spec_utils.describe_arg_from_spec(spec.name, spec)
expected_rst = ([
'**Bar** (`raster <input_types.html#raster>`__ or `vector <input_types.html#vector>`__, *required*): Description'
])
@ -285,42 +269,34 @@ class TestDescribeArgFromSpec(unittest.TestCase):
expected_rst = (
'.. _carbon-pools-path-columns-lucode:\n\n' +
'**lucode** (`integer <input_types.html#integer>`__, *required*): ' +
carbon.MODEL_SPEC['args']['carbon_pools_path']['columns']['lucode']['about']
carbon.MODEL_SPEC.inputs.get('carbon_pools_path').columns.get('lucode').about
)
self.assertEqual(repr(out), repr(expected_rst))
def _generate_files_from_spec(output_spec, workspace):
"""A utility function to support the metadata test."""
for filename, spec_data in output_spec.items():
if 'type' in spec_data and spec_data['type'] == 'directory':
os.mkdir(os.path.join(workspace, filename))
for spec_data in output_spec:
print(spec_data.__class__)
if spec_data.__class__ is spec_utils.DirectoryOutputSpec:
os.mkdir(os.path.join(workspace, spec_data.id))
_generate_files_from_spec(
spec_data['contents'], os.path.join(workspace, filename))
spec_data.contents, os.path.join(workspace, spec_data.id))
else:
filepath = os.path.join(workspace, filename)
if 'bands' in spec_data:
filepath = os.path.join(workspace, spec_data.id)
if hasattr(spec_data, 'band'):
driver = gdal.GetDriverByName('GTIFF')
n_bands = len(spec_data['bands'])
raster = driver.Create(
filepath, 2, 2, n_bands, gdal.GDT_Byte)
for i in range(n_bands):
band = raster.GetRasterBand(i + 1)
band.SetNoDataValue(2)
elif 'fields' in spec_data:
if 'geometries' in spec_data:
driver = gdal.GetDriverByName('GPKG')
target_vector = driver.CreateDataSource(filepath)
layer_name = os.path.basename(os.path.splitext(filepath)[0])
target_layer = target_vector.CreateLayer(
layer_name, geom_type=ogr.wkbPolygon)
for field_name, field_data in spec_data['fields'].items():
target_layer.CreateField(ogr.FieldDefn(field_name, ogr.OFTInteger))
else:
# Write a CSV if it has fields but no geometry
with open(filepath, 'w') as file:
file.write(
f"{','.join([field for field in spec_data['fields']])}")
raster = driver.Create(filepath, 2, 2, 1, gdal.GDT_Byte)
band = raster.GetRasterBand(1)
band.SetNoDataValue(2)
elif hasattr(spec_data, 'fields'):
driver = gdal.GetDriverByName('GPKG')
target_vector = driver.CreateDataSource(filepath)
layer_name = os.path.basename(os.path.splitext(filepath)[0])
target_layer = target_vector.CreateLayer(
layer_name, geom_type=ogr.wkbPolygon)
for field_spec in spec_data.fields:
target_layer.CreateField(ogr.FieldDefn(field_spec.id, ogr.OFTInteger))
else:
# Such as taskgraph.db, just create the file.
with open(filepath, 'w') as file:
@ -342,40 +318,37 @@ class TestMetadataFromSpec(unittest.TestCase):
"""Test writing metadata for an invest output workspace."""
# An example invest output spec
output_spec = {
'output': {
"type": "directory",
"contents": {
"urban_nature_supply_percapita.tif": {
"about": (
"The calculated supply per capita of urban nature."),
"bands": {1: {
"type": "number",
"units": u.m**2,
}}},
"admin_boundaries.gpkg": {
"about": (
"A copy of the user's administrative boundaries "
"vector with a single layer."),
"geometries": spec_utils.POLYGONS,
"fields": {
"SUP_DEMadm_cap": {
"type": "number",
"units": u.m**2/u.person,
"about": (
"The average urban nature supply/demand ")
}
}
}
},
},
'intermediate': {
'type': 'directory',
'contents': {
'taskgraph_cache': spec_utils.TASKGRAPH_DIR,
}
}
}
output_spec = spec_utils.ModelOutputs(
spec_utils.DirectoryOutputSpec(
id='output',
contents=spec_utils.Contents(
spec_utils.SingleBandRasterOutputSpec(
id="urban_nature_supply_percapita.tif",
about="The calculated supply per capita of urban nature.",
band=spec_utils.NumberInputSpec(units=u.m**2)
),
spec_utils.VectorOutputSpec(
id="admin_boundaries.gpkg",
about=("A copy of the user's administrative boundaries "
"vector with a single layer."),
geometries=spec_utils.POLYGONS,
fields=spec_utils.Fields(
spec_utils.NumberInputSpec(
id="SUP_DEMadm_cap",
units=u.m**2/u.person,
about="The average urban nature supply/demand"
)
)
)
)
),
spec_utils.DirectoryOutputSpec(
id='intermediate',
contents=spec_utils.Contents(
spec_utils.build_output_spec('taskgraph_cache', spec_utils.TASKGRAPH_DIR)
)
)
)
# Generate an output workspace with real files, without
# running an invest model.
_generate_files_from_spec(output_spec, self.workspace_dir)
@ -383,9 +356,17 @@ class TestMetadataFromSpec(unittest.TestCase):
model_module = types.SimpleNamespace(
__name__='urban_nature_access',
execute=lambda: None,
MODEL_SPEC={
'model_id': 'urban_nature_access',
'outputs': output_spec})
MODEL_SPEC=spec_utils.ModelSpec(
model_id='urban_nature_access',
model_title='Urban Nature Access',
userguide='',
aliases=[],
ui_spec={},
inputs=spec_utils.ModelInputs(),
args_with_spatial_overlap={},
outputs=output_spec
)
)
args_dict = {'workspace_dir': self.workspace_dir}
@ -399,4 +380,4 @@ class TestMetadataFromSpec(unittest.TestCase):
os.path.join(args_dict['workspace_dir'], 'output',
'urban_nature_supply_percapita.tif'))
self.assertCountEqual(resource.get_keywords(),
[model_module.MODEL_SPEC['model_id'], 'InVEST'])
[model_module.MODEL_SPEC.model_id, 'InVEST'])

View File

@ -44,8 +44,8 @@ class EndpointFunctionTests(unittest.TestCase):
spec = json.loads(response.get_data(as_text=True))
self.assertEqual(
set(spec),
{'model_id', 'model_title', 'pyname', 'userguide', 'aliases',
'ui_spec', 'args_with_spatial_overlap', 'args', 'outputs'})
{'model_id', 'model_title', 'userguide', 'aliases',
'ui_spec', 'args_with_spatial_overlap', 'inputs', 'outputs'})
def test_get_invest_validate(self):
"""UI server: get_invest_validate endpoint."""
@ -55,7 +55,7 @@ class EndpointFunctionTests(unittest.TestCase):
'workspace_dir': 'foo'
}
payload = {
'model_id': carbon.MODEL_SPEC['model_id'],
'model_id': carbon.MODEL_SPEC.model_id,
'args': json.dumps(args)
}
response = test_client.post(f'{ROUTE_PREFIX}/validate', json=payload)

View File

@ -30,6 +30,7 @@ class UsageLoggingTests(unittest.TestCase):
"""Usage logger test that we can extract bounding boxes."""
from natcap.invest import utils
from natcap.invest import usage
from natcap.invest import spec_utils
srs = osr.SpatialReference()
srs.ImportFromEPSG(32731) # WGS84 / UTM zone 31s
@ -64,20 +65,23 @@ class UsageLoggingTests(unittest.TestCase):
'blank_vector_path': '',
}
args_spec = {
'args': {
'raster': {'type': 'raster'},
'vector': {'type': 'vector'},
'not_a_gis_input': {'type': 'freestyle_string'},
'blank_raster_path': {'type': 'raster'},
'blank_vector_path': {'type': 'vector'},
}
}
model_spec = spec_utils.ModelSpec(
model_id='', model_title='', userguide=None,
aliases=None, ui_spec=spec_utils.UISpec(order=[], hidden={}),
inputs=spec_utils.ModelInputs(
spec_utils.SingleBandRasterInputSpec(id='raster', band=spec_utils.InputSpec()),
spec_utils.VectorInputSpec(id='vector', geometries={}, fields={}),
spec_utils.StringInputSpec(id='not_a_gis_input'),
spec_utils.SingleBandRasterInputSpec(id='blank_raster_path', band=spec_utils.InputSpec()),
spec_utils.VectorInputSpec(id='blank_vector_path', geometries={}, fields={})
),
outputs={},
args_with_spatial_overlap=None)
output_logfile = os.path.join(self.workspace_dir, 'logfile.txt')
with utils.log_to_file(output_logfile):
bb_inter, bb_union = usage._calculate_args_bounding_box(
model_args, args_spec)
model_args, model_spec)
numpy.testing.assert_allclose(
bb_inter, [-87.234108, -85.526151, -87.233424, -85.526205])

View File

@ -2039,7 +2039,11 @@ class TestValidationFromSpec(unittest.TestCase):
fields={},
geometries={'POINT'}
)
)
),
args_with_spatial_overlap={
'spatial_keys': ['raster_a', 'raster_b', 'vector_a'],
'different_projections_ok': True
}
)
driver = gdal.GetDriverByName('GTiff')
@ -2078,9 +2082,7 @@ class TestValidationFromSpec(unittest.TestCase):
'vector_a': reference_filepath,
}
validation_warnings = validation.validate(
args, spec, {'spatial_keys': list(args.keys()),
'different_projections_ok': True})
validation_warnings = validation.validate(args, spec)
self.assertEqual(len(validation_warnings), 1)
self.assertEqual(set(args.keys()), set(validation_warnings[0][0]))
formatted_bbox_list = '' # allows str matching w/o real bbox str
@ -2102,7 +2104,11 @@ class TestValidationFromSpec(unittest.TestCase):
id='raster_b',
band=NumberInputSpec(units=u.none)
)
)
),
args_with_spatial_overlap={
'spatial_keys': ['raster_a', 'raster_b'],
'different_projections_ok': True
}
)
driver = gdal.GetDriverByName('GTiff')
@ -2124,9 +2130,7 @@ class TestValidationFromSpec(unittest.TestCase):
'raster_b': filepath_2
}
validation_warnings = validation.validate(
args, spec, {'spatial_keys': list(args.keys()),
'different_projections_ok': True})
validation_warnings = validation.validate(args, spec)
expected = [(['raster_b'], validation.MESSAGES['INVALID_PROJECTION'])]
self.assertEqual(validation_warnings, expected)
@ -2151,7 +2155,11 @@ class TestValidationFromSpec(unittest.TestCase):
fields=Fields(),
geometries={'POINT'}
)
)
),
args_with_spatial_overlap={
'spatial_keys': ['raster_a', 'raster_b', 'vector_a'],
'different_projections_ok': True
}
)
driver = gdal.GetDriverByName('GTiff')
@ -2174,9 +2182,7 @@ class TestValidationFromSpec(unittest.TestCase):
}
# There should not be a spatial overlap check at all
# when less than 2 of the spatial keys are sufficient.
validation_warnings = validation.validate(
args, spec, {'spatial_keys': [i.id for i in spec.inputs],
'different_projections_ok': True})
validation_warnings = validation.validate(args, spec)
print(validation_warnings)
self.assertEqual(len(validation_warnings), 0)
@ -2186,9 +2192,7 @@ class TestValidationFromSpec(unittest.TestCase):
'raster_a': filepath_1,
'raster_b': filepath_2,
}
validation_warnings = validation.validate(
args, spec, {'spatial_keys': [i.id for i in spec.inputs],
'different_projections_ok': True})
validation_warnings = validation.validate(args, spec)
self.assertEqual(len(validation_warnings), 1)
formatted_bbox_list = '' # allows str matching w/o real bbox str
self.assertTrue(