Source code for modularodm.storedobject

# -*- coding: utf-8 -*-

import six
import copy
import logging
import warnings
from functools import wraps

from . import signals
from . import exceptions
from .fields import Field, ListField, ForeignList, AbstractForeignList
from .storage import Storage
from .query import QueryBase, RawQuery, QueryGroup
from .frozen import FrozenDict
from .cache import Cache
from .writequeue import WriteQueue, WriteAction


logger = logging.getLogger(__name__)


class ContextLogger(object):

    @staticmethod
    def sort_func(e):
        return (e.xtra._name if e.xtra else None, e.func.__name__)

    def report(self, sort_func=None):
        return self.logger.report(sort_func or self.sort_func)

    def __init__(self, log_level=None, xtra=None, sort_func=None):
        self.log_level = log_level
        self.xtra = xtra
        self.sort_func = sort_func or self.sort_func
        self.logger = Storage.logger

    def __enter__(self):
        self.listening = self.logger.listen(self.xtra)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.listening:
            report = self.logger.report(
                lambda e: (e.xtra._name if e.xtra else None, e.func.__name__)
            )
            if self.log_level is not None:
                logging.log(self.log_level, report)
            self.logger.clear()
        self.logger.pop()


def deref(data, keys, missing=None):
    if keys[0] in data:
        if len(keys) == 1:
            return data[keys[0]]
        return deref(data[keys[0]], keys[1:], missing=missing)
    return missing


def flatten_backrefs(data, stack=None):

    stack = stack or []

    if isinstance(data, list):
        return [(stack, item) for item in data]

    out = []
    for key, val in data.items():
        out.extend(flatten_backrefs(val, stack + [key]))

    return out


def log_storage(func):

    @wraps(func)
    def wrapped(this, *args, **kwargs):

        cls = this if isinstance(this, type) else type(this)

        with ContextLogger(log_level=this._log_level, xtra=cls):
            return func(this, *args, **kwargs)

    return wrapped


def warn_if_detached(func):
    """ Warn if self / cls is detached. """
    @wraps(func)
    def wrapped(this, *args, **kwargs):
        # Check for _detached in __dict__ instead of using hasattr
        # to avoid infinite loop in __getattr__
        if '_detached' in this.__dict__ and this._detached:
            warnings.warn('here')
        return func(this, *args, **kwargs)
    return wrapped


def has_storage(func):
    """ Ensure that self/cls contains a Storage backend. """
    @wraps(func)
    def wrapped(*args, **kwargs):
        me = args[0]
        if not hasattr(me, '_storage') or \
                not me._storage:
            raise exceptions.ImproperConfigurationError(
                'No storage backend attached to schema <{0}>.'
                    .format(me._name.upper())
            )
        return func(*args, **kwargs)
    return wrapped


class ObjectMeta(type):

    def add_field(cls, name, field):

        # Skip if not descriptor
        if not isinstance(field, Field):
            return

        # Memorize parent references
        field._schema_class = cls
        field._field_name = name

        # Check for primary key
        if field._is_primary:
            if cls._primary_name is None:
                cls._primary_name = name
                cls._primary_type = field.data_type
            else:
                raise AttributeError(
                    'Multiple primary keys are not supported.')

        # Wrap in list
        if field._list and not isinstance(field, ListField):
            field = ListField(
                field,
                **field._kwargs
            )
            # Memorize parent references
            field._schema_class = cls
            field._field_name = name
            # Set parent pointer of child field to list field
            field._field_instance._list_container = field

        # Subscribe to schema events
        field.subscribe(sender=cls)

        # Store descriptor to cls, cls._fields
        setattr(cls, name, field)
        cls._fields[name] = field

    def __init__(cls, name, bases, dct):

        # Run super-metaclass __init__
        super(ObjectMeta, cls).__init__(name, bases, dct)

        # Store prettified name
        cls._name = name.lower()

        # Store parameters from _meta
        my_meta = cls.__dict__.get('_meta', {})

        cls._is_optimistic = my_meta.get('optimistic', False)
        cls._is_abstract = my_meta.get('abstract', False)
        cls._log_level = my_meta.get('log_level', None)
        cls._version_of = my_meta.get('version_of', None)
        cls._version = my_meta.get('version', 1)
        cls._record_validators = my_meta.get('validators', [])

        # Prepare fields
        cls._fields = {}
        cls._primary_name = None
        cls._primary_type = None

        for key, value in cls.__dict__.items():
            cls.add_field(key, value)

        for base in bases:
            if not hasattr(base, '_fields') or not isinstance(base._fields, dict):
                continue
            for key, value in base._fields.items():
                cls.add_field(key, copy.deepcopy(value))

        # Impute field named _id as primary if no primary field specified;
        # must be exactly one primary field unless abstract
        if cls._fields:
            cls._is_root = False
            if cls._primary_name is None:
                if '_id' in cls._fields:
                    primary_field = cls._fields['_id']
                    primary_field._is_primary = True
                    if 'index' not in primary_field._kwargs or not primary_field._kwargs['index']:
                        primary_field._index = True
                    cls._primary_name = '_id'
                    cls._primary_type = cls._fields['_id'].data_type
                elif not cls._is_abstract:
                    raise AttributeError(
                        'Schemas must either define a field named _id or '
                        'specify exactly one field as primary.')
            # Register
            cls.register_collection()
        else:
            cls._is_root = True

    @property
    def _translator(cls):
        return cls._storage[0].translator


@six.add_metaclass(ObjectMeta)
[docs]class StoredObject(object): """ Base class to be used for models. """ _collections = {} _cache = Cache() _object_cache = Cache() queue = WriteQueue() def __init__(self, **kwargs): # Crash if abstract if self._is_abstract: raise TypeError('Cannot instantiate abstract schema') self.__backrefs = {} self._dirty = False self._detached = False self._is_loaded = kwargs.pop('_is_loaded', False) self._stored_key = None # Impute non-lazy default values (e.g. datetime with auto_now=True) for value in self._fields.values(): if not value.lazy_default: value.__set__(self, value._gen_default(), safe=True) # Add kwargs to instance for key, value in kwargs.items(): try: field = self._fields[key] field.__set__(self, value, safe=True) except KeyError: if key == '__backrefs': key = '_StoredObject__backrefs' setattr(self, key, value) if self._is_loaded: self._set_cache(self._primary_key, self, kwargs) def __eq__(self, other): try: if self is other: return True if self._primary_key != other._primary_key: return False return self.to_storage() == other.to_storage() except (AttributeError, TypeError): # Can't compare with "other". Try the reverse comparison return NotImplemented def __ne__(self, other): equal = self.__eq__(other) return equal if equal is NotImplemented else not equal def __hash__(self): # TODO: Is this the right thing to do? return id(self) @warn_if_detached def __unicode__(self): return unicode({field : unicode(getattr(self, field)) for field in self._fields}) @warn_if_detached def __str__(self): return unicode(self).decode('ascii', 'replace') @classmethod def register_collection(cls): cls._collections[cls._name] = cls @classmethod def get_collection(cls, name): return cls._collections[name.lower()] @property def _primary_key(self): return getattr(self, self._primary_name) @_primary_key.setter def _primary_key(self, value): setattr(self, self._primary_name, value) @property def _storage_key(self): """ Primary key passed through translator. """ return self._pk_to_storage(self._primary_key) @property @has_storage def _translator(self): return self.__class__._translator @has_storage def to_storage(self, translator=None, clone=False): data = {} for field_name, field_object in self._fields.items(): # Ignore primary and foreign fields if cloning # TODO: test this if clone: if field_object._is_primary or field_object._is_foreign: continue field_value = field_object.to_storage( field_object._get_underlying_data(self), translator ) data[field_name] = field_value data['_version'] = self._version if not clone and self.__backrefs: data['__backrefs'] = self.__backrefs return data @classmethod @has_storage def from_storage(cls, data, translator=None): result = {} for key, value in data.items(): field_object = cls._fields.get(key, None) if isinstance(field_object, Field): data_value = data[key] if data_value is None: value = None result[key] = None else: value = field_object.from_storage(data_value, translator) result[key] = value else: result[key] = value return result def clone(self): return self.load( data=self.to_storage(clone=True), _is_loaded=False ) # Backreferences @property def _backrefs(self): return FrozenDict(**self.__backrefs) @_backrefs.setter def _backrefs(self, _): raise exceptions.ModularOdmException('Cannot modify _backrefs.') @property def _backrefs_flat(self): return flatten_backrefs(self.__backrefs) def _remove_backref(self, backref_key, parent, parent_field_name, strict=False): try: self.__backrefs[backref_key][parent._name][parent_field_name].remove(parent._primary_key) self.save(force=True) except (KeyError, ValueError): if strict: raise def _update_backref(self, backref_key, parent, parent_field_name): updated = False try: refs = self.__backrefs[backref_key][parent._name][parent_field_name] if refs[refs.index(parent._stored_key)] != parent._primary_key: refs[refs.index(parent._stored_key)] = parent._primary_key updated = True except (KeyError, ValueError): self._set_backref(backref_key, parent_field_name, parent) return True if updated: self.save(force=True) return True return False def _set_backref(self, backref_key, parent_field_name, backref_value): backref_value_class_name = backref_value.__class__._name backref_value_primary_key = backref_value._primary_key if backref_value_primary_key is None: raise exceptions.DatabaseError('backref object\'s primary key must be saved first') if backref_key not in self.__backrefs: self.__backrefs[backref_key] = {} if backref_value_class_name not in self.__backrefs[backref_key]: self.__backrefs[backref_key][backref_value_class_name] = {} if parent_field_name not in self.__backrefs[backref_key][backref_value_class_name]: self.__backrefs[backref_key][backref_value_class_name][parent_field_name] = [] append_to = self.__backrefs[backref_key][backref_value_class_name][parent_field_name] if backref_value_primary_key not in append_to: append_to.append(backref_value_primary_key) self.save(force=True) @classmethod def set_storage(cls, storage): if not isinstance(storage, Storage): raise TypeError('Argument to set_storage must be an instance of Storage.') if not hasattr(cls, '_storage'): cls._storage = [] for field_name, field_object in cls._fields.items(): if field_object._index: storage._ensure_index(field_name) cls._storage.append(storage) # Caching ################################################################ @classmethod def _is_cached(cls, key): return cls._object_cache.get(cls._name, key) is not None @classmethod def _load_from_cache(cls, key): trans_key = cls._pk_to_storage(key) return cls._object_cache.get(cls._name, trans_key) @classmethod def _set_cache(cls, key, obj, data=None): trans_key = cls._pk_to_storage(key) cls._object_cache.set(cls._name, trans_key, obj) cls._cache.set(cls._name, trans_key, data) @classmethod def _get_cache(cls, key): trans_key = cls._pk_to_storage(key) return cls._object_cache.get(cls._name, trans_key) @classmethod def _get_cached_data(cls, key): return cls._cache.get(cls._name, key)
[docs] def get_changed_fields(self, cached_data, storage_data): """Get fields that differ between the cache_sandbox and the current object. Validation and after_save methods should only be run on diffed fields. :param cached_data: Storage-formatted data from cache_sandbox :param storage_data: Storage-formatted data from object :return: List of diffed fields """ if not self._is_loaded or cached_data is None: return [] return [ field for field in self._fields if cached_data.get(field) != storage_data.get(field) ]
# Cache clearing @classmethod def _clear_data_cache(cls, key=None): if not cls._fields: cls._cache.clear() elif key is not None: cls._cache.pop(cls._name, key) else: cls._cache.clear_schema(cls._name) @classmethod def _clear_object_cache(cls, key=None): if not cls._fields: cls._object_cache.clear() elif key is not None: cls._object_cache.pop(cls._name, key) else: cls._object_cache.clear_schema(cls._name) @classmethod def _clear_caches(cls, key=None): cls._clear_data_cache(key) cls._clear_object_cache(key) ########################################################################### @classmethod def _to_primary_key(cls, value): if value is None: return value if isinstance(value, cls): return value._primary_key return cls._check_pk_type(value) @classmethod def _check_pk_type(cls, key): if isinstance(key, cls._primary_type): return key try: cls._primary_type() cast_type = cls._primary_type except: cast_type = str try: key = cast_type(key) except: raise TypeError( 'Invalid key type: {key}, {type}, {ptype}.'.format( key=key, type=type(key), ptype=cast_type ) ) return key @classmethod @has_storage @log_storage
[docs] def load(cls, key=None, data=None, _is_loaded=True): """Get a record by its primary key. """ # Emit load signal signals.load.send( cls, key=key, data=data, ) if key is not None: key = cls._check_pk_type(key) cached_object = cls._load_from_cache(key) if cached_object is not None: return cached_object # Try loading from backend if data is None: data = cls._storage[0].get(cls._primary_name, cls._pk_to_storage(key)) # if not found, return None if data is None: return None # Convert storage data to ODM data = cls.from_storage(data) if cls._version_of and '_version' in data and data['_version'] != cls._version: old_object = cls._version_of.load(data=data) new_object = cls(_is_loaded=_is_loaded) cls.migrate(old_object, new_object) new_object._stored_key = new_object._primary_key return new_object ret = cls(_is_loaded=_is_loaded, **data) ret._stored_key = ret._primary_key return ret
@classmethod
[docs] def migrate_all(cls): """Migrate all records in this collection.""" for record in cls.find(): record.save()
@classmethod
[docs] def migrate(cls, old, new, verbose=True, dry_run=False, rm_refs=True): """Migrate record to new schema. :param old: Record from original schema :param new: Record from new schema :param verbose: Print detailed info :param dry_run: Dry run; make no changes if true :param rm_refs: Remove references on deleted fields """ if verbose: logging.basicConfig(format='%(levelname)s %(filename)s: %(message)s', level=logging.DEBUG) # Check deleted, added fields deleted_fields = [field for field in old._fields if field not in new._fields] added_fields = [field for field in new._fields if field not in old._fields] logging.info('Will delete fields: {0}'.format(deleted_fields)) logging.info('Will add fields: {0}'.format(added_fields)) # Check change in primary key if old._primary_name != new._primary_name: logging.info("The primary key will change from {old_name}: {old_field} to " "{new_name}: {new_field} in this migration. Primary keys and " "backreferences will not be automatically migrated. If you want " "to migrate primary keys, you should handle this in your " "migrate() method." .format(old_name=old._primary_name, old_field=old._fields[old._primary_name], new_name=new._primary_name, new_field=new._fields[new._primary_name])) # Copy fields to new object for field in old._fields: # Delete forward references on deleted fields if field not in cls._fields: if rm_refs: logging.info("Backreferences to this object keyed on foreign " "field {name}: {field} will be deleted in this migration. " "To prevent this behavior, re-run with <rm_fwd_refs> " "set to False.".format(name=field, field=old._fields[field])) if not dry_run: rm_fwd_refs(old) else: logging.info("Backreferences to this object keyed on foreign field " "{name}: {field} will be not deleted in this migration. " "To add this behavior, re-run with <rm_fwd_refs> " "set to True.".format(name=field, field=old._fields[field])) continue # Check for field change old_field_obj = old._fields[field] new_field_obj = new._fields[field] if old_field_obj != new_field_obj: if not old_field_obj._required and new_field_obj._required: logging.info("Field {name!r} is now required " "and therefore needs a default value " "for existing records. You can set " "this value in the _migrate() method. " "\nExample: " "\n if not old.{name}:" "\n new.{name} = 'default value'" .format(name=field)) else: logging.info( "Old field {name}: {old_field} differs from new field " "{name}: {new_field}. This field will not be " "automatically migrated. If you want to migrate this " "field, you should handle this in your migrate() " "method.".format(name=field, old_field=old_field_obj, new_field=new_field_obj) ) continue # Copy values of retained fields if not dry_run: field_object = cls._fields[field] field_object.__set__( new, getattr(old, field), safe=True ) # Copy backreferences if not dry_run: new.__backrefs = old.__backrefs # Run custom migration if not dry_run: cls._migrate(old, new)
@classmethod def _migrate(cls, old, new): """Subclasses can override this class to perform a custom migration. This is run after the migrate() method. Example: :: class NewSchema(StoredObject): _id = fields.StringField(primary=True, index=True) my_string = fields.StringField() @classmethod def _migrate(cls, old, new): new.my_string = old.my_string + 'yo' _meta = { 'version_of': OldSchema, 'version': 2, 'optimistic': True } :param old: Record from original schema :param new: Record from new schema """ return new @classmethod def explain_migration(cls): logging.basicConfig(format='%(levelname)s %(filename)s: %(message)s', level=logging.DEBUG) classes = [cls] methods = [cls._migrate] klass = cls while klass._version and klass._version_of: classes.insert(0, klass._version_of) try: methods.insert(0, klass._migrate) except AttributeError: methods.insert(0, None) klass = klass._version_of for step in range(len(classes) - 1): fr = classes[step] to = classes[step + 1] logging.info('From schema {0}'.format(fr._name)) logging.info('\n'.join('\t{0}'.format(field) for field in fr._fields)) logging.info('') logging.info('To schema {0}'.format(to._name)) logging.info('\n'.join('\t{0}'.format(field) for field in to._fields)) logging.info('') to.migrate(fr, to, verbose=True, dry_run=True) @classmethod def _must_be_loaded(cls, value): if value is not None and not value._is_loaded: raise exceptions.DatabaseError('Record must be loaded.') @has_storage @log_storage def _optimistic_insert(self): self._primary_key = self._storage[0]._optimistic_insert( self._primary_name, self.to_storage() )
[docs] def validate_record(self): """Apply record-level validation. Run on `save`. """ for validator in self._record_validators: validator(self)
@has_storage @log_storage
[docs] def save(self, force=False): """Save a record. :param bool force: Save even if no fields have changed; used to update back-references :returns: List of changed fields """ if self._detached: raise exceptions.DatabaseError('Cannot save detached object.') for field_name, field_object in self._fields.items(): if hasattr(field_object, 'on_before_save'): field_object.on_before_save(self) signals.before_save.send( self.__class__, instance=self ) cached_data = self._get_cached_data(self._stored_key) storage_data = self.to_storage() if self._primary_key is not None and cached_data is not None: fields_changed = set( self.get_changed_fields(cached_data, storage_data) ) else: fields_changed = set(self._fields.keys()) # Quit if no diffs if not fields_changed and not force: return [] # Apply field-level validation for field_name in fields_changed: field_object = self._fields[field_name] field_object.do_validate(getattr(self, field_name), self) # Apply record-level validation self.validate_record() primary_changed = ( self._primary_key != self._stored_key and self._primary_name in fields_changed ) if self._is_loaded: if primary_changed and not getattr(self, '_updating_key', False): self.delegate( self._storage[0].remove, False, RawQuery(self._primary_name, 'eq', self._stored_key) ) self._clear_caches(self._stored_key) self.insert(self._primary_key, storage_data) else: self.update_one(self, storage_data=storage_data, saved=True, inmem=True) elif self._is_optimistic and self._primary_key is None: self._optimistic_insert() else: self.insert(self._primary_key, storage_data) # if primary key has changed, follow back references and update # AND # run after_save or after_save_on_difference if self._is_loaded and primary_changed: if not getattr(self, '_updating_key', False): self._updating_key = True update_backref_keys(self) self._stored_key = self._primary_key self._updating_key = False else: self._stored_key = self._primary_key self._is_loaded = True signals.save.send( self.__class__, instance=self, fields_changed=fields_changed, cached_data=cached_data or {}, ) storage_data[self._primary_name] = self._storage_key self._set_cache(self._primary_key, self, storage_data) return fields_changed
[docs] def update_fields(self, **kwargs): """Update multiple fields, specified by keyword arguments. Example:: person.update(given_name='Fred', family_name='Mercury') ... is equivalent to ... :: person.given_name = 'Fred' person.family_name = 'Mercury' :param **kwargs: field names and the values to set """ for key, value in kwargs.items(): self._fields[key].__set__(self, value, safe=True)
def reload(self): storage_data = self._storage[0].get(self._primary_name, self._storage_key) for key, value in storage_data.items(): field_object = self._fields.get(key, None) if isinstance(field_object, Field): data_value = storage_data[key] if data_value is None: value = None else: value = field_object.from_storage(data_value) field_object.__set__(self, value, safe=True) elif key == '__backrefs': self._StoredObject__backrefs = value self._stored_key = self._primary_key self._set_cache(self._storage_key, self, storage_data) @warn_if_detached def __getattr__(self, item): errmsg = '{cls} object has no attribute {item}'.format( cls=self.__class__.__name__, item=item ) if item in self.__backrefs: backrefs = [] for parent, rest0 in six.iteritems(self.__backrefs[item]): for field, rest1 in six.iteritems(rest0): backrefs.extend([ (key, parent) for key in rest1 ]) return AbstractForeignList(backrefs) # Retrieve back-references if '__' in item and not item.startswith('__'): item_split = item.split('__') if len(item_split) == 2: parent_schema_name, backref_key = item_split backrefs = deref(self.__backrefs, [backref_key, parent_schema_name], missing={}) ids = sum( backrefs.values(), [] ) elif len(item_split) == 3: parent_schema_name, backref_key, parent_field_name = item_split ids = deref(self.__backrefs, [backref_key, parent_schema_name, parent_field_name], missing=[]) else: raise AttributeError(errmsg) try: base_class = self.get_collection(parent_schema_name) except KeyError: raise exceptions.ModularOdmException( 'Unknown schema <{0}>'.format( parent_schema_name ) ) return ForeignList(ids, literal=True, base_class=base_class) raise AttributeError(errmsg) @warn_if_detached def __setattr__(self, key, value): if key not in self._fields and not key.startswith('_'): warnings.warn('Setting an attribute that is neither a field nor a protected value.') super(StoredObject, self).__setattr__(key, value) # Querying ###### @classmethod def _parse_key_value(cls, value): if isinstance(value, StoredObject): return value._primary_key, value return value, cls.load(cls._pk_to_storage(value)) @classmethod @has_storage def _pk_to_storage(cls, key): return cls._fields[cls._primary_name].to_storage(key) @classmethod def _process_query(cls, query): if isinstance(query, RawQuery): field = cls._fields.get(query.attribute) if field is None: return if field._is_foreign: if getattr(query.argument, '_fields', None): if field._is_abstract: query.argument = ( query.argument._primary_key, query.argument._name, ) else: query.argument = query.argument._primary_key elif isinstance(query, QueryGroup): for node in query.nodes: cls._process_query(node) @classmethod @has_storage @log_storage
[docs] def find(cls, query=None, **kwargs): """ :param query: :param kwargs: :return: an iterable of :class:`StoredObject` instances """ cls._process_query(query) return cls._storage[0].QuerySet( cls, cls._storage[0].find(query, **kwargs) )
@classmethod @has_storage @log_storage def find_one(cls, query=None, **kwargs): cls._process_query(query) stored_data = cls._storage[0].find_one(query, **kwargs) return cls.load( key=stored_data[cls._primary_name], data=stored_data ) # Queueing @classmethod
[docs] def delegate(cls, method, conflict=None, *args, **kwargs): """Execute or queue a database action. Variable positional and keyword arguments are passed to the provided method. :param function method: Method to execute or queue :param bool conflict: Potential conflict between cache_sandbox and backend, e.g., in the event of bulk updates or removes that bypass the cache_sandbox """ if cls.queue.active: action = WriteAction(method, *args, **kwargs) if conflict: logger.warn('Delayed write {0!r} may cause the cache to ' 'diverge from the database until changes are ' 'committed.'.format(action)) cls.queue.push(action) else: method(*args, **kwargs)
@classmethod
[docs] def start_queue(cls): """Start the queue. Between calling `start_queue` and `commit_queue`, all writes will be deferred to the queue. """ cls.queue.start()
@classmethod
[docs] def clear_queue(cls): """Clear the queue. """ cls.queue.clear()
@classmethod
[docs] def cancel_queue(cls): """Cancel any pending actions. This method clears the queue and also clears caches if any actions are pending. """ if cls.queue: cls._cache.clear() cls._object_cache.clear() cls.clear_queue()
@classmethod
[docs] def commit_queue(cls): """Commit all queued actions. If any actions fail, clear caches. Note: the queue will be cleared whether an error is raised or not. """ try: cls.queue.commit() cls.clear_queue() except: cls.cancel_queue() raise
@classmethod
[docs] def subscribe(cls, signal_name, weak=True): """ :param str signal_name: Name of signal to subscribe to; must be found in ``signals.py``. :param bool weak: Create weak reference to callback :returns: Decorator created by ``Signal::connect_via`` :raises: ValueError if signal is not found Example usage: :: >>> @Schema.subscribe('before_save') ... def listener(cls, instance): ... instance.value += 1 """ try: signal = getattr(signals, signal_name) except AttributeError: raise ValueError( 'Signal {0} not found'.format(signal_name) ) sender = None if cls._is_root else cls return signal.connect_via(sender, weak)
@classmethod @has_storage def insert(cls, key, val): cls.delegate( cls._storage[0].insert, False, cls._primary_name, cls._pk_to_storage(key), val ) @classmethod def _includes_foreign(cls, keys): for key in keys: if key in cls._fields and cls._fields[key]._is_foreign: return True return False @classmethod def _data_to_storage(cls, data): storage_data = {} for key, value in data.items(): if key in cls._fields: field_object = cls._fields[key] if key == cls._primary_name: continue storage_data[key] = field_object.to_storage(value) else: storage_data[key] = value return storage_data def _update_in_memory(self, storage_data): for field_name, data_value in storage_data.items(): field_object = self._fields[field_name] field_object.__set__(self, data_value, safe=True) self.save() @classmethod def _which_to_obj(cls, which): if isinstance(which, QueryBase): return cls.find_one(which) if isinstance(which, StoredObject): return which return cls.load(cls._pk_to_storage(which)) @classmethod @has_storage def update_one(cls, which, data=None, storage_data=None, saved=False, inmem=False): storage_data = storage_data or cls._data_to_storage(data) includes_foreign = cls._includes_foreign(storage_data.keys()) obj = cls._which_to_obj(which) if saved or not includes_foreign: cls.delegate( cls._storage[0].update, False, RawQuery( cls._primary_name, 'eq', obj._primary_key ), storage_data, ) if obj and not inmem: obj._dirty = True if not saved: cls._clear_caches(obj._storage_key) else: obj._update_in_memory(storage_data) @classmethod @has_storage def update(cls, query, data=None, storage_data=None): storage_data = storage_data or cls._data_to_storage(data) includes_foreign = cls._includes_foreign(storage_data.keys()) objs = cls.find(query) keys = objs.get_keys() if not includes_foreign: cls.delegate( cls._storage[0].update, True, query, storage_data ) for key in keys: obj = cls._get_cache(key) if obj is not None: obj._dirty = True else: for obj in objs: obj._update_in_memory(storage_data) @classmethod @has_storage
[docs] def remove_one(cls, which, rm=True): """Remove an object, along with its references and back-references. Remove the object from the cache_sandbox and sets its _detached flag to True. :param which: Object selector: Query, StoredObject, or primary key :param rm: Remove data from backend """ # Look up object obj = cls._which_to_obj(which) # Remove references rm_fwd_refs(obj) rm_back_refs(obj) # Remove from cache_sandbox cls._clear_caches(obj._storage_key) # Remove from backend if rm: cls.delegate( cls._storage[0].remove, False, RawQuery(obj._primary_name, 'eq', obj._storage_key) ) # Set detached obj._detached = True
@classmethod @has_storage
[docs] def remove(cls, query=None): """Remove objects by query. :param query: Query object """ objs = cls.find(query) for obj in objs: cls.remove_one(obj, rm=False) cls.delegate( cls._storage[0].remove, False, query )
def rm_fwd_refs(obj): """When removing an object, other objects with references to the current object should remove those references. This function identifies objects with forward references to the current object, then removes those references. :param obj: Object to which forward references should be removed """ for stack, key in obj._backrefs_flat: # Unpack stack backref_key, parent_schema_name, parent_field_name = stack # Get parent info parent_schema = obj._collections[parent_schema_name] parent_key_store = parent_schema._pk_to_storage(key) parent_object = parent_schema.load(parent_key_store) if parent_object is None: continue # Remove forward references if parent_object._fields[parent_field_name]._list: getattr(parent_object, parent_field_name).remove(obj) else: parent_field_object = parent_object._fields[parent_field_name] setattr(parent_object, parent_field_name, parent_field_object._gen_default()) # Save parent_object.save() def _collect_refs(obj, fields=None): """ """ refs = [] fields = fields or [] for field_name, field_object in obj._fields.items(): # Skip if not foreign field if not field_object._is_foreign: continue # Skip if value is None value = getattr(obj, field_name) if value is None: continue # Skip if not in fields if fields and field_name not in fields: continue field_refs = [] # Build list of linked objects if ListField, else single field if isinstance(field_object, ListField): field_refs.extend([v for v in value if v]) field_instance = field_object._field_instance else: field_refs.append(value) field_instance = field_object # Skip if field does not specify back-references if not field_instance._backref_field_name: continue refs.extend([ { 'value': ref, 'field_name': field_name, 'field_instance': field_instance, } for ref in field_refs ]) return refs def rm_back_refs(obj): """When removing an object with foreign fields, back-references from other objects to the current object should be deleted. This function identifies foreign fields of the specified object whose values are not None and which specify back-reference keys, then removes back-references from linked objects to the specified object. :param obj: Object for which back-references should be removed """ for ref in _collect_refs(obj): ref['value']._remove_backref( ref['field_instance']._backref_field_name, obj, ref['field_name'], strict=False ) def ensure_backrefs(obj, fields=None): """Ensure that all forward references on the provided object have the appropriate backreferences. :param StoredObject obj: Database record :param list fields: Optional list of field names to check """ for ref in _collect_refs(obj, fields): updated = ref['value']._update_backref( ref['field_instance']._backref_field_name, obj, ref['field_name'], ) if updated: logging.debug('Updated reference {}:{}:{}:{}:{}'.format( obj._name, obj._primary_key, ref['field_name'], ref['value']._name, ref['value']._primary_key, )) def update_backref_keys(obj): """ """ for ref in _collect_refs(obj): ref['value']._update_backref( ref['field_instance']._backref_field_name, obj, ref['field_name'], ) for stack, key in obj._backrefs_flat: # Unpack stack backref_key, parent_schema_name, parent_field_name = stack # Get parent info parent_schema = obj._collections[parent_schema_name] parent_key_store = parent_schema._pk_to_storage(key) parent_object = parent_schema.load(parent_key_store) if parent_object is None: continue # field_object = parent_object._fields[parent_field_name] if field_object._list: value = getattr(parent_object, parent_field_name) if field_object._is_abstract: idx = value.index((obj._stored_key, obj._name)) value[idx] = (obj._primary_key, obj._name) else: idx = value.index(obj._stored_key) value[idx] = obj else: setattr(parent_object, parent_field_name, obj) # Save parent_object.save()