Skip to content

Commit

Permalink
Always avoid duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
YoannQDQ authored and nyalldawson committed Mar 30, 2023
1 parent 6cf3d72 commit a083874
Showing 1 changed file with 44 additions and 90 deletions.
134 changes: 44 additions & 90 deletions src/analysis/processing/qgsalgorithmrandomextract.cpp
Expand Up @@ -67,10 +67,6 @@ void QgsRandomExtractAlgorithm::initAlgorithm( const QVariantMap & )
addParameter( new QgsProcessingParameterNumber( QStringLiteral( "NUMBER" ), QObject::tr( "Number/percentage of features" ),
QgsProcessingParameterNumber::Integer, 10, false, 0 ) );

std::unique_ptr< QgsProcessingParameterBoolean > noDuplicates = std::make_unique< QgsProcessingParameterBoolean >( QStringLiteral( "NO_DUPLICATES" ), QObject::tr( "Avoid duplicates" ), false );
noDuplicates->setHelp( QObject::tr( "If checked, ensure the resulting subset contains no duplicated feature, at the cost of slightly worse performance" ) ) ;
addParameter( noDuplicates.release() );

addParameter( new QgsProcessingParameterFeatureSink( QStringLiteral( "OUTPUT" ), QObject::tr( "Extracted (random)" ) ) );
}

Expand All @@ -88,8 +84,6 @@ QVariantMap QgsRandomExtractAlgorithm::processAlgorithm( const QVariantMap &para

const int method = parameterAsEnum( parameters, QStringLiteral( "METHOD" ), context );
int number = parameterAsInt( parameters, QStringLiteral( "NUMBER" ), context );
const bool noDuplicates = parameterAsBool( parameters, QStringLiteral( "NO_DUPLICATES" ), context );

const long count = source->featureCount();

if ( method == 0 )
Expand Down Expand Up @@ -126,95 +120,55 @@ QVariantMap QgsRandomExtractAlgorithm::processAlgorithm( const QVariantMap &para
// initialize random engine
std::random_device randomDevice;
std::mt19937 mersenneTwister( randomDevice() );

if ( !noDuplicates )
std::uniform_int_distribution<size_t> fidsDistribution;

// If the number of features to select is greater than half the total number of features
// we will instead randomly select features to *exclude* from the output layer
size_t shuffledFeatureCount = number;
bool invertSelection = number > count / 2;
if ( invertSelection )
shuffledFeatureCount = count - number;

size_t nb = count;

// Shuffle <number> features at the start of the iterator
feedback->pushInfo( QObject::tr( "Randomly select %1 features" ).arg( number ) );
auto cursor = allFeats.begin();
using difference_type = std::vector<QgsFeatureId>::difference_type;
while ( shuffledFeatureCount-- )
{
feedback->pushInfo( QObject::tr( "Randomly select %1 features" ).arg( number ) );
const std::uniform_int_distribution<size_t> fidsDistribution( 0, allFeats.size() - 1 );

std::vector< size_t > indexes( number );
std::generate( indexes.begin(), indexes.end(), bind( fidsDistribution, mersenneTwister ) );
QHash< QgsFeatureId, int > idsCount;
for ( size_t i : indexes )
{
const QgsFeatureId id = allFeats.at( i );
if ( feedback->isCanceled() )
return QVariantMap();

idsCount[ id ] += 1;
}

const QgsFeatureIds ids = qgis::listToSet( idsCount.keys() );

feedback->pushInfo( QObject::tr( "Adding selected features" ) );
QgsFeatureIterator fit = source->getFeatures( QgsFeatureRequest().setFilterFids( ids ), QgsProcessingFeatureSource::FlagSkipGeometryValidityChecks );
while ( fit.nextFeature( f ) )
{
if ( feedback->isCanceled() )
return QVariantMap();

const int count = idsCount.value( f.id() );
for ( int i = 0; i < count; ++i )
{
if ( !sink->addFeature( f, QgsFeatureSink::FastInsert ) )
throw QgsProcessingException( writeFeatureError( sink.get(), parameters, QStringLiteral( "OUTPUT" ) ) );
}
}
if ( feedback->isCanceled() )
return QVariantMap();

// Update the distribution to match the number of unshuffled features
fidsDistribution.param( std::uniform_int_distribution<size_t>::param_type( 0, nb - 1 ) );
// Swap the current feature with a random one
std::swap( *cursor, *( cursor + static_cast<difference_type>( fidsDistribution( mersenneTwister ) ) ) );
// Move the cursor to the next feature
++cursor;

// Decrement the number of unshuffled features
--nb;
}

// No duplicates
// Insert the selected features into a QgsFeatureIds set
QgsFeatureIds selected;
if ( invertSelection )
for ( auto it = cursor; it != allFeats.end(); ++it )
selected.insert( *it );
else
for ( auto it = allFeats.begin(); it != cursor; ++it )
selected.insert( *it );

feedback->pushInfo( QObject::tr( "Adding selected features" ) );
fit = source->getFeatures( QgsFeatureRequest().setFilterFids( selected ), QgsProcessingFeatureSource::FlagSkipGeometryValidityChecks );
while ( fit.nextFeature( f ) )
{
feedback->pushInfo( QObject::tr( "Randomly select %1 features" ).arg( number ) );
std::uniform_int_distribution<size_t> fidsDistribution;

// If the number of features to select is greater than half the total number of features
// we will instead randomly select features to *exclude* from the output layer
size_t shuffledFeatureCount = number;
bool invertSelection = number > count / 2;
if ( invertSelection )
shuffledFeatureCount = count - number;

size_t nb = count;

// Shuffle <number> features at the start of the iterator
auto cursor = allFeats.begin();
using difference_type = std::vector<QgsFeatureId>::difference_type;
while ( shuffledFeatureCount-- )
{
if ( feedback->isCanceled() )
return QVariantMap();

// Update the distribution to match the number of unshuffled features
fidsDistribution.param( std::uniform_int_distribution<size_t>::param_type( 0, nb - 1 ) );
// Swap the current feature with a random one
std::swap( *cursor, *( cursor + static_cast<difference_type>( fidsDistribution( mersenneTwister ) ) ) );
// Move the cursor to the next feature
++cursor;

// Decrement the number of unshuffled features
--nb;
}

// Insert the selected features into a QgsFeatureIds set
QgsFeatureIds selected;
if ( invertSelection )
for ( auto it = cursor; it != allFeats.end(); ++it )
selected.insert( *it );
else
for ( auto it = allFeats.begin(); it != cursor; ++it )
selected.insert( *it );

feedback->pushInfo( QObject::tr( "Adding selected features" ) );
fit = source->getFeatures( QgsFeatureRequest().setFilterFids( selected ), QgsProcessingFeatureSource::FlagSkipGeometryValidityChecks );
while ( fit.nextFeature( f ) )
{
if ( feedback->isCanceled() )
return QVariantMap();

if ( !sink->addFeature( f, QgsFeatureSink::FastInsert ) )
throw QgsProcessingException( writeFeatureError( sink.get(), parameters, QStringLiteral( "OUTPUT" ) ) );
}
if ( feedback->isCanceled() )
return QVariantMap();

if ( !sink->addFeature( f, QgsFeatureSink::FastInsert ) )
throw QgsProcessingException( writeFeatureError( sink.get(), parameters, QStringLiteral( "OUTPUT" ) ) );
}

QVariantMap outputs;
Expand Down

0 comments on commit a083874

Please sign in to comment.