#!/usr/bin/python3

# TODO: export des fit en ascii should be versionned.

from typing import Any, List, Tuple

import sys
import os.path
import itertools

import h5py
import matplotlib.figure
import matplotlib.image
import numpy

from matplotlib.backends.backend_qt5agg import (
    FigureCanvasQTAgg, NavigationToolbar2QT)
from matplotlib.pyplot import Rectangle

from PyQt5.Qt import (Qt)
from PyQt5.QtCore import (pyqtSignal)
from PyQt5.QtWidgets import (
    QAction, QApplication, QSlider, QMenuBar, QTabWidget,
    QFileDialog, QStatusBar, QMessageBox, QRadioButton,
    QButtonGroup, QCheckBox, QPushButton, QHBoxLayout,
    QVBoxLayout, QSplitter, QTableWidgetItem, QTableWidget,
    QLabel, QLineEdit, QMainWindow, QWidget, QComboBox,
    QProgressDialog, QDoubleSpinBox)

from scipy.interpolate import griddata
from scipy.spatial.qhull import QhullError


def set_src():
    import sys
    import os.path as osp
    dirpath = osp.join(osp.dirname(osp.abspath(__file__)), osp.pardir)
    sys.path.insert(0, osp.abspath(dirpath))


try:
    import binoculars.main
    import binoculars.space
    import binoculars.plot
    import binoculars.fit
    import binoculars.util
except ImportError:
    # try to use code from src distribution
    set_src()
    import binoculars.main
    import binoculars.space
    import binoculars.plot
    import binoculars.fit
    import binoculars.util


class Window(QMainWindow):

    def __init__(self, parent=None):
        super(Window, self).__init__(parent)

        newproject = QAction("New project", self)
        newproject.triggered.connect(self.newproject)

        loadproject = QAction("Open project", self)
        loadproject.triggered.connect(self.loadproject)

        addspace = QAction("Import space", self)
        addspace.triggered.connect(self.add_to_project)

        menu_bar = QMenuBar()
        file = menu_bar.addMenu("&File")
        file.addAction(newproject)
        file.addAction(loadproject)
        file.addAction(addspace)

        self.setMenuBar(menu_bar)
        self.statusbar = QStatusBar()

        self.tab_widget = QTabWidget(self)
        self.tab_widget.setTabsClosable(True)
        self.tab_widget.tabCloseRequested.connect(self.tab_widget.removeTab)

        self.setCentralWidget(self.tab_widget)
        self.setMenuBar(menu_bar)
        self.setStatusBar(self.statusbar)

    def newproject(self):
        dialog = QFileDialog(self, "project filename")
        dialog.setNameFilters(['binoculars fit file (*.fit)'])
        dialog.setDefaultSuffix('fit')
        dialog.setFileMode(QFileDialog.AnyFile)
        dialog.setAcceptMode(QFileDialog.AcceptSave)
        if not dialog.exec_():
            return
        fname = dialog.selectedFiles()[0]
        if not fname:
            return
        try:
            widget = TopWidget(str(fname), parent=self)
            self.tab_widget.addTab(widget, short_filename(str(fname)))
            self.tab_widget.setCurrentWidget(widget)
        except Exception as e:
            QMessageBox.critical(
                self,
                'New project',
                'Unable to save project to {}: {}'.format(fname, e)
            )

    def loadproject(self, filename: str=None) -> None:
        if not filename:
            dialog = QFileDialog(self, "Load project")
            dialog.setNameFilters(['binoculars fit file (*.fit)'])
            dialog.setFileMode(QFileDialog.ExistingFiles)
            dialog.setAcceptMode(QFileDialog.AcceptOpen)
            if not dialog.exec_():
                return
            fname = dialog.selectedFiles()[0]
            if not fname:
                return
            try:
                widget = TopWidget(str(fname), parent=self)
                self.tab_widget.addTab(widget, short_filename(str(fname)))
                self.tab_widget.setCurrentWidget(widget)
            except Exception as e:
                QMessageBox.critical(
                    self,
                    'Load project',
                    'Unable to load project from {}: {}'.format(fname, e)
                )
        else:
            widget = TopWidget(str(fname), parent=self)
            self.tab_widget.addTab(widget, 'fname')
            self.tab_widget.setCurrentWidget(widget)

    def add_to_project(self, filename: str=None) -> None:
        if self.tab_widget.count() == 0:
            QMessageBox.warning(
                self, 'Warning', 'First select a file to store data')
            self.newproject()

        if not filename:
            dialog = QFileDialog(self, "Import spaces")
            dialog.setNameFilters(['binoculars space file (*.hdf5)'])
            dialog.setFileMode(QFileDialog.ExistingFiles)
            dialog.setAcceptMode(QFileDialog.AcceptOpen)
            if not dialog.exec_():
                return
            fname = dialog.selectedFiles()
            if not fname:
                return
            for name in fname:
                try:
                    widget = self.tab_widget.currentWidget()
                    widget.addspace(str(name))
                except Exception as e:
                    QMessageBox.critical(
                        self,
                        'Import spaces',
                        'Unable to import space {}: {}'.format(fname, e)
                    )
        else:
            widget = self.tab_widget.currentWidget()
            widget.addspace(filename)


class TopWidget(QWidget):

    def __init__(self, filename, parent=None):
        super(TopWidget, self).__init__(parent)

        hbox = QHBoxLayout()
        vbox = QVBoxLayout()
        minihbox = QHBoxLayout()
        minihbox2 = QHBoxLayout()

        self.database = FitData(filename)
        self.table = TableWidget(self.database)
        self.nav = ButtonedSlider()
        self.nav.slice_index.connect(self.index_change)
        self.table.trigger.connect(self.active_change)
        self.table.check_changed.connect(self.refresh_plot)
        self.tab_widget = QTabWidget()

        self.fitwidget = FitWidget(self.database, self)
        self.integratewidget = IntegrateWidget(self.database, self)
        self.plotwidget = OverviewWidget(self.database, self)
        self.peakwidget = PeakWidget(self.database, self)

        self.tab_widget.addTab(self.fitwidget, 'Fit')
        self.tab_widget.addTab(self.integratewidget, 'Integrate')
        self.tab_widget.addTab(self.plotwidget, 'plot')
        self.tab_widget.addTab(self.peakwidget, 'Peaktracker')

        self.emptywidget = QWidget()
        self.emptywidget.setLayout(vbox)

        vbox.addWidget(self.table)
        vbox.addWidget(self.nav)

        self.functions = list()
        self.function_box = QComboBox()
        for function in dir(binoculars.fit):
            cls = getattr(binoculars.fit, function)
            if isinstance(cls, type)\
               and issubclass(cls, binoculars.fit.PeakFitBase):
                self.functions.append(cls)
                self.function_box.addItem(function)
        self.function_box.setCurrentIndex(
            self.function_box.findText('PolarLorentzian2D')
        )

        vbox.addWidget(self.function_box)
        vbox.addLayout(minihbox)
        vbox.addLayout(minihbox2)

        self.all_button = QPushButton('fit all')
        self.rod_button = QPushButton('fit rod')
        self.slice_button = QPushButton('fit slice')

        self.all_button.clicked.connect(self.fit_all)
        self.rod_button.clicked.connect(self.fit_rod)
        self.slice_button.clicked.connect(self.fit_slice)

        minihbox.addWidget(self.all_button)
        minihbox.addWidget(self.rod_button)
        minihbox.addWidget(self.slice_button)

        self.allint_button = QPushButton('int all')
        self.rodint_button = QPushButton('int rod')
        self.sliceint_button = QPushButton('int slice')

        self.allint_button.clicked.connect(self.int_all)
        self.rodint_button.clicked.connect(self.int_rod)
        self.sliceint_button.clicked.connect(self.int_slice)

        minihbox2.addWidget(self.allint_button)
        minihbox2.addWidget(self.rodint_button)
        minihbox2.addWidget(self.sliceint_button)

        splitter = QSplitter(Qt.Horizontal)

        splitter.addWidget(self.emptywidget)
        splitter.addWidget(self.tab_widget)
        self.tab_widget.currentChanged.connect(self.tab_change)

        hbox.addWidget(splitter)
        self.setLayout(hbox)

    def tab_change(self, index: int):
        if index == 2:
            self.refresh_plot()

    def addspace(self, filename: str=None) -> None:
        self.table.addspace(filename or str(QFileDialog.getOpenFileName(self, 'Open Project', '.', '*.hdf5')))  # noqa

    def active_change(self) -> None:
        rodkey, axis, resolution = self.table.currentkey()
        newdatabase = RodData(self.database.filename, rodkey, axis, resolution)
        self.integratewidget.database = newdatabase
        self.peakwidget.database = newdatabase
        self.integratewidget.set_axis()
        self.peakwidget.set_axis()
        self.fitwidget.database = newdatabase
        self.nav.set_length(newdatabase.rodlength())
        index = newdatabase.load('index')
        if index is None:
            index = 0
        self.nav.set_index(index)
        self.index_change(index)

    def index_change(self, index: int) -> None:
        if index is None:
            index = 0
        self.fitwidget.database.save('index', self.nav.index())
        self.fitwidget.plot(index)
        self.integratewidget.plot(index)

    def refresh_plot(self) -> None:
        self.plotwidget.refresh(
            [RodData(self.database.filename, rodkey, axis, resolution)
             for rodkey, axis, resolution in self.table.checked()]
        )

    @property
    def fitclass(self):
        return self.functions[self.function_box.currentIndex()]

    def fit_slice(self) -> None:
        index = self.nav.index()
        space = self.fitwidget.database.space_from_index(index)
        self.fitwidget.fit(index, space, self.fitclass)
        self.fit_loc(self.fitwidget.database)
        self.fitwidget.plot(index)

    def fit_rod(self) -> None:
        def function(index: int, space):
            self.fitwidget.fit(index, space, self.fitclass)
        self.progressbox(self.fitwidget.database.rodkey, function, enumerate(
            self.fitwidget.database), self.fitwidget.database.rodlength())
        self.fit_loc(self.fitwidget.database)
        self.fitwidget.plot()

    def fit_all(self) -> None:
        def function(index: int, space):
            self.fitwidget.fit(index, space, self.fitclass)

        for rodkey, axis, resolution in self.table.checked():
            self.fitwidget.database = RodData(
                self.database.filename, rodkey, axis, resolution)
            self.progressbox(
                self.fitwidget.database.rodkey,
                function,
                enumerate(self.fitwidget.database),
                self.fitwidget.database.rodlength()
            )
            self.fit_loc(self.fitwidget.database)

        self.fitwidget.plot()

    def int_slice(self) -> None:
        index = self.nav.index()
        space = self.fitwidget.database.space_from_index(index)
        self.integratewidget.integrate(index, space)
        self.integratewidget.plot(index)

    def int_rod(self) -> None:
        self.progressbox(
            self.integratewidget.database.rodkey,
            self.integratewidget.integrate,
            enumerate(self.integratewidget.database),
            self.integratewidget.database.rodlength()
        )
        self.integratewidget.plot()

    def int_all(self) -> None:
        for rodkey, axis, resolution in self.table.checked():
            self.integratewidget.database = RodData(
                self.database.filename, rodkey, axis, resolution)
            self.progressbox(
                self.integratewidget.database.rodkey,
                self.integratewidget.integrate,
                enumerate(self.integratewidget.database),
                self.integratewidget.database.rodlength()
            )
        self.integratewidget.plot()

    def fit_loc(self, database):
        deg = 2
        for param in database.all_attrkeys():
            if param.startswith('loc'):
                x, y = database.all_from_key(param)
                x, yvar = database.all_from_key('var_{0}'.format(param))
                cx = x[numpy.invert(y.mask)]
                y = y.compressed()
                yvar = yvar.compressed()

                w = numpy.log(1 / yvar)
                w[w == numpy.inf] = 0
                w = numpy.nan_to_num(w)
                w[w < 0] = 0
                w[w < numpy.median(w)] = 0
                if len(x) > 0:
                    c = numpy.polynomial.polynomial.polyfit(cx, y, deg, w=w)
                    newy = numpy.polynomial.polynomial.polyval(x, c)
                    for index, newval in enumerate(newy):
                        database.save_sliceattr(
                            index,
                            'guessloc{0}'.format(param.lstrip('loc')),
                            newval
                        )

    def progressbox(self, rodkey, function, iterator, length):
        pd = QProgressDialog(
            'Processing {0}'.format(rodkey), 'Cancel', 0, length)
        pd.setWindowModality(Qt.WindowModal)
        pd.show()

        def progress(index, item):
            pd.setValue(index)
            if pd.wasCanceled():
                raise KeyboardInterrupt
            QApplication.processEvents()
            function(*item)
        for index, item in enumerate(iterator):
            progress(index, item)
        pd.close()


class TableWidget(QWidget):
    trigger = pyqtSignal()
    check_changed = pyqtSignal()

    def __init__(self, database, parent=None):
        super(TableWidget, self).__init__(parent)

        hbox = QHBoxLayout()
        self.database = database

        self.activeindex = 0

        self.table = QTableWidget(0, 5)
        self.table.setHorizontalHeaderLabels(
            ['', 'rod', 'axis', 'res', 'remove'])

        self.table.cellClicked.connect(self.setlength)

        for index, width in enumerate([25, 150, 40, 50, 70]):
            self.table.setColumnWidth(index, width)

        for filename, rodkey in zip(database.filelist, database.rods()):
            self.addspace(filename, rodkey)

        hbox.addWidget(self.table)
        self.setLayout(hbox)

    def addspace(self, filename: str, rodkey=None) -> None:
        def remove_callback(rodkey):
            return lambda: self.remove(rodkey)

        def activechange_callback(index):
            return lambda: self.setlength(index, 1)

        if rodkey is None:
            rodkey = short_filename(filename)
            if rodkey in self.database.rods():
                newkey = find_unused_rodkey(rodkey, self.database.rods())
                self.database.copy(rodkey, newkey)
                rodkey = newkey

        old_axis, old_resolution = self.database.load(
            rodkey, 'axis'), self.database.load(rodkey, 'resolution')
        self.database.create_rod(rodkey, filename)
        index = self.table.rowCount()
        self.table.insertRow(index)

        axes = binoculars.space.Axes.fromfile(filename)

        checkboxwidget = QCheckBox()
        checkboxwidget.rodkey = rodkey
        checkboxwidget.setChecked(0)
        self.table.setCellWidget(index, 0, checkboxwidget)
        checkboxwidget.clicked.connect(self.check_changed)

        item = QTableWidgetItem(rodkey)
        self.table.setItem(index, 1, item)

        axis = QComboBox()
        for ax in axes:
            axis.addItem(ax.label)
        self.table.setCellWidget(index, 2, axis)
        if old_axis is not None:
            self.table.cellWidget(
                index, 2).setCurrentIndex(axes.index(old_axis))
        elif index > 0:
            self.table.cellWidget(0, 2).setCurrentIndex(
                self.table.cellWidget(0, 2).currentIndex())

        resolution = QLineEdit()
        if old_resolution is not None:
            resolution.setText(str(old_resolution))
        elif index > 0:
            resolution.setText(self.table.cellWidget(0, 3).text())
        else:
            resolution.setText(
                str(axes[axes.index(str(axis.currentText()))].res))

        resolution.editingFinished.connect(activechange_callback(index))
        self.table.setCellWidget(index, 3, resolution)

        buttonwidget = QPushButton('remove')
        buttonwidget.clicked.connect(remove_callback(rodkey))
        self.table.setCellWidget(index, 4, buttonwidget)

    def remove(self, rodkey) -> None:
        table_rodkeys = [self.table.cellWidget(index, 0).rodkey
                         for index in range(self.table.rowCount())]
        for index, label in enumerate(table_rodkeys):
            if rodkey == label:
                self.table.removeRow(index)
        self.database.delete_rod(rodkey)
        print('removed: {0}'.format(rodkey))

    def setlength(self, y: int, x: int=1) -> None:
        if x == 1:
            self.activeindex = y
            rodkey, axis, resolution = self.currentkey()
            self.database.save(rodkey, 'axis', axis)
            self.database.save(rodkey, 'resolution', resolution)
            self.trigger.emit()

    def currentkey(self) -> Tuple[Any, str, float]:
        rodkey = self.table.cellWidget(self.activeindex, 0).rodkey
        axis = str(self.table.cellWidget(self.activeindex, 2).currentText())
        resolution = float(self.table.cellWidget(self.activeindex, 3).text())
        return rodkey, axis, resolution

    def checked(self) -> List[Tuple[Any, str, float]]:
        selection = []
        for index in range(self.table.rowCount()):
            checkbox = self.table.cellWidget(index, 0)
            if checkbox.checkState():
                rodkey = self.table.cellWidget(index, 0).rodkey
                axis = str(self.table.cellWidget(index, 2).currentText())
                resolution = float(self.table.cellWidget(index, 3).text())
                selection.append((rodkey, axis, resolution))
        return selection


class FitData(object):

    def __init__(self, filename):
        self.filename = filename
        self.axdict = dict()

        with h5py.File(self.filename, 'a') as db:
            for rodkey in self.rods():
                spacename = db[rodkey].attrs['filename']
                if not os.path.exists(spacename):
                    warningbox = QMessageBox(
                        2,
                        'Warning',
                        'Cannot find space {0} at file {1}; locate proper space'.format(rodkey, spacename),  # noqa
                        buttons=QMessageBox.Open
                    )
                    warningbox.exec_()
                    spacename = str(QFileDialog.getOpenFileName(
                        caption='Open space {0}'.format(rodkey),
                        directory='.',
                        filter='*.hdf5')
                    )
                    if not spacename:
                        raise IOError('Select proper input')
                    db[rodkey].attrs['filename'] = spacename
                self.axdict[rodkey] = binoculars.space.Axes.fromfile(spacename)

    def create_rod(self, rodkey, spacename):
        with h5py.File(self.filename, 'a') as db:
            if rodkey not in list(db.keys()):
                db.create_group(rodkey)
                db[rodkey].attrs['filename'] = spacename
                self.axdict[rodkey] = binoculars.space.Axes.fromfile(spacename)

    def delete_rod(self, rodkey):
        with h5py.File(self.filename, 'a') as db:
            del db[rodkey]

    def rods(self):
        with h5py.File(self.filename, 'a') as db:
            rods = list(db.keys())
        return rods

    def copy(self, oldkey, newkey):
        with h5py.File(self.filename, 'a') as db:
            if oldkey in list(db.keys()):
                db.copy(db[oldkey], db, name=newkey)

    @property
    def filelist(self):
        filelist = []
        with h5py.File(self.filename, 'a') as db:
            for key in db.keys():
                filelist.append(db[key].attrs['filename'])
        return filelist

    def save(self, rodkey, key, value):
        with h5py.File(self.filename, 'a') as db:
            db[rodkey].attrs[str(key)] = value

    def load(self, rodkey, key):
        with h5py.File(self.filename, 'a') as db:
            if rodkey in db:
                if key in db[rodkey].attrs:
                    return db[rodkey].attrs[str(key)]
            else:
                return None


class RodData(FitData):

    def __init__(self, filename, rodkey, axis, resolution):
        super(RodData, self).__init__(filename)
        self.rodkey = rodkey
        self.slicekey = '{0}_{1}'.format(axis, resolution)
        self.axis = axis
        self.resolution = resolution

        with h5py.File(self.filename, 'a') as db:
            if rodkey in db:
                if self.slicekey not in db[rodkey]:
                    db[rodkey].create_group(self.slicekey)
                    db[rodkey][self.slicekey].create_group('attrs')

    def save(self, key, value):
        super(RodData, self).save(self.rodkey, key, value)

    def load(self, key):
        return super(RodData, self).load(self.rodkey, key)

    def paxes(self):
        axes = self.axdict[self.rodkey]
        projected = list(axes)
        axindex = axes.index(self.axis)
        projected.pop(axindex)
        return projected

    def get_bins(self):
        axes = self.axdict[self.rodkey]
        axindex = axes.index(self.axis)
        ax = axes[axindex]

        bins = binoculars.space.get_bins(ax, self.resolution)
        return bins, ax, axindex

    def rodlength(self):
        bins, ax, axindex = self.get_bins()
        return numpy.alen(bins) - 1

    def get_index_value(self, index):
        values = binoculars.space.get_axis_values(self.axdict[self.rodkey],
                                                  self.axis,
                                                  self.resolution)
        return values[index]

    def get_key(self, index):
        axes = self.axdict[self.rodkey]
        bins, ax, axindex = self.get_bins()
        start, stop = bins[index], bins[index + 1]
        k = [slice(None) for i in axes]
        k[axindex] = slice(start, stop)
        return k

    def space_from_index(self, index):
        with h5py.File(self.filename, 'a') as db:
            filename = db[self.rodkey].attrs['filename']
        space = binoculars.space.Space.fromfile(filename, self.get_key(index))
        return space.project(self.axis)

    def save_data(self, index, key, data):
        with h5py.File(self.filename, 'a') as db:
            id = '{0}_{1}_data'.format(int(index), key)
            mid = '{0}_{1}_mask'.format(int(index), key)
            grp = db[self.rodkey][self.slicekey]
            try:  # TODO check instead of waiting for exception...
                dataset = grp.create_dataset(id, data.shape, dtype=data.dtype, compression='gzip')  # noqa
                dataset.write_direct(data)

                dataset = grp.create_dataset(mid, data.shape, dtype=data.mask.dtype, compression='gzip')  # noqa
                dataset.write_direct(data.mask)
            except RuntimeError:
                del grp[id]
                del grp[mid]

                dataset = grp.create_dataset(id, data.shape, dtype=data.dtype, compression='gzip')  # noqa
                dataset.write_direct(data)

                dataset = grp.create_dataset(mid, data.shape, dtype=data.mask.dtype, compression='gzip')  # noqa
                dataset.write_direct(data.mask)

    def load_data(self, index, key):
        with h5py.File(self.filename, 'a') as db:
            id = '{0}_{1}_data'.format(int(index), key)
            mid = '{0}_{1}_mask'.format(int(index), key)
            try:
                return numpy.ma.array(
                    db[self.rodkey][self.slicekey][id][...],
                    mask=db[self.rodkey][self.slicekey][mid][...]
                )
            except KeyError:
                return None

    def save_sliceattr(self, index, key, value):
        mkey = 'mask{0}'.format(key)
        with h5py.File(self.filename, 'a') as db:
            try:
                group = db[self.rodkey][self.slicekey][
                    'attrs']  # # else it breaks with the old fitaid
            except KeyError:
                db[self.rodkey][self.slicekey].create_group('attrs')
                group = db[self.rodkey][self.slicekey]['attrs']
            if key not in group:
                dataset = group.create_dataset(key, (self.rodlength(),))
                dataset = group.create_dataset(mkey, (self.rodlength(),),
                                               dtype=numpy.bool)
                dataset.write_direct(
                    numpy.ones(self.rodlength(), dtype=numpy.bool)
                )
            group[key][index] = value
            group[mkey][index] = 0

    def load_sliceattr(self, index, key):
        mkey = 'mask{0}'.format(key)
        with h5py.File(self.filename, 'a') as db:
            try:
                group = db[self.rodkey][self.slicekey]['attrs']
            except KeyError:
                db[self.rodkey][self.slicekey].create_group('attrs')
                group = db[self.rodkey][self.slicekey]['attrs']
            if key in list(group.keys()):
                return numpy.ma.array(group[key][index],
                                      mask=group[mkey][index])
            else:
                return None

    def all_attrkeys(self):
        with h5py.File(self.filename, 'a') as db:
            group = db[self.rodkey][self.slicekey]['attrs']
            return list(group.keys())

    def all_from_key(self, key):
        mkey = 'mask{0}'.format(key)
        axes = self.axdict[self.rodkey]
        with h5py.File(self.filename, 'a') as db:
            group = db[self.rodkey][self.slicekey]['attrs']
            if key in list(group.keys()):
                return (binoculars.space.get_axis_values(axes, self.axis,
                                                         self.resolution),
                        numpy.ma.array(group[key],
                                       mask=numpy.array(group[mkey])))

    def load_loc(self, index):
        loc = list()
        count = itertools.count()
        key = 'guessloc{0}'.format(next(count))
        while self.load_sliceattr(index, key) is not None:
            loc.append(self.load_sliceattr(index, key))
            key = 'guessloc{0}'.format(next(count))
        if len(loc) > 0:
            return loc
        else:
            count = itertools.count()
            key = 'loc{0}'.format(next(count))
            while self.load_sliceattr(index, key) is not None:
                loc.append(self.load_sliceattr(index, key))
                key = 'loc{0}'.format(next(count))
            if len(loc) > 0:
                return loc
            else:
                return None

    def save_loc(self, index, loc):
        for i, value in enumerate(loc):
            self.save_sliceattr(index, 'guessloc{0}'.format(i), value)

    def save_segments(self, segments):
        with h5py.File(self.filename, 'a') as db:
            grp = db[self.rodkey][self.slicekey]
            try:  # TODO check instead of exception.
                dataset = grp.create_dataset('segment', segments.shape,
                                             dtype=segments.dtype,
                                             compression='gzip')
                dataset.write_direct(segments)
            except RuntimeError:
                del grp['segment']
                dataset = grp.create_dataset('segment', segments.shape,
                                             dtype=segments.dtype,
                                             compression='gzip')
                dataset.write_direct(segments)

    def load_segments(self):
        with h5py.File(self.filename, 'a') as db:
            try:
                return numpy.array(db[self.rodkey][self.slicekey]['segment'][:])  # noqa
            except KeyError:
                return None

    def __iter__(self):
        for index in range(self.rodlength()):
            yield self.space_from_index(index)


def short_filename(filename):
    return filename.split('/')[-1].split('.')[0]


class HiddenToolbar(NavigationToolbar2QT):

    def __init__(self, corner_callback, canvas):
        super(HiddenToolbar, self).__init__(canvas, None)
        self._corner_callback = corner_callback
        self.zoom()

    def _generate_key(self):
        limits = []
        for a in self.canvas.figure.get_axes():
            limits.append([a.get_xlim(), a.get_ylim()])
        return limits

    def press(self, event):
        self._corner_preclick = self._generate_key()

    def release(self, event):
        if self._corner_preclick == self._generate_key():
            self._corner_callback(event.xdata, event.ydata)
        self._corner_preclick = None


class FitWidget(QWidget):

    def __init__(self, database, parent=None):
        super(FitWidget, self).__init__(parent)

        self.database = database
        vbox = QHBoxLayout()

        self.figure = matplotlib.figure.Figure()
        self.canvas = FigureCanvasQTAgg(self.figure)
        self.toolbar = HiddenToolbar(self.loc_callback, self.canvas)

        vbox.addWidget(self.canvas)
        self.setLayout(vbox)

    def loc_callback(self, x, y):
        if self.ax:
            self.database.save_loc(self.currentindex(), numpy.array([x, y]))

    def plot(self, index=None):
        if index is None:
            index = self.currentindex()
        space = self.database.space_from_index(index)
        fitdata = self.database.load_data(index, 'fit')
        self.figure.clear()
        self.figure.space_axes = space.axes
        info = self.database.get_index_value(index)
        label = self.database.axis

        if fitdata is not None:
            if space.dimension == 1:
                self.ax = self.figure.add_subplot(111)
                binoculars.plot.plot(
                    space, self.figure, self.ax, fit=fitdata)
            elif space.dimension == 2:
                self.ax = self.figure.add_subplot(121)
                binoculars.plot.plot(space, self.figure, self.ax, fit=None)
                self.ax = self.figure.add_subplot(122)
                binoculars.plot.plot(
                    space, self.figure, self.ax, fit=fitdata)
        else:
            self.ax = self.figure.add_subplot(111)
            binoculars.plot.plot(space, self.figure, self.ax)
        self.figure.suptitle('{0}, res = {1}, {2} = {3}'.format(
            self.database.rodkey, self.database.resolution, label, info))
        self.canvas.draw()

    def fit(self, index, space, function):
        print(index)
        if not len(space.get_masked().compressed()) == 0:
            loc = self.get_loc()
            fit = function(space, loc=loc)
            fit.fitdata.mask = space.get_masked().mask
            self.database.save_data(index, 'fit',  fit.fitdata)
            params = list(line.split(':')[0]
                          for line in fit.summary.split('\n'))
            print(fit.result, fit.variance)
            for key, value in zip(params, fit.result):
                self.database.save_sliceattr(index, key, value)
            for key, value in zip(params, fit.variance):
                self.database.save_sliceattr(
                    index, 'var_{0}'.format(key), value)

    def get_loc(self):
        return self.database.load_loc(self.currentindex())

    def currentindex(self):
        index = self.database.load('index')
        if index is None:
            return 0
        else:
            return index


class IntegrateWidget(QWidget):

    def __init__(self, database, parent=None):
        super(IntegrateWidget, self).__init__(parent)
        self.parent = parent
        self.database = database

        self.figure = matplotlib.figure.Figure()
        self.canvas = FigureCanvasQTAgg(self.figure)
        self.toolbar = HiddenToolbar(self.loc_callback, self.canvas)

        hbox = QHBoxLayout()

        splitter = QSplitter(Qt.Vertical)
        self.make_controlwidget()

        splitter.addWidget(self.canvas)
        splitter.addWidget(self.control_widget)

        hbox.addWidget(splitter)
        self.setLayout(hbox)

    def make_controlwidget(self):
        self.control_widget = QWidget()

        integratebox = QVBoxLayout()
        intensitybox = QHBoxLayout()
        backgroundbox = QHBoxLayout()

        self.aroundroi = QCheckBox('background around roi')
        self.aroundroi.setChecked(1)
        self.aroundroi.clicked.connect(self.refresh_aroundroi)

        self.hsize = QDoubleSpinBox()
        self.vsize = QDoubleSpinBox()

        intensitybox.addWidget(QLabel('roi size:'))
        intensitybox.addWidget(self.hsize)
        intensitybox.addWidget(self.vsize)

        self.left = QDoubleSpinBox()
        self.right = QDoubleSpinBox()
        self.top = QDoubleSpinBox()
        self.bottom = QDoubleSpinBox()

        self.hsize.valueChanged.connect(self.send)
        self.vsize.valueChanged.connect(self.send)
        self.left.valueChanged.connect(self.send)
        self.right.valueChanged.connect(self.send)
        self.top.valueChanged.connect(self.send)
        self.bottom.valueChanged.connect(self.send)

        backgroundbox.addWidget(self.aroundroi)
        backgroundbox.addWidget(self.left)
        backgroundbox.addWidget(self.right)
        backgroundbox.addWidget(self.top)
        backgroundbox.addWidget(self.bottom)

        integratebox.addLayout(intensitybox)
        integratebox.addLayout(backgroundbox)

        self.fromfit = QRadioButton('peak from fit', self)
        self.fromfit.setChecked(True)
        self.fromfit.toggled.connect(self.plot_box)
        self.fromfit.toggled.connect(self.refresh_tracker)

        self.fromsegment = QRadioButton('peak from segment', self)
        self.fromsegment.setChecked(False)
        self.fromsegment.toggled.connect(self.plot_box)
        self.fromsegment.toggled.connect(self.refresh_tracker)

        self.trackergroup = QButtonGroup(self)
        self.trackergroup.addButton(self.fromfit)
        self.trackergroup.addButton(self.fromsegment)

        radiobox = QHBoxLayout()
        radiobox.addWidget(self.fromfit)
        radiobox.addWidget(self.fromsegment)

        integratebox.addLayout(radiobox)

        self.control_widget.setLayout(integratebox)

    def refresh_aroundroi(self):
        self.database.save('aroundroi', self.aroundroi.checkState())
        axes = self.database.paxes()
        if not self.aroundroi.checkState():
            self.left.setMinimum(axes[0].min)
            self.left.setMaximum(axes[0].max)
            self.right.setMinimum(axes[0].min)
            self.right.setMaximum(axes[0].max)
            self.top.setMinimum(axes[1].min)
            self.top.setMaximum(axes[1].max)
            self.bottom.setMinimum(axes[1].min)
            self.bottom.setMaximum(axes[1].max)
        else:
            self.left.setMinimum(0)
            self.left.setMaximum(axes[0].max - axes[0].min)
            self.right.setMinimum(0)
            self.right.setMaximum(axes[0].max - axes[0].min)
            self.top.setMinimum(0)
            self.top.setMaximum(axes[1].max - axes[1].min)
            self.bottom.setMinimum(0)
            self.bottom.setMaximum(axes[1].max - axes[1].min)

    def refresh_tracker(self):
        self.database.save('fromfit', self.fromfit.isChecked())
        self.plot_box()

    def set_axis(self):
        roi = self.database.load('roi')

        aroundroi = self.database.load('aroundroi')
        if aroundroi is not None:
            self.aroundroi.setChecked(aroundroi)
        else:
            self.aroundroi.setChecked(True)
        self.refresh_aroundroi()

        axes = self.database.paxes()

        self.hsize.setSingleStep(axes[1].res)
        self.hsize.setDecimals(len(str(axes[1].res)) - 2)
        self.vsize.setSingleStep(axes[0].res)
        self.vsize.setDecimals(len(str(axes[0].res)) - 2)
        self.left.setSingleStep(axes[1].res)
        self.left.setDecimals(len(str(axes[1].res)) - 2)
        self.right.setSingleStep(axes[1].res)
        self.right.setDecimals(len(str(axes[1].res)) - 2)
        self.top.setSingleStep(axes[0].res)
        self.top.setDecimals(len(str(axes[0].res)) - 2)
        self.bottom.setSingleStep(axes[0].res)
        self.bottom.setDecimals(len(str(axes[0].res)) - 2)

        tracker = self.database.load('fromfit')
        if tracker is not None:
            if tracker:
                self.fromfit.setChecked(True)
            else:
                self.fromsegment.setChecked(True)

        if roi is not None:
            boxes = [self.hsize, self.vsize, self.left, self.right, self.top, self.bottom]  # noqa
            for box, value in zip(boxes, roi):
                box.setValue(value)

    def send(self):
        roi = [self.hsize.value(), self.vsize.value(), self.left.value(),
               self.right.value(), self.top.value(), self.bottom.value()]
        self.database.save('roi', roi)
        self.plot_box()

    def integrate(self, index, space):
        loc = self.get_loc()
        if loc is not None:
            axes = space.axes

            key = space.get_key(self.intkey(loc, axes))

            fitdata = self.database.load_data(index, 'fit')
            if fitdata is not None:
                fitintensity = fitdata[key].data.flatten()
                fitbkg = numpy.hstack([fitdata[space.get_key(bkgkey)].data.flatten()  # noqa
                                       for bkgkey in self.bkgkeys(loc, axes)])
                if numpy.alen(fitbkg) == 0:
                    fitstructurefactor = fitintensity.sum()
                elif numpy.alen(fitintensity) == 0:
                    fitstructurefactor = numpy.nan
                else:
                    fitstructurefactor = numpy.sqrt(fitintensity.sum() - numpy.alen(fitintensity) * 1.0 / numpy.alen(fitbkg) * fitbkg.sum())  # noqa
                self.database.save_sliceattr(
                    index, 'fitsf', fitstructurefactor)

            niintensity = space[
                self.intkey(loc, axes)].get_masked().compressed()

            try:
                intensity = interpolate(
                    space[self.intkey(loc, axes)]).flatten()
                bkg = numpy.hstack([space[bkgkey].get_masked().compressed()
                                    for bkgkey in self.bkgkeys(loc, axes)])
                interdata = space.get_masked()
                interdata[key] = intensity.reshape(interdata[key].shape)
                interdata[key].mask = numpy.zeros_like(interdata[key])
                self.database.save_data(index, 'inter',  interdata)
            except ValueError as e:
                print('Warning error interpolating silce {0}: {1}'.format(index, e))  # noqa
                intensity = numpy.array([])
                bkg = numpy.array([])
            except QhullError as e:
                print('Warning error interpolating silce {0}: {1}'.format(index, e))  # noqa
                intensity = numpy.array([])
                bkg = numpy.array([])

            if numpy.alen(intensity) == 0:
                structurefactor = numpy.nan
                nistructurefactor = numpy.nan
            elif numpy.alen(bkg) == 0:
                structurefactor = numpy.sqrt(intensity.sum())
                nistructurefactor = numpy.sqrt(niintensity.sum())
            else:
                structurefactor = numpy.sqrt(intensity.sum() - numpy.alen(intensity) * 1.0 / numpy.alen(bkg) * bkg.sum())  # noqa
                nistructurefactor = numpy.sqrt(niintensity.sum() - numpy.alen(niintensity) * 1.0 / numpy.alen(bkg) * bkg.sum())  # noqa

            self.database.save_sliceattr(index, 'sf', structurefactor)
            self.database.save_sliceattr(index, 'nisf', nistructurefactor)

            print('Structurefactor {0}: {1}'.format(index, structurefactor))

    def intkey(self, coords, axes):
        vsize = self.vsize.value() / 2
        hsize = self.hsize.value() / 2
        return tuple(ax.restrict(slice(coord - size, coord + size))
                     for ax, coord, size in zip(axes, coords, [vsize, hsize]))

    def bkgkeys(self, coords, axes):
        aroundroi = self.database.load('aroundroi')
        if aroundroi:
            key = self.intkey(coords, axes)

            vsize = self.vsize.value() / 2
            hsize = self.hsize.value() / 2

            leftkey = (key[0], axes[1].restrict(slice(coords[1] - hsize - self.left.value(), coords[1] - hsize)))  # noqa
            rightkey = (key[0], axes[1].restrict(slice(coords[1] + hsize, coords[1] + hsize + self.right.value())))  # noqa
            topkey = (axes[0].restrict(slice(coords[0] - vsize - self.top.value(), coords[0] - vsize)), key[1])  # noqa
            bottomkey = (axes[0].restrict(slice(coords[0] + vsize, coords[0] + vsize + self.bottom.value())), key[1])  # noqa

            return leftkey, rightkey, topkey, bottomkey
        else:
            return [(axes[0].restrict(slice(self.left.value(),
                                            self.right.value())),
                     axes[1].restrict(slice(self.top.value(),
                                            self.bottom.value())))]

    def get_loc(self):
        if self.fromfit.isChecked():
            return self.database.load_loc(self.currentindex())
        else:
            index = self.currentindex()
            indexvalue = self.database.get_index_value(index)
            return self.parent.peakwidget.get_coords(indexvalue)

    def loc_callback(self, x, y):
        if self.ax:
            if self.fromfit.isChecked():
                self.database.save_loc(
                    self.currentindex(), numpy.array([x, y]))
            else:
                index = self.currentindex()
                indexvalue = self.database.get_index_value(index)
                self.parent.peakwidget.add_row(numpy.array([indexvalue, x, y]))
            self.plot_box()

    def plot(self, index=None):
        if index is None:
            index = self.currentindex()
        space = self.database.space_from_index(index)
        interdata = self.database.load_data(index, 'inter')
        info = self.database.get_index_value(index)
        label = self.database.axis

        self.figure.clear()
        self.figure.space_axes = space.axes

        if interdata is not None:
            if space.dimension == 1:
                self.ax = self.figure.add_subplot(111)
                binoculars.plot.plot(
                    space, self.figure, self.ax, fit=interdata)
            elif space.dimension == 2:
                self.ax = self.figure.add_subplot(121)
                binoculars.plot.plot(space, self.figure, self.ax, fit=None)
                self.ax = self.figure.add_subplot(122)
                binoculars.plot.plot(
                    space, self.figure, self.ax, fit=interdata)
        else:
            self.ax = self.figure.add_subplot(111)
            binoculars.plot.plot(space, self.figure, self.ax)

        self.figure.suptitle('{0}, res = {1}, {2} = {3}'.format(
            self.database.rodkey, self.database.resolution, label, info))

        self.plot_box()
        self.canvas.draw()

    def plot_box(self):
        loc = self.get_loc()
        if len(self.figure.get_axes()) != 0 and loc is not None:
            ax = self.figure.get_axes()[0]
            axes = self.figure.space_axes
            key = self.intkey(loc, axes)
            bkgkey = self.bkgkeys(loc, axes)
            ax.patches = []
            rect = Rectangle((key[0].start, key[1].start),
                             key[0].stop - key[0].start,
                             key[1].stop - key[1].start,
                             alpha=0.2, color='k')
            ax.add_patch(rect)
            for k in bkgkey:
                bkg = Rectangle((k[0].start, k[1].start),
                                k[0].stop - k[0].start,
                                k[1].stop - k[1].start,
                                alpha=0.2, color='r')
                ax.add_patch(bkg)
            self.canvas.draw()

    def currentindex(self):
        index = self.database.load('index')
        if index is None:
            return 0
        else:
            return index


class ButtonedSlider(QWidget):
    slice_index = pyqtSignal(int)

    def __init__(self, parent=None):
        super(ButtonedSlider, self).__init__(parent)

        self.navigation_button_left_end = QPushButton('|<')
        self.navigation_button_left_one = QPushButton('<')
        self.navigation_slider = QSlider(Qt.Horizontal)
        self.navigation_slider.sliderReleased.connect(self.send)

        self.navigation_button_right_one = QPushButton('>')
        self.navigation_button_right_end = QPushButton('>|')

        self.navigation_button_left_end.setMaximumWidth(20)
        self.navigation_button_left_one.setMaximumWidth(20)
        self.navigation_button_right_end.setMaximumWidth(20)
        self.navigation_button_right_one.setMaximumWidth(20)

        self.navigation_button_left_end.clicked.connect(
            self.slider_change_left_end)
        self.navigation_button_left_one.clicked.connect(
            self.slider_change_left_one)
        self.navigation_button_right_end.clicked.connect(
            self.slider_change_right_end)
        self.navigation_button_right_one.clicked.connect(
            self.slider_change_right_one)

        box = QHBoxLayout()
        box.addWidget(self.navigation_button_left_end)
        box.addWidget(self.navigation_button_left_one)
        box.addWidget(self.navigation_slider)
        box.addWidget(self.navigation_button_right_one)
        box.addWidget(self.navigation_button_right_end)

        self.setDisabled(True)
        self.setLayout(box)

    def set_length(self, length):
        self.navigation_slider.setMinimum(0)
        self.navigation_slider.setMaximum(length - 1)
        self.navigation_slider.setTickPosition(QSlider.TicksBelow)
        self.navigation_slider.setValue(0)
        self.setEnabled(True)

    def send(self):
        self.slice_index.emit(self.navigation_slider.value())

    def slider_change_left_one(self):
        self.navigation_slider.setValue(
            max(self.navigation_slider.value() - 1, 0))
        self.send()

    def slider_change_left_end(self):
        self.navigation_slider.setValue(0)
        self.send()

    def slider_change_right_one(self):
        self.navigation_slider.setValue(
            min(self.navigation_slider.value() + 1,
                self.navigation_slider.maximum()))
        self.send()

    def slider_change_right_end(self):
        self.navigation_slider.setValue(self.navigation_slider.maximum())
        self.send()

    def index(self):
        return self.navigation_slider.value()

    def set_index(self, index):
        self.navigation_slider.setValue(index)


class HiddenToolbar2(NavigationToolbar2QT):

    def __init__(self, canvas):
        super(HiddenToolbar2, self).__init__(canvas, None)
        self.zoom()


class OverviewWidget(QWidget):

    def __init__(self, database, parent=None):
        super(OverviewWidget, self).__init__(parent)

        self.databaselist = list()

        self.figure = matplotlib.figure.Figure()
        self.canvas = FigureCanvasQTAgg(self.figure)
        self.toolbar = HiddenToolbar2(self.canvas)

        self.table = QTableWidget(0, 2)
        self.make_table()

        self.table.cellClicked.connect(self.plot)

        hbox = QHBoxLayout()

        splitter = QSplitter(Qt.Horizontal)

        splitter.addWidget(self.canvas)
        splitter.addWidget(self.control_widget)

        hbox.addWidget(splitter)
        self.setLayout(hbox)

    def select(self):
        selection = []
        for index in range(self.table.rowCount()):
            checkbox = self.table.cellWidget(index, 0)
            if checkbox.checkState():
                selection.append(str(self.table.cellWidget(index, 1).text()))
        return selection

    def make_table(self):
        self.control_widget = QWidget()
        vbox = QVBoxLayout()
        minibox = QHBoxLayout()

        vbox.addWidget(self.table)
        self.table.setHorizontalHeaderLabels(['', 'param'])
        for index, width in enumerate([25, 50]):
            self.table.setColumnWidth(index, width)
        self.log = QCheckBox('log')
        self.log.clicked.connect(self.plot)
        self.export_button = QPushButton('export curves')

        self.export_button.clicked.connect(self.export)

        minibox.addWidget(self.log)
        minibox.addWidget(self.export_button)
        vbox.addLayout(minibox)
        self.control_widget.setLayout(vbox)

    def export(self):
        folder = str(QFileDialog.getExistingDirectory(
            self, "Select directory to save curves"))
        params = self.select()
        for param in params:
            for database in self.databaselist:
                x, y = database.all_from_key(param)
                args = numpy.argsort(x)
                filename = '{0}_{1}.txt'.format(param, database.rodkey)
                numpy.savetxt(os.path.join(folder, filename),
                              numpy.vstack(arr[args] for arr in [x, y]).T)

    def refresh(self, databaselist):
        self.databaselist = databaselist
        params = self.select()
        while self.table.rowCount() > 0:
            self.table.removeRow(0)

        allparams = [[param
                      for param in database.all_attrkeys()
                      if not param.startswith('mask')]
                     for database in databaselist]

        allparams.extend([['locx_s', 'locy_s']]
                         for database in databaselist
                         if database.load_segments() is not None)

        if len(allparams) > 0:
            uniqueparams = numpy.unique(
                numpy.hstack([params for params in allparams]))
        else:
            uniqueparams = []

        for param in uniqueparams:
            index = self.table.rowCount()
            self.table.insertRow(index)

            checkboxwidget = QCheckBox()
            if param in params:
                checkboxwidget.setChecked(1)
            else:
                checkboxwidget.setChecked(0)
            self.table.setCellWidget(index, 0, checkboxwidget)
            checkboxwidget.clicked.connect(self.plot)

            item = QLabel(param)
            self.table.setCellWidget(index, 1, item)

        self.plot()

    def plot(self):
        params = self.select()
        self.figure.clear()

        self.ax = self.figure.add_subplot(111)
        for param in params:
            for database in self.databaselist:
                if param == 'locx_s':
                    segments = database.load_segments()
                    if segments is not None:
                        x = numpy.hstack(
                            [database.get_index_value(index)
                             for index in range(database.rodlength())]
                        )
                        y = numpy.vstack(
                            [get_coords(xvalue, segments) for xvalue in x]
                        )
                        self.ax.plot(
                            x, y[:, 0], '+',
                            label='{0} - {1}'.format('locx_s', database.rodkey)
                        )
                elif param == 'locy_s':
                    segments = database.load_segments()
                    if segments is not None:
                        x = numpy.hstack(
                            [database.get_index_value(index)
                             for index in range(database.rodlength())]
                        )
                        y = numpy.vstack(
                            [get_coords(xvalue, segments) for xvalue in x]
                        )
                        self.ax.plot(
                            x, y[:, 1], '+',
                            label='{0} - {1}'.format('locy_s', database.rodkey)
                        )
                else:
                    x, y = database.all_from_key(param)
                    self.ax.plot(
                        x, y, '+',
                        label='{0} - {1}'.format(param, database.rodkey)
                    )

        self.ax.legend()
        if self.log.checkState():
            self.ax.semilogy()
        self.canvas.draw()


class PeakWidget(QWidget):

    def __init__(self, database, parent=None):
        super(PeakWidget, self).__init__(parent)
        self.database = database

        # create a QTableWidget
        self.table = QTableWidget(0, 3, self)
        self.table.horizontalHeader().setStretchLastSection(True)
        self.table.verticalHeader().setVisible(False)
        self.table.itemChanged.connect(self.save)

        self.btn_add_row = QPushButton('+', self)
        self.btn_add_row.clicked.connect(self.add_row)

        self.buttonRemove = QPushButton('-', self)
        self.buttonRemove.clicked.connect(self.remove)

        vbox = QVBoxLayout()
        hbox = QHBoxLayout()

        hbox.addWidget(self.btn_add_row)
        hbox.addWidget(self.buttonRemove)

        vbox.addLayout(hbox)
        vbox.addWidget(self.table)
        self.setLayout(vbox)

    def set_axis(self):
        self.axes = self.database.paxes()
        while self.table.rowCount() > 0:
            self.table.removeRow(0)
        segments = self.database.load_segments()
        if segments is not None:
            for index in range(segments.shape[0]):
                self.add_row(segments[index, :])
        self.table.setHorizontalHeaderLabels(
            ['{0}'.format(self.database.axis),
             '{0}'.format(self.axes[0].label),
             '{0}'.format(self.axes[1].label)]
        )

    def add_row(self, row=None):
        rowindex = self.table.rowCount()
        self.table.insertRow(rowindex)
        if row is not None:
            for index in range(3):
                newitem = QTableWidgetItem(str(row[index]))
                self.table.setItem(rowindex, index, newitem)

    def remove(self):
        self.table.removeRow(self.table.currentRow())
        self.save()

    def axis_coords(self):
        a = numpy.zeros((self.table.rowCount(), self.table.columnCount()))
        for rowindex in range(a.shape[0]):
            for columnindex in range(a.shape[1]):
                item = self.table.item(rowindex, columnindex)
                if item is not None:
                    a[rowindex, columnindex] = float(item.text())
        return a

    def save(self):
        self.database.save_segments(self.axis_coords())

    def get_coords(self, x):
        return get_coords(x, self.axis_coords())


def get_coords(x, coords):

    if coords.shape[0] == 0:
        return None

    if coords.shape[0] == 1:
        return coords[0, 1:]

    args = numpy.argsort(coords[:, 0])

    x0 = coords[args, 0]
    x1 = coords[args, 1]
    x2 = coords[args, 2]

    if x < x0.min():
        first = 0
        last = 1
    elif x > x0.max():
        first = -2
        last = -1
    else:
        first = numpy.searchsorted(x0, x) - 1
        last = numpy.searchsorted(x0, x)

    a1 = (x1[last] - x1[first]) / (x0[last] - x0[first])
    b1 = x1[first] - a1 * x0[first]
    a2 = (x2[last] - x2[first]) / (x0[last] - x0[first])
    b2 = x2[first] - a2 * x0[first]

    return numpy.array([a1 * x + b1, a2 * x + b2])


def interpolate(space):
    data = space.get_masked()
    mask = data.mask
    grid = numpy.vstack([numpy.ma.array(g, mask=mask).compressed()
                         for g in space.get_grid()]).T
    open = numpy.vstack(
        [numpy.ma.array(g, mask=numpy.invert(mask)).compressed()
         for g in space.get_grid()]
    ).T
    if open.shape[0] == 0:
        return data.compressed()
    elif grid.shape[0] == 0:
        return data.compressed()
    else:
        interpolated = griddata(grid, data.compressed(), open)
        values = data.data.copy()
        values[mask] = interpolated
        mask = numpy.isnan(values)
        if mask.sum() > 0:
            data = numpy.ma.array(values, mask=mask)
            grid = numpy.vstack([numpy.ma.array(g, mask=mask).compressed()
                                 for g in space.get_grid()]).T
            open = numpy.vstack(
                [numpy.ma.array(g, mask=numpy.invert(mask)).compressed()
                 for g in space.get_grid()]
            ).T
            interpolated = griddata(
                grid, data.compressed(), open, method='nearest')
            values[mask] = interpolated
        return values


def find_unused_rodkey(rodkey: str, rods:str) -> str:
    if rodkey not in rods:
        return rodkey
    for index in itertools.count(0):
        newkey = '{0}_{1}'.format(rodkey, index)
        if newkey not in rods:
            return newkey


if __name__ == "__main__":
    app = QApplication(sys.argv)

    main = Window()
    main.resize(1000, 600)
    main.show()

    sys.exit(app.exec_())
