Skip to content

Commit

Permalink
[processing] keep only one zonal statistics algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
nirvn committed Jun 29, 2017
1 parent 58f6f93 commit 26d9c74
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 354 deletions.
7 changes: 3 additions & 4 deletions python/plugins/processing/algs/qgis/QGISAlgorithmProvider.py
Expand Up @@ -72,7 +72,7 @@
from .SpatialiteExecuteSQL import SpatialiteExecuteSQL
from .SymmetricalDifference import SymmetricalDifference
from .VectorSplit import VectorSplit
from .ZonalStatisticsQgis import ZonalStatisticsQgis
from .ZonalStatistics import ZonalStatistics

# from .ExtractByLocation import ExtractByLocation
# from .PointsInPolygon import PointsInPolygon
Expand Down Expand Up @@ -119,7 +119,6 @@
# from .JoinAttributes import JoinAttributes
# from .CreateConstantRaster import CreateConstantRaster
# from .PointsDisplacement import PointsDisplacement
# from .ZonalStatistics import ZonalStatistics
# from .PointsFromPolygons import PointsFromPolygons
# from .PointsFromLines import PointsFromLines
# from .RandomPointsExtent import RandomPointsExtent
Expand Down Expand Up @@ -210,7 +209,7 @@ def getAlgs(self):
# EquivalentNumField(),
# StatisticsByCategories(), ConcaveHull(),
# RasterLayerStatistics(), PointsDisplacement(),
# ZonalStatistics(), PointsFromPolygons(),
# PointsFromPolygons(),
# PointsFromLines(), RandomPointsExtent(),
# RandomPointsLayer(), RandomPointsPolygonsFixed(),
# RandomPointsPolygonsVariable(),
Expand Down Expand Up @@ -273,7 +272,7 @@ def getAlgs(self):
SpatialiteExecuteSQL(),
SymmetricalDifference(),
VectorSplit(),
ZonalStatisticsQgis()
ZonalStatistics()
]

if hasPlotly:
Expand Down
302 changes: 74 additions & 228 deletions python/plugins/processing/algs/qgis/ZonalStatistics.py
Expand Up @@ -4,8 +4,8 @@
***************************************************************************
ZonalStatistics.py
---------------------
Date : August 2013
Copyright : (C) 2013 by Alexander Bruy
Date : September 2016
Copyright : (C) 2016 by Alexander Bruy
Email : alexander dot bruy at gmail dot com
***************************************************************************
* *
Expand All @@ -16,41 +16,33 @@
* *
***************************************************************************
"""
from builtins import str

__author__ = 'Alexander Bruy'
__date__ = 'August 2013'
__copyright__ = '(C) 2013, Alexander Bruy'
__date__ = 'September 2016'
__copyright__ = '(C) 2016, Alexander Bruy'

# This will get replaced with a git SHA1 when you do a git archive

__revision__ = '$Format:%H$'

import numpy
import os

try:
from scipy.stats.mstats import mode
hasSciPy = True
except:
hasSciPy = False
from qgis.PyQt.QtGui import QIcon

from osgeo import gdal, ogr, osr
from qgis.core import (QgsApplication,
QgsFeatureSink,
QgsRectangle,
QgsGeometry,
QgsFeature,
QgsProcessingUtils)
from qgis.analysis import QgsZonalStatistics
from qgis.core import (QgsFeatureSink,
QgsProcessingUtils,
QgsProcessingParameterDefinition,
QgsProcessingParameterVectorLayer,
QgsProcessingParameterRasterLayer,
QgsProcessingParameterString,
QgsProcessingParameterNumber,
QgsProcessingParameterEnum,
QgsProcessingOutputVectorLayer)

from processing.algs.qgis.QgisAlgorithm import QgisAlgorithm
from processing.core.parameters import ParameterVector
from processing.core.parameters import ParameterRaster
from processing.core.parameters import ParameterString
from processing.core.parameters import ParameterNumber
from processing.core.parameters import ParameterBoolean
from processing.core.outputs import OutputVector
from processing.tools.raster import mapToPixel
from processing.tools import dataobjects, vector

pluginPath = os.path.split(os.path.split(os.path.dirname(__file__))[0])[0]


class ZonalStatistics(QgisAlgorithm):
Expand All @@ -59,26 +51,48 @@ class ZonalStatistics(QgisAlgorithm):
RASTER_BAND = 'RASTER_BAND'
INPUT_VECTOR = 'INPUT_VECTOR'
COLUMN_PREFIX = 'COLUMN_PREFIX'
GLOBAL_EXTENT = 'GLOBAL_EXTENT'
OUTPUT_LAYER = 'OUTPUT_LAYER'
STATISTICS = 'STATS'

def icon(self):
return QIcon(os.path.join(pluginPath, 'images', 'zonalstats.png'))

def group(self):
return self.tr('Raster tools')

def __init__(self):
super().__init__()
self.addParameter(ParameterRaster(self.INPUT_RASTER,
self.tr('Raster layer')))
self.addParameter(ParameterNumber(self.RASTER_BAND,
self.tr('Raster band'), 1, 999, 1))
self.addParameter(ParameterVector(self.INPUT_VECTOR,
self.tr('Vector layer containing zones'),
[dataobjects.TYPE_VECTOR_POLYGON]))
self.addParameter(ParameterString(self.COLUMN_PREFIX,
self.tr('Output column prefix'), '_'))
self.addParameter(ParameterBoolean(self.GLOBAL_EXTENT,
self.tr('Load whole raster in memory')))
self.addOutput(OutputVector(self.OUTPUT_LAYER, self.tr('Zonal statistics'), datatype=[dataobjects.TYPE_VECTOR_POLYGON]))
self.STATS = {self.tr('Count'): QgsZonalStatistics.Count,
self.tr('Sum'): QgsZonalStatistics.Sum,
self.tr('Mean'): QgsZonalStatistics.Mean,
self.tr('Median'): QgsZonalStatistics.Median,
self.tr('Std. dev.'): QgsZonalStatistics.StDev,
self.tr('Min'): QgsZonalStatistics.Min,
self.tr('Max'): QgsZonalStatistics.Max,
self.tr('Range'): QgsZonalStatistics.Range,
self.tr('Minority'): QgsZonalStatistics.Minority,
self.tr('Majority (mode)'): QgsZonalStatistics.Majority,
self.tr('Variety'): QgsZonalStatistics.Variety,
self.tr('Variance'): QgsZonalStatistics.Variance,
self.tr('All'): QgsZonalStatistics.All
}

self.addParameter(QgsProcessingParameterRasterLayer(self.INPUT_RASTER,
self.tr('Raster layer')))
self.addParameter(QgsProcessingParameterNumber(self.RASTER_BAND,
self.tr('Raster band'),
minValue=1, maxValue=999, defaultValue=1))
self.addParameter(QgsProcessingParameterVectorLayer(self.INPUT_VECTOR,
self.tr('Vector layer containing zones'),
[QgsProcessingParameterDefinition.TypeVectorPolygon]))
self.addParameter(QgsProcessingParameterString(self.COLUMN_PREFIX,
self.tr('Output column prefix'), '_'))
self.addParameter(QgsProcessingParameterEnum(self.STATISTICS,
self.tr('Statistics to calculate'),
list(self.STATS.keys()),
allowMultiple=True))
self.addOutput(QgsProcessingOutputVectorLayer(self.INPUT_VECTOR,
self.tr('Zonal statistics'),
QgsProcessingParameterDefinition.TypeVectorPolygon))

def name(self):
return 'zonalstatistics'
Expand All @@ -87,191 +101,23 @@ def displayName(self):
return self.tr('Zonal Statistics')

def processAlgorithm(self, parameters, context, feedback):
""" Based on code by Matthew Perry
https://gist.github.com/perrygeo/5667173
:param parameters:
:param context:
"""

layer = QgsProcessingUtils.mapLayerFromString(self.getParameterValue(self.INPUT_VECTOR), context)

rasterPath = str(self.getParameterValue(self.INPUT_RASTER))
bandNumber = self.getParameterValue(self.RASTER_BAND)
columnPrefix = self.getParameterValue(self.COLUMN_PREFIX)
useGlobalExtent = self.getParameterValue(self.GLOBAL_EXTENT)

rasterDS = gdal.Open(rasterPath, gdal.GA_ReadOnly)
geoTransform = rasterDS.GetGeoTransform()
rasterBand = rasterDS.GetRasterBand(bandNumber)
noData = rasterBand.GetNoDataValue()

cellXSize = abs(geoTransform[1])
cellYSize = abs(geoTransform[5])
rasterXSize = rasterDS.RasterXSize
rasterYSize = rasterDS.RasterYSize

rasterBBox = QgsRectangle(geoTransform[0],
geoTransform[3] - cellYSize * rasterYSize,
geoTransform[0] + cellXSize * rasterXSize,
geoTransform[3])

rasterGeom = QgsGeometry.fromRect(rasterBBox)

crs = osr.SpatialReference()
crs.ImportFromProj4(str(layer.crs().toProj4()))

if useGlobalExtent:
xMin = rasterBBox.xMinimum()
xMax = rasterBBox.xMaximum()
yMin = rasterBBox.yMinimum()
yMax = rasterBBox.yMaximum()

(startColumn, startRow) = mapToPixel(xMin, yMax, geoTransform)
(endColumn, endRow) = mapToPixel(xMax, yMin, geoTransform)

width = endColumn - startColumn
height = endRow - startRow

srcOffset = (startColumn, startRow, width, height)
srcArray = rasterBand.ReadAsArray(*srcOffset)
srcArray = srcArray * rasterBand.GetScale() + rasterBand.GetOffset()

newGeoTransform = (
geoTransform[0] + srcOffset[0] * geoTransform[1],
geoTransform[1],
0.0,
geoTransform[3] + srcOffset[1] * geoTransform[5],
0.0,
geoTransform[5],
)

memVectorDriver = ogr.GetDriverByName('Memory')
memRasterDriver = gdal.GetDriverByName('MEM')

fields = layer.fields()
(idxMin, fields) = vector.findOrCreateField(layer, fields,
columnPrefix + 'min', 21, 6)
(idxMax, fields) = vector.findOrCreateField(layer, fields,
columnPrefix + 'max', 21, 6)
(idxSum, fields) = vector.findOrCreateField(layer, fields,
columnPrefix + 'sum', 21, 6)
(idxCount, fields) = vector.findOrCreateField(layer, fields,
columnPrefix + 'count', 21, 6)
(idxMean, fields) = vector.findOrCreateField(layer, fields,
columnPrefix + 'mean', 21, 6)
(idxStd, fields) = vector.findOrCreateField(layer, fields,
columnPrefix + 'std', 21, 6)
(idxUnique, fields) = vector.findOrCreateField(layer, fields,
columnPrefix + 'unique', 21, 6)
(idxRange, fields) = vector.findOrCreateField(layer, fields,
columnPrefix + 'range', 21, 6)
(idxVar, fields) = vector.findOrCreateField(layer, fields,
columnPrefix + 'var', 21, 6)
(idxMedian, fields) = vector.findOrCreateField(layer, fields,
columnPrefix + 'median', 21, 6)
if hasSciPy:
(idxMode, fields) = vector.findOrCreateField(layer, fields,
columnPrefix + 'mode', 21, 6)

writer = self.getOutputFromName(self.OUTPUT_LAYER).getVectorWriter(fields, layer.wkbType(),
layer.crs(), context)

outFeat = QgsFeature()

outFeat.initAttributes(len(fields))
outFeat.setFields(fields)

features = QgsProcessingUtils.getFeatures(layer, context)
total = 100.0 / layer.featureCount() if layer.featureCount() else 0
for current, f in enumerate(features):
geom = f.geometry()

intersectedGeom = rasterGeom.intersection(geom)
ogrGeom = ogr.CreateGeometryFromWkt(intersectedGeom.exportToWkt())

if not useGlobalExtent:
bbox = intersectedGeom.boundingBox()

xMin = bbox.xMinimum()
xMax = bbox.xMaximum()
yMin = bbox.yMinimum()
yMax = bbox.yMaximum()

(startColumn, startRow) = mapToPixel(xMin, yMax, geoTransform)
(endColumn, endRow) = mapToPixel(xMax, yMin, geoTransform)

width = endColumn - startColumn
height = endRow - startRow

if width == 0 or height == 0:
continue

srcOffset = (startColumn, startRow, width, height)
srcArray = rasterBand.ReadAsArray(*srcOffset)
srcArray = srcArray * rasterBand.GetScale() + rasterBand.GetOffset()

newGeoTransform = (
geoTransform[0] + srcOffset[0] * geoTransform[1],
geoTransform[1],
0.0,
geoTransform[3] + srcOffset[1] * geoTransform[5],
0.0,
geoTransform[5],
)

# Create a temporary vector layer in memory
memVDS = memVectorDriver.CreateDataSource('out')
memLayer = memVDS.CreateLayer('poly', crs, ogr.wkbPolygon)

ft = ogr.Feature(memLayer.GetLayerDefn())
ft.SetGeometry(ogrGeom)
memLayer.CreateFeature(ft)
ft.Destroy()

# Rasterize it
rasterizedDS = memRasterDriver.Create('', srcOffset[2],
srcOffset[3], 1, gdal.GDT_Byte)
rasterizedDS.SetGeoTransform(newGeoTransform)
gdal.RasterizeLayer(rasterizedDS, [1], memLayer, burn_values=[1])
rasterizedArray = rasterizedDS.ReadAsArray()

srcArray = numpy.nan_to_num(srcArray)
masked = numpy.ma.MaskedArray(srcArray,
mask=numpy.logical_or(srcArray == noData,
numpy.logical_not(rasterizedArray)))

outFeat.setGeometry(geom)

attrs = f.attributes()
v = float(masked.min())
attrs.insert(idxMin, None if numpy.isnan(v) else v)
v = float(masked.max())
attrs.insert(idxMax, None if numpy.isnan(v) else v)
v = float(masked.sum())
attrs.insert(idxSum, None if numpy.isnan(v) else v)
attrs.insert(idxCount, int(masked.count()))
v = float(masked.mean())
attrs.insert(idxMean, None if numpy.isnan(v) else v)
v = float(masked.std())
attrs.insert(idxStd, None if numpy.isnan(v) else v)
attrs.insert(idxUnique, numpy.unique(masked.compressed()).size)
v = float(masked.max()) - float(masked.min())
attrs.insert(idxRange, None if numpy.isnan(v) else v)
v = float(masked.var())
attrs.insert(idxVar, None if numpy.isnan(v) else v)
v = float(numpy.ma.median(masked))
attrs.insert(idxMedian, None if numpy.isnan(v) else v)
if hasSciPy:
attrs.insert(idxMode, float(mode(masked, axis=None)[0][0]))

outFeat.setAttributes(attrs)
writer.addFeature(outFeat, QgsFeatureSink.FastInsert)

memVDS = None
rasterizedDS = None

feedback.setProgress(int(current * total))

rasterDS = None

del writer
bandNumber = self.parameterAsInt(parameters, self.RASTER_BAND, context)
columnPrefix = self.parameterAsString(parameters, self.COLUMN_PREFIX, context)
st = self.parameterAsEnums(parameters, self.STATISTICS, context)

vectorLayer = self.parameterAsVectorLayer(parameters, self.INPUT_VECTOR, context)
rasterLayer = self.parameterAsRasterLayer(parameters, self.INPUT_RASTER, context)

keys = list(self.STATS.keys())
selectedStats = 0
for i in st:
selectedStats |= self.STATS[keys[i]]

zs = QgsZonalStatistics(vectorLayer,
rasterLayer,
columnPrefix,
bandNumber,
selectedStats)
zs.calculateStatistics(feedback)

return {self.INPUT_VECTOR: vectorLayer}

0 comments on commit 26d9c74

Please sign in to comment.