Merge pull request #1690 from emlys/task/1570

enable gdal.UseExceptions in entrypoints and tests
This commit is contained in:
James Douglass 2024-11-21 16:36:26 -08:00 committed by GitHub
commit ec88d787e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 394 additions and 384 deletions

View File

@ -48,6 +48,8 @@ Unreleased Changes
reflect changes in how InVEST is installed on modern systems, and also to reflect changes in how InVEST is installed on modern systems, and also to
include images of the InVEST workbench instead of just broken links. include images of the InVEST workbench instead of just broken links.
https://github.com/natcap/invest/issues/1660 https://github.com/natcap/invest/issues/1660
* natcap.invest now works with (and requires) ``gdal.UseExceptions``. A
``FutureWarning`` is raised on import if GDAL exceptions are not enabled.
* Workbench * Workbench
* Several small updates to the model input form UI to improve usability * Several small updates to the model input form UI to improve usability
and visual consistency (https://github.com/natcap/invest/issues/912). and visual consistency (https://github.com/natcap/invest/issues/912).

View File

@ -4,8 +4,10 @@ import logging
import os import os
import sys import sys
from gettext import translation from gettext import translation
import warnings
import babel import babel
from osgeo import gdal
LOGGER = logging.getLogger('natcap.invest') LOGGER = logging.getLogger('natcap.invest')
LOGGER.addHandler(logging.NullHandler()) LOGGER.addHandler(logging.NullHandler())
@ -28,6 +30,14 @@ LOCALE_NAME_MAP = {
locale: babel.Locale(locale).display_name for locale in LOCALES locale: babel.Locale(locale).display_name for locale in LOCALES
} }
if not gdal.GetUseExceptions():
warnings.warn(('''
natcap.invest requires GDAL exceptions to be enabled. You must
call gdal.UseExceptions() to avoid unexpected behavior from
natcap.invest. A future version will enable exceptions on import.
gdal.UseExceptions() affects global state, so this may affect the
behavior of other packages.'''), FutureWarning)
def set_locale(locale_code): def set_locale(locale_code):
"""Set the `gettext` attribute of natcap.invest. """Set the `gettext` attribute of natcap.invest.

View File

@ -12,14 +12,15 @@ import sys
import textwrap import textwrap
import warnings import warnings
import natcap.invest from pygeoprocessing.geoprocessing_core import GDALUseExceptions
from natcap.invest import datastack with GDALUseExceptions():
from natcap.invest import model_metadata import natcap.invest
from natcap.invest import set_locale from natcap.invest import datastack
from natcap.invest import spec_utils from natcap.invest import model_metadata
from natcap.invest import ui_server from natcap.invest import set_locale
from natcap.invest import utils from natcap.invest import spec_utils
from natcap.invest import ui_server
from natcap.invest import utils
DEFAULT_EXIT_CODE = 1 DEFAULT_EXIT_CODE = 1
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
@ -218,267 +219,268 @@ def main(user_args=None):
so models may be run in this way without having GUI packages so models may be run in this way without having GUI packages
installed. installed.
""" """
parser = argparse.ArgumentParser( with GDALUseExceptions():
description=( parser = argparse.ArgumentParser(
'Integrated Valuation of Ecosystem Services and Tradeoffs. ' description=(
'InVEST (Integrated Valuation of Ecosystem Services and ' 'Integrated Valuation of Ecosystem Services and Tradeoffs. '
'Tradeoffs) is a family of tools for quantifying the values of ' 'InVEST (Integrated Valuation of Ecosystem Services and '
'natural capital in clear, credible, and practical ways. In ' 'Tradeoffs) is a family of tools for quantifying the values of '
'promising a return (of societal benefits) on investments in ' 'natural capital in clear, credible, and practical ways. In '
'nature, the scientific community needs to deliver knowledge and ' 'promising a return (of societal benefits) on investments in '
'tools to quantify and forecast this return. InVEST enables ' 'nature, the scientific community needs to deliver knowledge and '
'decision-makers to quantify the importance of natural capital, ' 'tools to quantify and forecast this return. InVEST enables '
'to assess the tradeoffs associated with alternative choices, ' 'decision-makers to quantify the importance of natural capital, '
'and to integrate conservation and human development. \n\n' 'to assess the tradeoffs associated with alternative choices, '
'Older versions of InVEST ran as script tools in the ArcGIS ' 'and to integrate conservation and human development. \n\n'
'ArcToolBox environment, but have almost all been ported over to ' 'Older versions of InVEST ran as script tools in the ArcGIS '
'a purely open-source python environment.'), 'ArcToolBox environment, but have almost all been ported over to '
prog='invest' 'a purely open-source python environment.'),
) prog='invest'
parser.add_argument('--version', action='version', )
version=natcap.invest.__version__) parser.add_argument('--version', action='version',
verbosity_group = parser.add_mutually_exclusive_group() version=natcap.invest.__version__)
verbosity_group.add_argument( verbosity_group = parser.add_mutually_exclusive_group()
'-v', '--verbose', dest='verbosity', default=0, action='count', verbosity_group.add_argument(
help=('Increase verbosity. Affects how much logging is printed to ' '-v', '--verbose', dest='verbosity', default=0, action='count',
'the console and (if running in headless mode) how much is ' help=('Increase verbosity. Affects how much logging is printed to '
'written to the logfile.')) 'the console and (if running in headless mode) how much is '
verbosity_group.add_argument( 'written to the logfile.'))
'--debug', dest='log_level', default=logging.ERROR, verbosity_group.add_argument(
action='store_const', const=logging.DEBUG, '--debug', dest='log_level', default=logging.ERROR,
help='Enable debug logging. Alias for -vvv') action='store_const', const=logging.DEBUG,
help='Enable debug logging. Alias for -vvv')
parser.add_argument( parser.add_argument(
'--taskgraph-log-level', dest='taskgraph_log_level', default='ERROR', '--taskgraph-log-level', dest='taskgraph_log_level', default='ERROR',
type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
help=('Set the logging level for Taskgraph. Affects how much logging ' help=('Set the logging level for Taskgraph. Affects how much logging '
'Taskgraph prints to the console and (if running in headless ' 'Taskgraph prints to the console and (if running in headless '
'mode) how much is written to the logfile.')) 'mode) how much is written to the logfile.'))
# list the language code and corresponding language name (in that language) # list the language code and corresponding language name (in that language)
supported_languages_string = ', '.join([ supported_languages_string = ', '.join([
f'{locale} ({display_name})' f'{locale} ({display_name})'
for locale, display_name in natcap.invest.LOCALE_NAME_MAP.items()]) for locale, display_name in natcap.invest.LOCALE_NAME_MAP.items()])
parser.add_argument( parser.add_argument(
'-L', '--language', default='en', '-L', '--language', default='en',
choices=natcap.invest.LOCALES, choices=natcap.invest.LOCALES,
help=('Choose a language. Model specs, names, and validation messages ' help=('Choose a language. Model specs, names, and validation messages '
'will be translated. Log messages are not translated. Value ' 'will be translated. Log messages are not translated. Value '
'should be an ISO 639-1 language code. Supported options are: ' 'should be an ISO 639-1 language code. Supported options are: '
f'{supported_languages_string}.')) f'{supported_languages_string}.'))
subparsers = parser.add_subparsers(dest='subcommand') subparsers = parser.add_subparsers(dest='subcommand')
listmodels_subparser = subparsers.add_parser( listmodels_subparser = subparsers.add_parser(
'list', help='List the available InVEST models') 'list', help='List the available InVEST models')
listmodels_subparser.add_argument( listmodels_subparser.add_argument(
'--json', action='store_true', help='Write output as a JSON object') '--json', action='store_true', help='Write output as a JSON object')
run_subparser = subparsers.add_parser( run_subparser = subparsers.add_parser(
'run', help='Run an InVEST model') 'run', help='Run an InVEST model')
# Recognize '--headless' for backwards compatibility. # Recognize '--headless' for backwards compatibility.
# This arg is otherwise unused. # This arg is otherwise unused.
run_subparser.add_argument( run_subparser.add_argument(
'-l', '--headless', action='store_true', '-l', '--headless', action='store_true',
help=argparse.SUPPRESS) help=argparse.SUPPRESS)
run_subparser.add_argument( run_subparser.add_argument(
'-d', '--datastack', default=None, nargs='?', '-d', '--datastack', default=None, nargs='?',
help=('Run the specified model with this JSON datastack. ' help=('Run the specified model with this JSON datastack. '
'Required if using --headless')) 'Required if using --headless'))
run_subparser.add_argument( run_subparser.add_argument(
'-w', '--workspace', default=None, nargs='?', '-w', '--workspace', default=None, nargs='?',
help=('The workspace in which outputs will be saved. ' help=('The workspace in which outputs will be saved. '
'Required if using --headless')) 'Required if using --headless'))
run_subparser.add_argument( run_subparser.add_argument(
'model', action=SelectModelAction, # Assert valid model name 'model', action=SelectModelAction, # Assert valid model name
help=('The model to run. Use "invest list" to list the available ' help=('The model to run. Use "invest list" to list the available '
'models.')) 'models.'))
validate_subparser = subparsers.add_parser( validate_subparser = subparsers.add_parser(
'validate', help=( 'validate', help=(
'Validate the parameters of a datastack')) 'Validate the parameters of a datastack'))
validate_subparser.add_argument( validate_subparser.add_argument(
'--json', action='store_true', help='Write output as a JSON object') '--json', action='store_true', help='Write output as a JSON object')
validate_subparser.add_argument( validate_subparser.add_argument(
'datastack', help=('Path to a JSON datastack.')) 'datastack', help=('Path to a JSON datastack.'))
getspec_subparser = subparsers.add_parser( getspec_subparser = subparsers.add_parser(
'getspec', help=('Get the specification of a model.')) 'getspec', help=('Get the specification of a model.'))
getspec_subparser.add_argument( getspec_subparser.add_argument(
'--json', action='store_true', help='Write output as a JSON object') '--json', action='store_true', help='Write output as a JSON object')
getspec_subparser.add_argument( getspec_subparser.add_argument(
'model', action=SelectModelAction, # Assert valid model name 'model', action=SelectModelAction, # Assert valid model name
help=('The model for which the spec should be fetched. Use "invest ' help=('The model for which the spec should be fetched. Use "invest '
'list" to list the available models.')) 'list" to list the available models.'))
serve_subparser = subparsers.add_parser( serve_subparser = subparsers.add_parser(
'serve', help=('Start the flask app on the localhost.')) 'serve', help=('Start the flask app on the localhost.'))
serve_subparser.add_argument( serve_subparser.add_argument(
'--port', type=int, default=56789, '--port', type=int, default=56789,
help='Port number for the Flask server') help='Port number for the Flask server')
export_py_subparser = subparsers.add_parser( export_py_subparser = subparsers.add_parser(
'export-py', help=('Save a python script that executes a model.')) 'export-py', help=('Save a python script that executes a model.'))
export_py_subparser.add_argument( export_py_subparser.add_argument(
'model', action=SelectModelAction, # Assert valid model name 'model', action=SelectModelAction, # Assert valid model name
help=('The model that the python script will execute. Use "invest ' help=('The model that the python script will execute. Use "invest '
'list" to list the available models.')) 'list" to list the available models.'))
export_py_subparser.add_argument( export_py_subparser.add_argument(
'-f', '--filepath', default=None, '-f', '--filepath', default=None,
help='Define a location for the saved .py file') help='Define a location for the saved .py file')
args = parser.parse_args(user_args) args = parser.parse_args(user_args)
natcap.invest.set_locale(args.language) natcap.invest.set_locale(args.language)
root_logger = logging.getLogger() root_logger = logging.getLogger()
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter( formatter = logging.Formatter(
fmt='%(asctime)s %(name)-18s %(levelname)-8s %(message)s', fmt='%(asctime)s %(name)-18s %(levelname)-8s %(message)s',
datefmt='%m/%d/%Y %H:%M:%S ') datefmt='%m/%d/%Y %H:%M:%S ')
handler.setFormatter(formatter) handler.setFormatter(formatter)
# Set the log level based on what the user provides in the available # Set the log level based on what the user provides in the available
# arguments. Verbosity: the more v's the lower the logging threshold. # arguments. Verbosity: the more v's the lower the logging threshold.
# If --debug is used, the logging threshold is 10. # If --debug is used, the logging threshold is 10.
# If the user goes lower than logging.DEBUG, default to logging.DEBUG. # If the user goes lower than logging.DEBUG, default to logging.DEBUG.
log_level = min(args.log_level, logging.ERROR - (args.verbosity*10)) log_level = min(args.log_level, logging.ERROR - (args.verbosity*10))
handler.setLevel(max(log_level, logging.DEBUG)) # don't go below DEBUG handler.setLevel(max(log_level, logging.DEBUG)) # don't go below DEBUG
root_logger.addHandler(handler) root_logger.addHandler(handler)
LOGGER.info('Setting handler log level to %s', log_level) LOGGER.info('Setting handler log level to %s', log_level)
# Set the log level for taskgraph. # Set the log level for taskgraph.
taskgraph_log_level = logging.getLevelName(args.taskgraph_log_level.upper()) taskgraph_log_level = logging.getLevelName(args.taskgraph_log_level.upper())
logging.getLogger('taskgraph').setLevel(taskgraph_log_level) logging.getLogger('taskgraph').setLevel(taskgraph_log_level)
LOGGER.debug('Setting taskgraph log level to %s', taskgraph_log_level) LOGGER.debug('Setting taskgraph log level to %s', taskgraph_log_level)
# FYI: Root logger by default has a level of logging.WARNING. # FYI: Root logger by default has a level of logging.WARNING.
# To capture ALL logging produced in this system at runtime, use this: # To capture ALL logging produced in this system at runtime, use this:
# logging.getLogger().setLevel(logging.DEBUG) # logging.getLogger().setLevel(logging.DEBUG)
# Also FYI: using logging.DEBUG means that the logger will defer to # Also FYI: using logging.DEBUG means that the logger will defer to
# the setting of the parent logger. # the setting of the parent logger.
logging.getLogger('natcap').setLevel(logging.DEBUG) logging.getLogger('natcap').setLevel(logging.DEBUG)
if args.subcommand == 'list': if args.subcommand == 'list':
# reevaluate the model names in the new language # reevaluate the model names in the new language
importlib.reload(model_metadata) importlib.reload(model_metadata)
if args.json:
message = build_model_list_json()
else:
message = build_model_list_table()
sys.stdout.write(message)
parser.exit()
if args.subcommand == 'validate':
try:
parsed_datastack = datastack.extract_parameter_set(args.datastack)
except Exception as error:
parser.exit(
1, "Error when parsing JSON datastack:\n " + str(error))
# reload validation module first so it's also in the correct language
importlib.reload(importlib.import_module('natcap.invest.validation'))
model_module = importlib.reload(importlib.import_module(
name=parsed_datastack.model_name))
try:
validation_result = model_module.validate(parsed_datastack.args)
except KeyError as missing_keys_error:
if args.json: if args.json:
message = json.dumps( message = build_model_list_json()
{'validation_results': {
str(list(missing_keys_error.args)): 'Key is missing'}})
else: else:
message = ('Datastack is missing keys:\n ' + message = build_model_list_table()
str(missing_keys_error.args))
# Missing keys have an exit code of 1 because that would indicate
# probably programmer error.
sys.stdout.write(message) sys.stdout.write(message)
parser.exit(1) parser.exit()
except Exception as error:
parser.exit(
1, ('Datastack could not be validated:\n ' +
str(error)))
# Even validation errors will have an exit code of 0 if args.subcommand == 'validate':
if args.json: try:
message = json.dumps({ parsed_datastack = datastack.extract_parameter_set(args.datastack)
'validation_results': validation_result}) except Exception as error:
else:
message = pprint.pformat(validation_result)
sys.stdout.write(message)
parser.exit(0)
if args.subcommand == 'getspec':
target_model = model_metadata.MODEL_METADATA[args.model].pyname
model_module = importlib.reload(
importlib.import_module(name=target_model))
spec = model_module.MODEL_SPEC
if args.json:
message = spec_utils.serialize_args_spec(spec)
else:
message = pprint.pformat(spec)
sys.stdout.write(message)
parser.exit(0)
if args.subcommand == 'run':
if args.headless:
warnings.warn(
'--headless (-l) is now the default (and only) behavior '
'for `invest run`. This flag will not be recognized '
'in the future.', FutureWarning, stacklevel=2) # 2 for brevity
if not args.datastack:
parser.exit(1, 'Datastack required for execution.')
try:
parsed_datastack = datastack.extract_parameter_set(args.datastack)
except Exception as error:
parser.exit(
1, "Error when parsing JSON datastack:\n " + str(error))
if not args.workspace:
if ('workspace_dir' not in parsed_datastack.args or
parsed_datastack.args['workspace_dir'] in ['', None]):
parser.exit( parser.exit(
1, ('Workspace must be defined at the command line ' 1, "Error when parsing JSON datastack:\n " + str(error))
'or in the datastack file'))
else:
parsed_datastack.args['workspace_dir'] = args.workspace
target_model = model_metadata.MODEL_METADATA[args.model].pyname # reload validation module first so it's also in the correct language
model_module = importlib.import_module(name=target_model) importlib.reload(importlib.import_module('natcap.invest.validation'))
LOGGER.info('Imported target %s from %s', model_module = importlib.reload(importlib.import_module(
model_module.__name__, model_module) name=parsed_datastack.model_name))
with utils.prepare_workspace(parsed_datastack.args['workspace_dir'], try:
name=parsed_datastack.model_name, validation_result = model_module.validate(parsed_datastack.args)
logging_level=log_level): except KeyError as missing_keys_error:
LOGGER.log(datastack.ARGS_LOG_LEVEL, if args.json:
'Starting model with parameters: \n%s', message = json.dumps(
datastack.format_args_dict(parsed_datastack.args, {'validation_results': {
parsed_datastack.model_name)) str(list(missing_keys_error.args)): 'Key is missing'}})
else:
message = ('Datastack is missing keys:\n ' +
str(missing_keys_error.args))
# We're deliberately not validating here because the user # Missing keys have an exit code of 1 because that would indicate
# can just call ``invest validate <datastack>`` to validate. # probably programmer error.
# sys.stdout.write(message)
# Exceptions will already be logged to the logfile but will ALSO be parser.exit(1)
# written to stdout if this exception is uncaught. This is by except Exception as error:
# design. parser.exit(
model_module.execute(parsed_datastack.args) 1, ('Datastack could not be validated:\n ' +
str(error)))
if args.subcommand == 'serve': # Even validation errors will have an exit code of 0
ui_server.app.run(port=args.port) if args.json:
parser.exit(0) message = json.dumps({
'validation_results': validation_result})
else:
message = pprint.pformat(validation_result)
if args.subcommand == 'export-py': sys.stdout.write(message)
target_filepath = args.filepath parser.exit(0)
if not args.filepath:
target_filepath = f'{args.model}_execute.py' if args.subcommand == 'getspec':
export_to_python(target_filepath, args.model) target_model = model_metadata.MODEL_METADATA[args.model].pyname
parser.exit() model_module = importlib.reload(
importlib.import_module(name=target_model))
spec = model_module.MODEL_SPEC
if args.json:
message = spec_utils.serialize_args_spec(spec)
else:
message = pprint.pformat(spec)
sys.stdout.write(message)
parser.exit(0)
if args.subcommand == 'run':
if args.headless:
warnings.warn(
'--headless (-l) is now the default (and only) behavior '
'for `invest run`. This flag will not be recognized '
'in the future.', FutureWarning, stacklevel=2) # 2 for brevity
if not args.datastack:
parser.exit(1, 'Datastack required for execution.')
try:
parsed_datastack = datastack.extract_parameter_set(args.datastack)
except Exception as error:
parser.exit(
1, "Error when parsing JSON datastack:\n " + str(error))
if not args.workspace:
if ('workspace_dir' not in parsed_datastack.args or
parsed_datastack.args['workspace_dir'] in ['', None]):
parser.exit(
1, ('Workspace must be defined at the command line '
'or in the datastack file'))
else:
parsed_datastack.args['workspace_dir'] = args.workspace
target_model = model_metadata.MODEL_METADATA[args.model].pyname
model_module = importlib.import_module(name=target_model)
LOGGER.info('Imported target %s from %s',
model_module.__name__, model_module)
with utils.prepare_workspace(parsed_datastack.args['workspace_dir'],
name=parsed_datastack.model_name,
logging_level=log_level):
LOGGER.log(datastack.ARGS_LOG_LEVEL,
'Starting model with parameters: \n%s',
datastack.format_args_dict(parsed_datastack.args,
parsed_datastack.model_name))
# We're deliberately not validating here because the user
# can just call ``invest validate <datastack>`` to validate.
#
# Exceptions will already be logged to the logfile but will ALSO be
# written to stdout if this exception is uncaught. This is by
# design.
model_module.execute(parsed_datastack.args)
if args.subcommand == 'serve':
ui_server.app.run(port=args.port)
parser.exit(0)
if args.subcommand == 'export-py':
target_filepath = args.filepath
if not args.filepath:
target_filepath = f'{args.model}_execute.py'
export_to_python(target_filepath, args.model)
parser.exit()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1737,15 +1737,16 @@ def extract_bathymetry_along_ray(
iy = int((point.y - bathy_gt[3]) / bathy_gt[5]) iy = int((point.y - bathy_gt[3]) / bathy_gt[5])
win_size = 1 win_size = 1
value = bathy_band.ReadAsArray( try:
xoff=ix, yoff=iy, value = bathy_band.ReadAsArray(
win_xsize=win_size, win_ysize=win_size) xoff=ix, yoff=iy,
if value is None: win_xsize=win_size, win_ysize=win_size)
except RuntimeError as ex:
location = {'xoff': ix, 'yoff': iy, 'win_xsize': win_size, location = {'xoff': ix, 'yoff': iy, 'win_xsize': win_size,
'win_ysize': win_size} 'win_ysize': win_size}
raise ValueError( raise ValueError(
f'got a {value} when trying to read bathymetry at {location}. ' f'Failed to read bathymetry at {location}. Does the bathymetry '
'Does the bathymetry input fully cover the fetch ray area?') 'input fully cover the fetch ray area?') from ex
if bathy_nodata is None or not numpy.isclose( if bathy_nodata is None or not numpy.isclose(
value[0][0], bathy_nodata, equal_nan=True): value[0][0], bathy_nodata, equal_nan=True):
bathy_values.append(value) bathy_values.append(value)
@ -2468,25 +2469,26 @@ def search_for_vector_habitat(
geometry = feature.GetGeometryRef() geometry = feature.GetGeometryRef()
if not geometry.IsValid(): if not geometry.IsValid():
geometry = geometry.Buffer(0) # sometimes this fixes geometry try:
if geometry is not None: # geometry is None if the buffer failed. geometry = geometry.Buffer(0) # sometimes this fixes geometry
clipped_geometry = geometry.Intersection(base_srs_clipping_geom) except RuntimeError:
if not clipped_geometry.IsEmpty(): LOGGER.warning(
if target_spatial_reference != base_spatial_reference: f"FID {feature.GetFID()} in {habitat_vector_path} has invalid "
err_code = clipped_geometry.Transform(transform) "geometry and will be excluded")
if err_code != 0: continue
LOGGER.warning( clipped_geometry = geometry.Intersection(base_srs_clipping_geom)
f"Could not transform feature from " if not clipped_geometry.IsEmpty():
f"{habitat_vector_path} to spatial reference " if target_spatial_reference != base_spatial_reference:
"system of AOI") err_code = clipped_geometry.Transform(transform)
continue if err_code != 0:
shapely_geom = shapely.wkb.loads( LOGGER.warning(
bytes(clipped_geometry.ExportToWkb())) f"Could not transform feature from "
shapely_geometry_list.extend(_list_geometry(shapely_geom)) f"{habitat_vector_path} to spatial reference "
else: "system of AOI")
LOGGER.warning( continue
f"FID {feature.GetFID()} in {habitat_vector_path} has invalid " shapely_geom = shapely.wkb.loads(
"geometry and will be excluded") bytes(clipped_geometry.ExportToWkb()))
shapely_geometry_list.extend(_list_geometry(shapely_geom))
if not shapely_geometry_list: if not shapely_geometry_list:
LOGGER.warning(f'No valid features exist in {habitat_vector_path}') LOGGER.warning(f'No valid features exist in {habitat_vector_path}')

View File

@ -782,14 +782,17 @@ def _build_spatial_index(
# put all the polygons in the kd_tree because it's fast and simple # put all the polygons in the kd_tree because it's fast and simple
for poly_feature in model_layer: for poly_feature in model_layer:
poly_geom = poly_feature.GetGeometryRef() poly_geom = poly_feature.GetGeometryRef()
poly_centroid = poly_geom.Centroid() if poly_geom.IsValid():
# put in row/col order since rasters are row/col indexed poly_centroid = poly_geom.Centroid()
kd_points.append([poly_centroid.GetY(), poly_centroid.GetX()]) # put in row/col order since rasters are row/col indexed
kd_points.append([poly_centroid.GetY(), poly_centroid.GetX()])
theta_model_parameters.append([ theta_model_parameters.append([
poly_feature.GetField(feature_id) for feature_id in poly_feature.GetField(feature_id) for feature_id in
['theta1', 'theta2', 'theta3']]) ['theta1', 'theta2', 'theta3']])
method_model_parameter.append(poly_feature.GetField('method')) method_model_parameter.append(poly_feature.GetField('method'))
else:
LOGGER.warning(f'skipping invalid geometry {poly_geom}')
method_model_parameter = numpy.array( method_model_parameter = numpy.array(
method_model_parameter, dtype=numpy.int32) method_model_parameter, dtype=numpy.int32)

View File

@ -1603,26 +1603,18 @@ def _validate_same_projection(base_vector_path, table_path):
invalid_projections = False invalid_projections = False
for path in data_paths: for path in data_paths:
gis_type = pygeoprocessing.get_gis_type(path)
def error_handler(err_level, err_no, err_msg): if gis_type == pygeoprocessing.UNKNOWN_TYPE:
"""Empty error handler to avoid stderr output.""" return f"{path} did not load"
pass elif gis_type == pygeoprocessing.RASTER_TYPE:
gdal.PushErrorHandler(error_handler) raster = gdal.OpenEx(path, gdal.OF_RASTER)
raster = gdal.OpenEx(path, gdal.OF_RASTER)
gdal.PopErrorHandler()
if raster is not None:
projection_as_str = raster.GetProjection() projection_as_str = raster.GetProjection()
ref = osr.SpatialReference() ref = osr.SpatialReference()
ref.ImportFromWkt(projection_as_str) ref.ImportFromWkt(projection_as_str)
raster = None
else: else:
vector = gdal.OpenEx(path, gdal.OF_VECTOR) vector = gdal.OpenEx(path, gdal.OF_VECTOR)
if vector is None:
return f"{path} did not load"
layer = vector.GetLayer() layer = vector.GetLayer()
ref = osr.SpatialReference(layer.GetSpatialRef().ExportToWkt()) ref = osr.SpatialReference(layer.GetSpatialRef().ExportToWkt())
layer = None
vector = None
if not base_ref.IsSame(ref): if not base_ref.IsSame(ref):
invalid_projections = True invalid_projections = True
if invalid_projections: if invalid_projections:

View File

@ -50,7 +50,9 @@ def _log_gdal_errors(*args, **kwargs):
"""Log error messages to osgeo. """Log error messages to osgeo.
All error messages are logged with reasonable ``logging`` levels based All error messages are logged with reasonable ``logging`` levels based
on the GDAL error level. on the GDAL error level. While we are now using ``gdal.UseExceptions()``,
we still need this to handle GDAL logging that does not get raised as
an exception.
Note: Note:
This function is designed to accept any number of positional and This function is designed to accept any number of positional and

View File

@ -304,19 +304,16 @@ def check_raster(filepath, projected=False, projection_units=None, **kwargs):
if file_warning: if file_warning:
return file_warning return file_warning
gdal.PushErrorHandler('CPLQuietErrorHandler') try:
gdal_dataset = gdal.OpenEx(filepath, gdal.OF_RASTER) gdal_dataset = gdal.OpenEx(filepath, gdal.OF_RASTER)
gdal.PopErrorHandler() except RuntimeError:
if gdal_dataset is None:
return MESSAGES['NOT_GDAL_RASTER'] return MESSAGES['NOT_GDAL_RASTER']
# Check that an overview .ovr file wasn't opened. # Check that an overview .ovr file wasn't opened.
if os.path.splitext(filepath)[1] == '.ovr': if os.path.splitext(filepath)[1] == '.ovr':
return MESSAGES['OVR_FILE'] return MESSAGES['OVR_FILE']
srs = osr.SpatialReference() srs = gdal_dataset.GetSpatialRef()
srs.ImportFromWkt(gdal_dataset.GetProjection())
projection_warning = _check_projection(srs, projected, projection_units) projection_warning = _check_projection(srs, projected, projection_units)
if projection_warning: if projection_warning:
gdal_dataset = None gdal_dataset = None
@ -378,9 +375,10 @@ def check_vector(filepath, geometries, fields=None, projected=False,
if file_warning: if file_warning:
return file_warning return file_warning
gdal.PushErrorHandler('CPLQuietErrorHandler') try:
gdal_dataset = gdal.OpenEx(filepath, gdal.OF_VECTOR) gdal_dataset = gdal.OpenEx(filepath, gdal.OF_VECTOR)
gdal.PopErrorHandler() except RuntimeError:
return MESSAGES['NOT_GDAL_VECTOR']
geom_map = { geom_map = {
'POINT': [ogr.wkbPoint, ogr.wkbPointM, ogr.wkbPointZM, 'POINT': [ogr.wkbPoint, ogr.wkbPointM, ogr.wkbPointZM,
@ -402,9 +400,6 @@ def check_vector(filepath, geometries, fields=None, projected=False,
for geom in geometries: for geom in geometries:
allowed_geom_types += geom_map[geom] allowed_geom_types += geom_map[geom]
if gdal_dataset is None:
return MESSAGES['NOT_GDAL_VECTOR']
# NOTE: this only checks the layer geometry type, not the types of the # NOTE: this only checks the layer geometry type, not the types of the
# actual geometries (layer.GetGeometryTypes()). This is probably equivalent # actual geometries (layer.GetGeometryTypes()). This is probably equivalent
# in most cases, and it's more efficient than checking every geometry, but # in most cases, and it's more efficient than checking every geometry, but

View File

@ -1148,23 +1148,18 @@ def _copy_vector_or_raster(base_file_path, target_file_path):
ValueError if the base file can't be opened by GDAL. ValueError if the base file can't be opened by GDAL.
""" """
# Open the file as raster first gis_type = pygeoprocessing.get_gis_type(base_file_path)
source_dataset = gdal.OpenEx(base_file_path, gdal.OF_RASTER) if gis_type == pygeoprocessing.RASTER_TYPE:
target_driver_name = _RASTER_DRIVER_NAME source_dataset = gdal.OpenEx(base_file_path, gdal.OF_RASTER)
if source_dataset is None: target_driver_name = _RASTER_DRIVER_NAME
# File didn't open as a raster; assume it's a vector elif gis_type == pygeoprocessing.VECTOR_TYPE:
source_dataset = gdal.OpenEx(base_file_path, gdal.OF_VECTOR) source_dataset = gdal.OpenEx(base_file_path, gdal.OF_VECTOR)
target_driver_name = _VECTOR_DRIVER_NAME target_driver_name = _VECTOR_DRIVER_NAME
else:
# Raise an exception if the file can't be opened by GDAL raise ValueError(f'File {base_file_path} is neither a GDAL-compatible '
if source_dataset is None: 'raster nor vector.')
raise ValueError(
'File %s is neither a GDAL-compatible raster nor vector.'
% base_file_path)
driver = gdal.GetDriverByName(target_driver_name) driver = gdal.GetDriverByName(target_driver_name)
driver.CreateCopy(target_file_path, source_dataset) driver.CreateCopy(target_file_path, source_dataset)
source_dataset = None
def _interpolate_vector_field_onto_raster( def _interpolate_vector_field_onto_raster(

View File

@ -1845,7 +1845,7 @@ def _mask_by_distance(base_raster_path, min_dist, max_dist, out_nodata,
def _create_distance_raster(base_raster_path, base_vector_path, def _create_distance_raster(base_raster_path, base_vector_path,
target_dist_raster_path, work_dir): target_dist_raster_path, work_dir, where_clause=None):
"""Create and rasterize vector onto a raster, and calculate dist transform. """Create and rasterize vector onto a raster, and calculate dist transform.
Create a raster where the pixel values represent the euclidean distance to Create a raster where the pixel values represent the euclidean distance to
@ -1857,6 +1857,9 @@ def _create_distance_raster(base_raster_path, base_vector_path,
base_vector_path (str): path to vector to be rasterized. base_vector_path (str): path to vector to be rasterized.
target_dist_raster_path (str): path to raster with distance transform. target_dist_raster_path (str): path to raster with distance transform.
work_dir (str): path to create a temp folder for saving files. work_dir (str): path to create a temp folder for saving files.
where_clause (str): If not None, is an SQL query-like string to filter
which features are rasterized. This kwarg is passed to
``pygeoprocessing.rasterize``.
Returns: Returns:
None None
@ -1884,7 +1887,8 @@ def _create_distance_raster(base_raster_path, base_vector_path,
base_vector_path, base_vector_path,
rasterized_raster_path, rasterized_raster_path,
burn_values=[1], burn_values=[1],
option_list=["ALL_TOUCHED=TRUE"]) option_list=["ALL_TOUCHED=TRUE"],
where_clause=where_clause)
# Calculate euclidean distance transform # Calculate euclidean distance transform
pygeoprocessing.distance_transform_edt( pygeoprocessing.distance_transform_edt(
@ -2589,67 +2593,25 @@ def _calculate_distances_land_grid(base_point_vector_path, base_raster_path,
# A list to hold the land to grid distances in order for each point # A list to hold the land to grid distances in order for each point
# features 'L2G' field # features 'L2G' field
l2g_dist = [] l2g_dist = []
# A list to hold the individual distance transform path's in order # A list to hold the individual distance transform paths in order
land_point_dist_raster_path_list = [] land_point_dist_raster_path_list = []
# Get the original layer definition which holds needed attribute values fid_field = base_point_layer.GetFIDColumn()
base_layer_defn = base_point_layer.GetLayerDefn() if not fid_field:
file_ext, driver_name = _get_file_ext_and_driver_name( fid_field = 'FID'
base_point_vector_path)
output_driver = ogr.GetDriverByName(driver_name)
single_feature_vector_path = os.path.join(
temp_dir, 'single_feature' + file_ext)
target_vector = output_driver.CreateDataSource(single_feature_vector_path)
# Create the new layer for target_vector using same name and
# geometry type from base_vector as well as spatial reference
target_layer = target_vector.CreateLayer(base_layer_defn.GetName(),
base_point_layer.GetSpatialRef(),
base_layer_defn.GetGeomType())
# Get the number of fields in original_layer
base_field_count = base_layer_defn.GetFieldCount()
# For every field, create a duplicate field and add it to the new
# shapefiles layer
for fld_index in range(base_field_count):
base_field = base_layer_defn.GetFieldDefn(fld_index)
target_field = ogr.FieldDefn(base_field.GetName(),
base_field.GetType())
# NOT setting the WIDTH or PRECISION because that seems to be
# unneeded and causes interesting OGR conflicts
target_layer.CreateField(target_field)
# Create a new shapefile with only one feature to burn onto a raster # Create a new shapefile with only one feature to burn onto a raster
# in order to get the distance transform based on that one feature # in order to get the distance transform based on that one feature
for feature_index, point_feature in enumerate(base_point_layer): for feature_index, point_feature in enumerate(base_point_layer):
# Get the point features land to grid value and add it to the list # Get the point features land to grid value and add it to the list
field_index = point_feature.GetFieldIndex('L2G') l2g_dist.append(float(point_feature.GetField('L2G')))
l2g_dist.append(float(point_feature.GetField(field_index)))
# Copy original_datasource's feature and set as new shapes feature dist_raster_path = os.path.join(temp_dir, f'dist_{feature_index}.tif')
output_feature = ogr.Feature(feature_def=target_layer.GetLayerDefn()) _create_distance_raster(
base_raster_path, base_point_vector_path, dist_raster_path,
# Since the original feature is of interest add its fields and work_dir, where_clause=f'{fid_field}={point_feature.GetFID()}')
# Values to the new feature from the intersecting geometries
# The False in SetFrom() signifies that the fields must match
# exactly
output_feature.SetFrom(point_feature, False)
target_layer.CreateFeature(output_feature)
target_vector.SyncToDisk()
target_layer.DeleteFeature(point_feature.GetFID())
dist_raster_path = os.path.join(temp_dir,
'dist_%s.tif' % feature_index)
_create_distance_raster(base_raster_path, single_feature_vector_path,
dist_raster_path, work_dir)
# Add each features distance transform result to list # Add each features distance transform result to list
land_point_dist_raster_path_list.append(dist_raster_path) land_point_dist_raster_path_list.append(dist_raster_path)
target_layer = None
target_vector = None
base_point_layer = None
base_point_vector = None
l2g_dist_array = numpy.array(l2g_dist) l2g_dist_array = numpy.array(l2g_dist)
def _min_land_ocean_dist(*grid_distances): def _min_land_ocean_dist(*grid_distances):

View File

@ -6,13 +6,14 @@ import os
import pandas import pandas
import numpy import numpy
from osgeo import gdal
import pygeoprocessing import pygeoprocessing
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'annual_water_yield') os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'annual_water_yield')
SAMPLE_DATA = os.path.join(REGRESSION_DATA, 'input') SAMPLE_DATA = os.path.join(REGRESSION_DATA, 'input')
gdal.UseExceptions()
class AnnualWaterYieldTests(unittest.TestCase): class AnnualWaterYieldTests(unittest.TestCase):
"""Regression Tests for Annual Water Yield Model.""" """Regression Tests for Annual Water Yield Model."""

View File

@ -12,6 +12,7 @@ import numpy.random
import numpy.testing import numpy.testing
import pygeoprocessing import pygeoprocessing
gdal.UseExceptions()
def make_simple_raster(base_raster_path, fill_val, nodata_val): def make_simple_raster(base_raster_path, fill_val, nodata_val):
"""Create a 10x10 raster on designated path with fill value. """Create a 10x10 raster on designated path with fill value.

View File

@ -10,12 +10,13 @@ import json
import importlib import importlib
import uuid import uuid
try: try:
from StringIO import StringIO from StringIO import StringIO
except ImportError: except ImportError:
from io import StringIO from io import StringIO
from osgeo import gdal
gdal.UseExceptions()
@contextlib.contextmanager @contextlib.contextmanager
def redirect_stdout(): def redirect_stdout():

View File

@ -17,6 +17,7 @@ from natcap.invest import validation
from osgeo import gdal from osgeo import gdal
from osgeo import osr from osgeo import osr
gdal.UseExceptions()
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', os.path.dirname(__file__), '..', 'data', 'invest-test-data',
'coastal_blue_carbon') 'coastal_blue_carbon')

View File

@ -20,6 +20,7 @@ from shapely.geometry import MultiPolygon
from shapely.geometry import Point from shapely.geometry import Point
from shapely.geometry import Polygon from shapely.geometry import Polygon
gdal.UseExceptions()
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', os.path.dirname(__file__), '..', 'data', 'invest-test-data',
'coastal_vulnerability') 'coastal_vulnerability')

View File

@ -9,6 +9,7 @@ from osgeo import gdal
import pandas import pandas
import pygeoprocessing import pygeoprocessing
gdal.UseExceptions()
MODEL_DATA_PATH = os.path.join( MODEL_DATA_PATH = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', os.path.dirname(__file__), '..', 'data', 'invest-test-data',
'crop_production_model', 'model_data') 'crop_production_model', 'model_data')

View File

@ -18,6 +18,7 @@ import shapely.geometry
from osgeo import gdal from osgeo import gdal
from osgeo import ogr from osgeo import ogr
gdal.UseExceptions()
_TEST_FILE_CWD = os.path.dirname(os.path.abspath(__file__)) _TEST_FILE_CWD = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(_TEST_FILE_CWD, DATA_DIR = os.path.join(_TEST_FILE_CWD,
'..', 'data', 'invest-test-data', 'data_stack') '..', 'data', 'invest-test-data', 'data_stack')

View File

@ -19,6 +19,7 @@ from shapely.geometry import box
from shapely.geometry import MultiPoint from shapely.geometry import MultiPoint
from shapely.geometry import Point from shapely.geometry import Point
gdal.UseExceptions()
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', os.path.dirname(__file__), '..', 'data', 'invest-test-data',
'delineateit') 'delineateit')

View File

@ -7,7 +7,7 @@ import os
from osgeo import gdal from osgeo import gdal
import numpy import numpy
gdal.UseExceptions()
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', os.path.dirname(__file__), '..', 'data', 'invest-test-data',
'forest_carbon_edge_effect') 'forest_carbon_edge_effect')

View File

@ -12,6 +12,7 @@ from osgeo import ogr
from osgeo import osr from osgeo import osr
from shapely.geometry import Polygon from shapely.geometry import Polygon
gdal.UseExceptions()
def make_raster_from_array( def make_raster_from_array(
base_array, base_raster_path, nodata_val=-1, gdal_type=gdal.GDT_Int32): base_array, base_raster_path, nodata_val=-1, gdal_type=gdal.GDT_Int32):

View File

@ -18,6 +18,7 @@ from osgeo import gdal
from osgeo import ogr from osgeo import ogr
from osgeo import osr from osgeo import osr
gdal.UseExceptions()
ORIGIN = (1180000.0, 690000.0) ORIGIN = (1180000.0, 690000.0)
_SRS = osr.SpatialReference() _SRS = osr.SpatialReference()
_SRS.ImportFromEPSG(26910) # UTM zone 10N _SRS.ImportFromEPSG(26910) # UTM zone 10N

View File

@ -3,6 +3,8 @@
import unittest import unittest
import os import os
from osgeo import gdal
gdal.UseExceptions()
class FileRegistryTests(unittest.TestCase): class FileRegistryTests(unittest.TestCase):
"""Tests for the InVEST file registry builder.""" """Tests for the InVEST file registry builder."""

View File

@ -4,7 +4,9 @@ import unittest
import pint import pint
from natcap.invest.model_metadata import MODEL_METADATA from natcap.invest.model_metadata import MODEL_METADATA
from osgeo import gdal
gdal.UseExceptions()
valid_nested_types = { valid_nested_types = {
None: { # if no parent type (arg is top-level), then all types are valid None: { # if no parent type (arg is top-level), then all types are valid
'boolean', 'boolean',

View File

@ -12,6 +12,7 @@ from osgeo import gdal
from osgeo import ogr from osgeo import ogr
from osgeo import osr from osgeo import osr
gdal.UseExceptions()
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'ndr') os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'ndr')

View File

@ -7,9 +7,11 @@ import unittest
import numpy import numpy
import pygeoprocessing import pygeoprocessing
import shapely.geometry import shapely.geometry
from osgeo import gdal
from osgeo import ogr from osgeo import ogr
from osgeo import osr from osgeo import osr
gdal.UseExceptions()
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'pollination') os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'pollination')

View File

@ -28,6 +28,7 @@ import warnings
from natcap.invest import utils from natcap.invest import utils
gdal.UseExceptions()
Pyro4.config.SERIALIZER = 'marshal' # allow null bytes in strings Pyro4.config.SERIALIZER = 'marshal' # allow null bytes in strings
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(

View File

@ -9,6 +9,7 @@ import numpy
from osgeo import gdal from osgeo import gdal
from osgeo import osr from osgeo import osr
gdal.UseExceptions()
class RouteDEMTests(unittest.TestCase): class RouteDEMTests(unittest.TestCase):
"""Tests for RouteDEM with Pygeoprocessing 1.x routing API.""" """Tests for RouteDEM with Pygeoprocessing 1.x routing API."""

View File

@ -5,7 +5,9 @@ import shutil
import os import os
import pandas import pandas
from osgeo import gdal
gdal.UseExceptions()
TEST_DATA_DIR = os.path.join( TEST_DATA_DIR = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', os.path.dirname(__file__), '..', 'data', 'invest-test-data',
'scenario_gen_proximity') 'scenario_gen_proximity')

View File

@ -14,6 +14,7 @@ from shapely.geometry import LineString
from shapely.geometry import Point from shapely.geometry import Point
from shapely.geometry import Polygon from shapely.geometry import Polygon
gdal.UseExceptions()
_SRS = osr.SpatialReference() _SRS = osr.SpatialReference()
_SRS.ImportFromEPSG(32731) # WGS84 / UTM zone 31s _SRS.ImportFromEPSG(32731) # WGS84 / UTM zone 31s
WKT = _SRS.ExportToWkt() WKT = _SRS.ExportToWkt()

View File

@ -9,6 +9,7 @@ import pygeoprocessing
from osgeo import gdal from osgeo import gdal
from osgeo import osr from osgeo import osr
gdal.UseExceptions()
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'sdr') os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'sdr')
SAMPLE_DATA = os.path.join(REGRESSION_DATA, 'input') SAMPLE_DATA = os.path.join(REGRESSION_DATA, 'input')

View File

@ -10,6 +10,7 @@ from osgeo import gdal
from osgeo import ogr from osgeo import ogr
from osgeo import osr from osgeo import osr
gdal.UseExceptions()
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', os.path.dirname(__file__), '..', 'data', 'invest-test-data',
'seasonal_water_yield') 'seasonal_water_yield')

View File

@ -2,7 +2,9 @@ import unittest
from natcap.invest import spec_utils from natcap.invest import spec_utils
from natcap.invest.unit_registry import u from natcap.invest.unit_registry import u
from osgeo import gdal
gdal.UseExceptions()
class TestSpecUtils(unittest.TestCase): class TestSpecUtils(unittest.TestCase):

View File

@ -13,7 +13,7 @@ import pygeoprocessing
from pygeoprocessing.geoprocessing_core import ( from pygeoprocessing.geoprocessing_core import (
DEFAULT_GTIFF_CREATION_TUPLE_OPTIONS as opts_tuple) DEFAULT_GTIFF_CREATION_TUPLE_OPTIONS as opts_tuple)
gdal.UseExceptions()
TEST_DATA = os.path.join(os.path.dirname( TEST_DATA = os.path.join(os.path.dirname(
__file__), '..', 'data', 'invest-test-data', 'stormwater') __file__), '..', 'data', 'invest-test-data', 'stormwater')

View File

@ -10,7 +10,9 @@ from unittest.mock import patch
from babel.messages import Catalog, mofile from babel.messages import Catalog, mofile
import natcap.invest import natcap.invest
from natcap.invest import validation from natcap.invest import validation
from osgeo import gdal
gdal.UseExceptions()
TEST_LANG = 'll' TEST_LANG = 'll'
# assign to local variable so that it won't be changed by translation # assign to local variable so that it won't be changed by translation

View File

@ -8,6 +8,7 @@ import numpy
import pandas import pandas
from osgeo import gdal from osgeo import gdal
gdal.UseExceptions()
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'ucm') os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'ucm')

View File

@ -13,6 +13,7 @@ from osgeo import gdal
from osgeo import ogr from osgeo import ogr
from osgeo import osr from osgeo import osr
gdal.UseExceptions()
class UFRMTests(unittest.TestCase): class UFRMTests(unittest.TestCase):
"""Tests for the Urban Flood Risk Mitigation Model.""" """Tests for the Urban Flood Risk Mitigation Model."""

View File

@ -6,7 +6,9 @@ import unittest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from natcap.invest import ui_server from natcap.invest import ui_server
from osgeo import gdal
gdal.UseExceptions()
TEST_DATA_PATH = os.path.join( TEST_DATA_PATH = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data') os.path.dirname(__file__), '..', 'data', 'invest-test-data')

View File

@ -18,6 +18,7 @@ from osgeo import gdal
from osgeo import ogr from osgeo import ogr
from osgeo import osr from osgeo import osr
gdal.UseExceptions()
_DEFAULT_ORIGIN = (444720, 3751320) _DEFAULT_ORIGIN = (444720, 3751320)
_DEFAULT_PIXEL_SIZE = (30, -30) _DEFAULT_PIXEL_SIZE = (30, -30)
_DEFAULT_EPSG = 3116 _DEFAULT_EPSG = 3116

View File

@ -13,6 +13,7 @@ import shapely.geometry
import numpy import numpy
import numpy.testing import numpy.testing
gdal.UseExceptions()
class UsageLoggingTests(unittest.TestCase): class UsageLoggingTests(unittest.TestCase):
"""Tests for the InVEST usage logging framework.""" """Tests for the InVEST usage logging framework."""

View File

@ -25,6 +25,7 @@ from osgeo import osr
from shapely.geometry import Point from shapely.geometry import Point
from shapely.geometry import Polygon from shapely.geometry import Polygon
gdal.UseExceptions()
class SuffixUtilsTests(unittest.TestCase): class SuffixUtilsTests(unittest.TestCase):
"""Tests for natcap.invest.utils.make_suffix_string.""" """Tests for natcap.invest.utils.make_suffix_string."""
@ -412,17 +413,22 @@ class GDALWarningsLoggingTests(unittest.TestCase):
logfile = os.path.join(self.workspace, 'logfile.txt') logfile = os.path.join(self.workspace, 'logfile.txt')
# this warning should go to stdout. invalid_polygon = ogr.CreateGeometryFromWkt(
gdal.Open('this_file_should_not_exist.tif') 'POLYGON ((-20 -20, -16 -20, -20 -16, -16 -16, -20 -20))')
# This produces a GDAL warning that does not raise an
# exception with UseExceptions(). Without capture_gdal_logging,
# it will be printed directly to stderr
invalid_polygon.IsValid()
with utils.log_to_file(logfile) as handler: with utils.log_to_file(logfile) as handler:
with utils.capture_gdal_logging(): with utils.capture_gdal_logging():
# warning should be captured. # warning should be captured.
gdal.Open('file_file_should_also_not_exist.tif') invalid_polygon.IsValid()
handler.flush() handler.flush()
# warning should go to stdout # warning should go to stderr
gdal.Open('this_file_should_not_exist.tif') invalid_polygon.IsValid()
with open(logfile) as opened_logfile: with open(logfile) as opened_logfile:
messages = [msg for msg in opened_logfile.read().split('\n') messages = [msg for msg in opened_logfile.read().split('\n')
@ -499,7 +505,11 @@ class PrepareWorkspaceTests(unittest.TestCase):
with utils.prepare_workspace(workspace, with utils.prepare_workspace(workspace,
'some_model'): 'some_model'):
warnings.warn('deprecated', UserWarning) warnings.warn('deprecated', UserWarning)
gdal.Open('file should not exist') invalid_polygon = ogr.CreateGeometryFromWkt(
'POLYGON ((-20 -20, -16 -20, -20 -16, -16 -16, -20 -20))')
# This produces a GDAL warning that does not raise an
# exception with UseExceptions()
invalid_polygon.IsValid()
self.assertTrue(os.path.exists(workspace)) self.assertTrue(os.path.exists(workspace))
logfile_glob = glob.glob(os.path.join(workspace, '*.txt')) logfile_glob = glob.glob(os.path.join(workspace, '*.txt'))
@ -509,11 +519,9 @@ class PrepareWorkspaceTests(unittest.TestCase):
with open(logfile_glob[0]) as logfile: with open(logfile_glob[0]) as logfile:
logfile_text = logfile.read() logfile_text = logfile.read()
# all the following strings should be in the logfile. # all the following strings should be in the logfile.
expected_string = ( self.assertTrue( # gdal logging captured
'file should not exist: No such file or directory') 'Self-intersection at or near point -18 -18' in logfile_text)
self.assertTrue( self.assertEqual(len(re.findall('WARNING', logfile_text)), 2)
expected_string in logfile_text) # gdal error captured
self.assertEqual(len(re.findall('WARNING', logfile_text)), 1)
self.assertTrue('Elapsed time:' in logfile_text) self.assertTrue('Elapsed time:' in logfile_text)

View File

@ -18,6 +18,7 @@ from osgeo import gdal
from osgeo import ogr from osgeo import ogr
from osgeo import osr from osgeo import osr
gdal.UseExceptions()
class SpatialOverlapTest(unittest.TestCase): class SpatialOverlapTest(unittest.TestCase):
"""Test Spatial Overlap.""" """Test Spatial Overlap."""

View File

@ -18,6 +18,7 @@ from shapely.geometry import Point
from natcap.invest import utils from natcap.invest import utils
import pygeoprocessing import pygeoprocessing
gdal.UseExceptions()
REGRESSION_DATA = os.path.join( REGRESSION_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'wave_energy') os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'wave_energy')
SAMPLE_DATA = os.path.join(REGRESSION_DATA, 'input') SAMPLE_DATA = os.path.join(REGRESSION_DATA, 'input')

View File

@ -17,6 +17,7 @@ from osgeo import osr
import pygeoprocessing import pygeoprocessing
gdal.UseExceptions()
SAMPLE_DATA = os.path.join( SAMPLE_DATA = os.path.join(
os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'wind_energy', os.path.dirname(__file__), '..', 'data', 'invest-test-data', 'wind_energy',
'input') 'input')