aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarc Bonnici <marc.bonnici@arm.com>2018-11-27 17:34:52 +0000
committersetrofim <setrofim@gmail.com>2018-12-07 09:55:17 +0000
commit0426a966dab4ed66f557f9b6ca41726df69ae086 (patch)
tree7fdf662febde47c600f3c873280a41099bf759f0
parenteabe15750c7f65f89410baac1f50e5b7f66316e5 (diff)
utils/postgres: Relocate functions to retrieve schema information
Move the functions to retrieve schema information to general utilities to be used in other classes.
-rw-r--r--wa/commands/create.py30
-rw-r--r--wa/output_processors/postgresql.py15
-rw-r--r--wa/utils/postgres.py41
3 files changed, 52 insertions, 34 deletions
diff --git a/wa/commands/create.py b/wa/commands/create.py
index 603feeae..7db6dd52 100644
--- a/wa/commands/create.py
+++ b/wa/commands/create.py
@@ -40,12 +40,11 @@ from wa.framework.exception import ConfigError, CommandError
from wa.instruments.energy_measurement import EnergyInstrumentBackend
from wa.utils.misc import (ensure_directory_exists as _d, capitalize,
ensure_file_directory_exists as _f)
-from wa.utils.postgres import get_schema
+from wa.utils.postgres import get_schema, POSTGRES_SCHEMA_DIR
from wa.utils.serializer import yaml
TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), 'templates')
-POSTGRES_SCHEMA_DIR = os.path.join(os.path.dirname(__file__), 'postgres_schemas')
class CreateDatabaseSubcommand(SubCommand):
@@ -114,7 +113,7 @@ class CreateDatabaseSubcommand(SubCommand):
raise ValueError('Databasename to create cannot be postgres.')
self._parse_args(args)
- self.schema_major, self.schema_minor, self.sql_commands = _get_schema(self.schemafilepath)
+ self.schema_major, self.schema_minor, self.sql_commands = get_schema(self.schemafilepath)
# Display the version if needed and exit
if args.schema_version:
@@ -191,7 +190,7 @@ class CreateDatabaseSubcommand(SubCommand):
def update_schema(self):
self._validate_version()
- schema_major, schema_minor, _ = _get_schema(self.schemafilepath)
+ schema_major, schema_minor, _ = get_schema(self.schemafilepath)
meta_oid, current_major, current_minor = self._get_database_schema_version()
while not (schema_major == current_major and schema_minor == current_minor):
@@ -209,7 +208,7 @@ class CreateDatabaseSubcommand(SubCommand):
if not os.path.exists(schema_update):
break
- _, _, sql_commands = _get_schema(schema_update)
+ _, _, sql_commands = get_schema(schema_update)
self._apply_database_schema(sql_commands, major, minor, meta_oid)
msg = "Updated the database schema to v{}.{}"
self.logger.debug(msg.format(major, minor))
@@ -226,7 +225,7 @@ class CreateDatabaseSubcommand(SubCommand):
# Reset minor to 0 with major version bump
current_minor = 0
- _, _, sql_commands = _get_schema(schema_update)
+ _, _, sql_commands = get_schema(schema_update)
self._apply_database_schema(sql_commands, current_major, current_minor, meta_oid)
msg = "Updated the database schema to v{}.{}"
self.logger.debug(msg.format(current_major, current_minor))
@@ -567,22 +566,3 @@ def get_class_name(name, postfix=''):
def touch(path):
with open(path, 'w') as _: # NOQA
pass
-
-
-def _get_schema(schemafilepath):
- sqlfile_path = os.path.join(
- POSTGRES_SCHEMA_DIR, schemafilepath)
-
- with open(sqlfile_path, 'r') as sqlfile:
- sql_commands = sqlfile.read()
-
- schema_major = None
- schema_minor = None
- # Extract schema version if present
- if sql_commands.startswith('--!VERSION'):
- splitcommands = sql_commands.split('!ENDVERSION!\n')
- schema_major, schema_minor = splitcommands[0].strip('--!VERSION!').split('.')
- schema_major = int(schema_major)
- schema_minor = int(schema_minor)
- sql_commands = splitcommands[1]
- return schema_major, schema_minor, sql_commands
diff --git a/wa/output_processors/postgresql.py b/wa/output_processors/postgresql.py
index 5c7059de..38dd6af1 100644
--- a/wa/output_processors/postgresql.py
+++ b/wa/output_processors/postgresql.py
@@ -31,7 +31,7 @@ from wa.framework.target.info import CpuInfo
from wa.utils.postgres import (POSTGRES_SCHEMA_DIR, cast_level, cast_vanilla,
adapt_vanilla, return_as_is, adapt_level,
ListOfLevel, adapt_ListOfX, create_iterable_adapter,
- get_schema, get_database_schema_version)
+ get_schema_versions)
from wa.utils.serializer import json
from wa.utils.types import level
@@ -127,7 +127,7 @@ class PostgresqlResultProcessor(OutputProcessor):
# N.B. Typecasters are for postgres->python and adapters the opposite
self.connect_to_database()
self.cursor = self.conn.cursor()
- self.check_schema_versions()
+ self.verify_schema_versions()
# Register the adapters and typecasters for enum types
self.cursor.execute("SELECT NULL::status_enum")
@@ -520,11 +520,9 @@ class PostgresqlResultProcessor(OutputProcessor):
self.conn.commit()
self.conn.reset()
- def check_schema_versions(self):
- schemafilepath = os.path.join(POSTGRES_SCHEMA_DIR, 'postgres_schema.sql')
- cur_major_version, cur_minor_version, _ = get_schema(schemafilepath)
- db_schema_version = get_database_schema_version(self.cursor)
- if (cur_major_version, cur_minor_version) != db_schema_version:
+ def verify_schema_versions(self):
+ local_schema_version, db_schema_version = get_schema_versions(self.cursor)
+ if local_schema_version != db_schema_version:
self.cursor.close()
self.cursor = None
self.conn.commit()
@@ -532,8 +530,7 @@ class PostgresqlResultProcessor(OutputProcessor):
msg = 'The current database schema is v{} however the local ' \
'schema version is v{}. Please update your database ' \
'with the create command'
- raise OutputProcessorError(msg.format(db_schema_version,
- (cur_major_version, cur_minor_version)))
+ raise OutputProcessorError(msg.format(db_schema_version, local_schema_version))
def _sql_write_lobject(self, source, lobject):
with open(source) as lobj_file:
diff --git a/wa/utils/postgres.py b/wa/utils/postgres.py
index 3a983204..1bedbbc6 100644
--- a/wa/utils/postgres.py
+++ b/wa/utils/postgres.py
@@ -28,6 +28,7 @@ http://initd.org/psycopg/docs/extensions.html#sql-adaptation-protocol-objects
"""
import re
+import os
try:
from psycopg2 import InterfaceError
@@ -39,6 +40,12 @@ except ImportError:
from wa.utils.types import level
+POSTGRES_SCHEMA_DIR = os.path.join(os.path.dirname(__file__),
+ '..',
+ 'commands',
+ 'postgres_schemas')
+
+
def cast_level(value, cur): # pylint: disable=unused-argument
"""Generic Level caster for psycopg2"""
if not InterfaceError:
@@ -217,3 +224,37 @@ def adapt_list(param):
final_string = final_string + str(item) + ","
final_string = "{" + final_string + "}"
return AsIs("'{}'".format(final_string))
+
+
+def get_schema(schemafilepath):
+ with open(schemafilepath, 'r') as sqlfile:
+ sql_commands = sqlfile.read()
+
+ schema_major = None
+ schema_minor = None
+ # Extract schema version if present
+ if sql_commands.startswith('--!VERSION'):
+ splitcommands = sql_commands.split('!ENDVERSION!\n')
+ schema_major, schema_minor = splitcommands[0].strip('--!VERSION!').split('.')
+ schema_major = int(schema_major)
+ schema_minor = int(schema_minor)
+ sql_commands = splitcommands[1]
+ return schema_major, schema_minor, sql_commands
+
+
+def get_database_schema_version(conn):
+ with conn.cursor() as cursor:
+ cursor.execute('''SELECT
+ DatabaseMeta.schema_major,
+ DatabaseMeta.schema_minor
+ FROM
+ DatabaseMeta;''')
+ schema_major, schema_minor = cursor.fetchone()
+ return (schema_major, schema_minor)
+
+
+def get_schema_versions(conn):
+ schemafilepath = os.path.join(POSTGRES_SCHEMA_DIR, 'postgres_schema.sql')
+ cur_major_version, cur_minor_version, _ = get_schema(schemafilepath)
+ db_schema_version = get_database_schema_version(conn)
+ return (cur_major_version, cur_minor_version), db_schema_version