Skip to content

Commit 26d9c74

Browse files
committedJun 29, 2017
[processing] keep only one zonal statistics algorithm
1 parent 58f6f93 commit 26d9c74

File tree

3 files changed

+77
-354
lines changed

3 files changed

+77
-354
lines changed
 

‎python/plugins/processing/algs/qgis/QGISAlgorithmProvider.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
from .SpatialiteExecuteSQL import SpatialiteExecuteSQL
7373
from .SymmetricalDifference import SymmetricalDifference
7474
from .VectorSplit import VectorSplit
75-
from .ZonalStatisticsQgis import ZonalStatisticsQgis
75+
from .ZonalStatistics import ZonalStatistics
7676

7777
# from .ExtractByLocation import ExtractByLocation
7878
# from .PointsInPolygon import PointsInPolygon
@@ -119,7 +119,6 @@
119119
# from .JoinAttributes import JoinAttributes
120120
# from .CreateConstantRaster import CreateConstantRaster
121121
# from .PointsDisplacement import PointsDisplacement
122-
# from .ZonalStatistics import ZonalStatistics
123122
# from .PointsFromPolygons import PointsFromPolygons
124123
# from .PointsFromLines import PointsFromLines
125124
# from .RandomPointsExtent import RandomPointsExtent
@@ -210,7 +209,7 @@ def getAlgs(self):
210209
# EquivalentNumField(),
211210
# StatisticsByCategories(), ConcaveHull(),
212211
# RasterLayerStatistics(), PointsDisplacement(),
213-
# ZonalStatistics(), PointsFromPolygons(),
212+
# PointsFromPolygons(),
214213
# PointsFromLines(), RandomPointsExtent(),
215214
# RandomPointsLayer(), RandomPointsPolygonsFixed(),
216215
# RandomPointsPolygonsVariable(),
@@ -273,7 +272,7 @@ def getAlgs(self):
273272
SpatialiteExecuteSQL(),
274273
SymmetricalDifference(),
275274
VectorSplit(),
276-
ZonalStatisticsQgis()
275+
ZonalStatistics()
277276
]
278277

279278
if hasPlotly:

‎python/plugins/processing/algs/qgis/ZonalStatistics.py

Lines changed: 74 additions & 228 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
***************************************************************************
55
ZonalStatistics.py
66
---------------------
7-
Date : August 2013
8-
Copyright : (C) 2013 by Alexander Bruy
7+
Date : September 2016
8+
Copyright : (C) 2016 by Alexander Bruy
99
Email : alexander dot bruy at gmail dot com
1010
***************************************************************************
1111
* *
@@ -16,41 +16,33 @@
1616
* *
1717
***************************************************************************
1818
"""
19-
from builtins import str
2019

2120
__author__ = 'Alexander Bruy'
22-
__date__ = 'August 2013'
23-
__copyright__ = '(C) 2013, Alexander Bruy'
21+
__date__ = 'September 2016'
22+
__copyright__ = '(C) 2016, Alexander Bruy'
2423

2524
# This will get replaced with a git SHA1 when you do a git archive
2625

2726
__revision__ = '$Format:%H$'
2827

29-
import numpy
28+
import os
3029

31-
try:
32-
from scipy.stats.mstats import mode
33-
hasSciPy = True
34-
except:
35-
hasSciPy = False
30+
from qgis.PyQt.QtGui import QIcon
3631

37-
from osgeo import gdal, ogr, osr
38-
from qgis.core import (QgsApplication,
39-
QgsFeatureSink,
40-
QgsRectangle,
41-
QgsGeometry,
42-
QgsFeature,
43-
QgsProcessingUtils)
32+
from qgis.analysis import QgsZonalStatistics
33+
from qgis.core import (QgsFeatureSink,
34+
QgsProcessingUtils,
35+
QgsProcessingParameterDefinition,
36+
QgsProcessingParameterVectorLayer,
37+
QgsProcessingParameterRasterLayer,
38+
QgsProcessingParameterString,
39+
QgsProcessingParameterNumber,
40+
QgsProcessingParameterEnum,
41+
QgsProcessingOutputVectorLayer)
4442

4543
from processing.algs.qgis.QgisAlgorithm import QgisAlgorithm
46-
from processing.core.parameters import ParameterVector
47-
from processing.core.parameters import ParameterRaster
48-
from processing.core.parameters import ParameterString
49-
from processing.core.parameters import ParameterNumber
50-
from processing.core.parameters import ParameterBoolean
51-
from processing.core.outputs import OutputVector
52-
from processing.tools.raster import mapToPixel
53-
from processing.tools import dataobjects, vector
44+
45+
pluginPath = os.path.split(os.path.split(os.path.dirname(__file__))[0])[0]
5446

5547

5648
class ZonalStatistics(QgisAlgorithm):
@@ -59,26 +51,48 @@ class ZonalStatistics(QgisAlgorithm):
5951
RASTER_BAND = 'RASTER_BAND'
6052
INPUT_VECTOR = 'INPUT_VECTOR'
6153
COLUMN_PREFIX = 'COLUMN_PREFIX'
62-
GLOBAL_EXTENT = 'GLOBAL_EXTENT'
63-
OUTPUT_LAYER = 'OUTPUT_LAYER'
54+
STATISTICS = 'STATS'
55+
56+
def icon(self):
57+
return QIcon(os.path.join(pluginPath, 'images', 'zonalstats.png'))
6458

6559
def group(self):
6660
return self.tr('Raster tools')
6761

6862
def __init__(self):
6963
super().__init__()
70-
self.addParameter(ParameterRaster(self.INPUT_RASTER,
71-
self.tr('Raster layer')))
72-
self.addParameter(ParameterNumber(self.RASTER_BAND,
73-
self.tr('Raster band'), 1, 999, 1))
74-
self.addParameter(ParameterVector(self.INPUT_VECTOR,
75-
self.tr('Vector layer containing zones'),
76-
[dataobjects.TYPE_VECTOR_POLYGON]))
77-
self.addParameter(ParameterString(self.COLUMN_PREFIX,
78-
self.tr('Output column prefix'), '_'))
79-
self.addParameter(ParameterBoolean(self.GLOBAL_EXTENT,
80-
self.tr('Load whole raster in memory')))
81-
self.addOutput(OutputVector(self.OUTPUT_LAYER, self.tr('Zonal statistics'), datatype=[dataobjects.TYPE_VECTOR_POLYGON]))
64+
self.STATS = {self.tr('Count'): QgsZonalStatistics.Count,
65+
self.tr('Sum'): QgsZonalStatistics.Sum,
66+
self.tr('Mean'): QgsZonalStatistics.Mean,
67+
self.tr('Median'): QgsZonalStatistics.Median,
68+
self.tr('Std. dev.'): QgsZonalStatistics.StDev,
69+
self.tr('Min'): QgsZonalStatistics.Min,
70+
self.tr('Max'): QgsZonalStatistics.Max,
71+
self.tr('Range'): QgsZonalStatistics.Range,
72+
self.tr('Minority'): QgsZonalStatistics.Minority,
73+
self.tr('Majority (mode)'): QgsZonalStatistics.Majority,
74+
self.tr('Variety'): QgsZonalStatistics.Variety,
75+
self.tr('Variance'): QgsZonalStatistics.Variance,
76+
self.tr('All'): QgsZonalStatistics.All
77+
}
78+
79+
self.addParameter(QgsProcessingParameterRasterLayer(self.INPUT_RASTER,
80+
self.tr('Raster layer')))
81+
self.addParameter(QgsProcessingParameterNumber(self.RASTER_BAND,
82+
self.tr('Raster band'),
83+
minValue=1, maxValue=999, defaultValue=1))
84+
self.addParameter(QgsProcessingParameterVectorLayer(self.INPUT_VECTOR,
85+
self.tr('Vector layer containing zones'),
86+
[QgsProcessingParameterDefinition.TypeVectorPolygon]))
87+
self.addParameter(QgsProcessingParameterString(self.COLUMN_PREFIX,
88+
self.tr('Output column prefix'), '_'))
89+
self.addParameter(QgsProcessingParameterEnum(self.STATISTICS,
90+
self.tr('Statistics to calculate'),
91+
list(self.STATS.keys()),
92+
allowMultiple=True))
93+
self.addOutput(QgsProcessingOutputVectorLayer(self.INPUT_VECTOR,
94+
self.tr('Zonal statistics'),
95+
QgsProcessingParameterDefinition.TypeVectorPolygon))
8296

8397
def name(self):
8498
return 'zonalstatistics'
@@ -87,191 +101,23 @@ def displayName(self):
87101
return self.tr('Zonal Statistics')
88102

89103
def processAlgorithm(self, parameters, context, feedback):
90-
""" Based on code by Matthew Perry
91-
https://gist.github.com/perrygeo/5667173
92-
:param parameters:
93-
:param context:
94-
"""
95-
96-
layer = QgsProcessingUtils.mapLayerFromString(self.getParameterValue(self.INPUT_VECTOR), context)
97-
98-
rasterPath = str(self.getParameterValue(self.INPUT_RASTER))
99-
bandNumber = self.getParameterValue(self.RASTER_BAND)
100-
columnPrefix = self.getParameterValue(self.COLUMN_PREFIX)
101-
useGlobalExtent = self.getParameterValue(self.GLOBAL_EXTENT)
102-
103-
rasterDS = gdal.Open(rasterPath, gdal.GA_ReadOnly)
104-
geoTransform = rasterDS.GetGeoTransform()
105-
rasterBand = rasterDS.GetRasterBand(bandNumber)
106-
noData = rasterBand.GetNoDataValue()
107-
108-
cellXSize = abs(geoTransform[1])
109-
cellYSize = abs(geoTransform[5])
110-
rasterXSize = rasterDS.RasterXSize
111-
rasterYSize = rasterDS.RasterYSize
112-
113-
rasterBBox = QgsRectangle(geoTransform[0],
114-
geoTransform[3] - cellYSize * rasterYSize,
115-
geoTransform[0] + cellXSize * rasterXSize,
116-
geoTransform[3])
117-
118-
rasterGeom = QgsGeometry.fromRect(rasterBBox)
119-
120-
crs = osr.SpatialReference()
121-
crs.ImportFromProj4(str(layer.crs().toProj4()))
122-
123-
if useGlobalExtent:
124-
xMin = rasterBBox.xMinimum()
125-
xMax = rasterBBox.xMaximum()
126-
yMin = rasterBBox.yMinimum()
127-
yMax = rasterBBox.yMaximum()
128-
129-
(startColumn, startRow) = mapToPixel(xMin, yMax, geoTransform)
130-
(endColumn, endRow) = mapToPixel(xMax, yMin, geoTransform)
131-
132-
width = endColumn - startColumn
133-
height = endRow - startRow
134-
135-
srcOffset = (startColumn, startRow, width, height)
136-
srcArray = rasterBand.ReadAsArray(*srcOffset)
137-
srcArray = srcArray * rasterBand.GetScale() + rasterBand.GetOffset()
138-
139-
newGeoTransform = (
140-
geoTransform[0] + srcOffset[0] * geoTransform[1],
141-
geoTransform[1],
142-
0.0,
143-
geoTransform[3] + srcOffset[1] * geoTransform[5],
144-
0.0,
145-
geoTransform[5],
146-
)
147-
148-
memVectorDriver = ogr.GetDriverByName('Memory')
149-
memRasterDriver = gdal.GetDriverByName('MEM')
150-
151-
fields = layer.fields()
152-
(idxMin, fields) = vector.findOrCreateField(layer, fields,
153-
columnPrefix + 'min', 21, 6)
154-
(idxMax, fields) = vector.findOrCreateField(layer, fields,
155-
columnPrefix + 'max', 21, 6)
156-
(idxSum, fields) = vector.findOrCreateField(layer, fields,
157-
columnPrefix + 'sum', 21, 6)
158-
(idxCount, fields) = vector.findOrCreateField(layer, fields,
159-
columnPrefix + 'count', 21, 6)
160-
(idxMean, fields) = vector.findOrCreateField(layer, fields,
161-
columnPrefix + 'mean', 21, 6)
162-
(idxStd, fields) = vector.findOrCreateField(layer, fields,
163-
columnPrefix + 'std', 21, 6)
164-
(idxUnique, fields) = vector.findOrCreateField(layer, fields,
165-
columnPrefix + 'unique', 21, 6)
166-
(idxRange, fields) = vector.findOrCreateField(layer, fields,
167-
columnPrefix + 'range', 21, 6)
168-
(idxVar, fields) = vector.findOrCreateField(layer, fields,
169-
columnPrefix + 'var', 21, 6)
170-
(idxMedian, fields) = vector.findOrCreateField(layer, fields,
171-
columnPrefix + 'median', 21, 6)
172-
if hasSciPy:
173-
(idxMode, fields) = vector.findOrCreateField(layer, fields,
174-
columnPrefix + 'mode', 21, 6)
175-
176-
writer = self.getOutputFromName(self.OUTPUT_LAYER).getVectorWriter(fields, layer.wkbType(),
177-
layer.crs(), context)
178-
179-
outFeat = QgsFeature()
180-
181-
outFeat.initAttributes(len(fields))
182-
outFeat.setFields(fields)
183-
184-
features = QgsProcessingUtils.getFeatures(layer, context)
185-
total = 100.0 / layer.featureCount() if layer.featureCount() else 0
186-
for current, f in enumerate(features):
187-
geom = f.geometry()
188-
189-
intersectedGeom = rasterGeom.intersection(geom)
190-
ogrGeom = ogr.CreateGeometryFromWkt(intersectedGeom.exportToWkt())
191-
192-
if not useGlobalExtent:
193-
bbox = intersectedGeom.boundingBox()
194-
195-
xMin = bbox.xMinimum()
196-
xMax = bbox.xMaximum()
197-
yMin = bbox.yMinimum()
198-
yMax = bbox.yMaximum()
199-
200-
(startColumn, startRow) = mapToPixel(xMin, yMax, geoTransform)
201-
(endColumn, endRow) = mapToPixel(xMax, yMin, geoTransform)
202-
203-
width = endColumn - startColumn
204-
height = endRow - startRow
205-
206-
if width == 0 or height == 0:
207-
continue
208-
209-
srcOffset = (startColumn, startRow, width, height)
210-
srcArray = rasterBand.ReadAsArray(*srcOffset)
211-
srcArray = srcArray * rasterBand.GetScale() + rasterBand.GetOffset()
212-
213-
newGeoTransform = (
214-
geoTransform[0] + srcOffset[0] * geoTransform[1],
215-
geoTransform[1],
216-
0.0,
217-
geoTransform[3] + srcOffset[1] * geoTransform[5],
218-
0.0,
219-
geoTransform[5],
220-
)
221-
222-
# Create a temporary vector layer in memory
223-
memVDS = memVectorDriver.CreateDataSource('out')
224-
memLayer = memVDS.CreateLayer('poly', crs, ogr.wkbPolygon)
225-
226-
ft = ogr.Feature(memLayer.GetLayerDefn())
227-
ft.SetGeometry(ogrGeom)
228-
memLayer.CreateFeature(ft)
229-
ft.Destroy()
230-
231-
# Rasterize it
232-
rasterizedDS = memRasterDriver.Create('', srcOffset[2],
233-
srcOffset[3], 1, gdal.GDT_Byte)
234-
rasterizedDS.SetGeoTransform(newGeoTransform)
235-
gdal.RasterizeLayer(rasterizedDS, [1], memLayer, burn_values=[1])
236-
rasterizedArray = rasterizedDS.ReadAsArray()
237-
238-
srcArray = numpy.nan_to_num(srcArray)
239-
masked = numpy.ma.MaskedArray(srcArray,
240-
mask=numpy.logical_or(srcArray == noData,
241-
numpy.logical_not(rasterizedArray)))
242-
243-
outFeat.setGeometry(geom)
244-
245-
attrs = f.attributes()
246-
v = float(masked.min())
247-
attrs.insert(idxMin, None if numpy.isnan(v) else v)
248-
v = float(masked.max())
249-
attrs.insert(idxMax, None if numpy.isnan(v) else v)
250-
v = float(masked.sum())
251-
attrs.insert(idxSum, None if numpy.isnan(v) else v)
252-
attrs.insert(idxCount, int(masked.count()))
253-
v = float(masked.mean())
254-
attrs.insert(idxMean, None if numpy.isnan(v) else v)
255-
v = float(masked.std())
256-
attrs.insert(idxStd, None if numpy.isnan(v) else v)
257-
attrs.insert(idxUnique, numpy.unique(masked.compressed()).size)
258-
v = float(masked.max()) - float(masked.min())
259-
attrs.insert(idxRange, None if numpy.isnan(v) else v)
260-
v = float(masked.var())
261-
attrs.insert(idxVar, None if numpy.isnan(v) else v)
262-
v = float(numpy.ma.median(masked))
263-
attrs.insert(idxMedian, None if numpy.isnan(v) else v)
264-
if hasSciPy:
265-
attrs.insert(idxMode, float(mode(masked, axis=None)[0][0]))
266-
267-
outFeat.setAttributes(attrs)
268-
writer.addFeature(outFeat, QgsFeatureSink.FastInsert)
269-
270-
memVDS = None
271-
rasterizedDS = None
272-
273-
feedback.setProgress(int(current * total))
274-
275-
rasterDS = None
276-
277-
del writer
104+
bandNumber = self.parameterAsInt(parameters, self.RASTER_BAND, context)
105+
columnPrefix = self.parameterAsString(parameters, self.COLUMN_PREFIX, context)
106+
st = self.parameterAsEnums(parameters, self.STATISTICS, context)
107+
108+
vectorLayer = self.parameterAsVectorLayer(parameters, self.INPUT_VECTOR, context)
109+
rasterLayer = self.parameterAsRasterLayer(parameters, self.INPUT_RASTER, context)
110+
111+
keys = list(self.STATS.keys())
112+
selectedStats = 0
113+
for i in st:
114+
selectedStats |= self.STATS[keys[i]]
115+
116+
zs = QgsZonalStatistics(vectorLayer,
117+
rasterLayer,
118+
columnPrefix,
119+
bandNumber,
120+
selectedStats)
121+
zs.calculateStatistics(feedback)
122+
123+
return {self.INPUT_VECTOR: vectorLayer}

‎python/plugins/processing/algs/qgis/ZonalStatisticsQgis.py

Lines changed: 0 additions & 122 deletions
This file was deleted.

0 commit comments

Comments
 (0)
Please sign in to comment.