Source code for elastic_connect.data_types.join

from .base import BaseDataType
import importlib
import logging

logger = logging.getLogger(__name__)

[docs]class Join(BaseDataType): """Abstract parent of model joins - dependent child / parent models.""" def __init__(self, name: str, source: str, target: str): super().__init__(name=name) self.source = None self.target = None self._source_module, self._source = self.split_class_def(source) self._target_module, self._target, self.target_property = self.parse_join_target(target) @staticmethod def split_class_def(class_def: str): module, class_name = class_def.rsplit('.', 1) return module, class_name
[docs] @staticmethod def class_for_name(module_name: str, class_name: str): """Handles import of target model class. Needed to prevent circular dependency.""" m = importlib.import_module(module_name) c = getattr(m, class_name) return c
@staticmethod def parse_join_target(join_target: str): target_property = None parts = join_target.split(':') if len(parts) == 2: target_property = parts[1] module, class_name = Join.split_class_def(parts[0]) return module, class_name, target_property
[docs] def get_target(self): """Gets the target model of the join.""" if not self.target: # get class from string - to avoid circular imports and class self reference self.target = self.class_for_name(self._target_module, self._target) return self.target
[docs] def get_source(self): """Gets the source model of the join.""" if not self.source: # get class from string - to avoid circular imports and class self reference self.source = self.class_for_name(self._source_module, self._source) return self.source
def _get_es_type(self): return 'keyword' def insert_reference(self, value: 'base_model.Model', model: 'base_model.Model'): return None def include_in_flat(self): return False def _is_value_model(self, value): if value is None or isinstance(value, str): return False return True
[docs]class SingleJoin(Join): """1:1 model join.""" def lazy_load(self, model): try: value = model.__getattribute__(self.name).id except AttributeError: value = model.__getattribute__(self.name) loaded = self.get_default_value() if value: loaded = self.get_target().get(value) return loaded def serialize(self, value: (str, 'base_model.Model'), depth: int, to_str: bool = False, flat: bool = True): if depth < 1: try: if value.id: ret = value.id else: if to_str: ret = object.__repr__(value) else: if flat: return None ret = value except (AttributeError, TypeError): ret = value else: ret = value.serialize(depth=depth-1) logger.debug("serialize single %s", ret) return ret def on_update(self, value: 'base_model.Model', model: 'base_model.Model'): if self.target_property and self._is_value_model(value): logger.debug("SingleJoin::on_update %s.%s = %s -> %s", model.__class__.__name__, self.name, value, self.name) target_type = value._mapping[self.target_property] target_type.insert_reference(model, value) return super().on_update(value, model) def insert_reference(self, value: 'base_model.Model', model: 'base_model.Model'): logger.debug("SingleJoin::insert_reference %s %s", self.name, value.id) model.__setattr__(self.name, value) def on_save(self, model): value = model.__getattribute__(self.name) logger.debug("SingleJoin::on_save %s %s %s", self.name, model.id, value and (hasattr(value, 'id') and value.id)) if value and hasattr(value, 'id') and value.id is None: logger.debug("SingleJoin::on_save - saving") value.save() return value return None
[docs]class MultiJoin(Join): """1:N model join.""" def __init__(self, name: str, source: str, target: str, join_by=None): super(MultiJoin, self).__init__(name=name, source=source, target=target) self.join_by = join_by def get_join_by(self): if not self.join_by: self.join_by = self.get_source()._mapping['_name'] + '_id' return self.join_by def lazy_load(self, model): try: value = [v.id for v in model.__getattribute__(self.name)] except AttributeError: value = model.__getattribute__(self.name) return [self.get_target().get(val) for val in value] def serialize(self, value: (str, 'base_model.Model'), depth: int, to_str: bool = False, flat: bool = True): ret = [SingleJoin.serialize(self, value=model, depth=depth, to_str=to_str, flat=flat) for model in value] logger.debug("MultiJoin::serialize %s", ret) return ret def get_default_value(self): return [] def on_update(self, value: 'list[base_model.Model]', model: 'base_model.Model'): if self.target_property: logger.debug("MultiJoin::on_update %s.%s = %s -> %s", model.__class__.__name__, self.name, value, self.name) for val in value: if not self._is_value_model(val): continue target_type = val._mapping[self.target_property] target_type.insert_reference(model, val) return super().on_update(value, model) def insert_reference(self, value: 'base_model.Model', model: 'base_model.Model'): logger.debug("MultiJoin::insert_reference %s.%s = %s", model, self.name, value) referred_attribute = model.__getattribute__(self.name) if value.id not in [r.id for r in referred_attribute if self._is_value_model(r)]: referred_attribute.append(value) def on_save(self, model: 'base_model.Model'): logger.debug("MultiJoin::on_save %s %s", self.name, model.id) ret = [] values = model.__getattribute__(self.name) for value in [v for v in values if v and hasattr(v, 'id') and v.id is None]: ret.append(value.save()) if len(ret): return ret return None def deserialize(self, value): if value is None: return [] return value
[docs]class LooseJoin(Join): def to_es(self, value): return {} def lazy_load(self, value: str): return {self.name: self.get_default_value()} def _has_es_type(self): return False def from_es(self, es_hit): return self.get_default_value()
[docs]class SingleJoinLoose(SingleJoin, LooseJoin): pass
[docs]class MultiJoinLoose(MultiJoin, LooseJoin): pass