# -*- coding: utf-8 -*_
import re
import pymongo
from .base import Storage
from ..query.queryset import BaseQuerySet
from ..query.query import QueryGroup
from ..query.query import RawQuery
from modularodm.exceptions import (
KeyExistsException,
MultipleResultsFound,
NoResultsFound,
)
# From mongoengine.queryset.transform
COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'not', 'elemMatch')
# GEO_OPERATORS = ('within_distance', 'within_spherical_distance',
# 'within_box', 'within_polygon', 'near', 'near_sphere',
# 'max_distance', 'geo_within', 'geo_within_box',
# 'geo_within_polygon', 'geo_within_center',
# 'geo_within_sphere', 'geo_intersects')
STRING_OPERATORS = ('contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
'exact', 'iexact')
# CUSTOM_OPERATORS = ('match',)
# MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
# STRING_OPERATORS + CUSTOM_OPERATORS)
# UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push',
# 'push_all', 'pull', 'pull_all', 'add_to_set',
# 'set_on_insert')
# Adapted from mongoengine.fields
def prepare_query_value(op, value):
if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'):
flags = 0
if op.startswith('i'):
flags = re.IGNORECASE
op = op.lstrip('i')
regex = r'%s'
if op == 'startswith':
regex = r'^%s'
elif op == 'endswith':
regex = r'%s$'
elif op == 'exact':
regex = r'^%s$'
# escape unsafe characters which could lead to a re.error
value = re.escape(value)
value = re.compile(regex % value, flags)
return value
# TODO: Test me
def translate_query(query=None, mongo_query=None):
"""
"""
mongo_query = mongo_query or {}
if isinstance(query, RawQuery):
attribute, operator, argument = \
query.attribute, query.operator, query.argument
if operator == 'eq':
mongo_query[attribute] = argument
elif operator in COMPARISON_OPERATORS:
mongo_operator = '$' + operator
if attribute not in mongo_query:
mongo_query[attribute] = {}
mongo_query[attribute][mongo_operator] = argument
elif operator in STRING_OPERATORS:
mongo_operator = '$regex'
mongo_regex = prepare_query_value(operator, argument)
if attribute not in mongo_query:
mongo_query[attribute] = {}
mongo_query[attribute][mongo_operator] = mongo_regex
elif isinstance(query, QueryGroup):
if query.operator == 'and':
return {'$and': [translate_query(node) for node in query.nodes]}
elif query.operator == 'or':
return {'$or' : [translate_query(node) for node in query.nodes]}
elif query.operator == 'not':
# Hack: A nor A == not A
subquery = translate_query(query.nodes[0])
return {'$nor' : [subquery, subquery]}
else:
raise ValueError('QueryGroup operator must be <and>, <or>, or <not>.')
elif query is None:
return {}
else:
raise TypeError('Query must be a QueryGroup or Query object.')
return mongo_query
class MongoQuerySet(BaseQuerySet):
_NEGATIVE_INDEXING = True
def __init__(self, schema, cursor):
super(MongoQuerySet, self).__init__(schema)
self.data = cursor
self._order = [('_id', 1)] # Default sorting
def _do_getitem(self, index, raw=False):
if isinstance(index, slice):
return MongoQuerySet(self.schema, self.data.clone()[index])
if index < 0:
clone = self.data.clone().sort([(o[0], o[1] * -1) for o in self._order])
result = clone[(index * -1) - 1]
else:
result = self.data[index]
if raw:
return result[self.primary]
return self.schema.load(data=result)
def __iter__(self, raw=False):
cursor = self.data.clone()
if raw:
return [each[self.primary] for each in cursor]
return (self.schema.load(data=each) for each in cursor)
def __len__(self):
return self.data.count(with_limit_and_skip=True)
count = __len__
def get_key(self, index):
return self.__getitem__(index, raw=True)
def get_keys(self):
return list(self.__iter__(raw=True))
def sort(self, *keys):
sort_key = []
for key in keys:
if key.startswith('-'):
key = key.lstrip('-')
sign = pymongo.DESCENDING
else:
sign = pymongo.ASCENDING
sort_key.append((key, sign))
self._order = sort_key
self.data = self.data.sort(sort_key)
return self
def offset(self, n):
self.data = self.data.skip(n)
return self
def limit(self, n):
self.data = self.data.limit(n)
return self
[docs]class MongoStorage(Storage):
"""Wrap a MongoDB collection. Note: `store` is a property instead of an
attribute to handle passing `db` as a proxy.
:param Database db:
:param str collection:
"""
QuerySet = MongoQuerySet
def __init__(self, db, collection):
self.db = db
self.collection = collection
@property
def store(self):
return self.db[self.collection]
def _ensure_index(self, key):
self.store.ensure_index(key)
def find(self, query=None, **kwargs):
mongo_query = translate_query(query)
return self.store.find(mongo_query)
def find_one(self, query=None, **kwargs):
mongo_query = translate_query(query)
matches = self.store.find(mongo_query).limit(2)
if matches.count() == 1:
return matches[0]
if matches.count() == 0:
raise NoResultsFound()
raise MultipleResultsFound(
'Query for find_one must return exactly one result; '
'returned {0}'.format(matches.count())
)
def get(self, primary_name, key):
return self.store.find_one({primary_name : key})
def insert(self, primary_name, key, value):
if primary_name not in value:
value = value.copy()
value[primary_name] = key
try:
self.store.insert(value)
except pymongo.errors.DuplicateKeyError:
raise KeyExistsException
def update(self, query, data):
mongo_query = translate_query(query)
# Field "_id" shouldn't appear in both search and update queries; else
# MongoDB will raise a "Mod on _id not allowed" error
if '_id' in mongo_query:
update_data = {k: v for k, v in data.items() if k != '_id'}
else:
update_data = data
update_query = {'$set': update_data}
self.store.update(
mongo_query,
update_query,
upsert=False,
multi=True,
)
def remove(self, query=None):
mongo_query = translate_query(query)
self.store.remove(mongo_query)
def flush(self):
pass