from __future__ import annotations
__author__ = 'github.com/arm61'
import copy
from numbers import Number
from typing import Optional
from typing import Union
import numpy as np
from easyscience import global_object
from easyscience.Objects.ObjectClasses import BaseObj
from easyscience.Objects.variable import Parameter
from easyreflectometry.sample import BaseAssembly
from easyreflectometry.sample import Sample
from easyreflectometry.utils import get_as_parameter
from easyreflectometry.utils import yaml_dump
from .resolution_functions import PercentageFwhm
from .resolution_functions import ResolutionFunction
DEFAULTS = {
    'scale': {
        'description': 'Scaling of the reflectomety profile',
        'url': 'https://github.com/reflectivity/edu_outreach/blob/master/refl_maths/paper.tex',
        'value': 1.0,
        'min': 0,
        'max': np.inf,
        'fixed': True,
    },
    'background': {
        'description': 'Linear background to include in reflectometry data',
        'url': 'https://github.com/reflectivity/edu_outreach/blob/master/refl_maths/paper.tex',
        'value': 1e-8,
        'min': 0.0,
        'max': np.inf,
        'fixed': True,
    },
    'resolution': {
        'value': 5.0,
    },
}
[docs]
class Model(BaseObj):
    """Model is the class that represents the experiment.
    It is used to store the information about the experiment and to perform the calculations.
    """
    # Added in super().__init__
    name: str
    sample: Sample
    scale: Parameter
    background: Parameter
[docs]
    def __init__(
        self,
        sample: Union[Sample, None] = None,
        scale: Union[Parameter, Number, None] = None,
        background: Union[Parameter, Number, None] = None,
        resolution_function: Union[ResolutionFunction, None] = None,
        name: str = 'EasyModel',
        color: str = 'black',
        unique_name: Optional[str] = None,
        interface=None,
    ):
        """Constructor.
        :param sample: The sample being modelled.
        :param scale: Scaling factor of profile.
        :param background: Linear background magnitude.
        :param name: Name of the model, defaults to 'EasyModel'.
        :param resolution_function: Resolution function, defaults to PercentageFwhm.
        :param interface: Calculator interface, defaults to `None`.
        """
        if unique_name is None:
            unique_name = global_object.generate_unique_name(self.__class__.__name__)
        if sample is None:
            sample = Sample(interface=interface)
        if resolution_function is None:
            resolution_function = PercentageFwhm(DEFAULTS['resolution']['value'])
        scale = get_as_parameter('scale', scale, DEFAULTS)
        background = get_as_parameter('background', background, DEFAULTS)
        self.color = color
        super().__init__(
            name=name,
            unique_name=unique_name,
            sample=sample,
            scale=scale,
            background=background,
        )
        self.resolution_function = resolution_function
        # Must be set after resolution function
        self.interface = interface 
[docs]
    def add_assemblies(self, *assemblies: list[BaseAssembly]) -> None:
        """Add assemblies to the model sample.
        :param assemblies: Assemblies to add to model sample.
        """
        if not assemblies:
            self.sample.add_assembly()
            if self.interface is not None:
                self.interface().add_item_to_model(self.sample[-1].unique_name, self.unique_name)
        else:
            for assembly in assemblies:
                if issubclass(assembly.__class__, BaseAssembly):
                    self.sample.add_assembly(assembly)
                    if self.interface is not None:
                        self.interface().add_item_to_model(self.sample[-1].unique_name, self.unique_name)
                else:
                    raise ValueError(f'Object {assembly} is not a valid type, must be a child of BaseAssembly.') 
[docs]
    def duplicate_assembly(self, index: int) -> None:
        """Duplicate a given item or layer in a sample.
        :param idx: Index of the item or layer to duplicate
        """
        self.sample.duplicate_assembly(index)
        if self.interface is not None:
            self.interface().add_item_to_model(self.sample[-1].unique_name, self.unique_name) 
[docs]
    def remove_assembly(self, index: int) -> None:
        """Remove an assembly from the model.
        :param idx: Index of the item to remove.
        """
        assembly_unique_name = self.sample[index].unique_name
        self.sample.remove_assembly(index)
        if self.interface is not None:
            self.interface().remove_item_from_model(assembly_unique_name, self.unique_name) 
    @property
    def resolution_function(self) -> ResolutionFunction:
        """Return the resolution function."""
        return self._resolution_function
    @resolution_function.setter
    def resolution_function(self, resolution_function: ResolutionFunction) -> None:
        """Set the resolution function for the model."""
        self._resolution_function = resolution_function
        if self.interface is not None:
            self.interface().set_resolution_function(self._resolution_function)
    @property
    def interface(self):
        """
        Get the current interface of the object
        """
        return self._interface
    @interface.setter
    def interface(self, new_interface) -> None:
        """Set the interface for the model."""
        # From super class
        self._interface = new_interface
        if new_interface is not None:
            self.generate_bindings()
            self._interface().set_resolution_function(self._resolution_function)
    # Representation
    @property
    def _dict_repr(self) -> dict[str, dict[str, str]]:
        """A simplified dict representation."""
        if isinstance(self._resolution_function, PercentageFwhm):
            resolution_value = self._resolution_function.as_dict()['constant']
            resolution = f'{resolution_value} %'
        else:
            resolution = 'function of Q'
        return {
            self.name: {
                'scale': float(self.scale.value),
                'background': float(self.background.value),
                'resolution': resolution,
                'color': self.color,
                'sample': self.sample._dict_repr,
            }
        }
    def __repr__(self) -> str:
        """String representation of the layer."""
        return yaml_dump(self._dict_repr)
[docs]
    def as_dict(self, skip: Optional[list[str]] = None) -> dict:
        """Produces a cleaned dict using a custom as_dict method to skip necessary things.
        The resulting dict matches the parameters in __init__
        :param skip: List of keys to skip, defaults to `None`.
        """
        if skip is None:
            skip = []
        skip.extend(['sample', 'resolution_function', 'interface'])
        this_dict = super().as_dict(skip=skip)
        this_dict['sample'] = self.sample.as_dict(skip=skip)
        this_dict['resolution_function'] = self.resolution_function.as_dict(skip=skip)
        if self.interface is None:
            this_dict['interface'] = None
        else:
            this_dict['interface'] = self.interface().name
        return this_dict 
[docs]
    @classmethod
    def from_dict(cls, passed_dict: dict) -> Model:
        """
        Create a Model from a dictionary.
        :param this_dict: dictionary of the Model
        :return: Model
        """
        # Causes circular import if imported at the top
        from easyreflectometry.calculators import CalculatorFactory
        this_dict = copy.deepcopy(passed_dict)
        resolution_function = ResolutionFunction.from_dict(this_dict['resolution_function'])
        del this_dict['resolution_function']
        interface_name = this_dict['interface']
        del this_dict['interface']
        if interface_name is not None:
            interface = CalculatorFactory()
            interface.switch(interface_name)
        else:
            interface = None
        model = super().from_dict(this_dict)
        model.resolution_function = resolution_function
        model.interface = interface
        return model