redesign SingleBandRasterInput and create RasterBand and RasterInput classes

This commit is contained in:
Emily Soth 2025-05-13 18:40:21 -07:00
parent a78e0df99f
commit 14d696e90e
7 changed files with 145 additions and 107 deletions

View File

@ -274,15 +274,32 @@ class FileInput(Input):
@dataclasses.dataclass
class SingleBandRasterInput(FileInput):
class RasterBand(FileInput):
"""A single-band raster input, or parameter, of an invest model.
This represents a raster file input (all GDAL-supported raster file types
are allowed), where only the first band is needed.
Attributes:
band: An `Input` representing the type of data expected in the
raster's first and only band
band_id: band index used to access the raster band
data_type: float or int
units: units of measurement of the raster band values
"""
band_id: typing.Union[int, str] = 1
data_type: typing.Type = float
units: typing.Union[pint.Unit, None] = None
@dataclasses.dataclass
class RasterInput(FileInput):
"""A raster input, or parameter, of an invest model.
This represents a raster file input (all GDAL-supported raster file types
are allowed), which may have multiple bands.
Attributes:
bands: An iterable of `RasterBand`s representing the bands expected
to be in the raster.
projected: Defaults to None, indicating a projected (as opposed to
geographic) coordinate system is not required. Set to True if a
projected coordinate system is required.
@ -290,7 +307,63 @@ class SingleBandRasterInput(FileInput):
specific unit of projection (such as meters) is required, indicate
it here.
"""
band: typing.Union[Input, None] = None
bands: typing.Iterable[RasterBand] = dataclasses.field(default_factory=[])
projected: typing.Union[bool, None] = None
projection_units: typing.Union[pint.Unit, None] = None
type: typing.ClassVar[str] = 'raster'
@timeout
def validate(self, filepath: str):
"""Validate a raster file against the requirements for this input.
Args:
filepath (string): The filepath to validate.
Returns:
A string error message if an error was found. ``None`` otherwise.
"""
# use FileInput instead of super() because when this is called from
# RasterOrVectorInput.validate, super() refers to multiple parents.
file_warning = FileInput.validate(self, filepath)
if file_warning:
return file_warning
try:
gdal_dataset = gdal.OpenEx(filepath, gdal.OF_RASTER)
except RuntimeError:
return get_message('NOT_GDAL_RASTER')
# Check that an overview .ovr file wasn't opened.
if os.path.splitext(filepath)[1] == '.ovr':
return get_message('OVR_FILE')
srs = gdal_dataset.GetSpatialRef()
projection_warning = _check_projection(srs, self.projected, self.projection_units)
if projection_warning:
return projection_warning
@dataclasses.dataclass
class SingleBandRasterInput(FileInput):
"""A single-band raster input, or parameter, of an invest model.
This represents a raster file input (all GDAL-supported raster file types
are allowed), where only the first band is needed. While he same thing can
be achieved using a `RasterInput`, this class exists to simplify access to
the band properties when there is only one band.
Attributes:
data_type: float or int
units: units of measurement of the raster values
projected: Defaults to None, indicating a projected (as opposed to
geographic) coordinate system is not required. Set to True if a
projected coordinate system is required.
projection_units: Defaults to None. If `projected` is `True`, and a
specific unit of projection (such as meters) is required, indicate
it here.
"""
data_type: typing.Type = float
units: typing.Union[pint.Unit, None] = None
projected: typing.Union[bool, None] = None
projection_units: typing.Union[pint.Unit, None] = None
type: typing.ClassVar[str] = 'raster'
@ -989,10 +1062,11 @@ class SingleBandRasterOutput(Output):
are allowed), where only the first band is used.
Attributes:
band: An `Output` representing the type of data produced in the
raster's first and only band
data_type: float or int
units: units of measurement of the raster values
"""
band: typing.Union[Output, None] = None
data_type: typing.Type = float
units: typing.Union[pint.Unit, None] = None
@dataclasses.dataclass
@ -1168,6 +1242,10 @@ class ModelSpec:
return as_dict
elif isinstance(obj, IterableWithDotAccess):
return obj.to_json()
elif obj is int:
return 'integer'
elif obj is float:
return 'number'
raise TypeError(f'fallback serializer is missing for {type(obj)}')
spec_dict = self.__dict__.copy()
@ -1241,7 +1319,8 @@ def build_input_spec(argkey, arg):
elif t == 'raster':
return SingleBandRasterInput(
**base_attrs,
band=build_input_spec('1', arg['bands'][1]),
data_type=int if arg['bands'][1]['type'] == 'integer' else float,
units=arg['bands'][1].get('units', None),
projected=arg.get('projected', None),
projection_units=arg.get('projection_units', None))
@ -1287,7 +1366,8 @@ def build_input_spec(argkey, arg):
geometry_types=arg['geometries'],
fields=[build_input_spec(key, field_spec)
for key, field_spec in arg['fields'].items()],
band=build_input_spec('1', arg['bands'][1]),
data_type=int if arg['bands'][1]['type'] == 'integer' else float,
units=arg['bands'][1].get('units', None),
projected=arg.get('projected', None),
projection_units=arg.get('projection_units', None))
@ -1337,7 +1417,8 @@ def build_output_spec(key, spec):
elif t == 'raster':
return SingleBandRasterOutput(
**base_attrs,
band=build_output_spec(1, spec['bands'][1]))
data_type=int if spec['bands'][1]['type'] == 'integer' else float,
units=spec['bands'][1].get('units', None))
elif t == 'vector':
return VectorOutput(
@ -1596,6 +1677,9 @@ def format_unit(unit):
Returns:
String describing the unit.
"""
if unit is None:
return ''
if not isinstance(unit, pint.Unit):
raise TypeError(
f'{unit} is of type {type(unit)}. '
@ -1636,36 +1720,6 @@ def format_unit(unit):
return formatted_unit
def serialize_args_spec(spec):
"""Serialize an MODEL_SPEC dict to a JSON string.
Args:
spec (dict): An invest model's MODEL_SPEC.
Raises:
TypeError if any object type within the spec is not handled by
json.dumps or by the fallback serializer.
Returns:
JSON String
"""
def fallback_serializer(obj):
"""Serialize objects that are otherwise not JSON serializeable."""
if isinstance(obj, pint.Unit):
return format_unit(obj)
# Sets are present in 'geometry_types' attributes of some args
# We don't need to worry about deserializing back to a set/array
# so casting to string is okay.
elif isinstance(obj, set):
return str(obj)
elif isinstance(obj, types.FunctionType):
return str(obj)
raise TypeError(f'fallback serializer is missing for {type(obj)}')
return json.dumps(spec, default=fallback_serializer)
# accepted geometry_types for a vector will be displayed in this order
GEOMETRY_ORDER = [
'POINT',
@ -1868,11 +1922,7 @@ def describe_arg_from_spec(name, spec):
in_parentheses = [type_string]
# For numbers and rasters that have units, display the units
units = None
if spec.__class__ is NumberInput:
units = spec.units
elif spec.__class__ is SingleBandRasterInput and spec.band.__class__ is NumberInput:
units = spec.band.units
units = spec.units if hasattr(spec, 'units') else None
if units:
units_string = format_unit(units)
if units_string:
@ -2005,7 +2055,6 @@ def write_metadata_file(datasource_path, spec, keywords_list,
except ValueError as e:
LOGGER.debug(f"Skipping metadata creation for {datasource_path}: {e}")
return None
resource.set_lineage(lineage_statement)
# a pre-existing metadata doc could have keywords
words = resource.get_keywords()
@ -2039,13 +2088,9 @@ def write_metadata_file(datasource_path, spec, keywords_list,
# fields that are in the spec but missing
# from model results because they are conditional.
LOGGER.debug(error)
if hasattr(spec, 'band'):
if isinstance(spec, SingleBandRasterInput) or isinstance(spec, SingleBandRasterOutput):
if len(resource.get_band_description(1).units) < 1:
try:
units = format_unit(spec.band.units)
except AttributeError:
units = ''
units = format_unit(spec.units)
resource.set_band_description(1, units=units)
resource.write(workspace=out_workspace)

View File

@ -8,7 +8,7 @@ MODEL_SPEC = spec.ModelSpec(inputs=[
spec.FileInput(id='foo'),
spec.FileInput(id='bar'),
spec.DirectoryInput(id='data_dir', contents={}),
spec.SingleBandRasterInput(id='raster', band=spec.Input()),
spec.SingleBandRasterInput(id='raster'),
spec.VectorInput(id='vector', fields={}, geometry_types={}),
spec.CSVInput(id='simple_table'),
spec.CSVInput(
@ -18,8 +18,7 @@ MODEL_SPEC = spec.ModelSpec(inputs=[
spec.RasterOrVectorInput(
id='path',
fields={},
geometry_types={'POINT', 'POLYGON'},
band=spec.NumberInput()
geometry_types={'POINT', 'POLYGON'}
)
]
)],

View File

@ -1,7 +1,7 @@
from natcap.invest import spec
MODEL_SPEC = spec.ModelSpec(inputs=[
spec.SingleBandRasterInput(id='raster', band=spec.Input())],
spec.SingleBandRasterInput(id='raster')],
outputs={},
model_id='raster_model',
model_title='',

View File

@ -192,13 +192,8 @@ class ValidateModelSpecs(unittest.TestCase):
self.assertIsInstance(output_spec.units, pint.Unit)
elif t is spec.SingleBandRasterOutput:
# raster type should have a bands property that maps each band
# index to a nested type dictionary describing the band's data
self.assertTrue(hasattr(output_spec, 'band'))
self.validate_output(
output_spec.band,
f'{key}.band',
parent_type=t)
self.assertTrue(hasattr(output_spec, 'data_type'))
self.assertTrue(hasattr(output_spec, 'units'))
elif t is spec.VectorOutput:
# vector type should have:
@ -347,13 +342,8 @@ class ValidateModelSpecs(unittest.TestCase):
self.assertIsInstance(arg.expression, str)
elif t is spec.SingleBandRasterInput:
# raster type should have a bands property that maps each band
# index to a nested type dictionary describing the band's data
self.assertTrue(hasattr(arg, 'band'))
self.validate_args(
arg.band,
f'{name}.band',
parent_type=t)
self.assertTrue(hasattr(arg, 'data_type'))
self.assertTrue(hasattr(arg, 'units'))
# may optionally have a 'projected' attribute that says
# whether the raster must be linearly projected

View File

@ -147,7 +147,7 @@ class TestDescribeArgFromSpec(unittest.TestCase):
def test_raster_spec(self):
raster_spec = spec.SingleBandRasterInput(
band=spec.IntegerInput(),
data_type=int,
about="Description",
name="Bar"
)
@ -158,7 +158,8 @@ class TestDescribeArgFromSpec(unittest.TestCase):
self.assertEqual(repr(out), repr(expected_rst))
raster_spec = spec.SingleBandRasterInput(
band=spec.NumberInput(units=u.millimeter/u.year),
data_type=float,
units=u.millimeter/u.year,
about="Description",
name="Bar"
)
@ -250,7 +251,7 @@ class TestDescribeArgFromSpec(unittest.TestCase):
multi_spec = spec.RasterOrVectorInput(
about="Description",
name="Bar",
band=spec.IntegerInput(),
data_type=int,
geometry_types={"POLYGON"},
fields={}
)
@ -281,12 +282,12 @@ def _generate_files_from_spec(output_spec, workspace):
spec_data.contents, os.path.join(workspace, spec_data.id))
else:
filepath = os.path.join(workspace, spec_data.id)
if hasattr(spec_data, 'band'):
if isinstance(spec_data, spec.SingleBandRasterOutput):
driver = gdal.GetDriverByName('GTIFF')
raster = driver.Create(filepath, 2, 2, 1, gdal.GDT_Byte)
band = raster.GetRasterBand(1)
band.SetNoDataValue(2)
elif hasattr(spec_data, 'fields'):
elif isinstance(spec_data, spec.VectorOutput):
driver = gdal.GetDriverByName('GPKG')
target_vector = driver.CreateDataSource(filepath)
layer_name = os.path.basename(os.path.splitext(filepath)[0])
@ -322,7 +323,8 @@ class TestMetadataFromSpec(unittest.TestCase):
spec.SingleBandRasterOutput(
id="urban_nature_supply_percapita.tif",
about="The calculated supply per capita of urban nature.",
band=spec.NumberInput(units=u.m**2)
data_type=float,
units=u.m**2
),
spec.VectorOutput(
id="admin_boundaries.gpkg",

View File

@ -68,10 +68,10 @@ class UsageLoggingTests(unittest.TestCase):
model_spec = spec.ModelSpec(
model_id='', model_title='', userguide=None, aliases=None,
inputs=[
spec.SingleBandRasterInput(id='raster', band=spec.Input()),
spec.SingleBandRasterInput(id='raster'),
spec.VectorInput(id='vector', geometry_types={}, fields={}),
spec.StringInput(id='not_a_gis_input'),
spec.SingleBandRasterInput(id='blank_raster_path', band=spec.Input()),
spec.SingleBandRasterInput(id='blank_raster_path'),
spec.VectorInput(id='blank_vector_path', geometry_types={}, fields={})
],
outputs={},

View File

@ -21,21 +21,21 @@ from osgeo import osr
from natcap.invest import spec
from natcap.invest.spec import (
u,
ModelSpec,
Input,
FileInput,
CSVInput,
StringInput,
OptionStringInput,
SingleBandRasterInput,
RasterOrVectorInput,
DirectoryInput,
VectorInput,
BooleanInput,
NumberInput,
CSVInput,
DirectoryInput,
FileInput,
Input,
IntegerInput,
ModelSpec,
NumberInput,
OptionStringInput,
PercentInput,
RasterOrVectorInput,
RatioInput,
PercentInput)
SingleBandRasterInput,
StringInput,
VectorInput)
gdal.UseExceptions()
@ -520,7 +520,7 @@ class RasterValidation(unittest.TestCase):
from natcap.invest import validation
filepath = os.path.join(self.workspace_dir, 'file.txt')
error_msg = SingleBandRasterInput(band=Input()).validate(filepath)
error_msg = SingleBandRasterInput().validate(filepath)
self.assertEqual(error_msg, validation.MESSAGES['FILE_NOT_FOUND'])
def test_invalid_raster(self):
@ -531,7 +531,7 @@ class RasterValidation(unittest.TestCase):
with open(filepath, 'w') as bad_raster:
bad_raster.write('not a raster')
error_msg = SingleBandRasterInput(band=Input()).validate(filepath)
error_msg = SingleBandRasterInput().validate(filepath)
self.assertEqual(error_msg, validation.MESSAGES['NOT_GDAL_RASTER'])
def test_invalid_ovr_raster(self):
@ -555,8 +555,7 @@ class RasterValidation(unittest.TestCase):
raster = None
filepath_ovr = os.path.join(self.workspace_dir, 'raster.tif.ovr')
error_msg = SingleBandRasterInput(
band=Input()).validate(filepath_ovr)
error_msg = SingleBandRasterInput().validate(filepath_ovr)
self.assertEqual(error_msg, validation.MESSAGES['OVR_FILE'])
def test_raster_not_projected(self):
@ -572,8 +571,7 @@ class RasterValidation(unittest.TestCase):
raster.SetProjection(wgs84_srs.ExportToWkt())
raster = None
error_msg = SingleBandRasterInput(
band=Input(), projected=True).validate(filepath)
error_msg = SingleBandRasterInput(projected=True).validate(filepath)
self.assertEqual(error_msg, validation.MESSAGES['NOT_PROJECTED'])
def test_raster_incorrect_units(self):
@ -591,7 +589,7 @@ class RasterValidation(unittest.TestCase):
raster = None
error_msg = SingleBandRasterInput(
band=Input(), projected=True, projection_units=spec.u.meter
projected=True, projection_units=spec.u.meter
).validate(filepath)
expected_msg = validation.MESSAGES['WRONG_PROJECTION_UNIT'].format(
unit_a='meter', unit_b='us_survey_foot')
@ -1491,7 +1489,7 @@ class TestGetValidatedDataframe(unittest.TestCase):
spec = CSVInput(columns=[
NumberInput(id='col1'),
SingleBandRasterInput(id='col2', band=NumberInput())
SingleBandRasterInput(id='col2')
])
with self.assertRaises(ValueError) as cm:
spec.get_validated_dataframe(csv_path)
@ -1516,8 +1514,7 @@ class TestGetValidatedDataframe(unittest.TestCase):
spec = CSVInput(columns=[
NumberInput(id='col1'),
SingleBandRasterInput(
id='col2', projected=True, band=NumberInput())
SingleBandRasterInput(id='col2', projected=True)
])
with self.assertRaises(ValueError) as cm:
spec.get_validated_dataframe(csv_path)
@ -1584,7 +1581,6 @@ class TestGetValidatedDataframe(unittest.TestCase):
NumberInput(id='col1'),
RasterOrVectorInput(
id='col2',
band=NumberInput(),
fields={},
geometry_types=['POLYGON']
)
@ -2041,11 +2037,13 @@ class TestValidationFromSpec(unittest.TestCase):
inputs=[
SingleBandRasterInput(
id='raster_a',
band=NumberInput(units=u.none)
data_type=float,
units=u.none
),
SingleBandRasterInput(
id='raster_b',
band=NumberInput(units=u.none)
data_type=float,
units=u.none
),
VectorInput(
id='vector_a',
@ -2111,11 +2109,13 @@ class TestValidationFromSpec(unittest.TestCase):
inputs=[
SingleBandRasterInput(
id='raster_a',
band=NumberInput(units=u.none)
data_type=float,
units=u.none
),
SingleBandRasterInput(
id='raster_b',
band=NumberInput(units=u.none)
data_type=float,
units=u.none
)
],
args_with_spatial_overlap={
@ -2155,11 +2155,13 @@ class TestValidationFromSpec(unittest.TestCase):
inputs=[
SingleBandRasterInput(
id='raster_a',
band=NumberInput(units=u.none)
data_type=float,
units=u.none
),
SingleBandRasterInput(
id='raster_b',
band=NumberInput(units=u.none),
data_type=float,
units=u.none,
required=False
),
VectorInput(