"""
Low-level relational database / sqlalchemy interaction.
The actual schemas for database tables are implemented in other files in this subpackage.
"""
from __future__ import division, print_function
import os, sys, io, json, threading, gc, re, weakref
from collections import OrderedDict, namedtuple
import numpy as np
try:
import queue
except ImportError:
import Queue as queue
from inspect import isclass
import sqlalchemy
from distutils.version import LooseVersion
if LooseVersion(sqlalchemy.__version__) < '1.2':
raise Exception('requires at least sqlalchemy 1.2')
import sqlalchemy.inspection, sqlalchemy.pool
from sqlalchemy import create_engine, Column, Integer, BigInteger, String, Boolean, Float, Date, DateTime, LargeBinary, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, deferred, sessionmaker, reconstructor
from sqlalchemy.types import TypeDecorator
from sqlalchemy.sql.expression import func
from .. import config
from neuroanalysis.util.optional_import import optional_import
pandas = optional_import('pandas')
class NDArray(TypeDecorator):
"""For marshalling arrays in/out of binary DB fields.
"""
impl = LargeBinary
hashable = False
cache_ok = False
def process_bind_param(self, value, dialect):
if value is None:
return b''
buf = io.BytesIO()
np.save(buf, value, allow_pickle=False)
return buf.getvalue()
def process_result_value(self, value, dialect):
if value is None or value == b'':
return None
buf = io.BytesIO(value)
return np.load(buf, allow_pickle=False)
@property
def python_type(self):
return np.ndarray
class CustomEncoder(json.JSONEncoder):
""" For encoding nonserializable floats into json
"""
def default(self, obj):
if isinstance(obj, np.floating):
return float(obj)
return json.JSONEncoder.default(self, obj)
class JSONObject(TypeDecorator):
"""For marshalling objects in/out of json-encoded text.
"""
impl = String
hashable = False
def process_bind_param(self, value, dialect):
return json.dumps(value, cls=CustomEncoder)
def process_result_value(self, value, dialect):
if value is None:
return None
return json.loads(value)
@property
def python_type(self):
object
class FloatType(TypeDecorator):
"""For marshalling float types (including numpy).
"""
impl = Float
cache_ok = False
def process_bind_param(self, value, dialect):
if value is None:
return None
return float(value)
@property
def python_type(self):
return float
#def process_result_value(self, value, dialect):
#buf = io.BytesIO(value)
#return np.load(buf, allow_pickle=False)
column_data_types = {
'int': Integer,
'bigint': BigInteger,
'float': FloatType,
'bool': Boolean,
'str': String,
'date': Date,
'datetime': DateTime,
'array': NDArray,
# 'object': JSONB, # provides support for postges jsonb, but conflicts with sqlite
'object': JSONObject,
}
def make_table_docstring(table):
"""Introspect ORM table class to generate a nice docstring.
"""
docstr = ['Sqlalchemy model for "%s" database table.\n' % table.__name__]
comment = table.__table_args__.get('comment', None)
if comment is not None:
docstr.append(comment.strip() + '\n')
insp = sqlalchemy.inspection.inspect(table)
docstr.append("Attributes\n----------")
for name, prop in insp.relationships.items():
docstr.append("%s : relationship" % name)
if hasattr(prop, 'entity'):
# entity attribute only available in recent sqlalchemy (>=1.3 ?)
docstr.append(" Reference to %s.%s" % (prop.entity.primary_key[0].table.name, prop.entity.primary_key[0].name))
for name, col in insp.columns.items():
typ_str = str(col.type)
docstr.append("%s : %s" % (name, typ_str))
if col.comment is not None:
docstr.append(" " + col.comment)
return '\n'.join(docstr)
def make_table(ormbase, name, columns, base=None, **table_args):
"""Generate an ORM mapping class from a simplified schema format.
Columns named 'id' (int) and 'meta' (object) are added automatically.
Parameters
----------
ormbase : ORMBase instance
The sqlalchemy ORM base on which to create this table.
name : str
Name of the table, used to set __tablename__ in the new class
base : class or None
Base class on which to build the new table class
table_args : keyword arguments
Extra keyword arguments are used to set __table_args__ in the new class
columns : list of tuple
List of column specifications. Each column is given as a tuple:
``(col_name, data_type, comment, {options})``. Where *col_name* and *comment*
are strings, *data_type* is a key in the column_data_types global, and
*options* is a dict providing extra initialization arguments to the sqlalchemy
Column (for example: 'index', 'unique'). Optionally, *data_type* may be a 'tablename.id'
string indicating that this column is a foreign key referencing another table.
"""
class_name = ''.join([part.title() for part in name.split('_')])
props = {
'__tablename__': name,
'__table_args__': table_args,
'id': Column(Integer, primary_key=True),
}
for column in columns:
colname, coltype = column[:2]
# avoid weird sqlalchemy issues with case handling
assert colname == colname.lower(), "Column names must be all lowercase (got %s.%s)" % (name, colname)
kwds = {} if len(column) < 4 else column[3]
kwds['comment'] = None if len(column) < 3 else column[2]
defer_col = kwds.pop('deferred', False)
ondelete = kwds.pop('ondelete', None)
if coltype not in column_data_types:
if not coltype.endswith('.id'):
raise ValueError("Unrecognized column type %s" % coltype)
# force indexing on all foreign keys; otherwise deletes can become vrey slow
kwds['index'] = True
props[colname] = Column(Integer, ForeignKey(coltype, ondelete=ondelete), **kwds)
else:
ctyp = column_data_types[coltype]
props[colname] = Column(ctyp, **kwds)
if defer_col:
props[colname] = deferred(props[colname])
props['meta'] = Column(column_data_types['object'])
if base is None:
new_table = type(class_name, (ormbase,), props)
else:
# need to jump through a hoop to allow __init__ on table classes;
# see: https://docs.sqlalchemy.org/en/latest/orm/constructors.html
if hasattr(base, '_init_on_load'):
@reconstructor
def _init_on_load(self, *args, **kwds):
base._init_on_load(self)
props['_init_on_load'] = _init_on_load
new_table = type(class_name, (base, ormbase), props)
return new_table
[docs]class Database(object):
"""Methods for doing relational database maintenance via sqlalchemy.
Supported backends: postgres, sqlite.
Features:
* Automatically build/dispose ro and rw engines (especially after fork)
* Generate ro/rw sessions on demand
* Methods for creating / dropping databases
* Clone databases across backends
"""
_all_dbs = weakref.WeakSet()
default_app_name = (' '.join(sys.argv))[-63:]
def __init__(self, ro_host, rw_host, db_name, ormbase):
self.ormbase = ormbase
self._mappings = {}
# default options for creating DB engines
self._engine_opts = {
'postgresql': {
'ro': {'echo': False, 'poolclass': sqlalchemy.pool.NullPool, 'isolation_level': 'AUTOCOMMIT'}, # {'pool_size': 0, 'max_overflow': 40, }
'rw': {'poolclass': sqlalchemy.pool.NullPool}, #{'pool_size': 0, 'max_overflow': 40},
'maint': {'poolclass': sqlalchemy.pool.NullPool},
}
}
self.ro_host = ro_host
self.rw_host = rw_host
self.db_name = db_name
self._ro_engine = None
self._rw_engine = None
self._maint_engine = None
self._engine_pid = None # pid of process that created these engines.
self._ro_sessionmaker = None
self._rw_sessionmaker = None
self.ro_address = self.db_address(ro_host, db_name)
self.rw_address = None if rw_host is None else self.db_address(rw_host, db_name)
self._all_dbs.add(self)
self._default_session = None
@property
def default_session(self):
self._check_engines()
if self._default_session is None:
self._default_session = self.session(readonly=True)
return self._default_session
def query(self, *args, **kwds):
return self.default_session.query(*args, **kwds)
def _find_mappings(self):
mappings = {cls.__tablename__:cls for cls in self.ormbase.__subclasses__()}
order = [t.name for t in self.ormbase.metadata.sorted_tables]
self._mappings = OrderedDict([(t, mappings[t]) for t in order if t in mappings])
def __getattr__(self, attr):
try:
# pretty sure I'll regret this later: I want to be able to ask for db.TableName
# and return the ORM object for a table.
# convert CamelCase to snake_case (credit: https://stackoverflow.com/a/12867228/643629)
table = re.sub(r'((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))', r'_\1', attr).lower()
if table not in self._mappings:
self._find_mappings()
return self._mappings[table]
except Exception:
return object.__getattribute__(self, attr)
def __repr__(self):
return "<%s %s (%s)>" % (type(self).__name__, self.ro_address, 'ro' if self.rw_address is None else 'rw')
def __str__(self):
# str(engine) does a nice job of masking passwords
s = str(self.ro_engine)[7:]
s = s.rstrip(')')
s = s.partition('?')[0]
return s
@property
def backend(self):
"""Return the backend used by this database (sqlite, postgres, etc.)
"""
# maybe ro_engine.name instead?
return self.ro_host.partition(':')[0]
[docs] @classmethod
def db_address(cls, host, db_name=None, app_name=None):
"""Return a complete address for DB access given a host (like postgres://user:pw@host) and database name.
Appends an app name to postgres addresses.
"""
if host.startswith('postgres'):
app_name = app_name or cls.default_app_name
return "{host}/{db_name}?application_name={app_name}".format(host=host, db_name=db_name, app_name=app_name)
else:
# for sqlite, db_name is the file path
if not host.endswith('/'):
host = host + '/'
return host + db_name
[docs] def get_database(self, db_name):
"""Return a new Database object with the same hosts and orm base, but different db name
"""
return Database(self.ro_host, self.rw_host, db_name, self.ormbase)
[docs] def dispose_engines(self):
"""Dispose any existing DB engines. This is necessary when forking to avoid accessing the same DB
connection simultaneously from two processes.
"""
if self._ro_engine is not None:
self._ro_engine.dispose()
if self._rw_engine is not None:
self._rw_engine.dispose()
if self._maint_engine is not None:
self._maint_engine.dispose()
self._ro_engine = None
self._ro_sessionmaker = None
self._rw_engine = None
self._rw_sessionmaker = None
self._maint_engine = None
self._engine_pid = None
self._default_session = None
# collect now or else we might try to collect engine-related garbage in forked processes,
# which can lead to "OperationalError: server closed the connection unexpectedly"
# Note: if this turns out to be flaky as well, we can just disable connection pooling.
gc.collect()
[docs] @classmethod
def dispose_all_engines(cls):
"""Dispose engines on all Database instances.
"""
for db in cls._all_dbs:
db.dispose_engines()
def _check_engines(self):
"""Dispose engines if they were built for a different PID
"""
if os.getpid() != self._engine_pid:
# In forked processes, we need to re-initialize the engine before
# creating a new session, otherwise child processes will
# inherit and muck with the same connections. See:
# https://docs.sqlalchemy.org/en/latest/faq/connections.html#how-do-i-use-engines-connections-sessions-with-python-multiprocessing-or-os-fork
if self._engine_pid is not None:
print("Making new session for subprocess %d != %d" % (os.getpid(), self._engine_pid))
self.dispose_engines()
@property
def ro_engine(self):
"""The read-only database engine.
"""
self._check_engines()
if self._ro_engine is None:
opts = self._engine_opts.get(self.backend, {}).get('ro', {})
self._ro_engine = create_engine(self.ro_address, **opts)
self._engine_pid = os.getpid()
return self._ro_engine
@property
def rw_engine(self):
"""The read-write database engine.
"""
self._check_engines()
if self._rw_engine is None:
if self.rw_address is None:
return None
opts = self._engine_opts.get(self.backend, {}).get('rw', {})
self._rw_engine = create_engine(self.rw_address, **opts)
self._engine_pid = os.getpid()
return self._rw_engine
@property
def maint_engine(self):
"""The maintenance engine.
For postgres DBs, this connects to the "postgres" database.
"""
self._check_engines()
if self._maint_engine is None:
opts = self._engine_opts.get(self.backend, {}).get('maint', None)
if opts is None:
# maybe just return rw engine for postgres?
raise Exception("no maintenance connection configured for DB %s" % self)
maint_addr = self.db_address(self.rw_host, 'postgres')
self._maint_engine = create_engine(maint_addr, **opts)
self._engine_pid = os.getpid()
return self._maint_engine
# external users should create sessions from here.
[docs] def session(self, readonly=True):
"""Create and return a new database Session instance.
If readonly is True, then the session is created using read-only credentials and has autocommit enabled.
This prevents idle-in-transaction timeouts that occur when GUI analysis tools would otherwise leave transactions
open after each request.
"""
if readonly:
if self._ro_sessionmaker is None:
self._ro_sessionmaker = sessionmaker(bind=self.ro_engine, query_cls=DBQuery)
return self._ro_sessionmaker()
else:
if self.rw_engine is None:
raise RuntimeError("Cannot start read-write DB session; no write access engine is defined (see config.synphys_db_host_rw)")
if self._rw_sessionmaker is None:
self._rw_sessionmaker = sessionmaker(bind=self.rw_engine, query_cls=DBQuery)
return self._rw_sessionmaker()
[docs] def reset_db(self):
"""Drop the existing database and initialize a new one.
"""
self.dispose_engines()
self.drop_database()
self.create_database()
self.create_tables()
self.grant_readonly_permission()
def list_databases(self):
engine = self.maint_engine
with engine.begin() as conn:
conn.connection.set_isolation_level(0)
return [rec[0] for rec in conn.execute('SELECT datname FROM pg_catalog.pg_database;')]
@property
def exists(self):
"""Bool indicating whether this DB exists yet.
"""
if self.backend == 'sqlite':
return os.path.isfile(self.db_name)
else:
return self.db_name in self.list_databases()
def drop_database(self):
if self.backend == 'sqlite':
if os.path.isfile(self.db_name):
os.remove(self.db_name)
elif self.backend == 'postgresql':
self.dispose_all_engines()
engine = self.maint_engine
with engine.begin() as conn:
conn.connection.set_isolation_level(0)
try:
conn.execute('drop database %s' % self.db_name)
except sqlalchemy.exc.ProgrammingError as err:
if 'does not exist' not in err.args[0]:
raise
else:
raise TypeError("Unsupported database backend %s" % self.backend)
def create_database(self):
if self.backend == 'sqlite':
return
elif self.backend == 'postgresql':
# connect to postgres db just so we can create the new DB
engine = self.maint_engine
with engine.begin() as conn:
conn.connection.set_isolation_level(0)
conn.execute('create database %s' % self.db_name)
# conn.execute('ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO %s;' % ro_user)
else:
raise TypeError("Unsupported database backend %s" % self.backend)
def grant_readonly_permission(self):
if self.backend == 'sqlite':
return
elif self.backend == 'postgresql':
ro_user = config.synphys_db_readonly_user
# grant readonly permissions
with self.rw_engine.begin() as conn:
conn.connection.set_isolation_level(0)
for cmd in [
('GRANT CONNECT ON DATABASE %s TO %s' % (self.db_name, ro_user)),
('GRANT USAGE ON SCHEMA public TO %s' % ro_user),
('GRANT SELECT ON ALL TABLES IN SCHEMA public to %s' % ro_user)]:
conn.execute(cmd)
else:
raise TypeError("Unsupported database backend %s" % self.backend)
[docs] def orm_tables(self):
"""Return a dependency-sorted of ORM mapping objects (tables) that are described by the ORM base for this database.
"""
# need to re-run every time because we can't tell when a new mapping has been added.
self._find_mappings()
return self._mappings
[docs] def table_names(self):
"""Return a list of the names of tables in this database.
May contain names that are not present in metadata_tables or orm_tables.
"""
return self.ro_engine.table_names()
[docs] def create_tables(self, tables=None, initialize=True):
"""Create tables in the database from the ORM base specification.
A list of the names of *tables* may be optionally specified to
create a subset of known tables.
"""
# Create all tables
meta_tables = self.metadata_tables()
if tables is not None:
tables = [meta_tables[t] for t in tables]
self.ormbase.metadata.create_all(bind=self.rw_engine, tables=tables)
self.grant_readonly_permission()
if initialize:
self.initialize_database()
[docs] def initialize_database(self):
"""Optionally called after create_tables.
Initialize is _not_ called when cloning databases.
Default does nothing; subclasses may override.
"""
pass
[docs] def drop_tables(self, tables=None):
"""Drop a list of tables (or all ORM-defined tables, if no list is given) from this database.
"""
drops = []
meta_tables = self.metadata_tables()
db_tables = self.table_names()
for k in meta_tables:
if tables is not None and k not in tables:
continue
if k in db_tables:
drops.append(k)
if len(drops) == 0:
return
if self.backend == 'sqlite':
for table in drops:
self.rw_engine.execute('drop table %s' % table)
else:
self.rw_engine.execute('drop table %s cascade' % (','.join(drops)))
# Seems to be not working correctly
# def enable_triggers(self, enable):
# """Enable or disable triggers for all tables in this group.
#
# This can be used to temporarily disable constraint checking on tables that are under development,
# allowing the rest of the pipeline to continue operating (for example, if removing an object from
# the pipeline would violate a foreign key constraint, disabling triggers will allow this constraint
# to go unchecked).
# """
# s = Session(readonly=False)
# enable = 'enable' if enable else 'disable'
# for table in self.tables.keys():
# s.execute("alter table %s %s trigger all;" % (table, enable))
# s.commit()
[docs] def vacuum(self, tables=None):
"""Cleans up database and analyzes table statistics in order to improve query planning.
Should be run after any significant changes to the database.
"""
with self.rw_engine.begin() as conn:
if self.backend == 'postgresql':
conn.connection.set_isolation_level(0)
if tables is None:
conn.execute('vacuum analyze')
else:
for table in tables:
conn.execute('vacuum analyze %s' % table)
else:
conn.execute('vacuum')
[docs] def bake_sqlite(self, sqlite_file, **kwds):
"""Dump a copy of this database to an sqlite file.
"""
sqlite_db = Database(ro_host="sqlite:///", rw_host="sqlite:///", db_name=sqlite_file, ormbase=self.ormbase)
sqlite_db.create_tables()
last_size = 0
for table in self.iter_copy_tables(self, sqlite_db, **kwds):
size = os.stat(sqlite_file).st_size
diff = size - last_size
last_size = size
print(" sqlite file size: %0.4fGB (+%0.4fGB for %s)" % (size*1e-9, diff*1e-9, table))
[docs] def clone_database(self, dest_db_name=None, dest_db=None, overwrite=False, **kwds):
"""Copy this database to a new one.
"""
if dest_db_name is not None:
assert isinstance(dest_db_name, str), "Destination DB name bust be a string"
assert dest_db is None, "Only specify one of dest_db_name or dest_db, not both"
dest_db = Database(self.ro_host, self.rw_host, dest_db_name, self.ormbase)
if dest_db.exists:
if overwrite:
dest_db.drop_database()
else:
raise Exception("Destination database %s already exists." % dest_db)
dest_db.create_database()
dest_db.create_tables(initialize=False)
for table in self.iter_copy_tables(self, dest_db, **kwds):
pass
[docs] @staticmethod
def iter_copy_tables(source_db, dest_db, tables=None, skip_tables=(), skip_columns={}, skip_errors=False, vacuum=True):
"""Iterator that copies all tables from one database to another.
Yields each table name as it is completed.
This function does not create tables in dest_db; use db.create_tables if needed.
"""
read_session = source_db.session(readonly=True)
write_session = dest_db.session(readonly=False)
try:
if dest_db.backend == 'postgres':
# disables some consistency checks to allow easier replication
write_session.execute("SET session_replication_role = 'replica';")
for table_name, table in source_db.metadata_tables().items():
if (table_name in skip_tables) or (tables is not None and table_name not in tables):
print("Skipping %s.." % table_name)
continue
print("Cloning %s.." % table_name)
# read from table in background thread, write to table in main thread.
skip_cols = skip_columns.get(table_name, [])
reader = TableReadThread(source_db, table, skip_columns=skip_cols)
i = 0
for i,rec in enumerate(reader):
try:
# Note: it is allowed to write `rec` directly back to the db, but
# in some cases (json columns) we run into a sqlalchemy bug. Converting
# to dict first is a workaround.
rec = {k:getattr(rec, k) for k in rec.keys()}
write_session.execute(table.insert(rec))
except Exception:
if skip_errors:
print("Skip record %d:" % i)
sys.excepthook(*sys.exc_info())
else:
raise
if i%1000 == 0:
print("%d/%d %0.2f%%\r" % (i, reader.max_id, (100.0*(i+1.0)/reader.max_id)), end="")
sys.stdout.flush()
print(" committing %d rows.. " % i)
write_session.commit()
read_session.rollback()
yield table_name
if vacuum:
print("Optimizing database..")
dest_db.vacuum()
print("All finished!")
finally:
if dest_db.backend == 'postgres':
write_session.execute("SET session_replication_role = 'origin';")
class DBQuery(sqlalchemy.orm.Query):
def add_table_columns(self, table, load_deferred=False):
"""Return a new query with all columns in *table* added.
Parameters
----------
table : sqlalchemy ORM table
The table from which columns will be added.
load_deferred : bool | list
If True, load all columns that are marked as deferred (by default,
these are ignored). Optionally, may specify a list of deferred
column names to load.
"""
assert isinstance(load_deferred, (list, bool)), "load_deferred must be bool or list"
meta = sqlalchemy.inspect(table)
cols = []
for col in meta.columns.keys():
load = (
(not meta.column_attrs[col].deferred) or
(load_deferred is True) or
(isinstance(load_deferred, list) and col in load_deferred)
)
if load:
cols.append(col)
return self.add_columns(*cols)
def dataframe(self, expand_tables=True, rename_columns=True):
"""Return a pandas dataframe constructed from the results of this query.
Columns are renamed from the original query (see DBQuery.recarray)
Parameters
----------
expand_tables : bool | list
If True, expand all table entities included in the query into individual
columns. Optionally, a list of table names to expand may be provided instead.
rename_columns : bool
If True, columns are renamed (see DBQuery.recarray).
"""
# don't like this; we want a bit more control over how columns are unpacked / renamed
if not rename_columns:
if expand_tables is False:
raise NotImplementedError("The combination expand_tables=False, rename_columns=False is not implemented")
return pandas.read_sql(self.statement, self.session.bind)
recs, col_names, col_types, rec_fields = self._prepare_array(expand_tables=expand_tables)
# coerce types
type_map = {float: 'float', int: pandas.Int64Dtype()}
col_types = [type_map.get(t, 'object') for t in col_types]
data = {}
for i, dest_col_name in enumerate(col_names):
source_col_name = rec_fields[i]
if isinstance(source_col_name, str):
col_data = [getattr(rec, source_col_name) for rec in recs]
elif isinstance(source_col_name, tuple):
col_data = [getattr(getattr(rec, source_col_name[0]), source_col_name[1], None) for rec in recs]
data[dest_col_name] = pandas.Series(col_data, dtype=col_types[i])
return pandas.concat(data, axis=1)
def recarray(self, expand_tables=True):
"""Return a numpy record array constructed from the results of this query.
Columns are renamed from the original query based on the following rules:
- If a column label is explicitly provided, that label is used without modification
- Result columns that contain a single DB column are renamed to `table.column`
- Result columns that are derived from more complex expressions use a string representation of the expression
- Duplicate column names have `_N` appended
Parameters
----------
expand_tables : bool | list
If True, expand all table entities included in the query into individual
columns. Optionally, a list of table names to expand may be provided instead.
"""
recs, col_names, col_types, rec_fields = self._prepare_array(expand_tables=expand_tables)
# need to represent everything as either float or obj in order to support null values
col_types = ['float' if t is float else 'object' for t in col_types]
dtype = list(zip(col_names, col_types))
# convert records to numpy array
if expand_tables is False:
arr = np.array(recs, dtype=dtype)
else:
arr = np.empty(len(recs), dtype=dtype)
for i, dest_col_name in enumerate(col_names):
source_col_name = rec_fields[i]
if isinstance(source_col_name, str):
arr[dest_col_name] = [getattr(rec, source_col_name) for rec in recs]
elif isinstance(source_col_name, tuple):
arr[dest_col_name] = [getattr(getattr(rec, source_col_name[0]), source_col_name[1], None) for rec in recs]
return arr
def _prepare_array(self, expand_tables):
recs = self.all()
row_types = (tuple,)
try:
row_types = row_types + (sqlalchemy.engine.row.Row,)
except AttributeError:
pass
if len(recs) > 0 and not isinstance(recs[0], row_types):
# sqlalchemy returns lists of keyed tuples in most cases, but lists of ORM instances if only one
# column was requested. This is a pain to handle later on, so we're normalizing the output here.
rectyp = namedtuple('record', [self.column_descriptions[0]['name']])
recs = [rectyp(x) for x in recs]
# decide on column names and dtypes to use
col_names = []
col_types = []
rec_fields = []
for col in self.column_descriptions:
try:
from sqlalchemy.ext.declarative.api import DeclarativeMeta
except ImportError:
from sqlalchemy.orm.decl_api import DeclarativeMeta
if isinstance(col['type'], DeclarativeMeta):
# this column holds an entire table; use table name unless aliased
table_name = col['entity'].__table__.name
aliased_table_name = col['name'] if col['aliased'] else table_name
expand = (
expand_tables is True or (
isinstance(expand_tables, list) and (
(table_name in expand_tables) or
(col['entity'] in expand_tables)
)
)
)
if expand:
# Which columns to expand?
expanded_cols = self._get_expanded_cols(col, recs)
for attribute_name in expanded_cols:
col_names.append(aliased_table_name + '.' + attribute_name)
rec_fields.append((col['name'], attribute_name))
col_types.append(self._get_column_type(getattr(col['entity'], attribute_name)))
else:
col_names.append(aliased_table_name)
col_types.append('object')
rec_fields.append(col['name'])
else:
rec_fields.append(col['name'])
expr = col['expr']
if isinstance(expr, sqlalchemy.sql.elements.Label):
# query specifies a label here; use that name unconditionally
col_names.append(expr.name)
else:
# assign column names as table.column
if isinstance(expr, sqlalchemy.orm.attributes.InstrumentedAttribute):
table_name = sqlalchemy.inspect(col['entity']).name if col['aliased'] else col['entity'].__table__.name
col_names.append(table_name + '.' + col['name'])
elif isinstance(expr, sqlalchemy.sql.annotation.AnnotatedColumn):
col_names.append(expr.table.name + '.' + expr.name)
elif isinstance(expr, sqlalchemy.sql.elements.BinaryExpression):
col_names.append(str(col['expr']) if col['name'] is None else col['name'])
else:
raise TypeError(f"recarray() does not support column of type {repr(expr)} (name: {col['name']})")
col_types.append(self._get_column_type(expr))
# modify any repeated names
seen_names = set()
for i,name in enumerate(col_names):
j = 0
while True:
new_name = name if j == 0 else f"{name}_{j}"
j += 1
if new_name not in seen_names:
seen_names.add(new_name)
col_names[i] = new_name
break
return recs, col_names, col_types, rec_fields
def _get_expanded_cols(self, column_desc, records):
"""Return the list of columns to use when expanding one entity column.
Ideally we'd ask sqlalchemy somehow, but for now we just find the first
non-null entry and ask which attributes were loaded. If no entries are found, then we can only
guess and the safest option is to include all columns.
"""
col_name = column_desc['name']
first_item = None
for rec in records:
first_item = getattr(rec, col_name)
if first_item is not None:
break
insp = sqlalchemy.inspect(column_desc['entity'])
if getattr(insp, 'is_aliased_class', False):
insp = insp.mapper
all_columns = list(insp.columns.keys())
if first_item is None:
# no items returned in this query; just guess all columns
return all_columns
else:
unloaded = sqlalchemy.orm.attributes.instance_state(first_item).unloaded
return [c for c in all_columns if c not in unloaded]
def _get_column_type(self, column_expr):
if isinstance(column_expr, sqlalchemy.sql.elements.Label):
column_expr = column_expr._element
return column_expr.type.python_type
class TableReadThread(threading.Thread):
"""Iterator that yields records (all columns) from a table.
Records are queried chunkwise and queued in a background thread to enable more efficient streaming.
"""
def __init__(self, db, table, chunksize=1000, skip_columns=()):
threading.Thread.__init__(self)
self.daemon = True
self.db = db
self.table = table
self.chunksize = chunksize
self.skip_columns = skip_columns
self.queue = queue.Queue(maxsize=5)
self.max_id = db.session().query(func.max(table.columns['id'])).all()[0][0] or 0
self.start()
def run(self):
try:
session = self.db.session()
table = self.table
chunksize = self.chunksize
all_columns = [col for col in table.columns if col.name not in self.skip_columns]
for i in range(0, self.max_id, chunksize):
query = session.query(*all_columns).filter((table.columns['id'] >= i) & (table.columns['id'] < i+chunksize))
records = query.all()
self.queue.put(records)
self.queue.put(None)
session.rollback()
session.close()
except Exception as exc:
sys.excepthook(*sys.exc_info())
self.queue.put(exc)
raise
def __iter__(self):
while True:
recs = self.queue.get()
if recs is None:
break
if isinstance(recs, Exception):
raise recs
for rec in recs:
yield rec