Merge pull request #1954 from natcap/release/3.16.0a1

3.16.0a1 alpha release
This commit is contained in:
James Douglass 2025-05-27 15:10:27 -07:00 committed by GitHub
commit 985e3d0985
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
121 changed files with 9196 additions and 6829 deletions

View File

@ -412,6 +412,22 @@ jobs:
yarn config set network-timeout 600000 -g
yarn install
- name: Download micromamba for distribution (MacOS)
if: matrix.os == 'macos-13'
run: |
curl -Ls https://micro.mamba.pm/api/micromamba/osx-64/latest | tar -xvj bin/micromamba
mv bin/micromamba dist/
./dist/micromamba --help # make sure the executable works
- name: Download micromamba for distribution (Windows)
if: matrix.os == 'windows-latest'
shell: pwsh
run: |
Invoke-Webrequest -URI https://micro.mamba.pm/api/micromamba/win-64/latest -OutFile micromamba.tar.bz2
tar xf micromamba.tar.bz2
MOVE -Force Library\bin\micromamba.exe dist\micromamba.exe
.\dist\micromamba.exe --help # make sure the executable works
- name: Authenticate GCP
if: github.event_name != 'pull_request'
uses: google-github-actions/auth@v2

View File

@ -42,7 +42,11 @@ jobs:
- uses: actions/checkout@v4
- name: Install dependencies
run: pip install twine
run: |
# Workaround for setuptools generating invalid metadata
# https://github.com/natcap/invest/issues/1913
pip install "twine>=6.1.0"
pip install -U packaging
- name: Extract version from autorelease branch name
if: ${{ github.head_ref }}

View File

@ -23,9 +23,9 @@
- Urban Flood Risk
- Urban Nature Access
- Urban Stormwater Retention
- Visitation: Recreation and Tourism
- Wave Energy
- Wind Energy
- Visitation: Recreation and Tourism
Workbench fixes/enhancements:
- Workbench
@ -64,6 +64,10 @@
Unreleased Changes
------------------
General
=======
* The workbench and the natcap.invest python package now support plugins.
Workbench
=========
* Metadata is now generated for files when creating a datastack (with any
@ -88,6 +92,49 @@ NDR
pixels were mistakenly used as real data. The effect of this change is most
pronounced along stream edges and should not affect the overall pattern of
results. (`#1845 <https://github.com/natcap/invest/issues/1845>`_)
* ``stream.tif`` is now saved in the main output folder rather than the
intermediate folder (`#1864 <https://github.com/natcap/invest/issues/1864>`_).
* Added a feature that allows the nutrient load to be entered as an
application rate or as an "extensive"/export measured value.
Previously, the model's biophysical table expected the ``load_[n|p]``
column to be an "extensive"/export measured value. Now, a new
column for both nitrogen and phosphorous, ``load_type_[n|p]``, is
required with expected values of either ``application-rate`` or
``measured-runoff``. See the Data Needs section of the NDR User
Guide for more details.
(`#1044 <https://github.com/natcap/invest/issues/1044>`_).
* Fixed a bug where input rasters (e.g. LULC) without a defined nodata value could
cause an OverflowError. (`#1904 <https://github.com/natcap/invest/issues/1904>`_).
Seasonal Water Yield
====================
* ``stream.tif`` is now saved in the main output folder rather than the
intermediate folder (`#1864 <https://github.com/natcap/invest/issues/1864>`_).
Urban Flood Risk
================
* The raster output ``Runoff_retention.tif`` has been renamed
``Runoff_retention_index.tif`` to clarify the difference between it and
``Runoff_retention_m3.tif``
(`#1837 <https://github.com/natcap/invest/issues/1837>`_).
Visitation: Recreation and Tourism
==================================
* user-day variables ``pr_PUD``, ``pr_TUD``, and ``avg_pr_UD`` are calculated
and written to ``regression_data.gpkg`` even if the Compute Regression
option is not selected.
(`#1893 <https://github.com/natcap/invest/issues/1893>`_).
Wind Energy
===========
* The model no longer returns results as rasters; instead, values are
written to the output ``wind_energy_points`` shapefile for each point
(`#1698 <https://github.com/natcap/invest/issues/1698>`_).
Any Decision Record (ADR): `ADR-0004: Remove Wind Energy Raster Outputs <https://github.com/natcap/invest/blob/main/doc/decision-records/ADR-0004-Remove-Wind-Energy-Raster-Outputs.md>`_
* The output ``wind_energy_points.shp`` no longer returns Harvested or
Valuation-related values for points that are invalid wind farm locations
due to depth or distance constraints
(`#1699 <https://github.com/natcap/invest/issues/1699>`_).
3.15.1 (2025-05-06)

View File

@ -2,15 +2,15 @@
DATA_DIR := data
GIT_SAMPLE_DATA_REPO := https://bitbucket.org/natcap/invest-sample-data.git
GIT_SAMPLE_DATA_REPO_PATH := $(DATA_DIR)/invest-sample-data
GIT_SAMPLE_DATA_REPO_REV := ecdab62bd6e2d3d9105e511cfd6884bf07f3d27b
GIT_SAMPLE_DATA_REPO_REV := 3d7a33c3d599daaec087a9c283a0c6b8377210f5
GIT_TEST_DATA_REPO := https://bitbucket.org/natcap/invest-test-data.git
GIT_TEST_DATA_REPO_PATH := $(DATA_DIR)/invest-test-data
GIT_TEST_DATA_REPO_REV := f0ebe739207ae57ae53a285d0fd954d6e8cfee54
GIT_TEST_DATA_REPO_REV := d8a7397ba5992de7a80260e40956e4c29176383d
GIT_UG_REPO := https://github.com/natcap/invest.users-guide
GIT_UG_REPO_PATH := doc/users-guide
GIT_UG_REPO_REV := 7d83c5bf05f0bef8dd4d2a4bd2f565ecf270af75
GIT_UG_REPO_REV := e3f9d7b0ac78d948a532ffaa6d71dc464cf41403
ENV = "./env"
ifeq ($(OS),Windows_NT)

View File

@ -20,6 +20,7 @@ import sys
from unittest.mock import MagicMock
import natcap.invest
from natcap.invest import models
from sphinx.ext import apidoc
DOCS_SOURCE_DIR = os.path.dirname(__file__)
@ -194,26 +195,16 @@ see :ref:`CreatingPythonScripts`.
:local:
"""
MODEL_ENTRYPOINTS_FILE = os.path.join(DOCS_SOURCE_DIR, 'models.rst')
# Find all importable modules with an execute function
# write out to a file models.rst in the source directory
invest_model_modules = {}
for _, name, _ in pkgutil.walk_packages(path=[INVEST_LIB_DIR],
prefix='natcap.'):
module = importlib.import_module(name)
# any module with a MODEL_SPEC is an invest model
if hasattr(module, 'MODEL_SPEC'):
model_title = module.MODEL_SPEC['model_name']
invest_model_modules[model_title] = name
# Write sphinx autodoc function for each entrypoint
with open(MODEL_ENTRYPOINTS_FILE, 'w') as models_rst:
models_rst.write(MODEL_RST_TEMPLATE)
for model_title, name in sorted(invest_model_modules.items()):
for model_id, pyname in sorted(models.model_id_to_pyname.items()):
model_title = models.model_id_to_spec[model_id].model_title
underline = ''.join(['=']*len(model_title))
models_rst.write(
f'{model_title}\n'
f'{underline}\n'
f'.. autofunction:: {name}.execute\n'
f'.. autofunction:: {pyname}.execute\n'
' :noindex:\n\n')

View File

@ -0,0 +1,44 @@
# ADR-0004: Remove Wind Energy Raster Outputs
Author: Megan Nissel
Science Lead: Rob Griffin
## Context
The Wind Energy model has three major data inputs required for all runs: a Wind Data Points CSV, containing Weibull parameters for each wind data point; a Bathymetry raster; and a CSV of global wind energy infrastructure parameters. Within the wind data points CSV, each row represents a discrete geographic coordinate point. During the model run, this CSV gets converted to a point vector and then the data are interpolated onto rasters.
When run without the valuation component, the model outputs the following:
- `density_W_per_m2.tif`: a raster representing power density (W/m^2) centered on a pixel.
- `harvested_energy_MWhr_per_yr.tif`: a raster representing the annual harvested energy from a farm centered on that pixel.
- `wind_energy_points.shp`: a vector (with points corresponding to those in the input Wind Energy points CSV) that summarizes the outputs of the two rasters.
When run with the valuation component, the model outputs three additional rasters in addition to the two listed above: `carbon_emissions_tons.tif`, `levelized_cost_price_per_kWh.tif`, and `npv.tif`. These values are not currently summarized in `wind_energy_points.shp`.
Users noticed the raster outputs included data in areas outside of those covered by the input Wind Data, resulting from the model's method of interpolating the vector data to the rasters. This led to a larger discussion around the validity of the interpolated raster results.
## Decision
Based on Rob's own use of the model, and review and evaluation of the problem, the consensus is that the model's current use of interpolation introduces too many potential violations of the constraints of the model (e.g. interpolating over areas that are invlaid due to ocean depth or distance from shore, or are outside of the areas included in the input wind speed data) and requires assumptions that may not be helpful for users. Rob therefore recommended removing the raster outputs entirely and retaining the associated values in the output `wind_energy_points.shp` vector.
As such, we have decided to move forward with removing the rasterized outputs:
- `carbon_emissions_tons.tif`
- `density_W_per_m2.tif`
- `harvested_energy_MWhr_per_yr.tif`
- `levelized_cost_price_per_kWh.tif`
- `npv.tif`
The model will need to be updated so that the valuation component also writes values to `wind_energy_points.shp`.
## Status
## Consequences
Once released, the model will no longer provide the rasterized outputs that it previously provided. Instead, values for each point will appear in `wind_energy_points.shp`. This vector will also contain valuation data if the model's valuation component is run.
## References
GitHub:
* [Pull Request](https://github.com/natcap/invest/pull/1898)
* [Discussion: Raster result values returned outside of wind data](https://github.com/natcap/invest/issues/1698)
* [User's Guide PR](https://github.com/natcap/invest.users-guide/pull/178)

View File

@ -20,7 +20,7 @@ pypiwin32; sys_platform == 'win32' # pip-only
# 60.7.0 exception because of https://github.com/pyinstaller/pyinstaller/issues/6564
setuptools>=8.0,!=60.7.0
PyInstaller>=4.10 # pip-only
PyInstaller>=6.9.0
setuptools_scm>=6.4.0
requests
coverage

View File

@ -11,6 +11,7 @@ import tempfile
import unittest
from natcap.invest import datastack
from natcap.invest import models
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger('invest-autovalidate.py')
@ -33,7 +34,7 @@ class ValidateExceptionTests(unittest.TestCase):
self.workspace, 'dummy.invs.json')
with open(datastack_path, 'w') as file:
file.write('"args": {"something": "else"},')
file.write('"model_name": natcap.invest.carbon')
file.write('"model_id": "carbon"')
with self.assertRaises(ValueError):
main(self.workspace)
@ -43,7 +44,7 @@ class ValidateExceptionTests(unittest.TestCase):
self.workspace, 'dummy.invs.json')
with open(datastack_path, 'w') as file:
file.write('"args": {"workspace_dir": "/home/foo"},')
file.write('"model_name": natcap.invest.carbon')
file.write('"model_id": "carbon"')
with self.assertRaises(ValueError):
main(self.workspace)
@ -71,17 +72,18 @@ def main(sampledatadir):
if 'workspace_dir' in paramset.args and \
paramset.args['workspace_dir'] != '':
msg = (
'%s : workspace_dir should not be defined '
'for sample datastacks' % datastack_path)
f'{datastack_path} : workspace_dir should not be defined '
'for sample datastacks' )
validation_messages += os.linesep + msg
LOGGER.error(msg)
else:
paramset.args['workspace_dir'] = tempfile.mkdtemp()
model_module = importlib.import_module(name=paramset.model_name)
model_module = importlib.import_module(
name=models.model_id_to_pyname[paramset.model_id])
model_warnings = [] # define here in case of uncaught exception.
try:
LOGGER.info('validating %s ', os.path.abspath(datastack_path))
LOGGER.info(f'validating {os.path.abspath(datastack_path)}')
model_warnings = getattr(
model_module, 'validate')(paramset.args)
except AttributeError as err:

View File

@ -16,15 +16,15 @@ def main(userguide_dir):
Raises:
OSError if any models reference files that do not exist.
"""
from natcap.invest.model_metadata import MODEL_METADATA
from natcap.invest import models
missing_files = []
userguide_dir_source = os.path.join(userguide_dir, 'source', 'en')
for data in MODEL_METADATA.values():
for module in models.pyname_to_module.values():
# html referenced won't exist unless we actually built the UG,
# so check for the rst with the same basename.
model_rst = f'{os.path.splitext(data.userguide)[0]}.rst'
model_rst = f'{os.path.splitext(module.MODEL_SPEC.userguide)[0]}.rst'
if not os.path.exists(os.path.join(
userguide_dir_source, model_rst)):
missing_files.append(data.userguide)

View File

@ -11,10 +11,9 @@ from osgeo import gdal
from osgeo import ogr
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
@ -105,11 +104,20 @@ WATERSHED_OUTPUT_FIELDS = {
**VALUATION_OUTPUT_FIELDS
}
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "annual_water_yield",
"model_name": MODEL_METADATA["annual_water_yield"].model_title,
"pyname": MODEL_METADATA["annual_water_yield"].pyname,
"userguide": MODEL_METADATA["annual_water_yield"].userguide,
"model_title": gettext("Annual Water Yield"),
"userguide": "annual_water_yield.html",
"aliases": ("hwy", "awy"),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['precipitation_path', 'eto_path', 'depth_to_root_rest_layer_path', 'pawc_path'],
['lulc_path', 'biophysical_table_path', 'seasonality_constant'],
['watersheds_path', 'sub_watersheds_path'],
['demand_table_path', 'valuation_table_path']
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["lulc_path",
"depth_to_root_rest_layer_path",
@ -121,13 +129,13 @@ MODEL_SPEC = {
"different_projections_ok": False,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"lulc_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"about": spec_utils.LULC['about'] + " " + gettext(
"about": spec.LULC['about'] + " " + gettext(
"All values in this raster must have corresponding entries "
"in the Biophysical Table.")
},
@ -145,7 +153,7 @@ MODEL_SPEC = {
"name": gettext("root restricting layer depth")
},
"precipitation_path": {
**spec_utils.PRECIP,
**spec.PRECIP,
"projected": True
},
"pawc_path": {
@ -159,7 +167,7 @@ MODEL_SPEC = {
"name": gettext("plant available water content")
},
"eto_path": {
**spec_utils.ET0,
**spec.ET0,
"projected": True
},
"watersheds_path": {
@ -171,7 +179,7 @@ MODEL_SPEC = {
"about": gettext("Unique identifier for each watershed.")
}
},
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"about": gettext(
"Map of watershed boundaries, such that each watershed drains "
"to a point of interest where hydropower production will be "
@ -187,7 +195,7 @@ MODEL_SPEC = {
"about": gettext("Unique identifier for each subwatershed.")
}
},
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"required": False,
"about": gettext(
"Map of subwatershed boundaries within each watershed in "
@ -197,7 +205,7 @@ MODEL_SPEC = {
"biophysical_table_path": {
"type": "csv",
"columns": {
"lucode": spec_utils.LULC_TABLE_COLUMN,
"lucode": spec.LULC_TABLE_COLUMN,
"lulc_veg": {
"type": "integer",
"about": gettext(
@ -336,7 +344,7 @@ MODEL_SPEC = {
"contents": {
"watershed_results_wyield.shp": {
"fields": {**WATERSHED_OUTPUT_FIELDS},
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"about": "Shapefile containing biophysical output values per watershed."
},
"watershed_results_wyield.csv": {
@ -346,7 +354,7 @@ MODEL_SPEC = {
},
"subwatershed_results_wyield.shp": {
"fields": {**SUBWATERSHED_OUTPUT_FIELDS},
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"about": "Shapefile containing biophysical output values per subwatershed."
},
"subwatershed_results_wyield.csv": {
@ -435,9 +443,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_dir": spec_utils.TASKGRAPH_DIR
"taskgraph_dir": spec.TASKGRAPH_DIR
}
}
})
def execute(args):
@ -526,9 +534,8 @@ def execute(args):
'Checking that watersheds have entries for every `ws_id` in the '
'valuation table.')
# Open/read in valuation parameters from CSV file
valuation_df = validation.get_validated_dataframe(
args['valuation_table_path'],
**MODEL_SPEC['args']['valuation_table_path'])
valuation_df = MODEL_SPEC.get_input(
'valuation_table_path').get_validated_dataframe(args['valuation_table_path'])
watershed_vector = gdal.OpenEx(
args['watersheds_path'], gdal.OF_VECTOR)
watershed_layer = watershed_vector.GetLayer()
@ -650,15 +657,16 @@ def execute(args):
'lulc': pygeoprocessing.get_raster_info(clipped_lulc_path)['nodata'][0]}
# Open/read in the csv file into a dictionary and add to arguments
bio_df = validation.get_validated_dataframe(args['biophysical_table_path'],
**MODEL_SPEC['args']['biophysical_table_path'])
bio_df = MODEL_SPEC.get_input('biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
bio_lucodes = set(bio_df.index.values)
bio_lucodes.add(nodata_dict['lulc'])
LOGGER.debug(f'bio_lucodes: {bio_lucodes}')
if 'demand_table_path' in args and args['demand_table_path'] != '':
demand_df = validation.get_validated_dataframe(
args['demand_table_path'], **MODEL_SPEC['args']['demand_table_path'])
demand_df = MODEL_SPEC.get_input('demand_table_path').get_validated_dataframe(
args['demand_table_path'])
demand_reclassify_dict = dict(
[(lucode, row['demand']) for lucode, row in demand_df.iterrows()])
demand_lucodes = set(demand_df.index.values)
@ -1314,5 +1322,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
return validation.validate(args, MODEL_SPEC)

View File

@ -12,9 +12,8 @@ import taskgraph
from . import validation
from . import utils
from . import spec_utils
from . import spec
from .unit_registry import u
from .model_metadata import MODEL_METADATA
from . import gettext
LOGGER = logging.getLogger(__name__)
@ -39,20 +38,29 @@ CARBON_OUTPUTS = {
]
}
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "carbon",
"model_name": MODEL_METADATA["carbon"].model_title,
"pyname": MODEL_METADATA["carbon"].pyname,
"userguide": MODEL_METADATA["carbon"].userguide,
"model_title": gettext("Carbon Storage and Sequestration"),
"userguide": "carbonstorage.html",
"aliases": (),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['lulc_bas_path', 'carbon_pools_path'],
['calc_sequestration', 'lulc_alt_path'],
['do_valuation', 'lulc_bas_year', 'lulc_alt_year', 'price_per_metric_ton_of_c', 'discount_rate', 'rate_change'],
],
"forum_tag": 'carbon'
},
"args_with_spatial_overlap": {
"spatial_keys": ["lulc_bas_path", "lulc_alt_path"],
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"lulc_bas_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"projection_units": u.meter,
"about": gettext(
@ -71,10 +79,11 @@ MODEL_SPEC = {
"name": gettext("calculate sequestration")
},
"lulc_alt_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"projection_units": u.meter,
"required": "calc_sequestration",
"allowed": "calc_sequestration",
"about": gettext(
"A map of LULC for the alternate scenario, which must occur "
"after the baseline scenario. All values in this raster must "
@ -86,7 +95,7 @@ MODEL_SPEC = {
"carbon_pools_path": {
"type": "csv",
"columns": {
"lucode": spec_utils.LULC_TABLE_COLUMN,
"lucode": spec.LULC_TABLE_COLUMN,
"c_above": {
"type": "number",
"units": u.metric_ton/u.hectare,
@ -115,6 +124,7 @@ MODEL_SPEC = {
"type": "number",
"units": u.year_AD,
"required": "do_valuation",
"allowed": "do_valuation",
"about": gettext(
"The calendar year of the baseline scenario depicted in the "
"baseline LULC map. Must be < alternate LULC year. Required "
@ -126,6 +136,7 @@ MODEL_SPEC = {
"type": "number",
"units": u.year_AD,
"required": "do_valuation",
"allowed": "do_valuation",
"about": gettext(
"The calendar year of the alternate scenario depicted in the "
"alternate LULC map. Must be > baseline LULC year. Required "
@ -135,6 +146,7 @@ MODEL_SPEC = {
"do_valuation": {
"type": "boolean",
"required": False,
"allowed": "calc_sequestration",
"about": gettext(
"Calculate net present value for the alternate scenario "
"and report it in the final HTML document."),
@ -144,6 +156,7 @@ MODEL_SPEC = {
"type": "number",
"units": u.currency/u.metric_ton,
"required": "do_valuation",
"allowed": "do_valuation",
"about": gettext(
"The present value of carbon. "
"Required if Run Valuation model is selected."),
@ -152,6 +165,7 @@ MODEL_SPEC = {
"discount_rate": {
"type": "percent",
"required": "do_valuation",
"allowed": "do_valuation",
"about": gettext(
"The annual market discount rate in the price of carbon, "
"which reflects society's preference for immediate benefits "
@ -163,6 +177,7 @@ MODEL_SPEC = {
"rate_change": {
"type": "percent",
"required": "do_valuation",
"allowed": "do_valuation",
"about": gettext(
"The relative annual change of the price of carbon. "
"Required if Run Valuation model is selected."),
@ -210,9 +225,9 @@ MODEL_SPEC = {
**CARBON_OUTPUTS
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
_OUTPUT_BASE_FILES = {
'c_storage_bas': 'c_storage_bas.tif',
@ -305,8 +320,8 @@ def execute(args):
"Baseline LULC Year is earlier than the Alternate LULC Year."
)
carbon_pool_df = validation.get_validated_dataframe(
args['carbon_pools_path'], **MODEL_SPEC['args']['carbon_pools_path'])
carbon_pool_df = MODEL_SPEC.get_input(
'carbon_pools_path').get_validated_dataframe(args['carbon_pools_path'])
try:
n_workers = int(args['n_workers'])
@ -677,5 +692,4 @@ def validate(args, limit_to=None):
the error message in the second part of the tuple. This should
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
return validation.validate(args, MODEL_SPEC)

View File

@ -3,6 +3,7 @@
import argparse
import codecs
import datetime
import gettext
import importlib
import json
import logging
@ -14,84 +15,103 @@ import warnings
import natcap.invest
from natcap.invest import datastack
from natcap.invest import model_metadata
from natcap.invest import spec_utils
from natcap.invest import set_locale
from natcap.invest import spec
from natcap.invest import ui_server
from natcap.invest import utils
from natcap.invest import models
from pygeoprocessing.geoprocessing_core import GDALUseExceptions
DEFAULT_EXIT_CODE = 1
LOGGER = logging.getLogger(__name__)
# Build up an index mapping aliases to model_name.
# ``model_name`` is the key to the MODEL_METADATA dict.
_MODEL_ALIASES = {}
for model_name, meta in model_metadata.MODEL_METADATA.items():
for alias in meta.aliases:
assert alias not in _MODEL_ALIASES, (
'Alias %s already defined for model %s') % (
alias, _MODEL_ALIASES[alias])
_MODEL_ALIASES[alias] = model_name
def build_model_list_table():
def build_model_list_table(locale_code):
"""Build a table of model names, aliases and other details.
This table is a table only in the sense that its contents are aligned
into columns, but are not separated by a delimiter. This table
is intended to be printed to stdout.
Args:
locale_code (str): Language code to pass to gettext. The model names
will be returned in this language.
Returns:
A string representation of the formatted table.
"""
from natcap.invest import gettext
model_names = sorted(model_metadata.MODEL_METADATA.keys())
max_model_name_length = max(len(name) for name in model_names)
from natcap.invest import LOCALE_DIR
translation = gettext.translation(
'messages',
languages=[locale_code],
localedir=LOCALE_DIR,
# fall back to a NullTranslation, which returns the English messages
fallback=True)
max_model_id_length = max(
len(_id) for _id in models.model_id_to_spec.keys())
# Adding 3 to max alias name length for the parentheses plus some padding.
max_alias_name_length = max(len(', '.join(meta.aliases))
for meta in model_metadata.MODEL_METADATA.values()) + 3
template_string = ' {model_name} {aliases} {model_title} {usage}'
strings = [gettext('Available models:')]
for model_name in model_names:
usage_string = '(No GUI available)'
if model_metadata.MODEL_METADATA[model_name].gui is not None:
usage_string = ''
max_alias_name_length = max(len(', '.join(
model_spec.aliases)) for model_spec in models.model_id_to_spec.values()) + 3
template_string = ' {model_id} {aliases} {model_title}'
strings = [translation.gettext('Available models:')]
for model_id, model_spec in models.model_id_to_spec.items():
alias_string = ', '.join(model_metadata.MODEL_METADATA[model_name].aliases)
alias_string = ', '.join(model_spec.aliases)
if alias_string:
alias_string = '(%s)' % alias_string
alias_string = f'({alias_string})'
strings.append(template_string.format(
model_name=model_name.ljust(max_model_name_length),
model_id=model_id.ljust(max_model_id_length),
aliases=alias_string.ljust(max_alias_name_length),
model_title=model_metadata.MODEL_METADATA[model_name].model_title,
usage=usage_string))
model_title=translation.gettext(model_spec.model_title)))
return '\n'.join(strings) + '\n'
def build_model_list_json():
def build_model_list_json(locale_code):
"""Build a json object of relevant information for the CLI.
The json object returned uses the human-readable model names for keys
and the values are another dict containing the internal name
of the model and the aliases recognized by the CLI.
Args:
locale_code (str): Language code to pass to gettext. The model names
will be returned in this language.
Returns:
A string representation of the JSON object.
"""
from natcap.invest import LOCALE_DIR
translation = gettext.translation(
'messages',
languages=[locale_code],
localedir=LOCALE_DIR,
# fall back to a NullTranslation, which returns the English messages
fallback=True)
json_object = {}
for model_name, model_data in model_metadata.MODEL_METADATA.items():
json_object[model_data.model_title] = {
'model_name': model_name,
'aliases': model_data.aliases
for model_id, model_spec in models.model_id_to_spec.items():
json_object[model_id] = {
'model_title': translation.gettext(model_spec.model_title),
'aliases': model_spec.aliases
}
return json.dumps(json_object)
def export_to_python(target_filepath, model, args_dict=None):
def export_to_python(target_filepath, model_id, args_dict=None):
"""Generate a python script that executes a model.
Args:
target_filepath (str): path to generate the python file
model_id (str): ID of the model to generate the script for
args_dict (dict): If provided, prefill these arg values in the script
Returns:
None
"""
script_template = textwrap.dedent("""\
# coding=UTF-8
# -----------------------------------------------
@ -121,17 +141,14 @@ def export_to_python(target_filepath, model, args_dict=None):
""")
if args_dict is None:
model_module = importlib.import_module(
name=model_metadata.MODEL_METADATA[model].pyname)
spec = model_module.MODEL_SPEC
cast_args = {key: '' for key in spec['args'].keys()}
cast_args = {
arg_spec.id: '' for arg_spec in models.model_id_to_spec[model_id].inputs}
else:
cast_args = dict((str(key), value) for (key, value)
in args_dict.items())
with codecs.open(target_filepath, 'w', encoding='utf-8') as py_file:
args = pprint.pformat(cast_args, indent=4) # 4 spaces
# Tweak formatting from pprint:
# * Bump parameter inline with starting { to next line
# * add trailing comma to last item item pair
@ -141,8 +158,8 @@ def export_to_python(target_filepath, model, args_dict=None):
py_file.write(script_template.format(
invest_version=natcap.invest.__version__,
today=datetime.datetime.now().strftime('%c'),
model_title=model_metadata.MODEL_METADATA[model].model_title,
pyname=model_metadata.MODEL_METADATA[model].pyname,
model_title=models.model_id_to_spec[model_id].model_title,
pyname=models.model_id_to_pyname[model_id],
model_args=args))
@ -161,11 +178,11 @@ class SelectModelAction(argparse.Action):
Identifiable model names are:
* the model name (verbatim) as identified in the keys of MODEL_METADATA
* the model id (exactly matching the MODEL_SPEC.model_id)
* a uniquely identifiable prefix for the model name (e.g. "d"
matches "delineateit", but "co" matches both
"coastal_vulnerability" and "coastal_blue_carbon").
* a known model alias, as registered in MODEL_METADATA
* a known model alias, as registered in MODEL_SPEC.aliases
If no single model can be identified based on these rules, an error
message is printed and the parser exits with a nonzero exit code.
@ -176,7 +193,7 @@ class SelectModelAction(argparse.Action):
Overridden from argparse.Action.__call__.
"""
known_models = sorted(list(model_metadata.MODEL_METADATA.keys()))
known_models = sorted(list(models.model_id_to_spec.keys()))
matching_models = [model for model in known_models if
model.startswith(values)]
@ -185,11 +202,11 @@ class SelectModelAction(argparse.Action):
model == values]
if len(matching_models) == 1: # match an identifying substring
modelname = matching_models[0]
elif len(exact_matches) == 1: # match an exact modelname
modelname = exact_matches[0]
elif values in _MODEL_ALIASES: # match an alias
modelname = _MODEL_ALIASES[values]
model_id = matching_models[0]
elif len(exact_matches) == 1: # match an exact model id
model_id = exact_matches[0]
elif values in models.model_alias_to_id: # match an alias
model_id = models.model_alias_to_id[values]
elif len(matching_models) == 0:
parser.exit(status=1, message=(
"Error: '%s' not a known model" % values))
@ -201,7 +218,7 @@ class SelectModelAction(argparse.Action):
" {matching_models}").format(
model=values,
matching_models=' '.join(matching_models)))
setattr(namespace, self.dest, modelname)
setattr(namespace, self.dest, model_id)
def main(user_args=None):
@ -359,12 +376,10 @@ def main(user_args=None):
logging.getLogger('natcap').setLevel(logging.DEBUG)
if args.subcommand == 'list':
# reevaluate the model names in the new language
importlib.reload(model_metadata)
if args.json:
message = build_model_list_json()
message = build_model_list_json(args.language)
else:
message = build_model_list_table()
message = build_model_list_table(args.language)
sys.stdout.write(message)
parser.exit()
@ -379,7 +394,7 @@ def main(user_args=None):
# reload validation module first so it's also in the correct language
importlib.reload(importlib.import_module('natcap.invest.validation'))
model_module = importlib.reload(importlib.import_module(
name=parsed_datastack.model_name))
name=models.model_id_to_pyname[parsed_datastack.model_id]))
try:
validation_result = model_module.validate(parsed_datastack.args)
@ -412,15 +427,12 @@ def main(user_args=None):
parser.exit(0)
if args.subcommand == 'getspec':
target_model = model_metadata.MODEL_METADATA[args.model].pyname
target_model = models.model_id_to_pyname[args.model]
model_module = importlib.reload(
importlib.import_module(name=target_model))
spec = model_module.MODEL_SPEC
model_spec = model_module.MODEL_SPEC
if args.json:
message = spec_utils.serialize_args_spec(spec)
else:
message = pprint.pformat(spec)
message = model_spec.to_json()
sys.stdout.write(message)
parser.exit(0)
@ -448,18 +460,20 @@ def main(user_args=None):
else:
parsed_datastack.args['workspace_dir'] = args.workspace
target_model = model_metadata.MODEL_METADATA[args.model].pyname
target_model = models.model_id_to_pyname[args.model]
model_module = importlib.import_module(name=target_model)
LOGGER.info('Imported target %s from %s',
model_module.__name__, model_module)
with utils.prepare_workspace(parsed_datastack.args['workspace_dir'],
name=parsed_datastack.model_name,
model_id=parsed_datastack.model_id,
logging_level=log_level):
LOGGER.log(datastack.ARGS_LOG_LEVEL,
'Starting model with parameters: \n%s',
datastack.format_args_dict(parsed_datastack.args,
parsed_datastack.model_name))
LOGGER.log(
datastack.ARGS_LOG_LEVEL,
'Starting model with parameters: \n%s',
datastack.format_args_dict(
parsed_datastack.args,
parsed_datastack.model_id))
# We're deliberately not validating here because the user
# can just call ``invest validate <datastack>`` to validate.
@ -472,7 +486,7 @@ def main(user_args=None):
try:
# If there's an exception from creating metadata
# I don't think we want to indicate a model failure
spec_utils.generate_metadata_for_outputs(
spec.generate_metadata_for_outputs(
model_module, parsed_datastack.args)
except Exception as exc:
LOGGER.warning(
@ -489,7 +503,6 @@ def main(user_args=None):
export_to_python(target_filepath, args.model)
parser.exit()
if __name__ == '__main__':
multiprocessing.freeze_support()
main()

View File

@ -104,10 +104,9 @@ import taskgraph
from osgeo import gdal
from .. import gettext
from .. import spec_utils
from .. import spec
from .. import utils
from .. import validation
from ..model_metadata import MODEL_METADATA
from ..unit_registry import u
LOGGER = logging.getLogger(__name__)
@ -159,15 +158,114 @@ CARBON_STOCK_AT_YEAR_RASTER_PATTERN = 'carbon-stock-at-{year}{suffix}.tif'
INTERMEDIATE_DIR_NAME = 'intermediate'
OUTPUT_DIR_NAME = 'output'
MODEL_SPEC = {
BIOPHYSICAL_TABLE_COLUMNS = {
"lucode": {
"type": "integer",
"about": gettext(
"The LULC code that represents this LULC "
"class in the LULC snapshot rasters.")},
"lulc-class": {
"type": "freestyle_string",
"about": gettext(
"Name of the LULC class. This label must be "
"unique among the all the LULC classes.")},
"biomass-initial": {
"type": "number",
"units": u.megatonne/u.hectare,
"about": gettext(
"The initial carbon stocks in the biomass pool for "
"this LULC class.")},
"soil-initial": {
"type": "number",
"units": u.megatonne/u.hectare,
"about": gettext(
"The initial carbon stocks in the soil pool for this "
"LULC class.")},
"litter-initial": {
"type": "number",
"units": u.megatonne/u.hectare,
"about": gettext(
"The initial carbon stocks in the litter pool for "
"this LULC class.")},
"biomass-half-life": {
"type": "number",
"units": u.year,
"expression": "value > 0",
"about": gettext("The half-life of carbon in the biomass pool.")},
"biomass-low-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the biomass pool that "
"is disturbed when a cell transitions away from this "
" LULC class in a low-impact disturbance.")},
"biomass-med-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the biomass pool that "
"is disturbed when a cell transitions away from this "
"LULC class in a medium-impact disturbance.")},
"biomass-high-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the biomass pool that "
"is disturbed when a cell transitions away from this "
"LULC class in a high-impact disturbance.")},
"biomass-yearly-accumulation": {
"type": "number",
"units": u.megatonne/u.hectare/u.year,
"about": gettext(
"Annual rate of CO2E accumulation in the biomass pool.")},
"soil-half-life": {
"type": "number",
"units": u.year,
"expression": "value > 0",
"about": gettext("The half-life of carbon in the soil pool.")},
"soil-low-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the soil pool that "
"is disturbed when a cell transitions away from this "
"LULC class in a low-impact disturbance.")},
"soil-med-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the soil pool that "
"is disturbed when a cell transitions away from this "
"LULC class in a medium-impact disturbance.")},
"soil-high-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the soil pool that "
"is disturbed when a cell transitions away from this "
"LULC class in a high-impact disturbance.")},
"soil-yearly-accumulation": {
"type": "number",
"units": u.megatonne/u.hectare/u.year,
"about": gettext(
"Annual rate of CO2E accumulation in the soil pool.")},
"litter-yearly-accumulation": {
"type": "number",
"units": u.megatonne/u.hectare/u.year,
"about": gettext(
"Annual rate of CO2E accumulation in the litter pool.")}
}
MODEL_SPEC = spec.build_model_spec({
"model_id": "coastal_blue_carbon",
"model_name": MODEL_METADATA["coastal_blue_carbon"].model_title,
"pyname": MODEL_METADATA["coastal_blue_carbon"].pyname,
"userguide": MODEL_METADATA["coastal_blue_carbon"].userguide,
"model_title": gettext("Coastal Blue Carbon"),
"userguide": "coastal_blue_carbon.html",
"aliases": ("cbc",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['landcover_snapshot_csv', 'biophysical_table_path', 'landcover_transitions_table', 'analysis_year'],
['do_economic_analysis', 'use_price_table', 'price', 'inflation_rate', 'price_table_path', 'discount_rate'],
]
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"landcover_snapshot_csv": {
"type": "csv",
"index_col": "snapshot_year",
@ -208,97 +306,7 @@ MODEL_SPEC = {
"name": gettext("biophysical table"),
"type": "csv",
"index_col": "lucode",
"columns": {
"lucode": {
"type": "integer",
"about": gettext(
"The LULC code that represents this LULC "
"class in the LULC snapshot rasters.")},
"lulc-class": {
"type": "freestyle_string",
"about": gettext(
"Name of the LULC class. This label must be "
"unique among the all the LULC classes.")},
"biomass-initial": {
"type": "number",
"units": u.megatonne/u.hectare,
"about": gettext(
"The initial carbon stocks in the biomass pool for "
"this LULC class.")},
"soil-initial": {
"type": "number",
"units": u.megatonne/u.hectare,
"about": gettext(
"The initial carbon stocks in the soil pool for this "
"LULC class.")},
"litter-initial": {
"type": "number",
"units": u.megatonne/u.hectare,
"about": gettext(
"The initial carbon stocks in the litter pool for "
"this LULC class.")},
"biomass-half-life": {
"type": "number",
"units": u.year,
"expression": "value > 0",
"about": gettext("The half-life of carbon in the biomass pool.")},
"biomass-low-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the biomass pool that "
"is disturbed when a cell transitions away from this "
" LULC class in a low-impact disturbance.")},
"biomass-med-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the biomass pool that "
"is disturbed when a cell transitions away from this "
"LULC class in a medium-impact disturbance.")},
"biomass-high-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the biomass pool that "
"is disturbed when a cell transitions away from this "
"LULC class in a high-impact disturbance.")},
"biomass-yearly-accumulation": {
"type": "number",
"units": u.megatonne/u.hectare/u.year,
"about": gettext(
"Annual rate of CO2E accumulation in the biomass pool.")},
"soil-half-life": {
"type": "number",
"units": u.year,
"expression": "value > 0",
"about": gettext("The half-life of carbon in the soil pool.")},
"soil-low-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the soil pool that "
"is disturbed when a cell transitions away from this "
"LULC class in a low-impact disturbance.")},
"soil-med-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the soil pool that "
"is disturbed when a cell transitions away from this "
"LULC class in a medium-impact disturbance.")},
"soil-high-impact-disturb": {
"type": "ratio",
"about": gettext(
"Proportion of carbon stock in the soil pool that "
"is disturbed when a cell transitions away from this "
"LULC class in a high-impact disturbance.")},
"soil-yearly-accumulation": {
"type": "number",
"units": u.megatonne/u.hectare/u.year,
"about": gettext(
"Annual rate of CO2E accumulation in the soil pool.")},
"litter-yearly-accumulation": {
"type": "number",
"units": u.megatonne/u.hectare/u.year,
"about": gettext(
"Annual rate of CO2E accumulation in the litter pool.")}
},
"columns": BIOPHYSICAL_TABLE_COLUMNS,
"about": gettext("Table of biophysical properties for each LULC class.")
},
"landcover_transitions_table": {
@ -369,6 +377,7 @@ MODEL_SPEC = {
"name": gettext("use price table"),
"type": "boolean",
"required": False,
"allowed": "do_economic_analysis",
"about": gettext(
"Use a yearly price table, rather than an initial "
"price and interest rate, to indicate carbon value over time."),
@ -378,6 +387,7 @@ MODEL_SPEC = {
"type": "number",
"units": u.currency/u.megatonne,
"required": "do_economic_analysis and (not use_price_table)",
"allowed": "do_economic_analysis and not use_price_table",
"about": gettext(
"The price of CO2E at the baseline year. Required if Do "
"Valuation is selected and Use Price Table is not selected."),
@ -386,6 +396,7 @@ MODEL_SPEC = {
"name": gettext("interest rate"),
"type": "percent",
"required": "do_economic_analysis and (not use_price_table)",
"allowed": "do_economic_analysis and not use_price_table",
"about": gettext(
"Annual increase in the price of CO2E. Required if Do "
"Valuation is selected and Use Price Table is not selected.")
@ -394,6 +405,7 @@ MODEL_SPEC = {
"name": gettext("price table"),
"type": "csv",
"required": "use_price_table",
"allowed": "use_price_table",
"index_col": "year",
"columns": {
"year": {
@ -415,6 +427,7 @@ MODEL_SPEC = {
"name": gettext("discount rate"),
"type": "percent",
"required": "do_economic_analysis",
"allowed": "do_economic_analysis",
"about": gettext(
"Annual discount rate on the price of carbon. This is "
"compounded each year after the baseline year. "
@ -530,9 +543,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
def execute(args):
@ -577,10 +590,9 @@ def execute(args):
task_graph, n_workers, intermediate_dir, output_dir, suffix = (
_set_up_workspace(args))
snapshots = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
**MODEL_SPEC['args']['landcover_snapshot_csv']
)['raster_path'].to_dict()
snapshots = MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
# Phase 1: alignment and preparation of inputs
baseline_lulc_year = min(snapshots.keys())
@ -600,9 +612,9 @@ def execute(args):
# We're assuming that the LULC initial variables and the carbon pool
# transient table are combined into a single lookup table.
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
**MODEL_SPEC['args']['biophysical_table_path'])
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
# LULC Classnames are critical to the transition mapping, so they must be
# unique. This check is here in ``execute`` because it's possible that
@ -970,10 +982,9 @@ def execute(args):
prices = None
if args.get('do_economic_analysis', False): # Do if truthy
if args.get('use_price_table', False):
prices = validation.get_validated_dataframe(
args['price_table_path'],
**MODEL_SPEC['args']['price_table_path']
)['price'].to_dict()
prices = MODEL_SPEC.get_input(
'price_table_path').get_validated_dataframe(
args['price_table_path'])['price'].to_dict()
else:
inflation_rate = float(args['inflation_rate']) * 0.01
annual_price = float(args['price'])
@ -1955,9 +1966,9 @@ def _read_transition_matrix(transition_csv_path, biophysical_df):
landcover transition, and the second contains accumulation rates for
the pool for the landcover transition.
"""
table = validation.get_validated_dataframe(
transition_csv_path, **MODEL_SPEC['args']['landcover_transitions_table']
).reset_index()
table = MODEL_SPEC.get_input(
'landcover_transitions_table').get_validated_dataframe(
transition_csv_path).reset_index()
lulc_class_to_lucode = {}
max_lucode = biophysical_df.index.max()
@ -2171,16 +2182,15 @@ def validate(args, limit_to=None):
A list of tuples where tuple[0] is an iterable of keys that the error
message applies to and tuple[1] is the string validation warning.
"""
validation_warnings = validation.validate(
args, MODEL_SPEC['args'])
validation_warnings = validation.validate(args, MODEL_SPEC)
sufficient_keys = validation.get_sufficient_keys(args)
invalid_keys = validation.get_invalid_keys(validation_warnings)
if ("landcover_snapshot_csv" not in invalid_keys and
"landcover_snapshot_csv" in sufficient_keys):
snapshots = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
**MODEL_SPEC['args']['landcover_snapshot_csv']
snapshots = MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv']
)['raster_path'].to_dict()
snapshot_years = set(snapshots.keys())
@ -2200,13 +2210,13 @@ def validate(args, limit_to=None):
# check for invalid options in the translation table
if ("landcover_transitions_table" not in invalid_keys and
"landcover_transitions_table" in sufficient_keys):
transitions_spec = MODEL_SPEC['args']['landcover_transitions_table']
transitions_spec = MODEL_SPEC.get_input('landcover_transitions_table')
transition_options = list(
transitions_spec['columns']['[LULC CODE]']['options'].keys())
transitions_spec.columns.get('[LULC CODE]').options.keys())
# lowercase options since utils call will lowercase table values
transition_options = [x.lower() for x in transition_options]
transitions_df = validation.get_validated_dataframe(
args['landcover_transitions_table'], **transitions_spec)
transitions_df = transitions_spec.get_validated_dataframe(
args['landcover_transitions_table'])
transitions_mask = ~transitions_df.isin(transition_options) & ~transitions_df.isna()
if transitions_mask.any(axis=None):
transition_numpy_mask = transitions_mask.values

View File

@ -9,27 +9,29 @@ import taskgraph
from osgeo import gdal
from .. import gettext
from .. import spec_utils
from .. import spec
from .. import utils
from .. import validation
from ..model_metadata import MODEL_METADATA
from ..unit_registry import u
from . import coastal_blue_carbon
LOGGER = logging.getLogger(__name__)
BIOPHYSICAL_COLUMNS_SPEC = coastal_blue_carbon.MODEL_SPEC[
'args']['biophysical_table_path']['columns']
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "coastal_blue_carbon_preprocessor",
"model_name": MODEL_METADATA["coastal_blue_carbon_preprocessor"].model_title,
"pyname": MODEL_METADATA["coastal_blue_carbon_preprocessor"].pyname,
"userguide": MODEL_METADATA["coastal_blue_carbon_preprocessor"].userguide,
"model_title": gettext("Coastal Blue Carbon Preprocessor"),
"userguide": "coastal_blue_carbon.html",
"aliases": ("cbc_pre",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['lulc_lookup_table_path', 'landcover_snapshot_csv']
]
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"lulc_lookup_table_path": {
"name": gettext("LULC lookup table"),
"type": "csv",
@ -117,14 +119,14 @@ MODEL_SPEC = {
"create the biophysical table input to the main model."),
"index_col": "lucode",
"columns": {
**BIOPHYSICAL_COLUMNS_SPEC,
**coastal_blue_carbon.BIOPHYSICAL_TABLE_COLUMNS,
# remove "expression" property which doesn't go in output spec
"biomass-half-life": dict(
set(BIOPHYSICAL_COLUMNS_SPEC["biomass-half-life"].items()) -
set(coastal_blue_carbon.BIOPHYSICAL_TABLE_COLUMNS["biomass-half-life"].items()) -
{("expression", "value > 0")}
),
"soil-half-life": dict(
set(BIOPHYSICAL_COLUMNS_SPEC["soil-half-life"].items()) -
set(coastal_blue_carbon.BIOPHYSICAL_TABLE_COLUMNS["soil-half-life"].items()) -
{("expression", "value > 0")}
)
}
@ -135,9 +137,9 @@ MODEL_SPEC = {
"to match all the other LULC maps."),
"bands": {1: {"type": "integer"}}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
ALIGNED_LULC_RASTER_TEMPLATE = 'aligned_lulc_{year}{suffix}.tif'
@ -181,10 +183,9 @@ def execute(args):
os.path.join(args['workspace_dir'], 'taskgraph_cache'),
n_workers, reporting_interval=5.0)
snapshots_dict = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
**MODEL_SPEC['args']['landcover_snapshot_csv']
)['raster_path'].to_dict()
snapshots_dict = MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
# Align the raster stack for analyzing the various transitions.
min_pixel_size = float('inf')
@ -214,9 +215,9 @@ def execute(args):
target_path_list=aligned_snapshot_paths,
task_name='Align input landcover rasters')
landcover_df = validation.get_validated_dataframe(
args['lulc_lookup_table_path'],
**MODEL_SPEC['args']['lulc_lookup_table_path'])
landcover_df = MODEL_SPEC.get_input(
'lulc_lookup_table_path').get_validated_dataframe(
args['lulc_lookup_table_path'])
target_transition_table = os.path.join(
output_dir, TRANSITION_TABLE.format(suffix=suffix))
@ -385,8 +386,8 @@ def _create_biophysical_table(landcover_df, target_biophysical_table_path):
``None``
"""
target_column_names = [
colname.lower() for colname in coastal_blue_carbon.MODEL_SPEC['args'][
'biophysical_table_path']['columns']]
spec.id.lower() for spec in
coastal_blue_carbon.MODEL_SPEC.get_input('biophysical_table_path').columns]
with open(target_biophysical_table_path, 'w') as bio_table:
bio_table.write(f"{','.join(target_column_names)}\n")
@ -420,4 +421,4 @@ def validate(args, limit_to=None):
A list of tuples where tuple[0] is an iterable of keys that the error
message applies to and tuple[1] is the string validation warning.
"""
return validation.validate(args, MODEL_SPEC['args'])
return validation.validate(args, MODEL_SPEC)

View File

@ -26,10 +26,9 @@ from shapely.geometry.base import BaseMultipartGeometry
from shapely.strtree import STRtree
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
@ -119,11 +118,47 @@ WWIII_FIELDS = {
"centered on each main sector direction X.")}
}
MODEL_SPEC = {
def get_vector_colnames(vector_path):
"""Get a list of column names from a vector.
This is used to fill in dropdown menu options.
Args:
vector_path (string): path to a vector file
Returns:
list of vector column names (strings)
"""
# a lot of times the path will be empty so don't even try to open it
if vector_path:
try:
vector = gdal.OpenEx(vector_path, gdal.OF_VECTOR)
return [defn.GetName() for defn in vector.GetLayer().schema]
except Exception as e:
LOGGER.exception(
f'Could not read column names from {vector_path}. ERROR: {e}')
else:
LOGGER.error('Empty vector path.')
return []
MODEL_SPEC = spec.build_model_spec({
"model_id": "coastal_vulnerability",
"model_name": MODEL_METADATA["coastal_vulnerability"].model_title,
"pyname": MODEL_METADATA["coastal_vulnerability"].pyname,
"userguide": MODEL_METADATA["coastal_vulnerability"].userguide,
"model_title": gettext("Coastal Vulnerability"),
"userguide": "coastal_vulnerability.html",
"aliases": ("cv",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['aoi_vector_path', 'model_resolution', 'landmass_vector_path'],
['bathymetry_raster_path', 'wwiii_vector_path', 'max_fetch_distance'],
['habitat_table_path', 'shelf_contour_vector_path', 'dem_path', 'dem_averaging_radius'],
['geomorphology_vector_path', 'geomorphology_fill_value'],
['population_raster_path', 'population_radius'],
['slr_vector_path', 'slr_field']
]
},
"args_with_spatial_overlap": {
"spatial_keys": [
"aoi_vector_path",
@ -137,11 +172,11 @@ MODEL_SPEC = {
"different_projections_ok": True,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"aoi_vector_path": {
**spec_utils.AOI,
**spec.AOI,
"projected": True,
"projection_units": u.meter,
"about": gettext("Map of the region over which to run the model.")
@ -157,7 +192,7 @@ MODEL_SPEC = {
"landmass_vector_path": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"about": gettext(
"Map of all landmasses in and around the region of interest. "
"It is not recommended to clip this landmass to the AOI "
@ -170,7 +205,7 @@ MODEL_SPEC = {
"wwiii_vector_path": {
"type": "vector",
"fields": WWIII_FIELDS,
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"about": gettext(
"Map of gridded wind and wave data that represent storm "
"conditions. This global dataset is provided with the InVEST "
@ -209,14 +244,14 @@ MODEL_SPEC = {
"shelf_contour_vector_path": {
"type": "vector",
"fields": {},
"geometries": spec_utils.LINES,
"geometries": spec.LINES,
"about": gettext(
"Map of the edges of the continental shelf or other locally "
"relevant bathymetry contour."),
"name": gettext("continental shelf contour")
},
"dem_path": {
**spec_utils.DEM,
**spec.DEM,
"bands": {1: {
"type": "number",
"units": u.other # any unit of length is ok
@ -287,7 +322,7 @@ MODEL_SPEC = {
"about": gettext("Relative exposure of the segment of coastline.")
}
},
"geometries": spec_utils.LINES,
"geometries": spec.LINES,
"required": False,
"about": gettext("Map of relative exposure of each segment of coastline."),
"name": gettext("geomorphology")
@ -302,6 +337,7 @@ MODEL_SPEC = {
"5": {"display_name": gettext("5: very high exposure")}
},
"required": "geomorphology_vector_path",
"allowed": "geomorphology_vector_path",
"about": gettext(
"Exposure rank to assign to any shore points that are not "
"near to any segment in the geomorphology vector. "
@ -322,6 +358,7 @@ MODEL_SPEC = {
"units": u.meter,
"expression": "value > 0",
"required": "population_raster_path",
"allowed": "population_raster_path",
"about": gettext(
"The radius around each shore point within which to compute "
"the average population density. "
@ -339,7 +376,7 @@ MODEL_SPEC = {
"units": u.none
}
},
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"required": False,
"about": gettext(
"Map of sea level rise rates or amounts. May be any sea level "
@ -349,7 +386,9 @@ MODEL_SPEC = {
"slr_field": {
"type": "option_string",
"options": {},
"dropdown_function": lambda args: get_vector_colnames(args['slr_vector_path']),
"required": "slr_vector_path",
"allowed": "slr_vector_path",
"about": gettext(
"Name of the field in the sea level rise vector which "
"contains the sea level rise metric of interest. "
@ -360,7 +399,7 @@ MODEL_SPEC = {
"outputs": {
"coastal_exposure.gpkg": {
"about": "This point vector file contains the final outputs of the model. The points are created based on the input model resolution, landmass, and AOI.",
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": FINAL_OUTPUT_FIELDS
},
"coastal_exposure.csv": {
@ -374,7 +413,7 @@ MODEL_SPEC = {
"intermediate_exposure.gpkg": {
"about": (
"Shore points with associated exposure variables"),
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": {
"shore_id": {
"type": "integer",
@ -410,7 +449,7 @@ MODEL_SPEC = {
"geomorphology_protected.gpkg": {
"about": (
"Geomorphology vector reprojected to match the AOI"),
"geometries": spec_utils.LINES,
"geometries": spec.LINES,
"fields": {
"rank": {
"type": "option_string",
@ -435,7 +474,7 @@ MODEL_SPEC = {
"of either the geomorphology or landmass polygon "
"inputs. Editing the geometory of one or both in "
"GIS could help resolve this."),
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": {
"shore_id": {
"type": "integer",
@ -539,7 +578,7 @@ MODEL_SPEC = {
"clipped_projected_landmass.gpkg": {
"about": "Clipped and reprojected landmass map",
"fields": {},
"geometries": spec_utils.POLYGONS
"geometries": spec.POLYGONS
},
"landmass_line_index.pickle": {
"about": "Pickled landmass index"
@ -551,7 +590,7 @@ MODEL_SPEC = {
},
"shore_points.gpkg": {
"about": "Map of shore points",
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": {
"shore_id": {
"type": "integer",
@ -562,7 +601,7 @@ MODEL_SPEC = {
"tmp_clipped_landmass.gpkg": {
"about": "Clipped landmass map",
"fields": {},
"geometries": spec_utils.POLYGONS
"geometries": spec.POLYGONS
}
}
},
@ -585,7 +624,7 @@ MODEL_SPEC = {
},
"fetch_points.gpkg": {
"about": "Shore points with added fetch ray data",
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": {
**WWIII_FIELDS,
"fdist_[SECTOR]": {
@ -606,7 +645,7 @@ MODEL_SPEC = {
"fetch_rays.gpkg": {
"about": (
"Map of fetch rays around each shore point."),
"geometries": spec_utils.LINESTRING,
"geometries": spec.LINESTRING,
"fields": {
"fetch_dist": {
"about": "Fetch distance along the ray",
@ -628,7 +667,7 @@ MODEL_SPEC = {
},
"wave_energies.gpkg": {
"about": "Shore points with associated wave energy data",
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": {
"E_ocean": {
"about": (
@ -685,16 +724,16 @@ MODEL_SPEC = {
"wind.pickle": {"about": "Pickled wind data"},
"wwiii_shore_points.gpkg": {
"about": "WaveWatch 3 data interpolated to shore points",
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": WWIII_FIELDS
}
}
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
_N_FETCH_RAYS = 16
@ -2299,8 +2338,8 @@ def _schedule_habitat_tasks(
list of pickle file path strings
"""
habitat_dataframe = validation.get_validated_dataframe(
habitat_table_path, **MODEL_SPEC['args']['habitat_table_path']
habitat_dataframe = MODEL_SPEC.get_input(
'habitat_table_path').get_validated_dataframe(habitat_table_path
).rename(columns={'protection distance (m)': 'distance'})
habitat_task_list = []
@ -2829,10 +2868,8 @@ def assemble_results_and_calculate_exposure(
with open(pickle_path, 'rb') as file:
final_values_dict[var_name] = pickle.load(file)
habitat_df = validation.get_validated_dataframe(
habitat_protection_path, **MODEL_SPEC['outputs']['intermediate'][
'contents']['habitats']['contents']['habitat_protection.csv']
).rename(columns={'r_hab': 'R_hab'})
habitat_df = utils.read_csv_to_dataframe(
habitat_protection_path).rename(columns={'r_hab': 'R_hab'})
output_layer.StartTransaction()
for feature in output_layer:
shore_id = feature.GetField(SHORE_ID_FIELD)
@ -3464,8 +3501,7 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
sufficient_keys = validation.get_sufficient_keys(args)
@ -3476,8 +3512,8 @@ def validate(args, limit_to=None):
'slr_field' in sufficient_keys):
fieldnames = validation.load_fields_from_vector(
args['slr_vector_path'])
error_msg = validation.check_option_string(args['slr_field'],
fieldnames)
error_msg = spec.OptionStringInput(
options=fieldnames).validate(args['slr_field'])
if error_msg:
validation_warnings.append((['slr_field'], error_msg))

View File

@ -12,11 +12,10 @@ from osgeo import gdal
from osgeo import osr
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .crop_production_regression import NUTRIENTS
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
@ -241,11 +240,17 @@ nutrient_units = {
"vitk": u.microgram/u.hectogram, # vitamin K
}
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "crop_production_percentile",
"model_name": MODEL_METADATA["crop_production_percentile"].model_title,
"pyname": MODEL_METADATA["crop_production_percentile"].pyname,
"userguide": MODEL_METADATA["crop_production_percentile"].userguide,
"model_title": gettext("Crop Production: Percentile"),
"userguide": "crop_production.html",
"aliases": ("cpp",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['model_data_path', 'landcover_raster_path', 'landcover_to_crop_table_path', 'aggregate_polygon_path']
]
},
"args_with_spatial_overlap": {
"spatial_keys": [
"landcover_raster_path",
@ -254,11 +259,11 @@ MODEL_SPEC = {
"different_projections_ok": True,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"landcover_raster_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"projection_units": u.meter
},
@ -279,7 +284,7 @@ MODEL_SPEC = {
"name": gettext("LULC to Crop Table")
},
"aggregate_polygon_path": {
**spec_utils.AOI,
**spec.AOI,
"projected": True,
"required": False
},
@ -504,9 +509,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
_INTERMEDIATE_OUTPUT_DIR = 'intermediate_output'
@ -608,9 +613,9 @@ def execute(args):
None.
"""
crop_to_landcover_df = validation.get_validated_dataframe(
args['landcover_to_crop_table_path'],
**MODEL_SPEC['args']['landcover_to_crop_table_path'])
crop_to_landcover_df = MODEL_SPEC.get_input(
'landcover_to_crop_table_path').get_validated_dataframe(
args['landcover_to_crop_table_path'])
lucodes_in_table = set(list(
crop_to_landcover_df[_EXPECTED_LUCODE_TABLE_HEADER]))
@ -716,11 +721,11 @@ def execute(args):
climate_percentile_yield_table_path = os.path.join(
args['model_data_path'],
_CLIMATE_PERCENTILE_TABLE_PATTERN % crop_name)
crop_climate_percentile_df = validation.get_validated_dataframe(
climate_percentile_yield_table_path,
**MODEL_SPEC['args']['model_data_path']['contents'][
'climate_percentile_yield_tables']['contents'][
'[CROP]_percentile_yield_table.csv'])
crop_climate_percentile_df = MODEL_SPEC.get_input(
'model_data_path').contents.get(
'climate_percentile_yield_tables').contents.get(
'[CROP]_percentile_yield_table.csv').get_validated_dataframe(
climate_percentile_yield_table_path)
yield_percentile_headers = [
x for x in crop_climate_percentile_df.columns if x != 'climate_bin']
@ -873,9 +878,11 @@ def execute(args):
# both 'crop_nutrient.csv' and 'crop' are known data/header values for
# this model data.
nutrient_df = validation.get_validated_dataframe(
os.path.join(args['model_data_path'], 'crop_nutrient.csv'),
**MODEL_SPEC['args']['model_data_path']['contents']['crop_nutrient.csv'])
nutrient_df = MODEL_SPEC.get_input(
'model_data_path').contents.get(
'crop_nutrient.csv').get_validated_dataframe(
os.path.join(args['model_data_path'], 'crop_nutrient.csv'))
result_table_path = os.path.join(
output_dir, 'result_table%s.csv' % file_suffix)
@ -1265,5 +1272,4 @@ def validate(args, limit_to=None):
the error message in the second part of the tuple. This should
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
return validation.validate(args, MODEL_SPEC)

View File

@ -10,10 +10,9 @@ from osgeo import gdal
from osgeo import osr
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
@ -67,21 +66,27 @@ NUTRIENTS = [
("vitk", "vitamin K", u.microgram/u.hectogram)
]
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "crop_production_regression",
"model_name": MODEL_METADATA["crop_production_regression"].model_title,
"pyname": MODEL_METADATA["crop_production_regression"].pyname,
"userguide": MODEL_METADATA["crop_production_regression"].userguide,
"model_title": gettext("Crop Production: Regression"),
"userguide": "crop_production.html",
"aliases": ("cpr",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['model_data_path', 'landcover_raster_path', 'landcover_to_crop_table_path', 'fertilization_rate_table_path', 'aggregate_polygon_path'],
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["landcover_raster_path", "aggregate_polygon_path"],
"different_projections_ok": True,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"landcover_raster_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"projection_units": u.meter
},
@ -121,7 +126,7 @@ MODEL_SPEC = {
"name": gettext("fertilization rate table")
},
"aggregate_polygon_path": {
**spec_utils.AOI,
**spec.AOI,
"required": False
},
"model_data_path": {
@ -270,7 +275,7 @@ MODEL_SPEC = {
"contents": {
"aggregate_vector.shp": {
"about": "Copy of input AOI vector",
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {}
},
"clipped_[CROP]_climate_bin_map.tif": {
@ -323,9 +328,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
_INTERMEDIATE_OUTPUT_DIR = 'intermediate_output'
@ -497,13 +502,13 @@ def execute(args):
LOGGER.info(
"Checking if the landcover raster is missing lucodes")
crop_to_landcover_df = validation.get_validated_dataframe(
args['landcover_to_crop_table_path'],
**MODEL_SPEC['args']['landcover_to_crop_table_path'])
crop_to_landcover_df = MODEL_SPEC.get_input(
'landcover_to_crop_table_path').get_validated_dataframe(
args['landcover_to_crop_table_path'])
crop_to_fertilization_rate_df = validation.get_validated_dataframe(
args['fertilization_rate_table_path'],
**MODEL_SPEC['args']['fertilization_rate_table_path'])
crop_to_fertilization_rate_df = MODEL_SPEC.get_input(
'fertilization_rate_table_path').get_validated_dataframe(
args['fertilization_rate_table_path'])
lucodes_in_table = set(list(
crop_to_landcover_df[_EXPECTED_LUCODE_TABLE_HEADER]))
@ -588,12 +593,11 @@ def execute(args):
task_name='crop_climate_bin')
dependent_task_list.append(crop_climate_bin_task)
crop_regression_df = validation.get_validated_dataframe(
os.path.join(args['model_data_path'],
_REGRESSION_TABLE_PATTERN % crop_name),
**MODEL_SPEC['args']['model_data_path']['contents'][
'climate_regression_yield_tables']['contents'][
'[CROP]_regression_yield_table.csv'])
crop_regression_df = MODEL_SPEC.get_input('model_data_path').contents.get(
'climate_regression_yield_tables').contents.get(
'[CROP]_regression_yield_table.csv').get_validated_dataframe(
os.path.join(args['model_data_path'],
_REGRESSION_TABLE_PATTERN % crop_name))
for _, row in crop_regression_df.iterrows():
for header in _EXPECTED_REGRESSION_TABLE_HEADERS:
if numpy.isnan(row[header]):
@ -812,9 +816,9 @@ def execute(args):
# both 'crop_nutrient.csv' and 'crop' are known data/header values for
# this model data.
nutrient_df = validation.get_validated_dataframe(
os.path.join(args['model_data_path'], 'crop_nutrient.csv'),
**MODEL_SPEC['args']['model_data_path']['contents']['crop_nutrient.csv'])
nutrient_df = MODEL_SPEC.get_input('model_data_path').contents.get(
'crop_nutrient.csv').get_validated_dataframe(
os.path.join(args['model_data_path'], 'crop_nutrient.csv'))
LOGGER.info("Generating report table")
crop_names = list(crop_to_landcover_df.index)
@ -1177,5 +1181,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
return validation.validate(args, MODEL_SPEC)

View File

@ -34,9 +34,10 @@ import warnings
from osgeo import gdal
from . import spec_utils
from . import spec
from . import utils
from . import validation
from . import models
try:
from . import __version__
@ -55,10 +56,10 @@ ARGS_LOG_LEVEL = 100 # define high log level so it should always show in logs
DATASTACK_EXTENSION = '.invest.tar.gz'
PARAMETER_SET_EXTENSION = '.invest.json'
DATASTACK_PARAMETER_FILENAME = 'parameters' + PARAMETER_SET_EXTENSION
UNKNOWN = 'UNKNOWN'
ParameterSet = collections.namedtuple('ParameterSet',
'args model_name invest_version')
'args model_id invest_version')
def _tarfile_safe_extract(archive_path, dest_dir_path):
@ -96,54 +97,7 @@ def _tarfile_safe_extract(archive_path, dest_dir_path):
safe_extract(tar, dest_dir_path)
def _copy_spatial_files(spatial_filepath, target_dir):
"""Copy spatial files to a new directory.
Args:
spatial_filepath (str): The filepath to a GDAL-supported file.
target_dir (str): The directory where all component files of
``spatial_filepath`` should be copied. If this directory does not
exist, it will be created.
Returns:
filepath (str): The path to a representative file copied into the
``target_dir``. If possible, this will match the basename of
``spatial_filepath``, so if someone provides an ESRI Shapefile called
``my_vector.shp``, the return value will be ``os.path.join(target_dir,
my_vector.shp)``.
"""
LOGGER.info(f'Copying {spatial_filepath} --> {target_dir}')
if not os.path.exists(target_dir):
os.makedirs(target_dir)
source_basename = os.path.basename(spatial_filepath)
return_filepath = None
spatial_file = gdal.OpenEx(spatial_filepath)
for member_file in spatial_file.GetFileList():
# ArcGIS Binary/Grid format includes the directory in the file listing.
# The parent directory isn't strictly needed, so we can just skip it.
if os.path.isdir(member_file):
continue
target_basename = os.path.basename(member_file)
target_filepath = os.path.join(target_dir, target_basename)
if source_basename == target_basename:
return_filepath = target_filepath
shutil.copyfile(member_file, target_filepath)
spatial_file = None
# I can't conceive of a case where the basename of the source file does not
# match any of the member file basenames, but just in case there's a
# weird GDAL driver that does this, it seems reasonable to fall back to
# whichever of the member files was most recent.
if not return_filepath:
return_filepath = target_filepath
return return_filepath
def format_args_dict(args_dict, model_name):
def format_args_dict(args_dict, model_id):
"""Nicely format an arguments dictionary for writing to a stream.
If printed to a console, the returned string will be aligned in two columns
@ -152,7 +106,7 @@ def format_args_dict(args_dict, model_name):
Args:
args_dict (dict): The args dictionary to format.
model_name (string): The model name (in python package import format)
model_id (string): The model ID (e.g. carbon)
Returns:
A formatted, unicode string.
@ -163,12 +117,11 @@ def format_args_dict(args_dict, model_name):
if len(sorted_args) > 0:
max_key_width = max(len(x[0]) for x in sorted_args)
format_str = "%-" + str(max_key_width) + "s %s"
format_str = f"%-{max_key_width}s %s"
args_string = '\n'.join([format_str % (arg) for arg in sorted_args])
args_string = "Arguments for InVEST %s %s:\n%s\n" % (model_name,
__version__,
args_string)
args_string = (
f"Arguments for InVEST {model_id} {__version__}:\n{args_string}\n")
return args_string
@ -190,7 +143,7 @@ def get_datastack_info(filepath, extract_path=None):
* ``"logfile"`` when the file is a text logfile.
The second item of the tuple is a ParameterSet namedtuple with the raw
parsed args, modelname and invest version that the file was built with.
parsed args, model id and invest version that the file was built with.
"""
if tarfile.is_tarfile(filepath):
if not extract_path:
@ -215,20 +168,19 @@ def get_datastack_info(filepath, extract_path=None):
return 'logfile', extract_parameters_from_logfile(filepath)
def build_datastack_archive(args, model_name, datastack_path):
def build_datastack_archive(args, model_id, datastack_path):
"""Build an InVEST datastack from an arguments dict.
Args:
args (dict): The arguments dictionary to include in the datastack.
model_name (string): The python-importable module string of the model
these args are for.
model_id (string): The id the model these args are for.
datastack_path (string): The path to where the datastack archive
should be written.
Returns:
``None``
"""
module = importlib.import_module(name=model_name)
module = importlib.import_module(name=models.model_id_to_pyname[model_id])
args = args.copy()
temp_workspace = tempfile.mkdtemp(prefix='datastack_')
@ -247,10 +199,11 @@ def build_datastack_archive(args, model_name, datastack_path):
# For tracking existing files so we don't copy files in twice
files_found = {}
LOGGER.debug(f'Keys: {sorted(args.keys())}')
args_spec = module.MODEL_SPEC['args']
spatial_types = {'raster', 'vector'}
file_based_types = spatial_types.union({'csv', 'file', 'directory'})
spatial_types = {spec.SingleBandRasterInput, spec.VectorInput,
spec.RasterOrVectorInput}
file_based_types = spatial_types.union({
spec.CSVInput, spec.FileInput, spec.DirectoryInput})
rewritten_args = {}
for key in args:
# Allow the model to override specific arguments in datastack archive
@ -283,11 +236,11 @@ def build_datastack_archive(args, model_name, datastack_path):
LOGGER.info(f'Starting to archive arg "{key}": {args[key]}')
# Possible that a user might pass an args key that doesn't belong to
# this model. Skip if so.
if key not in args_spec:
if key not in module.MODEL_SPEC.inputs:
LOGGER.info(f'Skipping arg {key}; not in model MODEL_SPEC')
input_type = args_spec[key]['type']
if input_type in file_based_types:
input_spec = module.MODEL_SPEC.get_input(key)
if type(input_spec) in file_based_types:
if args[key] in {None, ''}:
LOGGER.info(
f'Skipping key {key}, value is empty and cannot point to '
@ -307,22 +260,16 @@ def build_datastack_archive(args, model_name, datastack_path):
rewritten_args[key] = files_found[source_path]
continue
if input_type == 'csv':
if type(input_spec) is spec.CSVInput:
# check the CSV for columns that may be spatial.
# But also, the columns specification might not be listed, so don't
# require that 'columns' exists in the MODEL_SPEC.
spatial_columns = []
if 'columns' in args_spec[key]:
for col_name, col_definition in (
args_spec[key]['columns'].items()):
# Type attribute may be a string (one type) or set
# (multiple types allowed), so always convert to a set for
# easier comparison.
col_types = col_definition['type']
if isinstance(col_types, str):
col_types = set([col_types])
if col_types.intersection(spatial_types):
spatial_columns.append(col_name)
if input_spec.columns:
for col_spec in input_spec.columns:
if type(col_spec) in spatial_types:
spatial_columns.append(col_spec.id)
LOGGER.debug(f'Detected spatial columns: {spatial_columns}')
target_csv_path = os.path.join(
@ -335,8 +282,7 @@ def build_datastack_archive(args, model_name, datastack_path):
contained_files_dir = os.path.join(
data_dir, f'{key}_csv_data')
dataframe = validation.get_validated_dataframe(
source_path, **args_spec[key])
dataframe = input_spec.get_validated_dataframe(source_path)
csv_source_dir = os.path.abspath(os.path.dirname(source_path))
for spatial_column_name in spatial_columns:
# Iterate through the spatial columns, identify the set of
@ -377,7 +323,7 @@ def build_datastack_archive(args, model_name, datastack_path):
target_dir = os.path.join(
contained_files_dir,
f'{row_index}_{basename}')
target_filepath = _copy_spatial_files(
target_filepath = utils.copy_spatial_files(
source_filepath, target_dir)
target_filepath = os.path.relpath(
target_filepath, data_dir)
@ -396,14 +342,14 @@ def build_datastack_archive(args, model_name, datastack_path):
target_arg_value = target_csv_path
files_found[source_path] = target_arg_value
elif input_type == 'file':
elif type(input_spec) is spec.FileInput:
target_filepath = os.path.join(
data_dir, f'{key}_file')
shutil.copyfile(source_path, target_filepath)
target_arg_value = target_filepath
files_found[source_path] = target_arg_value
elif input_type == 'directory':
elif type(input_spec)is spec.DirectoryInput:
# copy the whole folder
target_directory = os.path.join(data_dir, f'{key}_directory')
os.makedirs(target_directory)
@ -423,22 +369,17 @@ def build_datastack_archive(args, model_name, datastack_path):
target_arg_value = target_directory
files_found[source_path] = target_arg_value
elif input_type in spatial_types:
elif type(input_spec) in spatial_types:
# Create a directory with a readable name, something like
# "aoi_path_vector" or "lulc_cur_path_raster".
spatial_dir = os.path.join(data_dir, f'{key}_{input_type}')
target_arg_value = _copy_spatial_files(
spatial_dir = os.path.join(data_dir, f'{key}_{input_spec.type}')
target_arg_value = utils.copy_spatial_files(
source_path, spatial_dir)
files_found[source_path] = target_arg_value
elif input_type == 'other':
# Note that no models currently use this to the best of my
# knowledge, so better to raise a NotImplementedError
raise NotImplementedError(
'The "other" MODEL_SPEC input type is not supported')
else:
LOGGER.debug(
f"Type {input_type} is not filesystem-based; "
f"Type {type(input_spec)} is not filesystem-based; "
"recording value directly")
# not a filesystem-based type
# Record the value directly
@ -453,17 +394,17 @@ def build_datastack_archive(args, model_name, datastack_path):
param_file_uri = os.path.join(temp_workspace,
'parameters' + PARAMETER_SET_EXTENSION)
parameter_set = build_parameter_set(
rewritten_args, model_name, param_file_uri, relative=True)
rewritten_args, model_id, param_file_uri, relative=True)
# write metadata for all files in args
keywords = [module.MODEL_SPEC['model_id'], 'InVEST']
keywords = [module.MODEL_SPEC.model_id, 'InVEST']
for k, v in args.items():
if isinstance(v, str) and os.path.isfile(v):
this_arg_spec = module.MODEL_SPEC['args'][k]
this_arg_spec = module.MODEL_SPEC.get_input(k)
# write metadata file to target location (in temp dir)
subdir = os.path.dirname(parameter_set['args'][k])
target_location = os.path.join(temp_workspace, subdir)
spec_utils.write_metadata_file(v, this_arg_spec, keywords,
spec.write_metadata_file(v, this_arg_spec, keywords,
out_workspace=target_location)
# Remove the handler before archiving the working dir (and the logfile)
@ -530,12 +471,12 @@ def extract_datastack_archive(datastack_path, dest_dir_path):
return new_args
def build_parameter_set(args, model_name, paramset_path, relative=False):
def build_parameter_set(args, model_id, paramset_path, relative=False):
"""Record a parameter set to a file on disk.
Args:
args (dict): The args dictionary to record to the parameter set.
model_name (string): An identifier string for the callable or InVEST
model_id (string): An identifier string for the callable or InVEST
model that would accept the arguments given.
paramset_path (string): The path to the file on disk where the
parameters should be recorded.
@ -584,7 +525,7 @@ def build_parameter_set(args, model_name, paramset_path, relative=False):
return linux_style_path
return args_param
parameter_data = {
'model_name': model_name,
'model_id': model_id,
'invest_version': __version__,
'args': _recurse(args)
}
@ -612,8 +553,7 @@ def extract_parameter_set(paramset_path):
args (dict): The arguments dict for the callable
invest_version (string): The version of InVEST used to record the
parameter set.
model_name (string): The name of the callable or model that these
arguments are intended for.
model_id (string): the ID of the model that these parameters are for
"""
paramset_parent_dir = os.path.dirname(os.path.abspath(paramset_path))
with codecs.open(paramset_path, 'r', encoding='UTF-8') as paramset_file:
@ -650,9 +590,16 @@ def extract_parameter_set(paramset_path):
return args_param
return args_param
return ParameterSet(_recurse(read_params['args']),
read_params['model_name'],
read_params['invest_version'])
if 'model_id' in read_params:
# New style datastacks include the model ID
model_id = read_params['model_id']
else:
# Old style datastacks use the pyname (core models only, no plugins)
model_id = models.pyname_to_model_id[read_params['model_name']]
return ParameterSet(
args=_recurse(read_params['args']),
model_id=model_id,
invest_version=read_params['invest_version'])
def extract_parameters_from_logfile(logfile_path):
@ -671,10 +618,7 @@ def extract_parameters_from_logfile(logfile_path):
logfile_path (string): The path to an InVEST logfile on disk.
Returns:
An instance of the ParameterSet namedtuple. If a model name and InVEST
version cannot be parsed from the Arguments section of the logfile,
``ParameterSet.model_name`` and ``ParameterSet.invest_version`` will be
set to ``datastack.UNKNOWN``.
An instance of the ParameterSet namedtuple.
Raises:
ValueError - when no arguments could be parsed from the logfile.
@ -687,16 +631,21 @@ def extract_parameters_from_logfile(logfile_path):
if not args_started:
# Line would look something like this:
# "Arguments for InVEST carbon 3.4.1rc1:\n"
# (new style, using model id)
# or
# "Arguments for InVEST natcap.invest.carbon 3.4.1rc1:\n"
if line.startswith('Arguments'):
try:
modelname, invest_version = line.split(' ')[3:5]
invest_version = invest_version.replace(':', '')
except ValueError:
# Old-style logfiles don't provide the modelename or
# version info.
modelname = UNKNOWN
invest_version = UNKNOWN
# (old style, using model pyname)
if line.startswith('Arguments for InVEST'):
identifier, invest_version = line.split(' ')[3:5]
if identifier in models.pyname_to_model_id:
# Old style logfiles use the pyname
# These will be for core models only, not plugins
model_id = models.pyname_to_model_id[identifier]
else:
# New style logfiles use the model id
model_id = identifier
invest_version = invest_version.replace(':', '')
args_started = True
continue
else:
@ -742,4 +691,4 @@ def extract_parameters_from_logfile(logfile_path):
pass
args_dict[args_key] = args_value
return ParameterSet(args_dict, modelname, invest_version)
return ParameterSet(args_dict, model_id, invest_version)

View File

@ -15,30 +15,36 @@ from osgeo import ogr
from osgeo import osr
from .. import gettext
from .. import spec_utils
from .. import spec
from .. import utils
from .. import validation
from ..model_metadata import MODEL_METADATA
from ..unit_registry import u
from . import delineateit_core
LOGGER = logging.getLogger(__name__)
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "delineateit",
"model_name": MODEL_METADATA["delineateit"].model_title,
"pyname": MODEL_METADATA["delineateit"].pyname,
"userguide": MODEL_METADATA["delineateit"].userguide,
"model_title": gettext("DelineateIt"),
"userguide": "delineateit.html",
"aliases": (),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['dem_path', 'detect_pour_points', 'outlet_vector_path', 'skip_invalid_geometry'],
['snap_points', 'flow_threshold', 'snap_distance'],
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["dem_path", "outlet_vector_path"],
"different_projections_ok": True,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"dem_path": {
**spec_utils.DEM,
**spec.DEM,
"projected": True
},
"detect_pour_points": {
@ -53,8 +59,9 @@ MODEL_SPEC = {
"outlet_vector_path": {
"type": "vector",
"fields": {},
"geometries": spec_utils.ALL_GEOMS,
"geometries": spec.ALL_GEOMS,
"required": "not detect_pour_points",
"allowed": "not detect_pour_points",
"about": gettext(
"A map of watershed outlets from which to delineate the "
"watersheds. Required if Detect Pour Points is not checked."),
@ -74,10 +81,11 @@ MODEL_SPEC = {
"name": gettext("snap points to the nearest stream")
},
"flow_threshold": {
**spec_utils.THRESHOLD_FLOW_ACCUMULATION,
**spec.THRESHOLD_FLOW_ACCUMULATION,
"required": "snap_points",
"allowed": "snap_points",
"about": gettext(
spec_utils.THRESHOLD_FLOW_ACCUMULATION["about"] +
spec.THRESHOLD_FLOW_ACCUMULATION["about"] +
" Required if Snap Points is selected."),
},
"snap_distance": {
@ -85,6 +93,7 @@ MODEL_SPEC = {
"type": "number",
"units": u.pixels,
"required": "snap_points",
"allowed": "snap_points",
"about": gettext(
"Maximum distance to relocate watershed outlet points in "
"order to snap them to a stream. Required if Snap Points "
@ -94,6 +103,7 @@ MODEL_SPEC = {
"skip_invalid_geometry": {
"type": "boolean",
"required": False,
"allowed": "not detect_pour_points",
"about": gettext(
"Skip delineation for any invalid geometries found in the "
"Outlet Features. Otherwise, an invalid geometry will cause "
@ -102,18 +112,18 @@ MODEL_SPEC = {
}
},
"outputs": {
"filled_dem.tif": spec_utils.FILLED_DEM,
"flow_direction.tif": spec_utils.FLOW_DIRECTION_D8,
"flow_accumulation.tif": spec_utils.FLOW_ACCUMULATION,
"filled_dem.tif": spec.FILLED_DEM,
"flow_direction.tif": spec.FLOW_DIRECTION_D8,
"flow_accumulation.tif": spec.FLOW_ACCUMULATION,
"preprocessed_geometries.gpkg": {
"about": (
"A vector containing only those geometries that the model can "
"verify are valid. The geometries appearing in this vector "
"will be the ones passed to watershed delineation."),
"geometries": spec_utils.ALL_GEOMS,
"geometries": spec.ALL_GEOMS,
"fields": {}
},
"streams.tif": spec_utils.STREAM,
"streams.tif": spec.STREAM,
"snapped_outlets.gpkg": {
"about": (
"A vector that indicates where outlet points (point "
@ -121,7 +131,7 @@ MODEL_SPEC = {
"Threshold Flow Accumulation and Pixel Distance to Snap "
"Outlet Points. Any non-point geometries will also have been "
"copied over to this vector, but will not have been altered."),
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": {}
},
"watersheds.gpkg": {
@ -129,18 +139,18 @@ MODEL_SPEC = {
"A vector defining the areas that are upstream from the "
"snapped outlet points, where upstream area is defined by the "
"D8 flow algorithm implementation in PyGeoprocessing."),
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"fields": {}
},
"pour_points.gpkg": {
"about": (
"Points where water flows off the defined area of the map."),
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": {}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
_OUTPUT_FILES = {
'preprocessed_geometries': 'preprocessed_geometries.gpkg',
@ -818,5 +828,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
return validation.validate(args, MODEL_SPEC)

View File

@ -3,6 +3,7 @@
An implementation of the model described in 'Degradation in carbon stocks
near tropical forest edges', by Chaplin-Kramer et. al (2015).
"""
import copy
import logging
import os
import pickle
@ -21,10 +22,9 @@ from osgeo import gdal
from osgeo import ogr
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
@ -35,23 +35,32 @@ DISTANCE_UPPER_BOUND = 500e3
# helpful to have a global nodata defined for the whole model
NODATA_VALUE = -1
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "forest_carbon_edge_effect",
"model_name": MODEL_METADATA["forest_carbon_edge_effect"].model_title,
"pyname": MODEL_METADATA["forest_carbon_edge_effect"].pyname,
"userguide": MODEL_METADATA["forest_carbon_edge_effect"].userguide,
"model_title": gettext("Forest Carbon Edge Effect"),
"userguide": "carbon_edge.html",
"aliases": ("fc",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['lulc_raster_path', 'biophysical_table_path', 'pools_to_calculate'],
['compute_forest_edge_effects', 'tropical_forest_edge_carbon_model_vector_path', 'n_nearest_model_points', 'biomass_to_carbon_conversion_factor'],
['aoi_vector_path']
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["aoi_vector_path", "lulc_raster_path"],
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"n_nearest_model_points": {
"expression": "value > 0 and value.is_integer()",
"type": "number",
"units": u.none,
"required": "compute_forest_edge_effects",
"allowed": "compute_forest_edge_effects",
"about": gettext(
"Number of closest regression models that are used when "
"calculating the total biomass. Each local model is linearly "
@ -63,7 +72,7 @@ MODEL_SPEC = {
"name": gettext("number of points to average")
},
"aoi_vector_path": {
**spec_utils.AOI,
**spec.AOI,
"projected": True,
"required": False
},
@ -71,7 +80,7 @@ MODEL_SPEC = {
"type": "csv",
"index_col": "lucode",
"columns": {
"lucode": spec_utils.LULC_TABLE_COLUMN,
"lucode": spec.LULC_TABLE_COLUMN,
"is_tropical_forest": {
"type": "boolean",
"about": gettext(
@ -115,8 +124,8 @@ MODEL_SPEC = {
"name": gettext("biophysical table")
},
"lulc_raster_path": {
**spec_utils.LULC,
"about": spec_utils.LULC['about'] + " " + gettext(
**spec.LULC,
"about": spec.LULC['about'] + " " + gettext(
"All values in this raster must "
"have corresponding entries in the Biophysical Table."),
"projected": True,
@ -171,10 +180,11 @@ MODEL_SPEC = {
"θ₃ parameter for the regression equation. "
"Used only for the asymptotic model.")}
},
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"projected": True,
"projection_units": u.meter,
"required": "compute_forest_edge_effects",
"allowed": "compute_forest_edge_effects",
"about": gettext(
"Map storing the optimal regression model for each tropical "
"subregion and the corresponding theta parameters for that "
@ -185,6 +195,7 @@ MODEL_SPEC = {
"biomass_to_carbon_conversion_factor": {
"type": "ratio",
"required": "compute_forest_edge_effects",
"allowed": "compute_forest_edge_effects",
"about": gettext(
"Proportion of forest edge biomass that is elemental carbon. "
"Required if Compute Forest Edge Effects is selected."),
@ -205,7 +216,7 @@ MODEL_SPEC = {
},
"aggregated_carbon_stocks.shp": {
"about": "AOI map with aggregated carbon statistics.",
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"c_sum": {
"type": "number",
@ -242,7 +253,7 @@ MODEL_SPEC = {
"about": (
"The regression parameters reprojected to match your "
"study area."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {}
},
"edge_distance.tif": {
@ -263,14 +274,14 @@ MODEL_SPEC = {
"about": (
"The Global Regression Models shapefile clipped "
"to the study area."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {}
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
def execute(args):
@ -430,9 +441,9 @@ def execute(args):
# Map non-forest landcover codes to carbon biomasses
LOGGER.info('Calculating direct mapped carbon stocks')
carbon_maps = []
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
**MODEL_SPEC['args']['biophysical_table_path'])
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
pool_list = [('c_above', True)]
if args['pools_to_calculate'] == 'all':
pool_list.extend([
@ -660,8 +671,8 @@ def _calculate_lulc_carbon_map(
"""
# classify forest pixels from lulc
biophysical_df = validation.get_validated_dataframe(
biophysical_table_path, **MODEL_SPEC['args']['biophysical_table_path'])
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(biophysical_table_path)
lucode_to_per_cell_carbon = {}
@ -720,8 +731,9 @@ def _map_distance_from_tropical_forest_edge(
"""
# Build a list of forest lucodes
biophysical_df = validation.get_validated_dataframe(
biophysical_table_path, **MODEL_SPEC['args']['biophysical_table_path'])
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
biophysical_table_path)
forest_codes = biophysical_df[biophysical_df['is_tropical_forest']].index.values
# Make a raster where 1 is non-forest landcover types and 0 is forest
@ -1133,25 +1145,9 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
invalid_keys = set([])
for affected_keys, error_msg in validation_warnings:
for key in affected_keys:
invalid_keys.add(key)
if ('pools_to_calculate' not in invalid_keys and
'biophysical_table_path' not in invalid_keys):
if args['pools_to_calculate'] == 'all':
# other fields have already been checked by validate
required_fields = ['c_above', 'c_below', 'c_soil', 'c_dead']
error_msg = validation.check_csv(
args['biophysical_table_path'],
header_patterns=required_fields,
axis=1)
if error_msg:
validation_warnings.append(
(['biophysical_table_path'], error_msg))
return validation_warnings
model_spec = copy.deepcopy(MODEL_SPEC)
if 'pools_to_calculate' in args and args['pools_to_calculate'] == 'all':
model_spec.get_input('biophysical_table_path').columns.get('c_below').required = True
model_spec.get_input('biophysical_table_path').columns.get('c_soil').required = True
model_spec.get_input('biophysical_table_path').columns.get('c_dead').required = True
return validation.validate(args, model_spec)

View File

@ -11,10 +11,9 @@ import taskgraph
from osgeo import gdal
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
@ -33,11 +32,18 @@ MISSING_MAX_DIST_MSG = gettext(
"Maximum distance value is missing for threats: {threat_list}.")
MISSING_WEIGHT_MSG = gettext("Weight value is missing for threats: {threat_list}.")
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "habitat_quality",
"model_name": MODEL_METADATA["habitat_quality"].model_title,
"pyname": MODEL_METADATA["habitat_quality"].pyname,
"userguide": MODEL_METADATA["habitat_quality"].userguide,
"model_title": gettext("Habitat Quality"),
"userguide": "habitat_quality.html",
"aliases": ("hq",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['lulc_cur_path', 'lulc_fut_path', 'lulc_bas_path'],
['threats_table_path', 'access_vector_path', 'sensitivity_table_path', 'half_saturation_constant'],
]
},
"args_with_spatial_overlap": {
"spatial_keys": [
"lulc_cur_path", "lulc_fut_path", "lulc_bas_path",
@ -45,11 +51,11 @@ MODEL_SPEC = {
"different_projections_ok": True,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"lulc_cur_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"about": gettext(
"Map of LULC at present. All values in this raster must "
@ -57,7 +63,7 @@ MODEL_SPEC = {
"name": gettext("current land cover")
},
"lulc_fut_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"required": False,
"about": gettext(
@ -68,7 +74,7 @@ MODEL_SPEC = {
"name": gettext("future land cover")
},
"lulc_bas_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"required": False,
"about": gettext(
@ -169,7 +175,7 @@ MODEL_SPEC = {
"represents completely accessible.")
}
},
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"required": False,
"about": gettext(
"Map of the relative protection that legal, institutional, "
@ -181,7 +187,7 @@ MODEL_SPEC = {
"type": "csv",
"index_col": "lucode",
"columns": {
"lucode": spec_utils.LULC_TABLE_COLUMN,
"lucode": spec.LULC_TABLE_COLUMN,
"name": {
"type": "freestyle_string",
"required": False
@ -377,9 +383,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
# All out rasters besides rarity should be gte to 0. Set nodata accordingly.
_OUT_NODATA = float(numpy.finfo(numpy.float32).min)
# Scaling parameter from User's Guide eq. 4 for quality of habitat
@ -448,12 +454,12 @@ def execute(args):
LOGGER.info("Checking Threat and Sensitivity tables for compliance")
# Get CSVs as dictionaries and ensure the key is a string for threats.
threat_df = validation.get_validated_dataframe(
args['threats_table_path'], **MODEL_SPEC['args']['threats_table_path']
).fillna('')
sensitivity_df = validation.get_validated_dataframe(
args['sensitivity_table_path'],
**MODEL_SPEC['args']['sensitivity_table_path'])
threat_df = MODEL_SPEC.get_input(
'threats_table_path').get_validated_dataframe(
args['threats_table_path']).fillna('')
sensitivity_df = MODEL_SPEC.get_input(
'sensitivity_table_path').get_validated_dataframe(
args['sensitivity_table_path'])
half_saturation_constant = float(args['half_saturation_constant'])
@ -1166,8 +1172,7 @@ def validate(args, limit_to=None):
the error message in the second part of the tuple. This should
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
@ -1175,12 +1180,12 @@ def validate(args, limit_to=None):
"sensitivity_table_path" not in invalid_keys and
"threat_raster_folder" not in invalid_keys):
# Get CSVs as dictionaries and ensure the key is a string for threats.
threat_df = validation.get_validated_dataframe(
args['threats_table_path'],
**MODEL_SPEC['args']['threats_table_path']).fillna('')
sensitivity_df = validation.get_validated_dataframe(
args['sensitivity_table_path'],
**MODEL_SPEC['args']['sensitivity_table_path'])
threat_df = MODEL_SPEC.get_input(
'threats_table_path').get_validated_dataframe(
args['threats_table_path']).fillna('')
sensitivity_df = MODEL_SPEC.get_input(
'sensitivity_table_path').get_validated_dataframe(
args['sensitivity_table_path'])
# check that the threat names in the threats table match with the
# threats columns in the sensitivity table.

View File

@ -16,14 +16,13 @@ from osgeo import gdal
from osgeo import ogr
from osgeo import osr
from . import datastack
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
# RESILIENCE stressor shorthand to use when parsing tables
@ -49,15 +48,26 @@ _DEFAULT_GTIFF_CREATION_OPTIONS = (
'TILED=YES', 'BIGTIFF=YES', 'COMPRESS=DEFLATE',
'BLOCKXSIZE=256', 'BLOCKYSIZE=256')
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "habitat_risk_assessment",
"model_name": MODEL_METADATA["habitat_risk_assessment"].model_title,
"pyname": MODEL_METADATA["habitat_risk_assessment"].pyname,
"userguide": MODEL_METADATA["habitat_risk_assessment"].userguide,
"model_title": gettext("Habitat Risk Assessment"),
"userguide": "habitat_risk_assessment.html",
"aliases": ("hra",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['info_table_path', 'criteria_table_path'],
['resolution', 'max_rating'],
['risk_eq', 'decay_eq'],
['aoi_vector_path'],
['n_overlapping_stressors'],
['visualize_outputs']
]
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"info_table_path": {
"name": gettext("habitat stressor table"),
"about": gettext("A table describing each habitat and stressor."),
@ -81,7 +91,7 @@ MODEL_SPEC = {
"values besides 0 or 1 will be treated as 0.")
}},
"fields": {},
"geometries": spec_utils.ALL_GEOMS,
"geometries": spec.ALL_GEOMS,
"about": gettext(
"Map of where the habitat or stressor exists. For "
"rasters, a pixel value of 1 indicates presence of "
@ -173,7 +183,7 @@ MODEL_SPEC = {
}
},
"aoi_vector_path": {
**spec_utils.AOI,
**spec.AOI,
"projected": True,
"projection_units": u.meter,
"fields": {
@ -286,7 +296,7 @@ MODEL_SPEC = {
"about": (
"Map of habitat-specific risk visualized in gradient "
"color from white to red on a map."),
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"fields": {
"Risk Score": {
"type": "integer",
@ -301,7 +311,7 @@ MODEL_SPEC = {
"about": (
"Map of ecosystem risk visualized in gradient "
"color from white to red on a map."),
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"fields": {
"Risk Score": {
"type": "integer",
@ -314,7 +324,7 @@ MODEL_SPEC = {
},
"STRESSOR_[STRESSOR].geojson": {
"about": "Map of stressor extent visualized in orange color.",
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"fields": {}
},
"SUMMARY_STATISTICS.csv": {
@ -384,7 +394,7 @@ MODEL_SPEC = {
"polygonized_[HABITAT/STRESSOR].gpkg": {
"about": "Polygonized habitat or stressor map",
"fields": {},
"geometries": spec_utils.POLYGON
"geometries": spec.POLYGON
},
"reclass_[HABITAT]_[STRESSOR].tif": {
"about": (
@ -408,7 +418,7 @@ MODEL_SPEC = {
"were provided in a spatial vector format, it will be "
"reprojected to the AOI projection."),
"fields": {},
"geometries": spec_utils.POLYGONS
"geometries": spec.POLYGONS
},
"rewritten_[HABITAT/STRESSOR/CRITERIA].tif": {
"about": (
@ -428,16 +438,16 @@ MODEL_SPEC = {
"provided are simplified to 1/2 the user-defined "
"raster resolution in order to speed up rasterization."),
"fields": {},
"geometries": spec_utils.POLYGONS
"geometries": spec.POLYGONS
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
_VALID_RISK_EQS = set(MODEL_SPEC['args']['risk_eq']['options'].keys())
_VALID_DECAY_TYPES = set(MODEL_SPEC['args']['decay_eq']['options'].keys())
_VALID_RISK_EQS = set(MODEL_SPEC.get_input('risk_eq').options.keys())
_VALID_DECAY_TYPES = set(MODEL_SPEC.get_input('decay_eq').options.keys())
def execute(args):
@ -1778,8 +1788,8 @@ def _parse_info_table(info_table_path):
info_table_path = os.path.abspath(info_table_path)
try:
table = validation.get_validated_dataframe(
info_table_path, **MODEL_SPEC['args']['info_table_path'])
table = MODEL_SPEC.get_input(
'info_table_path').get_validated_dataframe(info_table_path)
except ValueError as err:
if 'Index has duplicate keys' in str(err):
raise ValueError("Habitat and stressor names may not overlap.")
@ -2431,7 +2441,7 @@ def _override_datastack_archive_criteria_table_path(
os.path.splitext(os.path.basename(value))[0])
LOGGER.info(f"Copying spatial file {value} --> "
f"{dir_for_this_spatial_data}")
new_path = datastack._copy_spatial_files(
new_path = utils.copy_spatial_files(
value, dir_for_this_spatial_data)
criteria_table_array[row, col] = new_path
known_files[value] = new_path
@ -2461,4 +2471,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(args, MODEL_SPEC['args'])
return validation.validate(args, MODEL_SPEC)

View File

@ -1,167 +0,0 @@
import dataclasses
from natcap.invest import gettext
@dataclasses.dataclass
class _MODELMETA:
"""Dataclass to store frequently used model metadata."""
model_title: str # display name for the model
pyname: str # importable python module name for the model
gui: str # importable python class for the corresponding Qt UI
userguide: str # name of the corresponding built userguide file
aliases: tuple # alternate names for the model, if any
MODEL_METADATA = {
'annual_water_yield': _MODELMETA(
model_title=gettext('Annual Water Yield'),
pyname='natcap.invest.annual_water_yield',
gui='annual_water_yield.AnnualWaterYield',
userguide='annual_water_yield.html',
aliases=('hwy', 'awy')),
'carbon': _MODELMETA(
model_title=gettext('Carbon Storage and Sequestration'),
pyname='natcap.invest.carbon',
gui='carbon.Carbon',
userguide='carbonstorage.html',
aliases=()),
'coastal_blue_carbon': _MODELMETA(
model_title=gettext('Coastal Blue Carbon'),
pyname='natcap.invest.coastal_blue_carbon.coastal_blue_carbon',
gui='cbc.CoastalBlueCarbon',
userguide='coastal_blue_carbon.html',
aliases=('cbc',)),
'coastal_blue_carbon_preprocessor': _MODELMETA(
model_title=gettext('Coastal Blue Carbon Preprocessor'),
pyname='natcap.invest.coastal_blue_carbon.preprocessor',
gui='cbc.CoastalBlueCarbonPreprocessor',
userguide='coastal_blue_carbon.html',
aliases=('cbc_pre',)),
'coastal_vulnerability': _MODELMETA(
model_title=gettext('Coastal Vulnerability'),
pyname='natcap.invest.coastal_vulnerability',
gui='coastal_vulnerability.CoastalVulnerability',
userguide='coastal_vulnerability.html',
aliases=('cv',)),
'crop_production_percentile': _MODELMETA(
model_title=gettext('Crop Production: Percentile'),
pyname='natcap.invest.crop_production_percentile',
gui='crop_production.CropProductionPercentile',
userguide='crop_production.html',
aliases=('cpp',)),
'crop_production_regression': _MODELMETA(
model_title=gettext('Crop Production: Regression'),
pyname='natcap.invest.crop_production_regression',
gui='crop_production.CropProductionRegression',
userguide='crop_production.html',
aliases=('cpr',)),
'delineateit': _MODELMETA(
model_title=gettext('DelineateIt'),
pyname='natcap.invest.delineateit.delineateit',
gui='delineateit.Delineateit',
userguide='delineateit.html',
aliases=()),
'forest_carbon_edge_effect': _MODELMETA(
model_title=gettext('Forest Carbon Edge Effect'),
pyname='natcap.invest.forest_carbon_edge_effect',
gui='forest_carbon.ForestCarbonEdgeEffect',
userguide='carbon_edge.html',
aliases=('fc',)),
'habitat_quality': _MODELMETA(
model_title=gettext('Habitat Quality'),
pyname='natcap.invest.habitat_quality',
gui='habitat_quality.HabitatQuality',
userguide='habitat_quality.html',
aliases=('hq',)),
'habitat_risk_assessment': _MODELMETA(
model_title=gettext('Habitat Risk Assessment'),
pyname='natcap.invest.hra',
gui='hra.HabitatRiskAssessment',
userguide='habitat_risk_assessment.html',
aliases=('hra',)),
'ndr': _MODELMETA(
model_title=gettext('Nutrient Delivery Ratio'),
pyname='natcap.invest.ndr.ndr',
gui='ndr.Nutrient',
userguide='ndr.html',
aliases=()),
'pollination': _MODELMETA(
model_title=gettext('Crop Pollination'),
pyname='natcap.invest.pollination',
gui='pollination.Pollination',
userguide='croppollination.html',
aliases=()),
'recreation': _MODELMETA(
model_title=gettext('Visitation: Recreation and Tourism'),
pyname='natcap.invest.recreation.recmodel_client',
gui='recreation.Recreation',
userguide='recreation.html',
aliases=()),
'routedem': _MODELMETA(
model_title=gettext('RouteDEM'),
pyname='natcap.invest.routedem',
gui='routedem.RouteDEM',
userguide='routedem.html',
aliases=()),
'scenario_generator_proximity': _MODELMETA(
model_title=gettext('Scenario Generator: Proximity Based'),
pyname='natcap.invest.scenario_gen_proximity',
gui='scenario_gen.ScenarioGenProximity',
userguide='scenario_gen_proximity.html',
aliases=('sgp',)),
'scenic_quality': _MODELMETA(
model_title=gettext('Scenic Quality'),
pyname='natcap.invest.scenic_quality.scenic_quality',
gui='scenic_quality.ScenicQuality',
userguide='scenic_quality.html',
aliases=('sq',)),
'sdr': _MODELMETA(
model_title=gettext('Sediment Delivery Ratio'),
pyname='natcap.invest.sdr.sdr',
gui='sdr.SDR',
userguide='sdr.html',
aliases=()),
'seasonal_water_yield': _MODELMETA(
model_title=gettext('Seasonal Water Yield'),
pyname='natcap.invest.seasonal_water_yield.seasonal_water_yield',
gui='seasonal_water_yield.SeasonalWaterYield',
userguide='seasonal_water_yield.html',
aliases=('swy',)),
'stormwater': _MODELMETA(
model_title=gettext('Urban Stormwater Retention'),
pyname='natcap.invest.stormwater',
gui='stormwater.Stormwater',
userguide='stormwater.html',
aliases=()),
'wave_energy': _MODELMETA(
model_title=gettext('Wave Energy Production'),
pyname='natcap.invest.wave_energy',
gui='wave_energy.WaveEnergy',
userguide='wave_energy.html',
aliases=()),
'wind_energy': _MODELMETA(
model_title=gettext('Wind Energy Production'),
pyname='natcap.invest.wind_energy',
gui='wind_energy.WindEnergy',
userguide='wind_energy.html',
aliases=()),
'urban_flood_risk_mitigation': _MODELMETA(
model_title=gettext('Urban Flood Risk Mitigation'),
pyname='natcap.invest.urban_flood_risk_mitigation',
gui='urban_flood_risk_mitigation.UrbanFloodRiskMitigation',
userguide='urban_flood_mitigation.html',
aliases=('ufrm',)),
'urban_cooling_model': _MODELMETA(
model_title=gettext('Urban Cooling'),
pyname='natcap.invest.urban_cooling_model',
gui='urban_cooling_model.UrbanCoolingModel',
userguide='urban_cooling_model.html',
aliases=('ucm',)),
'urban_nature_access': _MODELMETA(
model_title=gettext('Urban Nature Access'),
pyname='natcap.invest.urban_nature_access',
gui='urban_nature_access.UrbanNatureAccess',
userguide='urban_nature_access.html',
aliases=('una',)),
}

View File

@ -0,0 +1,59 @@
import importlib
import pkgutil
import natcap.invest
def is_invest_compliant_model(module):
"""Check if a python module is an invest model.
Args:
module (module): python module to check
Returns:
True if the module has a ``MODEL_SPEC`` dictionary attribute and
``execute`` and ``validate`` functions, False otherwise
"""
return (
hasattr(module, "execute") and callable(module.execute) and
hasattr(module, "validate") and callable(module.validate) and
# could also validate model spec structure
hasattr(module, "MODEL_SPEC"))
# pyname: importable name e.g. natcap.invest.carbon, natcap.invest.sdr.sdr
# model id: identifier e.g. coastal_blue_carbon
# model title: e.g. Coastal Blue Carbon
pyname_to_module = {}
# discover core invest models. we could maintain a list of these,
# but this way it's one less thing to update
for _, _name, _ispkg in pkgutil.iter_modules(natcap.invest.__path__):
if _name in {'__main__', 'cli', 'ui_server', 'datastack'}:
continue # avoid a circular import
_module = importlib.import_module(f'natcap.invest.{_name}')
if _ispkg:
for _, _sub_name, _ in pkgutil.iter_modules(_module.__path__):
_submodule = importlib.import_module(f'natcap.invest.{_name}.{_sub_name}')
if is_invest_compliant_model(_submodule):
pyname_to_module[f'natcap.invest.{_name}.{_sub_name}'] = _submodule
else:
if is_invest_compliant_model(_module):
pyname_to_module[f'natcap.invest.{_name}'] = _module
# discover plugins: identify packages whose name starts with invest-
# and meet the basic API criteria for an invest plugin
for _, _name, _ispkg in pkgutil.iter_modules():
if _name.startswith('invest'):
_module = importlib.import_module(_name)
if is_invest_compliant_model(_module):
pyname_to_module[_name] = _module
model_id_to_pyname = {}
pyname_to_model_id = {}
model_id_to_spec = {}
model_alias_to_id = {}
for _pyname, _model in pyname_to_module.items():
model_id_to_pyname[_model.MODEL_SPEC.model_id] = _pyname
pyname_to_model_id[_pyname] = _model.MODEL_SPEC.model_id
model_id_to_spec[_model.MODEL_SPEC.model_id] = _model.MODEL_SPEC
for _alias in _model.MODEL_SPEC.aliases:
model_alias_to_id[_alias] = _model.MODEL_SPEC.model_id

View File

@ -9,13 +9,13 @@ import pygeoprocessing
import pygeoprocessing.routing
import taskgraph
from osgeo import gdal
from osgeo import gdal_array
from osgeo import ogr
from .. import gettext
from .. import spec_utils
from .. import spec
from .. import utils
from .. import validation
from ..model_metadata import MODEL_METADATA
from ..sdr import sdr
from ..unit_registry import u
from . import ndr_core
@ -24,28 +24,37 @@ LOGGER = logging.getLogger(__name__)
MISSING_NUTRIENT_MSG = gettext('Either calc_n or calc_p must be True')
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "ndr",
"model_name": MODEL_METADATA["ndr"].model_title,
"pyname": MODEL_METADATA["ndr"].pyname,
"userguide": MODEL_METADATA["ndr"].userguide,
"model_title": gettext("Nutrient Delivery Ratio"),
"userguide": "ndr.html",
"aliases": (),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['dem_path', 'lulc_path', 'runoff_proxy_path', 'watersheds_path', 'biophysical_table_path'],
['calc_p'],
['calc_n', 'subsurface_critical_length_n', 'subsurface_eff_n'],
['flow_dir_algorithm', 'threshold_flow_accumulation', 'k_param', 'runoff_proxy_av'],
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["dem_path", "lulc_path", "runoff_proxy_path",
"watersheds_path"],
"different_projections_ok": True,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"dem_path": {
**spec_utils.DEM,
**spec.DEM,
"projected": True
},
"lulc_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"about": spec_utils.LULC['about'] + " " + gettext(
"about": spec.LULC['about'] + " " + gettext(
"All values in this raster must "
"have corresponding entries in the Biophysical table.")
},
@ -65,7 +74,7 @@ MODEL_SPEC = {
"watersheds_path": {
"type": "vector",
"projected": True,
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {},
"about": gettext(
"Map of the boundaries of the watershed(s) over which to "
@ -76,24 +85,94 @@ MODEL_SPEC = {
"type": "csv",
"index_col": "lucode",
"columns": {
"lucode": spec_utils.LULC_TABLE_COLUMN,
"load_[NUTRIENT]": { # nitrogen or phosphorus nutrient loads
"lucode": spec.LULC_TABLE_COLUMN,
"load_type_p": {
"type": "option_string",
"required": True,
"options": {
"application-rate": {
"description": gettext(
"Treat the load values as nutrient "
"application rates (e.g. fertilizer, livestock "
"waste, ...)."
"The model will adjust the load using the "
"application rate and retention efficiency: "
"load_p * (1 - eff_p).")},
"measured-runoff": {
"description": gettext(
"Treat the load values as measured contaminant "
"runoff.")},
},
"about": gettext(
"Whether the nutrient load in column "
"load_p should be treated as "
"nutrient application rate or measured contaminant "
"runoff. 'application-rate' | 'measured-runoff'")
},
"load_type_n": {
"type": "option_string",
"required": True,
"options": {
"application-rate": {
"description": gettext(
"Treat the load values as nutrient "
"application rates (e.g. fertilizer, livestock "
"waste, ...)."
"The model will adjust the load using the "
"application rate and retention efficiency: "
"load_n * (1 - eff_n).")},
"measured-runoff": {
"description": gettext(
"Treat the load values as measured contaminant "
"runoff.")},
},
"about": gettext(
"Whether the nutrient load in column "
"load_n should be treated as "
"nutrient application rate or measured contaminant "
"runoff. 'application-rate' | 'measured-runoff'")
},
"load_n": { # nitrogen or phosphorus nutrient loads
"type": "number",
"units": u.kilogram/u.hectare/u.year,
"required": "calc_n",
"about": gettext(
"The nutrient loading for this land use class.")},
"eff_[NUTRIENT]": { # nutrient retention capacities
"The nitrogen loading for this land use class.")},
"load_p": {
"type": "number",
"units": u.kilogram/u.hectare/u.year,
"required": "calc_p",
"about": gettext(
"The phosphorus loading for this land use class.")},
"eff_n": {
"type": "ratio",
"required": "calc_n",
"about": gettext(
"Maximum nutrient retention efficiency. This is the "
"maximum proportion of the nutrient that is retained "
"Maximum nitrogen retention efficiency. This is the "
"maximum proportion of the nitrogen that is retained "
"on this LULC class.")},
"crit_len_[NUTRIENT]": { # nutrient critical lengths
"eff_p": {
"type": "ratio",
"required": "calc_p",
"about": gettext(
"Maximum phosphorus retention efficiency. This is the "
"maximum proportion of the phosphorus that is retained "
"on this LULC class.")},
"crit_len_n": {
"type": "number",
"units": u.meter,
"required": "calc_n",
"about": gettext(
"The distance after which it is assumed that this "
"LULC type retains the nutrient at its maximum "
"LULC type retains nitrogen at its maximum "
"capacity.")},
"crit_len_p": {
"type": "number",
"units": u.meter,
"required": "calc_p",
"about": gettext(
"The distance after which it is assumed that this "
"LULC type retains phosphorus at its maximum "
"capacity.")},
"proportion_subsurface_n": {
"type": "ratio",
@ -103,17 +182,15 @@ MODEL_SPEC = {
"is dissolved into the subsurface. By default, this "
"value should be set to 0, indicating that all "
"nutrients are delivered via surface flow. There is "
"no equivalent of this for phosphorus.")}
"no equivalent of this for phosphorus.")},
},
"about": gettext(
"A table mapping each LULC class to its biophysical "
"properties related to nutrient load and retention. Replace "
"'[NUTRIENT]' in the column names with 'n' or 'p' for "
"nitrogen or phosphorus respectively. Nitrogen data must be "
"provided if Calculate Nitrogen is selected. Phosphorus data "
"must be provided if Calculate Phosphorus is selected. All "
"LULC codes in the LULC raster must have corresponding "
"entries in this table."),
"properties related to nutrient load and retention. Nitrogen "
"data must be provided if Calculate Nitrogen is selected. "
"Phosphorus data must be provided if Calculate Phosphorus is "
"selected. All LULC codes in the LULC raster must have "
"corresponding entries in this table."),
"name": gettext("biophysical table")
},
"calc_p": {
@ -127,7 +204,7 @@ MODEL_SPEC = {
"name": gettext("calculate nitrogen")
},
"threshold_flow_accumulation": {
**spec_utils.THRESHOLD_FLOW_ACCUMULATION
**spec.THRESHOLD_FLOW_ACCUMULATION
},
"k_param": {
"type": "number",
@ -159,6 +236,7 @@ MODEL_SPEC = {
"type": "number",
"units": u.meter,
"required": "calc_n",
"allowed": "calc_n",
"name": gettext("subsurface critical length (nitrogen)"),
"about": gettext(
"The distance traveled (subsurface and downslope) after which "
@ -168,6 +246,7 @@ MODEL_SPEC = {
"subsurface_eff_n": {
"type": "ratio",
"required": "calc_n",
"allowed": "calc_n",
"name": gettext("subsurface maximum retention efficiency (nitrogen)"),
"about": gettext(
"The maximum nitrogen retention efficiency that can be "
@ -175,12 +254,12 @@ MODEL_SPEC = {
"retention due to biochemical degradation in soils. Required "
"if Calculate Nitrogen is selected.")
},
**spec_utils.FLOW_DIR_ALGORITHM
**spec.FLOW_DIR_ALGORITHM
},
"outputs": {
"watershed_results_ndr.gpkg": {
"about": "Vector with aggregated nutrient model results per watershed.",
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"p_surface_load": {
"type": "number",
@ -247,6 +326,7 @@ MODEL_SPEC = {
"units": u.kilogram/u.hectare
}}
},
"stream.tif": spec.STREAM,
"intermediate_outputs": {
"type": "directory",
"contents": {
@ -288,8 +368,8 @@ MODEL_SPEC = {
"about": "Effective phosphorus retention provided by the downslope flow path for each pixel",
"bands": {1: {"type": "ratio"}}
},
"flow_accumulation.tif": spec_utils.FLOW_ACCUMULATION,
"flow_direction.tif": spec_utils.FLOW_DIRECTION,
"flow_accumulation.tif": spec.FLOW_ACCUMULATION,
"flow_direction.tif": spec.FLOW_DIRECTION,
"ic_factor.tif": {
"about": "Index of connectivity",
"bands": {1: {"type": "ratio"}}
@ -346,7 +426,6 @@ MODEL_SPEC = {
"about": "Inverse of slope",
"bands": {1: {"type": "number", "units": u.none}}
},
"stream.tif": spec_utils.STREAM,
"sub_load_n.tif": {
"about": "Nitrogen loads for subsurface transport",
"bands": {1: {
@ -362,14 +441,14 @@ MODEL_SPEC = {
"about": "Above ground nitrogen loads",
"bands": {1: {
"type": "number",
"units": u.kilogram/u.year
"units": u.kilogram/u.hectare/u.year,
}}
},
"surface_load_p.tif": {
"about": "Above ground phosphorus loads",
"bands": {1: {
"type": "number",
"units": u.kilogram/u.year
"units": u.kilogram/u.hectare/u.year,
}}
},
"thresholded_slope.tif": {
@ -416,8 +495,8 @@ MODEL_SPEC = {
"about": "Runoff proxy input masked to exclude pixels outside the watershed",
"bands": {1: {"type": "number", "units": u.none}}
},
"filled_dem.tif": spec_utils.FILLED_DEM,
"slope.tif": spec_utils.SLOPE,
"filled_dem.tif": spec.FILLED_DEM,
"slope.tif": spec.SLOPE,
"subsurface_export_n.pickle": {
"about": "Pickled zonal statistics of nitrogen subsurface export"
},
@ -441,9 +520,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
_OUTPUT_BASE_FILES = {
'n_surface_export_path': 'n_surface_export.tif',
@ -451,6 +530,7 @@ _OUTPUT_BASE_FILES = {
'n_total_export_path': 'n_total_export.tif',
'p_surface_export_path': 'p_surface_export.tif',
'watershed_results_ndr_path': 'watershed_results_ndr.gpkg',
'stream_path': 'stream.tif'
}
INTERMEDIATE_DIR_NAME = 'intermediate_outputs'
@ -468,7 +548,6 @@ _INTERMEDIATE_BASE_FILES = {
's_accumulation_path': 's_accumulation.tif',
's_bar_path': 's_bar.tif',
's_factor_inverse_path': 's_factor_inverse.tif',
'stream_path': 'stream.tif',
'sub_load_n_path': 'sub_load_n.tif',
'surface_load_n_path': 'surface_load_n.tif',
'surface_load_p_path': 'surface_load_p.tif',
@ -591,9 +670,9 @@ def execute(args):
if args['calc_' + nutrient_id]:
nutrients_to_process.append(nutrient_id)
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
**MODEL_SPEC['args']['biophysical_table_path'])
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
# Ensure that if user doesn't explicitly assign a value,
# runoff_proxy_av = None
@ -650,7 +729,7 @@ def execute(args):
'mask_raster_path': f_reg['mask_path'],
'target_masked_raster_path': f_reg['masked_runoff_proxy_path'],
'target_dtype': gdal.GDT_Float32,
'default_nodata': _TARGET_NODATA,
'target_nodata': _TARGET_NODATA,
},
dependent_task_list=[mask_task, align_raster_task],
target_path_list=[f_reg['masked_runoff_proxy_path']],
@ -663,7 +742,7 @@ def execute(args):
'mask_raster_path': f_reg['mask_path'],
'target_masked_raster_path': f_reg['masked_dem_path'],
'target_dtype': gdal.GDT_Float32,
'default_nodata': float(numpy.finfo(numpy.float32).min),
'target_nodata': float(numpy.finfo(numpy.float32).min),
},
dependent_task_list=[mask_task, align_raster_task],
target_path_list=[f_reg['masked_dem_path']],
@ -676,7 +755,7 @@ def execute(args):
'mask_raster_path': f_reg['mask_path'],
'target_masked_raster_path': f_reg['masked_lulc_path'],
'target_dtype': gdal.GDT_Int32,
'default_nodata': numpy.iinfo(numpy.int32).min,
'target_nodata': numpy.iinfo(numpy.int32).min,
},
dependent_task_list=[mask_task, align_raster_task],
target_path_list=[f_reg['masked_lulc_path']],
@ -900,7 +979,10 @@ def execute(args):
func=_calculate_load,
args=(
f_reg['masked_lulc_path'],
biophysical_df[f'load_{nutrient}'],
biophysical_df[
[f'load_{nutrient}', f'eff_{nutrient}',
f'load_type_{nutrient}']].to_dict('index'),
nutrient,
load_path),
dependent_task_list=[align_raster_task, mask_lulc_task],
target_path_list=[load_path],
@ -1158,7 +1240,7 @@ def _create_mask_raster(source_raster_path, source_vector_path,
def _mask_raster(source_raster_path, mask_raster_path,
target_masked_raster_path, default_nodata, target_dtype):
target_masked_raster_path, target_nodata, target_dtype):
"""Using a raster of 1s and 0s, determine which pixels remain in output.
Args:
@ -1170,8 +1252,8 @@ def _mask_raster(source_raster_path, mask_raster_path,
target raster.
target_masked_raster_path (str): The path to where the target raster
should be written.
default_nodata (int, float, None): The nodata value that should be used
if ``source_raster_path`` does not have a defined nodata value.
target_nodata (int, float): The target nodata value that should match
``target_dtype``.
target_dtype (int): The ``gdal.GDT_*`` datatype of the target raster.
Returns:
@ -1179,22 +1261,20 @@ def _mask_raster(source_raster_path, mask_raster_path,
"""
source_raster_info = pygeoprocessing.get_raster_info(source_raster_path)
source_nodata = source_raster_info['nodata'][0]
nodata = source_nodata
if nodata is None:
nodata = default_nodata
target_numpy_dtype = gdal_array.GDALTypeCodeToNumericTypeCode(target_dtype)
def _mask_op(mask, raster):
result = numpy.full(mask.shape, nodata,
dtype=source_raster_info['numpy_type'])
result = numpy.full(mask.shape, target_nodata,
dtype=target_numpy_dtype)
valid_pixels = (
~pygeoprocessing.array_equals_nodata(raster, nodata) &
~pygeoprocessing.array_equals_nodata(raster, source_nodata) &
(mask == 1))
result[valid_pixels] = raster[valid_pixels]
return result
pygeoprocessing.raster_calculator(
[(mask_raster_path, 1), (source_raster_path, 1)], _mask_op,
target_masked_raster_path, target_dtype, nodata)
target_masked_raster_path, target_dtype, target_nodata)
def _add_fields_to_shapefile(field_pickle_map, target_vector_path):
@ -1256,7 +1336,7 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
spec_copy = copy.deepcopy(MODEL_SPEC['args'])
spec_copy = copy.deepcopy(MODEL_SPEC)
# Check required fields given the state of ``calc_n`` and ``calc_p``
nutrients_selected = []
for nutrient_letter in ('n', 'p'):
@ -1265,17 +1345,14 @@ def validate(args, limit_to=None):
for param in ['load', 'eff', 'crit_len']:
for nutrient in nutrients_selected:
spec_copy['biophysical_table_path']['columns'][f'{param}_{nutrient}'] = (
spec_copy['biophysical_table_path']['columns'][f'{param}_[NUTRIENT]'])
spec_copy['biophysical_table_path']['columns'][f'{param}_{nutrient}']['required'] = True
spec_copy['biophysical_table_path']['columns'].pop(f'{param}_[NUTRIENT]')
spec_copy.get_input('biophysical_table_path').columns.get(
f'{param}_{nutrient}').required = True
if 'n' in nutrients_selected:
spec_copy['biophysical_table_path']['columns']['proportion_subsurface_n'][
'required'] = True
spec_copy.get_input('biophysical_table_path').columns.get(
'proportion_subsurface_n').required = True
validation_warnings = validation.validate(
args, spec_copy, MODEL_SPEC['args_with_spatial_overlap'])
validation_warnings = validation.validate(args, spec_copy)
if not nutrients_selected:
validation_warnings.append(
@ -1322,13 +1399,18 @@ def _normalize_raster(base_raster_path_band, target_normalized_raster_path,
target_dtype=numpy.float32)
def _calculate_load(lulc_raster_path, lucode_to_load, target_load_raster):
def _calculate_load(
lulc_raster_path, lucode_to_load, nutrient_type, target_load_raster):
"""Calculate load raster by mapping landcover.
If load type is 'application-rate' adjust by ``1 - efficiency``.
Args:
lulc_raster_path (string): path to integer landcover raster.
lucode_to_load (dict): a mapping of landcover IDs to per-area
nutrient load.
lucode_to_load (dict): a mapping of landcover IDs to nutrient load,
efficiency, and load type. The load type value can be one of:
[ 'measured-runoff' | 'appliation-rate' ].
nutrient_type (str): the nutrient type key ('p' | 'n').
target_load_raster (string): path to target raster that will have
load values (kg/ha) mapped to pixels based on LULC.
@ -1336,12 +1418,34 @@ def _calculate_load(lulc_raster_path, lucode_to_load, target_load_raster):
None.
"""
app_rate = 'application-rate'
measured_runoff = 'measured-runoff'
load_key = f'load_{nutrient_type}'
eff_key = f'eff_{nutrient_type}'
load_type_key = f'load_type_{nutrient_type}'
# Raise ValueError if unknown load_type
for key, value in lucode_to_load.items():
load_type = value[load_type_key]
if not load_type in [app_rate, measured_runoff]:
# unknown load type, raise ValueError
raise ValueError(
'nutrient load type must be: '
f'"{app_rate}" | "{measured_runoff}". Instead '
f'found value of: "{load_type}".')
def _map_load_op(lucode_array):
"""Convert unit load to total load & handle nodata."""
"""Convert unit load to total load."""
result = numpy.empty(lucode_array.shape)
for lucode in numpy.unique(lucode_array):
try:
result[lucode_array == lucode] = (lucode_to_load[lucode])
if lucode_to_load[lucode][load_type_key] == measured_runoff:
result[lucode_array == lucode] = (
lucode_to_load[lucode][load_key])
elif lucode_to_load[lucode][load_type_key] == app_rate:
result[lucode_array == lucode] = (
lucode_to_load[lucode][load_key] * (
1 - lucode_to_load[lucode][eff_key]))
except KeyError:
raise KeyError(
'lucode: %d is present in the landuse raster but '
@ -1362,8 +1466,7 @@ def _map_surface_load(
"""Calculate surface load from landcover raster.
Args:
modified_load_path (string): path to modified load raster with units
of kg/pixel.
modified_load_path (string): path to modified load raster.
lulc_raster_path (string): path to landcover raster.
lucode_to_subsurface_proportion (dict): maps landcover codes to
subsurface proportion values. Or if None, no subsurface transfer

View File

@ -15,25 +15,31 @@ from osgeo import gdal
from osgeo import ogr
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "pollination",
"model_name": MODEL_METADATA["pollination"].model_title,
"pyname": MODEL_METADATA["pollination"].pyname,
"userguide": MODEL_METADATA["pollination"].userguide,
"model_title": gettext("Crop Pollination"),
"userguide": "croppollination.html",
"aliases": (),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['landcover_raster_path', 'landcover_biophysical_table_path'],
['guild_table_path', 'farm_vector_path']
]
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"landcover_raster_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"about": gettext(
"Map of LULC codes. All values in this raster must have "
@ -92,7 +98,7 @@ MODEL_SPEC = {
"type": "csv",
"index_col": "lucode",
"columns": {
"lucode": spec_utils.LULC_TABLE_COLUMN,
"lucode": spec.LULC_TABLE_COLUMN,
"nesting_[SUBSTRATE]_availability_index": {
"type": "ratio",
"about": gettext(
@ -167,7 +173,7 @@ MODEL_SPEC = {
"The proportion of pollination required on the farm "
"that is provided by managed pollinators.")}
},
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"required": False,
"about": gettext(
"Map of farm sites to be analyzed, with pollination data "
@ -180,7 +186,7 @@ MODEL_SPEC = {
"created_if": "farm_vector_path",
"about": gettext(
"A copy of the input farm polygon vector file with additional fields"),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"p_abund": {
"about": (
@ -312,13 +318,13 @@ MODEL_SPEC = {
"reprojected_farm_vector.shp": {
"about": "Farm vector reprojected to the LULC projection",
"fields": {},
"geometries": spec_utils.POLYGONS
"geometries": spec.POLYGONS
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
_INDEX_NODATA = -1
@ -1214,8 +1220,8 @@ def _parse_scenario_variables(args):
else:
farm_vector_path = None
guild_df = validation.get_validated_dataframe(
guild_table_path, **MODEL_SPEC['args']['guild_table_path'])
guild_df = MODEL_SPEC.get_input(
'guild_table_path').get_validated_dataframe(guild_table_path)
LOGGER.info('Checking to make sure guild table has all expected headers')
for header in _EXPECTED_GUILD_HEADERS:
@ -1226,9 +1232,9 @@ def _parse_scenario_variables(args):
f"'{header}' but was unable to find one. Here are all the "
f"headers from {guild_table_path}: {', '.join(guild_df.columns)}")
landcover_biophysical_df = validation.get_validated_dataframe(
landcover_biophysical_table_path,
**MODEL_SPEC['args']['landcover_biophysical_table_path'])
landcover_biophysical_df = MODEL_SPEC.get_input(
'landcover_biophysical_table_path').get_validated_dataframe(
landcover_biophysical_table_path)
biophysical_table_headers = landcover_biophysical_df.columns
for header in _EXPECTED_BIOPHYSICAL_HEADERS:
matches = re.findall(header, " ".join(biophysical_table_headers))
@ -1492,4 +1498,4 @@ def validate(args, limit_to=None):
# Deliberately not validating the interrelationship of the columns between
# the biophysical table and the guilds table as the model itself already
# does extensive checking for this.
return validation.validate(args, MODEL_SPEC['args'])
return validation.validate(args, MODEL_SPEC)

View File

@ -30,10 +30,9 @@ from osgeo import osr
# prefer to do intrapackage imports to avoid case where global package is
# installed and we import the global version of it rather than the local
from .. import gettext
from .. import spec_utils
from .. import spec
from .. import utils
from .. import validation
from ..model_metadata import MODEL_METADATA
from ..unit_registry import u
LOGGER = logging.getLogger(__name__)
@ -59,7 +58,7 @@ predictor_table_columns = {
"about": gettext("A spatial file to use as a predictor."),
"bands": {1: {"type": "number", "units": u.none}},
"fields": {},
"geometries": spec_utils.ALL_GEOMS
"geometries": spec.ALL_GEOMS
},
"type": {
"type": "option_string",
@ -100,17 +99,26 @@ predictor_table_columns = {
}
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "recreation",
"model_name": MODEL_METADATA["recreation"].model_title,
"pyname": MODEL_METADATA["recreation"].pyname,
"userguide": MODEL_METADATA["recreation"].userguide,
"model_title": gettext("Visitation: Recreation and Tourism"),
"userguide": "recreation.html",
"aliases": (),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['aoi_path'],
['start_year', 'end_year'],
['compute_regression', 'predictor_table_path', 'scenario_predictor_table_path'],
['grid_aoi', 'grid_type', 'cell_size'],
]
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"aoi_path": {
**spec_utils.AOI,
**spec.AOI,
"about": gettext("Map of area(s) over which to run the model.")
},
"hostname": {
@ -119,7 +127,8 @@ MODEL_SPEC = {
"about": gettext(
"FQDN to a recreation server. If not provided, a default is "
"assumed."),
"name": gettext("hostname")
"name": gettext("hostname"),
"hidden": True
},
"port": {
"type": "number",
@ -129,7 +138,8 @@ MODEL_SPEC = {
"about": gettext(
"the port on ``hostname`` to use for contacting the "
"recreation server."),
"name": gettext("port")
"name": gettext("port"),
"hidden": True
},
"start_year": {
"type": "number",
@ -169,6 +179,7 @@ MODEL_SPEC = {
"hexagon": {"display_name": gettext("hexagon")}
},
"required": "grid_aoi",
"allowed": "grid_aoi",
"about": gettext(
"The shape of grid cells to make within the AOI polygons. "
"Required if Grid AOI is selected."),
@ -179,6 +190,7 @@ MODEL_SPEC = {
"expression": "value > 0",
"units": u.other, # any unit of length is ok
"required": "grid_aoi",
"allowed": "grid_aoi",
"about": gettext(
"Size of grid cells to make, measured in the projection units "
"of the AOI. If the Grid Type is 'square', this is the length "
@ -199,6 +211,7 @@ MODEL_SPEC = {
"index_col": "id",
"columns": predictor_table_columns,
"required": "compute_regression",
"allowed": "compute_regression",
"about": gettext(
"A table that maps predictor IDs to spatial files and their "
"predictor metric types. The file paths can be absolute or "
@ -210,6 +223,7 @@ MODEL_SPEC = {
"index_col": "id",
"columns": predictor_table_columns,
"required": False,
"allowed": "compute_regression",
"about": gettext(
"A table of future or alternative scenario predictors. Maps "
"IDs to files and their types. The file paths can be absolute "
@ -221,7 +235,7 @@ MODEL_SPEC = {
"PUD_results.gpkg": {
"about": gettext(
"Results of photo-user-days aggregations in the AOI."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"PUD_YR_AVG": {
"about": gettext(
@ -240,7 +254,7 @@ MODEL_SPEC = {
"TUD_results.gpkg": {
"about": gettext(
"Results of twitter-user-days aggregations in the AOI."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"PUD_YR_AVG": {
"about": gettext(
@ -295,7 +309,7 @@ MODEL_SPEC = {
"about": gettext(
"AOI polygons with all the variables needed to compute a regression, "
"including predictor attributes and the user-days response variable."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"[PREDICTOR]": {
"type": "number",
@ -340,7 +354,7 @@ MODEL_SPEC = {
"about": gettext(
"Results of scenario, including the predictor data used in the "
"scenario and the predicted visitation patterns for the scenario."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"[PREDICTOR]": {
"type": "number",
@ -365,7 +379,7 @@ MODEL_SPEC = {
"about": gettext(
"Copy of the input AOI, gridded if applicable."),
"fields": {},
"geometries": spec_utils.POLYGONS
"geometries": spec.POLYGONS
},
"aoi.zip": {
"about": gettext("Compressed AOI")
@ -398,9 +412,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
# Have 5 seconds between timed progress outputs
@ -552,7 +566,7 @@ def execute(args):
prep_aoi_task.join()
# All the server communication happens in this task.
user_days_task = task_graph.add_task(
calc_user_days_task = task_graph.add_task(
func=_retrieve_user_days,
args=(file_registry['local_aoi_path'],
file_registry['compressed_aoi_path'],
@ -566,6 +580,15 @@ def execute(args):
file_registry['server_version']],
task_name='user-day-calculation')
assemble_userday_variables_task = task_graph.add_task(
func=_assemble_regression_data,
args=(file_registry['pud_results_path'],
file_registry['tud_results_path'],
file_registry['regression_vector_path']),
target_path_list=[file_registry['regression_vector_path']],
dependent_task_list=[calc_user_days_task],
task_name='assemble userday variables')
if 'compute_regression' in args and args['compute_regression']:
# Prepare the AOI for geoprocessing.
prepare_response_polygons_task = task_graph.add_task(
@ -579,26 +602,17 @@ def execute(args):
assemble_predictor_data_task = _schedule_predictor_data_processing(
file_registry['local_aoi_path'],
file_registry['response_polygons_lookup'],
prepare_response_polygons_task,
[prepare_response_polygons_task, assemble_userday_variables_task],
args['predictor_table_path'],
file_registry['regression_vector_path'],
intermediate_dir, task_graph)
assemble_regression_data_task = task_graph.add_task(
func=_assemble_regression_data,
args=(file_registry['pud_results_path'],
file_registry['tud_results_path'],
file_registry['regression_vector_path']),
target_path_list=[file_registry['regression_vector_path']],
dependent_task_list=[assemble_predictor_data_task, user_days_task],
task_name='assemble predictor data')
# Compute the regression
coefficient_json_path = os.path.join(
intermediate_dir, 'predictor_estimates.json')
predictor_df = validation.get_validated_dataframe(
args['predictor_table_path'],
**MODEL_SPEC['args']['predictor_table_path'])
predictor_df = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(
args['predictor_table_path'])
predictor_id_list = predictor_df.index
compute_regression_task = task_graph.add_task(
func=_compute_and_summarize_regression,
@ -612,16 +626,27 @@ def execute(args):
target_path_list=[file_registry['regression_coefficients'],
file_registry['regression_summary'],
coefficient_json_path],
dependent_task_list=[assemble_regression_data_task],
dependent_task_list=[assemble_predictor_data_task],
task_name='compute regression')
if ('scenario_predictor_table_path' in args and
args['scenario_predictor_table_path'] != ''):
driver = gdal.GetDriverByName('GPKG')
if os.path.exists(file_registry['scenario_results_path']):
driver.Delete(file_registry['scenario_results_path'])
aoi_vector = gdal.OpenEx(file_registry['local_aoi_path'])
target_vector = driver.CreateCopy(
file_registry['scenario_results_path'], aoi_vector)
target_layer = target_vector.GetLayer()
_rename_layer_from_parent(target_layer)
target_vector = target_layer = None
aoi_vector = None
utils.make_directories([scenario_dir])
build_scenario_data_task = _schedule_predictor_data_processing(
file_registry['local_aoi_path'],
file_registry['response_polygons_lookup'],
prepare_response_polygons_task,
[prepare_response_polygons_task],
args['scenario_predictor_table_path'],
file_registry['scenario_results_path'],
scenario_dir, task_graph)
@ -927,9 +952,8 @@ def _grid_vector(vector_path, grid_type, cell_size, out_grid_vector_path):
def _schedule_predictor_data_processing(
response_vector_path, response_polygons_pickle_path,
prepare_response_polygons_task,
predictor_table_path, target_predictor_vector_path,
working_dir, task_graph):
dependent_task_list, predictor_table_path,
target_predictor_vector_path, working_dir, task_graph):
"""Summarize spatial predictor data by polygons in the response vector.
Build a shapefile with geometry from the response vector, and tabular
@ -941,8 +965,7 @@ def _schedule_predictor_data_processing(
response_polygons_pickle_path (string): path to pickle that stores a
dict which maps each feature FID from ``response_vector_path`` to
its shapely geometry.
prepare_response_polygons_task (Taskgraph.Task object):
A Task needed for dependent_task_lists in this scope.
dependent_task_list (list): list of Taskgraph.Task objects.
predictor_table_path (string): path to a CSV file with three columns
'id', 'path' and 'type'. 'id' is the unique ID for that predictor
and must be less than 10 characters long. 'path' indicates the
@ -983,8 +1006,8 @@ def _schedule_predictor_data_processing(
'line_intersect_length': _line_intersect_length,
}
predictor_df = validation.get_validated_dataframe(
predictor_table_path, **MODEL_SPEC['args']['predictor_table_path'])
predictor_df = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(predictor_table_path)
predictor_task_list = []
predictor_json_list = [] # tracks predictor files to add to gpkg
@ -1014,7 +1037,7 @@ def _schedule_predictor_data_processing(
args=(predictor_type, response_polygons_pickle_path,
row['path'], predictor_target_path),
target_path_list=[predictor_target_path],
dependent_task_list=[prepare_response_polygons_task],
dependent_task_list=dependent_task_list,
task_name=f'predictor {predictor_id}'))
else:
predictor_target_path = os.path.join(
@ -1025,13 +1048,12 @@ def _schedule_predictor_data_processing(
args=(response_polygons_pickle_path,
row['path'], predictor_target_path),
target_path_list=[predictor_target_path],
dependent_task_list=[prepare_response_polygons_task],
dependent_task_list=dependent_task_list,
task_name=f'predictor {predictor_id}'))
# return predictor_task_list, predictor_json_list
assemble_predictor_data_task = task_graph.add_task(
func=_json_to_gpkg_table,
args=(response_vector_path, target_predictor_vector_path,
args=(target_predictor_vector_path,
predictor_json_list),
target_path_list=[target_predictor_vector_path],
dependent_task_list=predictor_task_list,
@ -1058,20 +1080,11 @@ def _prepare_response_polygons_lookup(
def _json_to_gpkg_table(
response_vector_path, predictor_vector_path,
predictor_json_list):
regression_vector_path, predictor_json_list):
"""Create a GeoPackage and a field with data from each json file."""
driver = gdal.GetDriverByName('GPKG')
if os.path.exists(predictor_vector_path):
driver.Delete(predictor_vector_path)
response_vector = gdal.OpenEx(
response_vector_path, gdal.OF_VECTOR | gdal.GA_Update)
predictor_vector = driver.CreateCopy(
predictor_vector_path, response_vector)
response_vector = None
layer = predictor_vector.GetLayer()
_rename_layer_from_parent(layer)
target_vector = gdal.OpenEx(
regression_vector_path, gdal.OF_VECTOR | gdal.GA_Update)
target_layer = target_vector.GetLayer()
predictor_id_list = []
for json_filename in predictor_json_list:
@ -1079,23 +1092,22 @@ def _json_to_gpkg_table(
predictor_id_list.append(predictor_id)
# Create a new field for the predictor
# Delete the field first if it already exists
field_index = layer.FindFieldIndex(
field_index = target_layer.FindFieldIndex(
str(predictor_id), 1)
if field_index >= 0:
layer.DeleteField(field_index)
target_layer.DeleteField(field_index)
predictor_field = ogr.FieldDefn(str(predictor_id), ogr.OFTReal)
layer.CreateField(predictor_field)
target_layer.CreateField(predictor_field)
with open(json_filename, 'r') as file:
predictor_results = json.load(file)
for feature_id, value in predictor_results.items():
feature = layer.GetFeature(int(feature_id))
feature = target_layer.GetFeature(int(feature_id))
feature.SetField(str(predictor_id), value)
layer.SetFeature(feature)
target_layer.SetFeature(feature)
layer = None
predictor_vector.FlushCache()
predictor_vector = None
target_layer = None
target_vector = None
def _raster_sum_mean(
@ -1377,7 +1389,7 @@ def _ogr_to_geometry_list(vector_path):
def _assemble_regression_data(
pud_vector_path, tud_vector_path, regression_vector_path):
pud_vector_path, tud_vector_path, target_vector_path):
"""Update the vector with the predictor data, adding response variables.
Args:
@ -1385,7 +1397,7 @@ def _assemble_regression_data(
layer with PUD_YR_AVG.
tud_vector_path (string): Path to the vector polygon
layer with TUD_YR_AVG.
regression_vector_path (string): The response polygons with predictor data.
target_vector_path (string): The response polygons with predictor data.
Fields will be added in order to compute the linear regression:
* pr_PUD
* pr_TUD
@ -1401,18 +1413,22 @@ def _assemble_regression_data(
tud_vector = gdal.OpenEx(
tud_vector_path, gdal.OF_VECTOR | gdal.GA_ReadOnly)
tud_layer = tud_vector.GetLayer()
target_vector = gdal.OpenEx(
regression_vector_path, gdal.OF_VECTOR | gdal.GA_Update)
driver = gdal.GetDriverByName('GPKG')
if os.path.exists(target_vector_path):
driver.Delete(target_vector_path)
target_vector = driver.CreateCopy(
target_vector_path, pud_vector)
target_layer = target_vector.GetLayer()
_rename_layer_from_parent(target_layer)
for field in target_layer.schema:
if field.name != POLYGON_ID_FIELD:
target_layer.DeleteField(
target_layer.FindFieldIndex(field.name, 1))
def _create_field(fieldname):
# Create a new field for the predictor
# Delete the field first if it already exists
field_index = target_layer.FindFieldIndex(
str(fieldname), 1)
if field_index >= 0:
target_layer.DeleteField(field_index)
field = ogr.FieldDefn(str(fieldname), ogr.OFTReal)
target_layer.CreateField(field)
@ -1675,7 +1691,7 @@ def _calculate_scenario(
scenario_results_path (string): path to desired output scenario
vector result which will be geometrically a copy of the input
AOI but contain the scenario predictor data fields as well as the
scenario esimated response.
scenario estimated response.
response_id (string): text ID of response variable to write to
the scenario result.
coefficient_json_path (string): path to json file with the pre-existing
@ -1752,8 +1768,8 @@ def _validate_same_id_lengths(table_path):
string message if IDs are too long
"""
predictor_df = validation.get_validated_dataframe(
table_path, **MODEL_SPEC['args']['predictor_table_path'])
predictor_df = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(table_path)
too_long = set()
for p_id in predictor_df.index:
if len(p_id) > 10:
@ -1781,12 +1797,13 @@ def _validate_same_ids_and_types(
string message if any of the fields in 'id' and 'type' don't match
between tables.
"""
predictor_df = validation.get_validated_dataframe(
predictor_table_path, **MODEL_SPEC['args']['predictor_table_path'])
predictor_df = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(
predictor_table_path)
scenario_predictor_df = validation.get_validated_dataframe(
scenario_predictor_table_path,
**MODEL_SPEC['args']['scenario_predictor_table_path'])
scenario_predictor_df = MODEL_SPEC.get_input(
'scenario_predictor_table_path').get_validated_dataframe(
scenario_predictor_table_path)
predictor_pairs = set([
(p_id, row['type']) for p_id, row in predictor_df.iterrows()])
@ -1811,9 +1828,9 @@ def _validate_same_projection(base_vector_path, table_path):
"""
# This will load the table as a list of paths which we can iterate through
# without bothering the rest of the table structure
data_paths = validation.get_validated_dataframe(
table_path, **MODEL_SPEC['args']['predictor_table_path']
)['path'].tolist()
data_paths = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(
table_path)['path'].tolist()
base_vector = gdal.OpenEx(base_vector_path, gdal.OF_VECTOR)
base_layer = base_vector.GetLayer()
@ -1854,8 +1871,8 @@ def _validate_predictor_types(table_path):
string message if any value in the ``type`` column does not match a
valid type, ignoring leading/trailing whitespace.
"""
df = validation.get_validated_dataframe(
table_path, **MODEL_SPEC['args']['predictor_table_path'])
df = MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(table_path)
# ignore leading/trailing whitespace because it will be removed
# when the type values are used
valid_types = set({'raster_mean', 'raster_sum', 'point_count',
@ -1914,7 +1931,7 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_messages = validation.validate(args, MODEL_SPEC['args'])
validation_messages = validation.validate(args, MODEL_SPEC)
sufficient_valid_keys = (validation.get_sufficient_keys(args) -
validation.get_invalid_keys(validation_messages))

View File

@ -7,26 +7,38 @@ import pygeoprocessing.routing
import taskgraph
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
INVALID_BAND_INDEX_MSG = gettext('Must be between 1 and {maximum}')
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "routedem",
"model_name": MODEL_METADATA["routedem"].model_title,
"pyname": MODEL_METADATA["routedem"].pyname,
"userguide": MODEL_METADATA["routedem"].userguide,
"model_title": gettext("RouteDEM"),
"userguide": "routedem.html",
"aliases": (),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['dem_path', 'dem_band_index'],
['calculate_slope'],
['algorithm'],
['calculate_flow_direction'],
['calculate_flow_accumulation'],
['calculate_stream_threshold', 'threshold_flow_accumulation',
'calculate_downslope_distance', 'calculate_stream_order',
'calculate_subwatersheds']
]
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"dem_path": spec_utils.DEM,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"dem_path": spec.DEM,
"dem_band_index": {
"type": "number",
"expression": "value >= 1",
@ -62,6 +74,7 @@ MODEL_SPEC = {
"calculate_flow_accumulation": {
"type": "boolean",
"required": False,
"allowed": "calculate_flow_direction",
"about": gettext(
"Calculate flow accumulation from the flow direction output."),
"name": gettext("calculate flow accumulation")
@ -69,20 +82,23 @@ MODEL_SPEC = {
"calculate_stream_threshold": {
"type": "boolean",
"required": False,
"allowed": "calculate_flow_accumulation",
"about": gettext(
"Calculate streams from the flow accumulation output. "),
"name": gettext("calculate streams")
},
"threshold_flow_accumulation": {
**spec_utils.THRESHOLD_FLOW_ACCUMULATION,
**spec.THRESHOLD_FLOW_ACCUMULATION,
"required": "calculate_stream_threshold",
"allowed": "calculate_stream_threshold",
"about": (
spec_utils.THRESHOLD_FLOW_ACCUMULATION['about'] + " " +
spec.THRESHOLD_FLOW_ACCUMULATION['about'] + " " +
gettext("Required if Calculate Streams is selected."))
},
"calculate_downslope_distance": {
"type": "boolean",
"required": False,
"allowed": "calculate_stream_threshold",
"about": gettext(
"Calculate flow distance from each pixel to a stream as "
"defined in the Calculate Streams output."),
@ -97,28 +113,30 @@ MODEL_SPEC = {
"calculate_stream_order": {
"type": "boolean",
"required": False,
"allowed": "calculate_stream_threshold and algorithm == 'D8'",
"about": gettext("Calculate the Strahler Stream order."),
"name": gettext("calculate strahler stream orders (D8 only)"),
},
"calculate_subwatersheds": {
"type": "boolean",
"required": False,
"allowed": "calculate_stream_order and algorithm == 'D8'",
"about": gettext("Determine subwatersheds from the stream order."),
"name": gettext("calculate subwatersheds (D8 only)"),
},
},
"outputs": {
"taskgraph_cache": spec_utils.TASKGRAPH_DIR,
"filled.tif": spec_utils.FILLED_DEM,
"flow_accumulation.tif": spec_utils.FLOW_ACCUMULATION,
"flow_direction.tif": spec_utils.FLOW_DIRECTION,
"slope.tif": spec_utils.SLOPE,
"stream_mask.tif": spec_utils.STREAM,
"taskgraph_cache": spec.TASKGRAPH_DIR,
"filled.tif": spec.FILLED_DEM,
"flow_accumulation.tif": spec.FLOW_ACCUMULATION,
"flow_direction.tif": spec.FLOW_DIRECTION,
"slope.tif": spec.SLOPE,
"stream_mask.tif": spec.STREAM,
"strahler_stream_order.gpkg": {
"about": (
"A vector of line segments indicating the Strahler stream "
"order and other properties of each stream segment."),
"geometries": spec_utils.LINESTRING,
"geometries": spec.LINESTRING,
"fields": {
"order": {
"about": "The Strahler stream order.",
@ -221,7 +239,7 @@ MODEL_SPEC = {
"subwatersheds. A new subwatershed is created for each "
"tributary of a stream and is influenced greatly by "
"your choice of Threshold Flow Accumulation value."),
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"fields": {
"stream_id": {
"about": (
@ -260,7 +278,7 @@ MODEL_SPEC = {
},
},
}
}
})
# replace %s with file suffix
@ -522,7 +540,7 @@ def validate(args, limit_to=None):
the error message in the second part of the tuple. This should
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(args, MODEL_SPEC['args'])
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
sufficient_keys = validation.get_sufficient_keys(args)

View File

@ -17,10 +17,9 @@ import taskgraph
from osgeo import gdal
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
@ -29,17 +28,27 @@ MISSING_CONVERT_OPTION_MSG = gettext(
'One or more of "convert_nearest_to_edge" or "convert_farthest_from_edge" '
'must be selected')
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "scenario_generator_proximity",
"model_name": MODEL_METADATA["scenario_generator_proximity"].model_title,
"pyname": MODEL_METADATA["scenario_generator_proximity"].pyname,
"userguide": MODEL_METADATA["scenario_generator_proximity"].userguide,
"model_title": gettext("Scenario Generator: Proximity Based"),
"userguide": "scenario_gen_proximity.html",
"aliases": ("sgp",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['base_lulc_path', 'aoi_path'],
['area_to_convert', 'focal_landcover_codes',
'convertible_landcover_codes', 'replacement_lucode'],
['convert_farthest_from_edge', 'convert_nearest_to_edge',
'n_fragmentation_steps']
]
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"base_lulc_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"about": gettext("Base map from which to generate scenarios."),
"name": gettext("base LULC map")
@ -88,7 +97,7 @@ MODEL_SPEC = {
"name": gettext("number of conversion steps")
},
"aoi_path": {
**spec_utils.AOI,
**spec.AOI,
"required": False,
"about": gettext(
"Area over which to run the conversion. Provide this input if "
@ -190,9 +199,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
# This sets the largest number of elements that will be packed at once and
@ -908,7 +917,7 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(args, MODEL_SPEC['args'])
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
if ('convert_nearest_to_edge' not in invalid_keys and

View File

@ -17,10 +17,9 @@ from osgeo import ogr
from osgeo import osr
from .. import gettext
from .. import spec_utils
from .. import spec
from .. import utils
from .. import validation
from ..model_metadata import MODEL_METADATA
from ..unit_registry import u
LOGGER = logging.getLogger(__name__)
@ -47,26 +46,34 @@ _INTERMEDIATE_BASE_FILES = {
'value_pattern': 'value_{id}.tif',
}
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "scenic_quality",
"model_name": MODEL_METADATA["scenic_quality"].model_title,
"pyname": MODEL_METADATA["scenic_quality"].pyname,
"userguide": MODEL_METADATA["scenic_quality"].userguide,
"model_title": gettext("Scenic Quality"),
"userguide": "scenic_quality.html",
"aliases": ("sq",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['aoi_path', 'structure_path', 'dem_path', 'refraction'],
['do_valuation', 'valuation_function', 'a_coef', 'b_coef',
'max_valuation_radius'],
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["aoi_path", "structure_path", "dem_path"],
"different_projections_ok": True,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"aoi_path": {
**spec_utils.AOI,
**spec.AOI,
},
"structure_path": {
"name": gettext("features impacting scenic quality"),
"type": "vector",
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": {
"radius": {
"type": "number",
@ -104,7 +111,7 @@ MODEL_SPEC = {
"quality. This must have the same projection as the DEM.")
},
"dem_path": {
**spec_utils.DEM,
**spec.DEM,
"projected": True,
"projection_units": u.meter
},
@ -125,6 +132,7 @@ MODEL_SPEC = {
"name": gettext("Valuation function"),
"type": "option_string",
"required": "do_valuation",
"allowed": "do_valuation",
"options": {
"linear": {"display_name": gettext("linear: a + bx")},
"logarithmic": {"display_name": gettext(
@ -141,6 +149,7 @@ MODEL_SPEC = {
"type": "number",
"units": u.none,
"required": "do_valuation",
"allowed": "do_valuation",
"about": gettext("First coefficient ('a') used by the valuation function"),
},
"b_coef": {
@ -148,6 +157,7 @@ MODEL_SPEC = {
"type": "number",
"units": u.none,
"required": "do_valuation",
"allowed": "do_valuation",
"about": gettext("Second coefficient ('b') used by the valuation function"),
},
"max_valuation_radius": {
@ -155,6 +165,7 @@ MODEL_SPEC = {
"type": "number",
"units": u.meter,
"required": False,
"allowed": "do_valuation",
"expression": "value > 0",
"about": gettext(
"Valuation will only be computed for cells that fall within "
@ -185,7 +196,7 @@ MODEL_SPEC = {
"contents": {
"aoi_reprojected.shp": {
"about": gettext("This vector is the AOI, reprojected to the DEMs spatial reference and projection."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {}
},
"dem_clipped.tif": {
@ -195,12 +206,12 @@ MODEL_SPEC = {
"structures_clipped.shp": {
"about": gettext(
"Copy of the structures vector, clipped to the AOI extent."),
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": {}
},
"structures_reprojected.shp": {
"about": gettext("Copy of the structures vector, reprojected to the DEMs spatial reference and projection."),
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": {}
},
"value_[FEATURE_ID].tif": {
@ -213,9 +224,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
def execute(args):
@ -272,7 +283,7 @@ def execute(args):
'b': float(args['b_coef']),
}
if (args['valuation_function'] not in
MODEL_SPEC['args']['valuation_function']['options']):
MODEL_SPEC.get_input('valuation_function').options):
raise ValueError('Valuation function type %s not recognized' %
args['valuation_function'])
max_valuation_radius = float(args['max_valuation_radius'])
@ -1109,5 +1120,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
return validation.validate(args, MODEL_SPEC)

View File

@ -18,32 +18,41 @@ from osgeo import gdal
from osgeo import ogr
from .. import gettext
from .. import spec_utils
from .. import spec
from .. import urban_nature_access
from .. import utils
from .. import validation
from ..model_metadata import MODEL_METADATA
from ..unit_registry import u
from . import sdr_core
LOGGER = logging.getLogger(__name__)
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "sdr",
"model_name": MODEL_METADATA["sdr"].model_title,
"pyname": MODEL_METADATA["sdr"].pyname,
"userguide": MODEL_METADATA["sdr"].userguide,
"model_title": gettext("Sediment Delivery Ratio"),
"userguide": "sdr.html",
"aliases": (),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['dem_path', 'erosivity_path', 'erodibility_path'],
['lulc_path', 'biophysical_table_path'],
['watersheds_path', 'drainage_path'],
['flow_dir_algorithm', 'threshold_flow_accumulation',
'k_param', 'sdr_max', 'ic_0_param', 'l_max']
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["dem_path", "erosivity_path", "erodibility_path",
"lulc_path", "drainage_path", "watersheds_path", ],
"different_projections_ok": False,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"dem_path": {
**spec_utils.DEM,
**spec.DEM,
"projected": True
},
"erosivity_path": {
@ -70,15 +79,15 @@ MODEL_SPEC = {
"name": gettext("soil erodibility")
},
"lulc_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"about": spec_utils.LULC['about'] + " " + gettext(
"about": spec.LULC['about'] + " " + gettext(
"All values in this raster must "
"have corresponding entries in the Biophysical Table.")
},
"watersheds_path": {
"type": "vector",
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"projected": True,
"fields": {},
"about": gettext(
@ -91,7 +100,7 @@ MODEL_SPEC = {
"type": "csv",
"index_col": "lucode",
"columns": {
"lucode": spec_utils.LULC_TABLE_COLUMN,
"lucode": spec.LULC_TABLE_COLUMN,
"usle_c": {
"type": "ratio",
"about": gettext("Cover-management factor for the USLE")},
@ -105,7 +114,7 @@ MODEL_SPEC = {
"corresponding entries in this table."),
"name": gettext("biophysical table")
},
"threshold_flow_accumulation": spec_utils.THRESHOLD_FLOW_ACCUMULATION,
"threshold_flow_accumulation": spec.THRESHOLD_FLOW_ACCUMULATION,
"k_param": {
"type": "number",
"units": u.none,
@ -142,7 +151,7 @@ MODEL_SPEC = {
"streams. Pixels with 0 are not drainages."),
"name": gettext("drainages")
},
**spec_utils.FLOW_DIR_ALGORITHM
**spec.FLOW_DIR_ALGORITHM
},
"outputs": {
"avoided_erosion.tif": {
@ -180,7 +189,7 @@ MODEL_SPEC = {
"units": u.metric_ton/u.hectare
}}
},
"stream.tif": spec_utils.STREAM,
"stream.tif": spec.STREAM,
"stream_and_drainage.tif": {
"created_if": "drainage_path",
"about": "This raster is the union of that layer with the calculated stream layer(Eq. (85)). Values of 1 represent streams, values of 0 are non-stream pixels.",
@ -195,7 +204,7 @@ MODEL_SPEC = {
},
"watershed_results_sdr.shp": {
"about": "Table containing biophysical values for each watershed",
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"sed_export": {
"type": "number",
@ -261,8 +270,8 @@ MODEL_SPEC = {
"units": u.metric_ton/(u.hectare*u.year)
}}
},
"flow_accumulation.tif": spec_utils.FLOW_ACCUMULATION,
"flow_direction.tif": spec_utils.FLOW_DIRECTION,
"flow_accumulation.tif": spec.FLOW_ACCUMULATION,
"flow_direction.tif": spec.FLOW_DIRECTION,
"ic.tif": {
"about": gettext("Index of connectivity (Eq. (70))"),
"bands": {1: {
@ -277,7 +286,7 @@ MODEL_SPEC = {
"units": u.none
}}
},
"pit_filled_dem.tif": spec_utils.FILLED_DEM,
"pit_filled_dem.tif": spec.FILLED_DEM,
"s_accumulation.tif": {
"about": gettext(
"Flow accumulation weighted by the thresholded slope. "
@ -300,7 +309,7 @@ MODEL_SPEC = {
"about": gettext("Sediment delivery ratio (Eq. (75))"),
"bands": {1: {"type": "ratio"}}
},
"slope.tif": spec_utils.SLOPE,
"slope.tif": spec.SLOPE,
"slope_threshold.tif": {
"about": gettext(
"Percent slope, thresholded to be no less than 0.005 "
@ -450,9 +459,9 @@ MODEL_SPEC = {
}
},
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
_OUTPUT_BASE_FILES = {
'rkls_path': 'rkls.tif',
@ -553,9 +562,9 @@ def execute(args):
"""
file_suffix = utils.make_suffix_string(args, 'results_suffix')
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
**MODEL_SPEC['args']['biophysical_table_path'])
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
# Test to see if c or p values are outside of 0..1
for key in ['usle_c', 'usle_p']:
@ -1590,5 +1599,4 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
return validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
return validation.validate(args, MODEL_SPEC)

View File

@ -14,10 +14,9 @@ from osgeo import gdal
from osgeo import ogr
from .. import gettext
from .. import spec_utils
from .. import spec
from .. import utils
from .. import validation
from ..model_metadata import MODEL_METADATA
from ..unit_registry import u
from . import seasonal_water_yield_core
@ -29,11 +28,22 @@ MONTH_ID_TO_LABEL = [
'jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct',
'nov', 'dec']
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "seasonal_water_yield",
"model_name": MODEL_METADATA["seasonal_water_yield"].model_title,
"pyname": MODEL_METADATA["seasonal_water_yield"].pyname,
"userguide": MODEL_METADATA["seasonal_water_yield"].userguide,
"model_title": gettext("Seasonal Water Yield"),
"userguide": "seasonal_water_yield.html",
"aliases": ("swy",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['lulc_raster_path', 'biophysical_table_path'],
['dem_raster_path', 'aoi_path'],
['flow_dir_algorithm', 'threshold_flow_accumulation', 'beta_i', 'gamma'],
['user_defined_local_recharge', 'l_path', 'et0_dir', 'precip_dir', 'soil_group_path'],
['monthly_alpha', 'alpha_m', 'monthly_alpha_path'],
['user_defined_climate_zones', 'rain_events_table_path', 'climate_zone_table_path', 'climate_zone_raster_path'],
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["dem_raster_path", "lulc_raster_path",
"soil_group_path", "aoi_path", "l_path",
@ -41,10 +51,10 @@ MODEL_SPEC = {
"different_projections_ok": True,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"threshold_flow_accumulation": spec_utils.THRESHOLD_FLOW_ACCUMULATION,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"threshold_flow_accumulation": spec.THRESHOLD_FLOW_ACCUMULATION,
"et0_dir": {
"type": "directory",
"contents": {
@ -67,6 +77,7 @@ MODEL_SPEC = {
},
},
"required": "not user_defined_local_recharge",
"allowed": "not user_defined_local_recharge",
"about": gettext(
"Directory containing maps of reference evapotranspiration "
"for each month. Only .tif files should be in this folder "
@ -95,6 +106,7 @@ MODEL_SPEC = {
},
},
"required": "not user_defined_local_recharge",
"allowed": "not user_defined_local_recharge",
"about": gettext(
"Directory containing maps of monthly precipitation for each "
"month. Only .tif files should be in this folder (no .tfw, "
@ -103,30 +115,31 @@ MODEL_SPEC = {
"name": gettext("precipitation directory")
},
"dem_raster_path": {
**spec_utils.DEM,
**spec.DEM,
"projected": True
},
"lulc_raster_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"about": spec_utils.LULC['about'] + " " + gettext(
"about": spec.LULC['about'] + " " + gettext(
"All values in this raster MUST "
"have corresponding entries in the Biophysical Table.")
},
"soil_group_path": {
**spec_utils.SOIL_GROUP,
**spec.SOIL_GROUP,
"projected": True,
"required": "not user_defined_local_recharge"
"required": "not user_defined_local_recharge",
"allowed": "not user_defined_local_recharge"
},
"aoi_path": {
**spec_utils.AOI,
**spec.AOI,
"projected": True
},
"biophysical_table_path": {
"type": "csv",
"index_col": "lucode",
"columns": {
"lucode": spec_utils.LULC_TABLE_COLUMN,
"lucode": spec.LULC_TABLE_COLUMN,
"cn_[SOIL_GROUP]": {
"type": "number",
"units": u.none,
@ -173,6 +186,7 @@ MODEL_SPEC = {
"required": (
"(not user_defined_local_recharge) & (not "
"user_defined_climate_zones)"),
"allowed": "not user_defined_climate_zones",
"about": gettext(
"A table containing the number of rain events for each month. "
"Required if neither User-Defined Local Recharge nor User-"
@ -182,6 +196,7 @@ MODEL_SPEC = {
"alpha_m": {
"type": "freestyle_string",
"required": "not monthly_alpha",
"allowed": "not monthly_alpha",
"about": gettext(
"The proportion of upslope annual available local recharge "
"that is available in each month. Required if Use Monthly "
@ -216,6 +231,7 @@ MODEL_SPEC = {
"units": u.millimeter
}},
"required": "user_defined_local_recharge",
"allowed": "user_defined_local_recharge",
"projected": True,
"about": gettext(
"Map of local recharge data. Required if User-Defined Local "
@ -249,6 +265,7 @@ MODEL_SPEC = {
"for each month.")}
},
"required": "user_defined_climate_zones",
"allowed": "user_defined_climate_zones",
"about": gettext(
"Table of monthly precipitation events for each climate zone. "
"Required if User-Defined Climate Zones is selected."),
@ -258,6 +275,7 @@ MODEL_SPEC = {
"type": "raster",
"bands": {1: {"type": "integer"}},
"required": "user_defined_climate_zones",
"allowed": "user_defined_climate_zones",
"projected": True,
"about": gettext(
"Map of climate zones. All values in this raster must have "
@ -289,12 +307,13 @@ MODEL_SPEC = {
}
},
"required": "monthly_alpha",
"allowed": "monthly_alpha",
"about": gettext(
"Table of alpha values for each month. "
"Required if Use Monthly Alpha Table is selected."),
"name": gettext("monthly alpha table")
},
**spec_utils.FLOW_DIR_ALGORITHM
**spec.FLOW_DIR_ALGORITHM
},
"outputs": {
"B.tif": {
@ -364,6 +383,7 @@ MODEL_SPEC = {
"units": u.millimeter/u.year
}}
},
"stream.tif": spec.STREAM,
"P.tif": {
"about": gettext("The total precipitation across all months on this pixel."),
"bands": {1: {
@ -382,7 +402,7 @@ MODEL_SPEC = {
},
"aggregated_results_swy.shp": {
"about": gettext("Table of biophysical values for each watershed"),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"qb": {
"about": gettext(
@ -423,15 +443,6 @@ MODEL_SPEC = {
"units": u.millimeter
}}
},
"stream.tif": {
"about": gettext(
"Stream network map generated from the input DEM and "
"Threshold Flow Accumulation. Values of 1 represent "
"streams, values of 0 are non-stream pixels."),
"bands": {1: {
"type": "integer"
}}
},
'Si.tif': {
"about": gettext("Map of the S_i factor derived from CN"),
"bands": {1: {"type": "number", "units": u.inch}}
@ -455,7 +466,7 @@ MODEL_SPEC = {
"clipped to match the other spatial inputs"),
"bands": {1: {"type": "integer"}}
},
'flow_accum.tif': spec_utils.FLOW_ACCUMULATION,
'flow_accum.tif': spec.FLOW_ACCUMULATION,
'prcp_a[MONTH].tif': {
"bands": {1: {"type": "number", "units": u.millimeter/u.year}},
"about": gettext("Monthly precipitation rasters, aligned and "
@ -486,9 +497,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
_OUTPUT_BASE_FILES = {
@ -500,6 +511,7 @@ _OUTPUT_BASE_FILES = {
'l_sum_path': 'L_sum.tif',
'l_sum_avail_path': 'L_sum_avail.tif',
'qf_path': 'QF.tif',
'stream_path': 'stream.tif',
'b_sum_path': 'B_sum.tif',
'b_path': 'B.tif',
'vri_path': 'Vri.tif',
@ -510,7 +522,6 @@ _INTERMEDIATE_BASE_FILES = {
'aetm_path_list': ['aetm_%d.tif' % (x+1) for x in range(N_MONTHS)],
'flow_dir_path': 'flow_dir.tif',
'qfm_path_list': ['qf_%d.tif' % (x+1) for x in range(N_MONTHS)],
'stream_path': 'stream.tif',
'si_path': 'Si.tif',
'lulc_aligned_path': 'lulc_aligned.tif',
'dem_aligned_path': 'dem_aligned.tif',
@ -605,20 +616,19 @@ def execute(args):
# fail early on a missing required rain events table
if (not args['user_defined_local_recharge'] and
not args['user_defined_climate_zones']):
rain_events_df = validation.get_validated_dataframe(
args['rain_events_table_path'],
**MODEL_SPEC['args']['rain_events_table_path'])
rain_events_df = MODEL_SPEC.get_input(
'rain_events_table_path').get_validated_dataframe(
args['rain_events_table_path'])
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
**MODEL_SPEC['args']['biophysical_table_path'])
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
if args['monthly_alpha']:
# parse out the alpha lookup table of the form (month_id: alpha_val)
alpha_month_map = validation.get_validated_dataframe(
args['monthly_alpha_path'],
**MODEL_SPEC['args']['monthly_alpha_path']
)['alpha'].to_dict()
alpha_month_map = MODEL_SPEC.get_input(
'monthly_alpha_path').get_validated_dataframe(
args['monthly_alpha_path'])['alpha'].to_dict()
else:
# make all 12 entries equal to args['alpha_m']
alpha_m = float(fractions.Fraction(args['alpha_m']))
@ -795,9 +805,9 @@ def execute(args):
'table_name': 'Climate Zone'}
for month_id in range(N_MONTHS):
if args['user_defined_climate_zones']:
cz_rain_events_df = validation.get_validated_dataframe(
args['climate_zone_table_path'],
**MODEL_SPEC['args']['climate_zone_table_path'])
cz_rain_events_df = MODEL_SPEC.get_input(
'climate_zone_table_path').get_validated_dataframe(
args['climate_zone_table_path'])
climate_zone_rain_events_month = (
cz_rain_events_df[MONTH_ID_TO_LABEL[month_id]].to_dict())
n_events_task = task_graph.add_task(
@ -1466,5 +1476,4 @@ def validate(args, limit_to=None):
the error message in the second part of the tuple. This should
be an empty list if validation succeeds.
"""
return validation.validate(args, MODEL_SPEC['args'],
MODEL_SPEC['args_with_spatial_overlap'])
return validation.validate(args, MODEL_SPEC)

2251
src/natcap/invest/spec.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,744 +0,0 @@
import importlib
import json
import logging
import os
import pprint
import geometamaker
import natcap.invest
import pint
from natcap.invest import utils
from . import gettext
from .unit_registry import u
from pydantic import ValidationError
LOGGER = logging.getLogger(__name__)
# Specs for common arg types ##################################################
WORKSPACE = {
"name": gettext("workspace"),
"about": gettext(
"The folder where all the model's output files will be written. If "
"this folder does not exist, it will be created. If data already "
"exists in the folder, it will be overwritten."),
"type": "directory",
"contents": {},
"must_exist": False,
"permissions": "rwx",
}
SUFFIX = {
"name": gettext("file suffix"),
"about": gettext(
"Suffix that will be appended to all output file names. Useful to "
"differentiate between model runs."),
"type": "freestyle_string",
"required": False,
"regexp": "[a-zA-Z0-9_-]*"
}
N_WORKERS = {
"name": gettext("taskgraph n_workers parameter"),
"about": gettext(
"The n_workers parameter to provide to taskgraph. "
"-1 will cause all jobs to run synchronously. "
"0 will run all jobs in the same process, but scheduling will take "
"place asynchronously. Any other positive integer will cause that "
"many processes to be spawned to execute tasks."),
"type": "number",
"units": u.none,
"required": False,
"expression": "value >= -1"
}
METER_RASTER = {
"type": "raster",
"bands": {
1: {
"type": "number",
"units": u.meter
}
}
}
AOI = {
"type": "vector",
"fields": {},
"geometries": {"POLYGON", "MULTIPOLYGON"},
"name": gettext("area of interest"),
"about": gettext(
"A map of areas over which to aggregate and "
"summarize the final results."),
}
LULC = {
"type": "raster",
"bands": {1: {"type": "integer"}},
"about": gettext(
"Map of land use/land cover codes. Each land use/land cover type "
"must be assigned a unique integer code."),
"name": gettext("land use/land cover")
}
DEM = {
"type": "raster",
"bands": {
1: {
"type": "number",
"units": u.meter
}
},
"about": gettext("Map of elevation above sea level."),
"name": gettext("digital elevation model")
}
PRECIP = {
"type": "raster",
"bands": {
1: {
"type": "number",
"units": u.millimeter/u.year
}
},
"about": gettext("Map of average annual precipitation."),
"name": gettext("precipitation")
}
ET0 = {
"name": gettext("reference evapotranspiration"),
"type": "raster",
"bands": {
1: {
"type": "number",
"units": u.millimeter
}
},
"about": gettext("Map of reference evapotranspiration values.")
}
SOIL_GROUP = {
"type": "raster",
"bands": {1: {"type": "integer"}},
"about": gettext(
"Map of soil hydrologic groups. Pixels may have values 1, 2, 3, or 4, "
"corresponding to soil hydrologic groups A, B, C, or D, respectively."),
"name": gettext("soil hydrologic group")
}
THRESHOLD_FLOW_ACCUMULATION = {
"expression": "value >= 0",
"type": "number",
"units": u.pixel,
"about": gettext(
"The number of upslope pixels that must flow into a pixel "
"before it is classified as a stream."),
"name": gettext("threshold flow accumulation")
}
LULC_TABLE_COLUMN = {
"type": "integer",
"about": gettext(
"LULC codes from the LULC raster. Each code must be a unique "
"integer.")
}
# Specs for common outputs ####################################################
TASKGRAPH_DIR = {
"type": "directory",
"about": (
"Cache that stores data between model runs. This directory contains no "
"human-readable data and you may ignore it."),
"contents": {
"taskgraph.db": {}
}
}
FILLED_DEM = {
"about": gettext("Map of elevation after any pits are filled"),
"bands": {1: {
"type": "number",
"units": u.meter
}}
}
FLOW_ACCUMULATION = {
"about": gettext("Map of flow accumulation"),
"bands": {1: {
"type": "number",
"units": u.none
}}
}
FLOW_DIRECTION = {
"about": gettext(
"MFD flow direction. Note: the pixel values should not "
"be interpreted directly. Each 32-bit number consists "
"of 8 4-bit numbers. Each 4-bit number represents the "
"proportion of flow into one of the eight neighboring "
"pixels."),
"bands": {1: {"type": "integer"}}
}
FLOW_DIRECTION_D8 = {
"about": gettext(
"D8 flow direction."),
"bands": {1: {"type": "integer"}}
}
SLOPE = {
"about": gettext(
"Percent slope, calculated from the pit-filled "
"DEM. 100 is equivalent to a 45 degree slope."),
"bands": {1: {"type": "percent"}}
}
STREAM = {
"about": "Stream network, created using flow direction and flow accumulation derived from the DEM and Threshold Flow Accumulation. Values of 1 represent streams, values of 0 are non-stream pixels.",
"bands": {1: {"type": "integer"}}
}
FLOW_DIR_ALGORITHM = {
"flow_dir_algorithm": {
"type": "option_string",
"options": {
"D8": {
"display_name": gettext("D8"),
"description": "D8 flow direction"
},
"MFD": {
"display_name": gettext("MFD"),
"description": "Multiple flow direction"
}
},
"about": gettext("Flow direction algorithm to use."),
"name": gettext("flow direction algorithm")
}
}
# geometry types ##############################################################
# the full list of ogr geometry types is in an enum in
# https://github.com/OSGeo/gdal/blob/master/gdal/ogr/ogr_core.h
POINT = {'POINT'}
LINESTRING = {'LINESTRING'}
POLYGON = {'POLYGON'}
MULTIPOINT = {'MULTIPOINT'}
MULTILINESTRING = {'MULTILINESTRING'}
MULTIPOLYGON = {'MULTIPOLYGON'}
LINES = LINESTRING | MULTILINESTRING
POLYGONS = POLYGON | MULTIPOLYGON
POINTS = POINT | MULTIPOINT
ALL_GEOMS = LINES | POLYGONS | POINTS
def format_unit(unit):
"""Represent a pint Unit as user-friendly unicode text.
This attempts to follow the style guidelines from the NIST
Guide to the SI (https://www.nist.gov/pml/special-publication-811):
- Use standard symbols rather than spelling out
- Use '/' to represent division
- Use the center dot ' · ' to represent multiplication
- Combine denominators into one, surrounded by parentheses
Args:
unit (pint.Unit): the unit to format
Raises:
TypeError if unit is not an instance of pint.Unit.
Returns:
String describing the unit.
"""
if not isinstance(unit, pint.Unit):
raise TypeError(
f'{unit} is of type {type(unit)}. '
f'It should be an instance of pint.Unit')
# Optionally use a pre-set format for a particular unit
custom_formats = {
u.pixel: gettext('number of pixels'),
u.year_AD: '', # don't need to mention units for a year input
u.other: '', # for inputs that can have any or multiple units
# For soil erodibility (t*h*ha/(ha*MJ*mm)), by convention the ha's
# are left on top and bottom and don't cancel out
# pint always cancels units where it can, so add them back in here
# this isn't a perfect solution
# see https://github.com/hgrecco/pint/issues/1364
u.t * u.hr / (u.MJ * u.mm): 't · h · ha / (ha · MJ · mm)',
u.none: gettext('unitless')
}
if unit in custom_formats:
return custom_formats[unit]
# look up the abbreviated symbol for each unit
# `formatter` expects an iterable of (unit, exponent) pairs, which lives in
# the pint.Unit's `_units` attribute.
unit_items = [(u.get_symbol(key), val) for key, val in unit._units.items()]
formatted_unit = pint.formatting.formatter(
unit_items,
as_ratio=True,
single_denominator=True,
product_fmt=" · ",
division_fmt='/',
power_fmt="{}{}",
parentheses_fmt="({})",
exp_call=pint.formatting._pretty_fmt_exponent)
if 'currency' in formatted_unit:
formatted_unit = formatted_unit.replace('currency', gettext('currency units'))
return formatted_unit
def serialize_args_spec(spec):
"""Serialize an MODEL_SPEC dict to a JSON string.
Args:
spec (dict): An invest model's MODEL_SPEC.
Raises:
TypeError if any object type within the spec is not handled by
json.dumps or by the fallback serializer.
Returns:
JSON String
"""
def fallback_serializer(obj):
"""Serialize objects that are otherwise not JSON serializeable."""
if isinstance(obj, pint.Unit):
return format_unit(obj)
# Sets are present in 'geometries' attributes of some args
# We don't need to worry about deserializing back to a set/array
# so casting to string is okay.
elif isinstance(obj, set):
return str(obj)
raise TypeError(f'fallback serializer is missing for {type(obj)}')
return json.dumps(spec, default=fallback_serializer)
# accepted geometries for a vector will be displayed in this order
GEOMETRY_ORDER = [
'POINT',
'MULTIPOINT',
'LINESTRING',
'MULTILINESTRING',
'POLYGON',
'MULTIPOLYGON']
INPUT_TYPES_HTML_FILE = 'input_types.html'
def format_required_string(required):
"""Represent an arg's required status as a user-friendly string.
Args:
required (bool | str | None): required property of an arg. May be
`True`, `False`, `None`, or a conditional string.
Returns:
string
"""
if required is None or required is True:
return gettext('required')
elif required is False:
return gettext('optional')
else:
# assume that the about text will describe the conditional
return gettext('conditionally required')
def format_geometries_string(geometries):
"""Represent a set of allowed vector geometries as user-friendly text.
Args:
geometries (set(str)): set of geometry names
Returns:
string
"""
# sort the geometries so they always display in a consistent order
sorted_geoms = sorted(
geometries,
key=lambda g: GEOMETRY_ORDER.index(g))
return '/'.join(gettext(geom).lower() for geom in sorted_geoms)
def format_permissions_string(permissions):
"""Represent a rwx-style permissions string as user-friendly text.
Args:
permissions (str): rwx-style permissions string
Returns:
string
"""
permissions_strings = []
if 'r' in permissions:
permissions_strings.append(gettext('read'))
if 'w' in permissions:
permissions_strings.append(gettext('write'))
if 'x' in permissions:
permissions_strings.append(gettext('execute'))
return ', '.join(permissions_strings)
def format_options_string_from_dict(options):
"""Represent a dictionary of option: description pairs as a bulleted list.
Args:
options (dict): the dictionary of options to document, where keys are
options and values are dictionaries describing the options.
They may have either or both 'display_name' and 'description' keys,
for example:
{'option1': {'display_name': 'Option 1', 'description': 'the first option'}}
Returns:
list of RST-formatted strings, where each is a line in a bullet list
"""
lines = []
for key, info in options.items():
display_name = info['display_name'] if 'display_name' in info else key
if 'description' in info:
lines.append(f'- {display_name}: {info["description"]}')
else:
lines.append(f'- {display_name}')
# sort the options alphabetically
# casefold() is a more aggressive version of lower() that may work better
# for some languages to remove all case distinctions
return sorted(lines, key=lambda line: line.casefold())
def format_options_string_from_list(options):
"""Represent options as a comma-separated list.
Args:
options (list[str]): the set of options to document
Returns:
string of comma-separated options
"""
return ', '.join(options)
def capitalize(title):
"""Capitalize a string into title case.
Args:
title (str): string to capitalize
Returns:
capitalized string (each word capitalized except linking words)
"""
def capitalize_word(word):
"""Capitalize a word, if appropriate."""
if word in {'of', 'the'}:
return word
else:
return word[0].upper() + word[1:]
title = ' '.join([capitalize_word(word) for word in title.split(' ')])
title = '/'.join([capitalize_word(word) for word in title.split('/')])
return title
def format_type_string(arg_type):
"""Represent an arg type as a user-friendly string.
Args:
arg_type (str|set(str)): the type to format. May be a single type or a
set of types.
Returns:
formatted string that links to a description of the input type(s)
"""
# some types need a more user-friendly name
# all types are listed here so that they can be marked up for translation
type_names = {
'boolean': gettext('true/false'),
'csv': gettext('CSV'),
'directory': gettext('directory'),
'file': gettext('file'),
'freestyle_string': gettext('text'),
'integer': gettext('integer'),
'number': gettext('number'),
'option_string': gettext('option'),
'percent': gettext('percent'),
'raster': gettext('raster'),
'ratio': gettext('ratio'),
'vector': gettext('vector')
}
def format_single_type(arg_type):
"""Represent a type as a link to the corresponding Input Types section.
Args:
arg_type (str): the type to format.
Returns:
formatted string that links to a description of the input type
"""
# Represent the type as a string. Some need a more user-friendly name.
# we can only use standard docutils features here, so no :ref:
# this syntax works to link to a section in a different page, but it
# isn't universally supported and depends on knowing the built page name.
if arg_type == 'freestyle_string':
section_name = 'text'
elif arg_type == 'option_string':
section_name = 'option'
elif arg_type == 'boolean':
section_name = 'truefalse'
elif arg_type == 'csv':
section_name = 'csv'
else:
section_name = arg_type
return f'`{type_names[arg_type]} <{INPUT_TYPES_HTML_FILE}#{section_name}>`__'
if isinstance(arg_type, set):
return ' or '.join(format_single_type(t) for t in sorted(arg_type))
else:
return format_single_type(arg_type)
def describe_arg_from_spec(name, spec):
"""Generate RST documentation for an arg, given an arg spec.
This is used for documenting:
- a single top-level arg
- a row or column in a CSV
- a field in a vector
- an item in a directory
Args:
name (str): Name to give the section. For top-level args this is
arg['name']. For nested args it's typically their key in the
dictionary one level up.
spec (dict): A arg spec dictionary that conforms to the InVEST args
spec specification. It must at least have the key `'type'`, and
whatever other keys are expected for that type.
Returns:
list of strings, where each string is a line of RST-formatted text.
The first line has the arg name, type, required state, description,
and units if applicable. Depending on the type, there may be additional
lines that are indented, that describe details of the arg such as
vector fields and geometries, option_string options, etc.
"""
type_string = format_type_string(spec['type'])
in_parentheses = [type_string]
# For numbers and rasters that have units, display the units
units = None
if spec['type'] == 'number':
units = spec['units']
elif spec['type'] == 'raster' and spec['bands'][1]['type'] == 'number':
units = spec['bands'][1]['units']
if units:
units_string = format_unit(units)
if units_string:
# pybabel can't find the message if it's in the f-string
translated_units = gettext("units")
in_parentheses.append(f'{translated_units}: **{units_string}**')
if spec['type'] == 'vector':
in_parentheses.append(format_geometries_string(spec["geometries"]))
# Represent the required state as a string, defaulting to required
# It doesn't make sense to include this for boolean checkboxes
if spec['type'] != 'boolean':
# get() returns None if the key doesn't exist in the dictionary
required_string = format_required_string(spec.get('required'))
in_parentheses.append(f'*{required_string}*')
# Nested args may not have an about section
if 'about' in spec:
sanitized_about_string = spec["about"].replace("_", "\\_")
about_string = f': {sanitized_about_string}'
else:
about_string = ''
first_line = f"**{name}** ({', '.join(in_parentheses)}){about_string}"
# Add details for the types that have them
indented_block = []
if spec['type'] == 'option_string':
# may be either a dict or set. if it's empty, the options are
# dynamically generated. don't try to document them.
if spec['options']:
if isinstance(spec['options'], dict):
indented_block.append(gettext('Options:'))
indented_block += format_options_string_from_dict(spec['options'])
else:
formatted_options = format_options_string_from_list(spec['options'])
indented_block.append(gettext('Options:') + f' {formatted_options}')
elif spec['type'] == 'csv':
if 'columns' not in spec and 'rows' not in spec:
first_line += gettext(
' Please see the sample data table for details on the format.')
# prepend the indent to each line in the indented block
return [first_line] + ['\t' + line for line in indented_block]
def describe_arg_from_name(module_name, *arg_keys):
"""Generate RST documentation for an arg, given its model and name.
Args:
module_name (str): invest model module containing the arg.
*arg_keys: one or more strings that are nested arg keys.
Returns:
String describing the arg in RST format. Contains an anchor named
<arg_keys[0]>-<arg_keys[1]>...-<arg_keys[n]>
where underscores in arg keys are replaced with hyphens.
"""
# import the specified module (that should have an MODEL_SPEC attribute)
module = importlib.import_module(module_name)
# start with the spec for all args
# narrow down to the nested spec indicated by the sequence of arg keys
spec = module.MODEL_SPEC['args']
for i, key in enumerate(arg_keys):
# convert raster band numbers to ints
if arg_keys[i - 1] == 'bands':
key = int(key)
try:
spec = spec[key]
except KeyError:
keys_so_far = '.'.join(arg_keys[:i + 1])
raise ValueError(
f"Could not find the key '{keys_so_far}' in the "
f"{module_name} model's MODEL_SPEC")
# format spec into an RST formatted description string
if 'name' in spec:
arg_name = capitalize(spec['name'])
else:
arg_name = arg_keys[-1]
# anchor names cannot contain underscores. sphinx will replace them
# automatically, but lets explicitly replace them here
anchor_name = '-'.join(arg_keys).replace('_', '-')
rst_description = '\n\n'.join(describe_arg_from_spec(arg_name, spec))
return f'.. _{anchor_name}:\n\n{rst_description}'
def write_metadata_file(datasource_path, spec, keywords_list,
lineage_statement='', out_workspace=None):
"""Write a metadata sidecar file for an invest dataset.
Create metadata for invest model inputs or outputs, taking care to
preserve existing human-modified attributes.
Note: We do not want to overwrite any existing metadata so if there is
invalid metadata for the datasource (i.e., doesn't pass geometamaker
validation in ``describe``), this function will NOT create new metadata.
Args:
datasource_path (str) - filepath to the data to describe
spec (dict) - the invest specification for ``datasource_path``
keywords_list (list) - sequence of strings
lineage_statement (str, optional) - string to describe origin of
the dataset
out_workspace (str, optional) - where to write metadata if different
from data location
Returns:
None
"""
def _get_key(key, resource):
"""Map name of actual key in yml from model_spec key name."""
names = {field.name.lower(): field.name
for field in resource.data_model.fields}
return names[key]
try:
resource = geometamaker.describe(datasource_path)
except ValidationError:
LOGGER.debug(
f"Skipping metadata creation for {datasource_path}, as invalid "
"metadata exists.")
return None
# Don't want function to fail bc can't create metadata due to invalid filetype
except ValueError as e:
LOGGER.debug(f"Skipping metadata creation for {datasource_path}: {e}")
return None
resource.set_lineage(lineage_statement)
# a pre-existing metadata doc could have keywords
words = resource.get_keywords()
resource.set_keywords(set(words + keywords_list))
if 'about' in spec and len(resource.get_description()) < 1:
resource.set_description(spec['about'])
attr_spec = None
if 'columns' in spec:
attr_spec = spec['columns']
if 'fields' in spec:
attr_spec = spec['fields']
if attr_spec:
for key, value in attr_spec.items():
try:
# field names in attr_spec are always lowercase, but the
# actual fieldname in the data could be any case because
# invest does not require case-sensitive fieldnames
yaml_key = _get_key(key, resource)
# Field description only gets set if its empty, i.e. ''
if len(resource.get_field_description(yaml_key)
.description.strip()) < 1:
about = value['about'] if 'about' in value else ''
resource.set_field_description(yaml_key, description=about)
# units only get set if empty
if len(resource.get_field_description(yaml_key)
.units.strip()) < 1:
units = format_unit(value['units']) if 'units' in value else ''
resource.set_field_description(yaml_key, units=units)
except KeyError as error:
# fields that are in the spec but missing
# from model results because they are conditional.
LOGGER.debug(error)
if 'bands' in spec:
for idx, value in spec['bands'].items():
if len(resource.get_band_description(idx).units) < 1:
try:
units = format_unit(spec['bands'][idx]['units'])
except KeyError:
units = ''
resource.set_band_description(idx, units=units)
resource.write(workspace=out_workspace)
def generate_metadata_for_outputs(model_module, args_dict):
"""Create metadata for all items in an invest model output workspace.
Args:
model_module (object) - the natcap.invest module containing
the MODEL_SPEC attribute
args_dict (dict) - the arguments dictionary passed to the
model's ``execute`` function.
Returns:
None
"""
file_suffix = utils.make_suffix_string(args_dict, 'results_suffix')
formatted_args = pprint.pformat(args_dict)
lineage_statement = (
f'Created by {model_module.__name__}.execute(\n{formatted_args})\n'
f'Version {natcap.invest.__version__}')
keywords = [model_module.MODEL_SPEC['model_id'], 'InVEST']
def _walk_spec(output_spec, workspace):
for filename, spec_data in output_spec.items():
if 'type' in spec_data and spec_data['type'] == 'directory':
if 'taskgraph.db' in spec_data['contents']:
continue
_walk_spec(
spec_data['contents'],
os.path.join(workspace, filename))
else:
pre, post = os.path.splitext(filename)
full_path = os.path.join(workspace, f'{pre}{file_suffix}{post}')
if os.path.exists(full_path):
try:
write_metadata_file(
full_path, spec_data, keywords, lineage_statement)
except ValueError as error:
# Some unsupported file formats, e.g. html
LOGGER.debug(error)
_walk_spec(model_module.MODEL_SPEC['outputs'], args_dict['workspace_dir'])

View File

@ -12,10 +12,9 @@ from osgeo import ogr
from osgeo import osr
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
@ -26,31 +25,39 @@ UINT8_NODATA = 255
UINT16_NODATA = 65535
NONINTEGER_SOILS_RASTER_MESSAGE = 'Soil group raster data type must be integer'
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "stormwater",
"model_name": MODEL_METADATA["stormwater"].model_title,
"pyname": MODEL_METADATA["stormwater"].pyname,
"userguide": MODEL_METADATA["stormwater"].userguide,
"model_title": gettext("Urban Stormwater Retention"),
"userguide": "stormwater.html",
"aliases": (),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['lulc_path', 'soil_group_path', 'precipitation_path', 'biophysical_table'],
['adjust_retention_ratios', 'retention_radius', 'road_centerlines_path'],
['aggregate_areas_path', 'replacement_cost'],
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["lulc_path", "soil_group_path", "precipitation_path",
"road_centerlines_path", "aggregate_areas_path"],
"different_projections_ok": True
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"lulc_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True
},
"soil_group_path": spec_utils.SOIL_GROUP,
"precipitation_path": spec_utils.PRECIP,
"soil_group_path": spec.SOIL_GROUP,
"precipitation_path": spec.PRECIP,
"biophysical_table": {
"type": "csv",
"index_col": "lucode",
"columns": {
"lucode": spec_utils.LULC_TABLE_COLUMN,
"lucode": spec.LULC_TABLE_COLUMN,
"emc_[POLLUTANT]": {
"type": "number",
"units": u.milligram/u.liter,
@ -115,6 +122,7 @@ MODEL_SPEC = {
"type": "number",
"units": u.other,
"required": "adjust_retention_ratios",
"allowed": "adjust_retention_ratios",
"about": gettext(
"Radius around each pixel to adjust retention ratios. "
"Measured in raster coordinate system units. For the "
@ -128,11 +136,12 @@ MODEL_SPEC = {
"geometries": {"LINESTRING", "MULTILINESTRING"},
"fields": {},
"required": "adjust_retention_ratios",
"allowed": "adjust_retention_ratios",
"about": gettext("Map of road centerlines"),
"name": gettext("Road centerlines")
},
"aggregate_areas_path": {
**spec_utils.AOI,
**spec.AOI,
"required": False,
"about": gettext(
"Areas over which to aggregate results (typically watersheds "
@ -215,7 +224,7 @@ MODEL_SPEC = {
"Map of aggregate data. This is identical to the aggregate "
"areas input vector, but each polygon is given additional "
"fields with the aggregate data."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"mean_retention_ratio": {
"type": "ratio",
@ -297,7 +306,7 @@ MODEL_SPEC = {
"Copy of the road centerlines vector input, "
"reprojected to the LULC raster projection."),
"fields": {},
"geometries": spec_utils.LINES
"geometries": spec.LINES
},
"rasterized_centerlines.tif": {
@ -369,9 +378,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
INTERMEDIATE_OUTPUTS = {
'lulc_aligned_path': 'lulc_aligned.tif',
@ -487,9 +496,9 @@ def execute(args):
# Build a lookup dictionary mapping each LULC code to its row
# sort by the LULC codes upfront because we use the sorted list in multiple
# places. it's more efficient to do this once.
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table'], **MODEL_SPEC['args']['biophysical_table']
).sort_index()
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table').get_validated_dataframe(
args['biophysical_table']).sort_index()
sorted_lucodes = biophysical_df.index.to_list()
# convert the nested dictionary in to a 2D array where rows are LULC codes
@ -1160,7 +1169,7 @@ def raster_average(raster_path, radius, kernel_path, out_path):
target_nodata=FLOAT_NODATA)
@ validation.invest_validator
@validation.invest_validator
def validate(args, limit_to=None):
"""Validate args to ensure they conform to `execute`'s contract.
@ -1178,8 +1187,7 @@ def validate(args, limit_to=None):
the error message in the second part of the tuple. This should
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(args, MODEL_SPEC['args'],
MODEL_SPEC['args_with_spatial_overlap'])
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
if 'soil_group_path' not in invalid_keys:
# check that soil group raster has integer type

View File

@ -12,9 +12,10 @@ import natcap.invest
from natcap.invest import cli
from natcap.invest import datastack
from natcap.invest import set_locale
from natcap.invest.model_metadata import MODEL_METADATA
from natcap.invest import spec_utils
from natcap.invest import models
from natcap.invest import spec
from natcap.invest import usage
from natcap.invest import validation
LOGGER = logging.getLogger(__name__)
@ -26,11 +27,6 @@ CORS(app, resources={
}
})
PYNAME_TO_MODEL_NAME_MAP = {
metadata.pyname: model_name
for model_name, metadata in MODEL_METADATA.items()
}
@app.route(f'/{PREFIX}/ready', methods=['GET'])
def get_is_ready():
@ -50,9 +46,8 @@ def get_invest_models():
A JSON string
"""
LOGGER.debug('get model list')
set_locale(request.args.get('language', 'en'))
importlib.reload(natcap.invest.model_metadata)
return cli.build_model_list_json()
locale_code = request.args.get('language', 'en')
return cli.build_model_list_json(locale_code)
@app.route(f'/{PREFIX}/getspec', methods=['POST'])
@ -69,11 +64,36 @@ def get_invest_getspec():
"""
set_locale(request.args.get('language', 'en'))
target_model = request.get_json()
target_module = MODEL_METADATA[target_model].pyname
importlib.reload(natcap.invest.spec_utils)
target_module = models.model_id_to_pyname[target_model]
importlib.reload(natcap.invest.validation)
model_module = importlib.reload(
importlib.import_module(name=target_module))
return spec_utils.serialize_args_spec(model_module.MODEL_SPEC)
return model_module.MODEL_SPEC.to_json()
@app.route(f'/{PREFIX}/dynamic_dropdowns', methods=['POST'])
def get_dynamic_dropdown_options():
"""Gets the list of dynamically populated dropdown options.
Body (JSON string):
model_id: string (e.g. carbon)
args: JSON string of InVEST model args keys and values
Returns:
A JSON string.
"""
payload = request.get_json()
LOGGER.debug(payload)
results = {}
model_module = importlib.import_module(
name=models.model_id_to_pyname[payload['model_id']])
for arg_spec in model_module.MODEL_SPEC.inputs:
if (isinstance(arg_spec, spec.OptionStringInput) and
arg_spec.dropdown_function):
results[arg_spec.id] = arg_spec.dropdown_function(
json.loads(payload['args']))
LOGGER.debug(results)
return json.dumps(results)
@app.route(f'/{PREFIX}/validate', methods=['POST'])
@ -81,7 +101,7 @@ def get_invest_validate():
"""Gets the return value of an InVEST model's validate function.
Body (JSON string):
model_module: string (e.g. natcap.invest.carbon)
model_id: string (e.g. carbon)
args: JSON string of InVEST model args keys and values
Accepts a `language` query parameter which should be an ISO 639-1 language
@ -101,7 +121,8 @@ def get_invest_validate():
set_locale(request.args.get('language', 'en'))
importlib.reload(natcap.invest.validation)
model_module = importlib.reload(
importlib.import_module(name=payload['model_module']))
importlib.import_module(
name=models.model_id_to_pyname[payload['model_id']]))
results = model_module.validate(
json.loads(payload['args']), limit_to=limit_to)
@ -109,36 +130,28 @@ def get_invest_validate():
return json.dumps(results)
@app.route(f'/{PREFIX}/colnames', methods=['POST'])
def get_vector_colnames():
"""Get a list of column names from a vector.
This is used to fill in dropdown menu options in a couple models.
@app.route(f'/{PREFIX}/args_enabled', methods=['POST'])
def get_args_enabled():
"""Gets the return value of an InVEST model's validate function.
Body (JSON string):
vector_path (string): path to a vector file
model_id: string (e.g. carbon)
args: JSON string of InVEST model args keys and values
Accepts a `language` query parameter which should be an ISO 639-1 language
code. Validation messages will be translated to the requested language if
translations are available, or fall back to English otherwise.
Returns:
a JSON string.
A JSON string.
"""
payload = request.get_json()
LOGGER.debug(payload)
vector_path = payload['vector_path']
# a lot of times the path will be empty so don't even try to open it
if vector_path:
try:
vector = gdal.OpenEx(vector_path, gdal.OF_VECTOR)
colnames = [defn.GetName() for defn in vector.GetLayer().schema]
LOGGER.debug(colnames)
return json.dumps(colnames)
except Exception as e:
LOGGER.exception(
f'Could not read column names from {vector_path}. ERROR: {e}')
else:
LOGGER.error('Empty vector path.')
# 422 Unprocessable Entity: the server understands the content type
# of the request entity, and the syntax of the request entity is
# correct, but it was unable to process the contained instructions.
return json.dumps([]), 422
model_spec = importlib.import_module(
name=models.model_id_to_pyname[payload['model_id']]).MODEL_SPEC
results = validation.args_enabled(json.loads(payload['args']), model_spec)
LOGGER.debug(results)
return json.dumps(results)
@app.route(f'/{PREFIX}/post_datastack_file', methods=['POST'])
@ -153,13 +166,10 @@ def post_datastack_file():
payload = request.get_json()
stack_type, stack_info = datastack.get_datastack_info(
payload['filepath'], payload.get('extractPath', None))
model_name = PYNAME_TO_MODEL_NAME_MAP[stack_info.model_name]
result_dict = {
'type': stack_type,
'args': stack_info.args,
'module_name': stack_info.model_name,
'model_run_name': model_name,
'model_human_name': MODEL_METADATA[model_name].model_title,
'model_id': stack_info.model_id,
'invest_version': stack_info.invest_version
}
return json.dumps(result_dict)
@ -171,7 +181,7 @@ def write_parameter_set_file():
Body (JSON string):
filepath: string
moduleName: string(e.g. natcap.invest.carbon)
model_id: string (e.g. carbon)
args: JSON string of InVEST model args keys and values
relativePaths: boolean
@ -182,13 +192,13 @@ def write_parameter_set_file():
"""
payload = request.get_json()
filepath = payload['filepath']
modulename = payload['moduleName']
model_id = payload['model_id']
args = json.loads(payload['args'])
relative_paths = payload['relativePaths']
try:
datastack.build_parameter_set(
args, modulename, filepath, relative=relative_paths)
args, model_id, filepath, relative=relative_paths)
except ValueError as message:
LOGGER.error(str(message))
return {
@ -207,7 +217,7 @@ def save_to_python():
Body (JSON string):
filepath: string
modelname: string (a key in natcap.invest.MODEL_METADATA)
model_id: string (matching a model_id from a MODEL_SPEC)
args_dict: JSON string of InVEST model args keys and values
Returns:
@ -215,11 +225,11 @@ def save_to_python():
"""
payload = request.get_json()
save_filepath = payload['filepath']
modelname = payload['modelname']
model_id = payload['model_id']
args_dict = json.loads(payload['args'])
cli.export_to_python(
save_filepath, modelname, args_dict)
save_filepath, model_id, args_dict)
return 'python script saved'
@ -230,7 +240,7 @@ def build_datastack_archive():
Body (JSON string):
filepath: string - the target path to save the archive
moduleName: string (e.g. natcap.invest.carbon) the python module name
model_id: string (e.g. carbon) the model id
args: JSON string of InVEST model args keys and values
Returns:
@ -242,7 +252,7 @@ def build_datastack_archive():
try:
datastack.build_datastack_archive(
json.loads(payload['args']),
payload['moduleName'],
payload['model_id'],
payload['filepath'])
except ValueError as message:
LOGGER.error(str(message))
@ -260,10 +270,12 @@ def build_datastack_archive():
def log_model_start():
payload = request.get_json()
usage._log_model(
payload['model_pyname'],
json.loads(payload['model_args']),
payload['invest_interface'],
payload['session_id'])
pyname=models.model_id_to_pyname[payload['model_id']],
model_args=json.loads(payload['model_args']),
invest_interface=payload['invest_interface'],
session_id=payload['session_id'],
type=payload['type'],
source=payload.get('source', None)) # source only used for plugins
return 'OK'

View File

@ -1,4 +1,5 @@
"""Urban Cooling Model."""
import copy
import logging
import math
import os
@ -19,32 +20,41 @@ from osgeo import ogr
from osgeo import osr
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
TARGET_NODATA = -1
_LOGGING_PERIOD = 5
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "urban_cooling_model",
"model_name": MODEL_METADATA["urban_cooling_model"].model_title,
"pyname": MODEL_METADATA["urban_cooling_model"].pyname,
"userguide": MODEL_METADATA["urban_cooling_model"].userguide,
"model_title": gettext("Urban Cooling"),
"userguide": "urban_cooling_model.html",
"aliases": ("ucm",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['lulc_raster_path', 'ref_eto_raster_path', 'aoi_vector_path', 'biophysical_table_path'],
['t_ref', 'uhi_max', 't_air_average_radius', 'green_area_cooling_distance', 'cc_method'],
['do_energy_valuation', 'building_vector_path', 'energy_consumption_table_path'],
['do_productivity_valuation', 'avg_rel_humidity'],
['cc_weight_shade', 'cc_weight_albedo', 'cc_weight_eti'],
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["lulc_raster_path", "ref_eto_raster_path",
"aoi_vector_path", "building_vector_path"],
"different_projections_ok": True,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"lulc_raster_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"projection_units": u.meter,
"about": gettext(
@ -52,14 +62,14 @@ MODEL_SPEC = {
"raster must have corresponding entries in the Biophysical "
"Table.")
},
"ref_eto_raster_path": spec_utils.ET0,
"aoi_vector_path": spec_utils.AOI,
"ref_eto_raster_path": spec.ET0,
"aoi_vector_path": spec.AOI,
"biophysical_table_path": {
"name": gettext("biophysical table"),
"type": "csv",
"index_col": "lucode",
"columns": {
"lucode": spec_utils.LULC_TABLE_COLUMN,
"lucode": spec.LULC_TABLE_COLUMN,
"kc": {
"type": "number",
"units": u.none,
@ -150,6 +160,7 @@ MODEL_SPEC = {
"name": gettext("average relative humidity"),
"type": "percent",
"required": "do_productivity_valuation",
"allowed": "do_productivity_valuation",
"about": gettext(
"The average relative humidity over the time period of "
"interest. Required if Run Work Productivity Valuation is "
@ -164,8 +175,9 @@ MODEL_SPEC = {
"about": gettext(
"Code indicating the building type. These codes must "
"match those in the Energy Consumption Table.")}},
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"required": "do_energy_valuation",
"allowed": "do_energy_valuation",
"about": gettext(
"A map of built infrastructure footprints. Required if Run "
"Energy Savings Valuation is selected.")
@ -200,6 +212,7 @@ MODEL_SPEC = {
}
},
"required": "do_energy_valuation",
"allowed": "do_energy_valuation",
"about": gettext(
"A table of energy consumption data for each building type. "
"Required if Run Energy Savings Valuation is selected.")
@ -255,7 +268,7 @@ MODEL_SPEC = {
"about": (
"A copy of the input Area of Interest vector with "
"additional fields."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"avg_cc": {
"about": "Average CC value",
@ -294,7 +307,7 @@ MODEL_SPEC = {
},
"buildings_with_stats.shp": {
"about": "A copy of the input vector “Building Footprints” with additional fields.",
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"energy_sav": {
"about": "Energy savings value (kWh or currency if optional energy cost input column was provided in the Energy Consumption Table). Savings are relative to a theoretical scenario where the city contains NO natural areas nor green spaces; where CC = 0 for all LULC classes.",
@ -342,14 +355,14 @@ MODEL_SPEC = {
"about": (
"The Area of Interest vector reprojected to the "
"spatial reference of the LULC."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {}
},
"reprojected_buildings.shp": {
"about": (
"The buildings vector reprojected to the spatial "
"reference of the LULC."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {}
},
"albedo.tif": {
@ -394,9 +407,9 @@ MODEL_SPEC = {
},
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
def execute(args):
@ -461,9 +474,9 @@ def execute(args):
intermediate_dir = os.path.join(
args['workspace_dir'], 'intermediate')
utils.make_directories([args['workspace_dir'], intermediate_dir])
biophysical_df = validation.get_validated_dataframe(
args['biophysical_table_path'],
**MODEL_SPEC['args']['biophysical_table_path'])
biophysical_df = MODEL_SPEC.get_input(
'biophysical_table_path').get_validated_dataframe(
args['biophysical_table_path'])
# cast to float and calculate relative weights
# Use default weights for shade, albedo, eti if the user didn't provide
@ -1146,9 +1159,9 @@ def calculate_energy_savings(
for field in target_building_layer.schema]
type_field_index = fieldnames.index('type')
energy_consumption_df = validation.get_validated_dataframe(
energy_consumption_table_path,
**MODEL_SPEC['args']['energy_consumption_table_path'])
energy_consumption_df = MODEL_SPEC.get_input(
'energy_consumption_table_path').get_validated_dataframe(
energy_consumption_table_path)
target_building_layer.StartTransaction()
last_time = time.time()
@ -1522,26 +1535,24 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
validation_warnings = validation.validate(args, MODEL_SPEC)
invalid_keys = validation.get_invalid_keys(validation_warnings)
if ('biophysical_table_path' not in invalid_keys and
'cc_method' not in invalid_keys):
spec = copy.deepcopy(MODEL_SPEC.get_input('biophysical_table_path'))
if args['cc_method'] == 'factors':
extra_biophysical_keys = ['shade', 'albedo']
spec.columns.get('shade').required = True
spec.columns.get('albedo').required = True
else:
# args['cc_method'] must be 'intensity'.
# If args['cc_method'] isn't one of these two allowed values
# ('intensity' or 'factors'), it'll be caught by
# validation.validate due to the allowed values stated in
# MODEL_SPEC.
extra_biophysical_keys = ['building_intensity']
spec.columns.get('building_intensity').required = True
error_msg = validation.check_csv(
args['biophysical_table_path'],
header_patterns=extra_biophysical_keys,
axis=1)
error_msg = spec.validate(args['biophysical_table_path'])
if error_msg:
validation_warnings.append((['biophysical_table_path'], error_msg))

View File

@ -14,19 +14,26 @@ from osgeo import ogr
from osgeo import osr
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .unit_registry import u
LOGGER = logging.getLogger(__name__)
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "urban_flood_risk_mitigation",
"model_name": MODEL_METADATA["urban_flood_risk_mitigation"].model_title,
"pyname": MODEL_METADATA["urban_flood_risk_mitigation"].pyname,
"userguide": MODEL_METADATA["urban_flood_risk_mitigation"].userguide,
"model_title": gettext("Urban Flood Risk Mitigation"),
"userguide": "urban_flood_mitigation.html",
"aliases": ("ufrm",),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['aoi_watersheds_path', 'rainfall_depth'],
['lulc_path', 'curve_number_table_path', 'soils_hydrological_group_raster_path'],
['built_infrastructure_vector_path', 'infrastructure_damage_loss_table_path']
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["aoi_watersheds_path", "lulc_path",
"built_infrastructure_vector_path",
@ -34,10 +41,10 @@ MODEL_SPEC = {
"different_projections_ok": True,
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"aoi_watersheds_path": spec_utils.AOI,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"aoi_watersheds_path": spec.AOI,
"rainfall_depth": {
"expression": "value > 0",
"type": "number",
@ -46,14 +53,14 @@ MODEL_SPEC = {
"name": gettext("rainfall depth")
},
"lulc_path": {
**spec_utils.LULC,
**spec.LULC,
"projected": True,
"about": gettext(
"Map of LULC. All values in this raster must have "
"corresponding entries in the Biophysical Table.")
},
"soils_hydrological_group_raster_path": {
**spec_utils.SOIL_GROUP,
**spec.SOIL_GROUP,
"projected": True
},
"curve_number_table_path": {
@ -86,7 +93,7 @@ MODEL_SPEC = {
"Code indicating the building type. These codes "
"must match those in the Damage Loss Table."
)}},
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"required": False,
"about": gettext("Map of building footprints."),
"name": gettext("built infrastructure")
@ -113,7 +120,7 @@ MODEL_SPEC = {
}
},
"outputs": {
"Runoff_retention.tif": {
"Runoff_retention_index.tif": {
"about": "Map of runoff retention index.",
"bands": {1: {
"type": "number",
@ -136,7 +143,7 @@ MODEL_SPEC = {
},
"flood_risk_service.shp": {
"about": "Aggregated results for each area of interest.",
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"rnf_rt_idx": {
"about": "Average runoff retention index.",
@ -178,14 +185,14 @@ MODEL_SPEC = {
"about": (
"Copy of AOI vector reprojected to the same spatial "
"reference as the LULC."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {}
},
"structures_reprojected.shp": {
"about": (
"Copy of built infrastructure vector reprojected to "
"the same spatial reference as the LULC."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {}
},
"aligned_lulc.tif": {
@ -206,9 +213,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
def execute(args):
@ -299,9 +306,9 @@ def execute(args):
task_name='align raster stack')
# Load CN table
cn_df = validation.get_validated_dataframe(
args['curve_number_table_path'],
**MODEL_SPEC['args']['curve_number_table_path'])
cn_df = MODEL_SPEC.get_input(
'curve_number_table_path').get_validated_dataframe(
args['curve_number_table_path'])
# make cn_table into a 2d array where first dim is lucode, second is
# 0..3 to correspond to CN_A..CN_D
@ -364,7 +371,7 @@ def execute(args):
# Generate Runoff Retention
runoff_retention_nodata = -9999
runoff_retention_raster_path = os.path.join(
args['workspace_dir'], f'Runoff_retention{file_suffix}.tif')
args['workspace_dir'], f'Runoff_retention_index{file_suffix}.tif')
runoff_retention_task = task_graph.add_task(
func=pygeoprocessing.raster_calculator,
args=([
@ -628,10 +635,9 @@ def _calculate_damage_to_infrastructure_in_aoi(
infrastructure_vector = gdal.OpenEx(structures_vector_path, gdal.OF_VECTOR)
infrastructure_layer = infrastructure_vector.GetLayer()
damage_type_map = validation.get_validated_dataframe(
structures_damage_table,
**MODEL_SPEC['args']['infrastructure_damage_loss_table_path']
)['damage'].to_dict()
damage_type_map = MODEL_SPEC.get_input(
'infrastructure_damage_loss_table_path').get_validated_dataframe(
structures_damage_table)['damage'].to_dict()
infrastructure_layer_defn = infrastructure_layer.GetLayerDefn()
type_index = -1
@ -921,8 +927,7 @@ def validate(args, limit_to=None):
be an empty list if validation succeeds.
"""
validation_warnings = validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
validation_warnings = validation.validate(args, MODEL_SPEC)
sufficient_keys = validation.get_sufficient_keys(args)
invalid_keys = validation.get_invalid_keys(validation_warnings)
@ -930,9 +935,9 @@ def validate(args, limit_to=None):
if ("curve_number_table_path" not in invalid_keys and
"curve_number_table_path" in sufficient_keys):
# Load CN table. Resulting DF has index and CN_X columns only.
cn_df = validation.get_validated_dataframe(
args['curve_number_table_path'],
**MODEL_SPEC['args']['curve_number_table_path'])
cn_df = MODEL_SPEC.get_input(
'curve_number_table_path').get_validated_dataframe(
args['curve_number_table_path'])
# Check for NaN values.
nan_mask = cn_df.isna()
if nan_mask.any(axis=None):

View File

@ -19,11 +19,10 @@ from osgeo import ogr
from osgeo import osr
from . import gettext
from . import spec_utils
from . import spec
from . import utils
from . import validation
from .model_metadata import MODEL_METADATA
from .spec_utils import u
from .spec import u
LOGGER = logging.getLogger(__name__)
UINT32_NODATA = int(numpy.iinfo(numpy.uint32).max)
@ -39,11 +38,19 @@ RADIUS_OPT_URBAN_NATURE = 'radius per urban nature class'
RADIUS_OPT_POP_GROUP = 'radius per population group'
POP_FIELD_REGEX = '^pop_'
ID_FIELDNAME = 'adm_unit_id'
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
'model_id': 'urban_nature_access',
'model_name': MODEL_METADATA['urban_nature_access'].model_title,
'pyname': MODEL_METADATA['urban_nature_access'].pyname,
'userguide': MODEL_METADATA['urban_nature_access'].userguide,
'model_title': gettext('Urban Nature Access'),
'userguide': 'urban_nature_access.html',
'aliases': ('una',),
'ui_spec': {
'order': [
['workspace_dir', 'results_suffix'],
['lulc_raster_path', 'lulc_attribute_table'],
['population_raster_path', 'admin_boundaries_vector_path', 'population_group_radii_table', 'urban_nature_demand', 'aggregate_by_pop_group'],
['search_radius_mode', 'decay_function', 'search_radius']
]
},
'args_with_spatial_overlap': {
'spatial_keys': [
'lulc_raster_path', 'population_raster_path',
@ -51,11 +58,11 @@ MODEL_SPEC = {
'different_projections_ok': True,
},
'args': {
'workspace_dir': spec_utils.WORKSPACE,
'results_suffix': spec_utils.SUFFIX,
'n_workers': spec_utils.N_WORKERS,
'workspace_dir': spec.WORKSPACE,
'results_suffix': spec.SUFFIX,
'n_workers': spec.N_WORKERS,
'lulc_raster_path': {
**spec_utils.LULC,
**spec.LULC,
'projected': True,
'projection_units': u.meter,
'about': (
@ -79,7 +86,7 @@ MODEL_SPEC = {
),
'index_col': 'lucode',
'columns': {
'lucode': spec_utils.LULC_TABLE_COLUMN,
'lucode': spec.LULC_TABLE_COLUMN,
'urban_nature': {
'type': 'ratio',
'about': (
@ -122,7 +129,7 @@ MODEL_SPEC = {
'admin_boundaries_vector_path': {
'type': 'vector',
'name': 'administrative boundaries',
'geometries': spec_utils.POLYGONS,
'geometries': spec.POLYGONS,
'fields': {
"pop_[POP_GROUP]": {
"type": "ratio",
@ -251,6 +258,7 @@ MODEL_SPEC = {
'units': u.m,
'expression': 'value > 0',
'required': f'search_radius_mode == "{RADIUS_OPT_UNIFORM}"',
'allowed': f'search_radius_mode == "{RADIUS_OPT_UNIFORM}"',
'about': gettext(
'The search radius to use when running the model under a '
'uniform search radius. Required when running the model '
@ -260,6 +268,7 @@ MODEL_SPEC = {
'name': 'population group radii table',
'type': 'csv',
'required': f'search_radius_mode == "{RADIUS_OPT_POP_GROUP}"',
'allowed': f'search_radius_mode == "{RADIUS_OPT_POP_GROUP}"',
'index_col': 'pop_group',
'columns': {
"pop_group": {
@ -341,7 +350,7 @@ MODEL_SPEC = {
"about": (
"A copy of the user's administrative boundaries "
"vector with a single layer."),
"geometries": spec_utils.POLYGONS,
"geometries": spec.POLYGONS,
"fields": {
"SUP_DEMadm_cap": {
"type": "number",
@ -612,9 +621,9 @@ MODEL_SPEC = {
}
}
},
'taskgraph_cache': spec_utils.TASKGRAPH_DIR,
'taskgraph_cache': spec.TASKGRAPH_DIR,
}
}
})
_OUTPUT_BASE_FILES = {
@ -934,9 +943,8 @@ def execute(args):
aoi_reprojection_task, lulc_mask_task]
)
attr_table = validation.get_validated_dataframe(
args['lulc_attribute_table'],
**MODEL_SPEC['args']['lulc_attribute_table'])
attr_table = MODEL_SPEC.get_input(
'lulc_attribute_table').get_validated_dataframe(args['lulc_attribute_table'])
kernel_paths = {} # search_radius, kernel path
kernel_tasks = {} # search_radius, kernel task
@ -954,15 +962,15 @@ def execute(args):
lucode_to_search_radii = list(
urban_nature_attrs[['search_radius_m']].itertuples(name=None))
elif args['search_radius_mode'] == RADIUS_OPT_POP_GROUP:
pop_group_table = validation.get_validated_dataframe(
args['population_group_radii_table'],
**MODEL_SPEC['args']['population_group_radii_table'])
pop_group_table = MODEL_SPEC.get_input(
'population_group_radii_table').get_validated_dataframe(
args['population_group_radii_table'])
search_radii = set(pop_group_table['search_radius_m'].unique())
# Build a dict of {pop_group: search_radius_m}
search_radii_by_pop_group = pop_group_table['search_radius_m'].to_dict()
else:
valid_options = ', '.join(
MODEL_SPEC['args']['search_radius_mode']['options'].keys())
MODEL_SPEC.get_input('search_radius_mode').options.keys())
raise ValueError(
"Invalid search radius mode provided: "
f"{args['search_radius_mode']}; must be one of {valid_options}")
@ -1834,8 +1842,8 @@ def _reclassify_urban_nature_area(
Returns:
``None``
"""
lulc_attribute_df = validation.get_validated_dataframe(
lulc_attribute_table, **MODEL_SPEC['args']['lulc_attribute_table'])
lulc_attribute_df = MODEL_SPEC.get_input(
'lulc_attribute_table').get_validated_dataframe(lulc_attribute_table)
squared_pixel_area = abs(
numpy.multiply(*_square_off_pixels(lulc_raster_path)))
@ -1867,9 +1875,9 @@ def _reclassify_urban_nature_area(
target_datatype=gdal.GDT_Float32,
target_nodata=FLOAT32_NODATA,
error_details={
'raster_name': MODEL_SPEC['args']['lulc_raster_path']['name'],
'raster_name': MODEL_SPEC.get_input('lulc_raster_path').name,
'column_name': 'urban_nature',
'table_name': MODEL_SPEC['args']['lulc_attribute_table']['name'],
'table_name': MODEL_SPEC.get_input('lulc_attribute_table').name
}
)
@ -2591,5 +2599,4 @@ def _mask_raster(source_raster_path, mask_raster_path, target_raster_path):
def validate(args, limit_to=None):
return validation.validate(
args, MODEL_SPEC['args'], MODEL_SPEC['args_with_spatial_overlap'])
return validation.validate(args, MODEL_SPEC)

View File

@ -18,6 +18,7 @@ import pygeoprocessing
import requests
from . import utils
from . import spec
ENCODING = sys.getfilesystemencoding()
LOGGER = logging.getLogger(__name__)
@ -65,12 +66,12 @@ def log_run(model_pyname, args):
log_exit_thread.start()
def _calculate_args_bounding_box(args, args_spec):
def _calculate_args_bounding_box(args, model_spec):
"""Calculate the bounding boxes of any GIS types found in `args_dict`.
Args:
args (dict): a string key and any value pair dictionary.
args_spec (dict): the model MODEL_SPEC describing args
model_spec (dict): the model's MODEL_SPEC
Returns:
bb_intersection, bb_union tuple that's either the lat/lng bounding
@ -119,10 +120,11 @@ def _calculate_args_bounding_box(args, args_spec):
# should already have been validated so the path is either valid or
# blank.
spatial_info = None
if args_spec['args'][key]['type'] == 'raster' and value.strip() != '':
if (isinstance(model_spec.get_input(key),
spec.SingleBandRasterInput) and value.strip() != ''):
spatial_info = pygeoprocessing.get_raster_info(value)
elif (args_spec['args'][key]['type'] == 'vector'
and value.strip() != ''):
elif (isinstance(model_spec.get_input(key),
spec.VectorInput) and value.strip() != ''):
spatial_info = pygeoprocessing.get_vector_info(value)
if spatial_info:
@ -156,7 +158,7 @@ def _calculate_args_bounding_box(args, args_spec):
LOGGER.exception(
f'Error when transforming coordinates: {transform_error}')
else:
LOGGER.debug(f'Arg {key} of type {args_spec["args"][key]["type"]} '
LOGGER.debug(f'Arg {key} of type {type(model_spec.get_input(key))} '
'excluded from bounding box calculation')
return bb_intersection, bb_union
@ -189,7 +191,7 @@ def _log_exit_status(session_id, status):
f'an exception encountered in _log_exit_status: {str(exception)}')
def _log_model(pyname, model_args, invest_interface, session_id=None):
def _log_model(pyname, model_args, invest_interface, type, source, session_id=None):
"""Log information about a model run to a remote server.
Args:
@ -197,6 +199,10 @@ def _log_model(pyname, model_args, invest_interface, session_id=None):
model_args (dict): the traditional InVEST argument dictionary.
invest_interface (string): a string identifying the calling UI,
e.g. `Qt` or 'Workbench'.
type (string): 'core' or 'plugin'
source (string): For plugins, a string identifying the source of the
plugin (e.g. 'git+https://github.com/foo/bar' or 'local'). For
core models, this is None.
Returns:
None
@ -216,11 +222,11 @@ def _log_model(pyname, model_args, invest_interface, session_id=None):
md5.update(json.dumps(data).encode('utf-8'))
return md5.hexdigest()
args_spec = importlib.import_module(pyname).MODEL_SPEC
model_spec = importlib.import_module(pyname).MODEL_SPEC
try:
bounding_box_intersection, bounding_box_union = (
_calculate_args_bounding_box(model_args, args_spec))
_calculate_args_bounding_box(model_args, model_spec))
log_start_url = requests.get(_ENDPOINTS_INDEX_URL).json()['START']
requests.post(log_start_url, data={
'model_name': pyname,
@ -233,6 +239,8 @@ def _log_model(pyname, model_args, invest_interface, session_id=None):
'bounding_box_intersection': str(bounding_box_intersection),
'bounding_box_union': str(bounding_box_union),
'session_id': session_id,
'type': type,
'source': source
})
except Exception as exception:
# An exception was thrown, we don't care.

View File

@ -161,17 +161,14 @@ def _format_time(seconds):
@contextlib.contextmanager
def prepare_workspace(
workspace, name, logging_level=logging.NOTSET, exclude_threads=None):
workspace, model_id, logging_level=logging.NOTSET, exclude_threads=None):
"""Prepare the workspace."""
if not os.path.exists(workspace):
os.makedirs(workspace)
modelname = '-'.join(name.replace(':', '').split(' '))
logfile = os.path.join(
workspace,
'InVEST-{modelname}-log-{timestamp}.txt'.format(
modelname=modelname,
timestamp=datetime.now().strftime("%Y-%m-%d--%H_%M_%S")))
f'InVEST-{model_id}-log-{datetime.now().strftime("%Y-%m-%d--%H_%M_%S")}.txt')
with capture_gdal_logging(), log_to_file(logfile,
exclude_threads=exclude_threads,
@ -186,7 +183,7 @@ def prepare_workspace(
try:
yield
except Exception:
LOGGER.exception(f'Exception while executing {modelname}')
LOGGER.exception(f'Exception while executing {model_id}')
raise
finally:
LOGGER.info('Elapsed time: %s',
@ -797,3 +794,50 @@ def matches_format_string(test_string, format_string):
if re.fullmatch(pattern, test_string):
return True
return False
def copy_spatial_files(spatial_filepath, target_dir):
"""Copy spatial files to a new directory.
Args:
spatial_filepath (str): The filepath to a GDAL-supported file.
target_dir (str): The directory where all component files of
``spatial_filepath`` should be copied. If this directory does not
exist, it will be created.
Returns:
filepath (str): The path to a representative file copied into the
``target_dir``. If possible, this will match the basename of
``spatial_filepath``, so if someone provides an ESRI Shapefile called
``my_vector.shp``, the return value will be ``os.path.join(target_dir,
my_vector.shp)``.
"""
LOGGER.info(f'Copying {spatial_filepath} --> {target_dir}')
if not os.path.exists(target_dir):
os.makedirs(target_dir)
source_basename = os.path.basename(spatial_filepath)
return_filepath = None
spatial_file = gdal.OpenEx(spatial_filepath)
for member_file in spatial_file.GetFileList():
# ArcGIS Binary/Grid format includes the directory in the file listing.
# The parent directory isn't strictly needed, so we can just skip it.
if os.path.isdir(member_file):
continue
target_basename = os.path.basename(member_file)
target_filepath = os.path.join(target_dir, target_basename)
if source_basename == target_basename:
return_filepath = target_filepath
shutil.copyfile(member_file, target_filepath)
spatial_file = None
# I can't conceive of a case where the basename of the source file does not
# match any of the member file basenames, but just in case there's a
# weird GDAL driver that does this, it seems reasonable to fall back to
# whichever of the member files was most recent.
if not return_filepath:
return_filepath = target_filepath
return return_filepath

File diff suppressed because it is too large Load Diff

View File

@ -18,10 +18,9 @@ from osgeo import ogr
import taskgraph
import pygeoprocessing
from . import utils
from . import spec_utils
from . import spec
from .unit_registry import u
from . import validation
from .model_metadata import MODEL_METADATA
from . import gettext
@ -131,29 +130,41 @@ CAPTURED_WEM_FIELDS = {
}
}
MODEL_SPEC = {
MODEL_SPEC = spec.build_model_spec({
"model_id": "wave_energy",
"model_name": MODEL_METADATA["wave_energy"].model_title,
"pyname": MODEL_METADATA["wave_energy"].pyname,
"userguide": MODEL_METADATA["wave_energy"].userguide,
"model_title": gettext("Wave Energy Production"),
"userguide": "wave_energy.html",
"aliases": (),
"ui_spec": {
"order": [
['workspace_dir', 'results_suffix'],
['wave_base_data_path', 'analysis_area', 'aoi_path', 'dem_path'],
['machine_perf_path', 'machine_param_path'],
['valuation_container', 'land_gridPts_path', 'machine_econ_path', 'number_of_machines'],
]
},
"args_with_spatial_overlap": {
"spatial_keys": ["aoi_path", "dem_path"],
"different_projections_ok": True
},
"args": {
"workspace_dir": spec_utils.WORKSPACE,
"results_suffix": spec_utils.SUFFIX,
"n_workers": spec_utils.N_WORKERS,
"workspace_dir": spec.WORKSPACE,
"results_suffix": spec.SUFFIX,
"n_workers": spec.N_WORKERS,
"wave_base_data_path": {
"type": "directory",
"contents": {
"NAmerica_WestCoast_4m.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"about": gettext(
"Point vector for the west coast of North America and "
"Hawaii.")},
"WCNA_extract.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"about": gettext(
"Extract vector for the west coast of North America "
"and Hawaii.")},
@ -165,14 +176,14 @@ MODEL_SPEC = {
"NAmerica_EastCoast_4m.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"about": gettext(
"Point vector for the East Coast of North America and "
"Puerto Rico.")},
"ECNA_extract.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"about": gettext(
"Extract vector for the East Coast of North America "
"and Puerto Rico.")},
@ -184,13 +195,13 @@ MODEL_SPEC = {
"North_Sea_4m.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"about": gettext(
"Point vector for the North Sea 4 meter resolution.")},
"North_Sea_4m_Extract.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"about": gettext(
"Extract vector for the North Sea 4 meter resolution.")},
"North_Sea_4m.bin": {
@ -201,13 +212,13 @@ MODEL_SPEC = {
"North_Sea_10m.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"about": gettext(
"Point vector for the North Sea 10 meter resolution.")},
"North_Sea_10m_Extract.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"about": gettext(
"Extract vector for the North Sea 10 meter resolution.")},
"North_Sea_10m.bin": {
@ -218,12 +229,12 @@ MODEL_SPEC = {
"Australia_4m.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"about": gettext("Point vector for Australia.")},
"Australia_Extract.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"about": gettext("Extract vector for Australia.")},
"Australia_4m.bin": {
"type": "file",
@ -231,12 +242,12 @@ MODEL_SPEC = {
"Global.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"about": gettext("Global point vector.")},
"Global_extract.shp": {
"type": "vector",
"fields": {},
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"about": gettext("Global extract vector.")},
"Global_WW3.txt.bin": {
"type": "file",
@ -266,7 +277,7 @@ MODEL_SPEC = {
"name": gettext("analysis area")
},
"aoi_path": {
**spec_utils.AOI,
**spec.AOI,
"projected": True,
"projection_units": u.meter,
"required": False
@ -306,6 +317,7 @@ MODEL_SPEC = {
"about": gettext("Value of the machine parameter.")
}
},
"index_col": "name",
"about": gettext("Table of parameters for the wave energy machine in use."),
"name": gettext("machine parameter table")
},
@ -325,6 +337,7 @@ MODEL_SPEC = {
"type": "csv",
"columns": LAND_GRID_POINT_FIELDS,
"required": "valuation_container",
"allowed": "valuation_container",
"about": gettext(
"A table of data for each connection point. Required if "
"Run Valuation is selected."),
@ -355,7 +368,9 @@ MODEL_SPEC = {
"about": gettext("Value of the machine parameter.")
}
},
"index_col": "name",
"required": "valuation_container",
"allowed": "valuation_container",
"about": gettext(
"Table of economic parameters for the wave energy machine. "
"Required if Run Valuation is selected."),
@ -366,6 +381,7 @@ MODEL_SPEC = {
"type": "number",
"units": u.none,
"required": "valuation_container",
"allowed": "valuation_container",
"about": gettext(
"Number of wave machines to model. Required if Run Valuation "
"is selected."),
@ -491,17 +507,17 @@ MODEL_SPEC = {
"contents": {
"aoi_clipped_to_extract_path.shp": {
"about": "AOI clipped to the analysis area",
"geometries": spec_utils.POLYGON,
"geometries": spec.POLYGON,
"fields": {}
},
"Captured_WEM_InputOutput_Pts.shp": {
"about": "Map of wave data points.",
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": CAPTURED_WEM_FIELDS
},
"Final_WEM_InputOutput_Pts.shp": {
"about": "Map of wave data points.",
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": {
**CAPTURED_WEM_FIELDS,
"W2L_MDIST": {
@ -539,7 +555,7 @@ MODEL_SPEC = {
"Indexed_WEM_InputOutput_Pts.shp": {
"about": "Map of wave data points.",
"fields": INDEXED_WEM_FIELDS,
"geometries": spec_utils.POINT
"geometries": spec.POINT
},
"interpolated_capwe_mwh.tif": {
"about": "Interpolated wave energy",
@ -563,7 +579,7 @@ MODEL_SPEC = {
},
"WEM_InputOutput_Pts.shp": {
"about": "Map of wave data points.",
"geometries": spec_utils.POINT,
"geometries": spec.POINT,
"fields": WEM_FIELDS
},
"GridPt.txt": {
@ -576,9 +592,9 @@ MODEL_SPEC = {
}
}
},
"taskgraph_cache": spec_utils.TASKGRAPH_DIR
"taskgraph_cache": spec.TASKGRAPH_DIR
}
}
})
# Set nodata value and target_pixel_type for new rasters
@ -742,21 +758,15 @@ def execute(args):
LOGGER.debug('Machine Performance Rows : %s', machine_perf_dict['periods'])
LOGGER.debug('Machine Performance Cols : %s', machine_perf_dict['heights'])
machine_param_dict = validation.get_validated_dataframe(
args['machine_param_path'],
index_col='name',
columns={
'name': {'type': 'option_string'},
'value': {'type': 'number'}
},
)['value'].to_dict()
machine_param_dict = MODEL_SPEC.get_input(
'machine_param_path').get_validated_dataframe(
args['machine_param_path'])['value'].to_dict()
# Check if required column fields are entered in the land grid csv file
if 'land_gridPts_path' in args:
# Create a grid_land_df dataframe for later use in valuation
grid_land_df = validation.get_validated_dataframe(
args['land_gridPts_path'],
**MODEL_SPEC['args']['land_gridPts_path'])
grid_land_df = MODEL_SPEC.get_input(
'land_gridPts_path').get_validated_dataframe(args['land_gridPts_path'])
missing_grid_land_fields = []
for field in ['id', 'type', 'lat', 'long', 'location']:
if field not in grid_land_df.columns:
@ -768,14 +778,9 @@ def execute(args):
'Connection Points File: %s' % missing_grid_land_fields)
if 'valuation_container' in args and args['valuation_container']:
machine_econ_dict = validation.get_validated_dataframe(
args['machine_econ_path'],
index_col='name',
columns={
'name': {'type': 'option_string'},
'value': {'type': 'number'}
}
)['value'].to_dict()
machine_econ_dict = MODEL_SPEC.get_input(
'machine_econ_path').get_validated_dataframe(
args['machine_econ_path'])['value'].to_dict()
# Build up a dictionary of possible analysis areas where the key
# is the analysis area selected and the value is a dictionary
@ -2371,4 +2376,4 @@ def validate(args, limit_to=None):
validation warning.
"""
return validation.validate(args, MODEL_SPEC['args'])
return validation.validate(args, MODEL_SPEC)

File diff suppressed because it is too large Load Diff

View File

@ -201,7 +201,7 @@ class CLIHeadlessTests(unittest.TestCase):
from natcap.invest import cli, validation
datastack_dict = {
'model_name': 'natcap.invest.carbon',
'model_id': 'carbon',
'invest_version': '3.10',
'args': {}
}
@ -292,7 +292,7 @@ class CLIHeadlessTests(unittest.TestCase):
"""CLI: Get validation results as JSON from cli."""
from natcap.invest import cli
datastack_dict = {
'model_name': 'natcap.invest.carbon',
'model_id': 'carbon',
'invest_version': '3.10',
'args': {}
}
@ -384,7 +384,7 @@ class CLIUnitTests(unittest.TestCase):
def test_export_to_python_default_args(self):
"""Export a python script w/ default args for a model."""
from natcap.invest import cli, model_metadata
from natcap.invest import cli, models
filename = 'foo.py'
target_filepath = os.path.join(self.workspace_dir, filename)
@ -394,10 +394,10 @@ class CLIUnitTests(unittest.TestCase):
self.assertTrue(os.path.exists(target_filepath))
target_model = model_metadata.MODEL_METADATA[target_model].pyname
target_model = models.model_id_to_pyname[target_model]
model_module = importlib.import_module(name=target_model)
spec = model_module.MODEL_SPEC
expected_args = {key: '' for key in spec['args'].keys()}
expected_args = {key: '' for key in spec.inputs_dict.keys()}
module_name = str(uuid.uuid4()) + 'testscript'
spec = importlib.util.spec_from_file_location(module_name, target_filepath)

View File

@ -211,9 +211,9 @@ class TestPreprocessor(unittest.TestCase):
lulc_csv.write('0,mangrove,True\n')
lulc_csv.write('1,parking lot,False\n')
landcover_df = validation.get_validated_dataframe(
landcover_table_path,
**preprocessor.MODEL_SPEC['args']['lulc_lookup_table_path'])
landcover_df = preprocessor.MODEL_SPEC.get_input(
'lulc_lookup_table_path').get_validated_dataframe(landcover_table_path)
target_table_path = os.path.join(self.workspace_dir,
'transition_table.csv')
@ -227,9 +227,8 @@ class TestPreprocessor(unittest.TestCase):
str(context.exception))
# Re-load the landcover table
landcover_df = validation.get_validated_dataframe(
landcover_table_path,
**preprocessor.MODEL_SPEC['args']['lulc_lookup_table_path'])
landcover_df = preprocessor.MODEL_SPEC.get_input(
'lulc_lookup_table_path').get_validated_dataframe(landcover_table_path)
preprocessor._create_transition_table(
landcover_df, [filename_a, filename_b], target_table_path)
@ -641,10 +640,9 @@ class TestCBC2(unittest.TestCase):
args = TestCBC2._create_model_args(self.workspace_dir)
args['workspace_dir'] = os.path.join(self.workspace_dir, 'workspace')
prior_snapshots = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
**coastal_blue_carbon.MODEL_SPEC['args']['landcover_snapshot_csv']
)['raster_path'].to_dict()
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
baseline_year = min(prior_snapshots.keys())
baseline_raster = prior_snapshots[baseline_year]
with open(args['landcover_snapshot_csv'], 'w') as snapshot_csv:
@ -819,10 +817,9 @@ class TestCBC2(unittest.TestCase):
args = TestCBC2._create_model_args(self.workspace_dir)
args['workspace_dir'] = os.path.join(self.workspace_dir, 'workspace')
prior_snapshots = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
**coastal_blue_carbon.MODEL_SPEC['args']['landcover_snapshot_csv']
)['raster_path'].to_dict()
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
baseline_year = min(prior_snapshots.keys())
baseline_raster = prior_snapshots[baseline_year]
with open(args['landcover_snapshot_csv'], 'w') as snapshot_csv:
@ -879,10 +876,9 @@ class TestCBC2(unittest.TestCase):
# Now work through the extra validation warnings.
# test validation: invalid analysis year
prior_snapshots = validation.get_validated_dataframe(
args['landcover_snapshot_csv'],
**coastal_blue_carbon.MODEL_SPEC['args']['landcover_snapshot_csv']
)['raster_path'].to_dict()
prior_snapshots = coastal_blue_carbon.MODEL_SPEC.get_input(
'landcover_snapshot_csv').get_validated_dataframe(
args['landcover_snapshot_csv'])['raster_path'].to_dict()
baseline_year = min(prior_snapshots)
# analysis year must be >= the last transition year.
args['analysis_year'] = baseline_year

View File

@ -1767,6 +1767,20 @@ class CoastalVulnerabilityValidationTests(unittest.TestCase):
validation.MESSAGES['INVALID_OPTION'].format(option_list=['Trend']))
self.assertTrue(expected_err in err_list)
def test_get_vector_colnames(self):
"""CV: get_vector_colnames."""
from natcap.invest import coastal_vulnerability
colnames = coastal_vulnerability.get_vector_colnames('')
self.assertEqual(colnames, [])
# a vector with one column
path = os.path.join(INPUT_DATA, 'sea_level_rise.gpkg')
colnames = coastal_vulnerability.get_vector_colnames(path)
self.assertEqual(colnames, ['Trend'])
# a non-vector file
path = os.path.join(INPUT_DATA, 'dem_wgs84.tif')
colnames = coastal_vulnerability.get_vector_colnames(path)
self.assertEqual(colnames, [])
def make_slr_vector(slr_point_vector_path, fieldname, shapely_feature, srs):
"""Create an SLR vector with a single point feature.

View File

@ -25,6 +25,19 @@ DATA_DIR = os.path.join(_TEST_FILE_CWD,
SAMPLE_DATA_DIR = os.path.join(
_TEST_FILE_CWD, '..', 'data', 'invest-sample-data')
# These modules live in tests/test_datastack_modules
# Each contains a different MODEL_SPEC for the purpose of datastack testing
MOCK_MODEL_ID_TO_PYNAME = {
name: f'test_datastack_modules.{name}' for name in [
'archive_extraction',
'duplicate_filepaths',
'nonspatial_files',
'raster',
'simple_parameters',
'ui_parameter_archive',
'vector'
]
}
# Allow our tests to import the test modules in the test directory.
sys.path.append(_TEST_FILE_CWD)
@ -125,8 +138,10 @@ class DatastackArchiveTests(unittest.TestCase):
archive_path = os.path.join(self.workspace, 'archive.invs.tar.gz')
datastack.build_datastack_archive(
params, 'test_datastack_modules.simple_parameters', archive_path)
with patch('natcap.invest.datastack.models') as p:
p.model_id_to_pyname = MOCK_MODEL_ID_TO_PYNAME
datastack.build_datastack_archive(
params, 'simple_parameters', archive_path)
out_directory = os.path.join(self.workspace, 'extracted_archive')
datastack._tarfile_safe_extract(archive_path, out_directory)
@ -141,6 +156,7 @@ class DatastackArchiveTests(unittest.TestCase):
def test_collect_rasters(self):
"""Datastack: test collect GDAL rasters."""
import natcap.invest.models
from natcap.invest import datastack
for raster_filename in (
'dem', # This is a multipart raster
@ -152,8 +168,10 @@ class DatastackArchiveTests(unittest.TestCase):
# Collect the raster's files into a single archive
archive_path = os.path.join(self.workspace, 'archive.invs.tar.gz')
datastack.build_datastack_archive(
params, 'test_datastack_modules.raster', archive_path)
with patch('natcap.invest.datastack.models') as p:
p.model_id_to_pyname = MOCK_MODEL_ID_TO_PYNAME
datastack.build_datastack_archive(
params, 'raster', archive_path)
# extract the archive
out_directory = os.path.join(self.workspace, 'extracted_archive')
@ -193,9 +211,11 @@ class DatastackArchiveTests(unittest.TestCase):
archive_path = os.path.join(dest_dir,
'archive.invs.tar.gz')
# Collect the vector's files into a single archive
datastack.build_datastack_archive(
params, 'test_datastack_modules.vector', archive_path)
with patch('natcap.invest.datastack.models') as p:
p.model_id_to_pyname = MOCK_MODEL_ID_TO_PYNAME
# Collect the vector's files into a single archive
datastack.build_datastack_archive(
params, 'vector', archive_path)
# extract the archive
out_directory = os.path.join(dest_dir, 'extracted_archive')
@ -243,8 +263,10 @@ class DatastackArchiveTests(unittest.TestCase):
archive_path = os.path.join(self.workspace, 'archive.invs.tar.gz')
datastack.build_datastack_archive(
params, 'test_datastack_modules.archive_extraction', archive_path)
with patch('natcap.invest.datastack.models') as p:
p.model_id_to_pyname = MOCK_MODEL_ID_TO_PYNAME
datastack.build_datastack_archive(
params, 'archive_extraction', archive_path)
# extract the archive
out_directory = os.path.join(self.workspace, 'extracted_archive')
@ -290,8 +312,10 @@ class DatastackArchiveTests(unittest.TestCase):
# Collect the file into an archive
archive_path = os.path.join(self.workspace, 'archive.invs.tar.gz')
datastack.build_datastack_archive(
params, 'test_datastack_modules.nonspatial_files', archive_path)
with patch('natcap.invest.datastack.models') as p:
p.model_id_to_pyname = MOCK_MODEL_ID_TO_PYNAME
datastack.build_datastack_archive(
params, 'nonspatial_files', archive_path)
# extract the archive
out_directory = os.path.join(self.workspace, 'extracted_archive')
@ -328,8 +352,10 @@ class DatastackArchiveTests(unittest.TestCase):
# Collect the file into an archive
archive_path = os.path.join(self.workspace, 'archive.invs.tar.gz')
datastack.build_datastack_archive(
params, 'test_datastack_modules.duplicate_filepaths', archive_path)
with patch('natcap.invest.datastack.models') as p:
p.model_id_to_pyname = MOCK_MODEL_ID_TO_PYNAME
datastack.build_datastack_archive(
params, 'duplicate_filepaths', archive_path)
# extract the archive
out_directory = os.path.join(self.workspace, 'extracted_archive')
@ -358,6 +384,7 @@ class DatastackArchiveTests(unittest.TestCase):
"""Datastack: test archive extraction."""
from natcap.invest import datastack
from natcap.invest import utils
from natcap.invest import spec
from natcap.invest import validation
params = {
@ -412,8 +439,10 @@ class DatastackArchiveTests(unittest.TestCase):
spatial_csv.write(f'4,{target_csv_vector_path}\n')
archive_path = os.path.join(self.workspace, 'archive.invs.tar.gz')
datastack.build_datastack_archive(
params, 'test_datastack_modules.archive_extraction', archive_path)
with patch('natcap.invest.datastack.models') as p:
p.model_id_to_pyname = MOCK_MODEL_ID_TO_PYNAME
datastack.build_datastack_archive(
params, 'archive_extraction', archive_path)
out_directory = os.path.join(self.workspace, 'extracted_archive')
archive_params = datastack.extract_datastack_archive(
archive_path, out_directory)
@ -435,13 +464,14 @@ class DatastackArchiveTests(unittest.TestCase):
self.assertTrue(
filecmp.cmp(archive_params[key], params[key], shallow=False))
spatial_csv_dict = validation.get_validated_dataframe(
archive_params['spatial_table'],
spatial_csv_dict = spec.CSVInput(
index_col='id',
columns={
'id': {'type': 'integer'},
'path': {'type': 'file'}
}).to_dict(orient='index')
columns=[
spec.IntegerInput(id='id'),
spec.FileInput(id='path')]
).get_validated_dataframe(
archive_params['spatial_table']
).to_dict(orient='index')
spatial_csv_dir = os.path.dirname(archive_params['spatial_table'])
numpy.testing.assert_allclose(
pygeoprocessing.raster_to_numpy_array(
@ -467,9 +497,28 @@ class DatastackArchiveTests(unittest.TestCase):
with self.assertRaises(ValueError):
with patch('natcap.invest.datastack.build_parameter_set',
side_effect=ValueError(error_message)):
datastack.build_datastack_archive(
params, 'test_datastack_modules.simple_parameters',
archive_path)
with patch('natcap.invest.datastack.models') as p:
p.model_id_to_pyname = MOCK_MODEL_ID_TO_PYNAME
datastack.build_datastack_archive(
params, 'simple_parameters',archive_path)
def test_extract_old_style_datastack(self):
"""Datastack: extract old-style datastack that uses pyname"""
from natcap.invest import datastack
json_path = os.path.join(self.workspace, 'old_datastack.json')
with open(json_path, 'w') as file:
json.dump({
"args": {
"factor": "",
"raster_path": "",
"results_suffix": "",
"workspace_dir": ""
},
"invest_version": "3.14.2",
"model_name": "natcap.invest.carbon"
}, file)
datastack_info = datastack.extract_parameter_set(json_path)
self.assertEqual(datastack_info.model_id, "carbon")
class ParameterSetTest(unittest.TestCase):
@ -661,7 +710,6 @@ class ParameterSetTest(unittest.TestCase):
07/20/2017 16:37:48 natcap.invest.ui.model INFO post args.
"""))
params = datastack.extract_parameters_from_logfile(logfile_path)
expected_params = datastack.ParameterSet(
@ -700,15 +748,18 @@ class ParameterSetTest(unittest.TestCase):
}
archive_path = os.path.join(self.workspace, 'archive.invs.tar.gz')
datastack.build_datastack_archive(
params, 'test_datastack_modules.simple_parameters', archive_path)
with patch('natcap.invest.datastack.models') as p:
p.model_id_to_pyname = MOCK_MODEL_ID_TO_PYNAME
datastack.build_datastack_archive(
params, 'simple_parameters', archive_path)
stack_type, stack_info = datastack.get_datastack_info(
archive_path, extract_path=os.path.join(self.workspace, 'archive'))
self.assertEqual(stack_type, 'archive')
self.assertEqual(stack_info, datastack.ParameterSet(
params, 'test_datastack_modules.simple_parameters',
params, 'simple_parameters',
natcap.invest.__version__))
def test_get_datastack_info_parameter_set(self):
@ -723,7 +774,7 @@ class ParameterSetTest(unittest.TestCase):
'd': '',
}
test_module_name = 'test_datastack_modules.simple_parameters'
test_module_name = 'simple_parameters'
json_path = os.path.join(self.workspace, 'archive.invs.json')
datastack.build_parameter_set(
params, test_module_name, json_path)
@ -756,33 +807,27 @@ class ParameterSetTest(unittest.TestCase):
self.assertEqual(stack_info, datastack.ParameterSet(
args, 'some_modelname', natcap.invest.__version__))
def test_get_datastack_info_logfile_iui_style(self):
"""Datastack: test get datastack info logfile iui style."""
def test_get_datastack_info_logfile_old_style(self):
"""Datastack: test get datastack info logfile old style."""
import natcap.invest
from natcap.invest import datastack
args = {
'a': 1,
'b': 2.7,
'c': [1, 2, 3.55],
'd': 'hello, world!',
'e': False,
}
logfile_path = os.path.join(self.workspace, 'logfile.txt')
with open(logfile_path, 'w') as logfile:
logfile.write(textwrap.dedent("""
Arguments:
suffix foo
some_int 1
some_float 2.33
workspace_dir some_workspace_dir
some other logging here.
"""))
expected_args = {
'suffix': 'foo',
'some_int': 1,
'some_float': 2.33,
'workspace_dir': 'some_workspace_dir',
}
# Old style of log files include the pyname instead of model ID
logfile.write(datastack.format_args_dict(args, 'natcap.invest.carbon'))
stack_type, stack_info = datastack.get_datastack_info(logfile_path)
self.assertEqual(stack_type, 'logfile')
self.assertEqual(stack_info, datastack.ParameterSet(
expected_args, datastack.UNKNOWN, datastack.UNKNOWN))
args, 'carbon', natcap.invest.__version__))
@unittest.skipUnless(sys.platform.startswith("win"), "requires Windows")
def test_mixed_path_separators_in_paramset_windows(self):
@ -884,8 +929,7 @@ class UtilitiesTest(unittest.TestCase):
'foo': 'bar',
}
args_string = format_args_dict(args_dict=args_dict,
model_name='test_model')
args_string = format_args_dict(args_dict, 'test_model')
expected_string = str(
'Arguments for InVEST test_model %s:\n'
'foo bar\n'

View File

@ -1,26 +1,30 @@
MODEL_SPEC = {
'model_id': 'archive_extraction_model',
'args': {
'blank': {'type': 'freestyle_string'},
'a': {'type': 'integer'},
'b': {'type': 'freestyle_string'},
'c': {'type': 'freestyle_string'},
'foo': {'type': 'file'},
'bar': {'type': 'file'},
'data_dir': {'type': 'directory', 'contents': {}},
'raster': {'type': 'raster'},
'vector': {'type': 'vector'},
'simple_table': {'type': 'csv'},
'spatial_table': {
'type': 'csv',
'columns': {
'ID': {'type': 'integer'},
'path': {
'type': {'raster', 'vector'},
'geometries': {'POINT', 'POLYGON'},
'bands': {1: {'type': 'number'}}
}
}
}
}
}
from natcap.invest import spec
MODEL_SPEC = spec.ModelSpec(inputs=[
spec.StringInput(id='blank'),
spec.IntegerInput(id='a'),
spec.StringInput(id='b'),
spec.StringInput(id='c'),
spec.FileInput(id='foo'),
spec.FileInput(id='bar'),
spec.DirectoryInput(id='data_dir', contents={}),
spec.SingleBandRasterInput(id='raster'),
spec.VectorInput(id='vector', fields={}, geometry_types={}),
spec.CSVInput(id='simple_table'),
spec.CSVInput(
id='spatial_table',
columns=[
spec.IntegerInput(id='ID'),
spec.RasterOrVectorInput(
id='path',
fields={},
geometry_types={'POINT', 'POLYGON'}
)
]
)],
outputs={},
model_id='archive_extraction_model',
model_title='',
userguide='',
input_field_order=[]
)

View File

@ -1,7 +1,13 @@
MODEL_SPEC = {
'model_id': 'duplicate_filepaths_model',
'args': {
'foo': {'type': 'file'},
'bar': {'type': 'file'},
}
}
from natcap.invest import spec
MODEL_SPEC = spec.ModelSpec(
inputs=[
spec.FileInput(id='foo'),
spec.FileInput(id='bar')
],
outputs={},
model_id='duplicate_filepaths_model',
model_title='',
userguide='',
input_field_order=[]
)

View File

@ -1,7 +1,13 @@
MODEL_SPEC = {
'model_id': 'nonspatial_model',
'args': {
'some_file': {'type': 'file'},
'data_dir': {'type': 'directory', 'contents': {}},
}
}
from natcap.invest import spec
MODEL_SPEC = spec.ModelSpec(inputs=[
spec.FileInput(id='some_file'),
spec.DirectoryInput(
id='data_dir',
contents=[])],
outputs={},
model_id='nonspatial_model',
model_title='',
userguide='',
input_field_order=[]
)

View File

@ -1,6 +1,10 @@
MODEL_SPEC = {
'model_id': 'raster_model',
'args': {
'raster': {'type': 'raster'},
}
}
from natcap.invest import spec
MODEL_SPEC = spec.ModelSpec(inputs=[
spec.SingleBandRasterInput(id='raster')],
outputs={},
model_id='raster_model',
model_title='',
userguide='',
input_field_order=[]
)

View File

@ -1,10 +1,17 @@
MODEL_SPEC = {
'model_id': 'simple_model',
'args': {
'a': {'type': 'integer'},
'b': {'type': 'freestyle_string'},
'c': {'type': 'freestyle_string'},
'd': {'type': 'freestyle_string'},
'workspace_dir': {'type': 'directory', 'contents': {}},
}
}
from natcap.invest import spec
MODEL_SPEC = spec.ModelSpec(inputs=[
spec.IntegerInput(id='a'),
spec.StringInput(id='b'),
spec.StringInput(id='c'),
spec.StringInput(id='d'),
spec.DirectoryInput(
id='workspace_dir',
contents=[]
)],
outputs={},
model_id='simple_model',
model_title='',
userguide='',
input_field_order=[]
)

View File

@ -1,7 +1,11 @@
MODEL_SPEC = {
'model_id': 'ui_parameters_model',
'args': {
'foo': {'type': 'freestyle_string'},
'bar': {'type': 'freestyle_string'},
}
}
from natcap.invest import spec
MODEL_SPEC = SimpleNamespace(inputs=[
spec.StringInput(id='foo'),
spec.StringInput(id='bar')],
outputs={},
model_id='ui_parameters_model',
model_title='',
userguide='',
input_field_order=[]
)

View File

@ -1,6 +1,11 @@
MODEL_SPEC = {
'model_id': 'vector_model',
'args': {
'vector': {'type': 'vector'},
}
}
from natcap.invest import spec
MODEL_SPEC = spec.ModelSpec(inputs=[
spec.VectorInput(
id='vector', fields={}, geometry_types={})],
outputs={},
model_id='vector_model',
model_title='',
userguide='',
input_field_order=[]
)

View File

@ -1161,8 +1161,8 @@ class HRAModelTests(unittest.TestCase):
def test_model(self):
"""HRA: end-to-end test of the model, including datastack."""
from natcap.invest import datastack
from natcap.invest import hra
from natcap.invest import datastack
args = {
'workspace_dir': os.path.join(self.workspace_dir, 'workspace'),
@ -1263,7 +1263,7 @@ class HRAModelTests(unittest.TestCase):
archive_path = os.path.join(self.workspace_dir, 'datstack.tar.gz')
datastack.build_datastack_archive(
args, 'natcap.invest.hra', archive_path)
args, 'habitat_risk_assessment', archive_path)
unarchived_path = os.path.join(self.workspace_dir, 'unarchived_data')
unarchived_args = datastack.extract_datastack_archive(

View File

@ -1,46 +1,116 @@
import importlib
import re
import subprocess
import unittest
import pytest
import pint
from natcap.invest.model_metadata import MODEL_METADATA
from natcap.invest.models import model_id_to_pyname
from natcap.invest import spec
from osgeo import gdal
PLUGIN_URL = 'git+https://github.com/emlys/demo-invest-plugin.git'
PLUGIN_NAME = 'foo-model'
gdal.UseExceptions()
valid_nested_types = {
valid_nested_input_types = {
None: { # if no parent type (arg is top-level), then all types are valid
'boolean',
'integer',
'csv',
'directory',
'file',
'freestyle_string',
'number',
'option_string',
'percent',
'raster',
'ratio',
'vector',
spec.BooleanInput,
spec.CSVInput,
spec.DirectoryInput,
spec.FileInput,
spec.IntegerInput,
spec.NumberInput,
spec.OptionStringInput,
spec.PercentInput,
spec.RasterOrVectorInput,
spec.RatioInput,
spec.SingleBandRasterInput,
spec.StringInput,
spec.VectorInput
},
'raster': {'integer', 'number', 'ratio', 'percent'},
'vector': {
'integer',
'freestyle_string',
'number',
'option_string',
'percent',
'ratio'},
'csv': {
'boolean',
'integer',
'freestyle_string',
'number',
'option_string',
'percent',
'raster',
'ratio',
'vector'},
'directory': {'csv', 'directory', 'file', 'raster', 'vector'}
spec.SingleBandRasterInput: {
spec.IntegerInput,
spec.NumberInput,
spec.PercentInput,
spec.RatioInput
},
spec.VectorInput: {
spec.IntegerInput,
spec.NumberInput,
spec.OptionStringInput,
spec.PercentInput,
spec.RatioInput,
spec.StringInput
},
spec.CSVInput: {
spec.BooleanInput,
spec.IntegerInput,
spec.NumberInput,
spec.OptionStringInput,
spec.PercentInput,
spec.RasterOrVectorInput,
spec.RatioInput,
spec.SingleBandRasterInput,
spec.StringInput,
spec.VectorInput
},
spec.DirectoryInput: {
spec.CSVInput,
spec.DirectoryInput,
spec.FileInput,
spec.RasterOrVectorInput,
spec.SingleBandRasterInput,
spec.VectorInput
}
}
valid_nested_output_types = {
None: { # if no parent type (arg is top-level), then all types are valid
spec.CSVOutput,
spec.DirectoryOutput,
spec.FileOutput,
spec.IntegerOutput,
spec.NumberOutput,
spec.OptionStringOutput,
spec.PercentOutput,
spec.RatioOutput,
spec.SingleBandRasterOutput,
spec.StringOutput,
spec.VectorOutput
},
spec.SingleBandRasterOutput: {
spec.IntegerOutput,
spec.NumberOutput,
spec.PercentOutput,
spec.RatioOutput
},
spec.VectorOutput: {
spec.IntegerOutput,
spec.NumberOutput,
spec.OptionStringOutput,
spec.PercentOutput,
spec.RatioOutput,
spec.StringOutput
},
spec.CSVOutput: {
spec.IntegerOutput,
spec.NumberOutput,
spec.OptionStringOutput,
spec.PercentOutput,
spec.RatioOutput,
spec.SingleBandRasterOutput,
spec.StringOutput,
spec.VectorOutput
},
spec.DirectoryOutput: {
spec.CSVOutput,
spec.DirectoryOutput,
spec.FileOutput,
spec.SingleBandRasterOutput,
spec.VectorOutput
}
}
@ -50,42 +120,47 @@ class ValidateModelSpecs(unittest.TestCase):
def test_model_specs_are_valid(self):
"""MODEL_SPEC: test each spec meets the expected pattern."""
required_keys = {
'model_id', 'model_name', 'pyname', 'userguide', 'args', 'outputs'}
optional_spatial_key = 'args_with_spatial_overlap'
for model_name, metadata in MODEL_METADATA.items():
# metadata is a collections.namedtuple, fields accessible by name
model = importlib.import_module(metadata.pyname)
required_keys = {'model_id', 'model_title', 'userguide',
'aliases', 'inputs', 'input_field_order', 'outputs'}
for model_id, pyname in model_id_to_pyname.items():
model = importlib.import_module(pyname)
# Validate top-level keys are correct
with self.subTest(metadata.pyname):
with self.subTest(pyname):
self.assertTrue(
required_keys.issubset(model.MODEL_SPEC),
required_keys.issubset(set(dir(model.MODEL_SPEC))),
("Required key(s) missing from MODEL_SPEC: "
f"{set(required_keys).difference(model.MODEL_SPEC)}"))
extra_keys = set(model.MODEL_SPEC).difference(required_keys)
if (extra_keys):
self.assertEqual(extra_keys, set([optional_spatial_key]))
self.assertTrue(
set(model.MODEL_SPEC[optional_spatial_key]).issubset(
{'spatial_keys', 'different_projections_ok'}))
f"{set(required_keys).difference(set(dir(model.MODEL_SPEC)))}"))
self.assertIsInstance(model.MODEL_SPEC.input_field_order, list)
found_keys = set()
for group in model.MODEL_SPEC.input_field_order:
self.assertIsInstance(group, list)
for key in group:
self.assertIsInstance(key, str)
self.assertNotIn(key, found_keys)
found_keys.add(key)
for arg_spec in model.MODEL_SPEC.inputs:
if arg_spec.hidden is True:
found_keys.add(arg_spec.id)
self.assertEqual(found_keys, set([s.id for s in model.MODEL_SPEC.inputs]))
# validate that each arg meets the expected pattern
# save up errors to report at the end
for key, arg in model.MODEL_SPEC['args'].items():
for arg_spec in model.MODEL_SPEC.inputs:
# the top level should have 'name' and 'about' attrs
# but they aren't required at nested levels
self.validate_args(arg, f'{model_name}.args.{key}')
self.validate_args(arg_spec, f'{model_id}.inputs.{arg_spec.id}')
for key, spec in model.MODEL_SPEC['outputs'].items():
self.validate_output(spec, f'{model_name}.outputs.{key}')
for output_spec in model.MODEL_SPEC.outputs:
self.validate_output(output_spec, f'{model_id}.outputs.{output_spec.id}')
def validate_output(self, spec, key, parent_type=None):
def validate_output(self, output_spec, key, parent_type=None):
"""
Recursively validate nested output specs against the output spec standard.
Args:
spec (dict): any nested output spec component of a MODEL_SPEC
output_spec (dict): any nested output spec component of a MODEL_SPEC
key (str): key to identify the spec by in error messages
parent_type (str): the type of this output's parent output (or None if
no parent).
@ -100,135 +175,85 @@ class ValidateModelSpecs(unittest.TestCase):
# if parent_type is None: # all top-level args must have these attrs
# for attr in ['about']:
# self.assertIn(attr, spec)
attrs = set(spec.keys())
attrs = set(dir(output_spec)) - set(dir(object()))
if 'type' in spec:
t = spec['type']
else:
file_extension = key.split('.')[-1]
if file_extension == 'tif':
t = 'raster'
elif file_extension in {'shp', 'gpkg', 'geojson'}:
t = 'vector'
elif file_extension == 'csv':
t = 'csv'
elif file_extension in {'json', 'txt', 'pickle', 'db', 'zip',
'dat', 'idx', 'html'}:
t = 'file'
else:
raise Warning(
f'output {key} has no recognized file extension and '
'no "type" property')
t = type(output_spec)
self.assertIn(t, valid_nested_output_types[parent_type])
self.assertIn(t, valid_nested_types[parent_type])
if t == 'number':
if t is spec.NumberOutput:
# number type should have a units property
self.assertIn('units', spec)
self.assertTrue(hasattr(output_spec, 'units'))
# Undefined units should use the custom u.none unit
self.assertIsInstance(spec['units'], pint.Unit)
attrs.remove('units')
self.assertIsInstance(output_spec.units, pint.Unit)
elif t == 'raster':
# raster type should have a bands property that maps each band
# index to a nested type dictionary describing the band's data
self.assertIn('bands', spec)
self.assertIsInstance(spec['bands'], dict)
for band in spec['bands']:
self.assertIsInstance(band, int)
self.validate_output(
spec['bands'][band],
f'{key}.bands.{band}',
parent_type=t)
attrs.remove('bands')
elif t is spec.SingleBandRasterOutput:
self.assertTrue(hasattr(output_spec, 'data_type'))
self.assertTrue(hasattr(output_spec, 'units'))
elif t == 'vector':
elif t is spec.VectorOutput:
# vector type should have:
# - a fields property that maps each field header to a nested
# type dictionary describing the data in that field
# - a geometries property: the set of valid geometry types
self.assertIn('fields', spec)
self.assertIsInstance(spec['fields'], dict)
for field in spec['fields']:
self.assertIsInstance(field, str)
# - a geometry_types property: the set of valid geometry types
self.assertTrue(hasattr(output_spec, 'fields'))
for field in output_spec.fields:
self.validate_output(
spec['fields'][field],
field,
f'{key}.fields.{field}',
parent_type=t)
self.assertIn('geometries', spec)
self.assertIsInstance(spec['geometries'], set)
self.assertTrue(hasattr(output_spec, 'geometry_types'))
self.assertIsInstance(output_spec.geometry_types, set)
attrs.remove('fields')
attrs.remove('geometries')
elif t == 'csv':
elif t is spec.CSVOutput:
# csv type may have a columns property.
# the columns property maps each expected column header
# name/pattern to a nested type dictionary describing the data
# in that column. may be absent if the table structure
# is too complex to describe this way.
self.assertIn('columns', spec)
self.assertIsInstance(spec['columns'], dict)
for column in spec['columns']:
self.assertIsInstance(column, str)
self.assertTrue(hasattr(output_spec, 'columns'))
for column in output_spec.columns:
self.validate_output(
spec['columns'][column],
column,
f'{key}.columns.{column}',
parent_type=t)
if 'index_col' in spec:
self.assertIn(spec['index_col'], spec['columns'])
if output_spec.index_col:
self.assertIn(output_spec.index_col, [s.id for s in output_spec.columns])
attrs.discard('columns')
attrs.discard('index_col')
elif t == 'directory':
elif t is spec.DirectoryOutput:
# directory type should have a contents property that maps each
# expected path name/pattern within the directory to a nested
# type dictionary describing the data at that filepath
self.assertIn('contents', spec)
self.assertIsInstance(spec['contents'], dict)
for path in spec['contents']:
self.assertIsInstance(path, str)
self.assertTrue(hasattr(output_spec, 'contents'))
for path in output_spec.contents:
self.validate_output(
spec['contents'][path],
path,
f'{key}.contents.{path}',
parent_type=t)
attrs.remove('contents')
elif t == 'option_string':
elif t is spec.OptionStringOutput:
# option_string type should have an options property that
# describes the valid options
self.assertIn('options', spec)
self.assertIsInstance(spec['options'], dict)
for option, description in spec['options'].items():
self.assertTrue(hasattr(output_spec, 'options'))
self.assertIsInstance(output_spec.options, dict)
for option, description in output_spec.options.items():
self.assertTrue(
isinstance(option, str) or
isinstance(option, int))
attrs.remove('options')
elif t == 'file':
elif t is spec.FileOutput:
pass
# iterate over the remaining attributes
# type-specific ones have been removed by this point
if 'about' in attrs:
self.assertIsInstance(spec['about'], str)
attrs.remove('about')
if 'created_if' in attrs:
if output_spec.about:
self.assertIsInstance(output_spec.about, str)
if output_spec.created_if:
# should be an arg key indicating that the output is
# created if that arg is provided or checked
self.assertIsInstance(spec['created_if'], str)
attrs.remove('created_if')
if 'type' in attrs:
self.assertIsInstance(spec['type'], str)
attrs.remove('type')
# args should not have any unexpected properties
# all attrs should have been removed by now
if attrs:
raise AssertionError(f'{key} has key(s) {attrs} that are not '
'expected for its type')
self.assertTrue(
isinstance(output_spec.created_if, str) or
isinstance(output_spec.created_if, bool))
def validate_args(self, arg, name, parent_type=None):
"""
@ -249,222 +274,186 @@ class ValidateModelSpecs(unittest.TestCase):
with self.subTest(nested_arg_name=name):
if parent_type is None: # all top-level args must have these attrs
for attr in ['name', 'about']:
self.assertIn(attr, arg)
self.assertTrue(hasattr(arg, attr))
# arg['type'] can be either a string or a set of strings
types = arg['type'] if isinstance(
arg['type'], set) else [arg['type']]
attrs = set(arg.keys())
attrs = set(dir(arg))
for t in types:
self.assertIn(t, valid_nested_types[parent_type])
t = type(arg)
self.assertIn(t, valid_nested_input_types[parent_type])
if t == 'option_string':
# option_string type should have an options property that
# describes the valid options
self.assertIn('options', arg)
# May be a list or dict because some option sets are self
# explanatory and others need a description
self.assertIsInstance(arg['options'], dict)
for key, val in arg['options'].items():
if t is spec.OptionStringInput:
# option_string type should have an options property that
# describes the valid options
self.assertTrue(hasattr(arg, 'options'))
# May be a list or dict because some option sets are self
# explanatory and others need a description
self.assertIsInstance(arg.options, dict)
for key, val in arg.options.items():
self.assertTrue(
isinstance(key, str) or
isinstance(key, int))
self.assertIsInstance(val, dict)
# top-level option_string args are shown as dropdowns
# so each option needs a display name
# an additional description is optional
if parent_type is None:
self.assertTrue(
isinstance(key, str) or
isinstance(key, int))
self.assertIsInstance(val, dict)
# top-level option_string args are shown as dropdowns
# so each option needs a display name
# an additional description is optional
if parent_type is None:
self.assertTrue(
set(val.keys()) == {'display_name'} or
set(val.keys()) == {
'display_name', 'description'})
# option_strings within a CSV or vector don't get a
# display name. the user has to enter the key.
else:
self.assertEqual(set(val.keys()), {'description'})
set(val.keys()) == {'display_name'} or
set(val.keys()) == {
'display_name', 'description'})
# option_strings within a CSV or vector don't get a
# display name. the user has to enter the key.
else:
self.assertEqual(set(val.keys()), {'description'})
if 'display_name' in val:
self.assertIsInstance(val['display_name'], str)
if 'description' in val:
self.assertIsInstance(val['description'], str)
if 'display_name' in val:
self.assertIsInstance(val['display_name'], str)
if 'description' in val:
self.assertIsInstance(val['description'], str)
attrs.remove('options')
attrs.remove('options')
elif t == 'freestyle_string':
# freestyle_string may optionally have a regexp attribute
# this is a regular expression that the string must match
if 'regexp' in arg:
self.assertIsInstance(arg['regexp'], str)
re.compile(arg['regexp']) # should be regex compilable
attrs.remove('regexp')
elif t is spec.StringInput:
# freestyle_string may optionally have a regexp attribute
# this is a regular expression that the string must match
if arg.regexp:
self.assertIsInstance(arg.regexp, str)
re.compile(arg.regexp) # should be regex compilable
attrs.remove('regexp')
elif t == 'number':
# number type should have a units property
self.assertIn('units', arg)
# Undefined units should use the custom u.none unit
self.assertIsInstance(arg['units'], pint.Unit)
attrs.remove('units')
elif t is spec.NumberInput:
# number type should have a units property
self.assertTrue(hasattr(arg, 'units'))
# Undefined units should use the custom u.none unit
self.assertIsInstance(arg.units, pint.Unit)
# number type may optionally have an 'expression' attribute
# this is a string expression to be evaluated with the
# intent of determining that the value is within a range.
# The expression must contain the string ``value``, which
# will represent the user-provided value (after it has been
# cast to a float). Example: "(value >= 0) & (value <= 1)"
if 'expression' in arg:
self.assertIsInstance(arg['expression'], str)
attrs.remove('expression')
# number type may optionally have an 'expression' attribute
# this is a string expression to be evaluated with the
# intent of determining that the value is within a range.
# The expression must contain the string ``value``, which
# will represent the user-provided value (after it has been
# cast to a float). Example: "(value >= 0) & (value <= 1)"
if arg.expression:
self.assertIsInstance(arg.expression, str)
elif t == 'raster':
# raster type should have a bands property that maps each band
# index to a nested type dictionary describing the band's data
self.assertIn('bands', arg)
self.assertIsInstance(arg['bands'], dict)
for band in arg['bands']:
self.assertIsInstance(band, int)
elif t is spec.SingleBandRasterInput:
self.assertTrue(hasattr(arg, 'data_type'))
self.assertTrue(hasattr(arg, 'units'))
# may optionally have a 'projected' attribute that says
# whether the raster must be linearly projected
if arg.projected is not None:
self.assertIsInstance(arg.projected, bool)
attrs.remove('projected')
# if 'projected' is True, may also have a 'projection_units'
# attribute saying the expected linear projection unit
if arg.projection_units:
# doesn't make sense to have projection units unless
# projected is True
self.assertTrue(arg.projected)
self.assertIsInstance(
arg.projection_units, pint.Unit)
attrs.remove('projection_units')
elif t is spec.VectorInput:
# vector type should have:
# - a fields property that maps each field header to a nested
# type dictionary describing the data in that field
# - a geometry_types property: the set of valid geometry types
self.assertTrue(hasattr(arg, 'fields'))
for field in arg.fields:
self.validate_args(
field,
f'{name}.fields.{field}',
parent_type=t)
self.assertTrue(hasattr(arg, 'geometry_types'))
self.assertIsInstance(arg.geometry_types, set)
attrs.remove('fields')
attrs.remove('geometry_types')
# may optionally have a 'projected' attribute that says
# whether the vector must be linearly projected
if arg.projected is not None:
self.assertIsInstance(arg.projected, bool)
attrs.remove('projected')
# if 'projected' is True, may also have a 'projection_units'
# attribute saying the expected linear projection unit
if arg.projection_units:
# doesn't make sense to have projection units unless
# projected is True
self.assertTrue(arg.projected)
self.assertIsInstance(
arg.projection_units, pint.Unit)
attrs.remove('projection_units')
elif t is spec.CSVInput:
# csv type should have a rows property, columns property, or
# neither. rows or columns properties map each expected header
# name/pattern to a nested type dictionary describing the data
# in that row/column. may have neither if the table structure
# is too complex to describe this way.
has_rows = bool(arg.rows)
has_cols = bool(arg.columns)
# should not have both
self.assertFalse(has_rows and has_cols)
if has_cols or has_rows:
direction = 'rows' if has_rows else 'columns'
headers = arg.rows if has_rows else arg.columns
for header in headers:
self.validate_args(
arg['bands'][band],
f'{name}.bands.{band}',
parent_type=t)
attrs.remove('bands')
# may optionally have a 'projected' attribute that says
# whether the raster must be linearly projected
if 'projected' in arg:
self.assertIsInstance(arg['projected'], bool)
attrs.remove('projected')
# if 'projected' is True, may also have a 'projection_units'
# attribute saying the expected linear projection unit
if 'projection_units' in arg:
# doesn't make sense to have projection units unless
# projected is True
self.assertTrue(arg['projected'])
self.assertIsInstance(
arg['projection_units'], pint.Unit)
attrs.remove('projection_units')
elif t == 'vector':
# vector type should have:
# - a fields property that maps each field header to a nested
# type dictionary describing the data in that field
# - a geometries property: the set of valid geometry types
self.assertIn('fields', arg)
self.assertIsInstance(arg['fields'], dict)
for field in arg['fields']:
self.assertIsInstance(field, str)
self.validate_args(
arg['fields'][field],
f'{name}.fields.{field}',
header,
f'{name}.{direction}.{header}',
parent_type=t)
self.assertIn('geometries', arg)
self.assertIsInstance(arg['geometries'], set)
if arg.index_col:
self.assertIn(arg.index_col, [s.id for s in arg.columns])
attrs.remove('fields')
attrs.remove('geometries')
elif t is spec.DirectoryInput:
# directory type should have a contents property that maps each
# expected path name/pattern within the directory to a nested
# type dictionary describing the data at that filepath
self.assertTrue(hasattr(arg, 'contents'))
for path in arg.contents:
self.validate_args(
path,
f'{name}.contents.{path}',
parent_type=t)
attrs.remove('contents')
# may optionally have a 'projected' attribute that says
# whether the vector must be linearly projected
if 'projected' in arg:
self.assertIsInstance(arg['projected'], bool)
attrs.remove('projected')
# if 'projected' is True, may also have a 'projection_units'
# attribute saying the expected linear projection unit
if 'projection_units' in arg:
# doesn't make sense to have projection units unless
# projected is True
self.assertTrue(arg['projected'])
self.assertIsInstance(
arg['projection_units'], pint.Unit)
attrs.remove('projection_units')
# may optionally have a 'permissions' attribute, which is a
# string of the unix-style directory permissions e.g. 'rwx'
if arg.permissions:
self.validate_permissions_value(arg.permissions)
attrs.remove('permissions')
# may optionally have an 'must_exist' attribute, which says
# whether the directory must already exist
# this defaults to True
if arg.must_exist is not None:
self.assertIsInstance(arg.must_exist, bool)
attrs.remove('must_exist')
elif t == 'csv':
# csv type should have a rows property, columns property, or
# neither. rows or columns properties map each expected header
# name/pattern to a nested type dictionary describing the data
# in that row/column. may have neither if the table structure
# is too complex to describe this way.
has_rows = 'rows' in arg
has_cols = 'columns' in arg
# should not have both
self.assertFalse(has_rows and has_cols)
if has_cols or has_rows:
direction = 'rows' if has_rows else 'columns'
headers = arg[direction]
self.assertIsInstance(headers, dict)
for header in headers:
self.assertIsInstance(header, str)
self.validate_args(
headers[header],
f'{name}.{direction}.{header}',
parent_type=t)
if 'index_col' in arg:
self.assertIn(arg['index_col'], arg['columns'])
attrs.discard('index_col')
attrs.discard('rows')
attrs.discard('columns')
elif t == 'directory':
# directory type should have a contents property that maps each
# expected path name/pattern within the directory to a nested
# type dictionary describing the data at that filepath
self.assertIn('contents', arg)
self.assertIsInstance(arg['contents'], dict)
for path in arg['contents']:
self.assertIsInstance(path, str)
self.validate_args(
arg['contents'][path],
f'{name}.contents.{path}',
parent_type=t)
attrs.remove('contents')
# may optionally have a 'permissions' attribute, which is a
# string of the unix-style directory permissions e.g. 'rwx'
if 'permissions' in arg:
self.validate_permissions_value(arg['permissions'])
attrs.remove('permissions')
# may optionally have an 'must_exist' attribute, which says
# whether the directory must already exist
# this defaults to True
if 'must_exist' in arg:
self.assertIsInstance(arg['must_exist'], bool)
attrs.remove('must_exist')
elif t == 'file':
# file type may optionally have a 'permissions' attribute
# this is a string listing the permissions e.g. 'rwx'
if 'permissions' in arg:
self.validate_permissions_value(arg['permissions'])
elif t is spec.FileInput:
# file type may optionally have a 'permissions' attribute
# this is a string listing the permissions e.g. 'rwx'
if arg.permissions:
self.validate_permissions_value(arg.permissions)
# iterate over the remaining attributes
# type-specific ones have been removed by this point
if 'name' in attrs:
self.assertIsInstance(arg['name'], str)
attrs.remove('name')
if 'about' in attrs:
self.assertIsInstance(arg['about'], str)
attrs.remove('about')
if 'required' in attrs:
# required value may be True, False, or a string that can be
# parsed as a python statement that evaluates to True or False
self.assertTrue(isinstance(arg['required'], bool) or
isinstance(arg['required'], str))
attrs.remove('required')
if 'type' in attrs:
self.assertTrue(isinstance(arg['type'], str) or
isinstance(arg['type'], set))
attrs.remove('type')
# args should not have any unexpected properties
# all attrs should have been removed by now
if attrs:
raise AssertionError(f'{name} has key(s) {attrs} that are not '
'expected for its type')
if arg.name:
self.assertIsInstance(arg.name, str)
if arg.about:
self.assertIsInstance(arg.about, str)
# required value may be True, False, or a string that can be
# parsed as a python statement that evaluates to True or False
self.assertTrue(isinstance(arg.required, bool) or
isinstance(arg.required, str))
if arg.allowed:
self.assertIn(type(arg.allowed), {str, bool})
def validate_permissions_value(self, permissions):
"""
@ -492,17 +481,27 @@ class ValidateModelSpecs(unittest.TestCase):
def test_model_specs_serialize(self):
"""MODEL_SPEC: test each arg spec can serialize to JSON."""
from natcap.invest import spec_utils
from natcap.invest import spec
for model_name, metadata in MODEL_METADATA.items():
model = importlib.import_module(metadata.pyname)
try:
_ = spec_utils.serialize_args_spec(model.MODEL_SPEC)
except TypeError as error:
self.fail(
f'Failed to avoid TypeError when serializing '
f'{metadata.pyname}.MODEL_SPEC: \n'
f'{error}')
for pyname in model_id_to_pyname.values():
model = importlib.import_module(pyname)
model.MODEL_SPEC.to_json()
@pytest.mark.skip(reason="Possible race condition of plugin not being uninstalled before other tests are run.")
class PluginTests(unittest.TestCase):
"""Tests for natcap.invest plugins."""
def tearDown(self):
subprocess.run(['pip', 'uninstall', '--yes', PLUGIN_NAME])
def test_plugin(self):
"""natcap.invest locates plugin as a namespace package."""
from natcap.invest import models
self.assertNotIn('foo', models.model_id_to_spec.keys())
subprocess.run(['pip', 'install', '--no-deps', PLUGIN_URL])
models = importlib.reload(models)
self.assertIn('foo', models.model_id_to_spec.keys())
if __name__ == '__main__':

View File

@ -124,18 +124,12 @@ class NDRTests(unittest.TestCase):
# use predefined directory so test can clean up files during teardown
args = NDRTests.generate_base_args(self.workspace_dir)
new_table_path = os.path.join(self.workspace_dir, 'table_c_len_0.csv')
with open(new_table_path, 'w') as target_file:
with open(args['biophysical_table_path'], 'r') as table_file:
target_file.write(table_file.readline())
while True:
line = table_file.readline()
if not line:
break
line_list = line.split(',')
# replace the crit_len_p with 0 in this column
line = (
','.join(line_list[0:12] + ['0.0'] + line_list[13::]))
target_file.write(line)
bio_df = pandas.read_csv(args['biophysical_table_path'])
# replace the crit_len_p with 0 in this column
bio_df['crit_len_p'] = 0
bio_df.to_csv(new_table_path)
bio_df = None
args['biophysical_table_path'] = new_table_path
ndr.execute(args)
@ -376,6 +370,39 @@ class NDRTests(unittest.TestCase):
if mismatch_list:
raise RuntimeError("results not expected: %s" % mismatch_list)
def test_mask_raster_nodata_overflow(self):
"""NDR test when target nodata value overflows source dtype."""
from natcap.invest.ndr import ndr
source_raster_path = os.path.join(self.workspace_dir, 'source.tif')
target_raster_path = os.path.join(
self.workspace_dir, 'target.tif')
source_dtype = numpy.int8
target_dtype = gdal.GDT_Int32
target_nodata = numpy.iinfo(numpy.int32).min
pygeoprocessing.numpy_array_to_raster(
base_array=numpy.full((4, 4), 1, dtype=source_dtype),
target_nodata=None,
pixel_size=(1, -1),
origin=(0, 0),
projection_wkt=None,
target_path=source_raster_path)
ndr._mask_raster(
source_raster_path=source_raster_path,
mask_raster_path=source_raster_path, # mask=source for convenience
target_masked_raster_path=target_raster_path,
target_nodata=target_nodata,
target_dtype=target_dtype)
# Mostly we're testing that _mask_raster did not raise an OverflowError,
# but we can assert the results anyway.
array = pygeoprocessing.raster_to_numpy_array(target_raster_path)
numpy.testing.assert_array_equal(
array,
numpy.full((4, 4), 1, dtype=numpy.int32)) # matches target_dtype
def test_validation(self):
"""NDR test argument validation."""
from natcap.invest import validation
@ -537,3 +564,56 @@ class NDRTests(unittest.TestCase):
expected_rpi = runoff_proxy_array/numpy.mean(runoff_proxy_array)
numpy.testing.assert_allclose(actual_rpi, expected_rpi)
def test_calculate_load_type(self):
"""Test ``_calculate_load`` for both load_types."""
from natcap.invest.ndr import ndr
# make simple lulc raster
lulc_path = os.path.join(self.workspace_dir, "lulc-load-type.tif")
lulc_array = numpy.array(
[[1, 2, 3, 4], [4, 3, 2, 1]], dtype=numpy.int16)
srs = osr.SpatialReference()
srs.ImportFromEPSG(26910)
projection_wkt = srs.ExportToWkt()
origin = (461251, 4923445)
pixel_size = (30, -30)
no_data = -1
pygeoprocessing.numpy_array_to_raster(
lulc_array, no_data, pixel_size, origin, projection_wkt,
lulc_path)
target_load_path = os.path.join(self.workspace_dir, "load_raster.tif")
# Calculate load
lucode_to_params = {
1: {'load_n': 10.0, 'eff_n': 0.5, 'load_type_n': 'measured-runoff'},
2: {'load_n': 20.0, 'eff_n': 0.5, 'load_type_n': 'measured-runoff'},
3: {'load_n': 10.0, 'eff_n': 0.5, 'load_type_n': 'application-rate'},
4: {'load_n': 20.0, 'eff_n': 0.5, 'load_type_n': 'application-rate'}}
ndr._calculate_load(lulc_path, lucode_to_params, 'n', target_load_path)
expected_results = numpy.array(
[[10.0, 20.0, 5.0, 10.0], [10.0, 5.0, 20.0, 10.0]])
actual_results = pygeoprocessing.raster_to_numpy_array(target_load_path)
numpy.testing.assert_allclose(actual_results, expected_results)
def test_calculate_load_type_raises_error(self):
"""Test ``_calculate_load`` raises ValueError on bad load_type's."""
from natcap.invest.ndr import ndr
lulc_path = os.path.join(self.workspace_dir, "lulc-load-type.tif")
target_load_path = os.path.join(self.workspace_dir, "load_raster.tif")
# Calculate load
lucode_to_params = {
1: {'load_n': 10.0, 'eff_n': 0.5, 'load_type_n': 'measured-runoff'},
2: {'load_n': 20.0, 'eff_n': 0.5, 'load_type_n': 'cheese'},
3: {'load_n': 10.0, 'eff_n': 0.5, 'load_type_n': 'application-rate'},
4: {'load_n': 20.0, 'eff_n': 0.5, 'load_type_n': 'application-rate'}}
with self.assertRaises(ValueError) as cm:
ndr._calculate_load(lulc_path, lucode_to_params, 'n', target_load_path)
actual_message = str(cm.exception)
self.assertTrue('found value of: "cheese"' in actual_message)

View File

@ -156,7 +156,7 @@ class TestBufferedNumpyDiskMap(unittest.TestCase):
def test_buffered_multiprocess_operation(self):
"""Recreation test buffered file manager parallel flushes."""
from natcap.invest.recreation import buffered_numpy_disk_map
array1 = numpy.array([1, 2, 3, 4])
array2 = numpy.array([-4, -1, -2, 4])
arraysize = array1.size * buffered_numpy_disk_map.BufferedNumpyDiskMap._ARRAY_TUPLE_TYPE.itemsize
@ -677,6 +677,40 @@ class TestRecClientServer(unittest.TestCase):
"""Delete workspace"""
shutil.rmtree(self.workspace_dir, ignore_errors=True)
def test_execute_no_regression(self):
"""Recreation test userday metrics exist if not computing regression."""
from natcap.invest.recreation import recmodel_client
args = {
'aoi_path': os.path.join(
SAMPLE_DATA, 'andros_aoi.shp'),
'compute_regression': False,
'start_year': recmodel_client.MIN_YEAR,
'end_year': recmodel_client.MAX_YEAR,
'grid_aoi': False,
'workspace_dir': self.workspace_dir,
'hostname': self.hostname,
'port': self.port,
}
recmodel_client.execute(args)
out_regression_vector_path = os.path.join(
args['workspace_dir'], 'regression_data.gpkg')
# These fields should exist even if `compute_regression` is False
expected_fields = ['pr_TUD', 'pr_PUD', 'avg_pr_UD']
# For convenience, assert the sums of the columns instead of all
# the individual values.
actual_sums = sum_vector_columns(
out_regression_vector_path, expected_fields)
expected_sums = {
'pr_TUD': 1.0,
'pr_PUD': 1.0,
'avg_pr_UD': 1.0
}
for key in expected_sums:
numpy.testing.assert_almost_equal(
actual_sums[key], expected_sums[key], decimal=3)
def test_all_metrics_local_server(self):
"""Recreation test with all but trivial predictor metrics."""
from natcap.invest.recreation import recmodel_client
@ -706,9 +740,9 @@ class TestRecClientServer(unittest.TestCase):
out_regression_vector_path = os.path.join(
args['workspace_dir'], f'regression_data_{suffix}.gpkg')
predictor_df = validation.get_validated_dataframe(
os.path.join(SAMPLE_DATA, 'predictors_all.csv'),
**recmodel_client.MODEL_SPEC['args']['predictor_table_path'])
predictor_df = recmodel_client.MODEL_SPEC.get_input(
'predictor_table_path').get_validated_dataframe(
os.path.join(SAMPLE_DATA, 'predictors_all.csv'))
field_list = list(predictor_df.index) + ['pr_TUD', 'pr_PUD', 'avg_pr_UD']
# For convenience, assert the sums of the columns instead of all
@ -1259,65 +1293,6 @@ class RecreationClientRegressionTests(unittest.TestCase):
# andros_aoi.shp fits 71 hexes at 20000 meters cell size
self.assertEqual(n_features, 71)
def test_existing_regression_coef(self):
"""Recreation test regression coefficients handle existing output."""
from natcap.invest.recreation import recmodel_client
from natcap.invest import validation
# Initialize a TaskGraph
taskgraph_db_dir = os.path.join(
self.workspace_dir, '_taskgraph_working_dir')
n_workers = -1 # single process mode.
task_graph = taskgraph.TaskGraph(taskgraph_db_dir, n_workers)
response_vector_path = os.path.join(
self.workspace_dir, 'no_grid_vector_path.gpkg')
response_polygons_lookup_path = os.path.join(
self.workspace_dir, 'response_polygons_lookup.pickle')
recmodel_client._copy_aoi_no_grid(
os.path.join(SAMPLE_DATA, 'andros_aoi.shp'), response_vector_path)
predictor_table_path = os.path.join(SAMPLE_DATA, 'predictors.csv')
# make outputs to be overwritten
predictor_dict = validation.get_validated_dataframe(
predictor_table_path,
**recmodel_client.MODEL_SPEC['args']['predictor_table_path']
).to_dict(orient='index')
predictor_list = predictor_dict.keys()
tmp_working_dir = tempfile.mkdtemp(dir=self.workspace_dir)
empty_json_list = [
os.path.join(tmp_working_dir, x + '.json') for x in predictor_list]
out_coefficient_vector_path = os.path.join(
self.workspace_dir, 'out_coefficient_vector.shp')
_make_empty_files(
[out_coefficient_vector_path] + empty_json_list)
prepare_response_polygons_task = task_graph.add_task(
func=recmodel_client._prepare_response_polygons_lookup,
args=(response_vector_path,
response_polygons_lookup_path),
target_path_list=[response_polygons_lookup_path],
task_name='prepare response polygons for geoprocessing')
# build again to test against overwriting output
recmodel_client._schedule_predictor_data_processing(
response_vector_path, response_polygons_lookup_path,
prepare_response_polygons_task, predictor_table_path,
out_coefficient_vector_path, tmp_working_dir, task_graph)
# Copied over from a shapefile formerly in our test-data repo:
expected_values = {
'bonefish': 19.96503546104,
'airdist': 40977.89565353348,
'ports': 14.0,
'bathy': 1.17308099107
}
vector = gdal.OpenEx(out_coefficient_vector_path)
layer = vector.GetLayer()
for feature in layer:
for k, v in expected_values.items():
numpy.testing.assert_almost_equal(feature.GetField(k), v)
def test_predictor_table_absolute_paths(self):
"""Recreation test validation from full path."""
from natcap.invest.recreation import recmodel_client

View File

@ -371,7 +371,7 @@ class SDRTests(unittest.TestCase):
with self.assertRaises(ValueError) as context:
sdr.execute(args)
self.assertIn(
'could not be interpreted as ratios', str(context.exception))
'could not be interpreted as RatioInput', str(context.exception))
def test_lucode_not_a_number(self):
"""SDR test expected exception for invalid data in lucode column."""
@ -392,7 +392,7 @@ class SDRTests(unittest.TestCase):
with self.assertRaises(ValueError) as context:
sdr.execute(args)
self.assertIn(
'could not be interpreted as integers', str(context.exception))
'could not be interpreted as IntegerInput', str(context.exception))
def test_missing_lulc_value(self):
"""SDR test for ValueError when LULC value not found in table."""

View File

@ -875,7 +875,7 @@ class SeasonalWaterYieldRegressionTests(unittest.TestCase):
with self.assertRaises(ValueError) as context:
seasonal_water_yield.execute(args)
self.assertIn(
'could not be interpreted as numbers', str(context.exception))
'could not be interpreted as NumberInput', str(context.exception))
def test_monthly_alpha_regression(self):
"""SWY monthly alpha values regression test on sample data.

380
tests/test_spec.py Normal file
View File

@ -0,0 +1,380 @@
import os
import shutil
import tempfile
import types
import unittest
import geometamaker
from natcap.invest import spec
from natcap.invest.unit_registry import u
from osgeo import gdal
from osgeo import ogr
gdal.UseExceptions()
class SpecUtilsUnitTests(unittest.TestCase):
"""Unit tests for natcap.invest.spec."""
def test_format_unit(self):
"""spec: test converting units to strings with format_unit."""
for unit_name, expected in [
('meter', 'm'),
('meter / second', 'm/s'),
('foot * mm', 'ft · mm'),
('t * hr * ha / ha / MJ / mm', 't · h · ha / (ha · MJ · mm)'),
('mm^3 / year', 'mm³/year')
]:
unit = spec.u.Unit(unit_name)
actual = spec.format_unit(unit)
self.assertEqual(expected, actual)
def test_format_unit_raises_error(self):
"""spec: format_unit raises TypeError if not a pint.Unit."""
with self.assertRaises(TypeError):
spec.format_unit({})
class TestDescribeArgFromSpec(unittest.TestCase):
"""Test building RST for various invest args specifications."""
def test_number_spec(self):
number_spec = spec.NumberInput(
name="Bar",
about="Description",
units=u.meter**3/u.month,
expression="value >= 0"
)
out = spec.describe_arg_from_spec(number_spec.name, number_spec)
expected_rst = ([
'**Bar** (`number <input_types.html#number>`__, '
'units: **m³/month**, *required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_ratio_spec(self):
ratio_spec = spec.RatioInput(
name="Bar",
about="Description"
)
out = spec.describe_arg_from_spec(ratio_spec.name, ratio_spec)
expected_rst = (['**Bar** (`ratio <input_types.html#ratio>`__, '
'*required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_percent_spec(self):
percent_spec = spec.PercentInput(
name="Bar",
about="Description",
required=False
)
out = spec.describe_arg_from_spec(percent_spec.name, percent_spec)
expected_rst = (['**Bar** (`percent <input_types.html#percent>`__, '
'*optional*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_integer_spec(self):
integer_spec = spec.IntegerInput(
name="Bar",
about="Description",
required=True
)
out = spec.describe_arg_from_spec(integer_spec.name, integer_spec)
expected_rst = (['**Bar** (`integer <input_types.html#integer>`__, '
'*required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_boolean_spec(self):
boolean_spec = spec.BooleanInput(
name="Bar",
about="Description"
)
out = spec.describe_arg_from_spec(boolean_spec.name, boolean_spec)
expected_rst = (['**Bar** (`true/false <input_types.html#truefalse>'
'`__): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_freestyle_string_spec(self):
string_spec = spec.StringInput(
name="Bar",
about="Description"
)
out = spec.describe_arg_from_spec(string_spec.name, string_spec)
expected_rst = (['**Bar** (`text <input_types.html#text>`__, '
'*required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_option_string_spec_dictionary(self):
option_spec = spec.OptionStringInput(
name="Bar",
about="Description",
options={
"option_a": {
"display_name": "A"
},
"Option_b": {
"description": "do something"
},
"option_c": {
"display_name": "c",
"description": "do something else"
}
}
)
# expect that option case is ignored
# otherwise, c would sort before A
out = spec.describe_arg_from_spec(option_spec.name, option_spec)
expected_rst = ([
'**Bar** (`option <input_types.html#option>`__, *required*): Description',
'\tOptions:',
'\t- A',
'\t- c: do something else',
'\t- Option_b: do something'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_option_string_spec_list(self):
option_spec = spec.OptionStringInput(
name="Bar",
about="Description",
options=["option_a", "Option_b"]
)
out = spec.describe_arg_from_spec(option_spec.name, option_spec)
expected_rst = ([
'**Bar** (`option <input_types.html#option>`__, *required*): Description',
'\tOptions: option_a, Option_b'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_raster_spec(self):
raster_spec = spec.SingleBandRasterInput(
data_type=int,
about="Description",
name="Bar"
)
out = spec.describe_arg_from_spec(raster_spec.name, raster_spec)
expected_rst = ([
'**Bar** (`raster <input_types.html#raster>`__, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
raster_spec = spec.SingleBandRasterInput(
data_type=float,
units=u.millimeter/u.year,
about="Description",
name="Bar"
)
out = spec.describe_arg_from_spec(raster_spec.name, raster_spec)
expected_rst = ([
'**Bar** (`raster <input_types.html#raster>`__, units: **mm/year**, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_vector_spec(self):
vector_spec = spec.VectorInput(
fields={},
geometry_types={"LINESTRING"},
about="Description",
name="Bar"
)
out = spec.describe_arg_from_spec(vector_spec.name, vector_spec)
expected_rst = ([
'**Bar** (`vector <input_types.html#vector>`__, linestring, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
vector_spec = spec.VectorInput(
fields=[
spec.IntegerInput(
id="id",
about="Unique identifier for each feature"
),
spec.NumberInput(
id="precipitation",
units=u.millimeter/u.year,
about="Average annual precipitation over the area"
)
],
geometry_types={"POLYGON", "MULTIPOLYGON"},
about="Description",
name="Bar"
)
out = spec.describe_arg_from_spec(vector_spec.name, vector_spec)
expected_rst = ([
'**Bar** (`vector <input_types.html#vector>`__, polygon/multipolygon, *required*): Description',
])
self.assertEqual(repr(out), repr(expected_rst))
def test_csv_spec(self):
csv_spec = spec.CSVInput(
about="Description.",
name="Bar"
)
out = spec.describe_arg_from_spec(csv_spec.name, csv_spec)
expected_rst = ([
'**Bar** (`CSV <input_types.html#csv>`__, *required*): Description. '
'Please see the sample data table for details on the format.'
])
self.assertEqual(repr(out), repr(expected_rst))
# Test every type that can be nested in a CSV column:
# number, ratio, percent, code,
csv_spec = spec.CSVInput(
about="Description",
name="Bar",
columns=[
spec.RatioInput(
id="b",
about="description"
)
]
)
out = spec.describe_arg_from_spec(csv_spec.name, csv_spec)
expected_rst = ([
'**Bar** (`CSV <input_types.html#csv>`__, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_directory_spec(self):
self.maxDiff = None
dir_spec = spec.DirectoryInput(
about="Description",
name="Bar",
contents={}
)
out = spec.describe_arg_from_spec(dir_spec.name, dir_spec)
expected_rst = ([
'**Bar** (`directory <input_types.html#directory>`__, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_multi_type_spec(self):
multi_spec = spec.RasterOrVectorInput(
about="Description",
name="Bar",
data_type=int,
geometry_types={"POLYGON"},
fields={}
)
out = spec.describe_arg_from_spec(multi_spec.name, multi_spec)
expected_rst = ([
'**Bar** (`raster <input_types.html#raster>`__ or `vector <input_types.html#vector>`__, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_real_model_spec(self):
from natcap.invest import carbon
out = spec.describe_arg_from_name(
'natcap.invest.carbon', 'carbon_pools_path', 'columns', 'lucode')
expected_rst = (
'.. _carbon-pools-path-columns-lucode:\n\n' +
'**lucode** (`integer <input_types.html#integer>`__, *required*): ' +
carbon.MODEL_SPEC.get_input('carbon_pools_path').columns.get('lucode').about
)
self.assertEqual(repr(out), repr(expected_rst))
def _generate_files_from_spec(output_spec, workspace):
"""A utility function to support the metadata test."""
for spec_data in output_spec:
if spec_data.__class__ is spec.DirectoryOutput:
os.mkdir(os.path.join(workspace, spec_data.id))
_generate_files_from_spec(
spec_data.contents, os.path.join(workspace, spec_data.id))
else:
filepath = os.path.join(workspace, spec_data.id)
if isinstance(spec_data, spec.SingleBandRasterOutput):
driver = gdal.GetDriverByName('GTIFF')
raster = driver.Create(filepath, 2, 2, 1, gdal.GDT_Byte)
band = raster.GetRasterBand(1)
band.SetNoDataValue(2)
elif isinstance(spec_data, spec.VectorOutput):
driver = gdal.GetDriverByName('GPKG')
target_vector = driver.CreateDataSource(filepath)
layer_name = os.path.basename(os.path.splitext(filepath)[0])
target_layer = target_vector.CreateLayer(
layer_name, geom_type=ogr.wkbPolygon)
for field_spec in spec_data.fields:
target_layer.CreateField(ogr.FieldDefn(field_spec.id, ogr.OFTInteger))
else:
# Such as taskgraph.db, just create the file.
with open(filepath, 'w') as file:
pass
class TestMetadataFromSpec(unittest.TestCase):
"""Tests for metadata-generation functions."""
def setUp(self):
"""Override setUp function to create temp workspace directory."""
self.workspace_dir = tempfile.mkdtemp()
def tearDown(self):
"""Override tearDown function to remove temporary directory."""
shutil.rmtree(self.workspace_dir)
def test_write_metadata_for_outputs(self):
"""Test writing metadata for an invest output workspace."""
# An example invest output spec
output_spec = [
spec.DirectoryOutput(
id='output',
contents=[
spec.SingleBandRasterOutput(
id="urban_nature_supply_percapita.tif",
about="The calculated supply per capita of urban nature.",
data_type=float,
units=u.m**2
),
spec.VectorOutput(
id="admin_boundaries.gpkg",
about=("A copy of the user's administrative boundaries "
"vector with a single layer."),
geometry_types=spec.POLYGONS,
fields=[
spec.NumberInput(
id="SUP_DEMadm_cap",
units=u.m**2/u.person,
about="The average urban nature supply/demand"
)
]
)
]
),
spec.DirectoryOutput(
id='intermediate',
contents=[
spec.build_output_spec('taskgraph_cache', spec.TASKGRAPH_DIR)
]
)
]
# Generate an output workspace with real files, without
# running an invest model.
_generate_files_from_spec(output_spec, self.workspace_dir)
model_module = types.SimpleNamespace(
__name__='urban_nature_access',
execute=lambda: None,
MODEL_SPEC=spec.ModelSpec(
model_id='urban_nature_access',
model_title='Urban Nature Access',
userguide='',
aliases=[],
input_field_order=[],
inputs={},
outputs=output_spec
)
)
args_dict = {'workspace_dir': self.workspace_dir}
spec.generate_metadata_for_outputs(model_module, args_dict)
files, messages = geometamaker.validate_dir(
self.workspace_dir, recursive=True)
self.assertEqual(len(files), 2)
self.assertFalse(any(messages))
resource = geometamaker.describe(
os.path.join(args_dict['workspace_dir'], 'output',
'urban_nature_supply_percapita.tif'))
self.assertCountEqual(resource.get_keywords(),
[model_module.MODEL_SPEC.model_id, 'InVEST'])

View File

@ -1,402 +0,0 @@
import os
import shutil
import tempfile
import types
import unittest
import geometamaker
from natcap.invest import spec_utils
from natcap.invest.unit_registry import u
from osgeo import gdal
from osgeo import ogr
gdal.UseExceptions()
class SpecUtilsUnitTests(unittest.TestCase):
"""Unit tests for natcap.invest.spec_utils."""
def test_format_unit(self):
"""spec_utils: test converting units to strings with format_unit."""
from natcap.invest import spec_utils
for unit_name, expected in [
('meter', 'm'),
('meter / second', 'm/s'),
('foot * mm', 'ft · mm'),
('t * hr * ha / ha / MJ / mm', 't · h · ha / (ha · MJ · mm)'),
('mm^3 / year', 'mm³/year')
]:
unit = spec_utils.u.Unit(unit_name)
actual = spec_utils.format_unit(unit)
self.assertEqual(expected, actual)
def test_format_unit_raises_error(self):
"""spec_utils: format_unit raises TypeError if not a pint.Unit."""
from natcap.invest import spec_utils
with self.assertRaises(TypeError):
spec_utils.format_unit({})
class TestDescribeArgFromSpec(unittest.TestCase):
"""Test building RST for various invest args specifications."""
def test_number_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "number",
"units": u.meter**3/u.month,
"expression": "value >= 0"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = ([
'**Bar** (`number <input_types.html#number>`__, '
'units: **m³/month**, *required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_ratio_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "ratio"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = (['**Bar** (`ratio <input_types.html#ratio>`__, '
'*required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_percent_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "percent",
"required": False
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = (['**Bar** (`percent <input_types.html#percent>`__, '
'*optional*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_code_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "integer",
"required": True
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = (['**Bar** (`integer <input_types.html#integer>`__, '
'*required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_boolean_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "boolean"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = (['**Bar** (`true/false <input_types.html#truefalse>'
'`__): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_freestyle_string_spec(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "freestyle_string"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = (['**Bar** (`text <input_types.html#text>`__, '
'*required*): Description'])
self.assertEqual(repr(out), repr(expected_rst))
def test_option_string_spec_dictionary(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "option_string",
"options": {
"option_a": {
"display_name": "A"
},
"Option_b": {
"description": "do something"
},
"option_c": {
"display_name": "c",
"description": "do something else"
}
}
}
# expect that option case is ignored
# otherwise, c would sort before A
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = ([
'**Bar** (`option <input_types.html#option>`__, *required*): Description',
'\tOptions:',
'\t- A',
'\t- c: do something else',
'\t- Option_b: do something'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_option_string_spec_list(self):
spec = {
"name": "Bar",
"about": "Description",
"type": "option_string",
"options": ["option_a", "Option_b"]
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = ([
'**Bar** (`option <input_types.html#option>`__, *required*): Description',
'\tOptions: option_a, Option_b'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_raster_spec(self):
spec = {
"type": "raster",
"bands": {1: {"type": "integer"}},
"about": "Description",
"name": "Bar"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = ([
'**Bar** (`raster <input_types.html#raster>`__, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
spec = {
"type": "raster",
"bands": {1: {
"type": "number",
"units": u.millimeter/u.year
}},
"about": "Description",
"name": "Bar"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = ([
'**Bar** (`raster <input_types.html#raster>`__, units: **mm/year**, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_vector_spec(self):
spec = {
"type": "vector",
"fields": {},
"geometries": {"LINESTRING"},
"about": "Description",
"name": "Bar"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = ([
'**Bar** (`vector <input_types.html#vector>`__, linestring, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
spec = {
"type": "vector",
"fields": {
"id": {
"type": "integer",
"about": "Unique identifier for each feature"
},
"precipitation": {
"type": "number",
"units": u.millimeter/u.year,
"about": "Average annual precipitation over the area"
}
},
"geometries": {"POLYGON", "MULTIPOLYGON"},
"about": "Description",
"name": "Bar"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = ([
'**Bar** (`vector <input_types.html#vector>`__, polygon/multipolygon, *required*): Description',
])
self.assertEqual(repr(out), repr(expected_rst))
def test_csv_spec(self):
spec = {
"type": "csv",
"about": "Description.",
"name": "Bar"
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = ([
'**Bar** (`CSV <input_types.html#csv>`__, *required*): Description. '
'Please see the sample data table for details on the format.'
])
self.assertEqual(repr(out), repr(expected_rst))
# Test every type that can be nested in a CSV column:
# number, ratio, percent, code,
spec = {
"type": "csv",
"about": "Description",
"name": "Bar",
"columns": {
"b": {"type": "ratio", "about": "description"}
}
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = ([
'**Bar** (`CSV <input_types.html#csv>`__, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_directory_spec(self):
self.maxDiff = None
spec = {
"type": "directory",
"about": "Description",
"name": "Bar",
"contents": {}
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = ([
'**Bar** (`directory <input_types.html#directory>`__, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_multi_type_spec(self):
spec = {
"type": {"raster", "vector"},
"about": "Description",
"name": "Bar",
"bands": {1: {"type": "integer"}},
"geometries": {"POLYGON"},
"fields": {}
}
out = spec_utils.describe_arg_from_spec(spec['name'], spec)
expected_rst = ([
'**Bar** (`raster <input_types.html#raster>`__ or `vector <input_types.html#vector>`__, *required*): Description'
])
self.assertEqual(repr(out), repr(expected_rst))
def test_real_model_spec(self):
from natcap.invest import carbon
out = spec_utils.describe_arg_from_name(
'natcap.invest.carbon', 'carbon_pools_path', 'columns', 'lucode')
expected_rst = (
'.. _carbon-pools-path-columns-lucode:\n\n' +
'**lucode** (`integer <input_types.html#integer>`__, *required*): ' +
carbon.MODEL_SPEC['args']['carbon_pools_path']['columns']['lucode']['about']
)
self.assertEqual(repr(out), repr(expected_rst))
def _generate_files_from_spec(output_spec, workspace):
"""A utility function to support the metadata test."""
for filename, spec_data in output_spec.items():
if 'type' in spec_data and spec_data['type'] == 'directory':
os.mkdir(os.path.join(workspace, filename))
_generate_files_from_spec(
spec_data['contents'], os.path.join(workspace, filename))
else:
filepath = os.path.join(workspace, filename)
if 'bands' in spec_data:
driver = gdal.GetDriverByName('GTIFF')
n_bands = len(spec_data['bands'])
raster = driver.Create(
filepath, 2, 2, n_bands, gdal.GDT_Byte)
for i in range(n_bands):
band = raster.GetRasterBand(i + 1)
band.SetNoDataValue(2)
elif 'fields' in spec_data:
if 'geometries' in spec_data:
driver = gdal.GetDriverByName('GPKG')
target_vector = driver.CreateDataSource(filepath)
layer_name = os.path.basename(os.path.splitext(filepath)[0])
target_layer = target_vector.CreateLayer(
layer_name, geom_type=ogr.wkbPolygon)
for field_name, field_data in spec_data['fields'].items():
target_layer.CreateField(ogr.FieldDefn(field_name, ogr.OFTInteger))
else:
# Write a CSV if it has fields but no geometry
with open(filepath, 'w') as file:
file.write(
f"{','.join([field for field in spec_data['fields']])}")
else:
# Such as taskgraph.db, just create the file.
with open(filepath, 'w') as file:
pass
class TestMetadataFromSpec(unittest.TestCase):
"""Tests for metadata-generation functions."""
def setUp(self):
"""Override setUp function to create temp workspace directory."""
self.workspace_dir = tempfile.mkdtemp()
def tearDown(self):
"""Override tearDown function to remove temporary directory."""
shutil.rmtree(self.workspace_dir)
def test_write_metadata_for_outputs(self):
"""Test writing metadata for an invest output workspace."""
# An example invest output spec
output_spec = {
'output': {
"type": "directory",
"contents": {
"urban_nature_supply_percapita.tif": {
"about": (
"The calculated supply per capita of urban nature."),
"bands": {1: {
"type": "number",
"units": u.m**2,
}}},
"admin_boundaries.gpkg": {
"about": (
"A copy of the user's administrative boundaries "
"vector with a single layer."),
"geometries": spec_utils.POLYGONS,
"fields": {
"SUP_DEMadm_cap": {
"type": "number",
"units": u.m**2/u.person,
"about": (
"The average urban nature supply/demand ")
}
}
}
},
},
'intermediate': {
'type': 'directory',
'contents': {
'taskgraph_cache': spec_utils.TASKGRAPH_DIR,
}
}
}
# Generate an output workspace with real files, without
# running an invest model.
_generate_files_from_spec(output_spec, self.workspace_dir)
model_module = types.SimpleNamespace(
__name__='urban_nature_access',
execute=lambda: None,
MODEL_SPEC={
'model_id': 'urban_nature_access',
'outputs': output_spec})
args_dict = {'workspace_dir': self.workspace_dir}
spec_utils.generate_metadata_for_outputs(model_module, args_dict)
files, messages = geometamaker.validate_dir(
self.workspace_dir, recursive=True)
self.assertEqual(len(files), 2)
self.assertFalse(any(messages))
resource = geometamaker.describe(
os.path.join(args_dict['workspace_dir'], 'output',
'urban_nature_supply_percapita.tif'))
self.assertCountEqual(resource.get_keywords(),
[model_module.MODEL_SPEC['model_id'], 'InVEST'])

View File

@ -99,7 +99,7 @@ class TranslationTests(unittest.TestCase):
def test_invest_validate(self):
"""Translation: test that CLI validate output is translated."""
datastack = { # write datastack to a JSON file
'model_name': 'natcap.invest.carbon',
'model_id': 'carbon',
'invest_version': '0.0',
'args': {}
}
@ -124,7 +124,8 @@ class TranslationTests(unittest.TestCase):
'api/models', query_string={'language': TEST_LANG})
result = json.loads(response.get_data(as_text=True))
self.assertIn(
TEST_MESSAGES['Carbon Storage and Sequestration'], result)
TEST_MESSAGES['Carbon Storage and Sequestration'],
[val['model_title'] for val in result.values()])
def test_server_get_invest_getspec(self):
"""Translation: test that /getspec endpoint is translated."""
@ -143,7 +144,7 @@ class TranslationTests(unittest.TestCase):
from natcap.invest import carbon
test_client = ui_server.app.test_client()
payload = {
'model_module': carbon.MODEL_SPEC['pyname'],
'model_id': carbon.MODEL_SPEC.model_id,
'args': json.dumps({})
}
response = test_client.post(

View File

@ -28,32 +28,6 @@ class EndpointFunctionTests(unittest.TestCase):
"""Override tearDown function to remove temporary directory."""
shutil.rmtree(self.workspace_dir)
def test_get_vector_colnames(self):
"""UI server: get_vector_colnames endpoint."""
test_client = ui_server.app.test_client()
# an empty path
response = test_client.post(
f'{ROUTE_PREFIX}/colnames', json={'vector_path': ''})
self.assertEqual(response.status_code, 422)
colnames = json.loads(response.get_data(as_text=True))
self.assertEqual(colnames, [])
# a vector with one column
path = os.path.join(
TEST_DATA_PATH, 'annual_water_yield', 'input',
'watersheds.shp')
response = test_client.post(
f'{ROUTE_PREFIX}/colnames', json={'vector_path': path})
self.assertEqual(response.status_code, 200)
colnames = json.loads(response.get_data(as_text=True))
self.assertEqual(colnames, ['ws_id'])
# a non-vector file
path = os.path.join(TEST_DATA_PATH, 'ndr', 'input', 'dem.tif')
response = test_client.post(
f'{ROUTE_PREFIX}/colnames', json={'vector_path': path})
self.assertEqual(response.status_code, 422)
colnames = json.loads(response.get_data(as_text=True))
self.assertEqual(colnames, [])
def test_get_invest_models(self):
"""UI server: get_invest_models endpoint."""
test_client = ui_server.app.test_client()
@ -61,18 +35,18 @@ class EndpointFunctionTests(unittest.TestCase):
self.assertEqual(response.status_code, 200)
models_dict = json.loads(response.get_data(as_text=True))
for model in models_dict.values():
self.assertEqual(set(model), {'model_name', 'aliases'})
self.assertEqual(set(model), {'model_title', 'aliases'})
def test_get_invest_spec(self):
"""UI server: get_invest_spec endpoint."""
test_client = ui_server.app.test_client()
response = test_client.post(f'{ROUTE_PREFIX}/getspec', json='sdr')
self.assertEqual(response.status_code, 200)
response = test_client.post(f'{ROUTE_PREFIX}/getspec', json='carbon')
spec = json.loads(response.get_data(as_text=True))
self.assertEqual(
set(spec),
{'model_id', 'model_name', 'pyname', 'userguide',
'args_with_spatial_overlap', 'args', 'outputs'})
{'model_id', 'model_title', 'userguide', 'aliases',
'input_field_order', 'different_projections_ok',
'validate_spatial_overlap', 'args', 'outputs'})
def test_get_invest_validate(self):
"""UI server: get_invest_validate endpoint."""
@ -82,7 +56,7 @@ class EndpointFunctionTests(unittest.TestCase):
'workspace_dir': 'foo'
}
payload = {
'model_module': carbon.MODEL_SPEC['pyname'],
'model_id': carbon.MODEL_SPEC.model_id,
'args': json.dumps(args)
}
response = test_client.post(f'{ROUTE_PREFIX}/validate', json=payload)
@ -112,8 +86,7 @@ class EndpointFunctionTests(unittest.TestCase):
response_data = json.loads(response.get_data(as_text=True))
self.assertEqual(
set(response_data),
{'type', 'args', 'module_name', 'model_run_name',
'model_human_name', 'invest_version'})
{'type', 'args', 'model_id', 'invest_version'})
def test_write_parameter_set_file(self):
"""UI server: write_parameter_set_file endpoint."""
@ -121,7 +94,7 @@ class EndpointFunctionTests(unittest.TestCase):
filepath = os.path.join(self.workspace_dir, 'datastack.json')
payload = {
'filepath': filepath,
'moduleName': 'natcap.invest.carbon',
'model_id': 'carbon',
'args': json.dumps({
'workspace_dir': 'foo'
}),
@ -137,7 +110,7 @@ class EndpointFunctionTests(unittest.TestCase):
actual_data = json.loads(file.read())
self.assertEqual(
set(actual_data),
{'args', 'invest_version', 'model_name'})
{'args', 'invest_version', 'model_id'})
def test_write_parameter_set_file_error_handling(self):
"""UI server: write_parameter_set_file endpoint
@ -147,7 +120,7 @@ class EndpointFunctionTests(unittest.TestCase):
filepath = os.path.join(self.workspace_dir, 'datastack.json')
payload = {
'filepath': filepath,
'moduleName': 'natcap.invest.carbon',
'model_id': 'carbon',
'args': json.dumps({
'workspace_dir': 'foo'
}),
@ -169,7 +142,7 @@ class EndpointFunctionTests(unittest.TestCase):
filepath = os.path.join(self.workspace_dir, 'script.py')
payload = {
'filepath': filepath,
'modelname': 'carbon',
'model_id': 'carbon',
'args': json.dumps({
'workspace_dir': 'foo'
}),
@ -189,7 +162,7 @@ class EndpointFunctionTests(unittest.TestCase):
payload = {
'filepath': target_filepath,
'moduleName': 'natcap.invest.carbon',
'model_id': 'carbon',
'args': json.dumps({
'workspace_dir': 'foo',
'carbon_pools_path': data_path
@ -216,7 +189,7 @@ class EndpointFunctionTests(unittest.TestCase):
payload = {
'filepath': target_filepath,
'moduleName': 'natcap.invest.carbon',
'model_id': 'carbon',
'args': json.dumps({
'workspace_dir': 'foo',
'carbon_pools_path': data_path
@ -231,7 +204,7 @@ class EndpointFunctionTests(unittest.TestCase):
self.assertEqual(
response.json,
{'message': error_message, 'error': True})
@patch('natcap.invest.ui_server.usage.requests.post')
@patch('natcap.invest.ui_server.usage.requests.get')
def test_log_model_start(self, mock_get, mock_post):
@ -242,12 +215,13 @@ class EndpointFunctionTests(unittest.TestCase):
mock_get.return_value = mock_response
test_client = ui_server.app.test_client()
payload = {
'model_pyname': 'natcap.invest.carbon',
'model_id': 'carbon',
'model_args': json.dumps({
'workspace_dir': 'foo'
}),
'invest_interface': 'Workbench',
'session_id': '12345'
'session_id': '12345',
'type': 'core'
}
response = test_client.post(
f'{ROUTE_PREFIX}/log_model_start', json=payload)
@ -258,7 +232,7 @@ class EndpointFunctionTests(unittest.TestCase):
self.assertEqual(mock_post.call_args.args[0], mock_url)
self.assertEqual(
mock_post.call_args.kwargs['data']['model_name'],
payload['model_pyname'])
'natcap.invest.carbon')
self.assertEqual(
mock_post.call_args.kwargs['data']['invest_interface'],
payload['invest_interface'])

View File

@ -30,6 +30,7 @@ class UsageLoggingTests(unittest.TestCase):
"""Usage logger test that we can extract bounding boxes."""
from natcap.invest import utils
from natcap.invest import usage
from natcap.invest import spec
srs = osr.SpatialReference()
srs.ImportFromEPSG(32731) # WGS84 / UTM zone 31s
@ -64,20 +65,22 @@ class UsageLoggingTests(unittest.TestCase):
'blank_vector_path': '',
}
args_spec = {
'args': {
'raster': {'type': 'raster'},
'vector': {'type': 'vector'},
'not_a_gis_input': {'type': 'freestyle_string'},
'blank_raster_path': {'type': 'raster'},
'blank_vector_path': {'type': 'vector'},
}
}
model_spec = spec.ModelSpec(
model_id='', model_title='', userguide=None, aliases=None,
inputs=[
spec.SingleBandRasterInput(id='raster'),
spec.VectorInput(id='vector', geometry_types={}, fields={}),
spec.StringInput(id='not_a_gis_input'),
spec.SingleBandRasterInput(id='blank_raster_path'),
spec.VectorInput(id='blank_vector_path', geometry_types={}, fields={})
],
outputs={},
input_field_order=[])
output_logfile = os.path.join(self.workspace_dir, 'logfile.txt')
with utils.log_to_file(output_logfile):
bb_inter, bb_union = usage._calculate_args_bounding_box(
model_args, args_spec)
model_args, model_spec)
numpy.testing.assert_allclose(
bb_inter, [-87.234108, -85.526151, -87.233424, -85.526205])

File diff suppressed because it is too large Load Diff

View File

@ -400,6 +400,76 @@ class WindEnergyUnitTests(unittest.TestCase):
numpy.testing.assert_allclose(
actual_levelized_array, desired_levelized_array)
def test_raster_values_to_point_vector(self):
"""WindEnergy: testing 'index_raster_values_to_point_vector' function."""
from natcap.invest import wind_energy
srs = osr.SpatialReference()
srs.ImportFromEPSG(3157)
projection_wkt = srs.ExportToWkt()
origin = (443723.127327877911739, 4956546.905980412848294)
pos_x = origin[0]
pos_y = origin[1]
# Setup parameters for creating point shapefile
fields = {'id': ogr.OFTReal}
attrs = [
{'id': 0}, {'id': 1}, {'id': 2},
{'id': 3}, {'id': 4}, {'id': 5}]
valid_ids = [1, 5]
geometries = [
Point(pos_x + 250, pos_y - 150), # out of extent
Point(pos_x + 50, pos_y), # valid
Point(pos_x + 150, pos_y), # invalid: nodata
Point(pos_x + 250, pos_y), # out of extent
Point(pos_x + 50, pos_y - 150), # invalid: nodata
Point(pos_x + 150, pos_y - 150)] # valid
base_vector_path = os.path.join(self.workspace_dir, 'base_vector.shp')
pygeoprocessing.shapely_geometry_to_vector(
geometries, base_vector_path, projection_wkt, 'ESRI Shapefile',
fields=fields, attribute_list=attrs, ogr_geom_type=ogr.wkbPoint)
# Setup parameters for create raster
matrix = numpy.array([
[1, -1],
[-1, 1]], dtype=numpy.int32)
# Create raster to use for testing input
raster_path = os.path.join(
self.workspace_dir, 'raster.tif')
pygeoprocessing.numpy_array_to_raster(
matrix, -1, (100, -100), origin, projection_wkt,
raster_path)
target_vector_path = os.path.join(self.workspace_dir, 'target_vector.shp')
wind_energy._index_raster_values_to_point_vector(
base_vector_path, [(raster_path, 'TestVal')],
target_vector_path, mask_keys=['TestVal'],
mask_field="Masked")
# Confirm that masked-out points still exist in the base vector,
# but have "Masked" values of 1. Out-of-extent points should
# have values of None
base_vector = gdal.OpenEx(base_vector_path, gdal.OF_VECTOR)
base_layer = base_vector.GetLayer()
self.assertEqual(base_layer.GetFeatureCount(), 6)
id_masked = [(feat.GetField('id'), feat.GetField('Masked'))
for feat in base_layer]
self.assertEqual(id_masked,
[(0, None), (1, 0), (2, 1), (3, None), (4, 1), (5, 0)])
base_layer = None
base_vector = None
# Confirm that the target vector has only valid points
target_vector = gdal.OpenEx(target_vector_path, gdal.OF_VECTOR)
target_layer = target_vector.GetLayer()
self.assertEqual(target_layer.GetFeatureCount(), 2)
feat_ids = [feat.GetField('id') for feat in target_layer]
self.assertEqual(feat_ids, valid_ids)
target_layer = None
target_vector = None
class WindEnergyRegressionTests(unittest.TestCase):
"""Regression tests for the Wind Energy module."""
@ -447,16 +517,6 @@ class WindEnergyRegressionTests(unittest.TestCase):
wind_energy.execute(args)
raster_results = [
'density_W_per_m2.tif', 'harvested_energy_MWhr_per_yr.tif']
for raster_path in raster_results:
model_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(args['workspace_dir'], 'output', raster_path))
reg_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(REGRESSION_DATA, 'noaoi', raster_path))
numpy.testing.assert_allclose(model_array, reg_array)
vector_path = 'wind_energy_points.shp'
_assert_vectors_equal(
@ -474,16 +534,6 @@ class WindEnergyRegressionTests(unittest.TestCase):
wind_energy.execute(args)
raster_results = [
'density_W_per_m2.tif', 'harvested_energy_MWhr_per_yr.tif']
for raster_path in raster_results:
model_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(args['workspace_dir'], 'output', raster_path))
reg_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(REGRESSION_DATA, 'nolandpoly', raster_path))
numpy.testing.assert_allclose(model_array, reg_array)
vector_path = 'wind_energy_points.shp'
_assert_vectors_equal(
@ -503,16 +553,6 @@ class WindEnergyRegressionTests(unittest.TestCase):
wind_energy.execute(args)
raster_results = [
'density_W_per_m2.tif', 'harvested_energy_MWhr_per_yr.tif']
for raster_path in raster_results:
model_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(args['workspace_dir'], 'output', raster_path))
reg_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(REGRESSION_DATA, 'nodistances', raster_path))
numpy.testing.assert_allclose(model_array, reg_array)
vector_path = 'wind_energy_points.shp'
_assert_vectors_equal(
@ -552,17 +592,6 @@ class WindEnergyRegressionTests(unittest.TestCase):
# that have already been made, but which need to be created again.
wind_energy.execute(args)
raster_results = [
'carbon_emissions_tons.tif',
'levelized_cost_price_per_kWh.tif', 'npv.tif']
for raster_path in raster_results:
model_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(args['workspace_dir'], 'output', raster_path))
reg_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(REGRESSION_DATA, 'pricevalgrid', raster_path))
numpy.testing.assert_allclose(model_array, reg_array)
vector_path = 'wind_energy_points.shp'
_assert_vectors_equal(
@ -596,19 +625,6 @@ class WindEnergyRegressionTests(unittest.TestCase):
wind_energy.execute(args)
raster_results = [
'carbon_emissions_tons.tif', 'levelized_cost_price_per_kWh.tif',
'npv.tif']
for raster_path in raster_results:
model_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(args['workspace_dir'], 'output', raster_path))
reg_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(REGRESSION_DATA, 'pricevalgridland', raster_path))
# loosened tolerance to pass against GDAL 2.2.4 and 2.4.1
numpy.testing.assert_allclose(
model_array, reg_array, rtol=1e-04)
vector_path = 'wind_energy_points.shp'
_assert_vectors_equal(
os.path.join(args['workspace_dir'], 'output', vector_path),
@ -640,17 +656,6 @@ class WindEnergyRegressionTests(unittest.TestCase):
wind_energy.execute(args)
raster_results = [
'carbon_emissions_tons.tif', 'levelized_cost_price_per_kWh.tif',
'npv.tif']
for raster_path in raster_results:
model_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(args['workspace_dir'], 'output', raster_path))
reg_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(REGRESSION_DATA, 'priceval', raster_path))
numpy.testing.assert_allclose(model_array, reg_array, rtol=1e-6)
vector_path = 'wind_energy_points.shp'
_assert_vectors_equal(
os.path.join(args['workspace_dir'], 'output', vector_path),
@ -683,17 +688,6 @@ class WindEnergyRegressionTests(unittest.TestCase):
wind_energy.execute(args)
raster_results = [
'carbon_emissions_tons.tif', 'levelized_cost_price_per_kWh.tif',
'npv.tif']
for raster_path in raster_results:
model_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(args['workspace_dir'], 'output', raster_path))
reg_array = pygeoprocessing.raster_to_numpy_array(
os.path.join(REGRESSION_DATA, 'priceval', raster_path))
numpy.testing.assert_allclose(model_array, reg_array, rtol=1e-6)
vector_path = 'wind_energy_points.shp'
_assert_vectors_equal(
os.path.join(args['workspace_dir'], 'output', vector_path),
@ -858,7 +852,8 @@ class WindEnergyValidationTests(unittest.TestCase):
from natcap.invest import wind_energy
from natcap.invest import validation
base_required_valuation = ['land_polygon_vector_path',
base_required_valuation = ['aoi_vector_path',
'land_polygon_vector_path',
'min_distance',
'max_distance',
'foundation_cost',

View File

@ -7,7 +7,7 @@ const EXT = OS === 'win32' ? 'exe' : 'dmg';
// Uniquely identify the changeset we're building & packaging.
const investVersion = execFileSync(
'../dist/invest/invest', ['--version']
).toString().trim();
).toString().split('\n')[0].trim();
// the appID may not display anywhere, but seems to control if the
// install overwrites pre-existing or creates a new install directory.
@ -35,6 +35,14 @@ const config = {
from: '../dist/invest',
to: 'invest',
},
{
from: '../dist/micromamba', // mac
to: 'micromamba',
},
{
from: '../dist/micromamba.exe', // windows
to: 'micromamba.exe',
},
{
from: '../dist/userguide',
to: 'documentation',

View File

@ -8,10 +8,10 @@
"start": "yarn build-main && yarn build:preload && concurrently --kill-others \"yarn serve\" \"electron .\"",
"serve": "cross-env MODE=development node scripts/watch.js",
"lint": "eslint --cache --color --ext .jsx,.js src",
"test": "cross-env PORT=56788 jest --runInBand --testPathIgnorePatterns /tests/binary_tests/ /tests/sampledata_linkcheck/",
"test": "jest --runInBand --testPathIgnorePatterns /tests/binary_tests/ /tests/sampledata_linkcheck/",
"test-main": "jest --runInBand --testMatch **/tests/main/*.test.js",
"test-renderer": "jest --runInBand --testMatch **/tests/renderer/*.test.js",
"test-flask": "cross-env PORT=56788 jest --runInBand --testMatch **/tests/invest/*.test.js",
"test-flask": "jest --runInBand --testMatch **/tests/invest/*.test.js",
"test-electron-app": "jest --runInBand --testMatch **/tests/binary_tests/*.test.js",
"test-sampledata-registry": "jest --runInBand --testMatch **/tests/sampledata_linkcheck/*.test.js",
"postinstall": "electron-builder install-app-deps",
@ -58,8 +58,11 @@
"i18next": "^22.4.9",
"localforage": "^1.9.0",
"node-fetch": "^2.6.7",
"nodejs-file-downloader": "^4.13.0",
"prop-types": "^15.7.2",
"react-i18next": "^12.1.4",
"toml": "^3.0.0",
"upath": "^2.0.1",
"yauzl": "^2.10.0"
},
"devDependencies": {
@ -103,4 +106,4 @@
"vite": "^5.4.14",
"yazl": "^2.5.1"
}
}
}

View File

@ -1,26 +1,59 @@
import { spawn, execSync } from 'child_process';
import http from 'http';
import fetch from 'node-fetch';
import { getLogger } from './logger';
import { settingsStore } from './settingsStore';
const logger = getLogger(__filename.split('/').slice(-1)[0]);
const HOSTNAME = 'http://127.0.0.1';
/**
* Spawn a child process running the Python Flask app.
*
* @param {string} investExe - path to executeable that launches flask app.
* @returns {ChildProcess} - a reference to the subprocess.
*/
export function createPythonFlaskProcess(investExe) {
const pythonServerProcess = spawn(
investExe,
['--debug', 'serve', '--port', process.env.PORT],
{ shell: true } // necessary in dev mode & relying on a conda env
);
const pidToSubprocess = {};
logger.debug(`Started python process as PID ${pythonServerProcess.pid}`);
// https://stackoverflow.com/a/71178451
async function getFreePort() {
return new Promise((resolve) => {
const srv = http.createServer();
srv.listen(0, () => {
const { port } = srv.address();
srv.close(() => resolve(port));
});
});
}
/** Find out if the Flask server is online, waiting until it is.
*
* @param {number} i - the number or previous tries
* @param {number} retries - number of recursive calls this function is allowed.
* @returns { Promise } resolves text indicating success.
*/
export async function getFlaskIsReady(port, i = 0, retries = 41) {
try {
await fetch(`${HOSTNAME}:${port}/api/ready`, {
method: 'get',
});
} catch (error) {
if (error.code === 'ECONNREFUSED') {
while (i < retries) {
i++;
// Try every X ms, usually takes a couple seconds to startup.
await new Promise((resolve) => setTimeout(resolve, 300));
logger.debug(`retry # ${i}`);
return getFlaskIsReady(port, i, retries);
}
logger.error(`Not able to connect to server after ${retries} tries.`);
}
logger.error(error);
throw error;
}
}
/**
* Set up handlers for server process events.
* @param {ChildProcess} pythonServerProcess - server process instance.
* @returns {undefined}
*/
export function setupServerProcessHandlers(pythonServerProcess) {
pythonServerProcess.stdout.on('data', (data) => {
logger.debug(`${data}`);
});
@ -28,9 +61,10 @@ export function createPythonFlaskProcess(investExe) {
logger.debug(`${data}`);
});
pythonServerProcess.on('error', (err) => {
logger.error(pythonServerProcess.spawnargs);
logger.error(err.stack);
logger.error(
`The flask app ${investExe} crashed or failed to start
`The invest flask app crashed or failed to start
so this application must be restarted`
);
throw err;
@ -42,52 +76,84 @@ export function createPythonFlaskProcess(investExe) {
logger.debug(`Flask process exited with code ${code}`);
});
pythonServerProcess.on('disconnect', () => {
logger.debug(`Flask process disconnected`);
logger.debug('Flask process disconnected');
});
return pythonServerProcess;
pidToSubprocess[pythonServerProcess.pid] = pythonServerProcess;
}
/** Find out if the Flask server is online, waiting until it is.
/**
* Spawn a child process running the Python Flask app for core invest.
*
* @param {number} i - the number or previous tries
* @param {number} retries - number of recursive calls this function is allowed.
* @returns { Promise } resolves text indicating success.
* @param {integer} _port - if provided, port to launch server on. Otherwise,
* an available port is chosen.
* @returns { integer } - PID of the process that was launched
*/
export async function getFlaskIsReady({ i = 0, retries = 41 } = {}) {
try {
await fetch(`${HOSTNAME}:${process.env.PORT}/api/ready`, {
method: 'get',
});
} catch (error) {
if (error.code === 'ECONNREFUSED') {
while (i < retries) {
i++;
// Try every X ms, usually takes a couple seconds to startup.
await new Promise((resolve) => setTimeout(resolve, 300));
logger.debug(`retry # ${i}`);
return getFlaskIsReady({ i: i, retries: retries });
}
logger.error(`Not able to connect to server after ${retries} tries.`);
}
logger.error(error);
throw error;
export async function createCoreServerProcess(_port = undefined) {
let port = _port;
if (port === undefined) {
port = await getFreePort();
}
logger.debug('creating invest core server process');
const pythonServerProcess = spawn(
settingsStore.get('investExe'),
['--debug', 'serve', '--port', port],
{ shell: true } // necessary in dev mode & relying on a conda env
);
settingsStore.set('core.port', port);
settingsStore.set('core.pid', pythonServerProcess.pid);
logger.debug(`Started python process as PID ${pythonServerProcess.pid}`);
setupServerProcessHandlers(pythonServerProcess);
await getFlaskIsReady(port, 0, 500);
logger.info('flask is ready');
}
/**
* Spawn a child process running the Python Flask app for a plugin.
* @param {string} modelID - id of the plugin to launch
* @param {integer} _port - if provided, port to launch server on. Otherwise,
* an available port is chosen.
* @returns { integer } - PID of the process that was launched
*/
export async function createPluginServerProcess(modelID, _port = undefined) {
let port = _port;
if (port === undefined) {
port = await getFreePort();
}
logger.debug('creating invest plugin server process');
const micromamba = settingsStore.get('micromamba');
const modelEnvPath = settingsStore.get(`plugins.${modelID}.env`);
const args = [
'run', '--prefix', `"${modelEnvPath}"`,
'invest', '--debug', 'serve', '--port', port];
logger.debug('spawning command:', micromamba, args);
// shell mode is necessary in dev mode & relying on a conda env
const pythonServerProcess = spawn(micromamba, args, { shell: true });
settingsStore.set(`plugins.${modelID}.port`, port);
settingsStore.set(`plugins.${modelID}.pid`, pythonServerProcess.pid);
logger.debug(`Started python process as PID ${pythonServerProcess.pid}`);
setupServerProcessHandlers(pythonServerProcess);
await getFlaskIsReady(port, 0, 500);
logger.info('flask is ready');
return pythonServerProcess.pid;
}
/**
* Kill the process running the Flask app
*
* @param {ChildProcess} subprocess - such as created by child_process.spawn
* @param {number} pid - process ID of the child process to shut down
* @returns {Promise}
*/
export async function shutdownPythonProcess(subprocess) {
export async function shutdownPythonProcess(pid) {
// builtin kill() method on a nodejs ChildProcess doesn't work on windows.
try {
if (process.platform !== 'win32') {
subprocess.kill();
pidToSubprocess[pid.toString()].kill();
} else {
const { pid } = subprocess;
execSync(`taskkill /pid ${pid} /t /f`);
}
} catch (error) {

View File

@ -1,5 +1,6 @@
import path from 'path';
import { spawnSync } from 'child_process';
import upath from 'upath';
import fs from 'fs';
import { execSync, spawnSync } from 'child_process';
import { ipcMain } from 'electron';
@ -14,7 +15,7 @@ const logger = getLogger(__filename.split('/').slice(-1)[0]);
* @param {boolean} isDevMode - a boolean designating dev mode or not.
* @returns {string} invest binary path string.
*/
export default function findInvestBinaries(isDevMode) {
export function findInvestBinaries(isDevMode) {
// Binding to the invest server binary:
let investExe;
const ext = (process.platform === 'win32') ? '.exe' : '';
@ -23,7 +24,7 @@ export default function findInvestBinaries(isDevMode) {
if (isDevMode) {
investExe = filename; // assume an active python env w/ exe on path
} else {
investExe = path.join(process.resourcesPath, 'invest', filename);
investExe = upath.join(process.resourcesPath, 'invest', filename);
// It's likely the exe path includes spaces because it's composed of
// app's Product Name, a user-facing name given to electron-builder.
// Quoting depends on the shell, ('/bin/sh' or 'cmd.exe') and type of
@ -51,3 +52,27 @@ export default function findInvestBinaries(isDevMode) {
);
return investExe;
}
/**
* Return the available micromamba executable.
*
* @param {boolean} isDevMode - a boolean designating dev mode or not.
* @returns {string} micromamba executable.
*/
export function findMicromambaExecutable(isDevMode) {
let micromambaExe;
if (isDevMode) {
micromambaExe = 'micromamba'; // assume that micromamba is available
} else {
micromambaExe = `"${upath.join(process.resourcesPath, 'micromamba')}"`;
}
// Check that the executable is working
const { stderr, error } = spawnSync(micromambaExe, ['--help'], { shell: true });
if (error) {
logger.error(stderr.toString());
logger.error('micromamba executable is not where we expected it.');
throw error;
}
logger.info(`using micromamba executable '${micromambaExe}'`);
return micromambaExe;
}

View File

@ -2,22 +2,17 @@
* Assign a class based on text content.
*
* @param {string} message - from a python logger
* @param {string} pyModuleName - e.g. 'natcap.invest.carbon'
* @returns {string} - a class name or an empty string
*/
export default function markupMessage(message, pyModuleName) {
const escapedPyModuleName = pyModuleName.replace(/\./g, '\\.');
const patterns = {
'invest-log-error': /(ERROR|CRITICAL)/,
'invest-log-primary-warning': new RegExp(`${escapedPyModuleName}.*WARNING`),
'invest-log-primary': new RegExp(escapedPyModuleName)
};
// eslint-disable-next-line
for (const [cls, pattern] of Object.entries(patterns)) {
if (pattern.test(message)) {
return cls;
}
export default function markupMessage(message) {
if (/(ERROR|CRITICAL)/.test(message)) {
return 'invest-log-error';
}
if (/natcap\.invest.*WARNING/.test(message)) {
return 'invest-log-primary-warning';
}
if (/natcap\.invest/.test(message)) {
return 'invest-log-primary';
}
return '';
}

View File

@ -4,6 +4,7 @@ import fetch from 'node-fetch';
import { getLogger } from './logger';
import pkg from '../../package.json';
import { settingsStore } from './settingsStore';
const logger = getLogger(__filename.split('/').slice(-1)[0]);
const WORKBENCH_VERSION = pkg.version;
@ -13,16 +14,28 @@ const PREFIX = 'api';
export default function investUsageLogger() {
const sessionId = crypto.randomUUID();
function start(modelPyName, args) {
function start(modelID, args, port) {
logger.debug('logging model start');
fetch(`${HOSTNAME}:${process.env.PORT}/${PREFIX}/log_model_start`, {
const body = {
model_id: modelID,
model_args: JSON.stringify(args),
invest_interface: `Workbench ${WORKBENCH_VERSION}`,
session_id: sessionId,
};
const plugins = settingsStore.get('plugins');
if (plugins && Object.keys(plugins).includes(modelID)) {
const source = plugins[modelID].source;
body.type = 'plugin';
// don't log the path to a local plugin, just log that it's local
body.source = source.startsWith('git+') ? source : 'local';
} else {
body.type = 'core';
}
fetch(`${HOSTNAME}:${port}/${PREFIX}/log_model_start`, {
method: 'post',
body: JSON.stringify({
model_pyname: modelPyName,
model_args: JSON.stringify(args),
invest_interface: `Workbench ${WORKBENCH_VERSION}`,
session_id: sessionId,
}),
body: JSON.stringify(body),
headers: { 'Content-Type': 'application/json' },
})
.then(async (response) => {
@ -31,9 +44,9 @@ export default function investUsageLogger() {
.catch((error) => logger.error(error));
}
function exit(status) {
function exit(status, port) {
logger.debug('logging model exit');
fetch(`${HOSTNAME}:${process.env.PORT}/${PREFIX}/log_model_exit`, {
fetch(`${HOSTNAME}:${port}/${PREFIX}/log_model_exit`, {
method: 'post',
body: JSON.stringify({
session_id: sessionId,

View File

@ -1,23 +1,28 @@
export const ipcMainChannels = {
ADD_PLUGIN: 'add-plugin',
BASE_URL: 'base-url',
CHANGE_LANGUAGE: 'change-language',
CHECK_FILE_PERMISSIONS: 'check-file-permissions',
CHECK_STORAGE_TOKEN: 'check-storage-token',
DOWNLOAD_MSVC: 'download-msvc',
DOWNLOAD_URL: 'download-url',
GET_ELECTRON_PATHS: 'get-electron-paths',
GET_N_CPUS: 'get-n-cpus',
GET_SETTING: 'get-setting',
GET_LANGUAGE: 'get-language',
HAS_MSVC: 'has-msvc',
INVEST_KILL: 'invest-kill',
INVEST_READ_LOG: 'invest-read-log',
INVEST_RUN: 'invest-run',
INVEST_VERSION: 'invest-version',
IS_FIRST_RUN: 'is-first-run',
LAUNCH_PLUGIN_SERVER: 'launch-plugin-server',
IS_NEW_VERSION: 'is-new-version',
LOGGER: 'logger',
OPEN_EXTERNAL_URL: 'open-external-url',
OPEN_PATH: 'open-path',
OPEN_LOCAL_HTML: 'open-local-html',
REMOVE_PLUGIN: 'remove-plugin',
SET_SETTING: 'set-setting',
SHOW_ITEM_IN_FOLDER: 'show-item-in-folder',
SHOW_OPEN_DIALOG: 'show-open-dialog',

View File

@ -1,5 +1,4 @@
import path from 'path';
import i18n from './i18n/i18n';
// eslint-disable-next-line import/no-extraneous-dependencies
import {
app,
@ -9,13 +8,29 @@ import {
ipcMain
} from 'electron';
import i18n from './i18n/i18n';
import BASE_URL from './baseUrl';
import {
createPythonFlaskProcess,
getFlaskIsReady,
shutdownPythonProcess,
createCoreServerProcess,
shutdownPythonProcess
} from './createPythonFlaskProcess';
import findInvestBinaries from './findInvestBinaries';
import { findInvestBinaries, findMicromambaExecutable } from './findBinaries';
import setupDownloadHandlers from './setupDownloadHandlers';
import setupDialogs from './setupDialogs';
import setupContextMenu from './setupContextMenu';
import setupCheckFilePermissions from './setupCheckFilePermissions';
import { setupCheckFirstRun } from './setupCheckFirstRun';
import { setupCheckStorageToken } from './setupCheckStorageToken';
import {
setupInvestRunHandlers,
setupLaunchPluginServerHandler,
setupInvestLogReaderHandler
} from './setupInvestHandlers';
import {
setupAddPlugin,
setupRemovePlugin,
setupWindowsMSVCHandlers
} from './setupAddRemovePlugin';
import { ipcMainChannels } from './ipcMainChannels';
import ELECTRON_DEV_MODE from './isDevMode';
import { getLogger } from './logger';
@ -23,24 +38,13 @@ import menuTemplate from './menubar';
import pkg from '../../package.json';
import { settingsStore, setupSettingsHandlers } from './settingsStore';
import { setupBaseUrl } from './setupBaseUrl';
import setupCheckFilePermissions from './setupCheckFilePermissions';
import { setupCheckFirstRun } from './setupCheckFirstRun';
import { setupCheckStorageToken } from './setupCheckStorageToken';
import setupContextMenu from './setupContextMenu';
import setupDialogs from './setupDialogs';
import setupDownloadHandlers from './setupDownloadHandlers';
import setupGetElectronPaths from './setupGetElectronPaths';
import setupGetNCPUs from './setupGetNCPUs';
import {
setupInvestLogReaderHandler,
setupInvestRunHandlers,
} from './setupInvestHandlers';
import { setupIsNewVersion } from './setupIsNewVersion';
import setupOpenExternalUrl from './setupOpenExternalUrl';
import setupOpenLocalHtml from './setupOpenLocalHtml';
import setupRendererLogger from './setupRendererLogger';
const logger = getLogger(__filename.split('/').slice(-1)[0]);
process.on('uncaughtException', (err) => {
@ -53,15 +57,10 @@ process.on('unhandledRejection', (err, promise) => {
process.exit(1);
});
if (!process.env.PORT) {
process.env.PORT = '56789';
}
// Keep a global reference of the window object, if you don't, the window will
// be closed automatically when the JavaScript object is garbage collected.
let mainWindow;
let splashScreen;
let flaskSubprocess;
let forceQuit = false;
export function destroyWindow() {
@ -84,8 +83,19 @@ export const createWindow = async () => {
});
splashScreen.loadURL(path.join(BASE_URL, 'splash.html'));
const investExe = findInvestBinaries(ELECTRON_DEV_MODE);
flaskSubprocess = createPythonFlaskProcess(investExe);
settingsStore.set('investExe', findInvestBinaries(ELECTRON_DEV_MODE));
settingsStore.set('micromamba', findMicromambaExecutable(ELECTRON_DEV_MODE));
// No plugin server processes should persist between workbench sessions
// In case any were left behind, remove them
const plugins = settingsStore.get('plugins');
if (plugins) {
Object.keys(plugins).forEach((modelID) => {
settingsStore.set(`plugins.${modelID}.pid`, '');
settingsStore.set(`plugins.${modelID}.port`, '');
});
}
await createCoreServerProcess();
setupDialogs();
setupCheckFilePermissions();
setupCheckFirstRun();
@ -98,7 +108,6 @@ export const createWindow = async () => {
setupOpenExternalUrl();
setupRendererLogger();
setupBaseUrl();
await getFlaskIsReady();
const devModeArg = ELECTRON_DEV_MODE ? '--devmode' : '';
// Create the browser window.
@ -108,7 +117,7 @@ export const createWindow = async () => {
webPreferences: {
preload: path.join(__dirname, '../preload/preload.js'),
defaultEncoding: 'UTF-8',
additionalArguments: [devModeArg, `--port=${process.env.PORT}`],
additionalArguments: [devModeArg],
},
});
Menu.setApplicationMenu(
@ -151,7 +160,11 @@ export const createWindow = async () => {
// have callbacks that won't work until the invest server is ready.
setupContextMenu(mainWindow);
setupDownloadHandlers(mainWindow);
setupInvestRunHandlers(investExe);
setupInvestRunHandlers();
setupLaunchPluginServerHandler();
setupAddPlugin(i18n);
setupRemovePlugin();
setupWindowsMSVCHandlers();
setupOpenLocalHtml(mainWindow, ELECTRON_DEV_MODE);
if (ELECTRON_DEV_MODE) {
// The timing of this is fussy due a chromium bug. It seems to only
@ -203,8 +216,22 @@ export function main() {
if (shuttingDown) { return; }
event.preventDefault();
shuttingDown = true;
await shutdownPythonProcess(settingsStore.get('core.pid'));
settingsStore.set('core.pid', '');
settingsStore.set('core.port', '');
const pluginServerPIDs = [];
const plugins = settingsStore.get('plugins') || {};
Object.keys(plugins).forEach((pluginID) => {
const pid = settingsStore.get(`plugins.${pluginID}.pid`);
if (pid) {
pluginServerPIDs.push(pid);
}
settingsStore.set(`plugins.${pluginID}.pid`, '');
settingsStore.set(`plugins.${pluginID}.port`, '');
});
await Promise.all(pluginServerPIDs.map((pid) => shutdownPythonProcess(pid)));
removeIpcMainListeners();
await shutdownPythonProcess(flaskSubprocess);
app.quit();
});
}

View File

@ -123,7 +123,7 @@ function createWindow(parentWindow, isDevMode) {
minimumFontSize: 12,
preload: path.join(__dirname, '../preload/preload.js'),
defaultEncoding: 'UTF-8',
additionalArguments: [devModeArg, `--port=${process.env.PORT}`],
additionalArguments: [devModeArg],
},
});
setupContextMenu(win);

View File

@ -73,7 +73,7 @@ export const settingsStore = initStore();
export function setupSettingsHandlers() {
ipcMain.handle(
ipcMainChannels.GET_SETTING,
(event, key) => settingsStore.get(key)
(event, key) => Promise.resolve(settingsStore.get(key))
);
ipcMain.on(

View File

@ -0,0 +1,239 @@
import upath from 'upath';
import fs from 'fs';
import { tmpdir } from 'os';
import toml from 'toml';
import { execFile, execSync, spawn } from 'child_process';
import { promisify } from 'util';
import { app, ipcMain } from 'electron';
import { Downloader } from 'nodejs-file-downloader';
import crypto from 'crypto';
import { getLogger } from './logger';
import { ipcMainChannels } from './ipcMainChannels';
import { settingsStore } from './settingsStore';
const logger = getLogger(__filename.split('/').slice(-1)[0]);
/**
* Spawn a child process and log its stdout, stderr, and any error in spawning.
*
* child_process.spawn is called with the provided cmd, args, and options,
* and the windowsHide option set to true. The shell option is set to true
* because spawn by default sets shell to false.
*
* Required properties missing from the store are initialized with defaults.
* Invalid properties are reset to defaults.
* @param {string} cmd - command to pass to spawn
* @param {Array} args - command arguments to pass to spawn
* @param {object} options - options to pass to spawn.
* @returns {Promise} resolves when the command finishes with exit code 0.
* Rejects with error otherwise.
*/
function spawnWithLogging(cmd, args, options) {
logger.info(cmd, args);
const cmdProcess = spawn(
cmd, args, { ...options, shell: true, windowsHide: true });
let errMessage;
if (cmdProcess.stdout) {
cmdProcess.stderr.on('data', (data) => {
errMessage = data.toString();
logger.info(errMessage);
});
cmdProcess.stdout.on('data', (data) => logger.info(data.toString()));
}
return new Promise((resolve, reject) => {
cmdProcess.on('error', (err) => {
logger.error(err);
reject(err);
});
cmdProcess.on('close', (code) => {
if (code === 0) {
resolve(code);
} else {
reject(errMessage);
}
});
});
}
export function setupAddPlugin(i18n) {
ipcMain.handle(
ipcMainChannels.ADD_PLUGIN,
async (event, url, revision, path) => {
try {
let pyprojectTOML;
let installString;
const micromamba = settingsStore.get('micromamba');
const rootPrefix = upath.join(app.getPath('userData'), 'micromamba_envs');
if (url) { // install from git URL
if (revision) {
installString = `git+${url}@${revision}`;
logger.info(`adding plugin from ${installString}`);
} else {
installString = `git+${url}`;
logger.info(`adding plugin from ${installString} at default branch`);
}
const baseEnvPrefix = upath.join(rootPrefix, 'invest_base');
// Create invest_base environment, if it doesn't already exist
// The purpose of this environment is just to ensure that git is available
if (!fs.existsSync(baseEnvPrefix)) {
event.sender.send('plugin-install-status', i18n.t('Creating base environment...'));
await spawnWithLogging(
micromamba,
['create', '--yes', '--prefix', `"${baseEnvPrefix}"`, '-c', 'conda-forge', 'git']
);
}
// Create a temporary directory and check out the plugin's pyproject.toml,
// without downloading any extra files or git history
event.sender.send('plugin-install-status', i18n.t('Downloading plugin source code...'));
const tmpPluginDir = fs.mkdtempSync(upath.join(tmpdir(), 'natcap-invest-'));
await spawnWithLogging(
micromamba,
['run', '--prefix', `"${baseEnvPrefix}"`,
'git', 'clone', '--depth', 1, '--no-checkout', url, tmpPluginDir]);
let head = 'HEAD';
if (revision) {
head = 'FETCH_HEAD';
await spawnWithLogging(
micromamba,
['run', '--prefix', `"${baseEnvPrefix}"`, 'git', 'fetch', 'origin', `${revision}`],
{ cwd: tmpPluginDir }
);
}
await spawnWithLogging(
micromamba,
['run', '--prefix', `"${baseEnvPrefix}"`, 'git', 'checkout', head, '--', 'pyproject.toml'],
{ cwd: tmpPluginDir }
);
// Read in the plugin's pyproject.toml, then delete it
pyprojectTOML = toml.parse(fs.readFileSync(
upath.join(tmpPluginDir, 'pyproject.toml')
).toString());
fs.rmSync(tmpPluginDir, { recursive: true, force: true });
} else { // install from local path
logger.info(`adding plugin from ${path}`);
installString = path;
// Read in the plugin's pyproject.toml
pyprojectTOML = toml.parse(fs.readFileSync(
upath.join(path, 'pyproject.toml')
).toString());
}
// Access plugin metadata from the pyproject.toml
const condaDeps = pyprojectTOML.tool.natcap.invest.conda_dependencies;
const packageName = pyprojectTOML.tool.natcap.invest.package_name;
const version = pyprojectTOML.project.version;
// Create a conda env containing the plugin and its dependencies
// use timestamp to ensure a unique path
// I wanted the env path to match the plugin model_id, but we can't
// know the model_id until after creating the environment to be able to
// import metadata from the MODEL_SPEC. And mamba does not support
// renaming or moving environments after they're created.
const pluginEnvPrefix = upath.join(rootPrefix, `plugin_${Date.now()}`);
const createCommand = [
'create', '--yes', '--prefix', `"${pluginEnvPrefix}"`,
'-c', 'conda-forge', 'python', 'git'];
if (condaDeps) { // include dependencies read from pyproject.toml
condaDeps.forEach((dep) => createCommand.push(`"${dep}"`));
}
event.sender.send('plugin-install-status', i18n.t('Creating plugin environment...'));
await spawnWithLogging(micromamba, createCommand);
logger.info('created micromamba env for plugin');
event.sender.send('plugin-install-status', i18n.t('Installing plugin into environment...'));
await spawnWithLogging(
micromamba,
['run', '--prefix', `"${pluginEnvPrefix}"`,
'python', '-m', 'pip', 'install', installString]
);
logger.info('installed plugin into its env');
event.sender.send('plugin-install-status', i18n.t('Importing plugin...'));
// Access plugin metadata from the MODEL_SPEC
const modelID = execSync(
`micromamba run --prefix "${pluginEnvPrefix}" ` +
`python -c "import ${packageName}; print(${packageName}.MODEL_SPEC.model_id)"`
).toString().trim();
const modelTitle= execSync(
`micromamba run --prefix "${pluginEnvPrefix}" ` +
`python -c "import ${packageName}; print(${packageName}.MODEL_SPEC.model_title)"`
).toString().trim();
// Write plugin metadata to the workbench's config.json
logger.info('writing plugin info to settings store');
// Uniquely identify plugin by a hash of its ID and version
// Hashing because the version may contain dots,
// which doesn't work well as a key for electron-store's set and get methods
const pluginID = crypto.createHash('sha1').update(`${modelID}@${version}`).digest('hex');
settingsStore.set(
`plugins.${pluginID}`,
{
modelID: modelID,
modelTitle: modelTitle,
type: 'plugin',
source: installString,
env: pluginEnvPrefix,
version: version,
}
);
logger.info('successfully added plugin');
} catch (error) {
console.log(error);
return error;
}
}
);
}
export function setupRemovePlugin() {
ipcMain.handle(
ipcMainChannels.REMOVE_PLUGIN,
async (e, pluginID) => {
logger.info('removing plugin', pluginID);
try {
// Delete the plugin's conda env
const env = settingsStore.get(`plugins.${pluginID}.env`);
const micromamba = settingsStore.get('micromamba');
await spawnWithLogging(micromamba, ['env', 'remove', '--yes', '--prefix', `"${env}"`]);
// Delete the plugin's data from storage
settingsStore.delete(`plugins.${pluginID}`);
logger.info('successfully removed plugin');
} catch (error) {
logger.info('Error removing plugin:');
logger.info(error);
return error;
}
}
);
}
export function setupWindowsMSVCHandlers() {
ipcMain.handle(
ipcMainChannels.HAS_MSVC,
() => {
return fs.existsSync(upath.join('C:', 'Windows', 'System32', 'VCRUNTIME140_1.dll'));
}
);
ipcMain.handle(
ipcMainChannels.DOWNLOAD_MSVC,
async () => {
const tmpDir = app.getPath('temp');
const exeName = 'vc_redist.x64.exe';
const downloader = new Downloader({
url: 'https://aka.ms/vs/17/release/vc_redist.x64.exe',
directory: tmpDir,
fileName: exeName,
});
try {
await downloader.download();
logger.info("Download complete");
} catch (error) {
logger.error("Download failed", error);
}
logger.info('Launching MSVC installer');
const exePath = upath.join(tmpDir, exeName);
await promisify(execFile)(exePath, ['/norestart']);
}
);
}

View File

@ -12,6 +12,7 @@ import investUsageLogger from './investUsageLogger';
import markupMessage from './investLogMarkup';
import writeInvestParameters from './writeInvestParameters';
import { settingsStore } from './settingsStore';
import { createPluginServerProcess } from './createPythonFlaskProcess';
const logger = getLogger(__filename.split('/').slice(-1)[0]);
@ -30,7 +31,17 @@ const TGLOGLEVELMAP = {
};
const TEMP_DIR = path.join(app.getPath('userData'), 'tmp');
export function setupInvestRunHandlers(investExe) {
export function setupLaunchPluginServerHandler() {
ipcMain.handle(
ipcMainChannels.LAUNCH_PLUGIN_SERVER,
async (event, pluginID) => {
const pid = await createPluginServerProcess(pluginID);
return pid;
}
);
}
export function setupInvestRunHandlers() {
const runningJobs = {};
ipcMain.on(ipcMainChannels.INVEST_KILL, (event, jobID) => {
@ -46,11 +57,10 @@ export function setupInvestRunHandlers(investExe) {
});
ipcMain.on(ipcMainChannels.INVEST_RUN, async (
event, modelRunName, pyModuleName, args, tabID
event, modelID, args, tabID
) => {
let investRun;
let investStarted = false;
let investStdErr = '';
const investStdErr = '';
const usageLogger = investUsageLogger();
const loggingLevel = settingsStore.get('loggingLevel');
const taskgraphLoggingLevel = settingsStore.get('taskgraphLoggingLevel');
@ -67,7 +77,7 @@ export function setupInvestRunHandlers(investExe) {
const datastackPath = path.join(tempDatastackDir, 'datastack.json');
const payload = {
filepath: datastackPath,
moduleName: pyModuleName,
model_id: modelID,
relativePaths: false,
args: JSON.stringify({
...args,
@ -75,28 +85,46 @@ export function setupInvestRunHandlers(investExe) {
}),
};
await writeInvestParameters(payload);
const cmdArgs = [
LOGLEVELMAP[loggingLevel],
TGLOGLEVELMAP[taskgraphLoggingLevel],
`--language "${language}"`,
'run',
modelRunName,
'--headless',
`-d "${datastackPath}"`,
];
logger.debug(`set to run ${cmdArgs}`);
if (process.platform !== 'win32') {
investRun = spawn(investExe, cmdArgs, {
shell: true, // without shell, IOError when datastack.py loads json
detached: true, // counter-intuitive, but w/ true: invest terminates when this shell terminates
});
} else { // windows
investRun = spawn(investExe, cmdArgs, {
shell: true,
});
let cmd;
let cmdArgs;
let port;
const plugins = settingsStore.get('plugins');
if (plugins && Object.keys(plugins).includes(modelID)) {
cmd = settingsStore.get('micromamba');
cmdArgs = [
'run',
`--prefix "${settingsStore.get(`plugins.${modelID}.env`)}"`,
'invest',
LOGLEVELMAP[loggingLevel],
TGLOGLEVELMAP[taskgraphLoggingLevel],
`--language "${language}"`,
'run',
settingsStore.get(`plugins.${modelID}.modelID`),
`-d "${datastackPath}"`,
];
port = settingsStore.get(`plugins.${modelID}.port`);
} else {
cmd = settingsStore.get('investExe');
cmdArgs = [
LOGLEVELMAP[loggingLevel],
TGLOGLEVELMAP[taskgraphLoggingLevel],
`--language "${language}"`,
'run',
modelID,
`-d "${datastackPath}"`];
port = settingsStore.get('core.port');
}
logger.debug(`about to run model with command: ${cmd} ${cmdArgs}`);
// without shell, IOError when datastack.py loads json
const spawnOptions = { shell: true };
if (process.platform !== 'win32') {
// counter-intuitive, but w/ true: invest terminates when this shell terminates
spawnOptions.detached = true;
}
const investRun = spawn(cmd, cmdArgs, spawnOptions);
// There's no general way to know that a spawned process started,
// so this logic to listen once on stdout seems like the way.
// We need to do the following only once after the process started:
@ -114,7 +142,7 @@ export function setupInvestRunHandlers(investExe) {
);
event.reply(`invest-logging-${tabID}`, path.resolve(investLogfile));
if (!ELECTRON_DEV_MODE && !process.env.PUPPETEER) {
usageLogger.start(pyModuleName, args);
usageLogger.start(modelID, args, port);
}
}
}
@ -122,7 +150,7 @@ export function setupInvestRunHandlers(investExe) {
// only be one logger message at a time.
event.reply(
`invest-stdout-${tabID}`,
[strData, markupMessage(strData, pyModuleName)]
[strData, markupMessage(strData)]
);
};
investRun.stdout.on('data', stdOutCallback);
@ -132,7 +160,7 @@ export function setupInvestRunHandlers(investExe) {
};
investRun.stderr.on('data', stdErrCallback);
investRun.on('close', (code) => {
investRun.on('close', () => {
logger.debug('invest subprocess stdio streams closed');
});
@ -149,14 +177,15 @@ export function setupInvestRunHandlers(investExe) {
});
});
if (!ELECTRON_DEV_MODE && !process.env.PUPPETEER) {
usageLogger.exit(investStdErr);
usageLogger.exit(investStdErr, port);
}
});
});
}
export function setupInvestLogReaderHandler() {
ipcMain.on(ipcMainChannels.INVEST_READ_LOG,
ipcMain.on(
ipcMainChannels.INVEST_READ_LOG,
(event, logfile, channel) => {
const fileStream = fs.createReadStream(logfile);
fileStream.on('error', (err) => {
@ -170,5 +199,6 @@ export function setupInvestLogReaderHandler() {
fileStream.on('data', (data) => {
event.reply(`invest-stdout-${channel}`, [`${data}`, '']);
});
});
}
);
}

View File

@ -1,13 +1,15 @@
import fetch from 'node-fetch';
import { getLogger } from './logger';
import { settingsStore } from './settingsStore';
const logger = getLogger(__filename.split('/').slice(-1)[0]);
const HOSTNAME = 'http://127.0.0.1';
export default function writeParametersToFile(payload) {
const port = settingsStore.get('core.port');
return (
fetch(`${HOSTNAME}:${process.env.PORT}/api/write_parameter_set_file`, {
fetch(`${HOSTNAME}:${port}/api/write_parameter_set_file`, {
method: 'post',
body: JSON.stringify(payload),
headers: { 'Content-Type': 'application/json' },

View File

@ -11,11 +11,10 @@ const ipcRendererChannels = [
/invest-stdout-*/,
/invest-exit-*/,
/download-status/,
/plugin-install-status/,
];
// args sent via `additionalArguments` to `webPreferences` for `BroswerWindow`
const portArg = process.argv.filter((arg) => arg.startsWith('--port'))[0];
const PORT = portArg ? portArg.split('=')[1] : '';
const userPaths = ipcRenderer.sendSync(ipcMainChannels.GET_ELECTRON_PATHS);
const isDevMode = process.argv.includes('--devmode');
@ -31,10 +30,10 @@ const electronLogPath = (userPaths)
: '';
export default {
PORT: PORT, // where the flask app is running
ELECTRON_LOG_PATH: electronLogPath,
USERGUIDE_PATH: userguidePath,
LANGUAGE: ipcRenderer.sendSync(ipcMainChannels.GET_LANGUAGE),
OS: process.platform,
logger: {
debug: (message) => ipcRenderer.send(ipcMainChannels.LOGGER, 'debug', message),
info: (message) => ipcRenderer.send(ipcMainChannels.LOGGER, 'info', message),
@ -56,7 +55,7 @@ export default {
},
sendSync: (channel, ...args) => {
if (Object.values(ipcMainChannels).includes(channel)) {
ipcRenderer.sendSync(channel, ...args);
return ipcRenderer.sendSync(channel, ...args);
}
},
on: (channel, func) => {

View File

@ -32,6 +32,19 @@ export default class InvestJob {
(key) => investJobStore.getItem(key)
));
}
// Migrate old-style jobs
// We can eventually remove this code once it's likely that most users
// will have updated and ran a newer version of the workbench
jobArray = await Promise.all(jobArray.map(async (job) => {
if (job.modelID === undefined) {
job.modelID = job.modelRunName;
job.modelTitle = job.modelHumanName;
delete job.modelRunName;
delete job.modelHumanName;
await investJobStore.setItem(job.hash, job);
}
return job;
}));
return jobArray;
}
@ -40,6 +53,18 @@ export default class InvestJob {
return InvestJob.getJobStore();
}
static async deleteJob(hash) {
await investJobStore.removeItem(hash);
// also remove item from the array that tracks the order of the jobs
const sortedJobHashes = await investJobStore.getItem(HASH_ARRAY_KEY);
const idx = sortedJobHashes.indexOf(hash);
if (idx > -1) {
sortedJobHashes.splice(idx, 1); // remove one item only
}
await investJobStore.setItem(HASH_ARRAY_KEY, sortedJobHashes);
return InvestJob.getJobStore();
}
static async saveJob(job) {
job.hash = window.crypto.getRandomValues(
new Uint32Array(1)
@ -65,30 +90,34 @@ export default class InvestJob {
/**
* @param {object} obj - with the following properties
* @param {string} obj.modelRunName - name to be passed to `invest run`
* @param {string} obj.modelHumanName - colloquial name of the invest model
* @param {string} obj.modelID - name to be passed to `invest run`
* @param {string} obj.modelTitle - colloquial name of the invest model
* @param {object} obj.argsValues - an invest "args dict" with initial values
* @param {string} obj.logfile - path to an existing invest logfile
* @param {string} obj.status - one of 'running'|'error'|'success'
* @param {string} obj.type - 'plugin' or 'core'
*/
constructor(
{
modelRunName,
modelHumanName,
modelID,
modelTitle,
argsValues,
logfile,
status,
type,
}
) {
if (!modelRunName || !modelHumanName) {
if (!modelID || !modelTitle) {
throw new Error(
'Cannot create instance of InvestJob without modelRunName and modelHumanName properties')
'Cannot create instance of InvestJob without modelID and modelTitle properties'
);
}
this.modelRunName = modelRunName;
this.modelHumanName = modelHumanName;
this.modelID = modelID;
this.modelTitle = modelTitle;
this.argsValues = argsValues;
this.logfile = logfile;
this.status = status;
this.type = type;
this.hash = null;
}
}

View File

@ -2,6 +2,7 @@ import React from 'react';
import PropTypes from 'prop-types';
import i18n from 'i18next';
import Badge from 'react-bootstrap/Badge';
import TabPane from 'react-bootstrap/TabPane';
import TabContent from 'react-bootstrap/TabContent';
import TabContainer from 'react-bootstrap/TabContainer';
@ -18,12 +19,16 @@ import { AiOutlineTrademarkCircle } from 'react-icons/ai';
import HomeTab from './components/HomeTab';
import InvestTab from './components/InvestTab';
import AppMenu from './components/AppMenu';
import SettingsModal from './components/SettingsModal';
import DataDownloadModal from './components/DataDownloadModal';
import DownloadProgressBar from './components/DownloadProgressBar';
import { getInvestModelNames } from './server_requests';
import PluginModal from './components/PluginModal';
import MetadataModal from './components/MetadataModal';
import InvestJob from './InvestJob';
import { dragOverHandlerNone } from './utils';
import { ipcMainChannels } from '../main/ipcMainChannels';
import { getInvestModelIDs } from './server_requests';
import Changelog from './components/Changelog';
const { ipcRenderer } = window.Workbench.electron;
@ -42,8 +47,11 @@ export default class App extends React.Component {
investList: null,
recentJobs: [],
showDownloadModal: false,
showPluginModal: false,
downloadedNofN: null,
showChangelog: false,
showSettingsModal: false,
showMetadataModal: false,
changelogDismissed: false,
};
this.switchTabs = this.switchTabs.bind(this);
@ -51,21 +59,24 @@ export default class App extends React.Component {
this.closeInvestModel = this.closeInvestModel.bind(this);
this.updateJobProperties = this.updateJobProperties.bind(this);
this.saveJob = this.saveJob.bind(this);
this.deleteJob = this.deleteJob.bind(this);
this.clearRecentJobs = this.clearRecentJobs.bind(this);
this.showDownloadModal = this.showDownloadModal.bind(this);
this.toggleDownloadModal = this.toggleDownloadModal.bind(this);
this.toggleSettingsModal = this.toggleSettingsModal.bind(this);
this.toggleMetadataModal = this.toggleMetadataModal.bind(this);
this.togglePluginModal = this.togglePluginModal.bind(this);
this.updateInvestList = this.updateInvestList.bind(this);
}
/** Initialize the list of invest models, recent invest jobs, etc. */
async componentDidMount() {
const investList = await getInvestModelNames();
const investList = await this.updateInvestList();
const recentJobs = await InvestJob.getJobStore();
this.setState({
investList: investList,
// filter out models that do not exist in current version of invest
recentJobs: recentJobs.filter((job) => (
Object.values(investList)
.map((m) => m.model_name)
.includes(job.modelRunName)
Object.keys(investList)
.includes(job.modelID)
)),
showDownloadModal: this.props.isFirstRun,
// Show changelog if this is a new version,
@ -84,8 +95,8 @@ export default class App extends React.Component {
ipcRenderer.removeAllListeners('download-status');
}
/** Change the tab that is currently visible.
*
/**
* Change the tab that is currently visible.
* @param {string} key - the value of one of the Nav.Link eventKey.
*/
switchTabs(key) {
@ -94,7 +105,7 @@ export default class App extends React.Component {
);
}
showDownloadModal(shouldShow) {
toggleDownloadModal(shouldShow) {
this.setState({
showDownloadModal: shouldShow,
});
@ -114,8 +125,26 @@ export default class App extends React.Component {
});
}
/** Push data for a new InvestTab component to an array.
*
togglePluginModal(show) {
this.setState({
showPluginModal: show
});
}
toggleMetadataModal(show) {
this.setState({
showMetadataModal: show
});
}
toggleSettingsModal(show) {
this.setState({
showSettingsModal: show
});
}
/**
* Push data for a new InvestTab component to an array.
* @param {InvestJob} job - as constructed by new InvestJob()
*/
openInvestModel(job) {
@ -133,7 +162,6 @@ export default class App extends React.Component {
/**
* Click handler for the close-tab button on an Invest model tab.
*
* @param {string} tabID - the eventKey of the tab containing the
* InvestTab component that will be removed.
*/
@ -161,8 +189,8 @@ export default class App extends React.Component {
});
}
/** Update properties of an open InvestTab.
*
/**
* Update properties of an open InvestTab.
* @param {string} tabID - the unique identifier of an open tab
* @param {obj} jobObj - key-value pairs of any job properties to be updated
*/
@ -174,10 +202,8 @@ export default class App extends React.Component {
});
}
/** Save data describing an invest job to a persistent store.
*
* And update the app's view of that store.
*
/**
* Save data describing an invest job to a persistent store.
* @param {string} tabID - the unique identifier of an open InvestTab.
*/
async saveJob(tabID) {
@ -188,6 +214,20 @@ export default class App extends React.Component {
});
}
/**
* Delete the job record from the store.
* @param {string} jobHash - the unique identifier of a saved Job.
*/
async deleteJob(jobHash) {
const recentJobs = await InvestJob.deleteJob(jobHash);
this.setState({
recentJobs: recentJobs,
});
}
/**
* Delete all the jobs from the store.
*/
async clearRecentJobs() {
const recentJobs = await InvestJob.clearStore();
this.setState({
@ -195,6 +235,22 @@ export default class App extends React.Component {
});
}
async updateInvestList() {
const coreModels = {};
const investList = await getInvestModelIDs();
Object.keys(investList).forEach((modelID) => {
coreModels[modelID] = { modelTitle: investList[modelID].model_title, type: 'core' };
});
const plugins = await ipcRenderer.invoke(ipcMainChannels.GET_SETTING, 'plugins') || {};
Object.keys(plugins).forEach((plugin) => {
plugins[plugin].type = 'plugin';
});
this.setState({
investList: { ...coreModels, ...plugins },
});
return { ...coreModels, ...plugins };
}
render() {
const {
investList,
@ -203,7 +259,10 @@ export default class App extends React.Component {
openTabIDs,
activeTab,
showDownloadModal,
showPluginModal,
showChangelog,
showSettingsModal,
showMetadataModal,
downloadedNofN,
} = this.state;
@ -233,13 +292,21 @@ export default class App extends React.Component {
default:
statusSymbol = '';
}
let badge;
if (investList) {
const modelType = investList[job.modelID].type;
if (modelType === 'plugin') {
badge = <Badge className="mr-1" variant="secondary">Plugin</Badge>;
}
}
investNavItems.push(
<OverlayTrigger
key={`${id}-tooltip`}
placement="bottom"
overlay={(
<Tooltip>
{job.modelHumanName}
{job.modelTitle}
</Tooltip>
)}
>
@ -258,11 +325,12 @@ export default class App extends React.Component {
}
}}
>
{badge}
{statusSymbol}
{` ${job.modelHumanName}`}
{` ${job.modelTitle}`}
</Nav.Link>
<Button
aria-label={`close ${job.modelHumanName} tab`}
aria-label={`close ${job.modelTitle} tab`}
className="close-tab"
variant="outline-dark"
onClick={(event) => {
@ -280,13 +348,14 @@ export default class App extends React.Component {
<TabPane
key={id}
eventKey={id}
aria-label={`${job.modelHumanName} tab`}
aria-label={`${job.modelTitle} tab`}
>
<InvestTab
job={job}
tabID={id}
saveJob={this.saveJob}
updateJobProperties={this.updateJobProperties}
investList={investList}
/>
</TabPane>
);
@ -294,17 +363,41 @@ export default class App extends React.Component {
return (
<React.Fragment>
<DataDownloadModal
show={showDownloadModal}
closeModal={() => this.showDownloadModal(false)}
/>
{
showChangelog &&
{showDownloadModal && (
<DataDownloadModal
show={showDownloadModal}
closeModal={() => this.toggleDownloadModal(false)}
/>
)}
{showPluginModal && (
<PluginModal
show={showPluginModal}
closeModal={() => this.togglePluginModal(false)}
openModal={() => this.togglePluginModal(true)}
updateInvestList={this.updateInvestList}
closeInvestModel={this.closeInvestModel}
openJobs={openJobs}
/>
)}
{showChangelog && (
<Changelog
show={showChangelog}
close={() => this.closeChangelogModal()}
/>
}
)}
{showMetadataModal && (
<MetadataModal
show={showMetadataModal}
close={() => this.toggleMetadataModal(false)}
/>
)}
{showSettingsModal && (
<SettingsModal
show={showSettingsModal}
close={() => this.toggleSettingsModal(false)}
nCPU={this.props.nCPU}
/>
)}
<TabContainer activeKey={activeTab}>
<Navbar
onDragOver={dragOverHandlerNone}
@ -346,11 +439,12 @@ export default class App extends React.Component {
)
: <div />
}
<SettingsModal
className="mx-3"
clearJobsStorage={this.clearRecentJobs}
showDownloadModal={() => this.showDownloadModal(true)}
nCPU={this.props.nCPU}
<AppMenu
openDownloadModal={() => this.toggleDownloadModal(true)}
openPluginModal={() => this.togglePluginModal(true)}
openChangelogModal={() => this.setState({ showChangelog: true })}
openSettingsModal={() => this.toggleSettingsModal(true)}
openMetadataModal={() => this.toggleMetadataModal(true)}
/>
</Col>
</Row>
@ -371,9 +465,10 @@ export default class App extends React.Component {
openInvestModel={this.openInvestModel}
recentJobs={recentJobs}
batchUpdateArgs={this.batchUpdateArgs}
deleteJob={this.deleteJob}
clearRecentJobs={this.clearRecentJobs}
/>
)
: <div />}
) : <div />}
</TabPane>
{investTabPanes}
</TabContent>

View File

@ -0,0 +1,56 @@
import React from 'react';
import Dropdown from 'react-bootstrap/Dropdown';
import { useTranslation } from 'react-i18next';
import { GiHamburgerMenu } from 'react-icons/gi';
export default function AppMenu(props) {
const { t } = useTranslation();
return (
<Dropdown>
<Dropdown.Toggle
className="app-menu-button"
aria-label="menu"
childBsPrefix="outline-secondary"
>
<GiHamburgerMenu />
</Dropdown.Toggle>
<Dropdown.Menu
align="right"
className="shadow"
>
<Dropdown.Item
as="button"
onClick={props.openPluginModal}
>
Manage Plugins
</Dropdown.Item>
<Dropdown.Item
as="button"
onClick={props.openDownloadModal}
>
Download Sample Data
</Dropdown.Item>
<Dropdown.Item
as="button"
onClick={props.openMetadataModal}
>
Configure Metadata
</Dropdown.Item>
<Dropdown.Item
as="button"
onClick={props.openChangelogModal}
>
View Changelog
</Dropdown.Item>
<Dropdown.Item
as="button"
onClick={props.openSettingsModal}
>
Settings
</Dropdown.Item>
</Dropdown.Menu>
</Dropdown>
);
}

View File

@ -88,8 +88,7 @@ export default function Changelog(props) {
and not, for example, sourced from user input. */}
<Modal.Body
dangerouslySetInnerHTML={htmlContent}
>
</Modal.Body>
/>
</Modal>
);
}

View File

@ -8,6 +8,7 @@ import Alert from 'react-bootstrap/Alert';
import Table from 'react-bootstrap/Table';
import {
MdErrorOutline,
MdClose,
} from 'react-icons/md';
import { withTranslation } from 'react-i18next';
@ -19,7 +20,7 @@ const { logger } = window.Workbench;
// A URL for sampledata to use in devMode, when the token containing the URL
// associated with a production build of the Workbench does not exist.
const BASE_URL = 'https://storage.googleapis.com/releases.naturalcapitalproject.org/invest/3.13.0/data';
const BASE_URL = 'https://storage.googleapis.com/releases.naturalcapitalproject.org/invest/3.15.1/data';
const DEFAULT_FILESIZE = 0;
/** Render a dialog with a form for configuring global invest settings */
@ -65,13 +66,13 @@ class DataDownloadModal extends React.Component {
const linksArray = [];
const modelCheckBoxState = {};
Object.entries(registry).forEach(([modelName, data]) => {
Object.entries(registry).forEach(([modelTitle, data]) => {
linksArray.push(`${baseURL}/${data.filename}`);
modelCheckBoxState[modelName] = true;
modelCheckBoxState[modelTitle] = true;
try {
registry[modelName].filesize = filesizes[data.filename];
registry[modelTitle].filesize = filesizes[data.filename];
} catch {
registry[modelName].filesize = DEFAULT_FILESIZE;
registry[modelTitle].filesize = DEFAULT_FILESIZE;
}
});
@ -137,20 +138,20 @@ class DataDownloadModal extends React.Component {
});
}
handleCheckList(event, modelName) {
handleCheckList(event, modelTitle) {
let {
selectedLinksArray,
modelCheckBoxState,
dataRegistry,
baseURL,
} = this.state;
const url = `${baseURL}/${dataRegistry[modelName].filename}`;
const url = `${baseURL}/${dataRegistry[modelTitle].filename}`;
if (event.currentTarget.checked) {
selectedLinksArray.push(url);
modelCheckBoxState[modelName] = true;
modelCheckBoxState[modelTitle] = true;
} else {
selectedLinksArray = selectedLinksArray.filter((val) => val !== url);
modelCheckBoxState[modelName] = false;
modelCheckBoxState[modelTitle] = false;
}
this.setState({
allDataCheck: false,
@ -209,26 +210,26 @@ class DataDownloadModal extends React.Component {
const downloadEnabled = Boolean(selectedLinksArray.length);
const DatasetCheckboxRows = [];
Object.keys(modelCheckBoxState)
.forEach((modelName) => {
const filesize = parseFloat(dataRegistry[modelName].filesize);
.forEach((modelTitle) => {
const filesize = parseFloat(dataRegistry[modelTitle].filesize);
const filesizeStr = `${(filesize / 1000000).toFixed(2)} MB`;
const note = dataRegistry[modelName].note || '';
const note = dataRegistry[modelTitle].note || '';
DatasetCheckboxRows.push(
<tr key={modelName}>
<tr key={modelTitle}>
<td>
<Form.Check
className="pt-1"
id={modelName}
id={modelTitle}
>
<Form.Check.Input
type="checkbox"
checked={modelCheckBoxState[modelName]}
checked={modelCheckBoxState[modelTitle]}
onChange={(event) => this.handleCheckList(
event, modelName
event, modelTitle
)}
/>
<Form.Check.Label>
{displayNames[modelName]}
{displayNames[modelTitle]}
</Form.Check.Label>
</Form.Check>
</td>
@ -263,10 +264,20 @@ class DataDownloadModal extends React.Component {
<p className="mb-0"><em>{this.state.alertPath}</em></p>
</Alert>
)
: <Modal.Title id="download-modal-title">
: (
<Modal.Title id="download-modal-title">
{t("Download InVEST sample data")}
</Modal.Title>
)
}
<Button
variant="secondary-outline"
onClick={this.closeDialog}
className="float-right"
aria-label="Close modal"
>
<MdClose />
</Button>
</Modal.Header>
<Modal.Body>
<Table
@ -294,12 +305,6 @@ class DataDownloadModal extends React.Component {
</Table>
</Modal.Body>
<Modal.Footer>
<Button
variant="secondary"
onClick={this.closeDialog}
>
{t("Cancel")}
</Button>
<Button
variant="primary"
onClick={this.handleSubmit}

View File

@ -1,12 +1,17 @@
import React from 'react';
import PropTypes from 'prop-types';
import Badge from 'react-bootstrap/Badge';
import ListGroup from 'react-bootstrap/ListGroup';
import Card from 'react-bootstrap/Card';
import Container from 'react-bootstrap/Container';
import Col from 'react-bootstrap/Col';
import Row from 'react-bootstrap/Row';
import Button from 'react-bootstrap/Button';
import { useTranslation } from 'react-i18next';
import {
MdClose,
} from 'react-icons/md';
import OpenButton from '../OpenButton';
import InvestJob from '../../InvestJob';
@ -20,53 +25,69 @@ const { logger } = window.Workbench;
export default class HomeTab extends React.Component {
constructor(props) {
super(props);
this.state = {
sortedModels: []
};
this.handleClick = this.handleClick.bind(this);
}
componentDidMount() {
// sort the model list alphabetically, by the model title,
// and with special placement of CBC Preprocessor before CBC model.
const sortedModels = Object.keys(this.props.investList).sort();
const cbcpElement = 'Coastal Blue Carbon Preprocessor';
const cbcIdx = sortedModels.indexOf('Coastal Blue Carbon');
const cbcpIdx = sortedModels.indexOf(cbcpElement);
if (cbcIdx > -1 && cbcpIdx > -1) {
sortedModels.splice(cbcpIdx, 1); // remove it
sortedModels.splice(cbcIdx, 0, cbcpElement); // insert it
}
this.setState({
sortedModels: sortedModels
});
}
handleClick(value) {
const { investList, openInvestModel } = this.props;
const modelRunName = investList[value].model_name;
const job = new InvestJob({
modelRunName: modelRunName,
modelHumanName: value
modelID: value,
modelTitle: investList[value].modelTitle,
type: investList[value].type,
});
openInvestModel(job);
}
render() {
const { recentJobs } = this.props;
const { sortedModels } = this.state;
const {
recentJobs,
investList,
openInvestModel,
deleteJob,
clearRecentJobs
} = this.props;
let sortedModelIds = {};
if (investList) {
// sort the model list alphabetically, by the model title,
// and with special placement of CBC Preprocessor before CBC model.
sortedModelIds = Object.keys(investList).sort(
(a, b) => {
if (investList[a].modelTitle > investList[b].modelTitle) {
return 1;
}
if (investList[b].modelTitle > investList[a].modelTitle) {
return -1;
}
return 0;
}
);
const cbcpElement = 'coastal_blue_carbon_preprocessor';
const cbcIdx = sortedModelIds.indexOf('coastal_blue_carbon');
const cbcpIdx = sortedModelIds.indexOf(cbcpElement);
if (cbcIdx > -1 && cbcpIdx > -1) {
sortedModelIds.splice(cbcpIdx, 1); // remove it
sortedModelIds.splice(cbcIdx, 0, cbcpElement); // insert it
}
}
// A button in a table row for each model
const investButtons = [];
sortedModels.forEach((model) => {
sortedModelIds.forEach((modelID) => {
const modelTitle = investList[modelID].modelTitle;
let badge;
if (investList[modelID].type === 'plugin') {
badge = <Badge className="mr-1" variant="secondary">Plugin</Badge>;
}
investButtons.push(
<ListGroup.Item
key={model}
className="invest-button"
title={model}
key={modelTitle}
name={modelTitle}
action
onClick={() => this.handleClick(model)}
onClick={() => this.handleClick(modelID)}
className="invest-button"
>
{model}
{ badge }
<span>{modelTitle}</span>
</ListGroup.Item>
);
});
@ -76,12 +97,25 @@ export default class HomeTab extends React.Component {
<Col md={6} className="invest-list-container">
<ListGroup className="invest-list-group">
{investButtons}
<ListGroup.Item
key="browse"
className="py-2 border-0"
>
<OpenButton
className="w-100 border-1 py-2 pl-3 text-left text-truncate"
openInvestModel={openInvestModel}
investList={investList}
/>
</ListGroup.Item>
</ListGroup>
</Col>
<Col className="recent-job-card-col">
<RecentInvestJobs
openInvestModel={this.props.openInvestModel}
openInvestModel={openInvestModel}
recentJobs={recentJobs}
investList={investList}
deleteJob={deleteJob}
clearRecentJobs={clearRecentJobs}
/>
</Col>
</Row>
@ -92,27 +126,35 @@ export default class HomeTab extends React.Component {
HomeTab.propTypes = {
investList: PropTypes.objectOf(
PropTypes.shape({
model_name: PropTypes.string,
}),
modelTitle: PropTypes.string,
type: PropTypes.string,
})
).isRequired,
openInvestModel: PropTypes.func.isRequired,
recentJobs: PropTypes.arrayOf(
PropTypes.shape({
modelRunName: PropTypes.string.isRequired,
modelHumanName: PropTypes.string.isRequired,
modelID: PropTypes.string.isRequired,
modelTitle: PropTypes.string.isRequired,
argsValues: PropTypes.object,
logfile: PropTypes.string,
status: PropTypes.string,
})
).isRequired,
deleteJob: PropTypes.func.isRequired,
clearRecentJobs: PropTypes.func.isRequired,
};
/**
* Renders a button for each recent invest job.
*/
function RecentInvestJobs(props) {
const { recentJobs, openInvestModel } = props;
const { t, i18n } = useTranslation();
const {
recentJobs,
openInvestModel,
deleteJob,
clearRecentJobs
} = props;
const { t } = useTranslation();
const handleClick = (jobMetadata) => {
try {
@ -124,18 +166,33 @@ function RecentInvestJobs(props) {
const recentButtons = [];
recentJobs.forEach((job) => {
if (job && job.argsValues && job.modelHumanName) {
if (job && job.argsValues && job.modelTitle) {
let badge;
if (job.type === 'plugin') {
badge = <Badge className="mr-1" variant="secondary">Plugin</Badge>;
}
recentButtons.push(
<Card
className="text-left recent-job-card"
as="button"
className="text-left recent-job-card mr-2 w-100"
key={job.hash}
onClick={() => handleClick(job)}
>
<Card.Body>
<Card.Header>
<span className="header-title">{job.modelHumanName}</span>
</Card.Header>
<Card.Header>
{badge}
<span className="header-title">{job.modelTitle}</span>
<Button
variant="outline-light"
onClick={() => deleteJob(job.hash)}
className="float-right p-1 mr-1 border-0"
aria-label="delete"
>
<MdClose />
</Button>
</Card.Header>
<Card.Body
className="text-left border-0"
as="button"
onClick={() => handleClick(job)}
>
<Card.Title>
<span className="text-heading">{'Workspace: '}</span>
<span className="text-mono">{job.argsValues.workspace_dir}</span>
@ -160,47 +217,57 @@ function RecentInvestJobs(props) {
});
return (
<>
<Container>
<Row>
<Col className="recent-header-col">
{recentButtons.length
? (
<h4>
{t('Recent runs:')}
</h4>
)
: (
<div className="default-text">
{t("Set up a model from a sample datastack file (.json) " +
"or from an InVEST model's logfile (.txt): ")}
</div>
)}
</Col>
<Col className="open-button-col">
<OpenButton
className="mr-2"
openInvestModel={openInvestModel}
/>
</Col>
</Row>
</Container>
<React.Fragment>
<Container>
<Row>
{recentButtons.length
? <div />
: (
<Card
className="text-left recent-job-card mr-2 w-100"
key="placeholder"
>
<Card.Header>
<span className="header-title">{t('Welcome!')}</span>
</Card.Header>
<Card.Body>
<Card.Title>
<span className="text-heading">
{t('After running a model, find your recent model runs here.')}
</span>
</Card.Title>
</Card.Body>
</Card>
)}
{recentButtons}
</React.Fragment>
</>
</Row>
{recentButtons.length
? (
<Row>
<Button
variant="secondary"
onClick={clearRecentJobs}
className="mr-2 w-100"
>
{t('Clear all model runs')}
</Button>
</Row>
)
: <div />}
</Container>
);
}
RecentInvestJobs.propTypes = {
recentJobs: PropTypes.arrayOf(
PropTypes.shape({
modelRunName: PropTypes.string.isRequired,
modelHumanName: PropTypes.string.isRequired,
modelID: PropTypes.string.isRequired,
modelTitle: PropTypes.string.isRequired,
argsValues: PropTypes.object,
logfile: PropTypes.string,
status: PropTypes.string,
})
).isRequired,
openInvestModel: PropTypes.func.isRequired,
deleteJob: PropTypes.func.isRequired,
clearRecentJobs: PropTypes.func.isRequired,
};

View File

@ -1,6 +1,7 @@
import React from 'react';
import PropTypes from 'prop-types';
import Spinner from 'react-bootstrap/Spinner';
import TabPane from 'react-bootstrap/TabPane';
import TabContent from 'react-bootstrap/TabContent';
import TabContainer from 'react-bootstrap/TabContainer';
@ -10,7 +11,7 @@ import Col from 'react-bootstrap/Col';
import Modal from 'react-bootstrap/Modal';
import Button from 'react-bootstrap/Button';
import {
MdKeyboardArrowRight,
MdKeyboardArrowRight
} from 'react-icons/md';
import { withTranslation } from 'react-i18next';
@ -19,34 +20,11 @@ import SetupTab from '../SetupTab';
import LogTab from '../LogTab';
import ResourcesLinks from '../ResourcesLinks';
import { getSpec } from '../../server_requests';
import { UI_SPEC } from '../../ui_config';
import { ipcMainChannels } from '../../../main/ipcMainChannels';
const { ipcRenderer } = window.Workbench.electron;
const { logger } = window.Workbench;
/** Get an invest model's MODEL_SPEC when a model button is clicked.
*
* @param {string} modelName - as in a model name appearing in `invest list`
* @returns {object} destructures to:
* { modelSpec, argsSpec, uiSpec }
*/
async function investGetSpec(modelName) {
const spec = await getSpec(modelName);
if (spec) {
const { args, ...modelSpec } = spec;
const uiSpec = UI_SPEC[modelName];
if (uiSpec) {
return { modelSpec: modelSpec, argsSpec: args, uiSpec: uiSpec };
}
logger.error(`no UI spec found for ${modelName}`);
} else {
logger.error(`no args spec found for ${modelName}`);
}
return undefined;
}
/**
* Render an invest model setup form, log display, etc.
* Manage launching of an invest model in a child process.
@ -59,9 +37,9 @@ class InvestTab extends React.Component {
activeTab: 'setup',
modelSpec: null, // MODEL_SPEC dict with all keys except MODEL_SPEC.args
argsSpec: null, // MODEL_SPEC.args, the immutable args stuff
uiSpec: null,
userTerminated: false,
executeClicked: false,
tabStatus: '',
showErrorModal: false,
};
@ -76,14 +54,35 @@ class InvestTab extends React.Component {
async componentDidMount() {
const { job } = this.props;
const {
modelSpec, argsSpec, uiSpec,
} = await investGetSpec(job.modelRunName);
this.setState({
modelSpec: modelSpec,
argsSpec: argsSpec,
uiSpec: uiSpec,
}, () => { this.switchTabs('setup'); });
// if it's a plugin, may need to start up the server
// otherwise, the core invest server should already be running
if (job.type === 'plugin') {
// if plugin server is already running, don't re-launch
// this will happen if we have >1 tab open with the same plugin
let pid = await ipcRenderer.invoke(
ipcMainChannels.GET_SETTING, `plugins.${job.modelID}.pid`);
if (!pid) {
pid = await ipcRenderer.invoke(
ipcMainChannels.LAUNCH_PLUGIN_SERVER,
job.modelID
);
if (!pid) {
this.setState({ tabStatus: 'failed' });
return;
}
}
}
try {
const { args, ...model_spec } = await getSpec(job.modelID);
this.setState({
modelSpec: model_spec,
argsSpec: args,
}, () => { this.switchTabs('setup'); });
} catch (error) {
console.log(error);
this.setState({ tabStatus: 'failed' });
return;
}
const { tabID } = this.props;
ipcRenderer.on(`invest-logging-${tabID}`, this.investLogfileCallback);
ipcRenderer.on(`invest-exit-${tabID}`, this.investExitCallback);
@ -160,8 +159,7 @@ class InvestTab extends React.Component {
ipcRenderer.send(
ipcMainChannels.INVEST_RUN,
job.modelRunName,
this.state.modelSpec.pyname,
job.modelID,
args,
tabID
);
@ -209,27 +207,44 @@ class InvestTab extends React.Component {
activeTab,
modelSpec,
argsSpec,
uiSpec,
executeClicked,
tabStatus,
showErrorModal,
} = this.state;
const {
status,
modelRunName,
modelID,
argsValues,
logfile,
} = this.props.job;
const { tabID, t } = this.props;
const { tabID, investList, t } = this.props;
if (tabStatus === 'failed') {
return (
<div className="invest-tab-loading">
{t('Failed to launch plugin')}
</div>
);
}
// Don't render the model setup & log until data has been fetched.
if (!modelSpec) {
return (<div />);
return (
<div className="invest-tab-loading">
<Spinner animation="border" role="status">
<span className="sr-only">Loading...</span>
</Spinner>
<br />
{t('Starting up model...')}
</div>
);
}
const logDisabled = !logfile;
const sidebarSetupElementId = `sidebar-setup-${tabID}`;
const sidebarFooterElementId = `sidebar-footer-${tabID}`;
const isCoreModel = investList[modelID].type === 'core';
return (
<>
@ -260,7 +275,8 @@ class InvestTab extends React.Component {
/>
<div className="sidebar-row sidebar-links">
<ResourcesLinks
moduleName={modelRunName}
modelID={modelID}
isCoreModel={isCoreModel}
docs={modelSpec.userguide}
/>
</div>
@ -288,17 +304,18 @@ class InvestTab extends React.Component {
aria-label="model setup tab"
>
<SetupTab
pyModuleName={modelSpec.pyname}
userguide={modelSpec.userguide}
modelName={modelRunName}
isCoreModel={isCoreModel}
modelID={modelID}
argsSpec={argsSpec}
uiSpec={uiSpec}
inputFieldOrder={modelSpec.input_field_order}
argsInitValues={argsValues}
investExecute={this.investExecute}
sidebarSetupElementId={sidebarSetupElementId}
sidebarFooterElementId={sidebarFooterElementId}
executeClicked={executeClicked}
switchTabs={this.switchTabs}
investList={investList}
tabID={tabID}
updateJobProperties={this.props.updateJobProperties}
/>
@ -333,15 +350,19 @@ class InvestTab extends React.Component {
InvestTab.propTypes = {
job: PropTypes.shape({
modelRunName: PropTypes.string.isRequired,
modelHumanName: PropTypes.string.isRequired,
modelID: PropTypes.string.isRequired,
argsValues: PropTypes.object,
logfile: PropTypes.string,
status: PropTypes.string,
type: PropTypes.string,
}).isRequired,
tabID: PropTypes.string.isRequired,
saveJob: PropTypes.func.isRequired,
updateJobProperties: PropTypes.func.isRequired,
investList: PropTypes.shape({
modelTitle: PropTypes.string,
}).isRequired,
t: PropTypes.func.isRequired,
};
export default withTranslation()(InvestTab);

View File

@ -6,13 +6,15 @@ import Col from 'react-bootstrap/Col';
import Alert from 'react-bootstrap/Alert';
import Button from 'react-bootstrap/Button';
import Form from 'react-bootstrap/Form';
import Modal from 'react-bootstrap/Modal';
import { MdClose } from 'react-icons/md';
import {
getGeoMetaMakerProfile,
setGeoMetaMakerProfile,
} from '../../../server_requests';
} from '../../server_requests';
import { openLinkInBrowser } from '../../../utils';
import { openLinkInBrowser } from '../../utils';
function AboutMetadataDiv() {
const { t } = useTranslation();
@ -67,9 +69,9 @@ function FormRow(label, value, handler) {
}
/**
* A form for submitting GeoMetaMaker profile data.
* A Modal with form for submitting GeoMetaMaker profile data.
*/
export default function MetadataForm() {
export default function MetadataModal(props) {
const { t } = useTranslation();
const [contactName, setContactName] = useState('');
@ -80,7 +82,6 @@ export default function MetadataForm() {
const [licenseURL, setLicenseURL] = useState('');
const [alertMsg, setAlertMsg] = useState('');
const [alertError, setAlertError] = useState(false);
const [showInfo, setShowInfo] = useState(false);
useEffect(() => {
async function loadProfile() {
@ -127,65 +128,74 @@ export default function MetadataForm() {
};
return (
<div id="metadata-form">
{
(showInfo)
? <AboutMetadataDiv />
: (
<Form onSubmit={handleSubmit} onChange={handleChange}>
<fieldset>
<legend>{t('Contact Information')}</legend>
<Form.Group controlId="name">
{FormRow(t('Full name'), contactName, setContactName)}
</Form.Group>
<Form.Group controlId="email">
{FormRow(t('Email address'), contactEmail, setContactEmail)}
</Form.Group>
<Form.Group controlId="job-title">
{FormRow(t('Job title'), contactPosition, setContactPosition)}
</Form.Group>
<Form.Group controlId="organization">
{FormRow(t('Organization name'), contactOrg, setContactOrg)}
</Form.Group>
</fieldset>
<fieldset>
<legend>{t('Data License Information')}</legend>
<Form.Group controlId="license-title">
{FormRow(t('Title'), licenseTitle, setLicenseTitle)}
</Form.Group>
<Form.Group controlId="license-url">
{FormRow('URL', licenseURL, setLicenseURL)}
</Form.Group>
</fieldset>
<Form.Row>
<Button
type="submit"
variant="primary"
className="my-1 py2 mx-2"
<Modal
show={props.show}
onHide={props.close}
size="lg"
aria-labelledby="metadata-modal-title"
>
<Modal.Header>
<Modal.Title id="metadata-modal-title">
{t('Configure Metadata')}
</Modal.Title>
<Button
variant="secondary-outline"
onClick={props.close}
className="float-right"
aria-label="Close modal"
>
<MdClose />
</Button>
</Modal.Header>
<Modal.Body>
<AboutMetadataDiv />
<hr />
<Form onSubmit={handleSubmit} onChange={handleChange}>
<fieldset>
<legend>{t('Contact Information')}</legend>
<Form.Group controlId="name">
{FormRow(t('Full name'), contactName, setContactName)}
</Form.Group>
<Form.Group controlId="email">
{FormRow(t('Email address'), contactEmail, setContactEmail)}
</Form.Group>
<Form.Group controlId="job-title">
{FormRow(t('Job title'), contactPosition, setContactPosition)}
</Form.Group>
<Form.Group controlId="organization">
{FormRow(t('Organization name'), contactOrg, setContactOrg)}
</Form.Group>
</fieldset>
<fieldset>
<legend>{t('Data License Information')}</legend>
<Form.Group controlId="license-title">
{FormRow(t('Title'), licenseTitle, setLicenseTitle)}
</Form.Group>
<Form.Group controlId="license-url">
{FormRow('URL', licenseURL, setLicenseURL)}
</Form.Group>
</fieldset>
<Form.Row>
<Button
type="submit"
variant="primary"
className="my-1 py2 mx-2"
>
{t('Save Metadata')}
</Button>
{
(alertMsg) && (
<Alert
className="my-1 py-2"
variant={alertError ? 'danger' : 'success'}
>
{t('Save Metadata')}
</Button>
{
(alertMsg) && (
<Alert
className="my-1 py-2"
variant={alertError ? 'danger' : 'success'}
>
{alertMsg}
</Alert>
)
}
</Form.Row>
</Form>
)
}
<Button
variant="outline-secondary"
className="my-1 py-2 mx-2 info-toggle"
onClick={() => setShowInfo((prevState) => !prevState)}
>
{showInfo ? t('Hide Info') : t('More Info')}
</Button>
</div>
{alertMsg}
</Alert>
)
}
</Form.Row>
</Form>
</Modal.Body>
</Modal>
);
}

View File

@ -24,7 +24,7 @@ class OpenButton extends React.Component {
}
async browseFile() {
const { t } = this.props;
const { t, investList, openInvestModel } = this.props;
const data = await ipcRenderer.invoke(ipcMainChannels.SHOW_OPEN_DIALOG);
if (!data.canceled) {
let datastack;
@ -38,29 +38,32 @@ class OpenButton extends React.Component {
return;
}
const job = new InvestJob({
modelRunName: datastack.model_run_name,
modelHumanName: datastack.model_human_name,
modelID: datastack.model_id,
modelTitle: investList[datastack.model_id].modelTitle,
argsValues: datastack.args,
type: investList[datastack.model_id].type,
});
this.props.openInvestModel(job);
openInvestModel(job);
}
}
render() {
const { t } = this.props;
const tipText = t('Browse to a datastack (.json) or InVEST logfile (.txt)');
const { t, className } = this.props;
const tipText = t(
`Open an InVEST model by loading input parameters from
a .json, .tgz, or InVEST logfile (.txt)`);
return (
<OverlayTrigger
placement="left"
placement="right"
delay={{ show: 250, hide: 400 }}
overlay={<Tooltip>{tipText}</Tooltip>}
>
<Button
className={this.props.className}
className={className}
onClick={this.browseFile}
variant="outline-dark"
variant="outline-primary"
>
{t("Open")}
{t('Browse to a datastack or InVEST logfile')}
</Button>
</OverlayTrigger>
);
@ -69,6 +72,12 @@ class OpenButton extends React.Component {
OpenButton.propTypes = {
openInvestModel: PropTypes.func.isRequired,
investList: PropTypes.shape({
modelTitle: PropTypes.string,
type: PropTypes.string,
}).isRequired,
t: PropTypes.func.isRequired,
className: PropTypes.string.isRequired,
};
export default withTranslation()(OpenButton);

View File

@ -0,0 +1,324 @@
import React, { useEffect, useState } from 'react';
import PropTypes from 'prop-types';
import Button from 'react-bootstrap/Button';
import Col from 'react-bootstrap/Col';
import Form from 'react-bootstrap/Form';
import Modal from 'react-bootstrap/Modal';
import Spinner from 'react-bootstrap/Spinner';
import { useTranslation } from 'react-i18next';
import { MdClose } from 'react-icons/md';
import { ipcMainChannels } from '../../../main/ipcMainChannels';
const { ipcRenderer } = window.Workbench.electron;
export default function PluginModal(props) {
const {
updateInvestList,
closeInvestModel,
openJobs,
show,
closeModal,
openModal,
} = props;
const [url, setURL] = useState('');
const [revision, setRevision] = useState('');
const [path, setPath] = useState('');
const [installErr, setInstallErr] = useState('');
const [uninstallErr, setUninstallErr] = useState('');
const [pluginToRemove, setPluginToRemove] = useState('');
const [installLoading, setInstallLoading] = useState(false);
const [uninstallLoading, setUninstallLoading] = useState(false);
const [statusMessage, setStatusMessage] = useState('');
const [needsMSVC, setNeedsMSVC] = useState(false);
const [plugins, setPlugins] = useState({});
const [installFrom, setInstallFrom] = useState('url');
const handleModalClose = () => {
setURL('');
setRevision('');
setInstallErr('');
setUninstallErr('');
closeModal();
};
const addPlugin = () => {
setInstallLoading(true);
ipcRenderer.on(`plugin-install-status`, (msg) => { setStatusMessage(msg); });
ipcRenderer.invoke(
ipcMainChannels.ADD_PLUGIN,
installFrom === 'url' ? url : undefined, // url
installFrom === 'url' ? revision : undefined, // revision
installFrom === 'path' ? path : undefined // path
).then((addPluginErr) => {
setInstallLoading(false);
updateInvestList();
if (addPluginErr) {
setInstallErr(addPluginErr);
} else {
// clear the input fields
setURL('');
setRevision('');
setPath('');
}
});
};
const removePlugin = () => {
setUninstallLoading(true);
Object.keys(openJobs).forEach((tabID) => {
if (openJobs[tabID].modelID === pluginToRemove) {
closeInvestModel(tabID);
}
});
ipcRenderer.invoke(ipcMainChannels.REMOVE_PLUGIN, pluginToRemove).then((err) => {
if (err) {
setUninstallErr(err)
} else {
updateInvestList();
setUninstallLoading(false);
}
});
};
const downloadMSVC = () => {
closeModal();
ipcRenderer.invoke(ipcMainChannels.DOWNLOAD_MSVC).then(
openModal()
);
};
useEffect(() => {
if (show) {
if (window.Workbench.OS === 'win32') {
ipcRenderer.invoke(ipcMainChannels.HAS_MSVC).then((hasMSVC) => {
setNeedsMSVC(!hasMSVC);
});
}
}
}, [show]);
useEffect(() => {
ipcRenderer.invoke(ipcMainChannels.GET_SETTING, 'plugins').then(
(data) => {
if (data) {
setPlugins(data);
setPluginToRemove(Object.keys(data)[0]);
}
}
);
}, [installLoading, uninstallLoading]);
const { t } = useTranslation();
let pluginFields;
if (installFrom === 'url') {
pluginFields = (
<Form.Row>
<Form.Group as={Col} xs={7}>
<Form.Label htmlFor="url">{t('Git URL')}</Form.Label>
<Form.Control
id="url"
type="text"
placeholder="https://github.com/owner/repo.git"
value={url}
onChange={(event) => setURL(event.currentTarget.value)}
/>
<Form.Text className="text-muted">
<i>{t('Default branch used unless otherwise specified')}</i>
</Form.Text>
</Form.Group>
<Form.Group as={Col}>
<Form.Label htmlFor="branch">{t('Branch, tag, or commit')}</Form.Label>
<Form.Control
id="branch"
type="text"
value={revision}
onChange={(event) => setRevision(event.currentTarget.value)}
/>
<Form.Text className="text-muted">
<i>{t('Optional')}</i>
</Form.Text>
</Form.Group>
</Form.Row>
);
} else {
pluginFields = (
<Form.Group>
<Form.Label htmlFor="path">{t('Local absolute path')}</Form.Label>
<Form.Control
id="path"
type="text"
placeholder={window.Workbench.OS === 'darwin'
? '/Users/username/path/to/plugin/'
: 'C:\\Documents\\path\\to\\plugin\\'}
value={path}
onChange={(event) => setPath(event.currentTarget.value)}
/>
</Form.Group>
);
}
let modalBody = (
<Modal.Body>
<Form>
<Form.Group>
<h5 className="mb-3">{t('Add a plugin')}</h5>
<Form.Group>
<Form.Label htmlFor="installFrom">{t('Install from')}</Form.Label>
<Form.Control
id="installFrom"
as="select"
onChange={(event) => setInstallFrom(event.target.value)}
className="w-auto"
>
<option value="url">{t('git URL')}</option>
<option value="path">{t('local path')}</option>
</Form.Control>
</Form.Group>
{pluginFields}
<Button
disabled={installLoading}
onClick={addPlugin}
>
{
installLoading ? (
<div className="adding-button">
<Spinner animation="border" role="status" size="sm" className="plugin-spinner">
<span className="sr-only">{t('Adding...')}</span>
</Spinner>
{t(statusMessage)}
</div>
) : t('Add')
}
</Button>
<Form.Text className="text-muted">
{t('This may take several minutes')}
</Form.Text>
</Form.Group>
<hr />
<Form.Group>
<h5 className="mb-3">{t('Remove a plugin')}</h5>
<Form.Label htmlFor="selectPluginToRemove">{t('Plugin name')}</Form.Label>
<Form.Control
id="selectPluginToRemove"
as="select"
value={pluginToRemove}
onChange={(event) => setPluginToRemove(event.currentTarget.value)}
>
{
Object.keys(plugins).map(
(pluginID) => (
<option
value={pluginID}
key={pluginID}
>
{plugins[pluginID].modelTitle}
</option>
)
)
}
</Form.Control>
<Button
disabled={uninstallLoading || !Object.keys(plugins).length}
className="mt-3"
onClick={removePlugin}
>
{
uninstallLoading ? (
<div className="adding-button">
<Spinner animation="border" role="status" size="sm" className="plugin-spinner">
<span className="sr-only">{t('Removing...')}</span>
</Spinner>
{t('Removing...')}
</div>
) : t('Remove')
}
</Button>
</Form.Group>
</Form>
</Modal.Body>
);
if (installErr) {
modalBody = (
<Modal.Body>
<h5>{t('Error installing plugin:')}</h5>
<div className="plugin-error">{installErr}</div>
<Button
onClick={() => ipcRenderer.send(
ipcMainChannels.SHOW_ITEM_IN_FOLDER,
window.Workbench.ELECTRON_LOG_PATH,
)}
>
{t('Find workbench logs')}
</Button>
</Modal.Body>
);
} else if (uninstallErr) {
modalBody = (
<Modal.Body>
<h5>{t('Error removing plugin:')}</h5>
<div className="plugin-error">{uninstallErr}</div>
<Button
onClick={() => ipcRenderer.send(
ipcMainChannels.SHOW_ITEM_IN_FOLDER,
window.Workbench.ELECTRON_LOG_PATH,
)}
>
{t('Find workbench logs')}
</Button>
</Modal.Body>
);
}
if (needsMSVC) {
modalBody = (
<Modal.Body>
<h5>
{t('Microsoft Visual C++ Redistributable must be installed!')}
</h5>
{t('Plugin features require the ')}
<a href="https://learn.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist">
{t('Microsoft Visual C++ Redistributable')}
</a>
{t('. You must download and install the redistributable before continuing.')}
<Button
className="mt-3"
onClick={downloadMSVC}
>
{t('Continue to download and install')}
</Button>
</Modal.Body>
);
}
return (
<Modal show={show} onHide={handleModalClose} contentClassName="plugin-modal">
<Modal.Header>
<Modal.Title>{t('Manage plugins')}</Modal.Title>
<Button
variant="secondary-outline"
onClick={handleModalClose}
className="float-right"
aria-label="Close modal"
>
<MdClose />
</Button>
</Modal.Header>
{modalBody}
</Modal>
);
}
PluginModal.propTypes = {
show: PropTypes.bool.isRequired,
closeModal: PropTypes.func.isRequired,
openModal: PropTypes.func.isRequired,
updateInvestList: PropTypes.func.isRequired,
closeInvestModel: PropTypes.func.isRequired,
openJobs: PropTypes.shape({
modelID: PropTypes.string,
}).isRequired,
};

View File

@ -40,17 +40,6 @@ const FORUM_TAGS = {
wind_energy: 'wind-energy',
};
/**
* Open the target href in an electron window.
*/
function handleUGClick(event) {
event.preventDefault();
ipcRenderer.send(
ipcMainChannels.OPEN_LOCAL_HTML, event.currentTarget.href
);
}
/** Render model-relevant links to the User's Guide and Forum.
*
* This should be a link to the model's User's Guide chapter and
@ -58,46 +47,75 @@ function handleUGClick(event) {
* e.g. https://community.naturalcapitalproject.org/tag/carbon
*/
export default function ResourcesTab(props) {
const { docs, moduleName } = props;
const { docs, isCoreModel, modelID } = props;
let forumURL = FORUM_ROOT;
const tagName = FORUM_TAGS[moduleName];
const tagName = FORUM_TAGS[modelID];
if (tagName) {
forumURL = `${FORUM_ROOT}/tag/${tagName}`;
}
const { t, i18n } = useTranslation();
const userGuideURL = `${window.Workbench.USERGUIDE_PATH}/${window.Workbench.LANGUAGE}/${docs}`;
const { t } = useTranslation();
const userGuideURL = (
isCoreModel
? `${window.Workbench.USERGUIDE_PATH}/${window.Workbench.LANGUAGE}/${docs}`
: docs
);
const userGuideDisplayText = isCoreModel ? "User's Guide" : "Plugin Documentation";
const userGuideAddlInfo = isCoreModel ? '(opens in new window)' : '(opens in web browser)';
const userGuideAriaLabel = `${userGuideDisplayText} ${userGuideAddlInfo}`;
/**
* Open the target href in an electron window.
*/
const handleUGClick = (event) => {
event.preventDefault();
if (isCoreModel) {
ipcRenderer.send(
ipcMainChannels.OPEN_LOCAL_HTML, event.currentTarget.href
);
} else {
ipcRenderer.send(
ipcMainChannels.OPEN_EXTERNAL_URL, event.currentTarget.href
);
}
}
return (
<React.Fragment>
<a
href={userGuideURL}
title={userGuideURL}
aria-label="go to user's guide in web browser"
onClick={handleUGClick}
>
<MdOpenInNew className="mr-1" />
{t("User's Guide")}
</a>
{
userGuideURL
&&
<a
href={userGuideURL}
title={userGuideURL}
aria-label={t(userGuideAriaLabel)}
onClick={handleUGClick}
>
<MdOpenInNew className="mr-1" />
{t(userGuideDisplayText)}
</a>
}
<a
href={forumURL}
title={forumURL}
aria-label="go to frequently asked questions in web browser"
aria-label={t('Frequently Asked Questions (opens in web browser)')}
onClick={openLinkInBrowser}
>
<MdOpenInNew className="mr-1" />
{t("Frequently Asked Questions")}
{t('Frequently Asked Questions')}
</a>
</React.Fragment>
);
}
ResourcesTab.propTypes = {
moduleName: PropTypes.string,
modelID: PropTypes.string,
isCoreModel: PropTypes.bool.isRequired,
docs: PropTypes.string,
};
ResourcesTab.defaultProps = {
moduleName: undefined,
modelID: undefined,
docs: '',
};

View File

@ -31,16 +31,16 @@ class SaveAsModal extends React.Component {
async browseSaveFile(event) {
const {
modelName,
modelID,
saveJsonFile,
saveDatastack,
savePythonScript
} = this.props;
const { datastackType, relativePaths } = this.state;
const defaultTargetPaths = {
json: `invest_${modelName}_args.json`,
tgz: `invest_${modelName}_datastack.tgz`,
py: `execute_invest_${modelName}.py`,
json: `invest_${modelID}_args.json`,
tgz: `invest_${modelID}_datastack.tgz`,
py: `execute_invest_${modelID}.py`,
};
const data = await ipcRenderer.invoke(

View File

@ -8,7 +8,6 @@ import Form from 'react-bootstrap/Form';
import Button from 'react-bootstrap/Button';
import Modal from 'react-bootstrap/Modal';
import {
MdSettings,
MdClose,
MdTranslate,
} from 'react-icons/md';
@ -17,7 +16,6 @@ import { withTranslation } from 'react-i18next';
import { ipcMainChannels } from '../../../main/ipcMainChannels';
import { getSupportedLanguages } from '../../server_requests';
import MetadataForm from './MetadataForm';
const { ipcRenderer } = window.Workbench.electron;
@ -26,21 +24,17 @@ class SettingsModal extends React.Component {
constructor(props) {
super(props);
this.state = {
show: false,
languageOptions: null,
loggingLevel: null,
taskgraphLoggingLevel: null,
nWorkers: null,
loggingLevel: '',
taskgraphLoggingLevel: '',
nWorkers: -1,
language: window.Workbench.LANGUAGE,
showConfirmLanguageChange: false,
};
this.handleShow = this.handleShow.bind(this);
this.handleClose = this.handleClose.bind(this);
this.handleChange = this.handleChange.bind(this);
this.handleChangeNumber = this.handleChangeNumber.bind(this);
this.loadSettings = this.loadSettings.bind(this);
this.handleChangeLanguage = this.handleChangeLanguage.bind(this);
this.switchToDownloadModal = this.switchToDownloadModal.bind(this);
}
async componentDidMount() {
@ -51,16 +45,6 @@ class SettingsModal extends React.Component {
this.loadSettings();
}
handleClose() {
this.setState({
show: false,
});
}
handleShow() {
this.setState({ show: true });
}
handleChange(event) {
const { name, value } = event.currentTarget;
this.setState({ [name]: value });
@ -97,14 +81,8 @@ class SettingsModal extends React.Component {
}
}
switchToDownloadModal() {
this.props.showDownloadModal();
this.handleClose();
}
render() {
const {
show,
languageOptions,
language,
loggingLevel,
@ -112,7 +90,7 @@ class SettingsModal extends React.Component {
nWorkers,
showConfirmLanguageChange,
} = this.state;
const { clearJobsStorage, nCPU, t } = this.props;
const { show, close, nCPU, t } = this.props;
const nWorkersOptions = [
[-1, `${t('Synchronous')} (-1)`],
@ -129,28 +107,18 @@ class SettingsModal extends React.Component {
};
return (
<React.Fragment>
<Button
aria-label="settings"
className="settings-icon-btn"
onClick={this.handleShow}
>
<MdSettings
className="settings-icon"
/>
</Button>
<Modal
className="settings-modal"
show={show}
onHide={this.handleClose}
onHide={close}
>
<Modal.Header>
<Modal.Title>{t('InVEST Settings')}</Modal.Title>
<Button
variant="secondary-outline"
onClick={this.handleClose}
onClick={close}
className="float-right"
aria-label="close settings"
aria-label="close modal"
>
<MdClose />
</Button>
@ -273,44 +241,14 @@ class SettingsModal extends React.Component {
: <div />
}
<hr />
<Button
variant="primary"
onClick={this.switchToDownloadModal}
className="w-100"
>
{t('Download Sample Data')}
</Button>
<hr />
<Button
variant="secondary"
onClick={clearJobsStorage}
className="mr-2 w-100"
>
{t('Clear Recent Jobs')}
</Button>
<span><em>{t('*no invest workspaces will be deleted')}</em></span>
<hr />
<Accordion>
<Accordion.Toggle
as={Button}
variant="outline-secondary"
eventKey="0"
className="mr-2 w-100"
>
{t('Configure Metadata')}
<BsChevronDown className="mx-1" />
</Accordion.Toggle>
<Accordion.Collapse eventKey="0">
<MetadataForm />
</Accordion.Collapse>
</Accordion>
</Modal.Body>
</Modal>
{
(languageOptions) ? (
<Modal show={showConfirmLanguageChange} className="confirm-modal" >
<Modal show={showConfirmLanguageChange} className="confirm-modal">
<Modal.Header>
<Modal.Title as="h5" >{t('Warning')}</Modal.Title>
<Modal.Title as="h5">{t('Warning')}</Modal.Title>
</Modal.Header>
<Modal.Body>
<p>
@ -336,8 +274,8 @@ class SettingsModal extends React.Component {
}
SettingsModal.propTypes = {
clearJobsStorage: PropTypes.func.isRequired,
showDownloadModal: PropTypes.func.isRequired,
show: PropTypes.bool.isRequired,
close: PropTypes.func.isRequired,
nCPU: PropTypes.number.isRequired,
};

View File

@ -164,6 +164,7 @@ export default function ArgInput(props) {
argkey,
argSpec,
userguide,
isCoreModel,
enabled,
updateArgValues,
handleFocus,
@ -322,7 +323,12 @@ export default function ArgInput(props) {
<Col>
<InputGroup>
<div className="d-flex flex-nowrap w-100">
<AboutModal arg={argSpec} userguide={userguide} argkey={argkey} />
<AboutModal
arg={argSpec}
userguide={userguide}
isCoreModel={isCoreModel}
argkey={argkey}
/>
{form}
</div>
{feedback}
@ -341,6 +347,7 @@ ArgInput.propTypes = {
units: PropTypes.string, // for numbers only
}).isRequired,
userguide: PropTypes.string.isRequired,
isCoreModel: PropTypes.bool.isRequired,
value: PropTypes.oneOfType(
[PropTypes.string, PropTypes.bool, PropTypes.number]),
touched: PropTypes.bool,
@ -380,13 +387,16 @@ function AboutModal(props) {
const handleAboutClose = () => setAboutShow(false);
const handleAboutOpen = () => setAboutShow(true);
const { userguide, arg, argkey } = props;
const { userguide, arg, argkey, isCoreModel } = props;
const { t, i18n } = useTranslation();
// create link to users guide entry for this arg
// create link to users guide entry for this arg IFF this is a core model
// anchor name is the arg name, with underscores replaced with hyphens
const userguideURL = `
${window.Workbench.USERGUIDE_PATH}/${window.Workbench.LANGUAGE}/${userguide}#${argkey.replace(/_/g, '-')}`;
const userguideURL = (
isCoreModel
? `${window.Workbench.USERGUIDE_PATH}/${window.Workbench.LANGUAGE}/${userguide}#${argkey.replace(/_/g, '-')}`
: null
);
return (
<React.Fragment>
<Button
@ -404,15 +414,19 @@ function AboutModal(props) {
<Modal.Body>
{arg.about}
<br />
<a
href={userguideURL}
title={userguideURL}
aria-label="open user guide section for this input in web browser"
onClick={handleClickUsersGuideLink}
>
{t("User's guide entry")}
<MdOpenInNew className="mr-1" />
</a>
{
isCoreModel
&&
<a
href={userguideURL}
title={userguideURL}
aria-label={t("User's guide entry (opens in new window)")}
onClick={handleClickUsersGuideLink}
>
{t("User's guide entry")}
<MdOpenInNew className="mr-1" />
</a>
}
</Modal.Body>
</Modal>
</React.Fragment>
@ -425,5 +439,6 @@ AboutModal.propTypes = {
about: PropTypes.string,
}).isRequired,
userguide: PropTypes.string.isRequired,
isCoreModel: PropTypes.bool.isRequired,
argkey: PropTypes.string.isRequired,
};

View File

@ -124,6 +124,7 @@ class ArgsForm extends React.Component {
argsEnabled,
argsDropdownOptions,
userguide,
isCoreModel,
scrollEventCount,
} = this.props;
const formItems = [];
@ -137,6 +138,7 @@ class ArgsForm extends React.Component {
argkey={argkey}
argSpec={argsSpec[argkey]}
userguide={userguide}
isCoreModel={isCoreModel}
dropdownOptions={argsDropdownOptions[argkey]}
enabled={argsEnabled[argkey]}
updateArgValues={this.props.updateArgValues}
@ -199,7 +201,9 @@ ArgsForm.propTypes = {
argsOrder: PropTypes.arrayOf(
PropTypes.arrayOf(PropTypes.string)
).isRequired,
argsEnabled: PropTypes.objectOf(PropTypes.bool),
userguide: PropTypes.string.isRequired,
isCoreModel: PropTypes.bool.isRequired,
updateArgValues: PropTypes.func.isRequired,
loadParametersFromFile: PropTypes.func.isRequired,
scrollEventCount: PropTypes.number,

View File

@ -19,6 +19,8 @@ import {
archiveDatastack,
fetchDatastackFromFile,
fetchValidation,
fetchArgsEnabled,
getDynamicDropdowns,
saveToPython,
writeParametersToFile
} from '../../server_requests';
@ -33,7 +35,7 @@ const { logger } = window.Workbench;
* Values initialize with either a complete args dict, or with empty/default values.
*
* @param {object} argsSpec - an InVEST model's MODEL_SPEC.args
* @param {object} uiSpec - the model's UI Spec.
* @param {object} inputFieldOrder - the order in which to display the input fields.
* @param {object} argsDict - key: value pairs of InVEST model arguments, or {}.
*
* @returns {object} to destructure into two args,
@ -43,11 +45,11 @@ const { logger } = window.Workbench;
* {object} argsDropdownOptions - stores lists of dropdown options for
* args of type 'option_string'.
*/
function initializeArgValues(argsSpec, uiSpec, argsDict) {
function initializeArgValues(argsSpec, inputFieldOrder, argsDict) {
const initIsEmpty = Object.keys(argsDict).length === 0;
const argsValues = {};
const argsDropdownOptions = {};
uiSpec.order.flat().forEach((argkey) => {
inputFieldOrder.flat().forEach((argkey) => {
// When initializing with undefined values, assign defaults so that,
// a) values are handled well by the html inputs and
// b) the object exported to JSON on "Save" or "Execute" includes defaults.
@ -85,6 +87,8 @@ class SetupTab extends React.Component {
super(props);
this._isMounted = false;
this.validationTimer = null;
this.enabledTimer = null;
this.dropdownTimer = null;
this.state = {
argsValues: null,
@ -104,13 +108,16 @@ class SetupTab extends React.Component {
this.wrapInvestExecute = this.wrapInvestExecute.bind(this);
this.investValidate = this.investValidate.bind(this);
this.debouncedValidate = this.debouncedValidate.bind(this);
this.investArgsEnabled = this.investArgsEnabled.bind(this);
this.debouncedArgsEnabled = this.debouncedArgsEnabled.bind(this);
this.updateArgTouched = this.updateArgTouched.bind(this);
this.updateArgValues = this.updateArgValues.bind(this);
this.batchUpdateArgs = this.batchUpdateArgs.bind(this);
this.callUISpecFunctions = this.callUISpecFunctions.bind(this);
this.browseForDatastack = this.browseForDatastack.bind(this);
this.loadParametersFromFile = this.loadParametersFromFile.bind(this);
this.triggerScrollEvent = this.triggerScrollEvent.bind(this);
this.callDropdownFunctions = this.callDropdownFunctions.bind(this);
this.debouncedDropdownFunctions = this.debouncedDropdownFunctions.bind(this);
}
componentDidMount() {
@ -122,12 +129,12 @@ class SetupTab extends React.Component {
* not on every re-render.
*/
this._isMounted = true;
const { argsInitValues, argsSpec, uiSpec } = this.props;
const { argsInitValues, argsSpec, inputFieldOrder } = this.props;
const {
argsValues,
argsDropdownOptions,
} = initializeArgValues(argsSpec, uiSpec, argsInitValues || {});
} = initializeArgValues(argsSpec, inputFieldOrder, argsInitValues || {});
// map each arg to an empty object, to fill in later
// here we use the argsSpec because it includes all args, even ones like
@ -136,10 +143,10 @@ class SetupTab extends React.Component {
acc[argkey] = {};
return acc;
}, {});
// here we only use the keys in uiSpec.order because args that
// here we only use the keys in inputFieldOrder because args that
// aren't displayed in the form don't need an enabled/disabled state.
// all args default to being enabled
const argsEnabled = uiSpec.order.flat().reduce((acc, argkey) => {
const argsEnabled = inputFieldOrder.flat().reduce((acc, argkey) => {
acc[argkey] = true;
return acc;
}, {});
@ -151,13 +158,15 @@ class SetupTab extends React.Component {
argsDropdownOptions: argsDropdownOptions,
}, () => {
this.investValidate();
this.callUISpecFunctions();
this.investArgsEnabled();
});
}
componentWillUnmount() {
this._isMounted = false;
clearTimeout(this.validationTimer);
clearTimeout(this.enabledTimer);
clearTimeout(this.dropdownTimer);
}
/**
@ -173,38 +182,6 @@ class SetupTab extends React.Component {
}));
}
/**
* Call functions from the UI spec to determine the enabled/disabled
* state and dropdown options for each input, if applicable.
*
* @returns {undefined}
*/
async callUISpecFunctions() {
const { enabledFunctions, dropdownFunctions } = this.props.uiSpec;
if (enabledFunctions) {
// this model has some fields that are conditionally enabled
const { argsEnabled } = this.state;
Object.keys(enabledFunctions).forEach((key) => {
argsEnabled[key] = enabledFunctions[key](this.state);
});
if (this._isMounted) {
this.setState({ argsEnabled: argsEnabled });
}
}
if (dropdownFunctions) {
// this model has a dropdown that's dynamically populated
const { argsDropdownOptions } = this.state;
await Promise.all(Object.keys(dropdownFunctions).map(async (key) => {
argsDropdownOptions[key] = await dropdownFunctions[key](this.state);
}));
if (this._isMounted) {
this.setState({ argsDropdownOptions: argsDropdownOptions });
}
}
}
/** Save the current invest arguments to a python script via datastack.py API.
*
* @param {string} filepath - desired path to the python script
@ -212,12 +189,12 @@ class SetupTab extends React.Component {
*/
async savePythonScript(filepath) {
const {
modelName,
modelID,
} = this.props;
const args = argsDictFromObject(this.state.argsValues);
const payload = {
filepath: filepath,
modelname: modelName,
model_id: modelID,
args: JSON.stringify(args),
};
const response = await saveToPython(payload);
@ -225,33 +202,29 @@ class SetupTab extends React.Component {
}
async saveJsonFile(datastackPath, relativePaths) {
const {
pyModuleName,
} = this.props;
const { modelID } = this.props;
const args = argsDictFromObject(this.state.argsValues);
const payload = {
filepath: datastackPath,
moduleName: pyModuleName,
model_id: modelID,
relativePaths: relativePaths,
args: JSON.stringify(args),
};
const {message, error} = await writeParametersToFile(payload);
const { message, error } = await writeParametersToFile(payload);
this.setSaveAlert(message, error);
}
async saveDatastack(datastackPath) {
const {
pyModuleName,
} = this.props;
const { modelID } = this.props;
const args = argsDictFromObject(this.state.argsValues);
const payload = {
filepath: datastackPath,
moduleName: pyModuleName,
model_id: modelID,
args: JSON.stringify(args),
};
const key = window.crypto.getRandomValues(new Uint16Array(1))[0].toString();
this.setSaveAlert('archiving...', false, key);
const {message, error} = await archiveDatastack(payload);
const { message, error } = await archiveDatastack(payload);
this.setSaveAlert(message, error, key);
}
@ -296,7 +269,7 @@ class SetupTab extends React.Component {
}
async loadParametersFromFile(filepath) {
const { pyModuleName, switchTabs, t } = this.props;
const { modelID, switchTabs, t, investList } = this.props;
let datastack;
try {
if (filepath.endsWith('gz')) { // .tar.gz, .tgz
@ -313,6 +286,7 @@ class SetupTab extends React.Component {
ipcMainChannels.CHECK_FILE_PERMISSIONS, directoryPath);
if (writable) {
datastack = await fetchDatastackFromFile({
model_id: modelID,
filepath: filepath,
extractPath: directoryPath,
});
@ -323,7 +297,9 @@ class SetupTab extends React.Component {
}
} else { return; } // dialog closed without selection
} else {
datastack = await fetchDatastackFromFile({ filepath: filepath });
datastack = await fetchDatastackFromFile(
{ model_id: modelID, filepath: filepath }
);
}
} catch (error) {
logger.error(error);
@ -332,15 +308,15 @@ class SetupTab extends React.Component {
);
return;
}
if (datastack.module_name === pyModuleName) {
if (datastack.model_id === modelID) {
this.batchUpdateArgs(datastack.args);
switchTabs('setup');
this.triggerScrollEvent();
} else {
alert( // eslint-disable-line no-alert
t(
'Datastack/Logfile for {{modelName}} does not match this model.',
{ modelName: datastack.model_human_name }
'Datastack/Logfile for model {{modelTitle}} does not match this model.',
{ modelTitle: investList[datastack.model_id].modelTitle }
)
);
}
@ -399,7 +375,8 @@ class SetupTab extends React.Component {
status: undefined, // Clear job status to hide model status indicator.
});
this.debouncedValidate();
this.callUISpecFunctions();
this.debouncedArgsEnabled();
this.debouncedDropdownFunctions();
});
}
@ -408,21 +385,82 @@ class SetupTab extends React.Component {
* @param {object} argsDict - key: value pairs of InVEST arguments.
*/
batchUpdateArgs(argsDict) {
const { argsSpec, uiSpec } = this.props;
const { argsSpec, inputFieldOrder } = this.props;
const {
argsValues,
argsDropdownOptions,
} = initializeArgValues(argsSpec, uiSpec, argsDict);
} = initializeArgValues(argsSpec, inputFieldOrder, argsDict);
this.setState({
argsValues: argsValues,
argsDropdownOptions: argsDropdownOptions,
}, () => {
this.investValidate();
this.callUISpecFunctions();
this.investArgsEnabled();
});
}
/** Get a debounced version of investArgsEnabled.
*
* The original function will not be called until after the
* debounced version stops being invoked for 200 milliseconds.
*
* @returns {undefined}
*/
debouncedArgsEnabled() {
if (this.enabledTimer) {
clearTimeout(this.enabledTimer);
}
// we want this check to be very responsive,
// but also to wait for a pause in data entry.
this.enabledTimer = setTimeout(this.investArgsEnabled, 200);
}
/** Set the enabled/disabled status of args.
*
* @returns {undefined}
*/
async investArgsEnabled() {
const { modelID } = this.props;
const { argsValues } = this.state;
if (this._isMounted) {
this.setState({
argsEnabled: await fetchArgsEnabled({
model_id: modelID,
args: JSON.stringify(argsDictFromObject(argsValues)),
}),
});
}
}
debouncedDropdownFunctions() {
if (this.dropdownTimer) {
clearTimeout(this.dropdownTimer);
}
// we want this check to be very responsive,
// but also to wait for a pause in data entry.
this.dropdownTimer = setTimeout(this.callDropdownFunctions, 200);
}
/** Call endpoint to get dynamically populated dropdown options.
*
* @returns {undefined}
*/
async callDropdownFunctions() {
const { modelID } = this.props;
const { argsValues, argsDropdownOptions } = this.state;
const payload = {
model_id: modelID,
args: JSON.stringify(argsDictFromObject(argsValues)),
};
const results = await getDynamicDropdowns(payload);
Object.keys(results).forEach((argkey) => {
argsDropdownOptions[argkey] = results[argkey];
});
this.setState({ argsDropdownOptions: argsDropdownOptions });
}
/** Get a debounced version of investValidate.
*
* The original validate function will not be called until after the
@ -444,11 +482,11 @@ class SetupTab extends React.Component {
* @returns {undefined}
*/
async investValidate() {
const { argsSpec, pyModuleName } = this.props;
const { argsSpec, modelID } = this.props;
const { argsValues, argsValidation, argsValid } = this.state;
const keyset = new Set(Object.keys(argsSpec));
const payload = {
model_module: pyModuleName,
model_id: modelID,
args: JSON.stringify(argsDictFromObject(argsValues)),
};
const results = await fetchValidation(payload);
@ -514,11 +552,12 @@ class SetupTab extends React.Component {
const {
argsSpec,
userguide,
isCoreModel,
inputFieldOrder,
sidebarSetupElementId,
sidebarFooterElementId,
executeClicked,
uiSpec,
modelName,
modelID,
} = this.props;
const SaveAlerts = [];
@ -560,6 +599,7 @@ class SetupTab extends React.Component {
)
: <span>{t('Run')}</span>
);
return (
<Container fluid>
<Row>
@ -569,8 +609,9 @@ class SetupTab extends React.Component {
argsValidation={argsValidation}
argsEnabled={argsEnabled}
argsDropdownOptions={argsDropdownOptions}
argsOrder={uiSpec.order}
argsOrder={inputFieldOrder}
userguide={userguide}
isCoreModel={isCoreModel}
updateArgValues={this.updateArgValues}
updateArgTouched={this.updateArgTouched}
loadParametersFromFile={this.loadParametersFromFile}
@ -597,15 +638,13 @@ class SetupTab extends React.Component {
</Button>
</OverlayTrigger>
<SaveAsModal
modelName={modelName}
modelID={modelID}
savePythonScript={this.savePythonScript}
saveJsonFile={this.saveJsonFile}
saveDatastack={this.saveDatastack}
removeSaveErrors={this.removeSaveErrors}
/>
<React.Fragment>
{SaveAlerts}
</React.Fragment>
{SaveAlerts}
</Portal>
<Portal elId={sidebarFooterElementId}>
<Button
@ -629,20 +668,16 @@ class SetupTab extends React.Component {
export default withTranslation()(SetupTab);
SetupTab.propTypes = {
pyModuleName: PropTypes.string.isRequired,
userguide: PropTypes.string.isRequired,
modelName: PropTypes.string.isRequired,
isCoreModel: PropTypes.bool.isRequired,
modelID: PropTypes.string.isRequired,
argsSpec: PropTypes.objectOf(
PropTypes.shape({
name: PropTypes.string,
type: PropTypes.string,
})
).isRequired,
uiSpec: PropTypes.shape({
order: PropTypes.arrayOf(PropTypes.arrayOf(PropTypes.string)).isRequired,
enabledFunctions: PropTypes.objectOf(PropTypes.func),
dropdownFunctions: PropTypes.objectOf(PropTypes.func),
}).isRequired,
inputFieldOrder: PropTypes.arrayOf(PropTypes.arrayOf(PropTypes.string)).isRequired,
argsInitValues: PropTypes.objectOf(PropTypes.oneOfType(
[PropTypes.string, PropTypes.bool, PropTypes.number])),
investExecute: PropTypes.func.isRequired,
@ -650,6 +685,11 @@ SetupTab.propTypes = {
sidebarFooterElementId: PropTypes.string.isRequired,
executeClicked: PropTypes.bool.isRequired,
switchTabs: PropTypes.func.isRequired,
investList: PropTypes.shape({
modelTitle: PropTypes.string,
type: PropTypes.string,
}).isRequired,
t: PropTypes.func.isRequired,
tabID: PropTypes.string.isRequired,
updateJobProperties: PropTypes.func.isRequired,
};

View File

@ -99,7 +99,7 @@
"User's guide entry": "Entrada de la guía del usuario",
"Only drop one file at a time.": "Solo se puede soltar un archivo a la vez.",
"Choose location to extract archive": "Seleccione la ubicación donde extraer el archivo",
"Datastack/Logfile for {{modelName}} does not match this model.": "Datastack/Logfile para {{modelName}} no calza con este modelo.",
"Datastack/Logfile for {{modelTitle}} does not match this model.": "Datastack/Logfile para {{modelTitle}} no calza con este modelo.",
"Running": "En curso",
"Run": "Ejecutar",
"Browse to a datastack (.json, .tgz) or InVEST logfile (.txt)": "Buscar una pila de datos (.json, .tgz) o un archivo de registro InVEST (.txt)",

View File

@ -99,7 +99,7 @@
"User's guide entry": "用户指南条目",
"Only drop one file at a time.": "一次只能丢一个文件。",
"Choose location to extract archive": "选择提取存档的位置",
"Datastack/Logfile for {{modelName}} does not match this model.": "{{modelName}} 个数据堆栈/日志文件与此模型不匹配。",
"Datastack/Logfile for {{modelTitle}} does not match this model.": "{{modelTitle}} 个数据堆栈/日志文件与此模型不匹配。",
"Running": "运行中",
"Run": "运行",
"Browse to a datastack (.json, .tgz) or InVEST logfile (.txt)": "浏览到数据包(.json, .tgz或InVEST日志文件.txt。",

Some files were not shown because too many files have changed in this diff Show More