#This file is part of Tryton.  The COPYRIGHT file at the top level of
#this repository contains the full copyright notices and license terms.
from trytond.model import fields
from trytond.pool import Pool, PoolMeta
from trytond.transaction import Transaction

__all__ = ['Payroll', 'Account']
__metaclass__ = PoolMeta


class Payroll:
    __name__ = "staff.payroll"
    analytic_accounts = fields.Many2One('analytic_account.account.selection',
        'Analytic Accounts')

    @classmethod
    def _view_look_dom_arch(cls, tree, type, field_children=None):
        AnalyticAccount = Pool().get('analytic_account.account')
        AnalyticAccount.convert_view(tree)
        arch, fields = super(Payroll, cls)._view_look_dom_arch(tree,
            type, field_children=field_children)
        return arch, fields

    @classmethod
    def fields_get(cls, fields_names=None):
        AnalyticAccount = Pool().get('analytic_account.account')

        res = super(Payroll, cls).fields_get(fields_names)

        analytic_accounts_field = super(Payroll, cls).fields_get(
                ['analytic_accounts'])['analytic_accounts']

        res.update(AnalyticAccount.analytic_accounts_fields_get(
                analytic_accounts_field, fields_names,
                states=cls.analytic_accounts.states))
        return res

    @classmethod
    def default_get(cls, fields, with_rec_name=True, with_on_change=True):
        fields = [x for x in fields if not x.startswith('analytic_account_')]
        return super(Payroll, cls).default_get(fields,
            with_rec_name=with_rec_name, with_on_change=with_on_change)

    @classmethod
    def read(cls, ids, fields_names=None):
        if fields_names:
            fields_names2 = [x for x in fields_names
                    if not x.startswith('analytic_account_')]
        else:
            fields_names2 = fields_names

        res = super(Payroll, cls).read(ids, fields_names=fields_names2)

        if not fields_names:
            fields_names = cls._fields.keys()

        root_ids = []
        for field in fields_names:
            if field.startswith('analytic_account_') and '.' not in field:
                root_ids.append(int(field[len('analytic_account_'):]))
        if root_ids:
            id2record = {}
            for record in res:
                id2record[record['id']] = record
            payrolls = cls.browse(ids)
            for payroll in payrolls:
                for root_id in root_ids:
                    id2record[payroll.id]['analytic_account_'
                        + str(root_id)] = None
                if not payroll.analytic_accounts:
                    continue
                for account in payroll.analytic_accounts.accounts:
                    if account.root.id in root_ids:
                        id2record[payroll.id]['analytic_account_'
                            + str(account.root.id)] = account.id
                        for field in fields_names:
                            if field.startswith('analytic_account_'
                                    + str(account.root.id) + '.'):
                                ham, field2 = field.split('.', 1)
                                id2record[payroll.id][field] = account[field2]
        return res

    @classmethod
    def create(cls, vlist):
        Selection = Pool().get('analytic_account.account.selection')
        vlist = [x.copy() for x in vlist]
        for vals in vlist:
            selection_vals = {}
            for field in vals.keys():
                if field.startswith('analytic_account_'):
                    if vals[field]:
                        selection_vals.setdefault('accounts', [])
                        selection_vals['accounts'].append(('add',
                                [vals[field]]))
                    del vals[field]
            if vals.get('analytic_accounts'):
                Selection.write([Selection(vals['analytic_accounts'])],
                    selection_vals)
            else:
                selection, = Selection.create([selection_vals])
                vals['analytic_accounts'] = selection.id
        return super(Payroll, cls).create(vlist)

    @classmethod
    def write(cls, *args):
        Selection = Pool().get('analytic_account.account.selection')

        actions = iter(args)
        args = []
        for payrolls, values in zip(actions, actions):
            values = values.copy()
            selection_vals = {}
            for field, value in values.items():
                if field.startswith('analytic_account_'):
                    root_id = int(field[len('analytic_account_'):])
                    selection_vals[root_id] = value
                    del values[field]
            if selection_vals:
                for payroll in payrolls:
                    accounts = []
                    if not payroll.analytic_accounts:
                        # Create missing selection
                        with Transaction().set_user(0):
                            selection, = Selection.create([{}])
                        cls.write([payroll], {
                            'analytic_accounts': selection.id,
                            })
                    for account in payroll.analytic_accounts.accounts:
                        if account.root.id in selection_vals:
                            value = selection_vals[account.root.id]
                            if value:
                                accounts.append(value)
                        else:
                            accounts.append(account.id)
                    for account_id in selection_vals.values():
                        if account_id \
                                and account_id not in accounts:
                            accounts.append(account_id)
                    to_remove = list(
                        set((a.id for a in payroll.analytic_accounts.accounts))
                        - set(accounts))
                    Selection.write([payroll.analytic_accounts], {
                            'accounts': [
                                ('remove', to_remove),
                                ('add', accounts),
                                ],
                            })
            args.extend((payrolls, values))
        return super(Payroll, cls).write(*args)

    @classmethod
    def delete(cls, lines):
        Selection = Pool().get('analytic_account.account.selection')

        selections = []
        for line in lines:
            if line.analytic_accounts:
                selections.append(line.analytic_accounts)

        super(Payroll, cls).delete(lines)
        Selection.delete(selections)

    @classmethod
    def copy(cls, payrolls, default=None):
        Selection = Pool().get('analytic_account.account.selection')

        new_lines = super(Payroll, cls).copy(payrolls, default=default)

        for payroll in payrolls:
            if payroll.analytic_accounts:
                selection, = Selection.copy([payroll.analytic_accounts])
                cls.write([payroll], {
                    'analytic_accounts': selection.id,
                    })
        return new_lines

    def get_moves_lines(self):
        moves_lines = super(Payroll, self).get_moves_lines()
        if self.analytic_accounts and self.analytic_accounts.accounts:
            for value in moves_lines:
                if value['debit'] == 0:
                    continue
                value['analytic_lines'] = []
                to_create = []
                for account in self.analytic_accounts.accounts:
                    vals = {}
                    vals['name'] = value['description']
                    vals['debit'] = value['debit']
                    vals['credit'] = value['credit']
                    vals['account'] = account.id
                    vals['journal'] = self.journal.id
                    vals['date'] = (self.date_effective)
                    vals['reference'] = self.number
                    vals['party'] = value['party']
                    to_create.append(vals)
                if to_create:
                    value['analytic_lines'] = [('create', to_create)]
        return moves_lines


class Account:
    __name__ = 'analytic_account.account'

    @classmethod
    def delete(cls, accounts):
        Payroll = Pool().get('staff.payroll')
        super(Account, cls).delete(accounts)
        # Restart the cache on the fields_view_get method of staff.payroll
        Payroll._fields_view_get_cache.clear()

    @classmethod
    def create(cls, vlist):
        Payroll = Pool().get('staff.payroll')
        accounts = super(Account, cls).create(vlist)
        # Restart the cache on the fields_view_get method of staff.payroll
        Payroll._fields_view_get_cache.clear()
        return accounts

    @classmethod
    def write(cls, accounts, values, *args):
        Payroll = Pool().get('staff.payroll')
        super(Account, cls).write(accounts, values, *args)
        # Restart the cache on the fields_view_get method of staff.payroll
        Payroll._fields_view_get_cache.clear()
