from .base import BaseDataType
import importlib
import logging
logger = logging.getLogger(__name__)
[docs]class Join(BaseDataType):
"""Abstract parent of model joins - dependent child / parent models."""
def __init__(self, name: str, source: str, target: str):
super().__init__(name=name)
self.source = None
self.target = None
self._source_module, self._source = self.split_class_def(source)
(self._target_module,
self._target,
self.target_property) = self.parse_join_target(target)
@staticmethod
def split_class_def(class_def: str):
module, class_name = class_def.rsplit('.', 1)
return module, class_name
[docs] @staticmethod
def class_for_name(module_name: str, class_name: str):
"""
Handles import of target model class. Needed to prevent circular
dependency.
"""
m = importlib.import_module(module_name)
c = getattr(m, class_name)
return c
@staticmethod
def parse_join_target(join_target: str):
target_property = None
parts = join_target.split(':')
if len(parts) == 2:
target_property = parts[1]
module, class_name = Join.split_class_def(parts[0])
return module, class_name, target_property
[docs] def get_target(self):
"""
Gets the target model of the join.
"""
if not self.target:
# get class from string - to avoid circular imports and class self
# reference
self.target = self.class_for_name(self._target_module,
self._target)
return self.target
[docs] def get_source(self):
"""Gets the source model of the join."""
if not self.source:
# get class from string - to avoid circular imports and class self
# reference
self.source = self.class_for_name(self._source_module,
self._source)
return self.source
def _get_es_type(self):
return {'type': 'keyword'}
def insert_reference(self,
value: 'base_model.Model', # noqa: F821
model: 'base_model.Model'): # noqa: F821
return None
def include_in_flat(self):
return False
def _is_value_model(self, value):
if value is None or isinstance(value, str):
return False
return True
[docs]class SingleJoin(Join):
"""
1:1 model join.
"""
def lazy_load(self, model):
try:
value = model.__getattribute__(self.name).id
except AttributeError:
value = model.__getattribute__(self.name)
loaded = self.get_default_value()
if value:
loaded = self.get_target().get(value)
return loaded
def serialize(self,
value: (str, 'base_model.Model'), # noqa: F821
depth: int,
to_str: bool = False,
flat: bool = True):
if depth < 1:
try:
if value.id:
ret = value.id
else:
if to_str:
ret = object.__repr__(value)
else:
if flat:
return None
ret = value
except (AttributeError, TypeError):
ret = value
else:
ret = value.serialize(depth=depth-1)
logger.debug("serialize single %s", ret)
return ret
def on_update(self,
value: 'base_model.Model', # noqa: F821
model: 'base_model.Model'): # noqa: F821
if self.target_property and self._is_value_model(value):
logger.debug("SingleJoin::on_update %s.%s = %s -> %s",
model.__class__.__name__,
self.name, value,
self.name)
target_type = value._mapping[self.target_property]
target_type.insert_reference(model, value)
return super().on_update(value, model)
def insert_reference(self,
value: 'base_model.Model', # noqa: F821
model: 'base_model.Model'): # noqa: F821
logger.debug("SingleJoin::insert_reference %s %s",
self.name,
value.id)
model.__setattr__(self.name, value)
def on_save(self, model):
value = model.__getattribute__(self.name)
logger.debug("SingleJoin::on_save %s %s %s",
self.name,
model.id,
value and (hasattr(value, 'id') and value.id))
if value and hasattr(value, 'id') and value.id is None:
logger.debug("SingleJoin::on_save - saving")
value.save()
return value
return None
[docs]class MultiJoin(Join):
"""1:N model join."""
def __init__(self, name: str, source: str, target: str, join_by=None):
super(MultiJoin, self).__init__(name=name,
source=source,
target=target)
self.join_by = join_by
def get_join_by(self):
if not self.join_by:
self.join_by = self.get_source()._mapping['_name'] + '_id'
return self.join_by
def lazy_load(self, model):
try:
value = [v.id for v in model.__getattribute__(self.name)]
except AttributeError:
value = model.__getattribute__(self.name)
return self.get_target().get(value)
def serialize(self,
value: (str, 'base_model.Model'), # noqa: F821
depth: int,
to_str: bool = False,
flat: bool = True):
ret = [SingleJoin.serialize(self, value=model, depth=depth,
to_str=to_str, flat=flat)
for model in value]
logger.debug("MultiJoin::serialize %s", ret)
return ret
def get_default_value(self):
return []
def on_update(self,
value: 'list[base_model.Model]', # noqa: F821
model: 'base_model.Model'): # noqa: F821
if self.target_property:
logger.debug("MultiJoin::on_update %s.%s = %s -> %s",
model.__class__.__name__,
self.name,
value,
self.name)
for val in value:
if not self._is_value_model(val):
continue
target_type = val._mapping[self.target_property]
target_type.insert_reference(model, val)
return super().on_update(value, model)
def insert_reference(self,
value: 'base_model.Model', # noqa: F821
model: 'base_model.Model'): # noqa: F821
logger.debug("MultiJoin::insert_reference %s.%s = %s",
model, self.name, value)
referred_attribute = model.__getattribute__(self.name)
referred_ids = [r.id for r in referred_attribute
if self._is_value_model(r)]
if value.id not in referred_ids:
referred_attribute.append(value)
def on_save(self, model: 'base_model.Model'): # noqa: F821
logger.debug("MultiJoin::on_save %s %s", self.name, model.id)
ret = []
values = model.__getattribute__(self.name)
initialized_values = [v for v in values
if v and hasattr(v, 'id') and v.id is None]
for value in initialized_values:
ret.append(value.save())
if len(ret):
return ret
return None
def deserialize(self, value):
if value is None:
return []
return value
[docs]class LooseJoin(Join):
def to_es(self, value):
return {}
def lazy_load(self, value: str):
return {self.name: self.get_default_value()}
def _has_es_type(self):
return False
def from_es(self, es_hit):
return self.get_default_value()
[docs]class SingleJoinLoose(SingleJoin, LooseJoin):
def __init__(self, name: str, source: str, target: str, do_lazy_load=False):
super().__init__(name=name, source=source, target=target)
self.do_lazy_load = do_lazy_load
def lazy_load(self, value):
if not self.do_lazy_load:
logger.debug("SingleJoinLoose::lazy_load %s of %s skipped" % (self.name, value))
return None
assert self.target_property
target = self.get_target()
find_by = {self.target_property: value.id}
try:
ret = target.find_by(**find_by, size=1)[0]
except IndexError as e:
ret = None
return ret
[docs]class MultiJoinLoose(MultiJoin, LooseJoin):
"""
Important! Dosen't preserve order!
"""
def __init__(self, name: str, source: str, target: str, join_by=None, do_lazy_load=False):
super().__init__(name=name, source=source, target=target)
self.do_lazy_load = do_lazy_load
def lazy_load(self, value):
if not self.do_lazy_load:
logger.debug("MultiJoinLoose::lazy_load %s of %s skipped" % (self.name, value))
return [];
assert self.target_property
target = self.get_target()
find_by = {self.target_property: value.id}
ret = target.find_by(**find_by)
return ret