remove ModelInputs and ModelOutputs classes
This commit is contained in:
parent
18d8d10992
commit
960b2999c2
|
@ -535,7 +535,7 @@ def execute(args):
|
|||
'Checking that watersheds have entries for every `ws_id` in the '
|
||||
'valuation table.')
|
||||
# Open/read in valuation parameters from CSV file
|
||||
valuation_df = MODEL_SPEC.inputs.get(
|
||||
valuation_df = MODEL_SPEC.get_input(
|
||||
'valuation_table_path').get_validated_dataframe(args['valuation_table_path'])
|
||||
watershed_vector = gdal.OpenEx(
|
||||
args['watersheds_path'], gdal.OF_VECTOR)
|
||||
|
@ -658,15 +658,15 @@ def execute(args):
|
|||
'lulc': pygeoprocessing.get_raster_info(clipped_lulc_path)['nodata'][0]}
|
||||
|
||||
# Open/read in the csv file into a dictionary and add to arguments
|
||||
bio_df = MODEL_SPEC.inputs.get(
|
||||
'biophysical_table_path').get_validated_dataframe(args['biophysical_table_path'])
|
||||
bio_df = MODEL_SPEC.get_input('biophysical_table_path').get_validated_dataframe(
|
||||
args['biophysical_table_path'])
|
||||
|
||||
bio_lucodes = set(bio_df.index.values)
|
||||
bio_lucodes.add(nodata_dict['lulc'])
|
||||
LOGGER.debug(f'bio_lucodes: {bio_lucodes}')
|
||||
|
||||
if 'demand_table_path' in args and args['demand_table_path'] != '':
|
||||
demand_df = MODEL_SPEC.inputs.get('demand_table_path').get_validated_dataframe(
|
||||
demand_df = MODEL_SPEC.get_input('demand_table_path').get_validated_dataframe(
|
||||
args['demand_table_path'])
|
||||
demand_reclassify_dict = dict(
|
||||
[(lucode, row['demand']) for lucode, row in demand_df.iterrows()])
|
||||
|
|
|
@ -321,7 +321,7 @@ def execute(args):
|
|||
"Baseline LULC Year is earlier than the Alternate LULC Year."
|
||||
)
|
||||
|
||||
carbon_pool_df = MODEL_SPEC.inputs.get(
|
||||
carbon_pool_df = MODEL_SPEC.get_input(
|
||||
'carbon_pools_path').get_validated_dataframe(args['carbon_pools_path'])
|
||||
|
||||
try:
|
||||
|
|
|
@ -148,7 +148,7 @@ def export_to_python(target_filepath, model_id, args_dict=None):
|
|||
|
||||
if args_dict is None:
|
||||
cast_args = {
|
||||
key: '' for key in models.model_id_to_spec[model_id].inputs.__dict__.keys()}
|
||||
key: '' for key in models.model_id_to_spec[model_id].inputs_dict.keys()}
|
||||
else:
|
||||
cast_args = dict((str(key), value) for (key, value)
|
||||
in args_dict.items())
|
||||
|
|
|
@ -585,7 +585,7 @@ def execute(args):
|
|||
task_graph, n_workers, intermediate_dir, output_dir, suffix = (
|
||||
_set_up_workspace(args))
|
||||
|
||||
snapshots = MODEL_SPEC.inputs.get(
|
||||
snapshots = MODEL_SPEC.get_input(
|
||||
'landcover_snapshot_csv').get_validated_dataframe(
|
||||
args['landcover_snapshot_csv'])['raster_path'].to_dict()
|
||||
|
||||
|
@ -607,7 +607,7 @@ def execute(args):
|
|||
|
||||
# We're assuming that the LULC initial variables and the carbon pool
|
||||
# transient table are combined into a single lookup table.
|
||||
biophysical_df = MODEL_SPEC.inputs.get(
|
||||
biophysical_df = MODEL_SPEC.get_input(
|
||||
'biophysical_table_path').get_validated_dataframe(
|
||||
args['biophysical_table_path'])
|
||||
|
||||
|
@ -977,7 +977,7 @@ def execute(args):
|
|||
prices = None
|
||||
if args.get('do_economic_analysis', False): # Do if truthy
|
||||
if args.get('use_price_table', False):
|
||||
prices = MODEL_SPEC.inputs.get(
|
||||
prices = MODEL_SPEC.get_input(
|
||||
'price_table_path').get_validated_dataframe(
|
||||
args['price_table_path'])['price'].to_dict()
|
||||
else:
|
||||
|
@ -1961,7 +1961,7 @@ def _read_transition_matrix(transition_csv_path, biophysical_df):
|
|||
landcover transition, and the second contains accumulation rates for
|
||||
the pool for the landcover transition.
|
||||
"""
|
||||
table = MODEL_SPEC.inputs.get(
|
||||
table = MODEL_SPEC.get_input(
|
||||
'landcover_transitions_table').get_validated_dataframe(
|
||||
transition_csv_path).reset_index()
|
||||
|
||||
|
@ -2183,7 +2183,7 @@ def validate(args, limit_to=None):
|
|||
|
||||
if ("landcover_snapshot_csv" not in invalid_keys and
|
||||
"landcover_snapshot_csv" in sufficient_keys):
|
||||
snapshots = MODEL_SPEC.inputs.get(
|
||||
snapshots = MODEL_SPEC.get_input(
|
||||
'landcover_snapshot_csv').get_validated_dataframe(
|
||||
args['landcover_snapshot_csv']
|
||||
)['raster_path'].to_dict()
|
||||
|
@ -2205,7 +2205,7 @@ def validate(args, limit_to=None):
|
|||
# check for invalid options in the translation table
|
||||
if ("landcover_transitions_table" not in invalid_keys and
|
||||
"landcover_transitions_table" in sufficient_keys):
|
||||
transitions_spec = MODEL_SPEC.inputs.get('landcover_transitions_table')
|
||||
transitions_spec = MODEL_SPEC.get_input('landcover_transitions_table')
|
||||
transition_options = list(
|
||||
transitions_spec.columns.get('[LULC CODE]').options.keys())
|
||||
# lowercase options since utils call will lowercase table values
|
||||
|
|
|
@ -184,7 +184,7 @@ def execute(args):
|
|||
os.path.join(args['workspace_dir'], 'taskgraph_cache'),
|
||||
n_workers, reporting_interval=5.0)
|
||||
|
||||
snapshots_dict = MODEL_SPEC.inputs.get(
|
||||
snapshots_dict = MODEL_SPEC.get_input(
|
||||
'landcover_snapshot_csv').get_validated_dataframe(
|
||||
args['landcover_snapshot_csv'])['raster_path'].to_dict()
|
||||
|
||||
|
@ -216,7 +216,7 @@ def execute(args):
|
|||
target_path_list=aligned_snapshot_paths,
|
||||
task_name='Align input landcover rasters')
|
||||
|
||||
landcover_df = MODEL_SPEC.inputs.get(
|
||||
landcover_df = MODEL_SPEC.get_input(
|
||||
'lulc_lookup_table_path').get_validated_dataframe(
|
||||
args['lulc_lookup_table_path'])
|
||||
|
||||
|
@ -388,7 +388,7 @@ def _create_biophysical_table(landcover_df, target_biophysical_table_path):
|
|||
"""
|
||||
target_column_names = [
|
||||
spec.id.lower() for spec in
|
||||
coastal_blue_carbon.MODEL_SPEC.inputs.get('biophysical_table_path').columns]
|
||||
coastal_blue_carbon.MODEL_SPEC.get_input('biophysical_table_path').columns]
|
||||
|
||||
with open(target_biophysical_table_path, 'w') as bio_table:
|
||||
bio_table.write(f"{','.join(target_column_names)}\n")
|
||||
|
|
|
@ -2341,7 +2341,7 @@ def _schedule_habitat_tasks(
|
|||
list of pickle file path strings
|
||||
|
||||
"""
|
||||
habitat_dataframe = MODEL_SPEC.inputs.get(
|
||||
habitat_dataframe = MODEL_SPEC.get_input(
|
||||
'habitat_table_path').get_validated_dataframe(habitat_table_path
|
||||
).rename(columns={'protection distance (m)': 'distance'})
|
||||
|
||||
|
|
|
@ -614,7 +614,7 @@ def execute(args):
|
|||
None.
|
||||
|
||||
"""
|
||||
crop_to_landcover_df = MODEL_SPEC.inputs.get(
|
||||
crop_to_landcover_df = MODEL_SPEC.get_input(
|
||||
'landcover_to_crop_table_path').get_validated_dataframe(
|
||||
args['landcover_to_crop_table_path'])
|
||||
|
||||
|
@ -696,7 +696,7 @@ def execute(args):
|
|||
climate_percentile_yield_table_path = os.path.join(
|
||||
args['model_data_path'],
|
||||
_CLIMATE_PERCENTILE_TABLE_PATTERN % crop_name)
|
||||
crop_climate_percentile_df = MODEL_SPEC.inputs.get(
|
||||
crop_climate_percentile_df = MODEL_SPEC.get_input(
|
||||
'model_data_path').contents.get(
|
||||
'climate_percentile_yield_tables').contents.get(
|
||||
'[CROP]_percentile_yield_table.csv').get_validated_dataframe(
|
||||
|
@ -853,7 +853,7 @@ def execute(args):
|
|||
|
||||
# both 'crop_nutrient.csv' and 'crop' are known data/header values for
|
||||
# this model data.
|
||||
nutrient_df = MODEL_SPEC.inputs.get(
|
||||
nutrient_df = MODEL_SPEC.get_input(
|
||||
'model_data_path').contents.get(
|
||||
'crop_nutrient.csv').get_validated_dataframe(
|
||||
os.path.join(args['model_data_path'], 'crop_nutrient.csv'))
|
||||
|
|
|
@ -503,11 +503,11 @@ def execute(args):
|
|||
|
||||
LOGGER.info(
|
||||
"Checking if the landcover raster is missing lucodes")
|
||||
crop_to_landcover_df = MODEL_SPEC.inputs.get(
|
||||
crop_to_landcover_df = MODEL_SPEC.get_input(
|
||||
'landcover_to_crop_table_path').get_validated_dataframe(
|
||||
args['landcover_to_crop_table_path'])
|
||||
|
||||
crop_to_fertilization_rate_df = MODEL_SPEC.inputs.get(
|
||||
crop_to_fertilization_rate_df = MODEL_SPEC.get_input(
|
||||
'fertilization_rate_table_path').get_validated_dataframe(
|
||||
args['fertilization_rate_table_path'])
|
||||
|
||||
|
@ -584,7 +584,7 @@ def execute(args):
|
|||
task_name='crop_climate_bin')
|
||||
dependent_task_list.append(crop_climate_bin_task)
|
||||
|
||||
crop_regression_df = MODEL_SPEC.inputs.get('model_data_path').contents.get(
|
||||
crop_regression_df = MODEL_SPEC.get_input('model_data_path').contents.get(
|
||||
'climate_regression_yield_tables').contents.get(
|
||||
'[CROP]_regression_yield_table.csv').get_validated_dataframe(
|
||||
os.path.join(args['model_data_path'],
|
||||
|
@ -807,7 +807,7 @@ def execute(args):
|
|||
|
||||
# both 'crop_nutrient.csv' and 'crop' are known data/header values for
|
||||
# this model data.
|
||||
nutrient_df = MODEL_SPEC.inputs.get('model_data_path').contents.get(
|
||||
nutrient_df = MODEL_SPEC.get_input('model_data_path').contents.get(
|
||||
'crop_nutrient.csv').get_validated_dataframe(
|
||||
os.path.join(args['model_data_path'], 'crop_nutrient.csv'))
|
||||
|
||||
|
|
|
@ -239,7 +239,7 @@ def build_datastack_archive(args, model_id, datastack_path):
|
|||
if key not in module.MODEL_SPEC.inputs:
|
||||
LOGGER.info(f'Skipping arg {key}; not in model MODEL_SPEC')
|
||||
|
||||
input_spec = module.MODEL_SPEC.inputs.get(key)
|
||||
input_spec = module.MODEL_SPEC.get_input(key)
|
||||
if input_spec.__class__ in file_based_types:
|
||||
if args[key] in {None, ''}:
|
||||
LOGGER.info(
|
||||
|
|
|
@ -425,7 +425,7 @@ def execute(args):
|
|||
# Map non-forest landcover codes to carbon biomasses
|
||||
LOGGER.info('Calculating direct mapped carbon stocks')
|
||||
carbon_maps = []
|
||||
biophysical_df = MODEL_SPEC.inputs.get(
|
||||
biophysical_df = MODEL_SPEC.get_input(
|
||||
'biophysical_table_path').get_validated_dataframe(
|
||||
args['biophysical_table_path'])
|
||||
pool_list = [('c_above', True)]
|
||||
|
@ -644,7 +644,7 @@ def _calculate_lulc_carbon_map(
|
|||
|
||||
"""
|
||||
# classify forest pixels from lulc
|
||||
biophysical_df = MODEL_SPEC.inputs.get(
|
||||
biophysical_df = MODEL_SPEC.get_input(
|
||||
'biophysical_table_path').get_validated_dataframe(biophysical_table_path)
|
||||
|
||||
lucode_to_per_cell_carbon = {}
|
||||
|
@ -704,7 +704,7 @@ def _map_distance_from_tropical_forest_edge(
|
|||
|
||||
"""
|
||||
# Build a list of forest lucodes
|
||||
biophysical_df = MODEL_SPEC.inputs.get(
|
||||
biophysical_df = MODEL_SPEC.get_input(
|
||||
'biophysical_table_path').get_validated_dataframe(
|
||||
biophysical_table_path)
|
||||
forest_codes = biophysical_df[biophysical_df['is_tropical_forest']].index.values
|
||||
|
@ -1036,7 +1036,7 @@ def validate(args, limit_to=None):
|
|||
"""
|
||||
model_spec = copy.deepcopy(MODEL_SPEC)
|
||||
if 'pools_to_calculate' in args and args['pools_to_calculate'] == 'all':
|
||||
model_spec.inputs.get('biophysical_table_path').columns.get('c_below').required = True
|
||||
model_spec.inputs.get('biophysical_table_path').columns.get('c_soil').required = True
|
||||
model_spec.inputs.get('biophysical_table_path').columns.get('c_dead').required = True
|
||||
model_spec.get_input('biophysical_table_path').columns.get('c_below').required = True
|
||||
model_spec.get_input('biophysical_table_path').columns.get('c_soil').required = True
|
||||
model_spec.get_input('biophysical_table_path').columns.get('c_dead').required = True
|
||||
return validation.validate(args, model_spec)
|
||||
|
|
|
@ -455,10 +455,10 @@ def execute(args):
|
|||
|
||||
LOGGER.info("Checking Threat and Sensitivity tables for compliance")
|
||||
# Get CSVs as dictionaries and ensure the key is a string for threats.
|
||||
threat_df = MODEL_SPEC.inputs.get(
|
||||
threat_df = MODEL_SPEC.get_input(
|
||||
'threats_table_path').get_validated_dataframe(
|
||||
args['threats_table_path']).fillna('')
|
||||
sensitivity_df = MODEL_SPEC.inputs.get(
|
||||
sensitivity_df = MODEL_SPEC.get_input(
|
||||
'sensitivity_table_path').get_validated_dataframe(
|
||||
args['sensitivity_table_path'])
|
||||
|
||||
|
@ -1181,10 +1181,10 @@ def validate(args, limit_to=None):
|
|||
"sensitivity_table_path" not in invalid_keys and
|
||||
"threat_raster_folder" not in invalid_keys):
|
||||
# Get CSVs as dictionaries and ensure the key is a string for threats.
|
||||
threat_df = MODEL_SPEC.inputs.get(
|
||||
threat_df = MODEL_SPEC.get_input(
|
||||
'threats_table_path').get_validated_dataframe(
|
||||
args['threats_table_path']).fillna('')
|
||||
sensitivity_df = MODEL_SPEC.inputs.get(
|
||||
sensitivity_df = MODEL_SPEC.get_input(
|
||||
'sensitivity_table_path').get_validated_dataframe(
|
||||
args['sensitivity_table_path'])
|
||||
|
||||
|
|
|
@ -447,8 +447,8 @@ MODEL_SPEC = spec_utils.build_model_spec({
|
|||
}
|
||||
})
|
||||
|
||||
_VALID_RISK_EQS = set(MODEL_SPEC.inputs.get('risk_eq').options.keys())
|
||||
_VALID_DECAY_TYPES = set(MODEL_SPEC.inputs.get('decay_eq').options.keys())
|
||||
_VALID_RISK_EQS = set(MODEL_SPEC.get_input('risk_eq').options.keys())
|
||||
_VALID_DECAY_TYPES = set(MODEL_SPEC.get_input('decay_eq').options.keys())
|
||||
|
||||
|
||||
def execute(args):
|
||||
|
@ -1791,7 +1791,7 @@ def _parse_info_table(info_table_path):
|
|||
info_table_path = os.path.abspath(info_table_path)
|
||||
|
||||
try:
|
||||
table = MODEL_SPEC.inputs.get(
|
||||
table = MODEL_SPEC.get_input(
|
||||
'info_table_path').get_validated_dataframe(info_table_path)
|
||||
except ValueError as err:
|
||||
if 'Index has duplicate keys' in str(err):
|
||||
|
|
|
@ -624,7 +624,7 @@ def execute(args):
|
|||
if args['calc_' + nutrient_id]:
|
||||
nutrients_to_process.append(nutrient_id)
|
||||
|
||||
biophysical_df = MODEL_SPEC.inputs.get(
|
||||
biophysical_df = MODEL_SPEC.get_input(
|
||||
'biophysical_table_path').get_validated_dataframe(
|
||||
args['biophysical_table_path'])
|
||||
|
||||
|
@ -1298,11 +1298,11 @@ def validate(args, limit_to=None):
|
|||
|
||||
for param in ['load', 'eff', 'crit_len']:
|
||||
for nutrient in nutrients_selected:
|
||||
spec_copy.inputs.get('biophysical_table_path').columns.get(
|
||||
spec_copy.get_input('biophysical_table_path').columns.get(
|
||||
f'{param}_{nutrient}').required = True
|
||||
|
||||
if 'n' in nutrients_selected:
|
||||
spec_copy.inputs.get('biophysical_table_path').columns.get(
|
||||
spec_copy.get_input('biophysical_table_path').columns.get(
|
||||
'proportion_subsurface_n').required = True
|
||||
|
||||
validation_warnings = validation.validate(args, spec_copy)
|
||||
|
|
|
@ -1221,7 +1221,7 @@ def _parse_scenario_variables(args):
|
|||
else:
|
||||
farm_vector_path = None
|
||||
|
||||
guild_df = MODEL_SPEC.inputs.get(
|
||||
guild_df = MODEL_SPEC.get_input(
|
||||
'guild_table_path').get_validated_dataframe(guild_table_path)
|
||||
|
||||
LOGGER.info('Checking to make sure guild table has all expected headers')
|
||||
|
@ -1233,7 +1233,7 @@ def _parse_scenario_variables(args):
|
|||
f"'{header}' but was unable to find one. Here are all the "
|
||||
f"headers from {guild_table_path}: {', '.join(guild_df.columns)}")
|
||||
|
||||
landcover_biophysical_df = MODEL_SPEC.inputs.get(
|
||||
landcover_biophysical_df = MODEL_SPEC.get_input(
|
||||
'landcover_biophysical_table_path').get_validated_dataframe(
|
||||
landcover_biophysical_table_path)
|
||||
biophysical_table_headers = landcover_biophysical_df.columns
|
||||
|
|
|
@ -609,7 +609,7 @@ def execute(args):
|
|||
# Compute the regression
|
||||
coefficient_json_path = os.path.join(
|
||||
intermediate_dir, 'predictor_estimates.json')
|
||||
predictor_df = MODEL_SPEC.inputs.get(
|
||||
predictor_df = MODEL_SPEC.get_input(
|
||||
'predictor_table_path').get_validated_dataframe(
|
||||
args['predictor_table_path'])
|
||||
predictor_id_list = predictor_df.index
|
||||
|
@ -996,7 +996,7 @@ def _schedule_predictor_data_processing(
|
|||
'line_intersect_length': _line_intersect_length,
|
||||
}
|
||||
|
||||
predictor_df = MODEL_SPEC.inputs.get(
|
||||
predictor_df = MODEL_SPEC.get_input(
|
||||
'predictor_table_path').get_validated_dataframe(predictor_table_path)
|
||||
predictor_task_list = []
|
||||
predictor_json_list = [] # tracks predictor files to add to gpkg
|
||||
|
@ -1765,7 +1765,7 @@ def _validate_same_id_lengths(table_path):
|
|||
string message if IDs are too long
|
||||
|
||||
"""
|
||||
predictor_df = MODEL_SPEC.inputs.get(
|
||||
predictor_df = MODEL_SPEC.get_input(
|
||||
'predictor_table_path').get_validated_dataframe(table_path)
|
||||
too_long = set()
|
||||
for p_id in predictor_df.index:
|
||||
|
@ -1794,11 +1794,11 @@ def _validate_same_ids_and_types(
|
|||
string message if any of the fields in 'id' and 'type' don't match
|
||||
between tables.
|
||||
"""
|
||||
predictor_df = MODEL_SPEC.inputs.get(
|
||||
predictor_df = MODEL_SPEC.get_input(
|
||||
'predictor_table_path').get_validated_dataframe(
|
||||
predictor_table_path)
|
||||
|
||||
scenario_predictor_df = MODEL_SPEC.inputs.get(
|
||||
scenario_predictor_df = MODEL_SPEC.get_input(
|
||||
'scenario_predictor_table_path').get_validated_dataframe(
|
||||
scenario_predictor_table_path)
|
||||
|
||||
|
@ -1825,7 +1825,7 @@ def _validate_same_projection(base_vector_path, table_path):
|
|||
"""
|
||||
# This will load the table as a list of paths which we can iterate through
|
||||
# without bothering the rest of the table structure
|
||||
data_paths = MODEL_SPEC.inputs.get(
|
||||
data_paths = MODEL_SPEC.get_input(
|
||||
'predictor_table_path').get_validated_dataframe(
|
||||
table_path)['path'].tolist()
|
||||
|
||||
|
@ -1868,7 +1868,7 @@ def _validate_predictor_types(table_path):
|
|||
string message if any value in the ``type`` column does not match a
|
||||
valid type, ignoring leading/trailing whitespace.
|
||||
"""
|
||||
df = MODEL_SPEC.inputs.get(
|
||||
df = MODEL_SPEC.get_input(
|
||||
'predictor_table_path').get_validated_dataframe(table_path)
|
||||
# ignore leading/trailing whitespace because it will be removed
|
||||
# when the type values are used
|
||||
|
|
|
@ -283,7 +283,7 @@ def execute(args):
|
|||
'b': float(args['b_coef']),
|
||||
}
|
||||
if (args['valuation_function'] not in
|
||||
MODEL_SPEC.inputs.get('valuation_function').options):
|
||||
MODEL_SPEC.get_input('valuation_function').options):
|
||||
raise ValueError('Valuation function type %s not recognized' %
|
||||
args['valuation_function'])
|
||||
max_valuation_radius = float(args['max_valuation_radius'])
|
||||
|
|
|
@ -562,7 +562,7 @@ def execute(args):
|
|||
|
||||
"""
|
||||
file_suffix = utils.make_suffix_string(args, 'results_suffix')
|
||||
biophysical_df = MODEL_SPEC.inputs.get(
|
||||
biophysical_df = MODEL_SPEC.get_input(
|
||||
'biophysical_table_path').get_validated_dataframe(
|
||||
args['biophysical_table_path'])
|
||||
|
||||
|
|
|
@ -625,17 +625,17 @@ def execute(args):
|
|||
# fail early on a missing required rain events table
|
||||
if (not args['user_defined_local_recharge'] and
|
||||
not args['user_defined_climate_zones']):
|
||||
rain_events_df = MODEL_SPEC.inputs.get(
|
||||
rain_events_df = MODEL_SPEC.get_input(
|
||||
'rain_events_table_path').get_validated_dataframe(
|
||||
args['rain_events_table_path'])
|
||||
|
||||
biophysical_df = MODEL_SPEC.inputs.get(
|
||||
biophysical_df = MODEL_SPEC.get_input(
|
||||
'biophysical_table_path').get_validated_dataframe(
|
||||
args['biophysical_table_path'])
|
||||
|
||||
if args['monthly_alpha']:
|
||||
# parse out the alpha lookup table of the form (month_id: alpha_val)
|
||||
alpha_month_map = MODEL_SPEC.inputs.get(
|
||||
alpha_month_map = MODEL_SPEC.get_input(
|
||||
'monthly_alpha_path').get_validated_dataframe(
|
||||
args['monthly_alpha_path'])['alpha'].to_dict()
|
||||
else:
|
||||
|
@ -814,7 +814,7 @@ def execute(args):
|
|||
'table_name': 'Climate Zone'}
|
||||
for month_id in range(N_MONTHS):
|
||||
if args['user_defined_climate_zones']:
|
||||
cz_rain_events_df = MODEL_SPEC.inputs.get(
|
||||
cz_rain_events_df = MODEL_SPEC.get_input(
|
||||
'climate_zone_table_path').get_validated_dataframe(
|
||||
args['climate_zone_table_path'])
|
||||
climate_zone_rain_events_month = (
|
||||
|
|
|
@ -8,8 +8,8 @@ import queue
|
|||
import re
|
||||
import threading
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from typing import Union, ClassVar
|
||||
|
||||
from osgeo import gdal
|
||||
from osgeo import ogr
|
||||
|
@ -199,13 +199,6 @@ class IterableWithDotAccess():
|
|||
# else:
|
||||
# raise StopIteration
|
||||
|
||||
|
||||
class ModelInputs(IterableWithDotAccess):
|
||||
pass
|
||||
|
||||
class ModelOutputs(IterableWithDotAccess):
|
||||
pass
|
||||
|
||||
class Rows(IterableWithDotAccess):
|
||||
pass
|
||||
|
||||
|
@ -223,19 +216,19 @@ class Input:
|
|||
id: str = ''
|
||||
name: str = ''
|
||||
about: str = ''
|
||||
required: Union[bool, str] = True
|
||||
allowed: Union[bool, str] = True
|
||||
required: typing.Union[bool, str] = True
|
||||
allowed: typing.Union[bool, str] = True
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Output:
|
||||
id: str = ''
|
||||
about: str = ''
|
||||
created_if: Union[bool, str] = True
|
||||
created_if: typing.Union[bool, str] = True
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FileInput(Input):
|
||||
permissions: str = 'r'
|
||||
type: ClassVar[str] = 'file'
|
||||
type: typing.ClassVar[str] = 'file'
|
||||
|
||||
# @timeout
|
||||
def validate(self, filepath):
|
||||
|
@ -269,10 +262,10 @@ class FileInput(Input):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class SingleBandRasterInput(FileInput):
|
||||
band: Union[Input, None] = None
|
||||
projected: Union[bool, None] = None
|
||||
projection_units: Union[pint.Unit, None] = None
|
||||
type: ClassVar[str] = 'raster'
|
||||
band: typing.Union[Input, None] = None
|
||||
projected: typing.Union[bool, None] = None
|
||||
projection_units: typing.Union[pint.Unit, None] = None
|
||||
type: typing.ClassVar[str] = 'raster'
|
||||
|
||||
# @timeout
|
||||
def validate(self, filepath):
|
||||
|
@ -313,10 +306,10 @@ class SingleBandRasterInput(FileInput):
|
|||
@dataclasses.dataclass
|
||||
class VectorInput(FileInput):
|
||||
geometries: set = dataclasses.field(default_factory=dict)
|
||||
fields: Union[Fields, None] = None
|
||||
projected: Union[bool, None] = None
|
||||
projection_units: Union[pint.Unit, None] = None
|
||||
type: ClassVar[str] = 'vector'
|
||||
fields: typing.Union[Fields, None] = None
|
||||
projected: typing.Union[bool, None] = None
|
||||
projection_units: typing.Union[pint.Unit, None] = None
|
||||
type: typing.ClassVar[str] = 'vector'
|
||||
|
||||
# @timeout
|
||||
def validate(self, filepath):
|
||||
|
@ -397,12 +390,12 @@ class VectorInput(FileInput):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class RasterOrVectorInput(SingleBandRasterInput, VectorInput):
|
||||
band: Union[Input, None] = None
|
||||
band: typing.Union[Input, None] = None
|
||||
geometries: set = dataclasses.field(default_factory=dict)
|
||||
fields: Union[Fields, None] = None
|
||||
projected: Union[bool, None] = None
|
||||
projection_units: Union[pint.Unit, None] = None
|
||||
type: ClassVar[str] = 'raster_or_vector'
|
||||
fields: typing.Union[Fields, None] = None
|
||||
projected: typing.Union[bool, None] = None
|
||||
projection_units: typing.Union[pint.Unit, None] = None
|
||||
type: typing.ClassVar[str] = 'raster_or_vector'
|
||||
|
||||
# @timeout
|
||||
def validate(self, filepath):
|
||||
|
@ -427,10 +420,10 @@ class RasterOrVectorInput(SingleBandRasterInput, VectorInput):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class CSVInput(FileInput):
|
||||
columns: Union[Columns, None] = None
|
||||
rows: Union[Rows, None] = None
|
||||
index_col: Union[str, None] = None
|
||||
type: ClassVar[str] = 'csv'
|
||||
columns: typing.Union[Columns, None] = None
|
||||
rows: typing.Union[Rows, None] = None
|
||||
index_col: typing.Union[str, None] = None
|
||||
type: typing.ClassVar[str] = 'csv'
|
||||
|
||||
# @timeout
|
||||
def validate(self, filepath):
|
||||
|
@ -548,10 +541,10 @@ class CSVInput(FileInput):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class DirectoryInput(Input):
|
||||
contents: Union[Contents, None] = None
|
||||
contents: typing.Union[Contents, None] = None
|
||||
permissions: str = ''
|
||||
must_exist: bool = True
|
||||
type: ClassVar[str] = 'directory'
|
||||
type: typing.ClassVar[str] = 'directory'
|
||||
|
||||
# @timeout
|
||||
def validate(self, dirpath):
|
||||
|
@ -619,9 +612,9 @@ class DirectoryInput(Input):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class NumberInput(Input):
|
||||
units: Union[pint.Unit, None] = None
|
||||
expression: Union[str, None] = None
|
||||
type: ClassVar[str] = 'number'
|
||||
units: typing.Union[pint.Unit, None] = None
|
||||
expression: typing.Union[str, None] = None
|
||||
type: typing.ClassVar[str] = 'number'
|
||||
|
||||
def validate(self, value):
|
||||
"""Validate numbers.
|
||||
|
@ -662,7 +655,7 @@ class NumberInput(Input):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class IntegerInput(Input):
|
||||
type: ClassVar[str] = 'integer'
|
||||
type: typing.ClassVar[str] = 'integer'
|
||||
|
||||
def validate(self, value):
|
||||
"""Validate an integer.
|
||||
|
@ -689,7 +682,7 @@ class IntegerInput(Input):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class RatioInput(Input):
|
||||
type: ClassVar[str] = 'ratio'
|
||||
type: typing.ClassVar[str] = 'ratio'
|
||||
|
||||
def validate(self, value):
|
||||
"""Validate a ratio (a proportion expressed as a value from 0 to 1).
|
||||
|
@ -717,7 +710,7 @@ class RatioInput(Input):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class PercentInput(Input):
|
||||
type: ClassVar[str] = 'percent'
|
||||
type: typing.ClassVar[str] = 'percent'
|
||||
|
||||
def validate(self, value):
|
||||
"""Validate a percent (a proportion expressed as a value from 0 to 100).
|
||||
|
@ -744,7 +737,7 @@ class PercentInput(Input):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class BooleanInput(Input):
|
||||
type: ClassVar[str] = 'boolean'
|
||||
type: typing.ClassVar[str] = 'boolean'
|
||||
|
||||
def validate(self, value):
|
||||
"""Validate a boolean value.
|
||||
|
@ -767,8 +760,8 @@ class BooleanInput(Input):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class StringInput(Input):
|
||||
regexp: Union[str, None] = None
|
||||
type: ClassVar[str] = 'string'
|
||||
regexp: typing.Union[str, None] = None
|
||||
type: typing.ClassVar[str] = 'string'
|
||||
|
||||
def validate(self, value):
|
||||
"""Validate an arbitrary string.
|
||||
|
@ -794,8 +787,8 @@ class StringInput(Input):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class OptionStringInput(Input):
|
||||
options: Union[list, None] = None
|
||||
type: ClassVar[str] = 'option_string'
|
||||
options: typing.Union[list, None] = None
|
||||
type: typing.ClassVar[str] = 'option_string'
|
||||
|
||||
def validate(self, value):
|
||||
"""Validate that a string is in a set of options.
|
||||
|
@ -828,26 +821,26 @@ class OtherInput(Input):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class SingleBandRasterOutput(Output):
|
||||
band: Union[Input, None] = None
|
||||
projected: Union[bool, None] = None
|
||||
projection_units: Union[pint.Unit, None] = None
|
||||
band: typing.Union[Input, None] = None
|
||||
projected: typing.Union[bool, None] = None
|
||||
projection_units: typing.Union[pint.Unit, None] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VectorOutput(Output):
|
||||
geometries: set = dataclasses.field(default_factory=dict)
|
||||
fields: Union[Fields, None] = None
|
||||
projected: Union[bool, None] = None
|
||||
projection_units: Union[pint.Unit, None] = None
|
||||
fields: typing.Union[Fields, None] = None
|
||||
projected: typing.Union[bool, None] = None
|
||||
projection_units: typing.Union[pint.Unit, None] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CSVOutput(Output):
|
||||
columns: Union[Columns, None] = None
|
||||
rows: Union[Rows, None] = None
|
||||
index_col: Union[str, None] = None
|
||||
columns: typing.Union[Columns, None] = None
|
||||
rows: typing.Union[Rows, None] = None
|
||||
index_col: typing.Union[str, None] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DirectoryOutput(Output):
|
||||
contents: Union[Contents, None] = None
|
||||
contents: typing.Union[Contents, None] = None
|
||||
permissions: str = ''
|
||||
must_exist: bool = True
|
||||
|
||||
|
@ -857,8 +850,8 @@ class FileOutput(Output):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class NumberOutput(Output):
|
||||
units: Union[pint.Unit, None] = None
|
||||
expression: Union[str, None] = None
|
||||
units: typing.Union[pint.Unit, None] = None
|
||||
expression: typing.Union[str, None] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class IntegerOutput(Output):
|
||||
|
@ -874,15 +867,15 @@ class PercentOutput(Output):
|
|||
|
||||
@dataclasses.dataclass
|
||||
class StringOutput(Output):
|
||||
regexp: Union[str, None] = None
|
||||
regexp: typing.Union[str, None] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class OptionStringOutput(Output):
|
||||
options: Union[list, None] = None
|
||||
options: typing.Union[list, None] = None
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UISpec:
|
||||
order: Union[list, None] = None
|
||||
order: typing.Union[list, None] = None
|
||||
hidden: list = None
|
||||
dropdown_functions: dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
@ -892,21 +885,26 @@ class ModelSpec:
|
|||
model_id: str
|
||||
model_title: str
|
||||
userguide: str
|
||||
aliases: set
|
||||
ui_spec: UISpec
|
||||
inputs: ModelInputs
|
||||
outputs: set
|
||||
inputs: typing.Iterable[Input]
|
||||
outputs: typing.Iterable[Output]
|
||||
args_with_spatial_overlap: dict
|
||||
aliases: set = dataclasses.field(default_factory=set)
|
||||
|
||||
def __post_init__(self):
|
||||
self.inputs_dict = {_input.id: _input for _input in self.inputs}
|
||||
self.outputs_dict = {_output.id: _output for _output in self.outputs}
|
||||
|
||||
def get_input(self, key):
|
||||
return self.inputs_dict[key]
|
||||
|
||||
|
||||
def build_model_spec(model_spec):
|
||||
input_specs = [
|
||||
inputs = [
|
||||
build_input_spec(argkey, argspec)
|
||||
for argkey, argspec in model_spec['args'].items()]
|
||||
output_specs = [
|
||||
outputs = [
|
||||
build_output_spec(argkey, argspec) for argkey, argspec in model_spec['outputs'].items()]
|
||||
inputs = ModelInputs(*input_specs)
|
||||
outputs = ModelOutputs(*output_specs)
|
||||
ui_spec = UISpec(
|
||||
order=model_spec['ui_spec']['order'],
|
||||
hidden=model_spec['ui_spec'].get('hidden', None),
|
||||
|
@ -1682,8 +1680,8 @@ def describe_arg_from_name(module_name, *arg_keys):
|
|||
module = importlib.import_module(module_name)
|
||||
# start with the spec for all args
|
||||
# narrow down to the nested spec indicated by the sequence of arg keys
|
||||
spec = module.MODEL_SPEC.inputs
|
||||
for i, key in enumerate(arg_keys):
|
||||
spec = module.MODEL_SPEC.get_input(arg_keys[0])
|
||||
for i, key in enumerate(arg_keys[1:]):
|
||||
# convert raster band numbers to ints
|
||||
if arg_keys[i - 1] == 'bands':
|
||||
key = int(key)
|
||||
|
|
|
@ -497,7 +497,7 @@ def execute(args):
|
|||
# Build a lookup dictionary mapping each LULC code to its row
|
||||
# sort by the LULC codes upfront because we use the sorted list in multiple
|
||||
# places. it's more efficient to do this once.
|
||||
biophysical_df = MODEL_SPEC.inputs.get(
|
||||
biophysical_df = MODEL_SPEC.get_input(
|
||||
'biophysical_table').get_validated_dataframe(
|
||||
args['biophysical_table']).sort_index()
|
||||
sorted_lucodes = biophysical_df.index.to_list()
|
||||
|
|
|
@ -475,7 +475,7 @@ def execute(args):
|
|||
intermediate_dir = os.path.join(
|
||||
args['workspace_dir'], 'intermediate')
|
||||
utils.make_directories([args['workspace_dir'], intermediate_dir])
|
||||
biophysical_df = MODEL_SPEC.inputs.get(
|
||||
biophysical_df = MODEL_SPEC.get_input(
|
||||
'biophysical_table_path').get_validated_dataframe(
|
||||
args['biophysical_table_path'])
|
||||
|
||||
|
@ -1160,7 +1160,7 @@ def calculate_energy_savings(
|
|||
for field in target_building_layer.schema]
|
||||
type_field_index = fieldnames.index('type')
|
||||
|
||||
energy_consumption_df = MODEL_SPEC.inputs.get(
|
||||
energy_consumption_df = MODEL_SPEC.get_input(
|
||||
'energy_consumption_table_path').get_validated_dataframe(
|
||||
energy_consumption_table_path)
|
||||
|
||||
|
@ -1541,7 +1541,7 @@ def validate(args, limit_to=None):
|
|||
invalid_keys = validation.get_invalid_keys(validation_warnings)
|
||||
if ('biophysical_table_path' not in invalid_keys and
|
||||
'cc_method' not in invalid_keys):
|
||||
spec = copy.deepcopy(MODEL_SPEC.inputs.get('biophysical_table_path'))
|
||||
spec = copy.deepcopy(MODEL_SPEC.get_input('biophysical_table_path'))
|
||||
if args['cc_method'] == 'factors':
|
||||
spec.columns.get('shade').required = True
|
||||
spec.columns.get('albedo').required = True
|
||||
|
|
|
@ -307,7 +307,7 @@ def execute(args):
|
|||
task_name='align raster stack')
|
||||
|
||||
# Load CN table
|
||||
cn_df = MODEL_SPEC.inputs.get(
|
||||
cn_df = MODEL_SPEC.get_input(
|
||||
'curve_number_table_path').get_validated_dataframe(
|
||||
args['curve_number_table_path'])
|
||||
|
||||
|
@ -636,7 +636,7 @@ def _calculate_damage_to_infrastructure_in_aoi(
|
|||
infrastructure_vector = gdal.OpenEx(structures_vector_path, gdal.OF_VECTOR)
|
||||
infrastructure_layer = infrastructure_vector.GetLayer()
|
||||
|
||||
damage_type_map = MODEL_SPEC.inputs.get(
|
||||
damage_type_map = MODEL_SPEC.get_input(
|
||||
'infrastructure_damage_loss_table_path').get_validated_dataframe(
|
||||
structures_damage_table)['damage'].to_dict()
|
||||
|
||||
|
@ -936,7 +936,7 @@ def validate(args, limit_to=None):
|
|||
if ("curve_number_table_path" not in invalid_keys and
|
||||
"curve_number_table_path" in sufficient_keys):
|
||||
# Load CN table. Resulting DF has index and CN_X columns only.
|
||||
cn_df = MODEL_SPEC.inputs.get(
|
||||
cn_df = MODEL_SPEC.get_input(
|
||||
'curve_number_table_path').get_validated_dataframe(
|
||||
args['curve_number_table_path'])
|
||||
# Check for NaN values.
|
||||
|
|
|
@ -944,7 +944,7 @@ def execute(args):
|
|||
aoi_reprojection_task, lulc_mask_task]
|
||||
)
|
||||
|
||||
attr_table = MODEL_SPEC.inputs.get(
|
||||
attr_table = MODEL_SPEC.get_input(
|
||||
'lulc_attribute_table').get_validated_dataframe(args['lulc_attribute_table'])
|
||||
kernel_paths = {} # search_radius, kernel path
|
||||
kernel_tasks = {} # search_radius, kernel task
|
||||
|
@ -963,7 +963,7 @@ def execute(args):
|
|||
lucode_to_search_radii = list(
|
||||
urban_nature_attrs[['search_radius_m']].itertuples(name=None))
|
||||
elif args['search_radius_mode'] == RADIUS_OPT_POP_GROUP:
|
||||
pop_group_table = MODEL_SPEC.inputs.get(
|
||||
pop_group_table = MODEL_SPEC.get_input(
|
||||
'population_group_radii_table').get_validated_dataframe(
|
||||
args['population_group_radii_table'])
|
||||
search_radii = set(pop_group_table['search_radius_m'].unique())
|
||||
|
@ -971,7 +971,7 @@ def execute(args):
|
|||
search_radii_by_pop_group = pop_group_table['search_radius_m'].to_dict()
|
||||
else:
|
||||
valid_options = ', '.join(
|
||||
MODEL_SPEC.inputs.get('search_radius_mode').options.keys())
|
||||
MODEL_SPEC.get_input('search_radius_mode').options.keys())
|
||||
raise ValueError(
|
||||
"Invalid search radius mode provided: "
|
||||
f"{args['search_radius_mode']}; must be one of {valid_options}")
|
||||
|
@ -1843,7 +1843,7 @@ def _reclassify_urban_nature_area(
|
|||
Returns:
|
||||
``None``
|
||||
"""
|
||||
lulc_attribute_df = MODEL_SPEC.inputs.get(
|
||||
lulc_attribute_df = MODEL_SPEC.get_input(
|
||||
'lulc_attribute_table').get_validated_dataframe(lulc_attribute_table)
|
||||
|
||||
squared_pixel_area = abs(
|
||||
|
@ -1876,9 +1876,9 @@ def _reclassify_urban_nature_area(
|
|||
target_datatype=gdal.GDT_Float32,
|
||||
target_nodata=FLOAT32_NODATA,
|
||||
error_details={
|
||||
'raster_name': MODEL_SPEC.inputs.get('lulc_raster_path').name,
|
||||
'raster_name': MODEL_SPEC.get_input('lulc_raster_path').name,
|
||||
'column_name': 'urban_nature',
|
||||
'table_name': MODEL_SPEC.inputs.get('lulc_attribute_table').name
|
||||
'table_name': MODEL_SPEC.get_input('lulc_attribute_table').name
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -120,10 +120,10 @@ def _calculate_args_bounding_box(args, model_spec):
|
|||
# should already have been validated so the path is either valid or
|
||||
# blank.
|
||||
spatial_info = None
|
||||
if (isinstance(model_spec.inputs.get(key),
|
||||
if (isinstance(model_spec.get_input(key),
|
||||
spec_utils.SingleBandRasterInput) and value.strip() != ''):
|
||||
spatial_info = pygeoprocessing.get_raster_info(value)
|
||||
elif (isinstance(model_spec.inputs.get(key),
|
||||
elif (isinstance(model_spec.get_input(key),
|
||||
spec_utils.VectorInput) and value.strip() != ''):
|
||||
spatial_info = pygeoprocessing.get_vector_info(value)
|
||||
|
||||
|
@ -158,7 +158,7 @@ def _calculate_args_bounding_box(args, model_spec):
|
|||
LOGGER.exception(
|
||||
f'Error when transforming coordinates: {transform_error}')
|
||||
else:
|
||||
LOGGER.debug(f'Arg {key} of type {model_spec.inputs.get(key).__class__} '
|
||||
LOGGER.debug(f'Arg {key} of type {model_spec.get_input(key).__class__} '
|
||||
'excluded from bounding box calculation')
|
||||
|
||||
return bb_intersection, bb_union
|
||||
|
|
|
@ -303,7 +303,7 @@ def validate(args, spec):
|
|||
# we don't need to try to validate them
|
||||
try:
|
||||
# Using deepcopy to make sure we don't modify the original spec
|
||||
parameter_spec = copy.deepcopy(spec.inputs.get(key))
|
||||
parameter_spec = copy.deepcopy(spec.get_input(key))
|
||||
except KeyError:
|
||||
LOGGER.debug(f'Provided key {key} does not exist in MODEL_SPEC')
|
||||
continue
|
||||
|
@ -418,14 +418,12 @@ def invest_validator(validate_func):
|
|||
warnings_ = validate_func(args, limit_to)
|
||||
return warnings_
|
||||
|
||||
args_spec = model_module.MODEL_SPEC.inputs
|
||||
|
||||
if limit_to is None:
|
||||
LOGGER.info('Starting whole-model validation with MODEL_SPEC')
|
||||
warnings_ = validate_func(args)
|
||||
else:
|
||||
LOGGER.info('Starting single-input validation with MODEL_SPEC')
|
||||
args_key_spec = args_spec.get(limit_to)
|
||||
args_key_spec = model_module.MODEL_SPEC.get_input(limit_to)
|
||||
|
||||
args_value = args[limit_to]
|
||||
error_msg = None
|
||||
|
|
|
@ -755,14 +755,14 @@ def execute(args):
|
|||
LOGGER.debug('Machine Performance Rows : %s', machine_perf_dict['periods'])
|
||||
LOGGER.debug('Machine Performance Cols : %s', machine_perf_dict['heights'])
|
||||
|
||||
machine_param_dict = MODEL_SPEC.inputs.get(
|
||||
machine_param_dict = MODEL_SPEC.get_input(
|
||||
'machine_param_path').get_validated_dataframe(
|
||||
args['machine_param_path'])['value'].to_dict()
|
||||
|
||||
# Check if required column fields are entered in the land grid csv file
|
||||
if 'land_gridPts_path' in args:
|
||||
# Create a grid_land_df dataframe for later use in valuation
|
||||
grid_land_df = MODEL_SPEC.inputs.get(
|
||||
grid_land_df = MODEL_SPEC.get_input(
|
||||
'land_gridPts_path').get_validated_dataframe(args['land_gridPts_path'])
|
||||
missing_grid_land_fields = []
|
||||
for field in ['id', 'type', 'lat', 'long', 'location']:
|
||||
|
@ -775,7 +775,7 @@ def execute(args):
|
|||
'Connection Points File: %s' % missing_grid_land_fields)
|
||||
|
||||
if 'valuation_container' in args and args['valuation_container']:
|
||||
machine_econ_dict = MODEL_SPEC.inputs.get(
|
||||
machine_econ_dict = MODEL_SPEC.get_input(
|
||||
'machine_econ_path').get_validated_dataframe(
|
||||
args['machine_econ_path'])['value'].to_dict()
|
||||
|
||||
|
|
|
@ -717,11 +717,11 @@ def execute(args):
|
|||
number_of_turbines = int(args['number_of_turbines'])
|
||||
|
||||
# Read the biophysical turbine parameters into a dictionary
|
||||
turbine_dict = MODEL_SPEC.inputs.get(
|
||||
turbine_dict = MODEL_SPEC.get_input(
|
||||
'turbine_parameters_path').get_validated_dataframe(
|
||||
args['turbine_parameters_path']).iloc[0].to_dict()
|
||||
# Read the biophysical global parameters into a dictionary
|
||||
global_params_dict = MODEL_SPEC.inputs.get(
|
||||
global_params_dict = MODEL_SPEC.get_input(
|
||||
'global_wind_parameters_path').get_validated_dataframe(
|
||||
args['global_wind_parameters_path']).iloc[0].to_dict()
|
||||
|
||||
|
@ -741,7 +741,7 @@ def execute(args):
|
|||
# If Price Table provided use that for price of energy, validate inputs
|
||||
time = parameters_dict['time_period']
|
||||
if args['price_table']:
|
||||
wind_price_df = MODEL_SPEC.inputs.get(
|
||||
wind_price_df = MODEL_SPEC.get_input(
|
||||
'wind_schedule').get_validated_dataframe(
|
||||
args['wind_schedule']).sort_index() # sort by year
|
||||
|
||||
|
@ -1112,7 +1112,7 @@ def execute(args):
|
|||
LOGGER.info('Grid Points Provided. Reading in the grid points')
|
||||
|
||||
# Read the grid points csv, and convert it to land and grid dictionary
|
||||
grid_land_df = MODEL_SPEC.inputs.get(
|
||||
grid_land_df = MODEL_SPEC.get_input(
|
||||
'grid_points_path').get_validated_dataframe(args['grid_points_path'])
|
||||
|
||||
# Convert the dataframes to dictionaries, using 'ID' (the index) as key
|
||||
|
@ -1933,7 +1933,7 @@ def _compute_density_harvested_fields(
|
|||
|
||||
# Read the wind energy data into a dictionary
|
||||
LOGGER.info('Reading in Wind Data into a dictionary')
|
||||
wind_point_df = MODEL_SPEC.inputs.get(
|
||||
wind_point_df = MODEL_SPEC.get_input(
|
||||
'wind_data_path').get_validated_dataframe(wind_data_path)
|
||||
wind_point_df.columns = wind_point_df.columns.str.upper()
|
||||
# Calculate scale value at new hub height given reference values.
|
||||
|
@ -2672,7 +2672,7 @@ def validate(args, limit_to=None):
|
|||
'global_wind_parameters_path' in valid_sufficient_keys):
|
||||
year_count = utils.read_csv_to_dataframe(
|
||||
args['wind_schedule']).shape[0]
|
||||
time = MODEL_SPEC.inputs.get(
|
||||
time = MODEL_SPEC.get_input(
|
||||
'global_wind_parameters_path').get_validated_dataframe(
|
||||
args['global_wind_parameters_path']).iloc[0]['time_period']
|
||||
if year_count != time + 1:
|
||||
|
|
|
@ -397,7 +397,7 @@ class CLIUnitTests(unittest.TestCase):
|
|||
target_model = models.model_id_to_pyname[target_model]
|
||||
model_module = importlib.import_module(name=target_model)
|
||||
spec = model_module.MODEL_SPEC
|
||||
expected_args = {key: '' for key in spec.inputs.__dict__.keys()}
|
||||
expected_args = {key: '' for key in spec.inputs_dict.keys()}
|
||||
|
||||
module_name = str(uuid.uuid4()) + 'testscript'
|
||||
spec = importlib.util.spec_from_file_location(module_name, target_filepath)
|
||||
|
|
|
@ -211,7 +211,7 @@ class TestPreprocessor(unittest.TestCase):
|
|||
lulc_csv.write('0,mangrove,True\n')
|
||||
lulc_csv.write('1,parking lot,False\n')
|
||||
|
||||
landcover_df = preprocessor.MODEL_SPEC.inputs.get(
|
||||
landcover_df = preprocessor.MODEL_SPEC.get_input(
|
||||
'lulc_lookup_table_path').get_validated_dataframe(landcover_table_path)
|
||||
|
||||
target_table_path = os.path.join(self.workspace_dir,
|
||||
|
@ -227,7 +227,7 @@ class TestPreprocessor(unittest.TestCase):
|
|||
str(context.exception))
|
||||
|
||||
# Re-load the landcover table
|
||||
landcover_df = preprocessor.MODEL_SPEC.inputs.get(
|
||||
landcover_df = preprocessor.MODEL_SPEC.get_input(
|
||||
'lulc_lookup_table_path').get_validated_dataframe(landcover_table_path)
|
||||
preprocessor._create_transition_table(
|
||||
landcover_df, [filename_a, filename_b], target_table_path)
|
||||
|
@ -640,7 +640,7 @@ class TestCBC2(unittest.TestCase):
|
|||
args = TestCBC2._create_model_args(self.workspace_dir)
|
||||
args['workspace_dir'] = os.path.join(self.workspace_dir, 'workspace')
|
||||
|
||||
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.inputs.get(
|
||||
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.get_input(
|
||||
'landcover_snapshot_csv').get_validated_dataframe(
|
||||
args['landcover_snapshot_csv'])['raster_path'].to_dict()
|
||||
baseline_year = min(prior_snapshots.keys())
|
||||
|
@ -817,7 +817,7 @@ class TestCBC2(unittest.TestCase):
|
|||
args = TestCBC2._create_model_args(self.workspace_dir)
|
||||
args['workspace_dir'] = os.path.join(self.workspace_dir, 'workspace')
|
||||
|
||||
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.inputs.get(
|
||||
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.get_input(
|
||||
'landcover_snapshot_csv').get_validated_dataframe(
|
||||
args['landcover_snapshot_csv'])['raster_path'].to_dict()
|
||||
baseline_year = min(prior_snapshots.keys())
|
||||
|
@ -876,7 +876,7 @@ class TestCBC2(unittest.TestCase):
|
|||
|
||||
# Now work through the extra validation warnings.
|
||||
# test validation: invalid analysis year
|
||||
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.inputs.get(
|
||||
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.get_input(
|
||||
'landcover_snapshot_csv').get_validated_dataframe(
|
||||
args['landcover_snapshot_csv'])['raster_path'].to_dict()
|
||||
baseline_year = min(prior_snapshots)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from types import SimpleNamespace
|
||||
from natcap.invest import spec_utils
|
||||
|
||||
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
|
||||
MODEL_SPEC = spec_utils.ModelSpec(inputs=[
|
||||
spec_utils.StringInput(id='blank'),
|
||||
spec_utils.IntegerInput(id='a'),
|
||||
spec_utils.StringInput(id='b'),
|
||||
|
@ -23,5 +22,11 @@ MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
|
|||
band=spec_utils.NumberInput()
|
||||
)
|
||||
)
|
||||
)
|
||||
))
|
||||
)],
|
||||
outputs={},
|
||||
model_id='',
|
||||
model_title='',
|
||||
userguide='',
|
||||
ui_spec=spec_utils.UISpec(),
|
||||
args_with_spatial_overlap={}
|
||||
)
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
from types import SimpleNamespace
|
||||
from natcap.invest import spec_utils
|
||||
|
||||
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
|
||||
spec_utils.FileInput(id='foo'),
|
||||
spec_utils.FileInput(id='bar'),
|
||||
))
|
||||
MODEL_SPEC = spec_utils.ModelSpec(
|
||||
inputs=[
|
||||
spec_utils.FileInput(id='foo'),
|
||||
spec_utils.FileInput(id='bar')
|
||||
],
|
||||
outputs={},
|
||||
model_id='',
|
||||
model_title='',
|
||||
userguide='',
|
||||
ui_spec=spec_utils.UISpec(),
|
||||
args_with_spatial_overlap={}
|
||||
)
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
from types import SimpleNamespace
|
||||
from natcap.invest import spec_utils
|
||||
|
||||
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
|
||||
MODEL_SPEC = spec_utils.ModelSpec(inputs=[
|
||||
spec_utils.FileInput(id='some_file'),
|
||||
spec_utils.DirectoryInput(
|
||||
id='data_dir',
|
||||
contents=spec_utils.Contents())
|
||||
))
|
||||
contents=spec_utils.Contents())],
|
||||
outputs={},
|
||||
model_id='',
|
||||
model_title='',
|
||||
userguide='',
|
||||
ui_spec=spec_utils.UISpec(),
|
||||
args_with_spatial_overlap={}
|
||||
)
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
from types import SimpleNamespace
|
||||
from natcap.invest import spec_utils
|
||||
|
||||
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
|
||||
spec_utils.SingleBandRasterInput(id='raster', band=spec_utils.Input())
|
||||
))
|
||||
MODEL_SPEC = spec_utils.ModelSpec(inputs=[
|
||||
spec_utils.SingleBandRasterInput(id='raster', band=spec_utils.Input())],
|
||||
outputs={},
|
||||
model_id='',
|
||||
model_title='',
|
||||
userguide='',
|
||||
ui_spec=spec_utils.UISpec(),
|
||||
args_with_spatial_overlap={}
|
||||
)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from types import SimpleNamespace
|
||||
from natcap.invest import spec_utils
|
||||
|
||||
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
|
||||
MODEL_SPEC = spec_utils.ModelSpec(inputs=[
|
||||
spec_utils.IntegerInput(id='a'),
|
||||
spec_utils.StringInput(id='b'),
|
||||
spec_utils.StringInput(id='c'),
|
||||
|
@ -9,5 +8,11 @@ MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
|
|||
spec_utils.DirectoryInput(
|
||||
id='workspace_dir',
|
||||
contents=spec_utils.Contents()
|
||||
)
|
||||
))
|
||||
)],
|
||||
outputs={},
|
||||
model_id='',
|
||||
model_title='',
|
||||
userguide='',
|
||||
ui_spec=spec_utils.UISpec(),
|
||||
args_with_spatial_overlap={}
|
||||
)
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
from types import SimpleNamespace
|
||||
from natcap.invest import spec_utils
|
||||
|
||||
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
|
||||
MODEL_SPEC = SimpleNamespace(inputs=[
|
||||
spec_utils.StringInput(id='foo'),
|
||||
spec_utils.StringInput(id='bar')
|
||||
))
|
||||
spec_utils.StringInput(id='bar')],
|
||||
outputs={},
|
||||
model_id='',
|
||||
model_title='',
|
||||
userguide='',
|
||||
ui_spec=spec_utils.UISpec(),
|
||||
args_with_spatial_overlap={}
|
||||
)
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
from types import SimpleNamespace
|
||||
from natcap.invest import spec_utils
|
||||
|
||||
MODEL_SPEC = SimpleNamespace(inputs=spec_utils.ModelInputs(
|
||||
MODEL_SPEC = spec_utils.ModelSpec(inputs=[
|
||||
spec_utils.VectorInput(
|
||||
id='vector', fields={}, geometries={})
|
||||
))
|
||||
id='vector', fields={}, geometries={})],
|
||||
outputs={},
|
||||
model_id='',
|
||||
model_title='',
|
||||
userguide='',
|
||||
ui_spec=spec_utils.UISpec(),
|
||||
args_with_spatial_overlap={}
|
||||
)
|
||||
|
|
|
@ -710,7 +710,7 @@ class TestRecClientServer(unittest.TestCase):
|
|||
out_regression_vector_path = os.path.join(
|
||||
args['workspace_dir'], f'regression_data_{suffix}.gpkg')
|
||||
|
||||
predictor_df = recmodel_client.MODEL_SPEC.inputs.get(
|
||||
predictor_df = recmodel_client.MODEL_SPEC.get_input(
|
||||
'predictor_table_path').get_validated_dataframe(
|
||||
os.path.join(SAMPLE_DATA, 'predictors_all.csv'))
|
||||
field_list = list(predictor_df.index) + ['pr_TUD', 'pr_PUD', 'avg_pr_UD']
|
||||
|
@ -1284,7 +1284,7 @@ class RecreationClientRegressionTests(unittest.TestCase):
|
|||
predictor_table_path = os.path.join(SAMPLE_DATA, 'predictors.csv')
|
||||
|
||||
# make outputs to be overwritten
|
||||
predictor_dict = recmodel_client.MODEL_SPEC.inputs.get(
|
||||
predictor_dict = recmodel_client.MODEL_SPEC.get_input(
|
||||
'predictor_table_path').get_validated_dataframe(
|
||||
predictor_table_path).to_dict(orient='index')
|
||||
predictor_list = predictor_dict.keys()
|
||||
|
|
|
@ -269,7 +269,7 @@ class TestDescribeArgFromSpec(unittest.TestCase):
|
|||
expected_rst = (
|
||||
'.. _carbon-pools-path-columns-lucode:\n\n' +
|
||||
'**lucode** (`integer <input_types.html#integer>`__, *required*): ' +
|
||||
carbon.MODEL_SPEC.inputs.get('carbon_pools_path').columns.get('lucode').about
|
||||
carbon.MODEL_SPEC.get_input('carbon_pools_path').columns.get('lucode').about
|
||||
)
|
||||
self.assertEqual(repr(out), repr(expected_rst))
|
||||
|
||||
|
@ -318,7 +318,7 @@ class TestMetadataFromSpec(unittest.TestCase):
|
|||
"""Test writing metadata for an invest output workspace."""
|
||||
|
||||
# An example invest output spec
|
||||
output_spec = spec_utils.ModelOutputs(
|
||||
output_spec = [
|
||||
spec_utils.DirectoryOutput(
|
||||
id='output',
|
||||
contents=spec_utils.Contents(
|
||||
|
@ -348,7 +348,7 @@ class TestMetadataFromSpec(unittest.TestCase):
|
|||
spec_utils.build_output_spec('taskgraph_cache', spec_utils.TASKGRAPH_DIR)
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
# Generate an output workspace with real files, without
|
||||
# running an invest model.
|
||||
_generate_files_from_spec(output_spec, self.workspace_dir)
|
||||
|
@ -362,7 +362,7 @@ class TestMetadataFromSpec(unittest.TestCase):
|
|||
userguide='',
|
||||
aliases=[],
|
||||
ui_spec={},
|
||||
inputs=spec_utils.ModelInputs(),
|
||||
inputs={},
|
||||
args_with_spatial_overlap={},
|
||||
outputs=output_spec
|
||||
)
|
||||
|
|
|
@ -68,13 +68,13 @@ class UsageLoggingTests(unittest.TestCase):
|
|||
model_spec = spec_utils.ModelSpec(
|
||||
model_id='', model_title='', userguide=None,
|
||||
aliases=None, ui_spec=spec_utils.UISpec(order=[], hidden={}),
|
||||
inputs=spec_utils.ModelInputs(
|
||||
inputs=[
|
||||
spec_utils.SingleBandRasterInput(id='raster', band=spec_utils.Input()),
|
||||
spec_utils.VectorInput(id='vector', geometries={}, fields={}),
|
||||
spec_utils.StringInput(id='not_a_gis_input'),
|
||||
spec_utils.SingleBandRasterInput(id='blank_raster_path', band=spec_utils.Input()),
|
||||
spec_utils.VectorInput(id='blank_vector_path', geometries={}, fields={})
|
||||
),
|
||||
],
|
||||
outputs={},
|
||||
args_with_spatial_overlap=None)
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ from natcap.invest import spec_utils
|
|||
from natcap.invest.spec_utils import (
|
||||
u,
|
||||
ModelSpec,
|
||||
ModelInputs,
|
||||
UISpec,
|
||||
Fields,
|
||||
Contents,
|
||||
|
@ -50,7 +49,7 @@ def ui_spec_with_defaults(order=[], hidden=[]):
|
|||
return UISpec(order=order, hidden=hidden)
|
||||
|
||||
def model_spec_with_defaults(model_id='', model_title='', userguide='', aliases=None,
|
||||
ui_spec=ui_spec_with_defaults(), inputs=ModelInputs(), outputs=set(),
|
||||
ui_spec=ui_spec_with_defaults(), inputs={}, outputs={},
|
||||
args_with_spatial_overlap=[]):
|
||||
return ModelSpec(model_id=model_id, model_title=model_title, userguide=userguide,
|
||||
aliases=aliases, ui_spec=ui_spec, inputs=inputs, outputs=outputs,
|
||||
|
@ -278,8 +277,8 @@ class ValidatorTest(unittest.TestCase):
|
|||
from natcap.invest import spec_utils
|
||||
from natcap.invest import validation
|
||||
|
||||
args_spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec_utils.build_input_spec('n_workers', spec_utils.N_WORKERS)))
|
||||
args_spec = model_spec_with_defaults(inputs=[
|
||||
spec_utils.build_input_spec('n_workers', spec_utils.N_WORKERS)])
|
||||
|
||||
@validation.invest_validator
|
||||
def validate(args, limit_to=None):
|
||||
|
@ -918,8 +917,7 @@ class CSVValidation(unittest.TestCase):
|
|||
with open(path, 'w') as file:
|
||||
file.write('1,2,3')
|
||||
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
CSVInput(id="mock_csv_path")))
|
||||
spec = model_spec_with_defaults(inputs=[CSVInput(id="mock_csv_path")])
|
||||
|
||||
# validate a mocked CSV that will take 6 seconds to return a value
|
||||
args = {"mock_csv_path": path}
|
||||
|
@ -1590,14 +1588,14 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
"""Validation: check that conditional requirements works."""
|
||||
from natcap.invest import validation
|
||||
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
NumberInput(id="number_a"),
|
||||
NumberInput(id="number_b", required=False),
|
||||
NumberInput(id="number_c", required="number_b"),
|
||||
NumberInput(id="number_d", required="number_b | number_c"),
|
||||
NumberInput(id="number_e", required="number_b & number_d"),
|
||||
NumberInput(id="number_f", required="not number_b")
|
||||
))
|
||||
])
|
||||
|
||||
args = {
|
||||
"number_a": 123,
|
||||
|
@ -1629,11 +1627,11 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
"""Validation: check AssertionError if expression is missing a var."""
|
||||
from natcap.invest import validation
|
||||
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
NumberInput(id="number_a"),
|
||||
NumberInput(id="number_b", required=False),
|
||||
NumberInput(id="number_c", required="some_var_not_in_args")
|
||||
))
|
||||
])
|
||||
|
||||
args = {
|
||||
"number_a": 123,
|
||||
|
@ -1655,11 +1653,11 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
with open(csv_b_path, 'w') as csv:
|
||||
csv.write('1,2,3')
|
||||
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
BooleanInput(id="condition", required=False),
|
||||
CSVInput(id="csv_a", required="condition"),
|
||||
CSVInput(id="csv_b", required="not condition")
|
||||
))
|
||||
])
|
||||
|
||||
args = {
|
||||
"condition": True,
|
||||
|
@ -1673,9 +1671,9 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
def test_requirement_missing(self):
|
||||
"""Validation: verify absolute requirement on missing key."""
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
NumberInput(id="number_a", units=u.none)
|
||||
))
|
||||
])
|
||||
args = {}
|
||||
self.assertEqual(
|
||||
[(['number_a'], validation.MESSAGES['MISSING_KEY'])],
|
||||
|
@ -1684,9 +1682,9 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
def test_requirement_no_value(self):
|
||||
"""Validation: verify absolute requirement without value."""
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
NumberInput(id="number_a", units=u.none)
|
||||
))
|
||||
])
|
||||
|
||||
args = {'number_a': ''}
|
||||
self.assertEqual(
|
||||
|
@ -1701,9 +1699,9 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
def test_invalid_value(self):
|
||||
"""Validation: verify invalidity."""
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
NumberInput(id="number_a", units=u.none)
|
||||
))
|
||||
])
|
||||
|
||||
args = {'number_a': 'not a number'}
|
||||
self.assertEqual(
|
||||
|
@ -1714,9 +1712,9 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
def test_conditionally_required_no_value(self):
|
||||
"""Validation: verify conditional requirement when no value."""
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
NumberInput(id="number_a", units=u.none),
|
||||
StringInput(id="string_a", required="number_a")))
|
||||
StringInput(id="string_a", required="number_a")])
|
||||
|
||||
args = {'string_a': None, "number_a": 1}
|
||||
|
||||
|
@ -1727,27 +1725,27 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
def test_conditionally_required_invalid(self):
|
||||
"""Validation: verify conditional validity behavior when invalid."""
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
NumberInput(id="number_a", units=u.none),
|
||||
OptionStringInput(
|
||||
id="string_a",
|
||||
required="number_a",
|
||||
options=['AAA', 'BBB']
|
||||
)
|
||||
))
|
||||
])
|
||||
|
||||
args = {'string_a': "ZZZ", "number_a": 1}
|
||||
|
||||
self.assertEqual(
|
||||
[(['string_a'], validation.MESSAGES['INVALID_OPTION'].format(
|
||||
option_list=spec.inputs.get('string_a').options))],
|
||||
option_list=spec.get_input('string_a').options))],
|
||||
validation.validate(args, spec))
|
||||
|
||||
def test_conditionally_required_vector_fields(self):
|
||||
"""Validation: conditionally required vector fields."""
|
||||
from natcap.invest import spec_utils
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
NumberInput(
|
||||
id="some_number",
|
||||
expression="value > 0.5",
|
||||
|
@ -1761,7 +1759,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
RatioInput(id="field_b", required="some_number == 2")
|
||||
)
|
||||
)
|
||||
))
|
||||
])
|
||||
|
||||
def _create_vector(filepath, fields=[]):
|
||||
gpkg_driver = gdal.GetDriverByName('GPKG')
|
||||
|
@ -1808,7 +1806,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
def test_conditionally_required_csv_columns(self):
|
||||
"""Validation: conditionally required csv columns."""
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
number_input_spec_with_defaults(
|
||||
id="some_number",
|
||||
expression="value > 0.5"
|
||||
|
@ -1820,7 +1818,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
RatioInput(id="field_b", required="some_number == 2")
|
||||
)
|
||||
)
|
||||
))
|
||||
])
|
||||
|
||||
# Create a CSV file with only field_a
|
||||
csv_path = os.path.join(self.workspace_dir, 'table1.csv')
|
||||
|
@ -1865,7 +1863,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
"""Validation: conditionally required csv rows."""
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(
|
||||
inputs=ModelInputs(
|
||||
inputs=[
|
||||
number_input_spec_with_defaults(
|
||||
id="some_number",
|
||||
expression="value > 0.5"
|
||||
|
@ -1882,7 +1880,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
# Create a CSV file with only field_a
|
||||
csv_path = os.path.join(self.workspace_dir, 'table1.csv')
|
||||
|
@ -1924,9 +1922,9 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
def test_validation_exception(self):
|
||||
"""Validation: Verify error when an unexpected exception occurs."""
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
NumberInput(id="number_a")
|
||||
))
|
||||
])
|
||||
args = {'number_a': 1}
|
||||
|
||||
# Patch in a new function that raises an exception
|
||||
|
@ -1941,7 +1939,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
def test_conditionally_required_directory_contents(self):
|
||||
"""Validation: conditionally required directory contents."""
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
NumberInput(
|
||||
id="some_number",
|
||||
expression="value > 0.5",
|
||||
|
@ -1960,7 +1958,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
)
|
||||
))
|
||||
])
|
||||
path_1 = os.path.join(self.workspace_dir, 'file.1')
|
||||
with open(path_1, 'w') as my_file:
|
||||
my_file.write('col1,col2')
|
||||
|
@ -1991,9 +1989,9 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
def test_validation_other(self):
|
||||
"""Validation: verify no error when 'other' type."""
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
OtherInput(id="number_a")
|
||||
))
|
||||
])
|
||||
args = {'number_a': 1}
|
||||
self.assertEqual([], validation.validate(args, spec))
|
||||
|
||||
|
@ -2015,7 +2013,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
|
||||
del args[previous_key] # delete the last addition to the dict.
|
||||
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(*specs))
|
||||
spec = model_spec_with_defaults(inputs=specs)
|
||||
self.assertEqual(
|
||||
[(['arg_J'], validation.MESSAGES['MISSING_KEY'])],
|
||||
validation.validate(args, spec))
|
||||
|
@ -2025,7 +2023,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
from natcap.invest import validation
|
||||
|
||||
spec = model_spec_with_defaults(
|
||||
inputs=ModelInputs(
|
||||
inputs=[
|
||||
SingleBandRasterInput(
|
||||
id='raster_a',
|
||||
band=NumberInput(units=u.none)
|
||||
|
@ -2039,7 +2037,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
fields={},
|
||||
geometries={'POINT'}
|
||||
)
|
||||
),
|
||||
],
|
||||
args_with_spatial_overlap={
|
||||
'spatial_keys': ['raster_a', 'raster_b', 'vector_a'],
|
||||
'different_projections_ok': True
|
||||
|
@ -2095,7 +2093,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
from natcap.invest import validation
|
||||
|
||||
spec = model_spec_with_defaults(
|
||||
inputs=ModelInputs(
|
||||
inputs=[
|
||||
SingleBandRasterInput(
|
||||
id='raster_a',
|
||||
band=NumberInput(units=u.none)
|
||||
|
@ -2104,7 +2102,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
id='raster_b',
|
||||
band=NumberInput(units=u.none)
|
||||
)
|
||||
),
|
||||
],
|
||||
args_with_spatial_overlap={
|
||||
'spatial_keys': ['raster_a', 'raster_b'],
|
||||
'different_projections_ok': True
|
||||
|
@ -2139,7 +2137,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
from natcap.invest import validation
|
||||
|
||||
spec = model_spec_with_defaults(
|
||||
inputs=ModelInputs(
|
||||
inputs=[
|
||||
SingleBandRasterInput(
|
||||
id='raster_a',
|
||||
band=NumberInput(units=u.none)
|
||||
|
@ -2155,7 +2153,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
fields=Fields(),
|
||||
geometries={'POINT'}
|
||||
)
|
||||
),
|
||||
],
|
||||
args_with_spatial_overlap={
|
||||
'spatial_keys': ['raster_a', 'raster_b', 'vector_a'],
|
||||
'different_projections_ok': True
|
||||
|
@ -2205,8 +2203,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
from natcap.invest import validation
|
||||
|
||||
args = {'a': 'a', 'b': 'b'}
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
StringInput(id='a')))
|
||||
spec = model_spec_with_defaults(inputs=[StringInput(id='a')])
|
||||
message = 'DEBUG:natcap.invest.validation:Provided key b does not exist in MODEL_SPEC'
|
||||
|
||||
with self.assertLogs('natcap.invest.validation', level='DEBUG') as cm:
|
||||
|
@ -2224,8 +2221,7 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
'e': '0.5', # middle
|
||||
'f': '1' # upper bound
|
||||
}
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
*(RatioInput(id=name) for name in args)))
|
||||
spec = model_spec_with_defaults(inputs=[RatioInput(id=name) for name in args])
|
||||
|
||||
expected_warnings = [
|
||||
(['a'], validation.MESSAGES['NOT_A_NUMBER'].format(value=args['a'])),
|
||||
|
@ -2248,8 +2244,8 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
'e': '55.5', # middle
|
||||
'f': '100' # upper bound
|
||||
}
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
*[PercentInput(id=name) for name in args]))
|
||||
spec = model_spec_with_defaults(
|
||||
inputs=[PercentInput(id=name) for name in args])
|
||||
|
||||
expected_warnings = [
|
||||
(['a'], validation.MESSAGES['NOT_A_NUMBER'].format(value=args['a'])),
|
||||
|
@ -2270,8 +2266,8 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
'c': '-1', # negative integers are ok
|
||||
'd': '0'
|
||||
}
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
*[IntegerInput(id=name) for name in args]))
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
IntegerInput(id=name) for name in args])
|
||||
|
||||
expected_warnings = [
|
||||
(['a'], validation.MESSAGES['NOT_A_NUMBER'].format(value=args['a'])),
|
||||
|
@ -2300,12 +2296,12 @@ class TestArgsEnabled(unittest.TestCase):
|
|||
def test_args_enabled(self):
|
||||
"""Validation: test getting args enabled/disabled status."""
|
||||
from natcap.invest import validation
|
||||
spec = model_spec_with_defaults(inputs=ModelInputs(
|
||||
spec = model_spec_with_defaults(inputs=[
|
||||
Input(id='a'),
|
||||
Input(id='b', allowed='a'),
|
||||
Input(id='c', allowed='not a'),
|
||||
Input(id='d', allowed='b <= 3')
|
||||
))
|
||||
])
|
||||
args = {
|
||||
'a': 'foo',
|
||||
'b': 2,
|
||||
|
|
Loading…
Reference in New Issue