import PIL
import numpy as np
import os
import pandas as pd
from shutil import copyfile
import torch

property_idx_to_label = {0: '淡红', 1: '淡白', 2: '红', 3: '淡紫', 4: '紫暗', 5: '紫红', 6: '青紫', 7: '其他'}
property_label_to_idx = {'淡红': 0, '淡白': 1, '红': 2, '淡紫': 3, '紫暗': 4, '紫红': 5, '青紫': 6, '其他': 7}

# property_idx_to_shetai = {0: '润', 1: '欠润', 2: '滑', 3: '燥', 4: '糙'}
property_shetai_to_idx = {'润': 1, '欠润': 2, '滑': 3, '燥': 4, '糙': 5} #需要再次减去1

class Image_DataSet(torch.utils.data.Dataset):

    def __init__(self, processed_dir = './processed-data', transform=None):
        self.processed_dir = processed_dir
        self.transform = transform

    def __getitem__(self, index):
        imgs = os.listdir(self.processed_dir)
        index = index % (len(imgs) - 1)

        img_file = os.path.join(self.processed_dir, imgs[index])
        source_image = PIL.Image.open(img_file).convert('RGB')
        # tmp=np.array(source_image)
        # image = np.expand_dims(source_image.transpose((2, 0, 1)), 0)
        if self.transform is not None:
            image = self.transform(source_image)
        if image.shape[1]<image.shape[2]:
            image = torch.transpose(image, 2, 1)

        img_file_only = imgs[index]
        tmpstrs = os.path.splitext(img_file_only)[0].split('-')
        y = int(tmpstrs[-1].split('_')[-1])
        return image, y

    def __len__(self):
        imgs = os.listdir(self.processed_dir)
        return len(imgs)


def check_label(xlsx_path, img_file, property='舌色'):
    def is_digit(s):
        try:
            float(s)  # is a number(either integer or real)
            return True
        except:
            return False

    patient_info = pd.read_excel(xlsx_path, sheet_name=None)
    patient_name = img_file.split('-')[0]
    print('searching [{}] in [{}]'.format(img_file, xlsx_path))
    for sheet_name, sheet in patient_info.items():
        # print(v)
        for col_nm1 in sheet.columns:
            if property in col_nm1:
                # print('I find property [{0}] in [{1}].[{2}]'.format(property, sheet_name, col_nm1))
                for col_nm2 in sheet.columns:
                    if '姓名' in col_nm2:
                        flags = (patient_name == sheet[col_nm2])
                        tmp_name = sheet[col_nm2][flags].values
                        label = sheet[col_nm1][flags].values
                        if len(tmp_name) == 1 and len(label) == 1:
                            tmp_name = tmp_name[0]
                            label = label[0]
                            # if type(label) == float64:
                            #     label=int(label)
                            if pd.isna(label):
                                return False, patient_name, None
                            if is_digit(str(label)):
                                label = int(label)
                            else:
                                if label in property_shetai_to_idx:
                                    label = property_shetai_to_idx[label]
                                else:
                                    print('error label {}, which not found in property_shetai_to_idx'.format(label))
                                    # raise ValueError(
                                    #     'error label {}, which not found in property_shetai_to_idx'.format(label))
                                    return False, tmp_name, None
                            print('I find [{0} → {1}] in [{2}].[{3}]'.format(tmp_name, label, sheet_name, col_nm2))
                            return True, tmp_name, label
    return False, patient_name, None


def preprocess_data():
    processed_dir = './processed-data-shetai'
    if not os.path.exists(processed_dir):
        os.makedirs(processed_dir)

    rootdir = 'E:\\workspace\\zhongyi'
    property = '舌苔润燥'
    list = os.listdir(rootdir)  # 列出文件夹下所有的目录与文件
    for i in range(0, len(list)):
        path = os.path.join(rootdir, list[i])
        if os.path.isdir(path):
            # take image dir and xlsx file
            sub_paths = os.listdir(path)
            img_paths = []
            xlsx_path = None
            for j in range(0, len(sub_paths)):
                sub_path = os.path.join(path, sub_paths[j])
                if os.path.isdir(sub_path):
                    img_paths.append(sub_path)
                elif os.path.isfile(sub_path):
                    xlsx_path = sub_path

            # take images
            for img_path in img_paths:
                img_files = os.listdir(img_path)
                for k in range(0, len(img_files)):
                    img_file = os.path.join(img_path, img_files[k])
                    img_file_only = img_files[k]
                    islabeled, name, label = check_label(xlsx_path, img_file_only, property=property)
                    if islabeled and label!=None:
                        if label >= 6:
                            continue
                        if list[i] == '软件所-第一批数据':
                            new_label = label
                        else:
                            new_label = label - 1
                        print('原始label: {0} 重写为 {1}'.format(label, new_label))
                        tmpstrs = os.path.splitext(img_file_only)
                        dst_img_file_only = tmpstrs[0] + '-' + property + '_' + str(new_label) + tmpstrs[1]
                        dst_img_file = os.path.join(processed_dir, dst_img_file_only)
                        copyfile(img_file, dst_img_file)

def crop_img():
    from PIL import Image
    import re
    rootdir = 'E:\\workspace\\zhongyi\\软件所-第一批数据\\JDYT\\'
    cropeddir = 'E:\\workspace\\zhongyi\\软件所-第一批数据\\croped\\'
    list = os.listdir(rootdir)
    for i in range(0, len(list)):
        img_file = os.path.join(rootdir, list[i])
        img_hold = Image.open(img_file).convert('RGB')
        img_size = img_hold.size
        if img_size[0]<img_size[1]:
            w=img_size[0]
            h=img_size[1]
            left = int((0.8/3)*w)
            right = w-left

            bottom = int((9/10)*h)
            top = bottom - 1200
            box = (left, top, right, bottom)
        else:
            w = img_size[0]
            h = img_size[1]
            left = int((0.8 / 10) * w)
            right = left + 1200*1.1

            top = int((0.9 / 3) * h)
            bottom = top+900*1.1
            box = (left, top, right, bottom)
        # 开始截取
        region = img_hold.crop(box)

        name = ''.join(re.findall('[\u4e00-\u9fa5]', list[i]))
        name = name + '-' + 'croped' + '.jpg'

        croped_img_file = os.path.join(cropeddir, name)
        region.save(croped_img_file)



if __name__ == "__main__":
    preprocess_data()
    # crop_img()
