Source code for django_celery_results.backends.database

from __future__ import absolute_import, unicode_literals

import json

from celery import maybe_signature
from celery.backends.base import BaseDictBackend
from celery.exceptions import ChordError
from celery.result import allow_join_result
from celery.utils.serialization import b64encode, b64decode
from celery.utils.log import get_logger
from django.db import transaction

from ..models import TaskResult, ChordCounter


logger = get_logger(__name__)


[docs]class DatabaseBackend(BaseDictBackend): """The Django database backend, using models to store task state.""" TaskModel = TaskResult subpolling_interval = 0.5 def _store_result(self, task_id, result, status, traceback=None, request=None, using=None): """Store return value and status of an executed task.""" content_type, content_encoding, result = self.encode_content(result) _, _, meta = self.encode_content({ 'children': self.current_task_children(request), }) task_name = getattr(request, 'task', None) _, _, task_args = self.encode_content( getattr(request, 'argsrepr', getattr(request, 'args', None)) ) _, _, task_kwargs = self.encode_content( getattr(request, 'kwargsrepr', getattr(request, 'kwargs', None)) ) worker = getattr(request, 'hostname', None) self.TaskModel._default_manager.store_result( content_type, content_encoding, task_id, result, status, traceback=traceback, meta=meta, task_name=task_name, task_args=task_args, task_kwargs=task_kwargs, worker=worker, using=using, ) return result def _get_task_meta_for(self, task_id): """Get task metadata for a task by id.""" obj = self.TaskModel._default_manager.get_task(task_id) res = obj.as_dict() meta = self.decode_content(obj, res.pop('meta', None)) or {} result = self.decode_content(obj, res.get('result')) task_args = self.decode_content(obj, res.get('task_args')) task_kwargs = self.decode_content(obj, res.get('task_kwargs')) res.update( meta, result=result, task_args=task_args, task_kwargs=task_kwargs, ) return self.meta_from_decoded(res)
[docs] def encode_content(self, data): content_type, content_encoding, content = self._encode(data) if content_encoding == 'binary': content = b64encode(content) return content_type, content_encoding, content
[docs] def decode_content(self, obj, content): if content: if obj.content_encoding == 'binary': content = b64decode(content) return self.decode(content)
def _forget(self, task_id): try: self.TaskModel._default_manager.get(task_id=task_id).delete() except self.TaskModel.DoesNotExist: pass
[docs] def cleanup(self): """Delete expired metadata.""" self.TaskModel._default_manager.delete_expired(self.expires)
[docs] def apply_chord(self, header_result, body, **kwargs): """Add a ChordCounter with the expected number of results""" results = [r.as_tuple() for r in header_result] data = json.dumps(results) ChordCounter.objects.create( group_id=header_result.id, sub_tasks=data, count=len(results) )
[docs] def on_chord_part_return(self, request, state, result, **kwargs): """Called on finishing each part of a Chord header""" tid, gid = request.id, request.group if not gid or not tid: return call_callback = False with transaction.atomic(): # We need to know if `count` hits 0. # wrap the update in a transaction # with a `select_for_update` lock to prevent race conditions. # SELECT FOR UPDATE is not supported on all databases chord_counter = ( ChordCounter.objects.select_for_update() .get(group_id=gid) ) chord_counter.count -= 1 if chord_counter.count != 0: chord_counter.save() else: # Last task in the chord header has finished call_callback = True chord_counter.delete() if call_callback: deps = chord_counter.group_result(app=self.app) if deps.ready(): callback = maybe_signature(request.chord, app=self.app) trigger_callback( app=self.app, callback=callback, group_result=deps )
[docs]def trigger_callback(app, callback, group_result): """Add the callback to the queue or mark the callback as failed Implementation borrowed from `celery.app.builtins.unlock_chord` """ j = ( group_result.join_native if group_result.supports_native_join else group_result.join ) try: with allow_join_result(): ret = j(timeout=app.conf.result_chord_join_timeout, propagate=True) except Exception as exc: # pylint: disable=broad-except try: culprit = next(group_result._failed_join_report()) reason = "Dependency {0.id} raised {1!r}".format(culprit, exc) except StopIteration: reason = repr(exc) logger.exception("Chord %r raised: %r", group_result.id, exc) app.backend.chord_error_from_stack(callback, ChordError(reason)) else: try: callback.delay(ret) except Exception as exc: # pylint: disable=broad-except logger.exception("Chord %r raised: %r", group_result.id, exc) app.backend.chord_error_from_stack( callback, exc=ChordError("Callback error: {0!r}".format(exc)) )