redesign SingleBandRasterInput and create RasterBand and RasterInput classes
This commit is contained in:
parent
a78e0df99f
commit
14d696e90e
|
@ -274,15 +274,32 @@ class FileInput(Input):
|
|||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SingleBandRasterInput(FileInput):
|
||||
class RasterBand(FileInput):
|
||||
"""A single-band raster input, or parameter, of an invest model.
|
||||
|
||||
This represents a raster file input (all GDAL-supported raster file types
|
||||
are allowed), where only the first band is needed.
|
||||
|
||||
Attributes:
|
||||
band: An `Input` representing the type of data expected in the
|
||||
raster's first and only band
|
||||
band_id: band index used to access the raster band
|
||||
data_type: float or int
|
||||
units: units of measurement of the raster band values
|
||||
"""
|
||||
band_id: typing.Union[int, str] = 1
|
||||
data_type: typing.Type = float
|
||||
units: typing.Union[pint.Unit, None] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RasterInput(FileInput):
|
||||
"""A raster input, or parameter, of an invest model.
|
||||
|
||||
This represents a raster file input (all GDAL-supported raster file types
|
||||
are allowed), which may have multiple bands.
|
||||
|
||||
Attributes:
|
||||
bands: An iterable of `RasterBand`s representing the bands expected
|
||||
to be in the raster.
|
||||
projected: Defaults to None, indicating a projected (as opposed to
|
||||
geographic) coordinate system is not required. Set to True if a
|
||||
projected coordinate system is required.
|
||||
|
@ -290,7 +307,63 @@ class SingleBandRasterInput(FileInput):
|
|||
specific unit of projection (such as meters) is required, indicate
|
||||
it here.
|
||||
"""
|
||||
band: typing.Union[Input, None] = None
|
||||
bands: typing.Iterable[RasterBand] = dataclasses.field(default_factory=[])
|
||||
projected: typing.Union[bool, None] = None
|
||||
projection_units: typing.Union[pint.Unit, None] = None
|
||||
type: typing.ClassVar[str] = 'raster'
|
||||
|
||||
@timeout
|
||||
def validate(self, filepath: str):
|
||||
"""Validate a raster file against the requirements for this input.
|
||||
|
||||
Args:
|
||||
filepath (string): The filepath to validate.
|
||||
|
||||
Returns:
|
||||
A string error message if an error was found. ``None`` otherwise.
|
||||
"""
|
||||
# use FileInput instead of super() because when this is called from
|
||||
# RasterOrVectorInput.validate, super() refers to multiple parents.
|
||||
file_warning = FileInput.validate(self, filepath)
|
||||
if file_warning:
|
||||
return file_warning
|
||||
|
||||
try:
|
||||
gdal_dataset = gdal.OpenEx(filepath, gdal.OF_RASTER)
|
||||
except RuntimeError:
|
||||
return get_message('NOT_GDAL_RASTER')
|
||||
|
||||
# Check that an overview .ovr file wasn't opened.
|
||||
if os.path.splitext(filepath)[1] == '.ovr':
|
||||
return get_message('OVR_FILE')
|
||||
|
||||
srs = gdal_dataset.GetSpatialRef()
|
||||
projection_warning = _check_projection(srs, self.projected, self.projection_units)
|
||||
if projection_warning:
|
||||
return projection_warning
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SingleBandRasterInput(FileInput):
|
||||
"""A single-band raster input, or parameter, of an invest model.
|
||||
|
||||
This represents a raster file input (all GDAL-supported raster file types
|
||||
are allowed), where only the first band is needed. While he same thing can
|
||||
be achieved using a `RasterInput`, this class exists to simplify access to
|
||||
the band properties when there is only one band.
|
||||
|
||||
Attributes:
|
||||
data_type: float or int
|
||||
units: units of measurement of the raster values
|
||||
projected: Defaults to None, indicating a projected (as opposed to
|
||||
geographic) coordinate system is not required. Set to True if a
|
||||
projected coordinate system is required.
|
||||
projection_units: Defaults to None. If `projected` is `True`, and a
|
||||
specific unit of projection (such as meters) is required, indicate
|
||||
it here.
|
||||
"""
|
||||
data_type: typing.Type = float
|
||||
units: typing.Union[pint.Unit, None] = None
|
||||
projected: typing.Union[bool, None] = None
|
||||
projection_units: typing.Union[pint.Unit, None] = None
|
||||
type: typing.ClassVar[str] = 'raster'
|
||||
|
@ -989,10 +1062,11 @@ class SingleBandRasterOutput(Output):
|
|||
are allowed), where only the first band is used.
|
||||
|
||||
Attributes:
|
||||
band: An `Output` representing the type of data produced in the
|
||||
raster's first and only band
|
||||
data_type: float or int
|
||||
units: units of measurement of the raster values
|
||||
"""
|
||||
band: typing.Union[Output, None] = None
|
||||
data_type: typing.Type = float
|
||||
units: typing.Union[pint.Unit, None] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -1168,6 +1242,10 @@ class ModelSpec:
|
|||
return as_dict
|
||||
elif isinstance(obj, IterableWithDotAccess):
|
||||
return obj.to_json()
|
||||
elif obj is int:
|
||||
return 'integer'
|
||||
elif obj is float:
|
||||
return 'number'
|
||||
raise TypeError(f'fallback serializer is missing for {type(obj)}')
|
||||
|
||||
spec_dict = self.__dict__.copy()
|
||||
|
@ -1241,7 +1319,8 @@ def build_input_spec(argkey, arg):
|
|||
elif t == 'raster':
|
||||
return SingleBandRasterInput(
|
||||
**base_attrs,
|
||||
band=build_input_spec('1', arg['bands'][1]),
|
||||
data_type=int if arg['bands'][1]['type'] == 'integer' else float,
|
||||
units=arg['bands'][1].get('units', None),
|
||||
projected=arg.get('projected', None),
|
||||
projection_units=arg.get('projection_units', None))
|
||||
|
||||
|
@ -1287,7 +1366,8 @@ def build_input_spec(argkey, arg):
|
|||
geometry_types=arg['geometries'],
|
||||
fields=[build_input_spec(key, field_spec)
|
||||
for key, field_spec in arg['fields'].items()],
|
||||
band=build_input_spec('1', arg['bands'][1]),
|
||||
data_type=int if arg['bands'][1]['type'] == 'integer' else float,
|
||||
units=arg['bands'][1].get('units', None),
|
||||
projected=arg.get('projected', None),
|
||||
projection_units=arg.get('projection_units', None))
|
||||
|
||||
|
@ -1337,7 +1417,8 @@ def build_output_spec(key, spec):
|
|||
elif t == 'raster':
|
||||
return SingleBandRasterOutput(
|
||||
**base_attrs,
|
||||
band=build_output_spec(1, spec['bands'][1]))
|
||||
data_type=int if spec['bands'][1]['type'] == 'integer' else float,
|
||||
units=spec['bands'][1].get('units', None))
|
||||
|
||||
elif t == 'vector':
|
||||
return VectorOutput(
|
||||
|
@ -1596,6 +1677,9 @@ def format_unit(unit):
|
|||
Returns:
|
||||
String describing the unit.
|
||||
"""
|
||||
if unit is None:
|
||||
return ''
|
||||
|
||||
if not isinstance(unit, pint.Unit):
|
||||
raise TypeError(
|
||||
f'{unit} is of type {type(unit)}. '
|
||||
|
@ -1636,36 +1720,6 @@ def format_unit(unit):
|
|||
return formatted_unit
|
||||
|
||||
|
||||
def serialize_args_spec(spec):
|
||||
"""Serialize an MODEL_SPEC dict to a JSON string.
|
||||
|
||||
Args:
|
||||
spec (dict): An invest model's MODEL_SPEC.
|
||||
|
||||
Raises:
|
||||
TypeError if any object type within the spec is not handled by
|
||||
json.dumps or by the fallback serializer.
|
||||
|
||||
Returns:
|
||||
JSON String
|
||||
"""
|
||||
|
||||
def fallback_serializer(obj):
|
||||
"""Serialize objects that are otherwise not JSON serializeable."""
|
||||
if isinstance(obj, pint.Unit):
|
||||
return format_unit(obj)
|
||||
# Sets are present in 'geometry_types' attributes of some args
|
||||
# We don't need to worry about deserializing back to a set/array
|
||||
# so casting to string is okay.
|
||||
elif isinstance(obj, set):
|
||||
return str(obj)
|
||||
elif isinstance(obj, types.FunctionType):
|
||||
return str(obj)
|
||||
raise TypeError(f'fallback serializer is missing for {type(obj)}')
|
||||
|
||||
return json.dumps(spec, default=fallback_serializer)
|
||||
|
||||
|
||||
# accepted geometry_types for a vector will be displayed in this order
|
||||
GEOMETRY_ORDER = [
|
||||
'POINT',
|
||||
|
@ -1868,11 +1922,7 @@ def describe_arg_from_spec(name, spec):
|
|||
in_parentheses = [type_string]
|
||||
|
||||
# For numbers and rasters that have units, display the units
|
||||
units = None
|
||||
if spec.__class__ is NumberInput:
|
||||
units = spec.units
|
||||
elif spec.__class__ is SingleBandRasterInput and spec.band.__class__ is NumberInput:
|
||||
units = spec.band.units
|
||||
units = spec.units if hasattr(spec, 'units') else None
|
||||
if units:
|
||||
units_string = format_unit(units)
|
||||
if units_string:
|
||||
|
@ -2005,7 +2055,6 @@ def write_metadata_file(datasource_path, spec, keywords_list,
|
|||
except ValueError as e:
|
||||
LOGGER.debug(f"Skipping metadata creation for {datasource_path}: {e}")
|
||||
return None
|
||||
|
||||
resource.set_lineage(lineage_statement)
|
||||
# a pre-existing metadata doc could have keywords
|
||||
words = resource.get_keywords()
|
||||
|
@ -2039,13 +2088,9 @@ def write_metadata_file(datasource_path, spec, keywords_list,
|
|||
# fields that are in the spec but missing
|
||||
# from model results because they are conditional.
|
||||
LOGGER.debug(error)
|
||||
|
||||
if hasattr(spec, 'band'):
|
||||
if isinstance(spec, SingleBandRasterInput) or isinstance(spec, SingleBandRasterOutput):
|
||||
if len(resource.get_band_description(1).units) < 1:
|
||||
try:
|
||||
units = format_unit(spec.band.units)
|
||||
except AttributeError:
|
||||
units = ''
|
||||
units = format_unit(spec.units)
|
||||
resource.set_band_description(1, units=units)
|
||||
|
||||
resource.write(workspace=out_workspace)
|
||||
|
|
|
@ -8,7 +8,7 @@ MODEL_SPEC = spec.ModelSpec(inputs=[
|
|||
spec.FileInput(id='foo'),
|
||||
spec.FileInput(id='bar'),
|
||||
spec.DirectoryInput(id='data_dir', contents={}),
|
||||
spec.SingleBandRasterInput(id='raster', band=spec.Input()),
|
||||
spec.SingleBandRasterInput(id='raster'),
|
||||
spec.VectorInput(id='vector', fields={}, geometry_types={}),
|
||||
spec.CSVInput(id='simple_table'),
|
||||
spec.CSVInput(
|
||||
|
@ -18,8 +18,7 @@ MODEL_SPEC = spec.ModelSpec(inputs=[
|
|||
spec.RasterOrVectorInput(
|
||||
id='path',
|
||||
fields={},
|
||||
geometry_types={'POINT', 'POLYGON'},
|
||||
band=spec.NumberInput()
|
||||
geometry_types={'POINT', 'POLYGON'}
|
||||
)
|
||||
]
|
||||
)],
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from natcap.invest import spec
|
||||
|
||||
MODEL_SPEC = spec.ModelSpec(inputs=[
|
||||
spec.SingleBandRasterInput(id='raster', band=spec.Input())],
|
||||
spec.SingleBandRasterInput(id='raster')],
|
||||
outputs={},
|
||||
model_id='raster_model',
|
||||
model_title='',
|
||||
|
|
|
@ -192,13 +192,8 @@ class ValidateModelSpecs(unittest.TestCase):
|
|||
self.assertIsInstance(output_spec.units, pint.Unit)
|
||||
|
||||
elif t is spec.SingleBandRasterOutput:
|
||||
# raster type should have a bands property that maps each band
|
||||
# index to a nested type dictionary describing the band's data
|
||||
self.assertTrue(hasattr(output_spec, 'band'))
|
||||
self.validate_output(
|
||||
output_spec.band,
|
||||
f'{key}.band',
|
||||
parent_type=t)
|
||||
self.assertTrue(hasattr(output_spec, 'data_type'))
|
||||
self.assertTrue(hasattr(output_spec, 'units'))
|
||||
|
||||
elif t is spec.VectorOutput:
|
||||
# vector type should have:
|
||||
|
@ -347,13 +342,8 @@ class ValidateModelSpecs(unittest.TestCase):
|
|||
self.assertIsInstance(arg.expression, str)
|
||||
|
||||
elif t is spec.SingleBandRasterInput:
|
||||
# raster type should have a bands property that maps each band
|
||||
# index to a nested type dictionary describing the band's data
|
||||
self.assertTrue(hasattr(arg, 'band'))
|
||||
self.validate_args(
|
||||
arg.band,
|
||||
f'{name}.band',
|
||||
parent_type=t)
|
||||
self.assertTrue(hasattr(arg, 'data_type'))
|
||||
self.assertTrue(hasattr(arg, 'units'))
|
||||
|
||||
# may optionally have a 'projected' attribute that says
|
||||
# whether the raster must be linearly projected
|
||||
|
|
|
@ -147,7 +147,7 @@ class TestDescribeArgFromSpec(unittest.TestCase):
|
|||
|
||||
def test_raster_spec(self):
|
||||
raster_spec = spec.SingleBandRasterInput(
|
||||
band=spec.IntegerInput(),
|
||||
data_type=int,
|
||||
about="Description",
|
||||
name="Bar"
|
||||
)
|
||||
|
@ -158,7 +158,8 @@ class TestDescribeArgFromSpec(unittest.TestCase):
|
|||
self.assertEqual(repr(out), repr(expected_rst))
|
||||
|
||||
raster_spec = spec.SingleBandRasterInput(
|
||||
band=spec.NumberInput(units=u.millimeter/u.year),
|
||||
data_type=float,
|
||||
units=u.millimeter/u.year,
|
||||
about="Description",
|
||||
name="Bar"
|
||||
)
|
||||
|
@ -250,7 +251,7 @@ class TestDescribeArgFromSpec(unittest.TestCase):
|
|||
multi_spec = spec.RasterOrVectorInput(
|
||||
about="Description",
|
||||
name="Bar",
|
||||
band=spec.IntegerInput(),
|
||||
data_type=int,
|
||||
geometry_types={"POLYGON"},
|
||||
fields={}
|
||||
)
|
||||
|
@ -281,12 +282,12 @@ def _generate_files_from_spec(output_spec, workspace):
|
|||
spec_data.contents, os.path.join(workspace, spec_data.id))
|
||||
else:
|
||||
filepath = os.path.join(workspace, spec_data.id)
|
||||
if hasattr(spec_data, 'band'):
|
||||
if isinstance(spec_data, spec.SingleBandRasterOutput):
|
||||
driver = gdal.GetDriverByName('GTIFF')
|
||||
raster = driver.Create(filepath, 2, 2, 1, gdal.GDT_Byte)
|
||||
band = raster.GetRasterBand(1)
|
||||
band.SetNoDataValue(2)
|
||||
elif hasattr(spec_data, 'fields'):
|
||||
elif isinstance(spec_data, spec.VectorOutput):
|
||||
driver = gdal.GetDriverByName('GPKG')
|
||||
target_vector = driver.CreateDataSource(filepath)
|
||||
layer_name = os.path.basename(os.path.splitext(filepath)[0])
|
||||
|
@ -322,7 +323,8 @@ class TestMetadataFromSpec(unittest.TestCase):
|
|||
spec.SingleBandRasterOutput(
|
||||
id="urban_nature_supply_percapita.tif",
|
||||
about="The calculated supply per capita of urban nature.",
|
||||
band=spec.NumberInput(units=u.m**2)
|
||||
data_type=float,
|
||||
units=u.m**2
|
||||
),
|
||||
spec.VectorOutput(
|
||||
id="admin_boundaries.gpkg",
|
||||
|
|
|
@ -68,10 +68,10 @@ class UsageLoggingTests(unittest.TestCase):
|
|||
model_spec = spec.ModelSpec(
|
||||
model_id='', model_title='', userguide=None, aliases=None,
|
||||
inputs=[
|
||||
spec.SingleBandRasterInput(id='raster', band=spec.Input()),
|
||||
spec.SingleBandRasterInput(id='raster'),
|
||||
spec.VectorInput(id='vector', geometry_types={}, fields={}),
|
||||
spec.StringInput(id='not_a_gis_input'),
|
||||
spec.SingleBandRasterInput(id='blank_raster_path', band=spec.Input()),
|
||||
spec.SingleBandRasterInput(id='blank_raster_path'),
|
||||
spec.VectorInput(id='blank_vector_path', geometry_types={}, fields={})
|
||||
],
|
||||
outputs={},
|
||||
|
|
|
@ -21,21 +21,21 @@ from osgeo import osr
|
|||
from natcap.invest import spec
|
||||
from natcap.invest.spec import (
|
||||
u,
|
||||
ModelSpec,
|
||||
Input,
|
||||
FileInput,
|
||||
CSVInput,
|
||||
StringInput,
|
||||
OptionStringInput,
|
||||
SingleBandRasterInput,
|
||||
RasterOrVectorInput,
|
||||
DirectoryInput,
|
||||
VectorInput,
|
||||
BooleanInput,
|
||||
NumberInput,
|
||||
CSVInput,
|
||||
DirectoryInput,
|
||||
FileInput,
|
||||
Input,
|
||||
IntegerInput,
|
||||
ModelSpec,
|
||||
NumberInput,
|
||||
OptionStringInput,
|
||||
PercentInput,
|
||||
RasterOrVectorInput,
|
||||
RatioInput,
|
||||
PercentInput)
|
||||
SingleBandRasterInput,
|
||||
StringInput,
|
||||
VectorInput)
|
||||
|
||||
gdal.UseExceptions()
|
||||
|
||||
|
@ -520,7 +520,7 @@ class RasterValidation(unittest.TestCase):
|
|||
from natcap.invest import validation
|
||||
|
||||
filepath = os.path.join(self.workspace_dir, 'file.txt')
|
||||
error_msg = SingleBandRasterInput(band=Input()).validate(filepath)
|
||||
error_msg = SingleBandRasterInput().validate(filepath)
|
||||
self.assertEqual(error_msg, validation.MESSAGES['FILE_NOT_FOUND'])
|
||||
|
||||
def test_invalid_raster(self):
|
||||
|
@ -531,7 +531,7 @@ class RasterValidation(unittest.TestCase):
|
|||
with open(filepath, 'w') as bad_raster:
|
||||
bad_raster.write('not a raster')
|
||||
|
||||
error_msg = SingleBandRasterInput(band=Input()).validate(filepath)
|
||||
error_msg = SingleBandRasterInput().validate(filepath)
|
||||
self.assertEqual(error_msg, validation.MESSAGES['NOT_GDAL_RASTER'])
|
||||
|
||||
def test_invalid_ovr_raster(self):
|
||||
|
@ -555,8 +555,7 @@ class RasterValidation(unittest.TestCase):
|
|||
raster = None
|
||||
|
||||
filepath_ovr = os.path.join(self.workspace_dir, 'raster.tif.ovr')
|
||||
error_msg = SingleBandRasterInput(
|
||||
band=Input()).validate(filepath_ovr)
|
||||
error_msg = SingleBandRasterInput().validate(filepath_ovr)
|
||||
self.assertEqual(error_msg, validation.MESSAGES['OVR_FILE'])
|
||||
|
||||
def test_raster_not_projected(self):
|
||||
|
@ -572,8 +571,7 @@ class RasterValidation(unittest.TestCase):
|
|||
raster.SetProjection(wgs84_srs.ExportToWkt())
|
||||
raster = None
|
||||
|
||||
error_msg = SingleBandRasterInput(
|
||||
band=Input(), projected=True).validate(filepath)
|
||||
error_msg = SingleBandRasterInput(projected=True).validate(filepath)
|
||||
self.assertEqual(error_msg, validation.MESSAGES['NOT_PROJECTED'])
|
||||
|
||||
def test_raster_incorrect_units(self):
|
||||
|
@ -591,7 +589,7 @@ class RasterValidation(unittest.TestCase):
|
|||
raster = None
|
||||
|
||||
error_msg = SingleBandRasterInput(
|
||||
band=Input(), projected=True, projection_units=spec.u.meter
|
||||
projected=True, projection_units=spec.u.meter
|
||||
).validate(filepath)
|
||||
expected_msg = validation.MESSAGES['WRONG_PROJECTION_UNIT'].format(
|
||||
unit_a='meter', unit_b='us_survey_foot')
|
||||
|
@ -1491,7 +1489,7 @@ class TestGetValidatedDataframe(unittest.TestCase):
|
|||
|
||||
spec = CSVInput(columns=[
|
||||
NumberInput(id='col1'),
|
||||
SingleBandRasterInput(id='col2', band=NumberInput())
|
||||
SingleBandRasterInput(id='col2')
|
||||
])
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
spec.get_validated_dataframe(csv_path)
|
||||
|
@ -1516,8 +1514,7 @@ class TestGetValidatedDataframe(unittest.TestCase):
|
|||
|
||||
spec = CSVInput(columns=[
|
||||
NumberInput(id='col1'),
|
||||
SingleBandRasterInput(
|
||||
id='col2', projected=True, band=NumberInput())
|
||||
SingleBandRasterInput(id='col2', projected=True)
|
||||
])
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
spec.get_validated_dataframe(csv_path)
|
||||
|
@ -1584,7 +1581,6 @@ class TestGetValidatedDataframe(unittest.TestCase):
|
|||
NumberInput(id='col1'),
|
||||
RasterOrVectorInput(
|
||||
id='col2',
|
||||
band=NumberInput(),
|
||||
fields={},
|
||||
geometry_types=['POLYGON']
|
||||
)
|
||||
|
@ -2041,11 +2037,13 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
inputs=[
|
||||
SingleBandRasterInput(
|
||||
id='raster_a',
|
||||
band=NumberInput(units=u.none)
|
||||
data_type=float,
|
||||
units=u.none
|
||||
),
|
||||
SingleBandRasterInput(
|
||||
id='raster_b',
|
||||
band=NumberInput(units=u.none)
|
||||
data_type=float,
|
||||
units=u.none
|
||||
),
|
||||
VectorInput(
|
||||
id='vector_a',
|
||||
|
@ -2111,11 +2109,13 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
inputs=[
|
||||
SingleBandRasterInput(
|
||||
id='raster_a',
|
||||
band=NumberInput(units=u.none)
|
||||
data_type=float,
|
||||
units=u.none
|
||||
),
|
||||
SingleBandRasterInput(
|
||||
id='raster_b',
|
||||
band=NumberInput(units=u.none)
|
||||
data_type=float,
|
||||
units=u.none
|
||||
)
|
||||
],
|
||||
args_with_spatial_overlap={
|
||||
|
@ -2155,11 +2155,13 @@ class TestValidationFromSpec(unittest.TestCase):
|
|||
inputs=[
|
||||
SingleBandRasterInput(
|
||||
id='raster_a',
|
||||
band=NumberInput(units=u.none)
|
||||
data_type=float,
|
||||
units=u.none
|
||||
),
|
||||
SingleBandRasterInput(
|
||||
id='raster_b',
|
||||
band=NumberInput(units=u.none),
|
||||
data_type=float,
|
||||
units=u.none,
|
||||
required=False
|
||||
),
|
||||
VectorInput(
|
||||
|
|
Loading…
Reference in New Issue