import PIL
import numpy as np
import os
import pandas as pd
from shutil import copyfile
import torch

property_idx_to_label = {0: '淡红', 1: '淡白', 2: '红', 3: '淡紫', 4: '紫暗', 5: '紫红', 6: '青紫', 7: '其他'}
property_label_to_idx = {'淡红': 0, '淡白': 1, '红': 2, '淡紫': 3, '紫暗': 4, '紫红': 5, '青紫': 6, '其他': 7}


class Image_DataSet(torch.utils.data.Dataset):

    def __init__(self, transform=None):
        self.processed_dir = './processed-data'
        self.transform = transform

    def __getitem__(self, index):
        imgs = os.listdir(self.processed_dir)
        index = index % (len(imgs) - 1)

        img_file = os.path.join(self.processed_dir, imgs[index])
        source_image = PIL.Image.open(img_file)
        # source_image=np.array(source_image)
        # image = np.expand_dims(source_image.transpose((2, 0, 1)), 0)
        if self.transform is not None:
            image = self.transform(source_image)

        img_file_only = imgs[index]
        tmpstrs = os.path.splitext(img_file_only)[0].split('-')
        y = int(tmpstrs[-1].split('_')[-1])
        return image, y

    def __len__(self):
        imgs = os.listdir(self.processed_dir)
        return len(imgs)


def check_label(xlsx_path, img_file, property='舌色'):
    def is_digit(s):
        try:
            float(s)  # is a number(either integer or real)
            return True
        except:
            return False

    patient_info = pd.read_excel(xlsx_path, sheet_name=None)
    patient_name = img_file.split('-')[0]
    print('searching [{}] in [{}]'.format(patient_name, xlsx_path))
    for sheet_name, sheet in patient_info.items():
        # print(v)
        for col_nm1 in sheet.columns:
            if property in col_nm1:
                # print('I find property [{0}] in [{1}].[{2}]'.format(property, sheet_name, col_nm1))
                for col_nm2 in sheet.columns:
                    if '姓名' in col_nm2:
                        flags = (patient_name == sheet[col_nm2])
                        tmp_name = sheet[col_nm2][flags].values
                        label = sheet[col_nm1][flags].values
                        if len(tmp_name) == 1 and len(label) == 1:
                            tmp_name = tmp_name[0]
                            label = label[0]
                            # if type(label) == float64:
                            #     label=int(label)
                            if is_digit(str(label)):
                                label = int(label)
                            else:
                                if label in property_label_to_idx:
                                    label = property_label_to_idx[label]
                                else:
                                    raise ValueError(
                                        'error label {}, which not found in property_label_to_idx'.format(label))
                            print('I find [{0} → {1}] in [{2}].[{3}]'.format(tmp_name, label, sheet_name, col_nm2))
                            return True, tmp_name, label
    return False, patient_name, None


def preprocess_data():
    rootdir = './tongue-tops'
    property = '舌色'
    list = os.listdir(rootdir)  # 列出文件夹下所有的目录与文件
    for i in range(0, len(list)):
        path = os.path.join(rootdir, list[i])
        if os.path.isdir(path):
            # take image dir and xlsx file
            sub_paths = os.listdir(path)
            img_path = None
            xlsx_path = None
            for j in range(0, len(sub_paths)):
                sub_path = os.path.join(path, sub_paths[j])
                if os.path.isdir(sub_path):
                    img_path = sub_path
                elif os.path.isfile(sub_path):
                    xlsx_path = sub_path
            # take patient info
            # take images
            processed_dir = './processed-data'
            if not os.path.exists(processed_dir):
                os.makedirs(processed_dir)
            img_files = os.listdir(img_path)

            for k in range(0, len(img_files)):
                img_file = os.path.join(img_path, img_files[k])
                img_file_only = img_files[k]
                islabeled, name, label = check_label(xlsx_path, img_file_only, property=property)
                if islabeled:
                    tmpstrs = os.path.splitext(img_file_only)
                    dst_img_file_only = tmpstrs[0] + '-' + property + '_' + str(label) + tmpstrs[1]
                    dst_img_file = os.path.join(processed_dir, dst_img_file_only)
                    copyfile(img_file, dst_img_file)

# if __name__ == "__main__":
#     preprocess_data()
