Source code for superjson._superjson

# -*- coding: utf-8 -*-

import os
import json
import time
import inspect

from collections import OrderedDict, deque
from datetime import datetime
from base64 import b64encode, b64decode

from .helper import get_class_name
from .comments import strip_comments
from .warning import logger, WARN_MSG, prt_console
from .pkg import compresslib
from .pkg.atomicwrites import atomic_write


[docs]def get_class_name_from_dumper_loader_method(func): """ Get default value of ``class_name`` argument. Because the third argument of dumper, loader method must be the class name. """ return inspect.getfullargspec(func).defaults[0]
[docs]def is_dumper_method(func): """ Test if it is a dumper method. """ if inspect.getfullargspec(func).args == ["self", "obj", "class_name"]: return True else: return False
[docs]def is_loader_method(func): """ Test if it is a loader method. """ if inspect.getfullargspec(func).args == ["self", "dct", "class_name"]: return True else: return False
class Meta(type): def __new__(cls, name, bases, attrs): klass = super(Meta, cls).__new__(cls, name, bases, attrs) _dumpers = dict() _loaders = dict() for base in inspect.getmro(klass): for attr, value in base.__dict__.items(): dumper_warning_message = WARN_MSG.format( attr=attr, method_type="dumper", obj_or_dct="obj", dump_or_load="dump", ) loader_warning_message = WARN_MSG.format( attr=attr, method_type="loader", obj_or_dct="dct", dump_or_load="load", ) # link dumper / loader method with the full classname # find dumper method, if attr.startswith("dump_"): try: if is_dumper_method(value): class_name = get_class_name_from_dumper_loader_method( value) _dumpers[class_name] = value else: logger.warning(dumper_warning_message) except TypeError: logger.warning(dumper_warning_message) # find loader method if attr.startswith("load_"): try: if is_loader_method(value): class_name = get_class_name_from_dumper_loader_method( value) _loaders[class_name] = value else: logger.warning(loader_warning_message) except TypeError: logger.warning(loader_warning_message) klass._dumpers = _dumpers klass._loaders = _loaders return klass bytes_class_name = get_class_name(bytes()) set_class_name = get_class_name(set())
[docs]def is_compressed_json_file(abspath): """Test a file is a valid json file. - ``*.json``: uncompressed, utf-8 encode json file - ``*.js``: uncompressed, utf-8 encode json file - ``*.gz``: compressed, utf-8 encode json file """ abspath = abspath.lower() fname, ext = os.path.splitext(abspath) if ext in [".json", ".js"]: is_compressed = False elif ext == ".gz": is_compressed = True else: raise ValueError( "'%s' is not a valid json file. " "extension has to be '.json' or '.js' for uncompressed, '.gz' " "for compressed." % abspath) return is_compressed
[docs]class BaseSuperJson(metaclass=Meta): """ A extensable json encoder/decoder. You can easily custom converter for any types. """ _dumpers = dict() _loaders = dict() def _dump(self, obj): """Dump single object to json serializable value. """ class_name = get_class_name(obj) if class_name in self._dumpers: return self._dumpers[class_name](self, obj) raise TypeError("%r is not JSON serializable" % obj) def _json_convert(self, obj): """Recursive helper method that converts dict types to standard library json serializable types, so they can be converted into json. """ # OrderedDict if isinstance(obj, OrderedDict): try: return self._dump(obj) except TypeError: return {k: self._json_convert(v) for k, v in obj.items()} # nested dict elif isinstance(obj, dict): return {k: self._json_convert(v) for k, v in obj.items()} # list or tuple elif isinstance(obj, (list, tuple)): return list((self._json_convert(v) for v in obj)) # float elif isinstance(obj, float): return float(json.encoder.FLOAT_REPR(obj)) # single object try: return self._dump(obj) except TypeError: return obj def _object_hook1(self, dct): """A function can convert dict data into object. it's an O(1) implementation. """ # {"$class_name": obj_data} if len(dct) == 1: for key, value in dct.items(): class_name = key[1:] if class_name in self._loaders: return self._loaders[class_name](self, dct) return dct return dct def _object_hook2(self, dct): # pragma: no cover """Another object hook implementation. it's an O(N) implementation. """ for class_name, loader in self._loaders.items(): if ("$" + class_name) in dct: return loader(self, dct) return dct
[docs] def dumps( self, obj, indent: bool = None, sort_keys: bool = None, pretty: bool = False, float_precision: int = None, ensure_ascii: bool = True, compress: bool = False, **kwargs ): """Dump any object into json string. :param pretty: if ``True``, dump json into pretty indent and sorted key format. :type pretty: bool :param float_precision: default ``None``, limit floats to N-decimal points. :type float_precision: int :param compress: default ``False``. If True, then compress encoded string. :type compress: bool """ if pretty: indent = 4 sort_keys = True if float_precision is None: json.encoder.FLOAT_REPR = repr else: json.encoder.FLOAT_REPR = lambda x: format( x, ".%sf" % float_precision) s = json.dumps( self._json_convert(obj), indent=indent, sort_keys=sort_keys, ensure_ascii=ensure_ascii, **kwargs ) if compress: s = compresslib.compress(s, return_type="str") return s
[docs] def loads( self, s: str, object_hook: bool = None, decompress: bool = False, ignore_comments: bool = False, **kwargs, ): """load object from json encoded string. :param decompress: default ``False``. If True, then decompress string. :type decompress: bool :param ignore_comments: default ``False``. If True, then ignore comments. :type ignore_comments: bool """ if decompress: s = compresslib.decompress(s, return_type="str") if ignore_comments: s = strip_comments(s) if object_hook is None: object_hook = self._object_hook1 if "object_pairs_hook" in kwargs: del kwargs["object_pairs_hook"] obj = json.loads( s, object_hook=object_hook, object_pairs_hook=None, **kwargs ) return obj
[docs] def dump( self, obj, abspath: str, indent: bool = None, sort_keys: bool = None, pretty: bool = False, float_precision: int = None, ensure_ascii: bool = True, overwrite: bool = False, verbose: bool = True, **kwargs ): """Dump any object into file. :param abspath: if ``*.json, *.js**`` then do regular dump. if ``*.gz``, then perform compression. :type abspath: str :param pretty: if True, dump json into pretty indent and sorted key format. :type pretty: bool :param float_precision: default ``None``, limit floats to N-decimal points. :type float_precision: int :param overwrite: default ``False``, If ``True``, when you dump to existing file, it silently overwrite it. If ``False``, an alert message is shown. Default setting ``False`` is to prevent overwrite file by mistake. :type overwrite: boolean :param verbose: default True, help-message-display trigger. :type verbose: boolean """ prt_console("\nDump to '%s' ..." % abspath, verbose) is_compressed = is_compressed_json_file(abspath) if not overwrite: if os.path.exists(abspath): # pragma: no cover prt_console( " Stop! File exists and overwrite is not allowed", verbose, ) return st = time.process_time() s = self.dumps( obj, indent=indent, sort_keys=sort_keys, pretty=pretty, float_precision=float_precision, ensure_ascii=ensure_ascii, compress=False, # use uncompressed string, and directly write to file **kwargs, ) with atomic_write(abspath, mode="wb", overwrite=True) as f: if is_compressed: f.write(compresslib.compress(s, return_type="bytes")) else: f.write(s.encode("utf-8")) prt_console( " Complete! Elapse %.6f sec." % (time.process_time() - st), verbose, ) return s
[docs] def load( self, abspath: str, object_hook=None, ignore_comments: bool = False, verbose: bool = True, **kwargs ): """load object from json file. :param abspath: if ``*.json, *.js** then do regular dump. if ``*.gz``, then perform decompression. :type abspath: str :param ignore_comments: default ``False. If True, then ignore comments. :type ignore_comments: bool :param verbose: default True, help-message-display trigger. :type verbose: boolean """ prt_console("\nLoad from '%s' ..." % abspath, verbose) is_compressed = is_compressed_json_file(abspath) if not os.path.exists(abspath): raise EnvironmentError("'%s' doesn't exist." % abspath) st = time.process_time() with open(abspath, "rb") as f: if is_compressed: s = compresslib.decompress(f.read(), return_type="str") else: s = f.read().decode("utf-8") obj = self.loads( s, object_hook=object_hook, decompress=False, ignore_comments=ignore_comments, **kwargs, ) prt_console(" Complete! Elapse %.6f sec." % (time.process_time() - st), verbose) return obj
# ---------------------------------------------------------------------- # Support built in data type # ----------------------------------------------------------------------
[docs] def dump_bytes(self, obj, class_name=bytes_class_name): """ ``btyes`` dumper. """ return {"$" + class_name: b64encode(obj).decode()}
[docs] def load_bytes(self, dct, class_name=bytes_class_name): """ ``btyes`` loader. """ return b64decode(dct["$" + class_name].encode())
[docs] def dump_datetime(self, obj, class_name="datetime.datetime"): """ ``datetime.datetime`` dumper. """ return {"$" + class_name: obj.isoformat()}
[docs] def load_datetime(self, dct, class_name="datetime.datetime"): """ ``datetime.datetime`` loader. """ try: from dateutil.parser import parse except ImportError: # pragma: no cover msg = ("You need to install `python-dateutil` to support load/dump for datetime type") logger.info(msg) raise return parse(dct["$" + class_name])
[docs] def dump_date(self, obj, class_name="datetime.date"): """ ``datetime.date`` dumper. """ return {"$" + class_name: str(obj)}
[docs] def load_date(self, dct, class_name="datetime.date"): """ ``datetime.date`` loader. """ return datetime.strptime(dct["$" + class_name], "%Y-%m-%d").date()
[docs] def dump_set(self, obj, class_name=set_class_name): """ ``set`` dumper. """ return {"$" + class_name: [self._json_convert(item) for item in obj]}
[docs] def load_set(self, dct, class_name=set_class_name): """ ``set`` loader. """ return set(dct["$" + class_name])
[docs] def dump_deque(self, obj, class_name="collections.deque"): """ ``collections.deque`` dumper. """ return {"$" + class_name: [self._json_convert(item) for item in obj]}
[docs] def load_deque(self, dct, class_name="collections.deque"): """ ``collections.deque`` loader. """ return deque(dct["$" + class_name])
[docs] def dump_OrderedDict(self, obj, class_name="collections.OrderedDict"): """ ``collections.OrderedDict`` dumper. """ return { "$" + class_name: [ (key, self._json_convert(value)) for key, value in obj.items() ] }
[docs] def load_OrderedDict(self, dct, class_name="collections.OrderedDict"): """ ``collections.OrderedDict`` loader. """ return OrderedDict(dct["$" + class_name])
[docs]class SuperJson(BaseSuperJson): pass
superjson = SuperJson()