from __future__ import absolute_import
import ast
import sys
import zlib
import warnings
import json

import numpy as np

from ..handlers import BaseHandler, register, unregister
from ..compat import numeric_types
from ..util import b64decode, b64encode
from .. import compat


__all__ = ['register_handlers', 'unregister_handlers']

native_byteorder = '<' if sys.byteorder == 'little' else '>'


def get_byteorder(arr):
    """translate equals sign to native order"""
    byteorder = arr.dtype.byteorder
    return native_byteorder if byteorder == '=' else byteorder


class NumpyBaseHandler(BaseHandler):
    def flatten_dtype(self, dtype, data):
        if hasattr(dtype, 'tostring'):
            data['dtype'] = dtype.tostring()
        else:
            dtype = compat.ustr(dtype)
            prefix = '(numpy.record, '
            if dtype.startswith(prefix):
                dtype = dtype[len(prefix) : -1]
            data['dtype'] = dtype

    def restore_dtype(self, data):
        dtype = data['dtype']
        if dtype.startswith(('{', '[')):
            dtype = ast.literal_eval(dtype)
        return np.dtype(dtype)


class NumpyDTypeHandler(NumpyBaseHandler):
    def flatten(self, obj, data):
        self.flatten_dtype(obj, data)
        return data

    def restore(self, data):
        return self.restore_dtype(data)


class NumpyGenericHandler(NumpyBaseHandler):
    def flatten(self, obj, data):
        self.flatten_dtype(obj.dtype.newbyteorder('N'), data)
        data['value'] = self.context.flatten(obj.tolist(), reset=False)
        return data

    def restore(self, data):
        value = self.context.restore(data['value'], reset=False)
        return self.restore_dtype(data).type(value)


class NumpyNDArrayHandler(NumpyBaseHandler):
    """Stores arrays as text representation, without regard for views"""

    def flatten_flags(self, obj, data):
        if obj.flags.writeable is False:
            data['writeable'] = False

    def restore_flags(self, data, arr):
        if not data.get('writeable', True):
            arr.flags.writeable = False

    def flatten(self, obj, data):
        self.flatten_dtype(obj.dtype.newbyteorder('N'), data)
        self.flatten_flags(obj, data)
        data['values'] = self.context.flatten(obj.tolist(), reset=False)
        if 0 in obj.shape:
            # add shape information explicitly as it cannot be
            # inferred from an empty list
            data['shape'] = obj.shape
        return data

    def restore(self, data):
        values = self.context.restore(data['values'], reset=False)
        arr = np.array(
            values, dtype=self.restore_dtype(data), order=data.get('order', 'C')
        )
        shape = data.get('shape', None)
        if shape is not None:
            arr = arr.reshape(shape)

        self.restore_flags(data, arr)
        return arr


class NumpyNDArrayHandlerBinary(NumpyNDArrayHandler):
    """stores arrays with size greater than 'size_threshold' as
    (optionally) compressed base64

    Notes
    -----
    This would be easier to implement using np.save/np.load, but
    that would be less language-agnostic
    """

    def __init__(self, size_threshold=16, compression=zlib):
        """
        :param size_threshold: nonnegative int or None
            valid values for 'size_threshold' are all nonnegative
            integers and None
            if size_threshold is None, values are always stored as nested lists
        :param compression: a compression module or None
            valid values for 'compression' are {zlib, bz2, None}
            if compresion is None, no compression is applied
        """
        self.size_threshold = size_threshold
        self.compression = compression

    def flatten_byteorder(self, obj, data):
        byteorder = obj.dtype.byteorder
        if byteorder != '|':
            data['byteorder'] = get_byteorder(obj)

    def restore_byteorder(self, data, arr):
        byteorder = data.get('byteorder', None)
        if byteorder:
            arr.dtype = arr.dtype.newbyteorder(byteorder)

    def flatten(self, obj, data):
        """encode numpy to json"""
        if self.size_threshold is None or self.size_threshold >= obj.size:
            # encode as text
            data = super(NumpyNDArrayHandlerBinary, self).flatten(obj, data)
        else:
            # encode as binary
            if obj.dtype == np.object:
                # There's a bug deep in the bowels of numpy that causes a
                # segfault when round-tripping an ndarray of dtype object.
                # E.g., the following will result in a segfault:
                #     import numpy as np
                #     arr = np.array([str(i) for i in range(3)],
                #                    dtype=np.object)
                #     dtype = arr.dtype
                #     shape = arr.shape
                #     buf = arr.tobytes()
                #     del arr
                #     arr = np.ndarray(buffer=buf, dtype=dtype,
                #                      shape=shape).copy()
                # So, save as a binary-encoded list in this case
                buf = json.dumps(obj.tolist()).encode()
            elif hasattr(obj, 'tobytes'):
                # numpy docstring is lacking as of 1.11.2,
                # but this is the option we need
                buf = obj.tobytes(order='a')
            else:
                # numpy < 1.9 compatibility
                buf = obj.tostring(order='a')
            if self.compression:
                buf = self.compression.compress(buf)
            data['values'] = b64encode(buf)
            data['shape'] = obj.shape
            self.flatten_dtype(obj.dtype.newbyteorder('N'), data)
            self.flatten_byteorder(obj, data)
            self.flatten_flags(obj, data)

            if not obj.flags.c_contiguous:
                data['order'] = 'F'

        return data

    def restore(self, data):
        """decode numpy from json"""
        values = data['values']
        if isinstance(values, list):
            # decode text representation
            arr = super(NumpyNDArrayHandlerBinary, self).restore(data)
        elif isinstance(values, numeric_types):
            # single-value array
            arr = np.array([values], dtype=self.restore_dtype(data))
        else:
            # decode binary representation
            dtype = self.restore_dtype(data)
            buf = b64decode(values)
            if self.compression:
                buf = self.compression.decompress(buf)
            # See note above about segfault bug for numpy dtype object. Those
            # are saved as a list to work around that.
            if dtype == np.object:
                values = json.loads(buf.decode())
                arr = np.array(values, dtype=dtype, order=data.get('order', 'C'))
                shape = data.get('shape', None)
                if shape is not None:
                    arr = arr.reshape(shape)
            else:
                arr = np.ndarray(
                    buffer=buf,
                    dtype=dtype,
                    shape=data.get('shape'),
                    order=data.get('order', 'C'),
                ).copy()  # make a copy, to force the result to own the data
                self.restore_byteorder(data, arr)
            self.restore_flags(data, arr)

        return arr


class NumpyNDArrayHandlerView(NumpyNDArrayHandlerBinary):
    """Pickles references inside ndarrays, or array-views

    Notes
    -----
    The current implementation has some restrictions.

    'base' arrays, or arrays which are viewed by other arrays,
    must be f-or-c-contiguous.
    This is not such a large restriction in practice, because all
    numpy array creation is c-contiguous by default.
    Relaxing this restriction would be nice though; especially if
    it can be done without bloating the design too much.

    Furthermore, ndarrays which are views of array-like objects
    implementing __array_interface__,
    but which are not themselves nd-arrays, are deepcopied with
    a warning (by default),
    as we cannot guarantee whatever custom logic such classes
    implement is correctly reproduced.
    """

    def __init__(self, mode='warn', size_threshold=16, compression=zlib):
        """
        :param mode: {'warn', 'raise', 'ignore'}
            How to react when encountering array-like objects whos
            references we cannot safely serialize
        :param size_threshold: nonnegative int or None
            valid values for 'size_threshold' are all nonnegative
            integers and None
            if size_threshold is None, values are always stored as nested lists
        :param compression: a compression module or None
            valid values for 'compression' are {zlib, bz2, None}
            if compresion is None, no compression is applied
        """
        super(NumpyNDArrayHandlerView, self).__init__(size_threshold, compression)
        self.mode = mode

    def flatten(self, obj, data):
        """encode numpy to json"""
        base = obj.base
        if base is None and obj.flags.forc:
            # store by value
            data = super(NumpyNDArrayHandlerView, self).flatten(obj, data)
            # ensure that views on arrays stored as text
            # are interpreted correctly
            if not obj.flags.c_contiguous:
                data['order'] = 'F'
        elif isinstance(base, np.ndarray) and base.flags.forc:
            # store by reference
            data['base'] = self.context.flatten(base, reset=False)

            offset = obj.ctypes.data - base.ctypes.data
            if offset:
                data['offset'] = offset

            if not obj.flags.c_contiguous:
                data['strides'] = obj.strides

            data['shape'] = obj.shape
            self.flatten_dtype(obj.dtype.newbyteorder('N'), data)
            self.flatten_flags(obj, data)

            if get_byteorder(obj) != '|':
                byteorder = 'S' if get_byteorder(obj) != get_byteorder(base) else None
                if byteorder:
                    data['byteorder'] = byteorder

            if self.size_threshold is None or self.size_threshold >= obj.size:
                # not used in restore since base is present, but
                # include values for human-readability
                super(NumpyNDArrayHandlerBinary, self).flatten(obj, data)
        else:
            # store a deepcopy or fail
            if self.mode == 'warn':
                msg = (
                    "ndarray is defined by reference to an object "
                    "we do not know how to serialize. "
                    "A deep copy is serialized instead, breaking "
                    "memory aliasing."
                )
                warnings.warn(msg)
            elif self.mode == 'raise':
                msg = (
                    "ndarray is defined by reference to an object we do "
                    "not know how to serialize."
                )
                raise ValueError(msg)
            data = super(NumpyNDArrayHandlerView, self).flatten(obj.copy(), data)

        return data

    def restore(self, data):
        """decode numpy from json"""
        base = data.get('base', None)
        if base is None:
            # decode array with owndata=True
            arr = super(NumpyNDArrayHandlerView, self).restore(data)
        else:
            # decode array view, which references the data of another array
            base = self.context.restore(base, reset=False)
            assert (
                base.flags.forc
            ), "Current implementation assumes base is C or F contiguous"

            arr = np.ndarray(
                buffer=base.data,
                dtype=self.restore_dtype(data).newbyteorder(data.get('byteorder', '|')),
                shape=data.get('shape'),
                offset=data.get('offset', 0),
                strides=data.get('strides', None),
            )

            self.restore_flags(data, arr)

        return arr


def register_handlers():
    register(np.dtype, NumpyDTypeHandler, base=True)
    register(np.generic, NumpyGenericHandler, base=True)
    register(np.ndarray, NumpyNDArrayHandlerView(), base=True)


def unregister_handlers():
    unregister(np.dtype)
    unregister(np.generic)
    unregister(np.ndarray)