Commit 3c9c5251 authored by zhiyang.zhou's avatar zhiyang.zhou

sehtai

parent de5d8ac1
......@@ -8,6 +8,8 @@ import torch
property_idx_to_label = {0: '淡红', 1: '淡白', 2: '红', 3: '淡紫', 4: '紫暗', 5: '紫红', 6: '青紫', 7: '其他'}
property_label_to_idx = {'淡红': 0, '淡白': 1, '红': 2, '淡紫': 3, '紫暗': 4, '紫红': 5, '青紫': 6, '其他': 7}
# property_idx_to_shetai = {0: '润', 1: '欠润', 2: '滑', 3: '燥', 4: '糙'}
property_shetai_to_idx = {'润': 1, '欠润': 2, '滑': 3, '燥': 4, '糙': 5} #需要再次减去1
class Image_DataSet(torch.utils.data.Dataset):
......@@ -21,7 +23,7 @@ class Image_DataSet(torch.utils.data.Dataset):
img_file = os.path.join(self.processed_dir, imgs[index])
source_image = PIL.Image.open(img_file).convert('RGB')
# source_image=np.array(source_image)
# tmp=np.array(source_image)
# image = np.expand_dims(source_image.transpose((2, 0, 1)), 0)
if self.transform is not None:
image = self.transform(source_image)
......@@ -69,12 +71,12 @@ def check_label(xlsx_path, img_file, property='舌色'):
if is_digit(str(label)):
label = int(label)
else:
if label in property_label_to_idx:
label = property_label_to_idx[label]
if label in property_shetai_to_idx:
label = property_shetai_to_idx[label]
else:
print('error label {}, which not found in property_label_to_idx'.format(label))
print('error label {}, which not found in property_shetai_to_idx'.format(label))
# raise ValueError(
# 'error label {}, which not found in property_label_to_idx'.format(label))
# 'error label {}, which not found in property_shetai_to_idx'.format(label))
return False, tmp_name, None
print('I find [{0} → {1}] in [{2}].[{3}]'.format(tmp_name, label, sheet_name, col_nm2))
return True, tmp_name, label
......@@ -82,12 +84,12 @@ def check_label(xlsx_path, img_file, property='舌色'):
def preprocess_data():
processed_dir = './processed-data'
processed_dir = './processed-data-shetai'
if not os.path.exists(processed_dir):
os.makedirs(processed_dir)
rootdir = './tongue-tops'
property = '舌'
rootdir = 'E:\\workspace\\zhongyi'
property = '舌苔润燥'
list = os.listdir(rootdir) # 列出文件夹下所有的目录与文件
for i in range(0, len(list)):
path = os.path.join(rootdir, list[i])
......@@ -111,10 +113,15 @@ def preprocess_data():
img_file_only = img_files[k]
islabeled, name, label = check_label(xlsx_path, img_file_only, property=property)
if islabeled and label!=None:
if label >= 8:
if label >= 6:
continue
if list[i] == '软件所-第一批数据':
new_label = label
else:
new_label = label - 1
print('原始label: {0} 重写为 {1}'.format(label, new_label))
tmpstrs = os.path.splitext(img_file_only)
dst_img_file_only = tmpstrs[0] + '-' + property + '_' + str(label) + tmpstrs[1]
dst_img_file_only = tmpstrs[0] + '-' + property + '_' + str(new_label) + tmpstrs[1]
dst_img_file = os.path.join(processed_dir, dst_img_file_only)
copyfile(img_file, dst_img_file)
......
......@@ -20,14 +20,14 @@ parser.add_argument('--no_cuda', action='store_true', default=False, help='disab
parser.add_argument('-b', '--batch_size', default=64, type=int, help='mini-batch size (default: 64)')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum')
parser.add_argument('--weight_decay', '--wd', default=2e-4, type=float, metavar='W')
parser.add_argument('--weight_decay', '--wd', default=0, type=float, metavar='W')
parser.add_argument('--epochs', type=int, default=40, metavar='N', help='number of epochs to train')
parser.add_argument('--schedule', type=int, nargs='+', default=[25, 30, 35],
help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.2, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--model_dir', default='dnn_models/tongue_modes/', help='directory of model for saving checkpoint')
parser.add_argument('--training_data_dir', default='./processed-data', help='directory of training data')
parser.add_argument('--test_data_dir', default='./processed-testdata', help='directory of test data')
parser.add_argument('--model_dir', default='dnn_models/tongue_modes/shetai/', help='directory of model for saving checkpoint')
parser.add_argument('--training_data_dir', default='./processed-data-shetai', help='directory of training data')
parser.add_argument('--test_data_dir', default='./shetai-test', help='directory of test data')
parser.add_argument('--save_freq', '-s', default=1, type=int, metavar='N', help='save frequency')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
......@@ -61,7 +61,7 @@ transform_train = transforms.Compose([
transforms.Resize((224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalizer
# normalizer
])
training_data_dir = args.training_data_dir
tongue_dataset_train = Image_DataSet(processed_dir=training_data_dir, transform=transform_train)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment