Commit ad2619dd authored by zhiyang.zhou's avatar zhiyang.zhou

增加”软件所的数据的预处理程序“

parent 3c16b603
...@@ -25,6 +25,8 @@ class Image_DataSet(torch.utils.data.Dataset): ...@@ -25,6 +25,8 @@ class Image_DataSet(torch.utils.data.Dataset):
# image = np.expand_dims(source_image.transpose((2, 0, 1)), 0) # image = np.expand_dims(source_image.transpose((2, 0, 1)), 0)
if self.transform is not None: if self.transform is not None:
image = self.transform(source_image) image = self.transform(source_image)
if image.shape[1]<image.shape[2]:
image = torch.transpose(image, 2, 1)
img_file_only = imgs[index] img_file_only = imgs[index]
tmpstrs = os.path.splitext(img_file_only)[0].split('-') tmpstrs = os.path.splitext(img_file_only)[0].split('-')
...@@ -68,8 +70,11 @@ def check_label(xlsx_path, img_file, property='舌色'): ...@@ -68,8 +70,11 @@ def check_label(xlsx_path, img_file, property='舌色'):
if label in property_label_to_idx: if label in property_label_to_idx:
label = property_label_to_idx[label] label = property_label_to_idx[label]
else: else:
raise ValueError( print('error label {}, which not found in property_label_to_idx'.format(label))
'error label {}, which not found in property_label_to_idx'.format(label)) # raise ValueError(
# 'error label {}, which not found in property_label_to_idx'.format(label))
label = None
return False, tmp_name, label
print('I find [{0} → {1}] in [{2}].[{3}]'.format(tmp_name, label, sheet_name, col_nm2)) print('I find [{0} → {1}] in [{2}].[{3}]'.format(tmp_name, label, sheet_name, col_nm2))
return True, tmp_name, label return True, tmp_name, label
return False, patient_name, None return False, patient_name, None
...@@ -103,11 +108,53 @@ def preprocess_data(): ...@@ -103,11 +108,53 @@ def preprocess_data():
img_file = os.path.join(img_path, img_files[k]) img_file = os.path.join(img_path, img_files[k])
img_file_only = img_files[k] img_file_only = img_files[k]
islabeled, name, label = check_label(xlsx_path, img_file_only, property=property) islabeled, name, label = check_label(xlsx_path, img_file_only, property=property)
if islabeled: if islabeled and label!=None:
if label >= 8:
continue
tmpstrs = os.path.splitext(img_file_only) tmpstrs = os.path.splitext(img_file_only)
dst_img_file_only = tmpstrs[0] + '-' + property + '_' + str(label) + tmpstrs[1] dst_img_file_only = tmpstrs[0] + '-' + property + '_' + str(label) + tmpstrs[1]
dst_img_file = os.path.join(processed_dir, dst_img_file_only) dst_img_file = os.path.join(processed_dir, dst_img_file_only)
copyfile(img_file, dst_img_file) copyfile(img_file, dst_img_file)
# if __name__ == "__main__": def crop_img():
# preprocess_data() from PIL import Image
import re
rootdir = 'E:\\workspace\\zhongyi\\软件所-第一批数据\\JDYT\\'
cropeddir = 'E:\\workspace\\zhongyi\\软件所-第一批数据\\croped\\'
list = os.listdir(rootdir)
for i in range(0, len(list)):
img_file = os.path.join(rootdir, list[i])
img_hold = Image.open(img_file).convert('RGB')
img_size = img_hold.size
if img_size[0]<img_size[1]:
w=img_size[0]
h=img_size[1]
left = int((0.8/3)*w)
right = w-left
bottom = int((9/10)*h)
top = bottom - 1200
box = (left, top, right, bottom)
else:
w = img_size[0]
h = img_size[1]
left = int((0.8 / 10) * w)
right = left + 1200*1.1
top = int((0.9 / 3) * h)
bottom = top+900*1.1
box = (left, top, right, bottom)
# 开始截取
region = img_hold.crop(box)
name = ''.join(re.findall('[\u4e00-\u9fa5]', list[i]))
name = name + '-' + 'croped' + '.jpg'
croped_img_file = os.path.join(cropeddir, name)
region.save(croped_img_file)
if __name__ == "__main__":
preprocess_data()
# crop_img()
...@@ -52,7 +52,7 @@ normalizer = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9] ...@@ -52,7 +52,7 @@ normalizer = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]
transform_train = transforms.Compose([ transform_train = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.Resize((298, 224)), transforms.Resize((224)),
# transforms.RandomCrop((224, 224), padding=14), # transforms.RandomCrop((224, 224), padding=14),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToTensor(), transforms.ToTensor(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment