Commit 9a3e236d authored by zhiyang.zhou's avatar zhiyang.zhou

Initial commit

parents
Pipeline #188 canceled with stages
import PIL
import numpy as np
import os
import pandas as pd
from shutil import copyfile
import torch
property_idx_to_label = {0: '淡红', 1: '淡白', 2: '红', 3: '淡紫', 4: '紫暗', 5: '紫红', 6: '青紫', 7: '其他'}
property_label_to_idx = {'淡红': 0, '淡白': 1, '红': 2, '淡紫': 3, '紫暗': 4, '紫红': 5, '青紫': 6, '其他': 7}
class Image_DataSet(torch.utils.data.Dataset):
def __init__(self, transform=None):
self.processed_dir = './processed-data'
self.transform = transform
def __getitem__(self, index):
imgs = os.listdir(self.processed_dir)
index = index % (len(imgs) - 1)
img_file = os.path.join(self.processed_dir, imgs[index])
source_image = PIL.Image.open(img_file)
# source_image=np.array(source_image)
# image = np.expand_dims(source_image.transpose((2, 0, 1)), 0)
if self.transform is not None:
image = self.transform(source_image)
img_file_only = imgs[index]
tmpstrs = os.path.splitext(img_file_only)[0].split('-')
y = int(tmpstrs[-1].split('_')[-1])
return image, y
def __len__(self):
imgs = os.listdir(self.processed_dir)
return len(imgs)
def check_label(xlsx_path, img_file, property='舌色'):
def is_digit(s):
try:
float(s) # is a number(either integer or real)
return True
except:
return False
patient_info = pd.read_excel(xlsx_path, sheet_name=None)
patient_name = img_file.split('-')[0]
print('searching [{}] in [{}]'.format(patient_name, xlsx_path))
for sheet_name, sheet in patient_info.items():
# print(v)
for col_nm1 in sheet.columns:
if property in col_nm1:
# print('I find property [{0}] in [{1}].[{2}]'.format(property, sheet_name, col_nm1))
for col_nm2 in sheet.columns:
if '姓名' in col_nm2:
flags = (patient_name == sheet[col_nm2])
tmp_name = sheet[col_nm2][flags].values
label = sheet[col_nm1][flags].values
if len(tmp_name) == 1 and len(label) == 1:
tmp_name = tmp_name[0]
label = label[0]
# if type(label) == float64:
# label=int(label)
if is_digit(str(label)):
label = int(label)
else:
if label in property_label_to_idx:
label = property_label_to_idx[label]
else:
raise ValueError(
'error label {}, which not found in property_label_to_idx'.format(label))
print('I find [{0} → {1}] in [{2}].[{3}]'.format(tmp_name, label, sheet_name, col_nm2))
return True, tmp_name, label
return False, patient_name, None
def preprocess_data():
rootdir = './tongue-tops'
property = '舌色'
list = os.listdir(rootdir) # 列出文件夹下所有的目录与文件
for i in range(0, len(list)):
path = os.path.join(rootdir, list[i])
if os.path.isdir(path):
# take image dir and xlsx file
sub_paths = os.listdir(path)
img_path = None
xlsx_path = None
for j in range(0, len(sub_paths)):
sub_path = os.path.join(path, sub_paths[j])
if os.path.isdir(sub_path):
img_path = sub_path
elif os.path.isfile(sub_path):
xlsx_path = sub_path
# take patient info
# take images
processed_dir = './processed-data'
if not os.path.exists(processed_dir):
os.makedirs(processed_dir)
img_files = os.listdir(img_path)
for k in range(0, len(img_files)):
img_file = os.path.join(img_path, img_files[k])
img_file_only = img_files[k]
islabeled, name, label = check_label(xlsx_path, img_file_only, property=property)
if islabeled:
tmpstrs = os.path.splitext(img_file_only)
dst_img_file_only = tmpstrs[0] + '-' + property + '_' + str(label) + tmpstrs[1]
dst_img_file = os.path.join(processed_dir, dst_img_file_only)
copyfile(img_file, dst_img_file)
# if __name__ == "__main__":
# preprocess_data()
import argparse
import os
import torch.utils.data
import torchvision.transforms as transforms
from torch import optim
from torchvision import datasets
import torch.nn.functional as F
from img_dataset import Image_DataSet
from models import wideresnet, resnet
parser = argparse.ArgumentParser(description='PyTorch DenseNet Training')
parser.add_argument('--model_name', default='wide_resnet-34x10',
help='model name, cifar-10 models: wide_resnet-34x10, cifar-100 models: resnet-50, resnet-101,'
'resnet-152')
parser.add_argument('--gpuid', default='0', type=str, help='which gpu to use')
parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('-b', '--batch_size', default=64, type=int, help='mini-batch size (default: 64)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum')
parser.add_argument('--weight_decay', '--wd', default=2e-4, type=float, metavar='W')
parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train')
parser.add_argument('--schedule', type=int, nargs='+', default=[100, 150],
help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--model_dir', default='dnn_models/tongue_modes/', help='directory of model for saving checkpoint')
parser.add_argument('--save_freq', '-s', default=1, type=int, metavar='N', help='save frequency')
args = parser.parse_args()
NUM_CLASSES = 7
use_cuda = not args.no_cuda and torch.cuda.is_available()
GPUID = args.gpuid
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPUID)
device = torch.device("cuda" if use_cuda else "cpu")
torch.manual_seed(1)
# settings
model_dir = args.model_dir
if not os.path.exists(model_dir):
os.makedirs(model_dir)
# training data and test data settings
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
normalizer = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.RandomResizedCrop((224, 168)),
transforms.RandomCrop((224, 168), padding=14),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# normalizer
])
tongue_dataset_train = Image_DataSet(transform=transform_train)
tongue_train_loader = torch.utils.data.DataLoader(tongue_dataset_train, batch_size=args.batch_size, shuffle=True,
**kwargs)
def get_model(model_name, num_classes=NUM_CLASSES):
if model_name == 'wide_resnet-34x10':
return wideresnet.WideResNet(num_classes=num_classes)
elif model_name == 'resnet-50':
return resnet.ResNet50(num_classes=num_classes)
elif model_name == 'resnet-101':
return resnet.ResNet101(num_classes=num_classes)
elif model_name == 'resnet-152':
return resnet.ResNet152(num_classes=num_classes)
else:
raise ValueError('Unsupport model: {0}', model_name)
def adjust_learning_rate(optimizer, epoch):
"""decrease the learning rate"""
epoch_lr = args.lr
for i in range(0, len(args.schedule)):
if epoch >= args.schedule[i]:
epoch_lr = args.lr * np.power(args.gamma, (i + 1))
for param_group in optimizer.param_groups:
param_group['lr'] = epoch_lr
return epoch_lr
def train(args, model, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, y) in enumerate(train_loader):
if use_cuda:
data, y = data.cuda(non_blocking=True), y.cuda(non_blocking=True)
optimizer.zero_grad()
# calculate robust loss
model.train()
nat_logits = model(data)
cur_batch_size = len(data)
loss = (1.0 / cur_batch_size) * F.cross_entropy(nat_logits, y)
loss.backward()
optimizer.step()
# print progress
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def eval_train(model, train_loader):
model.eval()
train_loss = 0
correct = 0
with torch.no_grad():
for data, target in train_loader:
data, target = data.to(device), target.to(device)
output = model(data)
train_loss += F.cross_entropy(output, target, size_average=False).item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
train_loss /= len(train_loader.dataset)
print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
train_loss, correct, len(train_loader.dataset),
100. * correct / len(train_loader.dataset)))
training_accuracy = correct / len(train_loader.dataset)
return train_loss, training_accuracy
def eval_test(model, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, size_average=False).item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
test_accuracy = correct / len(test_loader.dataset)
return test_loss, test_accuracy
def main():
print(args)
model = get_model(args.model_name, num_classes=NUM_CLASSES).to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
for epoch in range(1, args.epochs + 1):
# adjust learning rate for SGD
adjust_learning_rate(optimizer, epoch)
# adversarial training
train(args, model, tongue_train_loader, optimizer, epoch)
# evaluation on natural examples
print('================================================================')
eval_train(model, tongue_train_loader)
# eval_test(model, test_loader)
print('================================================================')
# save checkpoint
if epoch % args.save_freq == 0:
torch.save(model.state_dict(),
os.path.join(model_dir, 'model-{}-epoch{}.pt'.format(args.model_name, epoch)))
torch.save(optimizer.state_dict(),
os.path.join(model_dir, 'opt-{}-cpt_epoch{}.pt'.format(args.model_name, epoch)))
if __name__ == '__main__':
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment