invest/tests/test_validation.py

2299 lines
86 KiB
Python

"""Testing module for validation."""
import functools
import os
import shutil
import stat
import string
import sys
import tempfile
import textwrap
import time
import unittest
import warnings
from unittest.mock import Mock
import numpy
import pandas
from osgeo import gdal
from osgeo import ogr
from osgeo import osr
from natcap.invest import spec
from natcap.invest.spec import (
u,
BooleanInput,
CSVInput,
DirectoryInput,
FileInput,
Input,
IntegerInput,
ModelSpec,
NumberInput,
OptionStringInput,
PercentInput,
RasterOrVectorInput,
RatioInput,
SingleBandRasterInput,
StringInput,
VectorInput)
gdal.UseExceptions()
def model_spec_with_defaults(model_id='', model_title='', userguide='', aliases=None,
inputs={}, outputs={}, input_field_order=[]):
return ModelSpec(model_id=model_id, model_title=model_title, userguide=userguide,
aliases=aliases, inputs=inputs, outputs=outputs,
input_field_order=input_field_order)
def number_input_spec_with_defaults(id='', units=u.none, expression='', **kwargs):
return NumberInput(id=id, units=units, expression=expression, **kwargs)
class SpatialOverlapTest(unittest.TestCase):
"""Test Spatial Overlap."""
def setUp(self):
"""Create a new workspace to use for each test."""
self.workspace_dir = tempfile.mkdtemp()
def tearDown(self):
"""Remove the workspace created for this test."""
shutil.rmtree(self.workspace_dir)
def test_no_overlap(self):
"""Validation: verify lack of overlap."""
import pygeoprocessing
from natcap.invest import validation
driver = gdal.GetDriverByName('GTiff')
filepath_1 = os.path.join(self.workspace_dir, 'raster_1.tif')
filepath_2 = os.path.join(self.workspace_dir, 'raster_2.tif')
filepath_list = []
bbox_list = []
for filepath, geotransform in (
(filepath_1, [1, 1, 0, 1, 0, 1]),
(filepath_2, [100, 1, 0, 100, 0, 1])):
raster = driver.Create(filepath, 3, 3, 1, gdal.GDT_Int32)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(4326)
raster.SetProjection(wgs84_srs.ExportToWkt())
raster.SetGeoTransform(geotransform)
raster = None
filepath_list.append(filepath)
bbox_list.append(
pygeoprocessing.get_raster_info(filepath)['bounding_box'])
error_msg = validation.check_spatial_overlap([filepath_1, filepath_2])
formatted_lists = validation._format_bbox_list(
filepath_list, bbox_list)
self.assertTrue(validation.MESSAGES['BBOX_NOT_INTERSECT'].format(
bboxes=formatted_lists) in error_msg)
def test_overlap(self):
"""Validation: verify overlap."""
from natcap.invest import validation
driver = gdal.GetDriverByName('GTiff')
filepath_1 = os.path.join(self.workspace_dir, 'raster_1.tif')
filepath_2 = os.path.join(self.workspace_dir, 'raster_2.tif')
for filepath, geotransform in (
(filepath_1, [1, 1, 0, 1, 0, 1]),
(filepath_2, [2, 1, 0, 2, 0, 1])):
raster = driver.Create(filepath, 3, 3, 1, gdal.GDT_Int32)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(4326)
raster.SetProjection(wgs84_srs.ExportToWkt())
raster.SetGeoTransform(geotransform)
raster = None
self.assertEqual(
None, validation.check_spatial_overlap([filepath_1, filepath_2]))
def test_check_overlap_undefined_projection(self):
"""Validation: check overlap of raster with an undefined projection."""
from natcap.invest import validation
driver = gdal.GetDriverByName('GTiff')
filepath_1 = os.path.join(self.workspace_dir, 'raster_1.tif')
filepath_2 = os.path.join(self.workspace_dir, 'raster_2.tif')
raster_1 = driver.Create(filepath_1, 3, 3, 1, gdal.GDT_Int32)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(4326)
raster_1.SetProjection(wgs84_srs.ExportToWkt())
raster_1.SetGeoTransform([1, 1, 0, 1, 0, 1])
raster_1 = None
# set up a raster with an undefined projection
raster_2 = driver.Create(filepath_2, 3, 3, 1, gdal.GDT_Int32)
raster_2.SetGeoTransform([2, 1, 0, 2, 0, 1])
raster_2 = None
error_msg = validation.check_spatial_overlap(
[filepath_1, filepath_2], different_projections_ok=True)
expected = validation.MESSAGES['NO_PROJECTION'].format(filepath=filepath_2)
self.assertEqual(error_msg, expected)
def test_check_overlap_unable_to_transform(self):
"""Validation: check overlap when layer cannot transform to EPSG:4326."""
from natcap.invest import validation
driver = gdal.GetDriverByName('GTiff')
filepath_1 = os.path.join(self.workspace_dir, 'raster_1.tif')
filepath_2 = os.path.join(self.workspace_dir, 'raster_2.tif')
raster_1 = driver.Create(filepath_1, 3, 3, 1, gdal.GDT_Int32)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(4326)
raster_1.SetProjection(wgs84_srs.ExportToWkt())
raster_1.SetGeoTransform([1, 1, 0, 1, 0, 1])
raster_1 = None
# set up a raster with an outside-the-globe extent
raster_2 = driver.Create(filepath_2, 3, 3, 1, gdal.GDT_Int32)
eckert_srs = osr.SpatialReference()
proj4str = '+proj=eck4 +lon_0=0 +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs +type=crs'
eckert_srs.ImportFromProj4(proj4str)
raster_2.SetProjection(eckert_srs.ExportToWkt())
raster_2.SetGeoTransform(
(-10000000.0, 19531.25, 0.0, 15000000.0, 0.0, -4882.8125))
raster_2 = None
# The improper raster should be skipped by validation, thus no errors
error_msg = validation.check_spatial_overlap(
[filepath_1, filepath_2], different_projections_ok=True)
self.assertEqual(None, error_msg)
@unittest.skip("skipping due to unresolved projection comparison question")
def test_different_projections_not_ok(self):
"""Validation: different projections not allowed by default.
This test illustrates a bug we don't yet have a good solution for
(natcap/invest#558)
When ``different_projections_ok is False``, we don't check that the
projections are actually the same, because there isn't a great way to
do so. So there's the possibility that some bounding boxes overlap
numerically, but have different projections, and thus pass validation
when they shouldn't.
"""
from natcap.invest import validation
driver = gdal.GetDriverByName('GTiff')
filepath_1 = os.path.join(self.workspace_dir, 'raster_1.tif')
filepath_2 = os.path.join(self.workspace_dir, 'raster_2.tif')
# bounding boxes overlap if we don't account for the projections
for filepath, geotransform, epsg in (
(filepath_1, [1, 1, 0, 1, 0, 1], 4326),
(filepath_2, [2, 1, 0, 2, 0, 1], 2193)):
raster = driver.Create(filepath, 3, 3, 1, gdal.GDT_Int32)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(epsg)
raster.SetProjection(wgs84_srs.ExportToWkt())
raster.SetGeoTransform(geotransform)
raster = None
expected = (f'Spatial files {[filepath_1, filepath_2]} do not all '
'have the same projection')
self.assertEqual(
validation.check_spatial_overlap([filepath_1, filepath_2]),
expected)
class ValidatorTest(unittest.TestCase):
"""Test Validator."""
def test_args_wrong_type(self):
"""Validation: check for error when args is the wrong type."""
from natcap.invest import validation
@validation.invest_validator
def validate(args, limit_to=None):
pass
with self.assertRaises(AssertionError):
validate(args=123)
def test_limit_to_wrong_type(self):
"""Validation: check for error when limit_to is the wrong type."""
from natcap.invest import validation
@validation.invest_validator
def validate(args, limit_to=None):
pass
with self.assertRaises(AssertionError):
validate(args={}, limit_to=1234)
def test_limit_to_not_in_args(self):
"""Validation: check for error when limit_to is not a key in args."""
from natcap.invest import validation
@validation.invest_validator
def validate(args, limit_to=None):
pass
with self.assertRaises(AssertionError):
validate(args={}, limit_to='bar')
def test_args_keys_must_be_strings(self):
"""Validation: check for error when args keys are not all strings."""
from natcap.invest import validation
@validation.invest_validator
def validate(args, limit_to=None):
pass
with self.assertRaises(AssertionError):
validate(args={1: 'foo'})
def test_return_keys_in_args(self):
"""Validation: check for error when return keys not all in args."""
from natcap.invest import validation
@validation.invest_validator
def validate(args, limit_to=None):
return [(('a',), 'error 1')]
validation_errors = validate({})
self.assertEqual(validation_errors,
[(('a',), 'error 1')])
def test_wrong_parameter_names(self):
"""Validation: check for error when wrong function signature used."""
from natcap.invest import validation
@validation.invest_validator
def validate(foo):
pass
with self.assertRaises(AssertionError):
validate({})
def test_return_value(self):
"""Validation: validation errors should be returned from decorator."""
from natcap.invest import validation
errors = [(('a', 'b'), 'Error!')]
@validation.invest_validator
def validate(args, limit_to=None):
return errors
validation_errors = validate({'a': 'foo', 'b': 'bar'})
self.assertEqual(validation_errors, errors)
def test_n_workers(self):
"""Validation: validation error returned on invalid n_workers."""
from natcap.invest import validation
args_spec = model_spec_with_defaults(inputs=[
spec.build_input_spec('n_workers', spec.N_WORKERS)])
@validation.invest_validator
def validate(args, limit_to=None):
return validation.validate(args, args_spec)
args = {'n_workers': 'not a number'}
validation_errors = validate(args)
expected = [(
['n_workers'],
validation.MESSAGES['NOT_A_NUMBER'].format(value=args['n_workers']))]
self.assertEqual(validation_errors, expected)
def test_timeout_succeed(self):
from natcap.invest import validation
# both args and the kwarg should be passed to the function
@spec.timeout
def func(arg1, arg2, kwarg=None):
self.assertEqual(kwarg, 'kwarg')
time.sleep(1)
# this will raise an error if the timeout is exceeded
# timeout defaults to 5 seconds so this should pass
func('arg1', 'arg2', kwarg='kwarg')
def test_timeout_fail(self):
from natcap.invest import validation
# both args and the kwarg should be passed to the function
@spec.timeout
def func(arg):
time.sleep(6)
# this will return a warning if the timeout is exceeded
# timeout defaults to 5 seconds so this should fail
with warnings.catch_warnings(record=True) as ws:
# cause all warnings to always be triggered
warnings.simplefilter("always")
func('arg')
self.assertTrue(len(ws) == 1)
self.assertTrue('timed out' in str(ws[0].message))
class DirectoryValidation(unittest.TestCase):
"""Test Directory Validation."""
def setUp(self):
"""Create a new workspace to use for each test."""
self.workspace_dir = tempfile.mkdtemp()
def tearDown(self):
"""Remove the workspace created for this test."""
shutil.rmtree(self.workspace_dir)
def test_exists(self):
"""Validation: when a folder must exist and does."""
from natcap.invest import validation
self.assertIsNone(DirectoryInput().validate(self.workspace_dir))
def test_not_exists(self):
"""Validation: when a folder must exist but does not."""
from natcap.invest import validation
dirpath = os.path.join(self.workspace_dir, 'nonexistent_dir')
validation_warning = DirectoryInput().validate(dirpath)
self.assertEqual(validation_warning, validation.MESSAGES['DIR_NOT_FOUND'])
def test_file(self):
"""Validation: when a file is given to folder validation."""
from natcap.invest import validation
filepath = os.path.join(self.workspace_dir, 'some_file.txt')
with open(filepath, 'w') as opened_file:
opened_file.write('the text itself does not matter.')
validation_warning = DirectoryInput().validate(filepath)
self.assertEqual(validation_warning, validation.MESSAGES['NOT_A_DIR'])
def test_valid_permissions(self):
"""Validation: folder permissions."""
from natcap.invest import validation
self.assertIsNone(DirectoryInput(
permissions='rwx').validate(self.workspace_dir))
def test_workspace_not_exists(self):
"""Validation: when a folder's parent must exist with permissions."""
from natcap.invest import validation
dirpath = 'foo'
new_dir = os.path.join(self.workspace_dir, dirpath)
self.assertIsNone(DirectoryInput(
must_exist=False, permissions='rwx').validate(new_dir))
@unittest.skipIf(
sys.platform.startswith('win'),
'requires support for os.chmod(), which is unreliable on Windows')
class DirectoryValidationMacOnly(unittest.TestCase):
"""Test Directory Permissions Validation."""
def test_invalid_permissions_r(self):
"""Validation: when a folder must have read/write/execute
permissions but is missing write and execute permissions."""
from natcap.invest import validation
with tempfile.TemporaryDirectory() as tempdir:
os.chmod(tempdir, stat.S_IREAD)
validation_warning = DirectoryInput(permissions='rwx').validate(tempdir)
self.assertEqual(
validation_warning,
validation.MESSAGES['NEED_PERMISSION_DIRECTORY'].format(permission='execute'))
def test_invalid_permissions_w(self):
"""Validation: when a folder must have read/write/execute
permissions but is missing read and execute permissions."""
from natcap.invest import validation
with tempfile.TemporaryDirectory() as tempdir:
os.chmod(tempdir, stat.S_IWRITE)
validation_warning = DirectoryInput(permissions='rwx').validate(tempdir)
self.assertEqual(
validation_warning,
validation.MESSAGES['NEED_PERMISSION_DIRECTORY'].format(permission='read'))
def test_invalid_permissions_x(self):
"""Validation: when a folder must have read/write/execute
permissions but is missing read and write permissions."""
from natcap.invest import validation
with tempfile.TemporaryDirectory() as tempdir:
os.chmod(tempdir, stat.S_IEXEC)
validation_warning = DirectoryInput(permissions='rwx').validate(tempdir)
self.assertEqual(
validation_warning,
validation.MESSAGES['NEED_PERMISSION_DIRECTORY'].format(permission='read'))
def test_invalid_permissions_rw(self):
"""Validation: when a folder must have read/write/execute
permissions but is missing execute permission."""
from natcap.invest import validation
with tempfile.TemporaryDirectory() as tempdir:
os.chmod(tempdir, stat.S_IREAD | stat.S_IWRITE)
validation_warning = DirectoryInput(permissions='rwx').validate(tempdir)
self.assertEqual(
validation_warning,
validation.MESSAGES['NEED_PERMISSION_DIRECTORY'].format(permission='execute'))
def test_invalid_permissions_rx(self):
"""Validation: when a folder must have read/write/execute
permissions but is missing write permission."""
from natcap.invest import validation
with tempfile.TemporaryDirectory() as tempdir:
os.chmod(tempdir, stat.S_IREAD | stat.S_IEXEC)
validation_warning = DirectoryInput(permissions='rwx').validate(tempdir)
self.assertEqual(
validation_warning,
validation.MESSAGES['NEED_PERMISSION_DIRECTORY'].format(permission='write'))
def test_invalid_permissions_wx(self):
"""Validation: when a folder must have read/write/execute
permissions but is missing read permission."""
from natcap.invest import validation
with tempfile.TemporaryDirectory() as tempdir:
os.chmod(tempdir, stat.S_IWRITE | stat.S_IEXEC)
validation_warning = DirectoryInput(permissions='rwx').validate(tempdir)
self.assertEqual(
validation_warning,
validation.MESSAGES['NEED_PERMISSION_DIRECTORY'].format(permission='read'))
class FileValidation(unittest.TestCase):
"""Test File Validator."""
def setUp(self):
"""Create a new workspace to use for each test."""
self.workspace_dir = tempfile.mkdtemp()
def tearDown(self):
"""Remove the workspace created for this test."""
shutil.rmtree(self.workspace_dir)
def test_file_exists(self):
"""Validation: test that a file exists."""
from natcap.invest import validation
filepath = os.path.join(self.workspace_dir, 'file.txt')
with open(filepath, 'w') as new_file:
new_file.write("Here's some text.")
self.assertIsNone(FileInput().validate(filepath))
def test_file_not_found(self):
"""Validation: test when a file is not found."""
from natcap.invest import validation
filepath = os.path.join(self.workspace_dir, 'file.txt')
error_msg = FileInput().validate(filepath)
self.assertEqual(error_msg, validation.MESSAGES['FILE_NOT_FOUND'])
class RasterValidation(unittest.TestCase):
"""Test Raster Validation."""
def setUp(self):
"""Create a new workspace to use for each test."""
self.workspace_dir = tempfile.mkdtemp()
def tearDown(self):
"""Remove the workspace created for this test."""
shutil.rmtree(self.workspace_dir)
def test_file_not_found(self):
"""Validation: test that a raster exists."""
from natcap.invest import validation
filepath = os.path.join(self.workspace_dir, 'file.txt')
error_msg = SingleBandRasterInput().validate(filepath)
self.assertEqual(error_msg, validation.MESSAGES['FILE_NOT_FOUND'])
def test_invalid_raster(self):
"""Validation: test when a raster format is invalid."""
from natcap.invest import validation
filepath = os.path.join(self.workspace_dir, 'file.txt')
with open(filepath, 'w') as bad_raster:
bad_raster.write('not a raster')
error_msg = SingleBandRasterInput().validate(filepath)
self.assertEqual(error_msg, validation.MESSAGES['NOT_GDAL_RASTER'])
def test_invalid_ovr_raster(self):
"""Validation: test when a .tif.ovr file is input as a raster."""
from natcap.invest import validation
# Use EPSG:32731 # WGS84 / UTM zone 31s
driver = gdal.GetDriverByName('GTiff')
filepath = os.path.join(self.workspace_dir, 'raster.tif')
raster = driver.Create(filepath, 3, 3, 1, gdal.GDT_Int32)
meters_srs = osr.SpatialReference()
meters_srs.ImportFromEPSG(32731)
raster.SetProjection(meters_srs.ExportToWkt())
raster = None
# I could only create overviews when opening the file, not on creation.
# Build overviews taken from:
# https://gis.stackexchange.com/questions/270498/compress-gtiff-external-overviews-with-gdal-api
raster = gdal.OpenEx(filepath)
gdal.SetConfigOption("COMPRESS_OVERVIEW", "DEFLATE")
raster.BuildOverviews("AVERAGE", [2, 4, 8, 16, 32, 64, 128, 256])
raster = None
filepath_ovr = os.path.join(self.workspace_dir, 'raster.tif.ovr')
error_msg = SingleBandRasterInput().validate(filepath_ovr)
self.assertEqual(error_msg, validation.MESSAGES['OVR_FILE'])
def test_raster_not_projected(self):
"""Validation: test when a raster is not linearly projected."""
from natcap.invest import validation
# use WGS84 as not linearly projected.
driver = gdal.GetDriverByName('GTiff')
filepath = os.path.join(self.workspace_dir, 'raster.tif')
raster = driver.Create(filepath, 3, 3, 1, gdal.GDT_Int32)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(4326)
raster.SetProjection(wgs84_srs.ExportToWkt())
raster = None
error_msg = SingleBandRasterInput(projected=True).validate(filepath)
self.assertEqual(error_msg, validation.MESSAGES['NOT_PROJECTED'])
def test_raster_incorrect_units(self):
"""Validation: test when a raster projection has wrong units."""
from natcap.invest import validation
# Use EPSG:32066 # NAD27 / BLM 16N (in US Survey Feet)
driver = gdal.GetDriverByName('GTiff')
filepath = os.path.join(self.workspace_dir, 'raster.tif')
raster = driver.Create(filepath, 3, 3, 1, gdal.GDT_Int32)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(32066)
raster.SetProjection(wgs84_srs.ExportToWkt())
raster = None
error_msg = SingleBandRasterInput(
projected=True, projection_units=spec.u.meter
).validate(filepath)
expected_msg = validation.MESSAGES['WRONG_PROJECTION_UNIT'].format(
unit_a='meter', unit_b='us_survey_foot')
self.assertEqual(expected_msg, error_msg)
class VectorValidation(unittest.TestCase):
"""Test Vector Validation."""
def setUp(self):
"""Create a new workspace to use for each test."""
self.workspace_dir = tempfile.mkdtemp()
def tearDown(self):
"""Remove the workspace created for this test."""
shutil.rmtree(self.workspace_dir)
def test_file_not_found(self):
"""Validation: test when a vector file is not found."""
from natcap.invest import validation
filepath = os.path.join(self.workspace_dir, 'file.txt')
error_msg = VectorInput(
geometry_types={'POINT'}, fields={}).validate(filepath)
self.assertEqual(error_msg, validation.MESSAGES['FILE_NOT_FOUND'])
def test_invalid_vector(self):
"""Validation: test when a vector's format is invalid."""
from natcap.invest import validation
filepath = os.path.join(self.workspace_dir, 'file.txt')
with open(filepath, 'w') as bad_vector:
bad_vector.write('not a vector')
error_msg = VectorInput(
geometry_types={'POINT'}, fields={}).validate(filepath)
self.assertEqual(error_msg, validation.MESSAGES['NOT_GDAL_VECTOR'])
def test_missing_fieldnames(self):
"""Validation: test when a vector is missing fields."""
from natcap.invest import validation
driver = gdal.GetDriverByName('GPKG')
filepath = os.path.join(self.workspace_dir, 'vector.gpkg')
vector = driver.Create(filepath, 0, 0, 0, gdal.GDT_Unknown)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(4326)
layer = vector.CreateLayer('sample_layer', wgs84_srs, ogr.wkbPoint)
for field_name, field_type in (('COL_A', ogr.OFTInteger),
('col_b', ogr.OFTString)):
layer.CreateField(ogr.FieldDefn(field_name, field_type))
new_feature = ogr.Feature(layer.GetLayerDefn())
new_feature.SetField('COL_A', 1)
new_feature.SetField('col_b', 'hello')
layer.CreateFeature(new_feature)
new_feature = None
layer = None
vector = None
error_msg = VectorInput(
geometry_types={'POINT'},
fields=[
Input(id='col_a'),
Input(id='col_b'),
Input(id='col_c')]
).validate(filepath)
expected = validation.MESSAGES['MATCHED_NO_HEADERS'].format(
header='field', header_name='col_c')
self.assertEqual(error_msg, expected)
def test_vector_projected_in_m(self):
"""Validation: test that a vector's projection has expected units."""
from natcap.invest import validation
driver = gdal.GetDriverByName('GPKG')
filepath = os.path.join(self.workspace_dir, 'vector.gpkg')
vector = driver.Create(filepath, 0, 0, 0, gdal.GDT_Unknown)
meters_srs = osr.SpatialReference()
meters_srs.ImportFromEPSG(32731)
layer = vector.CreateLayer('sample_layer', meters_srs, ogr.wkbPoint)
layer = None
vector = None
error_msg = VectorInput(
fields={}, geometry_types={'POINT'}, projected=True, projection_units=spec.u.foot
).validate(filepath)
expected_msg = validation.MESSAGES['WRONG_PROJECTION_UNIT'].format(
unit_a='foot', unit_b='metre')
self.assertEqual(error_msg, expected_msg)
self.assertIsNone(VectorInput(
fields={}, geometry_types={'POINT'}, projected=True, projection_units=spec.u.meter
).validate(filepath))
def test_wrong_geom_type(self):
"""Validation: checks that the vector's geometry type is correct."""
from natcap.invest import validation
driver = gdal.GetDriverByName('GPKG')
filepath = os.path.join(self.workspace_dir, 'vector.gpkg')
vector = driver.Create(filepath, 0, 0, 0, gdal.GDT_Unknown)
meters_srs = osr.SpatialReference()
meters_srs.ImportFromEPSG(32731)
layer = vector.CreateLayer('sample_layer', meters_srs, ogr.wkbPoint)
layer = None
vector = None
self.assertIsNone(VectorInput(
geometry_types={'POLYGON', 'POINT'}, fields=None).validate(filepath))
self.assertEqual(
VectorInput(fields=None, geometry_types={'MULTIPOINT'}).validate(filepath),
validation.MESSAGES['WRONG_GEOM_TYPE'].format(allowed={'MULTIPOINT'}))
class FreestyleStringValidation(unittest.TestCase):
"""Test Freestyle String Validation."""
def test_int(self):
"""Validation: test that an int can be a valid string."""
from natcap.invest import validation
self.assertIsNone(StringInput().validate(1234))
def test_float(self):
"""Validation: test that a float can be a valid string."""
from natcap.invest import validation
self.assertIsNone(StringInput().validate(1.234))
def test_regexp(self):
"""Validation: test that we can check regex patterns on strings."""
from natcap.invest import validation
from natcap.invest.spec import SUFFIX
self.assertEqual(
None, StringInput(regexp='^1.[0-9]+$').validate(1.234))
regexp = '^[a-zA-Z]+$'
error_msg = StringInput(regexp=regexp).validate('foobar12')
self.assertEqual(
error_msg, validation.MESSAGES['REGEXP_MISMATCH'].format(regexp=regexp))
error_msg = StringInput(regexp=SUFFIX['regexp']).validate('4/20')
self.assertEqual(
error_msg, validation.MESSAGES['REGEXP_MISMATCH'].format(regexp=SUFFIX['regexp']))
class OptionStringValidation(unittest.TestCase):
"""Test Option String Validation."""
def test_valid_option_set(self):
"""Validation: test that a string is a valid option in a set."""
from natcap.invest import validation
self.assertIsNone(OptionStringInput(
options={'foo', 'bar', 'Baz'}).validate('foo'))
def test_invalid_option_set(self):
"""Validation: test when a string is not a valid option in a set."""
from natcap.invest import validation
options = ['foo', 'bar', 'Baz']
error_msg = OptionStringInput(options=options).validate('FOO')
self.assertEqual(
error_msg,
validation.MESSAGES['INVALID_OPTION'].format(
option_list=sorted(options)))
def test_valid_option_dict(self):
"""Validation: test that a string is a valid option in a dict."""
from natcap.invest import validation
self.assertIsNone(OptionStringInput(
options={'foo': 'desc', 'bar': 'desc', 'Baz': 'desc'}).validate('foo'))
def test_invalid_option_dict(self):
"""Validation: test when a string is not a valid option in a dict."""
from natcap.invest import validation
options = {'foo': 'desc', 'bar': 'desc', 'Baz': 'desc'}
error_msg = OptionStringInput(options=options).validate('FOO')
self.assertEqual(
error_msg,
validation.MESSAGES['INVALID_OPTION'].format(
option_list=sorted(options.keys())))
class NumberValidation(unittest.TestCase):
"""Test Number Validation."""
def test_string(self):
"""Validation: test when a string is not a number."""
from natcap.invest import validation
value = 'this is a string'
error_msg = NumberInput().validate(value)
self.assertEqual(
error_msg, validation.MESSAGES['NOT_A_NUMBER'].format(value=value))
def test_expression(self):
"""Validation: test that we can use numeric expressions."""
from natcap.invest import validation
self.assertIsNone(NumberInput(
expression='(value < 100) & (value > 4)').validate(35))
def test_expression_missing_value(self):
"""Validation: test the expression string for the 'value' term."""
from natcap.invest import validation
with self.assertRaises(AssertionError):
error_msg = NumberInput(expression='foo < 5').validate(35)
def test_expression_failure(self):
"""Validation: test when a number does not meet the expression."""
from natcap.invest import validation
value = 35
condition = 'float(value) < 0'
error_msg = NumberInput(expression=condition).validate(value)
self.assertEqual(error_msg, validation.MESSAGES['INVALID_VALUE'].format(
value=value, condition=condition))
def test_expression_failure_string(self):
"""Validation: test when string value does not meet the expression."""
from natcap.invest import validation
value = '35'
condition = 'int(value) < 0'
error_msg = NumberInput(expression=condition).validate(value)
self.assertEqual(error_msg, validation.MESSAGES['INVALID_VALUE'].format(
value=value, condition=condition))
class BooleanValidation(unittest.TestCase):
"""Test Boolean Validation."""
def test_actual_bool(self):
"""Validation: test when boolean type objects are passed."""
from natcap.invest import validation
self.assertIsNone(BooleanInput().validate(True))
self.assertIsNone(BooleanInput().validate(False))
def test_string_boolean(self):
"""Validation: an error should be raised when the type is wrong."""
from natcap.invest import validation
for non_boolean_value in ('true', 1, [], set()):
self.assertIsInstance(
BooleanInput().validate(non_boolean_value), str)
def test_invalid_string(self):
"""Validation: test when invalid strings are passed."""
from natcap.invest import validation
value = 'not clear'
error_msg = BooleanInput().validate(value)
self.assertEqual(
error_msg, validation.MESSAGES['NOT_BOOLEAN'].format(value=value))
class CSVValidation(unittest.TestCase):
"""Test CSV Validation."""
def setUp(self):
"""Create a new workspace to use for each test."""
self.workspace_dir = tempfile.mkdtemp()
def tearDown(self):
"""Remove the workspace created for this test."""
shutil.rmtree(self.workspace_dir, ignore_errors=True)
def test_file_not_found(self):
"""Validation: test when a file is not found."""
from natcap.invest import validation
nonexistent_file = os.path.join(self.workspace_dir, 'nope.txt')
error_msg = CSVInput().validate(nonexistent_file)
self.assertEqual(error_msg, validation.MESSAGES['FILE_NOT_FOUND'])
def test_csv_fieldnames(self):
"""Validation: test that we can check fieldnames in a CSV."""
from natcap.invest import validation
df = pandas.DataFrame([
{'foo': 1, 'bar': 2, 'baz': 3},
{'foo': 2, 'bar': 3, 'baz': 4},
{'foo': 3, 'bar': 4, 'baz': 5}])
target_file = os.path.join(self.workspace_dir, 'test.csv')
df.to_csv(target_file)
self.assertIsNone(
CSVInput(columns=[
IntegerInput(id='foo'),
IntegerInput(id='bar')]
).validate(target_file))
def test_csv_bom_fieldnames(self):
"""Validation: test that we can check fieldnames in a CSV with BOM."""
from natcap.invest import validation
df = pandas.DataFrame([
{'foo': 1, 'bar': 2, 'baz': 3},
{'foo': 2, 'bar': 3, 'baz': 4},
{'foo': 3, 'bar': 4, 'baz': 5}])
target_file = os.path.join(self.workspace_dir, 'test.csv')
df.to_csv(target_file, encoding='utf-8-sig')
self.assertIsNone(
CSVInput(columns=[
IntegerInput(id='foo'),
IntegerInput(id='bar')]
).validate(target_file))
def test_csv_missing_fieldnames(self):
"""Validation: test that we can check missing fieldnames in a CSV."""
from natcap.invest import validation
df = pandas.DataFrame([
{'foo': 1, 'bar': 2, 'baz': 3},
{'foo': 2, 'bar': 3, 'baz': 4},
{'foo': 3, 'bar': 4, 'baz': 5}])
target_file = os.path.join(self.workspace_dir, 'test.csv')
df.to_csv(target_file)
error_msg = CSVInput(
columns=[Input(id='field_a')]).validate(target_file)
expected_msg = validation.MESSAGES['MATCHED_NO_HEADERS'].format(
header='column', header_name='field_a')
self.assertEqual(error_msg, expected_msg)
def test_wrong_filetype(self):
"""Validation: verify CSV type does not open pickles."""
from natcap.invest import validation
df = pandas.DataFrame([
{'foo': 1, 'bar': 2, 'baz': 3},
{'foo': 2, 'bar': 3, 'baz': 4},
{'foo': 3, 'bar': 4, 'baz': 5}])
target_file = os.path.join(self.workspace_dir, 'test.pckl')
df.to_pickle(target_file)
error_msg = CSVInput(
columns=[Input(id='field_a')]).validate(target_file)
self.assertIn('must be encoded as UTF-8', error_msg)
def test_slow_to_open(self):
"""Test timeout by mocking a CSV that is slow to open"""
from natcap.invest import validation
# make an actual file so that `check_file` will pass
path = os.path.join(self.workspace_dir, 'slow.csv')
with open(path, 'w') as file:
file.write('1,2,3')
csv_spec = model_spec_with_defaults(inputs=[CSVInput(id="mock_csv_path")])
# validate a mocked CSV that will take 6 seconds to return a value
args = {"mock_csv_path": path}
# define a side effect for the mock that will sleep
# for longer than the allowed timeout
@spec.timeout
def delay(*args, **kwargs):
time.sleep(7)
return []
# replace the validation.check_csv with the mock function, and try to validate
with unittest.mock.patch('natcap.invest.spec.CSVInput.validate', delay):
with warnings.catch_warnings(record=True) as ws:
# cause all warnings to always be triggered
warnings.simplefilter("always")
validation.validate(args, csv_spec)
self.assertEqual(len(ws), 1)
self.assertTrue('timed out' in str(ws[0].message))
def test_check_headers(self):
"""Validation: check that CSV header validation works."""
from natcap.invest import validation
expected_headers = ['hello', '1']
actual = ['hello', '1', '2']
result = spec.check_headers(expected_headers, actual)
self.assertEqual(result, None)
# each pattern should match at least one header
actual = ['1', '2']
result = spec.check_headers(expected_headers, actual)
expected_msg = validation.MESSAGES['MATCHED_NO_HEADERS'].format(
header='header', header_name='hello')
self.assertEqual(result, expected_msg)
# duplicate headers that match a pattern are not allowed
actual = ['hello', '1', '1']
result = spec.check_headers(expected_headers, actual, 'column')
expected_msg = validation.MESSAGES['DUPLICATE_HEADER'].format(
header='column', header_name='1', number=2)
self.assertEqual(result, expected_msg)
# duplicate headers that don't match a pattern are allowed
actual = ['hello', '1', 'x', 'x']
result = spec.check_headers(expected_headers, actual)
self.assertEqual(result, None)
class TestGetValidatedDataframe(unittest.TestCase):
"""Tests for validation.get_validated_dataframe."""
def setUp(self):
"""Create a new workspace to use for each test."""
self.workspace_dir = tempfile.mkdtemp()
def tearDown(self):
"""Remove the workspace created for this test."""
shutil.rmtree(self.workspace_dir)
def test_get_validated_dataframe(self):
"""validation: test the default behavior"""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
with open(csv_file, 'w') as file_obj:
file_obj.write(textwrap.dedent(
"""\
header, ,
a, ,
b,c
"""
))
input_spec = CSVInput(columns=[StringInput(id='header')])
df = input_spec.get_validated_dataframe(csv_file)
# header and table values should be lowercased
self.assertEqual(df.columns[0], 'header')
self.assertEqual(df['header'][0], 'a')
self.assertEqual(df['header'][1], 'b')
def test_unique_key_not_first_column(self):
"""validation: test success when key field is not first column."""
from natcap.invest import validation
csv_text = ("desc,lucode,val1,val2\n"
"corn,1,0.5,2\n"
"bread,2,1,4\n"
"beans,3,0.5,4\n"
"butter,4,9,1")
table_path = os.path.join(self.workspace_dir, 'table.csv')
with open(table_path, 'w') as table_file:
table_file.write(csv_text)
input_spec = CSVInput(
index_col='lucode',
columns=[
StringInput(id='desc'),
IntegerInput(id='lucode'),
NumberInput(id='val1'),
NumberInput(id='val2')
])
df = input_spec.get_validated_dataframe(table_path)
self.assertEqual(df.index.name, 'lucode')
self.assertEqual(list(df.index.values), [1, 2, 3, 4])
self.assertEqual(df['desc'][2], 'bread')
def test_non_unique_keys(self):
"""validation: test error is raised if keys are not unique."""
from natcap.invest import validation
csv_text = ("lucode,desc,val1,val2\n"
"1,corn,0.5,2\n"
"2,bread,1,4\n"
"2,beans,0.5,4\n"
"4,butter,9,1")
table_path = os.path.join(self.workspace_dir, 'table.csv')
with open(table_path, 'w') as table_file:
table_file.write(csv_text)
input_spec = CSVInput(
index_col='lucode',
columns=[
StringInput(id='desc'),
IntegerInput(id='lucode'),
NumberInput(id='val1'),
NumberInput(id='val2')])
with self.assertRaises(ValueError):
input_spec.get_validated_dataframe(table_path)
def test_missing_key_field(self):
"""validation: test error is raised when missing key field."""
from natcap.invest import validation
csv_text = ("luode,desc,val1,val2\n"
"1,corn,0.5,2\n"
"2,bread,1,4\n"
"3,beans,0.5,4\n"
"4,butter,9,1")
table_path = os.path.join(self.workspace_dir, 'table.csv')
with open(table_path, 'w') as table_file:
table_file.write(csv_text)
input_spec = CSVInput(
index_col='lucode',
columns=[
StringInput(id='desc'),
IntegerInput(id='lucode'),
NumberInput(id='val1'),
NumberInput(id='val2')])
with self.assertRaises(ValueError):
input_spec.get_validated_dataframe(table_path)
def test_column_subset(self):
"""validation: test column subset is properly returned."""
from natcap.invest import validation
table_path = os.path.join(self.workspace_dir, 'table.csv')
with open(table_path, 'w') as table_file:
table_file.write(
"lucode,desc,val1,val2\n"
"1,corn,0.5,2\n"
"2,bread,1,4\n"
"3,beans,0.5,4\n"
"4,butter,9,1")
input_spec = CSVInput(columns=[
IntegerInput(id='lucode'),
NumberInput(id='val1'),
NumberInput(id='val2')
])
df = input_spec.get_validated_dataframe(table_path)
self.assertEqual(list(df.columns), ['lucode', 'val1', 'val2'])
def test_column_pattern_matching(self):
"""validation: test column subset is properly returned."""
from natcap.invest import validation
table_path = os.path.join(self.workspace_dir, 'table.csv')
with open(table_path, 'w') as table_file:
table_file.write(
"lucode,grassland_value,forest_value,wetland_valueee\n"
"1,0.5,2\n"
"2,1,4\n"
"3,0.5,4\n"
"4,9,1")
input_spec = CSVInput(columns=[
IntegerInput(id='lucode'),
NumberInput(id='[HABITAT]_value')
])
df = input_spec.get_validated_dataframe(table_path)
self.assertEqual(
list(df.columns), ['lucode', 'grassland_value', 'forest_value'])
def test_trailing_comma(self):
"""validation: test a trailing comma on first line is handled properly."""
from natcap.invest import validation
table_path = os.path.join(self.workspace_dir, 'table.csv')
with open(table_path, 'w') as table_file:
table_file.write(
"lucode,desc,val1,val2\n"
"1,corn,0.5,2,\n"
"2,bread,1,4\n"
"3,beans,0.5,4\n"
"4,butter,9,1")
input_spec = CSVInput(
columns=[
StringInput(id='desc'),
IntegerInput(id='lucode'),
NumberInput(id='val1'),
NumberInput(id='val2')])
result = input_spec.get_validated_dataframe(table_path)
self.assertEqual(result['val2'][0], 2)
self.assertEqual(result['lucode'][1], 2)
def test_trailing_comma_second_line(self):
"""validation: test a trailing comma on second line is handled properly."""
from natcap.invest import validation
csv_text = ("lucode,desc,val1,val2\n"
"1,corn,0.5,2\n"
"2,bread,1,4,\n"
"3,beans,0.5,4\n"
"4,butter,9,1")
table_path = os.path.join(self.workspace_dir, 'table.csv')
with open(table_path, 'w') as table_file:
table_file.write(csv_text)
input_spec = CSVInput(
index_col='lucode',
columns=[
StringInput(id='desc'),
IntegerInput(id='lucode'),
NumberInput(id='val1'),
NumberInput(id='val2')])
result = input_spec.get_validated_dataframe(table_path).to_dict(orient='index')
expected_result = {
1: {'desc': 'corn', 'val1': 0.5, 'val2': 2},
2: {'desc': 'bread', 'val1': 1, 'val2': 4},
3: {'desc': 'beans', 'val1': 0.5, 'val2': 4},
4: {'desc': 'butter', 'val1': 9, 'val2': 1}}
self.assertDictEqual(result, expected_result)
def test_convert_cols_to_lower(self):
"""validation: test that column names are converted to lowercase"""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
with open(csv_file, 'w') as file_obj:
file_obj.write(textwrap.dedent(
"""\
header,
A,
b
"""
))
input_spec = CSVInput(columns=[StringInput(id='header')])
df = input_spec.get_validated_dataframe(csv_file)
self.assertEqual(df['header'][0], 'a')
def test_convert_vals_to_lower(self):
"""validation: test that values are converted to lowercase"""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
with open(csv_file, 'w') as file_obj:
file_obj.write(textwrap.dedent(
"""\
HEADER,
a,
b
"""
))
input_spec = CSVInput(columns=[StringInput(id='header')])
df = input_spec.get_validated_dataframe(csv_file)
self.assertEqual(df.columns[0], 'header')
def test_integer_type_columns(self):
"""validation: integer column values are returned as integers."""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
with open(csv_file, 'w') as file_obj:
file_obj.write(textwrap.dedent(
"""\
id,header,
1,5.0,
2,-1,
3,
"""
))
input_spec = CSVInput(columns=[
IntegerInput(id='id'),
IntegerInput(id='header')])
df = input_spec.get_validated_dataframe(csv_file)
self.assertIsInstance(df['header'][0], numpy.int64)
self.assertIsInstance(df['header'][1], numpy.int64)
# empty values are returned as pandas.NA
self.assertTrue(pandas.isna(df['header'][2]))
def test_float_type_columns(self):
"""validation: float column values are returned as floats."""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
with open(csv_file, 'w') as file_obj:
file_obj.write(textwrap.dedent(
"""\
h1,h2,h3
5,0.5,.4
-1,.3,
"""
))
input_spec = CSVInput(columns=[
NumberInput(id='h1'),
RatioInput(id='h2'),
PercentInput(id='h3')
])
df = input_spec.get_validated_dataframe(csv_file)
self.assertEqual(df['h1'].dtype, float)
self.assertEqual(df['h2'].dtype, float)
self.assertEqual(df['h3'].dtype, float)
# empty values are returned as numpy.nan
self.assertTrue(numpy.isnan(df['h3'][1]))
def test_string_type_columns(self):
"""validation: string column values are returned as strings."""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
with open(csv_file, 'w') as file_obj:
file_obj.write(textwrap.dedent(
"""\
h1,h2,h3
1,a,foo
2,b,
"""
))
input_spec = CSVInput(columns=[
StringInput(id='h1'),
OptionStringInput(id='h2', options=['a', 'b']),
StringInput(id='h3')
])
df = input_spec.get_validated_dataframe(csv_file)
self.assertEqual(df['h1'][0], '1')
self.assertEqual(df['h2'][1], 'b')
# empty values are returned as NA
self.assertTrue(pandas.isna(df['h3'][1]))
def test_boolean_type_columns(self):
"""validation: boolean column values are returned as booleans."""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
with open(csv_file, 'w') as file_obj:
file_obj.write(textwrap.dedent(
"""\
index,h1
a,1
b,0
c,
"""
))
input_spec = CSVInput(columns=[
StringInput(id='index'),
BooleanInput(id='h1')])
df = input_spec.get_validated_dataframe(csv_file)
self.assertEqual(df['h1'][0], True)
self.assertEqual(df['h1'][1], False)
# empty values are returned as pandas.NA
self.assertTrue(pandas.isna(df['h1'][2]))
def test_expand_path_columns(self):
"""validation: test values in path columns are expanded."""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
# create files so that validation will pass
open(os.path.join(self.workspace_dir, 'foo.txt'), 'w').close()
os.mkdir(os.path.join(self.workspace_dir, 'foo'))
open(os.path.join(self.workspace_dir, 'foo', 'bar.txt'), 'w').close()
with open(csv_file, 'w') as file_obj:
file_obj.write(textwrap.dedent(
f"""\
bar,path
1,foo.txt
2,foo/bar.txt
3,{self.workspace_dir}/foo.txt
4,
"""
))
input_spec = CSVInput(columns=[
IntegerInput(id='bar'),
FileInput(id='path')
])
df = input_spec.get_validated_dataframe(csv_file)
self.assertEqual(
f'{self.workspace_dir}{os.sep}foo.txt',
df['path'][0])
self.assertEqual(
f'{self.workspace_dir}{os.sep}foo{os.sep}bar.txt',
df['path'][1])
self.assertEqual(
f'{self.workspace_dir}{os.sep}foo.txt',
df['path'][2])
# empty values are returned as empty strings
self.assertTrue(pandas.isna(df['path'][3]))
def test_other_kwarg(self):
"""validation: any other kwarg should be passed to pandas.read_csv"""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
with open(csv_file, 'w') as file_obj:
file_obj.write(textwrap.dedent(
"""\
h1;h2;h3
a;b;c
d;e;f
"""
))
# using sep=None with the default engine='python',
# it should infer what the separator is
input_spec = CSVInput(columns=[
StringInput(id='h1'),
StringInput(id='h2'),
StringInput(id='h3')])
df = input_spec.get_validated_dataframe(
csv_file,
read_csv_kwargs={'converters': {'h2': lambda val: f'foo_{val}'}})
self.assertEqual(df.columns[0], 'h1')
self.assertEqual(df['h2'][1], 'foo_e')
def test_csv_with_integer_headers(self):
"""
validation: CSV with integer headers should be read into strings.
This shouldn't matter for any of the models, but if a user inputs a CSV
with extra columns that are labeled with numbers, it should still work.
"""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
with open(csv_file, 'w') as file_obj:
file_obj.write(textwrap.dedent(
"""\
1,2,3
a,b,c
d,e,f
"""
))
input_spec = CSVInput(columns=[
StringInput(id='1'),
StringInput(id='2'),
StringInput(id='3')
])
df = input_spec.get_validated_dataframe(csv_file)
# expect headers to be strings
self.assertEqual(df.columns[0], '1')
self.assertEqual(df['1'][0], 'a')
def test_removal_whitespace(self):
"""validation: test that leading/trailing whitespace is removed."""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
with open(csv_file, 'w') as file_obj:
file_obj.write(" Col1, Col2 ,Col3 \n")
file_obj.write(" val1, val2 ,val3 \n")
file_obj.write(" , 2 1 , ")
input_spec = CSVInput(columns=[
StringInput(id='col1'),
StringInput(id='col2'),
StringInput(id='col3')
])
df = input_spec.get_validated_dataframe(csv_file)
# header should have no leading / trailing whitespace
self.assertEqual(list(df.columns), ['col1', 'col2', 'col3'])
# values should have no leading / trailing whitespace
self.assertEqual(df['col1'][0], 'val1')
self.assertEqual(df['col2'][0], 'val2')
self.assertEqual(df['col3'][0], 'val3')
self.assertEqual(df['col1'][1], '')
self.assertEqual(df['col2'][1], '2 1')
self.assertEqual(df['col3'][1], '')
def test_nan_row(self):
"""validation: test NaN row is dropped."""
from natcap.invest import validation
csv_text = ("lucode,desc,val1,val2\n"
"1,corn,0.5,2\n"
",,,\n"
"3,beans,0.5,4\n"
"4,butter,9,1")
table_path = os.path.join(self.workspace_dir, 'table.csv')
with open(table_path, 'w') as table_file:
table_file.write(csv_text)
input_spec = CSVInput(
index_col='lucode',
columns=[
StringInput(id='desc'),
IntegerInput(id='lucode'),
NumberInput(id='val1'),
NumberInput(id='val2')])
result = input_spec.get_validated_dataframe(
table_path).to_dict(orient='index')
expected_result = {
1: {'desc': 'corn', 'val1': 0.5, 'val2': 2},
3: {'desc': 'beans', 'val1': 0.5, 'val2': 4},
4: {'desc': 'butter', 'val1': 9, 'val2': 1}}
self.assertDictEqual(result, expected_result)
def test_rows(self):
"""validation: read csv with row headers instead of columns"""
from natcap.invest import validation
csv_file = os.path.join(self.workspace_dir, 'csv.csv')
with open(csv_file, 'w') as file_obj:
file_obj.write("row1, a ,b\n")
file_obj.write("row2,1,3\n")
input_spec = CSVInput(rows=[
StringInput(id='row1'),
NumberInput(id='row2')
])
df = input_spec.get_validated_dataframe(csv_file)
# header should have no leading / trailing whitespace
self.assertEqual(list(df.columns), ['row1', 'row2'])
self.assertEqual(df['row1'][0], 'a')
self.assertEqual(df['row1'][1], 'b')
self.assertEqual(df['row2'][0], 1)
self.assertEqual(df['row2'][1], 3)
self.assertEqual(df['row2'].dtype, float)
def test_csv_raster_validation_missing_file(self):
"""validation: validate missing raster within csv column"""
from natcap.invest import validation
csv_path = os.path.join(self.workspace_dir, 'csv.csv')
raster_path = os.path.join(self.workspace_dir, 'foo.tif')
with open(csv_path, 'w') as file_obj:
file_obj.write('col1,col2\n')
file_obj.write(f'1,{raster_path}\n')
input_spec = CSVInput(columns=[
NumberInput(id='col1'),
SingleBandRasterInput(id='col2')
])
with self.assertRaises(ValueError) as cm:
input_spec.get_validated_dataframe(csv_path)
self.assertIn('File not found', str(cm.exception))
def test_csv_raster_validation_not_projected(self):
"""validation: validate unprojected raster within csv column"""
from natcap.invest import validation
# create a non-linear projected raster and validate it
driver = gdal.GetDriverByName('GTiff')
csv_path = os.path.join(self.workspace_dir, 'csv.csv')
raster_path = os.path.join(self.workspace_dir, 'foo.tif')
raster = driver.Create(raster_path, 3, 3, 1, gdal.GDT_Int32)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(4326)
raster.SetProjection(wgs84_srs.ExportToWkt())
raster = None
with open(csv_path, 'w') as file_obj:
file_obj.write('col1,col2\n')
file_obj.write(f'1,{raster_path}\n')
input_spec = CSVInput(columns=[
NumberInput(id='col1'),
SingleBandRasterInput(id='col2', projected=True)
])
with self.assertRaises(ValueError) as cm:
input_spec.get_validated_dataframe(csv_path)
self.assertIn('must be projected', str(cm.exception))
def test_csv_vector_validation_missing_field(self):
"""validation: validate vector missing field in csv column"""
from natcap.invest import validation
import pygeoprocessing
from shapely.geometry import Point
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326)
projection_wkt = srs.ExportToWkt()
csv_path = os.path.join(self.workspace_dir, 'csv.csv')
vector_path = os.path.join(self.workspace_dir, 'test.gpkg')
pygeoprocessing.shapely_geometry_to_vector(
[Point(0.0, 0.0)], vector_path, projection_wkt, 'GPKG',
fields={'b': ogr.OFTInteger},
attribute_list=[{'b': 0}],
ogr_geom_type=ogr.wkbPoint)
with open(csv_path, 'w') as file_obj:
file_obj.write('col1,col2\n')
file_obj.write(f'1,{vector_path}\n')
input_spec = CSVInput(columns=[
NumberInput(id='col1'),
VectorInput(
id='col2',
fields=[
IntegerInput(id='a'),
IntegerInput(id='b')
],
geometry_types=['POINT']
)
])
with self.assertRaises(ValueError) as cm:
input_spec.get_validated_dataframe(csv_path)
self.assertIn(
'Expected the field "a" but did not find it',
str(cm.exception))
def test_csv_raster_or_vector_validation(self):
"""validation: validate vector in raster-or-vector csv column"""
from natcap.invest import validation
import pygeoprocessing
from shapely.geometry import Point
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326)
projection_wkt = srs.ExportToWkt()
csv_path = os.path.join(self.workspace_dir, 'csv.csv')
vector_path = os.path.join(self.workspace_dir, 'test.gpkg')
pygeoprocessing.shapely_geometry_to_vector(
[Point(0.0, 0.0)], vector_path, projection_wkt, 'GPKG',
ogr_geom_type=ogr.wkbPoint)
with open(csv_path, 'w') as file_obj:
file_obj.write('col1,col2\n')
file_obj.write(f'1,{vector_path}\n')
input_spec = CSVInput(columns=[
NumberInput(id='col1'),
RasterOrVectorInput(
id='col2',
fields={},
geometry_types=['POLYGON']
)
])
with self.assertRaises(ValueError) as cm:
input_spec.get_validated_dataframe(csv_path)
self.assertIn(
"Geometry type must be one of ['POLYGON']",
str(cm.exception))
class TestValidationFromSpec(unittest.TestCase):
"""Test Validation From Spec."""
def setUp(self):
"""Create a new workspace to use for each test."""
self.workspace_dir = tempfile.mkdtemp()
def tearDown(self):
"""Remove the workspace created for this test."""
shutil.rmtree(self.workspace_dir)
def test_conditional_requirement(self):
"""Validation: check that conditional requirements works."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a"),
NumberInput(id="number_b", required=False),
NumberInput(id="number_c", required="number_b"),
NumberInput(id="number_d", required="number_b | number_c"),
NumberInput(id="number_e", required="number_b & number_d"),
NumberInput(id="number_f", required="not number_b")
])
args = {
"number_a": 123,
"number_b": 456,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(sorted(validation_warnings), [
(['number_c', 'number_d'], validation.MESSAGES['MISSING_KEY']),
])
args = {
"number_a": 123,
"number_b": 456,
"number_c": 1,
"number_d": 3,
"number_e": 4,
}
self.assertEqual([], validation.validate(args, model_spec))
args = {
"number_a": 123,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(sorted(validation_warnings), [
(['number_f'], validation.MESSAGES['MISSING_KEY'])
])
def test_conditional_requirement_missing_var(self):
"""Validation: check AssertionError if expression is missing a var."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a"),
NumberInput(id="number_b", required=False),
NumberInput(id="number_c", required="some_var_not_in_args")
])
args = {
"number_a": 123,
"number_b": 456,
}
with self.assertRaises(AssertionError) as cm:
validation_warnings = validation.validate(args, model_spec)
self.assertTrue('some_var_not_in_args' in str(cm.exception))
def test_conditional_requirement_not_required(self):
"""Validation: unrequired conditional requirement should always pass"""
from natcap.invest import validation
csv_a_path = os.path.join(self.workspace_dir, 'csv_a.csv')
csv_b_path = os.path.join(self.workspace_dir, 'csv_b.csv')
# initialize test CSV files
with open(csv_a_path, 'w') as csv:
csv.write('a,b,c')
with open(csv_b_path, 'w') as csv:
csv.write('1,2,3')
model_spec = model_spec_with_defaults(inputs=[
BooleanInput(id="condition", required=False),
CSVInput(id="csv_a", required="condition"),
CSVInput(id="csv_b", required="not condition")
])
args = {
"condition": True,
"csv_a": csv_a_path,
# csv_b is absent, which is okay because it's not required
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(validation_warnings, [])
def test_requirement_missing(self):
"""Validation: verify absolute requirement on missing key."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a", units=u.none)
])
args = {}
self.assertEqual(
[(['number_a'], validation.MESSAGES['MISSING_KEY'])],
validation.validate(args, model_spec))
def test_requirement_no_value(self):
"""Validation: verify absolute requirement without value."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a", units=u.none)
])
args = {'number_a': ''}
self.assertEqual(
[(['number_a'], validation.MESSAGES['MISSING_VALUE'])],
validation.validate(args, model_spec))
args = {'number_a': None}
self.assertEqual(
[(['number_a'], validation.MESSAGES['MISSING_VALUE'])],
validation.validate(args, model_spec))
def test_invalid_value(self):
"""Validation: verify invalidity."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a", units=u.none)
])
args = {'number_a': 'not a number'}
self.assertEqual(
[(['number_a'], validation.MESSAGES['NOT_A_NUMBER'].format(
value=args['number_a']))],
validation.validate(args, model_spec))
def test_conditionally_required_no_value(self):
"""Validation: verify conditional requirement when no value."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a", units=u.none),
StringInput(id="string_a", required="number_a")])
args = {'string_a': None, "number_a": 1}
self.assertEqual(
[(['string_a'], validation.MESSAGES['MISSING_VALUE'])],
validation.validate(args, model_spec))
def test_conditionally_required_invalid(self):
"""Validation: verify conditional validity behavior when invalid."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a", units=u.none),
OptionStringInput(
id="string_a",
required="number_a",
options=['AAA', 'BBB']
)
])
args = {'string_a': "ZZZ", "number_a": 1}
self.assertEqual(
[(['string_a'], validation.MESSAGES['INVALID_OPTION'].format(
option_list=model_spec.get_input('string_a').options))],
validation.validate(args, model_spec))
def test_conditionally_required_vector_fields(self):
"""Validation: conditionally required vector fields."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
NumberInput(
id="some_number",
expression="value > 0.5",
units=u.none
),
VectorInput(
id="vector",
geometry_types=spec.POINTS,
fields=[
RatioInput(id="field_a"),
RatioInput(id="field_b", required="some_number == 2")
]
)
])
def _create_vector(filepath, fields=[]):
gpkg_driver = gdal.GetDriverByName('GPKG')
vector = gpkg_driver.Create(filepath, 0, 0, 0,
gdal.GDT_Unknown)
vector_srs = osr.SpatialReference()
vector_srs.ImportFromEPSG(4326) # WGS84
layer = vector.CreateLayer('layer', vector_srs, ogr.wkbPoint)
for fieldname in fields:
layer.CreateField(ogr.FieldDefn(fieldname, ogr.OFTReal))
new_feature = ogr.Feature(layer.GetLayerDefn())
new_feature.SetGeometry(ogr.CreateGeometryFromWkt('POINT (1 1)'))
layer = None
vector = None
vector_path = os.path.join(self.workspace_dir, 'vector1.gpkg')
_create_vector(vector_path, ['field_a'])
args = {
'some_number': 1,
'vector': vector_path,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(validation_warnings, [])
args = {
'some_number': 2, # trigger validation warning
'vector': vector_path,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(
validation_warnings,
[(['vector'], validation.MESSAGES['MATCHED_NO_HEADERS'].format(
header='field', header_name='field_b'))])
vector_path = os.path.join(self.workspace_dir, 'vector2.gpkg')
_create_vector(vector_path, ['field_a', 'field_b'])
args = {
'some_number': 2, # field_b is present, no validation warning now
'vector': vector_path,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(validation_warnings, [])
def test_conditionally_required_csv_columns(self):
"""Validation: conditionally required csv columns."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
number_input_spec_with_defaults(
id="some_number",
expression="value > 0.5"
),
CSVInput(
id="csv",
columns=[
RatioInput(id="field_a"),
RatioInput(id="field_b", required="some_number == 2")
]
)
])
# Create a CSV file with only field_a
csv_path = os.path.join(self.workspace_dir, 'table1.csv')
with open(csv_path, 'w') as csv_file:
csv_file.write(textwrap.dedent(
"""\
"field_a",
1,"""))
args = {
'some_number': 1,
'csv': csv_path,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(validation_warnings, [])
# trigger validation warning when some_number == 2
args = {
'some_number': 2,
'csv': csv_path,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(
validation_warnings,
[(['csv'], validation.MESSAGES['MATCHED_NO_HEADERS'].format(
header='column', header_name='field_b'))])
# Create a CSV file with both field_a and field_b
csv_path = os.path.join(self.workspace_dir, 'table2.csv')
with open(csv_path, 'w') as csv_file:
csv_file.write(textwrap.dedent(
"""\
"field_a","field_b"
1,2"""))
args = {
'some_number': 2, # field_b is present, no validation warning now
'csv': csv_path,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(validation_warnings, [])
def test_conditionally_required_csv_rows(self):
"""Validation: conditionally required csv rows."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(
inputs=[
number_input_spec_with_defaults(
id="some_number",
expression="value > 0.5"
),
CSVInput(
id="csv",
rows=[
RatioInput(
id="field_a",
required=True),
RatioInput(
id="field_b",
required="some_number == 2"
)
]
)
]
)
# Create a CSV file with only field_a
csv_path = os.path.join(self.workspace_dir, 'table1.csv')
with open(csv_path, 'w') as csv_file:
csv_file.write(textwrap.dedent(
""""field_a",1"""))
args = {
'some_number': 1,
'csv': csv_path,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(validation_warnings, [])
# trigger validation warning when some_number == 2
args = {
'some_number': 2,
'csv': csv_path,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(
validation_warnings,
[(['csv'], validation.MESSAGES['MATCHED_NO_HEADERS'].format(
header='row', header_name='field_b'))])
# Create a CSV file with both field_a and field_b
csv_path = os.path.join(self.workspace_dir, 'table2.csv')
with open(csv_path, 'w') as csv_file:
csv_file.write(textwrap.dedent(
"""\
"field_a",1
"field_b",2"""))
args = {
'some_number': 2, # field_b is present, no validation warning now
'csv': csv_path,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(validation_warnings, [])
def test_validation_exception(self):
"""Validation: Verify error when an unexpected exception occurs."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
NumberInput(id="number_a")
])
args = {'number_a': 1}
# Patch in a new function that raises an exception
with unittest.mock.patch('natcap.invest.spec.NumberInput.validate',
Mock(side_effect=ValueError('foo'))):
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(
validation_warnings,
[(['number_a'], validation.MESSAGES['UNEXPECTED_ERROR'])])
def test_conditionally_required_directory_contents(self):
"""Validation: conditionally required directory contents."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
NumberInput(
id="some_number",
expression="value > 0.5",
units=u.none
),
DirectoryInput(
id="directory",
contents=[
CSVInput(
id="file.1",
required=True,
),
CSVInput(
id="file.2",
required="some_number == 2"
)
]
)
])
path_1 = os.path.join(self.workspace_dir, 'file.1')
with open(path_1, 'w') as my_file:
my_file.write('col1,col2')
args = {
'some_number': 1,
'directory': self.workspace_dir,
}
self.assertEqual([], validation.validate(args, model_spec))
path_2 = os.path.join(self.workspace_dir, 'file.2')
with open(path_2, 'w') as my_file:
my_file.write('col1,col2')
args = {
'some_number': 2,
'directory': self.workspace_dir,
}
self.assertEqual([], validation.validate(args, model_spec))
os.remove(path_2)
self.assertFalse(os.path.exists(path_2))
args = {
'some_number': 2,
'directory': self.workspace_dir,
}
# TODO: directory contents are not actually validated right now
self.assertEqual([], validation.validate(args, model_spec))
def test_conditional_validity_recursive(self):
"""Validation: check that we can require from nested conditions."""
from natcap.invest import validation
specs = []
previous_key = None
args = {}
for letter in string.ascii_uppercase[:10]:
key = f'arg_{letter}'
specs.append(StringInput(
id=key,
required=previous_key
))
previous_key = key
args[key] = key
del args[previous_key] # delete the last addition to the dict.
model_spec = model_spec_with_defaults(inputs=specs)
self.assertEqual(
[(['arg_J'], validation.MESSAGES['MISSING_KEY'])],
validation.validate(args, model_spec))
def test_spatial_overlap_error(self):
"""Validation: check that we return an error on spatial mismatch."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(
inputs=[
SingleBandRasterInput(
id='raster_a',
data_type=float,
units=u.none
),
SingleBandRasterInput(
id='raster_b',
data_type=float,
units=u.none
),
VectorInput(
id='vector_a',
fields={},
geometry_types={'POINT'}
)
]
)
driver = gdal.GetDriverByName('GTiff')
filepath_1 = os.path.join(self.workspace_dir, 'raster_1.tif')
filepath_2 = os.path.join(self.workspace_dir, 'raster_2.tif')
reference_filepath = os.path.join(self.workspace_dir, 'reference.gpkg')
# Filepaths 1 and 2 are obviously outside of UTM zone 31N.
for filepath, geotransform, epsg_code in (
(filepath_1, [1, 1, 0, 1, 0, 1], 4326),
(filepath_2, [100, 1, 0, 100, 0, 1], 4326)):
raster = driver.Create(filepath, 3, 3, 1, gdal.GDT_Int32)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(epsg_code)
raster.SetProjection(wgs84_srs.ExportToWkt())
raster.SetGeoTransform(geotransform)
raster = None
gpkg_driver = gdal.GetDriverByName('GPKG')
vector = gpkg_driver.Create(reference_filepath, 0, 0, 0,
gdal.GDT_Unknown)
vector_srs = osr.SpatialReference()
vector_srs.ImportFromEPSG(32731) # UTM 31N
layer = vector.CreateLayer('layer', vector_srs, ogr.wkbPoint)
new_feature = ogr.Feature(layer.GetLayerDefn())
new_feature.SetGeometry(ogr.CreateGeometryFromWkt('POINT (1 1)'))
layer.CreateFeature(new_feature)
new_feature = None
layer = None
vector = None
args = {
'raster_a': filepath_1,
'raster_b': filepath_2,
'vector_a': reference_filepath,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(len(validation_warnings), 1)
self.assertEqual(set(args.keys()), set(validation_warnings[0][0]))
formatted_bbox_list = '' # allows str matching w/o real bbox str
self.assertTrue(
validation.MESSAGES['BBOX_NOT_INTERSECT'].format(
bboxes=formatted_bbox_list) in validation_warnings[0][1])
def test_spatial_overlap_error_undefined_projection(self):
"""Validation: check spatial overlap message when no projection"""
from natcap.invest import validation
model_spec = model_spec_with_defaults(
inputs=[
SingleBandRasterInput(
id='raster_a',
data_type=float,
units=u.none
),
SingleBandRasterInput(
id='raster_b',
data_type=float,
units=u.none
)
]
)
driver = gdal.GetDriverByName('GTiff')
filepath_1 = os.path.join(self.workspace_dir, 'raster_1.tif')
filepath_2 = os.path.join(self.workspace_dir, 'raster_2.tif')
raster_1 = driver.Create(filepath_1, 3, 3, 1, gdal.GDT_Int32)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(4326)
raster_1.SetProjection(wgs84_srs.ExportToWkt())
raster_1.SetGeoTransform([1, 1, 0, 1, 0, 1])
raster_1 = None
# don't define a projection for the second raster
driver.Create(filepath_2, 3, 3, 1, gdal.GDT_Int32)
args = {
'raster_a': filepath_1,
'raster_b': filepath_2
}
validation_warnings = validation.validate(args, model_spec)
expected = [(['raster_b'], validation.MESSAGES['INVALID_PROJECTION'])]
self.assertEqual(validation_warnings, expected)
def test_spatial_overlap_error_optional_args(self):
"""Validation: check for spatial mismatch with insufficient args."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(
inputs=[
SingleBandRasterInput(
id='raster_a',
data_type=float,
units=u.none
),
SingleBandRasterInput(
id='raster_b',
data_type=float,
units=u.none,
required=False
),
VectorInput(
id='vector_a',
required=False,
fields=[],
geometry_types={'POINT'}
)
]
)
driver = gdal.GetDriverByName('GTiff')
filepath_1 = os.path.join(self.workspace_dir, 'raster_1.tif')
filepath_2 = os.path.join(self.workspace_dir, 'raster_2.tif')
# Filepaths 1 and 2 do not overlap
for filepath, geotransform, epsg_code in (
(filepath_1, [1, 1, 0, 1, 0, 1], 4326),
(filepath_2, [100, 1, 0, 100, 0, 1], 4326)):
raster = driver.Create(filepath, 3, 3, 1, gdal.GDT_Int32)
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(epsg_code)
raster.SetProjection(wgs84_srs.ExportToWkt())
raster.SetGeoTransform(geotransform)
raster = None
args = {
'raster_a': filepath_1,
}
# There should not be a spatial overlap check at all
# when less than 2 of the spatial keys are sufficient.
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(len(validation_warnings), 0)
# And even though there are three spatial keys in the spec,
# Only the ones checked should appear in the validation output
args = {
'raster_a': filepath_1,
'raster_b': filepath_2,
}
validation_warnings = validation.validate(args, model_spec)
self.assertEqual(len(validation_warnings), 1)
formatted_bbox_list = '' # allows str matching w/o real bbox str
self.assertTrue(
validation.MESSAGES['BBOX_NOT_INTERSECT'].format(
bboxes=formatted_bbox_list) in validation_warnings[0][1])
self.assertEqual(set(args.keys()), set(validation_warnings[0][0]))
def test_allow_extra_keys(self):
"""Including extra keys in args that aren't in MODEL_SPEC should work"""
from natcap.invest import validation
args = {'a': 'a', 'b': 'b'}
model_spec = model_spec_with_defaults(inputs=[StringInput(id='a')])
message = 'DEBUG:natcap.invest.validation:Provided key b does not exist in MODEL_SPEC'
with self.assertLogs('natcap.invest.validation', level='DEBUG') as cm:
validation.validate(args, model_spec)
self.assertTrue(message in cm.output)
def test_check_ratio(self):
"""Validation: test ratio type validation."""
from natcap.invest import validation
args = {
'a': 'xyz', # not a number
'b': '1.5', # too large
'c': '-1', # too small
'd': '0', # lower bound
'e': '0.5', # middle
'f': '1' # upper bound
}
model_spec = model_spec_with_defaults(inputs=[RatioInput(id=name) for name in args])
expected_warnings = [
(['a'], validation.MESSAGES['NOT_A_NUMBER'].format(value=args['a'])),
(['b'], validation.MESSAGES['NOT_WITHIN_RANGE'].format(
value=args['b'], range='[0, 1]')),
(['c'], validation.MESSAGES['NOT_WITHIN_RANGE'].format(
value=float(args['c']), range='[0, 1]'))]
actual_warnings = validation.validate(args, model_spec)
for warning in actual_warnings:
self.assertTrue(warning in expected_warnings)
def test_check_percent(self):
"""Validation: test percent type validation."""
from natcap.invest import validation
args = {
'a': 'xyz', # not a number
'b': '100.5', # too large
'c': '-1', # too small
'd': '0', # lower bound
'e': '55.5', # middle
'f': '100' # upper bound
}
model_spec = model_spec_with_defaults(
inputs=[PercentInput(id=name) for name in args])
expected_warnings = [
(['a'], validation.MESSAGES['NOT_A_NUMBER'].format(value=args['a'])),
(['b'], validation.MESSAGES['NOT_WITHIN_RANGE'].format(
value=args['b'], range='[0, 100]')),
(['c'], validation.MESSAGES['NOT_WITHIN_RANGE'].format(
value=float(args['c']), range='[0, 100]'))]
actual_warnings = validation.validate(args, model_spec)
for warning in actual_warnings:
self.assertTrue(warning in expected_warnings)
def test_check_integer(self):
"""Validation: test integer type validation."""
from natcap.invest import validation
args = {
'a': 'xyz', # not a number
'b': '1.5', # not an integer
'c': '-1', # negative integers are ok
'd': '0'
}
model_spec = model_spec_with_defaults(inputs=[
IntegerInput(id=name) for name in args])
expected_warnings = [
(['a'], validation.MESSAGES['NOT_A_NUMBER'].format(value=args['a'])),
(['b'], validation.MESSAGES['NOT_AN_INTEGER'].format(value=args['b']))]
actual_warnings = validation.validate(args, model_spec)
self.assertEqual(len(actual_warnings), len(expected_warnings))
for warning in actual_warnings:
self.assertTrue(warning in expected_warnings)
class TestArgsEnabled(unittest.TestCase):
def test_args_enabled(self):
"""Validation: test getting args enabled/disabled status."""
from natcap.invest import validation
model_spec = model_spec_with_defaults(inputs=[
Input(id='a'),
Input(id='b', allowed='a'),
Input(id='c', allowed='not a'),
Input(id='d', allowed='b <= 3')
])
args = {
'a': 'foo',
'b': 2,
'c': 'bar',
'd': None
}
self.assertEqual(
validation.args_enabled(args, model_spec),
{
'a': True,
'b': True,
'c': False,
'd': True
}
)