Commit 103c8ddd authored by zhiyang.zhou's avatar zhiyang.zhou

支持指定训练集和测试集目录

parent 445d587e
...@@ -48,7 +48,7 @@ def check_label(xlsx_path, img_file, property='舌色'): ...@@ -48,7 +48,7 @@ def check_label(xlsx_path, img_file, property='舌色'):
patient_info = pd.read_excel(xlsx_path, sheet_name=None) patient_info = pd.read_excel(xlsx_path, sheet_name=None)
patient_name = img_file.split('-')[0] patient_name = img_file.split('-')[0]
print('searching [{}] in [{}]'.format(patient_name, xlsx_path)) print('searching [{}] in [{}]'.format(img_file, xlsx_path))
for sheet_name, sheet in patient_info.items(): for sheet_name, sheet in patient_info.items():
# print(v) # print(v)
for col_nm1 in sheet.columns: for col_nm1 in sheet.columns:
...@@ -64,6 +64,8 @@ def check_label(xlsx_path, img_file, property='舌色'): ...@@ -64,6 +64,8 @@ def check_label(xlsx_path, img_file, property='舌色'):
label = label[0] label = label[0]
# if type(label) == float64: # if type(label) == float64:
# label=int(label) # label=int(label)
if pd.isna(label):
return False, patient_name, None
if is_digit(str(label)): if is_digit(str(label)):
label = int(label) label = int(label)
else: else:
...@@ -73,14 +75,17 @@ def check_label(xlsx_path, img_file, property='舌色'): ...@@ -73,14 +75,17 @@ def check_label(xlsx_path, img_file, property='舌色'):
print('error label {}, which not found in property_label_to_idx'.format(label)) print('error label {}, which not found in property_label_to_idx'.format(label))
# raise ValueError( # raise ValueError(
# 'error label {}, which not found in property_label_to_idx'.format(label)) # 'error label {}, which not found in property_label_to_idx'.format(label))
label = None return False, tmp_name, None
return False, tmp_name, label
print('I find [{0} → {1}] in [{2}].[{3}]'.format(tmp_name, label, sheet_name, col_nm2)) print('I find [{0} → {1}] in [{2}].[{3}]'.format(tmp_name, label, sheet_name, col_nm2))
return True, tmp_name, label return True, tmp_name, label
return False, patient_name, None return False, patient_name, None
def preprocess_data(): def preprocess_data():
processed_dir = './processed-data'
if not os.path.exists(processed_dir):
os.makedirs(processed_dir)
rootdir = './tongue-tops' rootdir = './tongue-tops'
property = '舌色' property = '舌色'
list = os.listdir(rootdir) # 列出文件夹下所有的目录与文件 list = os.listdir(rootdir) # 列出文件夹下所有的目录与文件
...@@ -100,11 +105,7 @@ def preprocess_data(): ...@@ -100,11 +105,7 @@ def preprocess_data():
# take images # take images
for img_path in img_paths: for img_path in img_paths:
processed_dir = './processed-data'
if not os.path.exists(processed_dir):
os.makedirs(processed_dir)
img_files = os.listdir(img_path) img_files = os.listdir(img_path)
for k in range(0, len(img_files)): for k in range(0, len(img_files)):
img_file = os.path.join(img_path, img_files[k]) img_file = os.path.join(img_path, img_files[k])
img_file_only = img_files[k] img_file_only = img_files[k]
......
...@@ -12,9 +12,9 @@ from models import wideresnet, resnet ...@@ -12,9 +12,9 @@ from models import wideresnet, resnet
from torchvision import models from torchvision import models
parser = argparse.ArgumentParser(description='PyTorch DenseNet Training') parser = argparse.ArgumentParser(description='PyTorch Tongue-diagnosis Training')
parser.add_argument('--model_name', default='resnet-50', parser.add_argument('--model_name', default='resnet-50',
help='models: densenet121, resnet-50, resnet-101, inception_v3, vgg19_bn') help='models: densenet121, resnet-50, vgg19_bn')
parser.add_argument('--gpuid', default='0', type=str, help='which gpu to use') parser.add_argument('--gpuid', default='0', type=str, help='which gpu to use')
parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('-b', '--batch_size', default=64, type=int, help='mini-batch size (default: 64)') parser.add_argument('-b', '--batch_size', default=64, type=int, help='mini-batch size (default: 64)')
...@@ -26,6 +26,8 @@ parser.add_argument('--schedule', type=int, nargs='+', default=[25, 30, 35], ...@@ -26,6 +26,8 @@ parser.add_argument('--schedule', type=int, nargs='+', default=[25, 30, 35],
help='Decrease learning rate at these epochs.') help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.2, help='LR is multiplied by gamma on schedule.') parser.add_argument('--gamma', type=float, default=0.2, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--model_dir', default='dnn_models/tongue_modes/', help='directory of model for saving checkpoint') parser.add_argument('--model_dir', default='dnn_models/tongue_modes/', help='directory of model for saving checkpoint')
parser.add_argument('--training_data_dir', default='./processed-data', help='directory of training data')
parser.add_argument('--test_data_dir', default='./processed-testdata', help='directory of test data')
parser.add_argument('--save_freq', '-s', default=1, type=int, metavar='N', help='save frequency') parser.add_argument('--save_freq', '-s', default=1, type=int, metavar='N', help='save frequency')
parser.add_argument('--log_interval', type=int, default=10, metavar='N', parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status') help='how many batches to wait before logging training status')
...@@ -52,19 +54,31 @@ torch.backends.cudnn.enabled = False ...@@ -52,19 +54,31 @@ torch.backends.cudnn.enabled = False
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {} kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
normalizer = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], normalizer = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
#trining data
transform_train = transforms.Compose([ transform_train = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.Resize((224)), transforms.Resize((224)),
# transforms.RandomCrop((224, 224), padding=14),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToTensor(), transforms.ToTensor(),
normalizer normalizer
]) ])
processed_dir = './processed-data' training_data_dir = args.training_data_dir
tongue_dataset_train = Image_DataSet(processed_dir=processed_dir, transform=transform_train) tongue_dataset_train = Image_DataSet(processed_dir=training_data_dir, transform=transform_train)
tongue_train_loader = torch.utils.data.DataLoader(tongue_dataset_train, batch_size=args.batch_size, shuffle=True, tongue_train_loader = torch.utils.data.DataLoader(tongue_dataset_train, batch_size=args.batch_size, shuffle=True,
**kwargs) **kwargs)
# test data
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.Resize((298, 224)),
transforms.ToTensor(),
normalizer
])
test_data_dir = args.test_data_dir
tongue_dataset_test = Image_DataSet(processed_dir=test_data_dir, transform=transform_test)
tongue_test_loader = torch.utils.data.DataLoader(tongue_dataset_test, batch_size=args.batch_size, shuffle=False,
**kwargs)
def get_model(model_name): def get_model(model_name):
...@@ -78,10 +92,6 @@ def get_model(model_name): ...@@ -78,10 +92,6 @@ def get_model(model_name):
model = models.vgg19_bn(pretrained=True) model = models.vgg19_bn(pretrained=True)
model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=NUM_CLASSES, bias=True) model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=NUM_CLASSES, bias=True)
return model.to(device) return model.to(device)
elif model_name == 'inception_v3':
model = models.inception_v3(pretrained=True)
model.fc = torch.nn.Linear(in_features=2048, out_features=1000, bias=True)
return model.to(device)
else: else:
raise ValueError('Unsupport model: {0}', model_name) raise ValueError('Unsupport model: {0}', model_name)
...@@ -176,7 +186,7 @@ def main(): ...@@ -176,7 +186,7 @@ def main():
# evaluation on natural examples # evaluation on natural examples
print('================================================================') print('================================================================')
eval_train(model, tongue_train_loader) eval_train(model, tongue_train_loader)
# eval_test(model, test_loader) eval_test(model, tongue_test_loader)
print('================================================================') print('================================================================')
# save checkpoint # save checkpoint
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment