Skip to content

Commit

Permalink
Add schemas methods
Browse files Browse the repository at this point in the history
  • Loading branch information
troopa81 authored and nyalldawson committed Jan 29, 2021
1 parent 7b77243 commit ba7d0fb
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 9 deletions.
12 changes: 12 additions & 0 deletions src/providers/oracle/qgsoracleconn.cpp
Expand Up @@ -23,6 +23,7 @@
#include "qgsfields.h"
#include "qgsoracletablemodel.h"
#include "qgssettings.h"
#include "qgsoracleconnpool.h"

#include <QSqlError>

Expand Down Expand Up @@ -986,4 +987,15 @@ QList<QgsVectorDataProvider::NativeType> QgsOracleConn::nativeTypes()
<< QgsVectorDataProvider::NativeType( tr( "Date & Time" ), "TIMESTAMP(6)", QVariant::DateTime, 38, 38, 6, 6 );
}

QgsPoolOracleConn::QgsPoolOracleConn( const QString &connInfo )
: mConn( QgsOracleConnPool::instance()->acquireConnection( connInfo ) )
{
}

QgsPoolOracleConn::~QgsPoolOracleConn()
{
if ( mConn )
QgsOracleConnPool::instance()->releaseConnection( mConn );
}

// vim: sw=2 :
15 changes: 15 additions & 0 deletions src/providers/oracle/qgsoracleconn.h
Expand Up @@ -112,6 +112,21 @@ struct QgsOracleLayerProperty
#endif
};

/**
* Wraps acquireConnection() and releaseConnection() from a QgsOracleConnPool.
* This can be used to ensure a connection is correctly released when scope ends
*/
class QgsPoolOracleConn
{
class QgsOracleConn *mConn;
public:
QgsPoolOracleConn( const QString &connInfo );
~QgsPoolOracleConn();

class QgsOracleConn *get() const { return mConn; }
};


class QgsOracleConn : public QObject
{
Q_OBJECT
Expand Down
85 changes: 80 additions & 5 deletions src/providers/oracle/qgsoracleproviderconnection.cpp
Expand Up @@ -21,6 +21,8 @@
#include "qgsexception.h"
#include "qgsapplication.h"

#include <QSqlRecord>

QgsOracleProviderConnection::QgsOracleProviderConnection( const QString &name )
: QgsAbstractDatabaseProviderConnection( name )
{
Expand All @@ -47,7 +49,7 @@ void QgsOracleProviderConnection::setDefaultCapabilities()
Capability::DropVectorTable,
Capability::DropRasterTable,
Capability::CreateVectorTable,
Capability::RenameSchema,
//Capability::RenameSchema,
Capability::DropSchema,
Capability::CreateSchema,
Capability::RenameVectorTable,
Expand Down Expand Up @@ -129,15 +131,88 @@ void QgsOracleProviderConnection::remove( const QString &name ) const
QList<QgsVectorDataProvider::NativeType> QgsOracleProviderConnection::nativeTypes() const
{
QList<QgsVectorDataProvider::NativeType> types;
QgsOracleConn *conn = QgsOracleConnPool::instance()->acquireConnection( QgsDataSourceUri{ uri() }.connectionInfo( false ) );
if ( conn )
QgsPoolOracleConn conn( QgsDataSourceUri{ uri() }.connectionInfo( false ) );
if ( conn.get() )
{
types = conn->nativeTypes();
QgsOracleConnPool::instance()->releaseConnection( conn );
types = conn.get()->nativeTypes();
}
if ( types.isEmpty() )
{
throw QgsProviderConnectionException( QObject::tr( "Error retrieving native types for connection %1" ).arg( uri() ) );
}
return types;
}

void QgsOracleProviderConnection::createSchema( const QString &name ) const
{
checkCapability( Capability::CreateSchema );
executeSqlPrivate( QStringLiteral( "CREATE USER %1" )
.arg( QgsOracleConn::quotedIdentifier( name ) ) );
}

void QgsOracleProviderConnection::dropSchema( const QString &name, bool force ) const
{
checkCapability( Capability::DropSchema );
executeSqlPrivate( QStringLiteral( "DROP USER %1 %2" )
.arg( QgsOracleConn::quotedIdentifier( name ) )
.arg( force ? QStringLiteral( "CASCADE" ) : QString() ) );
}

QStringList QgsOracleProviderConnection::schemas( ) const
{
checkCapability( Capability::Schemas );
QStringList schemas;

QList<QVariantList> users = executeSqlPrivate( QStringLiteral( "SELECT USERNAME FROM ALL_USERS" ) );
for ( QVariantList userInfos : users )
schemas << userInfos.at( 0 ).toString();

return schemas;
}

QList<QVariantList> QgsOracleProviderConnection::executeSqlPrivate( const QString &sql, QgsFeedback *feedback ) const
{
QList<QVariantList> results;

// Check feedback first!
if ( feedback && feedback->isCanceled() )
{
return results;
}

QgsPoolOracleConn pconn( QgsDataSourceUri{ uri() }.connectionInfo( false ) );
if ( !pconn.get() )
{
throw QgsProviderConnectionException( QObject::tr( "Connection failed: %1" ).arg( uri() ) );
}

if ( feedback && feedback->isCanceled() )
{
return results;
}

QSqlQuery qry( *pconn.get() );
if ( !qry.exec( sql ) )
{
throw QgsProviderConnectionException( QObject::tr( "SQL error: %1 returned %2" )
.arg( qry.lastQuery(),
qry.lastError().text() ) );
}

const int nbFields = qry.record().count();
while ( qry.next() )
{
if ( feedback && feedback->isCanceled() )
{
return results;
}

QVariantList cols;
for ( int i = 0; i < nbFields; i++ )
cols << qry.value( i );

results << cols;
}

return results;
}
5 changes: 5 additions & 0 deletions src/providers/oracle/qgsoracleproviderconnection.h
Expand Up @@ -28,12 +28,17 @@ class QgsOracleProviderConnection : public QgsAbstractDatabaseProviderConnection

// QgsAbstractProviderConnection interface

void createSchema( const QString &name ) const override;
void dropSchema( const QString &name, bool force = false ) const override;

QStringList schemas( ) const override;
void store( const QString &name ) const override;
void remove( const QString &name ) const override;
QList<QgsVectorDataProvider::NativeType> nativeTypes() const override;

private:

QList<QVariantList> executeSqlPrivate( const QString &sql, QgsFeedback *feedback = nullptr ) const;
void setDefaultCapabilities();
};

Expand Down
10 changes: 6 additions & 4 deletions tests/src/python/test_qgsproviderconnection_base.py
Expand Up @@ -47,6 +47,8 @@ class TestPyQgsProviderConnectionBase():
# Provider test cases must define the provider name (e.g. "postgres" or "ogr")
providerKey = ''

configuration = {}

@classmethod
def setUpClass(cls):
"""Run before all tests"""
Expand All @@ -63,10 +65,10 @@ def tearDownClass(cls):
def setUp(self):
QgsSettings().clear()

def _test_save_load(self, md, uri):
def _test_save_load(self, md, uri, configuration):
"""Common tests on connection save and load"""

conn = md.createConnection(uri, {})
conn = md.createConnection(uri, configuration)

md.saveConnection(conn, 'qgis_test1')
# Check that we retrieve the new connection
Expand Down Expand Up @@ -399,7 +401,7 @@ def test_errors(self):
"""Test SQL errors"""

md = QgsProviderRegistry.instance().providerMetadata(self.providerKey)
conn = self._test_save_load(md, self.uri)
conn = self._test_save_load(md, self.uri, self.configuration)

if conn.capabilities() & QgsAbstractDatabaseProviderConnection.Schemas:
with self.assertRaises(QgsProviderConnectionException) as ex:
Expand All @@ -423,7 +425,7 @@ def test_connections(self):
created_spy = QSignalSpy(md.connectionCreated)
changed_spy = QSignalSpy(md.connectionChanged)

conn = self._test_save_load(md, self.uri)
conn = self._test_save_load(md, self.uri, self.configuration)

self.assertEqual(len(created_spy), 1)
self.assertEqual(len(changed_spy), 0)
Expand Down
4 changes: 4 additions & 0 deletions tests/src/python/test_qgsproviderconnection_oracle.py
Expand Up @@ -30,6 +30,10 @@ class TestPyQgsProviderConnectionOracle(unittest.TestCase, TestPyQgsProviderConn
# Provider test cases must define the provider name (e.g. "postgres" or "ogr")
providerKey = 'oracle'

# there is no service for oracle provider test so we need to save user and password
# to keep them when storing/loading connections in parent class _test_save_load method
configuration = {"saveUsername": True, "savePassword": True}

@classmethod
def setUpClass(cls):
"""Run before all tests"""
Expand Down

0 comments on commit ba7d0fb

Please sign in to comment.