from __future__ import print_function, division
from sqlalchemy.orm import aliased
import sqlalchemy.sql.elements
from collections import OrderedDict
from .database import default_db
from .database.schema import schema_description
from . import constants
# tables containing data used to classify cells, mapped to the cell attributes
# used to access the table
_cell_data_tables = [
('cell', None),
('morphology', 'morphology'),
('patch_seq', 'patch_seq'),
('intrinsic', 'intrinsic'),
('cortical_cell_location', 'cortical_location'),
]
# names of attributes available for classification, mapped back to their source tables
_db_schema = schema_description()
_criteria_attributes = {}
for table_name,table_attr in _cell_data_tables:
cols = _db_schema[table_name]['columns']
for k in cols.keys():
_criteria_attributes[k] = (table_name, table_attr)
[docs]class CellClass(object):
"""Represents a class of cells as a list of selection criteria.
Construct with an arbitrary set of keyword arguments, where each argument specifies
a criteria for matching cells. Keyword argument names must be a column from the
:class:`Cell <aisynphys.database.schema.Cell>`, :class:`Morphology <aisynphys.database.schema.Morphology>`,
Intrinsic, or CorticalCellLocation database tables.
Can also be filtered with an arbitrary list of expressions (unnamed arguments),
each of which must be an sqlalchemy BinaryExpression, as used in query.filter(),
referring to one of the available tables.
Example::
pv_class = CellClass(cre_type='pvalb')
inhibitory_class = CellClass(cre_type=('pvalb', 'sst', 'vip'))
l23_pyr_class = CellClass(cortical_layer='2/3')
l5_spiny_class = CellClass(dendrite_type='spiny', cortical_layer='5')
deep_l3_class = CellClass(
db.CorticalCellLocation.fractional_layer_depth < 0.5,
cortical_layer='3')
"""
def __init__(self, *exprs, name=None, **criteria):
# sanity check inputs
global _criteria_attributes
assert name is None or isinstance(name, str), f"name must be a string or None (got {repr(name)})"
for k,v in criteria.items():
assert k in _criteria_attributes, f"Key '{k}' is not a valid cell class criterion."
for ex in exprs:
assert isinstance(ex, sqlalchemy.sql.elements.BinaryExpression), f"non-keyword arguments must be sqlalchemy binary expressions (got {repr(ex)})"
assert ex.left.name in _criteria_attributes, f"Key '{ex.left.name}' is not a valid cell class criterion."
self.criteria = criteria
self.exprs = exprs
self._name = name
@property
def name(self):
"""A short string representation of this cell class.
If no name was supplied, then this value is `as_tuple` concatenated with spaces.
"""
if self._name is not None:
return self._name
return ' '.join([str(x) for x in self.as_tuple])
@property
def as_tuple(self):
"""A tuple representation of this cell class used for display purposes.
Order of elements in the tuple is (target_layer, pyramidal, cre_type), but
elements are only present if they were specified as criteria for the cell class.
"""
name = []
target_layer = self.criteria.get('target_layer')
cortical_layer = self.criteria.get('cortical_layer')
if target_layer is not None:
name.append('L' + target_layer)
elif cortical_layer is not None:
name.append('L' + str(cortical_layer))
if 'dendrite_type' in self.criteria:
name.append(str(self.criteria['dendrite_type']))
if 'pyramidal' in self.criteria:
name.append('pyr' if self.criteria['pyramidal'] else 'nonpyr')
cre_type = self.criteria.get('cre_type')
if cre_type is not None:
name.append(str(cre_type))
t_type = self.criteria.get('t_type')
if t_type is not None:
name.append(str(t_type))
return tuple(name)
@property
def is_excitatory(self):
"""True if this class includes only excitatory cells; False if this class includes only inhibitory cells;
None if the class may include a mixture of excitatory and inhibitory cells.
Relevant criteria used here are:
* cell.cre_type
* morphology.dendrite_type
* cell.cell_class
* cell.cell_class_nonsynaptic
"""
is_ex = []
cre = self.criteria.get('cre_type')
if not isinstance(cre, (tuple, list)):
cre = (cre,)
cre_is_exc = all([c in constants.EXCITATORY_CRE_TYPES for c in cre])
cre_is_inh = all([c in constants.INHIBITORY_CRE_TYPES for c in cre])
if cre_is_exc:
is_ex.append(True)
elif cre_is_inh:
is_ex.append(False)
dendrite = self.criteria.get('dendrite_type')
if dendrite == 'spiny':
is_ex.append(True)
elif dendrite in ['aspiny', 'sparsely spiny']:
is_ex.append(False)
elif dendrite is not None:
return None
for cell_class in [self.criteria.get('cell_class'), self.criteria.get('cell_class_nonsynaptic')]:
if cell_class == 'ex':
is_ex.append(True)
elif cell_class == 'in':
is_ex.append(False)
elif cell_class == 'mixed':
return None
elif cell_class is not None:
raise ValueError("cell class criteria must be 'ex', 'in', or 'mixed'")
if len(is_ex) == 0:
return None
if all(is_ex):
return True
elif not any(is_ex):
return False
else:
return None
@property
def output_synapse_type(self):
"""Expected type of synapses "ex", "in", or None to be output from this cell type.
"""
return {True: 'ex', False: 'in'}.get(self.is_excitatory, None)
def __contains__(self, cell):
if len(self.criteria) == 0 and len(self.exprs) == 0:
return True
# check expressions
for expr in self.exprs:
# get variable name and value from expression
key = expr.left.name
ref_val = expr.right.value
# check requested value
val = self._get_cell_subattr(cell, key)
if val is not None and expr.operator(val, ref_val):
continue
else:
return False
# check keyword arg criteria
for k, v in self.criteria.items():
if isinstance(v, dict):
or_attr = []
for k2, v2 in v.items():
v1 = self._get_cell_subattr(cell, k2)
or_attr.append(v1 == v2)
if not any(or_attr):
return False
elif isinstance(v, (tuple, list)):
if self._get_cell_subattr(cell, k) not in v:
return False
else:
if self._get_cell_subattr(cell, k) != v:
return False
return True
@staticmethod
def _get_cell_subattr(cell, attr):
"""Get attribute from cell or one of its linked tables (morphology, intrinsic, etc..)
"""
global _criteria_attributes
if attr not in _criteria_attributes:
raise Exception(f'Cannot use "{attr}" for cell typing; attribute not found on cell or linked objects')
sub_obj_name = _criteria_attributes[attr][1]
obj = cell if sub_obj_name is None else getattr(cell, sub_obj_name, None)
if obj is None:
return None
return getattr(obj, attr)
def __hash__(self):
return hash(self.name)
def __eq__(self, a):
"""Cell class is considered equal to its *name* to allow it to be indexed from a dict more
easily::
cc = CellClass(cre_type='sst', layer='6')
cc.name => 'L6 sst'
{cc: 1}['L6 sst'] => 1
"""
if isinstance(a, str):
return a == self.name
elif isinstance(a, CellClass):
return a.name == self.name
else:
return object.__eq__(self) # should raise NotImplemented
def __repr__(self):
return "<CellClass %s>" % self.name
def __str__(self):
return self.name
[docs] def filter_query(self, query, cell_table, db=None):
"""Return a modified query (sqlalchemy) that filters results to include only those in
this cell class.
"""
if db is None:
db = default_db
morpho = aliased(db.Morphology)
intrinsic = aliased(db.Intrinsic)
location = aliased(db.CorticalCellLocation)
patch_seq = aliased(db.PatchSeq)
query = (query.outerjoin(morpho, morpho.cell_id==cell_table.id)
.outerjoin(intrinsic, intrinsic.cell_id==cell_table.id)
.outerjoin(location, location.cell_id==cell_table.id)
.outerjoin(patch_seq, patch_seq.cell_id==cell_table.id))
tables = [cell_table, morpho, intrinsic, location, patch_seq]
for expr in self.exprs:
found_attr = False
key = expr.left.name
ref_val = expr.right.value
for table in tables:
if hasattr(table, key):
found_attr = True
query = query.filter(expr.operator(getattr(table, key), ref_val))
break
if not found_attr:
raise Exception('Cannot use "%s" for cell typing; attribute not found in available tables.' % key)
for k, v in self.criteria.items():
found_attr = False
for table in tables:
if hasattr(table, k):
found_attr = True
if isinstance(v, (tuple, list)):
query = query.filter(getattr(table, k).in_(v))
else:
query = query.filter(getattr(table, k)==v)
break
if not found_attr:
raise Exception('Cannot use "%s" for cell typing; attribute not found in available tables.' % k)
return query
[docs] def dataframe_mask(self, df, prefix=''):
"""Given a dataframe containing columns describing cell properties, return a boolean
Series that indicates whether the cell in each row is a member of the class.
Dataframe columns must be named like "cell.cre_type" or "morphology.dendrite_type", as
defined in the database. Optionally, these names may begin with *prefix*.
"""
global _criteria_attributes
import pandas
mask = pandas.Series([True] * len(df))
if len(self.criteria) == 0 and len(self.exprs) == 0:
return mask
# check expressions
for expr in self.exprs:
# get variable name and value from expression
key = expr.left.name
ref_val = expr.right.value
# check requested value
col = df[self._get_df_col_name(key, prefix)]
mask &= expr.operator(col, ref_val)
# check keyword arg criteria
for k, v in self.criteria.items():
if isinstance(v, dict):
raise NotImplementedError("dict cell class criteria not supported in dataframe_mask")
elif isinstance(v, (tuple, list)):
# lists/tuples: items in column must be in list
col = df[self._get_df_col_name(k, prefix)]
mask &= col.isin(v)
else:
# scalars: items in column must be equal to value
col = df[self._get_df_col_name(k, prefix)]
mask &= (col == v)
return mask
def _get_df_col_name(self, key, prefix):
table_name = _criteria_attributes.get(key, [None])[0]
if table_name is None:
raise Exception(f'Cannot use "{key}" for cell typing; attribute not found on cell or linked tables')
return prefix + table_name + '.' + key
def classify_cells(cell_classes, cells=None, pairs=None, missing_attr='raise'):
"""Given cell class definitions and a list of cells, return a dict indicating which cells
are members of each class.
Parameters
----------
cell_classes : dict
List of CellClass instances
cells : list | None
List of Cell instances to be classified.
pairs : list | None
List of pairs from which cells will be collected. May not be used with *cells*
missing_attr : str
Determines the behavior when a criteria attribute is missing on a
cell. If 'ignore', then the cell is excluded from the result,. If 'raise',
then an exception is raised. Default is 'raise'.
Returns
-------
cell_groups : OrderedDict
Dictionary mapping {cell_class: [list of cells]}
Example
-------
pv_cell_class = CellClass(cre_type='pvalb', target_layer='2/3')
sst_cell_class = CellClass(cre_type='sst', target_layer='2/3')
cell_classes = [pv_cell_class, sst_cell_class]
cells = session.Query(db.Cell).all()
grouped_cells = classify_cells(cell_classes, cells=cells)
pv_cells = grouped_cells[pv_cell_class]
sst_cells = grouped_cells[sst_cell_class]
"""
if pairs is not None:
assert cells is None, "cells and pairs arguments are mutually exclusive"
cells = set([p.pre_cell for p in pairs] + [p.post_cell for p in pairs])
cell_groups = OrderedDict([(cell_class, set()) for cell_class in cell_classes])
for cell in cells:
for cell_class in cell_classes:
try:
if cell in cell_class:
cell_groups[cell_class].add(cell)
except Exception:
if missing_attr == 'ignore':
continue
else:
raise
return cell_groups
def classify_cell_dataframe(cell_classes, df, prefix=''):
"""Classify cells in a pandas dataframe using a dictionary of cell classes.
Returns a pandas series giving the key for the first cell class to match each row.
"""
import pandas
match = pandas.Series([None] * len(df), dtype=object)
for k, cls in reversed(list(cell_classes.items())):
mask = cls.dataframe_mask(df, prefix=prefix)
match[mask] = k
return match
def classify_pairs(pairs, cell_groups):
"""Given a list of cell pairs and a dict that groups cells together by class (ie the output of classify_cells),
return a dict that groups pairs into (pre, post) cell type buckets.
Parameters
----------
pairs : list of Pair instances
The Pair instances (probably returned from a database query) to be grouped
cell_groups : dict
Specifies the cell classes and the cells that belong to each class. The format is
the same as the output of classify_cells().
Returns
-------
pair_groups : OrderedDict
Maps {(pre_class, post_class): [list of pairs]}
"""
results = OrderedDict()
for pre_class, pre_group in cell_groups.items():
for post_class, post_group in cell_groups.items():
post_group = cell_groups[post_class]
class_pairs = [p for p in pairs if p.pre_cell in pre_group and p.post_cell in post_group]
results[(pre_class, post_class)] = class_pairs
return results
def classify_pair_dataframe(cell_classes, df, col_names=('pre_class', 'post_class')):
"""Add two new columns to a pair dataframe giving the pre and post class
names.
"""
df[col_names[0]] = classify_cell_dataframe(cell_classes, df, prefix='pre_')
df[col_names[1]] = classify_cell_dataframe(cell_classes, df, prefix='post_')
_criteria_attribute_cache = {}
def _get_criteria_attributes(db):
"""Return a dict mapping attribute:(table_attribute, db_table) for all cell-related attributes
that can be used for classification criteria.
"""
global _criteria_attribute_cache
if db not in _criteria_attribute_cache:
# tables containing extra per-cell data that can be used as classification criteria
criteria_tables = [
(None, db.Cell),
('morphology', db.Morphology),
('patch_seq', db.PatchSeq),
('intrinsic', db.Intrinsic),
('cortical_location', db.CorticalCellLocation),
]
# names of attributes available for classification, mapped back to their source tables
criteria_attributes = {}
for name,table in criteria_tables:
for k in table.__table__.columns.keys():
criteria_attributes[k] = name,table
_criteria_attribute_cache[db] = criteria_attributes
return _criteria_attribute_cache[db]