remove ModelInputs and ModelOutputs classes

This commit is contained in:
Emily Soth 2025-05-02 14:23:14 -07:00
parent 18d8d10992
commit 960b2999c2
40 changed files with 274 additions and 245 deletions

View File

@ -535,7 +535,7 @@ def execute(args):
'Checking that watersheds have entries for every `ws_id` in the '
'valuation table.')
# Open/read in valuation parameters from CSV file
valuation_df = MODEL_SPEC.inputs.get(
valuation_df = MODEL_SPEC.get_input(
'valuation_table_path').get_validated_dataframe(args['valuation_table_path'])
watershed_vector = gdal.OpenEx(
args['watersheds_path'], gdal.OF_VECTOR)
@ -658,15 +658,15 @@ def execute(args):
'lulc': pygeoprocessing.get_raster_info(clipped_lulc_path)['nodata'][0]}
# Open/read in the csv file into a dictionary and add to arguments
bio_df = MODEL_SPEC.inputs.get(
'biophysical_table_path').get_validated_dataframe(args['biophysical_table_path'])
bio_df = MODEL_SPEC.get_input('biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
bio_lucodes = set(bio_df.index.values)
bio_lucodes.add(nodata_dict['lulc'])
LOGGER.debug(f'bio_lucodes: {bio_lucodes}')
if 'demand_table_path' in args and args['demand_table_path'] != '':
demand_df = MODEL_SPEC.inputs.get('demand_table_path').get_validated_dataframe(
demand_df = MODEL_SPEC.get_input('demand_table_path').get_validated_dataframe(
args['demand_table_path'])
demand_reclassify_dict = dict(
[(lucode, row['demand']) for lucode, row in demand_df.iterrows()])

View File

@ -321,7 +321,7 @@ def execute(args):
"Baseline LULC Year is earlier than the Alternate LULC Year."
)
carbon_pool_df = MODEL_SPEC.inputs.get(
carbon_pool_df = MODEL_SPEC.get_input(
'carbon_pools_path').get_validated_dataframe(args['carbon_pools_path'])
try:

View File

@ -148,7 +148,7 @@ def export_to_python(target_filepath, model_id, args_dict=None):
if args_dict is None:
cast_args = {
key: '' for key in models.model_id_to_spec[model_id].inputs.__dict__.keys()}
key: '' for key in models.model_id_to_spec[model_id].inputs_dict.keys()}
else:
cast_args = dict((str(key), value) for (key, value)
in args_dict.items())

View File

@ -585,7 +585,7 @@ def execute(args):
task_graph, n_workers, intermediate_dir, output_dir, suffix = (
_set_up_workspace(args))
snapshots = MODEL_SPEC.inputs.get(
snapshots = MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
@ -607,7 +607,7 @@ def execute(args):
# We're assuming that the LULC initial variables and the carbon pool
# transient table are combined into a single lookup table.
biophysical_df = MODEL_SPEC.inputs.get(
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
@ -977,7 +977,7 @@ def execute(args):
prices = None
if args.get('do_economic_analysis', False): # Do if truthy
if args.get('use_price_table', False):
prices = MODEL_SPEC.inputs.get(
prices = MODEL_SPEC.get_input(
'price_table_path').get_validated_dataframe(
args['price_table_path'])['price'].to_dict()
else:
@ -1961,7 +1961,7 @@ def _read_transition_matrix(transition_csv_path, biophysical_df):
landcover transition, and the second contains accumulation rates for
the pool for the landcover transition.
"""
table = MODEL_SPEC.inputs.get(
table = MODEL_SPEC.get_input(
'landcover_transitions_table').get_validated_dataframe(
transition_csv_path).reset_index()
@ -2183,7 +2183,7 @@ def validate(args, limit_to=None):
if ("landcover_snapshot_csv" not in invalid_keys and
"landcover_snapshot_csv" in sufficient_keys):
snapshots = MODEL_SPEC.inputs.get(
snapshots = MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv']
)['raster_path'].to_dict()
@ -2205,7 +2205,7 @@ def validate(args, limit_to=None):
# check for invalid options in the translation table
if ("landcover_transitions_table" not in invalid_keys and
"landcover_transitions_table" in sufficient_keys):
transitions_spec = MODEL_SPEC.inputs.get('landcover_transitions_table')
transitions_spec = MODEL_SPEC.get_input('landcover_transitions_table')
transition_options = list(
transitions_spec.columns.get('[LULC CODE]').options.keys())
# lowercase options since utils call will lowercase table values

View File

@ -184,7 +184,7 @@ def execute(args):
os.path.join(args['workspace_dir'], 'taskgraph_cache'),
n_workers, reporting_interval=5.0)
snapshots_dict = MODEL_SPEC.inputs.get(
snapshots_dict = MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
@ -216,7 +216,7 @@ def execute(args):
target_path_list=aligned_snapshot_paths,
task_name='Align input landcover rasters')
landcover_df = MODEL_SPEC.inputs.get(
landcover_df = MODEL_SPEC.get_input(
'lulc_lookup_table_path').get_validated_dataframe(
args['lulc_lookup_table_path'])
@ -388,7 +388,7 @@ def _create_biophysical_table(landcover_df, target_biophysical_table_path):
"""
target_column_names = [
spec.id.lower() for spec in
coastal_blue_carbon.MODEL_SPEC.inputs.get('biophysical_table_path').columns]
coastal_blue_carbon.MODEL_SPEC.get_input('biophysical_table_path').columns]
with open(target_biophysical_table_path, 'w') as bio_table:
bio_table.write(f"{','.join(target_column_names)}\n")

View File

@ -2341,7 +2341,7 @@ def _schedule_habitat_tasks(
list of pickle file path strings
"""
habitat_dataframe = MODEL_SPEC.inputs.get(
habitat_dataframe = MODEL_SPEC.get_input(
'habitat_table_path').get_validated_dataframe(habitat_table_path
).rename(columns={'protection distance (m)': 'distance'})

View File

@ -614,7 +614,7 @@ def execute(args):
None.
"""
crop_to_landcover_df = MODEL_SPEC.inputs.get(
crop_to_landcover_df = MODEL_SPEC.get_input(
'landcover_to_crop_table_path').get_validated_dataframe(
args['landcover_to_crop_table_path'])
@ -696,7 +696,7 @@ def execute(args):
climate_percentile_yield_table_path = os.path.join(
args['model_data_path'],
_CLIMATE_PERCENTILE_TABLE_PATTERN % crop_name)
crop_climate_percentile_df = MODEL_SPEC.inputs.get(
crop_climate_percentile_df = MODEL_SPEC.get_input(
'model_data_path').contents.get(
'climate_percentile_yield_tables').contents.get(
'[CROP]_percentile_yield_table.csv').get_validated_dataframe(
@ -853,7 +853,7 @@ def execute(args):
# both 'crop_nutrient.csv' and 'crop' are known data/header values for
# this model data.
nutrient_df = MODEL_SPEC.inputs.get(
nutrient_df = MODEL_SPEC.get_input(
'model_data_path').contents.get(
'crop_nutrient.csv').get_validated_dataframe(
os.path.join(args['model_data_path'], 'crop_nutrient.csv'))

View File

@ -503,11 +503,11 @@ def execute(args):
LOGGER.info(
"Checking if the landcover raster is missing lucodes")
crop_to_landcover_df = MODEL_SPEC.inputs.get(
crop_to_landcover_df = MODEL_SPEC.get_input(
'landcover_to_crop_table_path').get_validated_dataframe(
args['landcover_to_crop_table_path'])
crop_to_fertilization_rate_df = MODEL_SPEC.inputs.get(
crop_to_fertilization_rate_df = MODEL_SPEC.get_input(
'fertilization_rate_table_path').get_validated_dataframe(
args['fertilization_rate_table_path'])
@ -584,7 +584,7 @@ def execute(args):
task_name='crop_climate_bin')
dependent_task_list.append(crop_climate_bin_task)
crop_regression_df = MODEL_SPEC.inputs.get('model_data_path').contents.get(
crop_regression_df = MODEL_SPEC.get_input('model_data_path').contents.get(
'climate_regression_yield_tables').contents.get(
'[CROP]_regression_yield_table.csv').get_validated_dataframe(
os.path.join(args['model_data_path'],
@ -807,7 +807,7 @@ def execute(args):
# both 'crop_nutrient.csv' and 'crop' are known data/header values for
# this model data.
nutrient_df = MODEL_SPEC.inputs.get('model_data_path').contents.get(
nutrient_df = MODEL_SPEC.get_input('model_data_path').contents.get(
'crop_nutrient.csv').get_validated_dataframe(
os.path.join(args['model_data_path'], 'crop_nutrient.csv'))

View File

@ -239,7 +239,7 @@ def build_datastack_archive(args, model_id, datastack_path):
if key not in module.MODEL_SPEC.inputs:
LOGGER.info(f'Skipping arg {key}; not in model MODEL_SPEC')
input_spec = module.MODEL_SPEC.inputs.get(key)
input_spec = module.MODEL_SPEC.get_input(key)
if input_spec.__class__ in file_based_types:
if args[key] in {None, ''}:
LOGGER.info(

View File

@ -425,7 +425,7 @@ def execute(args):
# Map non-forest landcover codes to carbon biomasses
LOGGER.info('Calculating direct mapped carbon stocks')
carbon_maps = []
biophysical_df = MODEL_SPEC.inputs.get(
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
pool_list = [('c_above', True)]
@ -644,7 +644,7 @@ def _calculate_lulc_carbon_map(
"""
# classify forest pixels from lulc
biophysical_df = MODEL_SPEC.inputs.get(
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(biophysical_table_path)
lucode_to_per_cell_carbon = {}
@ -704,7 +704,7 @@ def _map_distance_from_tropical_forest_edge(
"""
# Build a list of forest lucodes
biophysical_df = MODEL_SPEC.inputs.get(
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
biophysical_table_path)
forest_codes = biophysical_df[biophysical_df['is_tropical_forest']].index.values
@ -1036,7 +1036,7 @@ def validate(args, limit_to=None):
"""
model_spec = copy.deepcopy(MODEL_SPEC)
if 'pools_to_calculate' in args and args['pools_to_calculate'] == 'all':
model_spec.inputs.get('biophysical_table_path').columns.get('c_below').required = True
model_spec.inputs.get('biophysical_table_path').columns.get('c_soil').required = True
model_spec.inputs.get('biophysical_table_path').columns.get('c_dead').required = True
model_spec.get_input('biophysical_table_path').columns.get('c_below').required = True
model_spec.get_input('biophysical_table_path').columns.get('c_soil').required = True
model_spec.get_input('biophysical_table_path').columns.get('c_dead').required = True
return validation.validate(args, model_spec)

View File

@ -455,10 +455,10 @@ def execute(args):
LOGGER.info("Checking Threat and Sensitivity tables for compliance")
# Get CSVs as dictionaries and ensure the key is a string for threats.
threat_df = MODEL_SPEC.inputs.get(
threat_df = MODEL_SPEC.get_input(
'threats_table_path').get_validated_dataframe(
args['threats_table_path']).fillna('')
sensitivity_df = MODEL_SPEC.inputs.get(
sensitivity_df = MODEL_SPEC.get_input(
'sensitivity_table_path').get_validated_dataframe(
args['sensitivity_table_path'])
@ -1181,10 +1181,10 @@ def validate(args, limit_to=None):
"sensitivity_table_path" not in invalid_keys and
"threat_raster_folder" not in invalid_keys):
# Get CSVs as dictionaries and ensure the key is a string for threats.
threat_df = MODEL_SPEC.inputs.get(
threat_df = MODEL_SPEC.get_input(
'threats_table_path').get_validated_dataframe(
args['threats_table_path']).fillna('')
sensitivity_df = MODEL_SPEC.inputs.get(
sensitivity_df = MODEL_SPEC.get_input(
'sensitivity_table_path').get_validated_dataframe(
args['sensitivity_table_path'])

View File

@ -447,8 +447,8 @@ MODEL_SPEC = spec_utils.build_model_spec({
}
})
_VALID_RISK_EQS = set(MODEL_SPEC.inputs.get('risk_eq').options.keys())
_VALID_DECAY_TYPES = set(MODEL_SPEC.inputs.get('decay_eq').options.keys())
_VALID_RISK_EQS = set(MODEL_SPEC.get_input('risk_eq').options.keys())
_VALID_DECAY_TYPES = set(MODEL_SPEC.get_input('decay_eq').options.keys())
def execute(args):
@ -1791,7 +1791,7 @@ def _parse_info_table(info_table_path):
info_table_path = os.path.abspath(info_table_path)
try:
table = MODEL_SPEC.inputs.get(
table = MODEL_SPEC.get_input(
'info_table_path').get_validated_dataframe(info_table_path)
except ValueError as err:
if 'Index has duplicate keys' in str(err):

View File

@ -624,7 +624,7 @@ def execute(args):
if args['calc_' + nutrient_id]:
nutrients_to_process.append(nutrient_id)
biophysical_df = MODEL_SPEC.inputs.get(
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
@ -1298,11 +1298,11 @@ def validate(args, limit_to=None):
for param in ['load', 'eff', 'crit_len']:
for nutrient in nutrients_selected:
spec_copy.inputs.get('biophysical_table_path').columns.get(
spec_copy.get_input('biophysical_table_path').columns.get(
f'{param}_{nutrient}').required = True
if 'n' in nutrients_selected:
spec_copy.inputs.get('biophysical_table_path').columns.get(
spec_copy.get_input('biophysical_table_path').columns.get(
'proportion_subsurface_n').required = True
validation_warnings = validation.validate(args, spec_copy)

View File

@ -1221,7 +1221,7 @@ def _parse_scenario_variables(args):
else:
farm_vector_path = None
guild_df = MODEL_SPEC.inputs.get(
guild_df = MODEL_SPEC.get_input(
'guild_table_path').get_validated_dataframe(guild_table_path)
LOGGER.info('Checking to make sure guild table has all expected headers')
@ -1233,7 +1233,7 @@ def _parse_scenario_variables(args):
f"'{header}' but was unable to find one. Here are all the "
f"headers from {guild_table_path}: {', '.join(guild_df.columns)}")
landcover_biophysical_df = MODEL_SPEC.inputs.get(
landcover_biophysical_df = MODEL_SPEC.get_input(
'landcover_biophysical_table_path').get_validated_dataframe(
landcover_biophysical_table_path)
biophysical_table_headers = landcover_biophysical_df.columns

View File

@ -609,7 +609,7 @@ def execute(args):
# Compute the regression
coefficient_json_path = os.path.join(
intermediate_dir, 'predictor_estimates.json')
predictor_df = MODEL_SPEC.inputs.get(
predictor_df = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(
args['predictor_table_path'])
predictor_id_list = predictor_df.index
@ -996,7 +996,7 @@ def _schedule_predictor_data_processing(
'line_intersect_length': _line_intersect_length,
}
predictor_df = MODEL_SPEC.inputs.get(
predictor_df = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(predictor_table_path)
predictor_task_list = []
predictor_json_list = [] # tracks predictor files to add to gpkg
@ -1765,7 +1765,7 @@ def _validate_same_id_lengths(table_path):
string message if IDs are too long
"""
predictor_df = MODEL_SPEC.inputs.get(
predictor_df = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(table_path)
too_long = set()
for p_id in predictor_df.index:
@ -1794,11 +1794,11 @@ def _validate_same_ids_and_types(
string message if any of the fields in 'id' and 'type' don't match
between tables.
"""
predictor_df = MODEL_SPEC.inputs.get(
predictor_df = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(
predictor_table_path)
scenario_predictor_df = MODEL_SPEC.inputs.get(
scenario_predictor_df = MODEL_SPEC.get_input(
'scenario_predictor_table_path').get_validated_dataframe(
scenario_predictor_table_path)
@ -1825,7 +1825,7 @@ def _validate_same_projection(base_vector_path, table_path):
"""
# This will load the table as a list of paths which we can iterate through
# without bothering the rest of the table structure
data_paths = MODEL_SPEC.inputs.get(
data_paths = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(
table_path)['path'].tolist()
@ -1868,7 +1868,7 @@ def _validate_predictor_types(table_path):
string message if any value in the ``type`` column does not match a
valid type, ignoring leading/trailing whitespace.
"""
df = MODEL_SPEC.inputs.get(
df = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(table_path)
# ignore leading/trailing whitespace because it will be removed
# when the type values are used

View File

@ -283,7 +283,7 @@ def execute(args):
'b': float(args['b_coef']),
}
if (args['valuation_function'] not in
MODEL_SPEC.inputs.get('valuation_function').options):
MODEL_SPEC.get_input('valuation_function').options):
raise ValueError('Valuation function type %s not recognized' %
args['valuation_function'])
max_valuation_radius = float(args['max_valuation_radius'])

View File

@ -562,7 +562,7 @@ def execute(args):
"""
file_suffix = utils.make_suffix_string(args, 'results_suffix')
biophysical_df = MODEL_SPEC.inputs.get(
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])

View File

@ -625,17 +625,17 @@ def execute(args):
# fail early on a missing required rain events table
if (not args['user_defined_local_recharge'] and
not args['user_defined_climate_zones']):
rain_events_df = MODEL_SPEC.inputs.get(
rain_events_df = MODEL_SPEC.get_input(
'rain_events_table_path').get_validated_dataframe(
args['rain_events_table_path'])
biophysical_df = MODEL_SPEC.inputs.get(
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
if args['monthly_alpha']:
# parse out the alpha lookup table of the form (month_id: alpha_val)
alpha_month_map = MODEL_SPEC.inputs.get(
alpha_month_map = MODEL_SPEC.get_input(
'monthly_alpha_path').get_validated_dataframe(
args['monthly_alpha_path'])['alpha'].to_dict()
else:
@ -814,7 +814,7 @@ def execute(args):
'table_name': 'Climate Zone'}
for month_id in range(N_MONTHS):
if args['user_defined_climate_zones']:
cz_rain_events_df = MODEL_SPEC.inputs.get(
cz_rain_events_df = MODEL_SPEC.get_input(
'climate_zone_table_path').get_validated_dataframe(
args['climate_zone_table_path'])
climate_zone_rain_events_month = (

View File

@ -8,8 +8,8 @@ import queue
import re
import threading
import types
import typing
import warnings
from typing import Union, ClassVar
from osgeo import gdal
from osgeo import ogr
@ -199,13 +199,6 @@ class IterableWithDotAccess():
# else:
# raise StopIteration
class ModelInputs(IterableWithDotAccess):
pass
class ModelOutputs(IterableWithDotAccess):
pass
class Rows(IterableWithDotAccess):
pass
@ -223,19 +216,19 @@ class Input:
id: str = ''
name: str = ''
about: str = ''
required: Union[bool, str] = True
allowed: Union[bool, str] = True
required: typing.Union[bool, str] = True
allowed: typing.Union[bool, str] = True
@dataclasses.dataclass
class Output:
id: str = ''
about: str = ''
created_if: Union[bool, str] = True
created_if: typing.Union[bool, str] = True
@dataclasses.dataclass
class FileInput(Input):
permissions: str = 'r'
type: ClassVar[str] = 'file'
type: typing.ClassVar[str] = 'file'
# @timeout
def validate(self, filepath):
@ -269,10 +262,10 @@ class FileInput(Input):
@dataclasses.dataclass
class SingleBandRasterInput(FileInput):
band: Union[Input, None] = None
projected: Union[bool, None] = None
projection_units: Union[pint.Unit, None] = None
type: ClassVar[str] = 'raster'
band: typing.Union[Input, None] = None
projected: typing.Union[bool, None] = None
projection_units: typing.Union[pint.Unit, None] = None
type: typing.ClassVar[str] = 'raster'
# @timeout
def validate(self, filepath):
@ -313,10 +306,10 @@ class SingleBandRasterInput(FileInput):
@dataclasses.dataclass
class VectorInput(FileInput):
geometries: set = dataclasses.field(default_factory=dict)
fields: Union[Fields, None] = None
projected: Union[bool, None] = None
projection_units: Union[pint.Unit, None] = None
type: ClassVar[str] = 'vector'
fields: typing.Union[Fields, None] = None
projected: typing.Union[bool, None] = None
projection_units: typing.Union[pint.Unit, None] = None
type: typing.ClassVar[str] = 'vector'
# @timeout
def validate(self, filepath):
@ -397,12 +390,12 @@ class VectorInput(FileInput):
@dataclasses.dataclass
class RasterOrVectorInput(SingleBandRasterInput, VectorInput):
band: Union[Input, None] = None
band: typing.Union[Input, None] = None
geometries: set = dataclasses.field(default_factory=dict)
fields: Union[Fields, None] = None
projected: Union[bool, None] = None
projection_units: Union[pint.Unit, None] = None
type: ClassVar[str] = 'raster_or_vector'
fields: typing.Union[Fields, None] = None
projected: typing.Union[bool, None] = None
projection_units: typing.Union[pint.Unit, None] = None
type: typing.ClassVar[str] = 'raster_or_vector'
# @timeout
def validate(self, filepath):
@ -427,10 +420,10 @@ class RasterOrVectorInput(SingleBandRasterInput, VectorInput):
@dataclasses.dataclass
class CSVInput(FileInput):
columns: Union[Columns, None] = None
rows: Union[Rows, None] = None
index_col: Union[str, None] = None
type: ClassVar[str] = 'csv'
columns: typing.Union[Columns, None] = None
rows: typing.Union[Rows, None] = None
index_col: typing.Union[str, None] = None
type: typing.ClassVar[str] = 'csv'
# @timeout
def validate(self, filepath):
@ -548,10 +541,10 @@ class CSVInput(FileInput):
@dataclasses.dataclass
class DirectoryInput(Input):
contents: Union[Contents, None] = None
contents: typing.Union[Contents, None] = None
permissions: str = ''
must_exist: bool = True
type: ClassVar[str] = 'directory'
type: typing.ClassVar[str] = 'directory'
# @timeout
def validate(self, dirpath):
@ -619,9 +612,9 @@ class DirectoryInput(Input):
@dataclasses.dataclass
class NumberInput(Input):
units: Union[pint.Unit, None] = None
expression: Union[str, None] = None
type: ClassVar[str] = 'number'
units: typing.Union[pint.Unit, None] = None
expression: typing.Union[str, None] = None
type: typing.ClassVar[str] = 'number'
def validate(self, value):
"""Validate numbers.
@ -662,7 +655,7 @@ class NumberInput(Input):
@dataclasses.dataclass
class IntegerInput(Input):
type: ClassVar[str] = 'integer'
type: typing.ClassVar[str] = 'integer'
def validate(self, value):
"""Validate an integer.
@ -689,7 +682,7 @@ class IntegerInput(Input):
@dataclasses.dataclass
class RatioInput(Input):
type: ClassVar[str] = 'ratio'
type: typing.ClassVar[str] = 'ratio'
def validate(self, value):
"""Validate a ratio (a proportion expressed as a value from 0 to 1).
@ -717,7 +710,7 @@ class RatioInput(Input):
@dataclasses.dataclass
class PercentInput(Input):
type: ClassVar[str] = 'percent'
type: typing.ClassVar[str] = 'percent'
def validate(self, value):
"""Validate a percent (a proportion expressed as a value from 0 to 100).
@ -744,7 +737,7 @@ class PercentInput(Input):
@dataclasses.dataclass
class BooleanInput(Input):
type: ClassVar[str] = 'boolean'
type: typing.ClassVar[str] = 'boolean'
def validate(self, value):
"""Validate a boolean value.
@ -767,8 +760,8 @@ class BooleanInput(Input):
@dataclasses.dataclass
class StringInput(Input):
regexp: Union[str, None] = None
type: ClassVar[str] = 'string'
regexp: typing.Union[str, None] = None
type: typing.ClassVar[str] = 'string'
def validate(self, value):
"""Validate an arbitrary string.
@ -794,8 +787,8 @@ class StringInput(Input):
@dataclasses.dataclass
class OptionStringInput(Input):
options: Union[list, None] = None
type: ClassVar[str] = 'option_string'
options: typing.Union[list, None] = None
type: typing.ClassVar[str] = 'option_string'
def validate(self, value):
"""Validate that a string is in a set of options.
@ -828,26 +821,26 @@ class OtherInput(Input):
@dataclasses.dataclass
class SingleBandRasterOutput(Output):
band: Union[Input, None] = None
projected: Union[bool, None] = None
projection_units: Union[pint.Unit, None] = None
band: typing.Union[Input, None] = None
projected: typing.Union[bool, None] = None
projection_units: typing.Union[pint.Unit, None] = None
@dataclasses.dataclass
class VectorOutput(Output):
geometries: set = dataclasses.field(default_factory=dict)
fields: Union[Fields, None] = None
projected: Union[bool, None] = None
projection_units: Union[pint.Unit, None] = None
fields: typing.Union[Fields, None] = None
projected: typing.Union[bool, None] = None
projection_units: typing.Union[pint.Unit, None] = None
@dataclasses.dataclass
class CSVOutput(Output):
columns: Union[Columns, None] = None
rows: Union[Rows, None] = None
index_col: Union[str, None] = None
columns: typing.Union[Columns, None] = None
rows: typing.Union[Rows, None] = None
index_col: typing.Union[str, None] = None
@dataclasses.dataclass
class DirectoryOutput(Output):
contents: Union[Contents, None] = None
contents: typing.Union[Contents, None] = None
permissions: str = ''
must_exist: bool = True
@ -857,8 +850,8 @@ class FileOutput(Output):
@dataclasses.dataclass
class NumberOutput(Output):
units: Union[pint.Unit, None] = None
expression: Union[str, None] = None
units: typing.Union[pint.Unit, None] = None
expression: typing.Union[str, None] = None
@dataclasses.dataclass
class IntegerOutput(Output):
@ -874,15 +867,15 @@ class PercentOutput(Output):
@dataclasses.dataclass
class StringOutput(Output):
regexp: Union[str, None] = None
regexp: typing.Union[str, None] = None
@dataclasses.dataclass
class OptionStringOutput(Output):
options: Union[list, None] = None
options: typing.Union[list, None] = None
@dataclasses.dataclass
class UISpec:
order: Union[list, None] = None
order: typing.Union[list, None] = None
hidden: list = None
dropdown_functions: dict = dataclasses.field(default_factory=dict)
@ -892,21 +885,26 @@ class ModelSpec:
model_id: str
model_title: str
userguide: str
aliases: set
ui_spec: UISpec
inputs: ModelInputs
outputs: set
inputs: typing.Iterable[Input]
outputs: typing.Iterable[Output]
args_with_spatial_overlap: dict
aliases: set = dataclasses.field(default_factory=set)
def __post_init__(self):
self.inputs_dict = {_input.id: _input for _input in self.inputs}
self.outputs_dict = {_output.id: _output for _output in self.outputs}
def get_input(self, key):
return self.inputs_dict[key]
def build_model_spec(model_spec):
input_specs = [
inputs = [
build_input_spec(argkey, argspec)
for argkey, argspec in model_spec['args'].items()]
output_specs = [
outputs = [
build_output_spec(argkey, argspec) for argkey, argspec in model_spec['outputs'].items()]
inputs = ModelInputs(*input_specs)
outputs = ModelOutputs(*output_specs)
ui_spec = UISpec(
order=model_spec['ui_spec']['order'],
hidden=model_spec['ui_spec'].get('hidden', None),
@ -1682,8 +1680,8 @@ def describe_arg_from_name(module_name, *arg_keys):
module = importlib.import_module(module_name)
# start with the spec for all args
# narrow down to the nested spec indicated by the sequence of arg keys
spec = module.MODEL_SPEC.inputs
for i, key in enumerate(arg_keys):
spec = module.MODEL_SPEC.get_input(arg_keys[0])
for i, key in enumerate(arg_keys[1:]):
# convert raster band numbers to ints
if arg_keys[i - 1] == 'bands':
key = int(key)

View File

@ -497,7 +497,7 @@ def execute(args):
# Build a lookup dictionary mapping each LULC code to its row
# sort by the LULC codes upfront because we use the sorted list in multiple
# places. it's more efficient to do this once.
biophysical_df = MODEL_SPEC.inputs.get(
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table').get_validated_dataframe(
args['biophysical_table']).sort_index()
sorted_lucodes = biophysical_df.index.to_list()

View File

@ -475,7 +475,7 @@ def execute(args):
intermediate_dir = os.path.join(
args['workspace_dir'], 'intermediate')
utils.make_directories([args['workspace_dir'], intermediate_dir])
biophysical_df = MODEL_SPEC.inputs.get(
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
@ -1160,7 +1160,7 @@ def calculate_energy_savings(
for field in target_building_layer.schema]
type_field_index = fieldnames.index('type')
energy_consumption_df = MODEL_SPEC.inputs.get(
energy_consumption_df = MODEL_SPEC.get_input(
'energy_consumption_table_path').get_validated_dataframe(
energy_consumption_table_path)
@ -1541,7 +1541,7 @@ def validate(args, limit_to=None):
invalid_keys = validation.get_invalid_keys(validation_warnings)
if ('biophysical_table_path' not in invalid_keys and
'cc_method' not in invalid_keys):
spec = copy.deepcopy(MODEL_SPEC.inputs.get('biophysical_table_path'))
spec = copy.deepcopy(MODEL_SPEC.get_input('biophysical_table_path'))
if args['cc_method'] == 'factors':
spec.columns.get('shade').required = True
spec.columns.get('albedo').required = True

View File

@ -307,7 +307,7 @@ def execute(args):
task_name='align raster stack')
# Load CN table
cn_df = MODEL_SPEC.inputs.get(
cn_df = MODEL_SPEC.get_input(
'curve_number_table_path').get_validated_dataframe(
args['curve_number_table_path'])
@ -636,7 +636,7 @@ def _calculate_damage_to_infrastructure_in_aoi(
infrastructure_vector = gdal.OpenEx(structures_vector_path, gdal.OF_VECTOR)
infrastructure_layer = infrastructure_vector.GetLayer()
damage_type_map = MODEL_SPEC.inputs.get(
damage_type_map = MODEL_SPEC.get_input(
'infrastructure_damage_loss_table_path').get_validated_dataframe(
structures_damage_table)['damage'].to_dict()
@ -936,7 +936,7 @@ def validate(args, limit_to=None):
if ("curve_number_table_path" not in invalid_keys and
"curve_number_table_path" in sufficient_keys):
# Load CN table. Resulting DF has index and CN_X columns only.
cn_df = MODEL_SPEC.inputs.get(
cn_df = MODEL_SPEC.get_input(
'curve_number_table_path').get_validated_dataframe(
args['curve_number_table_path'])
# Check for NaN values.

View File

@ -944,7 +944,7 @@ def execute(args):
aoi_reprojection_task, lulc_mask_task]
)
attr_table = MODEL_SPEC.inputs.get(
attr_table = MODEL_SPEC.get_input(
'lulc_attribute_table').get_validated_dataframe(args['lulc_attribute_table'])
kernel_paths = {} # search_radius, kernel path
kernel_tasks = {} # search_radius, kernel task
@ -963,7 +963,7 @@ def execute(args):
lucode_to_search_radii = list(
urban_nature_attrs[['search_radius_m']].itertuples(name=None))
elif args['search_radius_mode'] == RADIUS_OPT_POP_GROUP:
pop_group_table = MODEL_SPEC.inputs.get(
pop_group_table = MODEL_SPEC.get_input(
'population_group_radii_table').get_validated_dataframe(
args['population_group_radii_table'])
search_radii = set(pop_group_table['search_radius_m'].unique())
@ -971,7 +971,7 @@ def execute(args):
search_radii_by_pop_group = pop_group_table['search_radius_m'].to_dict()
else:
valid_options = ', '.join(
MODEL_SPEC.inputs.get('search_radius_mode').options.keys())
MODEL_SPEC.get_input('search_radius_mode').options.keys())
raise ValueError(
"Invalid search radius mode provided: "
f"{args['search_radius_mode']}; must be one of {valid_options}")
@ -1843,7 +1843,7 @@ def _reclassify_urban_nature_area(
Returns:
``None``
"""
lulc_attribute_df = MODEL_SPEC.inputs.get(
lulc_attribute_df = MODEL_SPEC.get_input(
'lulc_attribute_table').get_validated_dataframe(lulc_attribute_table)
squared_pixel_area = abs(
@ -1876,9 +1876,9 @@ def _reclassify_urban_nature_area(
target_datatype=gdal.GDT_Float32,
target_nodata=FLOAT32_NODATA,
error_details={
'raster_name': MODEL_SPEC.inputs.get('lulc_raster_path').name,
'raster_name': MODEL_SPEC.get_input('lulc_raster_path').name,
'column_name': 'urban_nature',
'table_name': MODEL_SPEC.inputs.get('lulc_attribute_table').name
'table_name': MODEL_SPEC.get_input('lulc_attribute_table').name
}
)

View File

@ -120,10 +120,10 @@ def _calculate_args_bounding_box(args, model_spec):
# should already have been validated so the path is either valid or
# blank.
spatial_info = None
if (isinstance(model_spec.inputs.get(key),
if (isinstance(model_spec.get_input(key),
spec_utils.SingleBandRasterInput) and value.strip() != ''):
spatial_info = pygeoprocessing.get_raster_info(value)
elif (isinstance(model_spec.inputs.get(key),
elif (isinstance(model_spec.get_input(key),
spec_utils.VectorInput) and value.strip() != ''):
spatial_info = pygeoprocessing.get_vector_info(value)
@ -158,7 +158,7 @@ def _calculate_args_bounding_box(args, model_spec):
LOGGER.exception(
f'Error when transforming coordinates: {transform_error}')
else:
LOGGER.debug(f'Arg {key} of type {model_spec.inputs.get(key).__class__} '
LOGGER.debug(f'Arg {key} of type {model_spec.get_input(key).__class__} '
'excluded from bounding box calculation')
return bb_intersection, bb_union

View File

@ -303,7 +303,7 @@ def validate(args, spec):
# we don't need to try to validate them
try:
# Using deepcopy to make sure we don't modify the original spec
parameter_spec = copy.deepcopy(spec.inputs.get(key))
parameter_spec = copy.deepcopy(spec.get_input(key))
except KeyError:
LOGGER.debug(f'Provided key {key} does not exist in MODEL_SPEC')
continue
@ -418,14 +418,12 @@ def invest_validator(validate_func):
warnings_ = validate_func(args, limit_to)
return warnings_
args_spec = model_module.MODEL_SPEC.inputs
if limit_to is None:
LOGGER.info('Starting whole-model validation with MODEL_SPEC')
warnings_ = validate_func(args)
else:
LOGGER.info('Starting single-input validation with MODEL_SPEC')
args_key_spec = args_spec.get(limit_to)
args_key_spec = model_module.MODEL_SPEC.get_input(limit_to)
args_value = args[limit_to]
error_msg = None

View File

@ -755,14 +755,14 @@ def execute(args):
LOGGER.debug('Machine Performance Rows : %s', machine_perf_dict['periods'])
LOGGER.debug('Machine Performance Cols : %s', machine_perf_dict['heights'])
machine_param_dict = MODEL_SPEC.inputs.get(
machine_param_dict = MODEL_SPEC.get_input(
'machine_param_path').get_validated_dataframe(
args['machine_param_path'])['value'].to_dict()
# Check if required column fields are entered in the land grid csv file
if 'land_gridPts_path' in args:
# Create a grid_land_df dataframe for later use in valuation
grid_land_df = MODEL_SPEC.inputs.get(
grid_land_df = MODEL_SPEC.get_input(
'land_gridPts_path').get_validated_dataframe(args['land_gridPts_path'])
missing_grid_land_fields = []
for field in ['id', 'type', 'lat', 'long', 'location']:
@ -775,7 +775,7 @@ def execute(args):
'Connection Points File: %s' % missing_grid_land_fields)
if 'valuation_container' in args and args['valuation_container']:
machine_econ_dict = MODEL_SPEC.inputs.get(
machine_econ_dict = MODEL_SPEC.get_input(
'machine_econ_path').get_validated_dataframe(
args['machine_econ_path'])['value'].to_dict()

View File

@ -717,11 +717,11 @@ def execute(args):
number_of_turbines = int(args['number_of_turbines'])
# Read the biophysical turbine parameters into a dictionary
turbine_dict = MODEL_SPEC.inputs.get(
turbine_dict = MODEL_SPEC.get_input(
'turbine_parameters_path').get_validated_dataframe(
args['turbine_parameters_path']).iloc[0].to_dict()
# Read the biophysical global parameters into a dictionary
global_params_dict = MODEL_SPEC.inputs.get(
global_params_dict = MODEL_SPEC.get_input(
'global_wind_parameters_path').get_validated_dataframe(
args['global_wind_parameters_path']).iloc[0].to_dict()
@ -741,7 +741,7 @@ def execute(args):
# If Price Table provided use that for price of energy, validate inputs
time = parameters_dict['time_period']
if args['price_table']:
wind_price_df = MODEL_SPEC.inputs.get(
wind_price_df = MODEL_SPEC.get_input(
'wind_schedule').get_validated_dataframe(
args['wind_schedule']).sort_index() # sort by year
@ -1112,7 +1112,7 @@ def execute(args):
LOGGER.info('Grid Points Provided. Reading in the grid points')
# Read the grid points csv, and convert it to land and grid dictionary
grid_land_df = MODEL_SPEC.inputs.get(
grid_land_df = MODEL_SPEC.get_input(
'grid_points_path').get_validated_dataframe(args['grid_points_path'])
# Convert the dataframes to dictionaries, using 'ID' (the index) as key
@ -1933,7 +1933,7 @@ def _compute_density_harvested_fields(
# Read the wind energy data into a dictionary
LOGGER.info('Reading in Wind Data into a dictionary')
wind_point_df = MODEL_SPEC.inputs.get(
wind_point_df = MODEL_SPEC.get_input(
'wind_data_path').get_validated_dataframe(wind_data_path)
wind_point_df.columns = wind_point_df.columns.str.upper()
# Calculate scale value at new hub height given reference values.
@ -2672,7 +2672,7 @@ def validate(args, limit_to=None):
'global_wind_parameters_path' in valid_sufficient_keys):
year_count = utils.read_csv_to_dataframe(
args['wind_schedule']).shape[0]
time = MODEL_SPEC.inputs.get(
time = MODEL_SPEC.get_input(
'global_wind_parameters_path').get_validated_dataframe(
args['global_wind_parameters_path']).iloc[0]['time_period']
if year_count != time + 1:

View File

@ -397,7 +397,7 @@ class CLIUnitTests(unittest.TestCase):
target_model = models.model_id_to_pyname[target_model]
model_module = importlib.import_module(name=target_model)
spec = model_module.MODEL_SPEC
expected_args = {key: '' for key in spec.inputs.__dict__.keys()}
expected_args = {key: '' for key in spec.inputs_dict.keys()}
module_name = str(uuid.uuid4()) + 'testscript'
spec = importlib.util.spec_from_file_location(module_name, target_filepath)

View File

@ -211,7 +211,7 @@ class TestPreprocessor(unittest.TestCase):
lulc_csv.write('0,mangrove,True\n')
lulc_csv.write('1,parking lot,False\n')
landcover_df = preprocessor.MODEL_SPEC.inputs.get(
landcover_df = preprocessor.MODEL_SPEC.get_input(
'lulc_lookup_table_path').get_validated_dataframe(landcover_table_path)
target_table_path = os.path.join(self.workspace_dir,
@ -227,7 +227,7 @@ class TestPreprocessor(unittest.TestCase):
str(context.exception))
# Re-load the landcover table
landcover_df = preprocessor.MODEL_SPEC.inputs.get(
landcover_df = preprocessor.MODEL_SPEC.get_input(
'lulc_lookup_table_path').get_validated_dataframe(landcover_table_path)
preprocessor._create_transition_table(
landcover_df, [filename_a, filename_b], target_table_path)
@ -640,7 +640,7 @@ class TestCBC2(unittest.TestCase):
args = TestCBC2._create_model_args(self.workspace_dir)
args['workspace_dir'] = os.path.join(self.workspace_dir, 'workspace')
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.inputs.get(
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
baseline_year = min(prior_snapshots.keys())
@ -817,7 +817,7 @@ class TestCBC2(unittest.TestCase):
args = TestCBC2._create_model_args(self.workspace_dir)
args['workspace_dir'] = os.path.join(self.workspace_dir, 'workspace')
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.inputs.get(
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
baseline_year = min(prior_snapshots.keys())
@ -876,7 +876,7 @@ class TestCBC2(unittest.TestCase):
# Now work through the extra validation warnings.
# test validation: invalid analysis year
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.inputs.get(
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
baseline_year = min(prior_snapshots)

View File

@ -1,7 +1,6 @@
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
MODEL_SPEC = spec_utils.ModelSpec(inputs=[
spec_utils.StringInput(id='blank'),
spec_utils.IntegerInput(id='a'),
spec_utils.StringInput(id='b'),
@ -23,5 +22,11 @@ MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
band=spec_utils.NumberInput()
)
)
)
))
)],
outputs={},
model_id='',
model_title='',
userguide='',
ui_spec=spec_utils.UISpec(),
args_with_spatial_overlap={}
)

View File

@ -1,7 +1,14 @@
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
spec_utils.FileInput(id='foo'),
spec_utils.FileInput(id='bar'),
))
MODEL_SPEC = spec_utils.ModelSpec(
inputs=[
spec_utils.FileInput(id='foo'),
spec_utils.FileInput(id='bar')
],
outputs={},
model_id='',
model_title='',
userguide='',
ui_spec=spec_utils.UISpec(),
args_with_spatial_overlap={}
)

View File

@ -1,9 +1,14 @@
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
MODEL_SPEC = spec_utils.ModelSpec(inputs=[
spec_utils.FileInput(id='some_file'),
spec_utils.DirectoryInput(
id='data_dir',
contents=spec_utils.Contents())
))
contents=spec_utils.Contents())],
outputs={},
model_id='',
model_title='',
userguide='',
ui_spec=spec_utils.UISpec(),
args_with_spatial_overlap={}
)

View File

@ -1,6 +1,11 @@
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
spec_utils.SingleBandRasterInput(id='raster', band=spec_utils.Input())
))
MODEL_SPEC = spec_utils.ModelSpec(inputs=[
spec_utils.SingleBandRasterInput(id='raster', band=spec_utils.Input())],
outputs={},
model_id='',
model_title='',
userguide='',
ui_spec=spec_utils.UISpec(),
args_with_spatial_overlap={}
)

View File

@ -1,7 +1,6 @@
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
MODEL_SPEC = spec_utils.ModelSpec(inputs=[
spec_utils.IntegerInput(id='a'),
spec_utils.StringInput(id='b'),
spec_utils.StringInput(id='c'),
@ -9,5 +8,11 @@ MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
spec_utils.DirectoryInput(
id='workspace_dir',
contents=spec_utils.Contents()
)
))
)],
outputs={},
model_id='',
model_title='',
userguide='',
ui_spec=spec_utils.UISpec(),
args_with_spatial_overlap={}
)

View File

@ -1,7 +1,12 @@
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
MODEL_SPEC = SimpleNamespace(inputs=[
spec_utils.StringInput(id='foo'),
spec_utils.StringInput(id='bar')
))
spec_utils.StringInput(id='bar')],
outputs={},
model_id='',
model_title='',
userguide='',
ui_spec=spec_utils.UISpec(),
args_with_spatial_overlap={}
)

View File

@ -1,7 +1,12 @@
from types import SimpleNamespace
from natcap.invest import spec_utils
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
MODEL_SPEC = spec_utils.ModelSpec(inputs=[
spec_utils.VectorInput(
id='vector', fields={}, geometries={})
))
id='vector', fields={}, geometries={})],
outputs={},
model_id='',
model_title='',
userguide='',
ui_spec=spec_utils.UISpec(),
args_with_spatial_overlap={}
)

View File

@ -710,7 +710,7 @@ class TestRecClientServer(unittest.TestCase):
out_regression_vector_path = os.path.join(
args['workspace_dir'], f'regression_data_{suffix}.gpkg')
predictor_df = recmodel_client.MODEL_SPEC.inputs.get(
predictor_df = recmodel_client.MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(
os.path.join(SAMPLE_DATA, 'predictors_all.csv'))
field_list = list(predictor_df.index) + ['pr_TUD', 'pr_PUD', 'avg_pr_UD']
@ -1284,7 +1284,7 @@ class RecreationClientRegressionTests(unittest.TestCase):
predictor_table_path = os.path.join(SAMPLE_DATA, 'predictors.csv')
# make outputs to be overwritten
predictor_dict = recmodel_client.MODEL_SPEC.inputs.get(
predictor_dict = recmodel_client.MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(
predictor_table_path).to_dict(orient='index')
predictor_list = predictor_dict.keys()

View File

@ -269,7 +269,7 @@ class TestDescribeArgFromSpec(unittest.TestCase):
expected_rst = (
'.. _carbon-pools-path-columns-lucode:\n\n' +
'**lucode** (`integer <input_types.html#integer>`__, *required*): ' +
carbon.MODEL_SPEC.inputs.get('carbon_pools_path').columns.get('lucode').about
carbon.MODEL_SPEC.get_input('carbon_pools_path').columns.get('lucode').about
)
self.assertEqual(repr(out), repr(expected_rst))
@ -318,7 +318,7 @@ class TestMetadataFromSpec(unittest.TestCase):
"""Test writing metadata for an invest output workspace."""
# An example invest output spec
output_spec = spec_utils.ModelOutputs(
output_spec = [
spec_utils.DirectoryOutput(
id='output',
contents=spec_utils.Contents(
@ -348,7 +348,7 @@ class TestMetadataFromSpec(unittest.TestCase):
spec_utils.build_output_spec('taskgraph_cache', spec_utils.TASKGRAPH_DIR)
)
)
)
]
# Generate an output workspace with real files, without
# running an invest model.
_generate_files_from_spec(output_spec, self.workspace_dir)
@ -362,7 +362,7 @@ class TestMetadataFromSpec(unittest.TestCase):
userguide='',
aliases=[],
ui_spec={},
inputs=spec_utils.ModelInputs(),
inputs={},
args_with_spatial_overlap={},
outputs=output_spec
)

View File

@ -68,13 +68,13 @@ class UsageLoggingTests(unittest.TestCase):
model_spec = spec_utils.ModelSpec(
model_id='', model_title='', userguide=None,
aliases=None, ui_spec=spec_utils.UISpec(order=[], hidden={}),
inputs=spec_utils.ModelInputs(
inputs=[
spec_utils.SingleBandRasterInput(id='raster', band=spec_utils.Input()),
spec_utils.VectorInput(id='vector', geometries={}, fields={}),
spec_utils.StringInput(id='not_a_gis_input'),
spec_utils.SingleBandRasterInput(id='blank_raster_path', band=spec_utils.Input()),
spec_utils.VectorInput(id='blank_vector_path', geometries={}, fields={})
),
],
outputs={},
args_with_spatial_overlap=None)

View File

@ -22,7 +22,6 @@ from natcap.invest import spec_utils
from natcap.invest.spec_utils import (
u,
ModelSpec,
ModelInputs,
UISpec,
Fields,
Contents,
@ -50,7 +49,7 @@ def ui_spec_with_defaults(order=[], hidden=[]):
return UISpec(order=order, hidden=hidden)
def model_spec_with_defaults(model_id='', model_title='', userguide='', aliases=None,
ui_spec=ui_spec_with_defaults(), inputs=ModelInputs(), outputs=set(),
ui_spec=ui_spec_with_defaults(), inputs={}, outputs={},
args_with_spatial_overlap=[]):
return ModelSpec(model_id=model_id, model_title=model_title, userguide=userguide,
aliases=aliases, ui_spec=ui_spec, inputs=inputs, outputs=outputs,
@ -278,8 +277,8 @@ class ValidatorTest(unittest.TestCase):
from natcap.invest import spec_utils
from natcap.invest import validation
args_spec = model_spec_with_defaults(inputs=ModelInputs(
spec_utils.build_input_spec('n_workers', spec_utils.N_WORKERS)))
args_spec = model_spec_with_defaults(inputs=[
spec_utils.build_input_spec('n_workers', spec_utils.N_WORKERS)])
@validation.invest_validator
def validate(args, limit_to=None):
@ -918,8 +917,7 @@ class CSVValidation(unittest.TestCase):
with open(path, 'w') as file:
file.write('1,2,3')
spec = model_spec_with_defaults(inputs=ModelInputs(
CSVInput(id="mock_csv_path")))
spec = model_spec_with_defaults(inputs=[CSVInput(id="mock_csv_path")])
# validate a mocked CSV that will take 6 seconds to return a value
args = {"mock_csv_path": path}
@ -1590,14 +1588,14 @@ class TestValidationFromSpec(unittest.TestCase):
"""Validation: check that conditional requirements works."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a"),
NumberInput(id="number_b", required=False),
NumberInput(id="number_c", required="number_b"),
NumberInput(id="number_d", required="number_b | number_c"),
NumberInput(id="number_e", required="number_b & number_d"),
NumberInput(id="number_f", required="not number_b")
))
])
args = {
"number_a": 123,
@ -1629,11 +1627,11 @@ class TestValidationFromSpec(unittest.TestCase):
"""Validation: check AssertionError if expression is missing a var."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a"),
NumberInput(id="number_b", required=False),
NumberInput(id="number_c", required="some_var_not_in_args")
))
])
args = {
"number_a": 123,
@ -1655,11 +1653,11 @@ class TestValidationFromSpec(unittest.TestCase):
with open(csv_b_path, 'w') as csv:
csv.write('1,2,3')
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
BooleanInput(id="condition", required=False),
CSVInput(id="csv_a", required="condition"),
CSVInput(id="csv_b", required="not condition")
))
])
args = {
"condition": True,
@ -1673,9 +1671,9 @@ class TestValidationFromSpec(unittest.TestCase):
def test_requirement_missing(self):
"""Validation: verify absolute requirement on missing key."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a", units=u.none)
))
])
args = {}
self.assertEqual(
[(['number_a'], validation.MESSAGES['MISSING_KEY'])],
@ -1684,9 +1682,9 @@ class TestValidationFromSpec(unittest.TestCase):
def test_requirement_no_value(self):
"""Validation: verify absolute requirement without value."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a", units=u.none)
))
])
args = {'number_a': ''}
self.assertEqual(
@ -1701,9 +1699,9 @@ class TestValidationFromSpec(unittest.TestCase):
def test_invalid_value(self):
"""Validation: verify invalidity."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a", units=u.none)
))
])
args = {'number_a': 'not a number'}
self.assertEqual(
@ -1714,9 +1712,9 @@ class TestValidationFromSpec(unittest.TestCase):
def test_conditionally_required_no_value(self):
"""Validation: verify conditional requirement when no value."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a", units=u.none),
StringInput(id="string_a", required="number_a")))
StringInput(id="string_a", required="number_a")])
args = {'string_a': None, "number_a": 1}
@ -1727,27 +1725,27 @@ class TestValidationFromSpec(unittest.TestCase):
def test_conditionally_required_invalid(self):
"""Validation: verify conditional validity behavior when invalid."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a", units=u.none),
OptionStringInput(
id="string_a",
required="number_a",
options=['AAA', 'BBB']
)
))
])
args = {'string_a': "ZZZ", "number_a": 1}
self.assertEqual(
[(['string_a'], validation.MESSAGES['INVALID_OPTION'].format(
option_list=spec.inputs.get('string_a').options))],
option_list=spec.get_input('string_a').options))],
validation.validate(args, spec))
def test_conditionally_required_vector_fields(self):
"""Validation: conditionally required vector fields."""
from natcap.invest import spec_utils
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
NumberInput(
id="some_number",
expression="value > 0.5",
@ -1761,7 +1759,7 @@ class TestValidationFromSpec(unittest.TestCase):
RatioInput(id="field_b", required="some_number == 2")
)
)
))
])
def _create_vector(filepath, fields=[]):
gpkg_driver = gdal.GetDriverByName('GPKG')
@ -1808,7 +1806,7 @@ class TestValidationFromSpec(unittest.TestCase):
def test_conditionally_required_csv_columns(self):
"""Validation: conditionally required csv columns."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
number_input_spec_with_defaults(
id="some_number",
expression="value > 0.5"
@ -1820,7 +1818,7 @@ class TestValidationFromSpec(unittest.TestCase):
RatioInput(id="field_b", required="some_number == 2")
)
)
))
])
# Create a CSV file with only field_a
csv_path = os.path.join(self.workspace_dir, 'table1.csv')
@ -1865,7 +1863,7 @@ class TestValidationFromSpec(unittest.TestCase):
"""Validation: conditionally required csv rows."""
from natcap.invest import validation
spec = model_spec_with_defaults(
inputs=ModelInputs(
inputs=[
number_input_spec_with_defaults(
id="some_number",
expression="value > 0.5"
@ -1882,7 +1880,7 @@ class TestValidationFromSpec(unittest.TestCase):
)
)
)
)
]
)
# Create a CSV file with only field_a
csv_path = os.path.join(self.workspace_dir, 'table1.csv')
@ -1924,9 +1922,9 @@ class TestValidationFromSpec(unittest.TestCase):
def test_validation_exception(self):
"""Validation: Verify error when an unexpected exception occurs."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a")
))
])
args = {'number_a': 1}
# Patch in a new function that raises an exception
@ -1941,7 +1939,7 @@ class TestValidationFromSpec(unittest.TestCase):
def test_conditionally_required_directory_contents(self):
"""Validation: conditionally required directory contents."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
NumberInput(
id="some_number",
expression="value > 0.5",
@ -1960,7 +1958,7 @@ class TestValidationFromSpec(unittest.TestCase):
)
)
)
))
])
path_1 = os.path.join(self.workspace_dir, 'file.1')
with open(path_1, 'w') as my_file:
my_file.write('col1,col2')
@ -1991,9 +1989,9 @@ class TestValidationFromSpec(unittest.TestCase):
def test_validation_other(self):
"""Validation: verify no error when 'other' type."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
OtherInput(id="number_a")
))
])
args = {'number_a': 1}
self.assertEqual([], validation.validate(args, spec))
@ -2015,7 +2013,7 @@ class TestValidationFromSpec(unittest.TestCase):
del args[previous_key] # delete the last addition to the dict.
spec = model_spec_with_defaults(inputs=ModelInputs(*specs))
spec = model_spec_with_defaults(inputs=specs)
self.assertEqual(
[(['arg_J'], validation.MESSAGES['MISSING_KEY'])],
validation.validate(args, spec))
@ -2025,7 +2023,7 @@ class TestValidationFromSpec(unittest.TestCase):
from natcap.invest import validation
spec = model_spec_with_defaults(
inputs=ModelInputs(
inputs=[
SingleBandRasterInput(
id='raster_a',
band=NumberInput(units=u.none)
@ -2039,7 +2037,7 @@ class TestValidationFromSpec(unittest.TestCase):
fields={},
geometries={'POINT'}
)
),
],
args_with_spatial_overlap={
'spatial_keys': ['raster_a', 'raster_b', 'vector_a'],
'different_projections_ok': True
@ -2095,7 +2093,7 @@ class TestValidationFromSpec(unittest.TestCase):
from natcap.invest import validation
spec = model_spec_with_defaults(
inputs=ModelInputs(
inputs=[
SingleBandRasterInput(
id='raster_a',
band=NumberInput(units=u.none)
@ -2104,7 +2102,7 @@ class TestValidationFromSpec(unittest.TestCase):
id='raster_b',
band=NumberInput(units=u.none)
)
),
],
args_with_spatial_overlap={
'spatial_keys': ['raster_a', 'raster_b'],
'different_projections_ok': True
@ -2139,7 +2137,7 @@ class TestValidationFromSpec(unittest.TestCase):
from natcap.invest import validation
spec = model_spec_with_defaults(
inputs=ModelInputs(
inputs=[
SingleBandRasterInput(
id='raster_a',
band=NumberInput(units=u.none)
@ -2155,7 +2153,7 @@ class TestValidationFromSpec(unittest.TestCase):
fields=Fields(),
geometries={'POINT'}
)
),
],
args_with_spatial_overlap={
'spatial_keys': ['raster_a', 'raster_b', 'vector_a'],
'different_projections_ok': True
@ -2205,8 +2203,7 @@ class TestValidationFromSpec(unittest.TestCase):
from natcap.invest import validation
args = {'a': 'a', 'b': 'b'}
spec = model_spec_with_defaults(inputs=ModelInputs(
StringInput(id='a')))
spec = model_spec_with_defaults(inputs=[StringInput(id='a')])
message = 'DEBUG:natcap.invest.validation:Provided key b does not exist in MODEL_SPEC'
with self.assertLogs('natcap.invest.validation', level='DEBUG') as cm:
@ -2224,8 +2221,7 @@ class TestValidationFromSpec(unittest.TestCase):
'e': '0.5', # middle
'f': '1' # upper bound
}
spec = model_spec_with_defaults(inputs=ModelInputs(
*(RatioInput(id=name) for name in args)))
spec = model_spec_with_defaults(inputs=[RatioInput(id=name) for name in args])
expected_warnings = [
(['a'], validation.MESSAGES['NOT_A_NUMBER'].format(value=args['a'])),
@ -2248,8 +2244,8 @@ class TestValidationFromSpec(unittest.TestCase):
'e': '55.5', # middle
'f': '100' # upper bound
}
spec = model_spec_with_defaults(inputs=ModelInputs(
*[PercentInput(id=name) for name in args]))
spec = model_spec_with_defaults(
inputs=[PercentInput(id=name) for name in args])
expected_warnings = [
(['a'], validation.MESSAGES['NOT_A_NUMBER'].format(value=args['a'])),
@ -2270,8 +2266,8 @@ class TestValidationFromSpec(unittest.TestCase):
'c': '-1', # negative integers are ok
'd': '0'
}
spec = model_spec_with_defaults(inputs=ModelInputs(
*[IntegerInput(id=name) for name in args]))
spec = model_spec_with_defaults(inputs=[
IntegerInput(id=name) for name in args])
expected_warnings = [
(['a'], validation.MESSAGES['NOT_A_NUMBER'].format(value=args['a'])),
@ -2300,12 +2296,12 @@ class TestArgsEnabled(unittest.TestCase):
def test_args_enabled(self):
"""Validation: test getting args enabled/disabled status."""
from natcap.invest import validation
spec = model_spec_with_defaults(inputs=ModelInputs(
spec = model_spec_with_defaults(inputs=[
Input(id='a'),
Input(id='b', allowed='a'),
Input(id='c', allowed='not a'),
Input(id='d', allowed='b <= 3')
))
])
args = {
'a': 'foo',
'b': 2,