Commit 098dc40b authored by ZhouZhiyang's avatar ZhouZhiyang

支持是否使用pretrain

parent 3c9c5251
...@@ -163,7 +163,62 @@ def crop_img(): ...@@ -163,7 +163,62 @@ def crop_img():
region.save(croped_img_file) region.save(croped_img_file)
def stats():
if __name__ == "__main__": excel_name_cnt = 0
preprocess_data() excel_2_file_name_cnt = 0
# crop_img() rootdir = 'E:\\workspace\\zhongyi'
dirs = os.listdir(rootdir) # 列出文件夹下所有的目录与文件
for i in range(0, len(dirs)):
path = os.path.join(rootdir, dirs[i])
if os.path.isdir(path):
# take image dir and xlsx file
sub_paths = os.listdir(path)
img_paths = []
xlsx_path = None
for j in range(0, len(sub_paths)):
sub_path = os.path.join(path, sub_paths[j])
if os.path.isdir(sub_path):
img_paths.append(sub_path)
elif os.path.isfile(sub_path):
xlsx_path = sub_path
img_files=[]
for img_path in img_paths:
cur_dir_img_files = os.listdir(img_path)
img_files.extend(cur_dir_img_files)
print('在[{0}]下总共有[{1}]张图片'.format(path, len(img_files)))
# 根据excel的病人名字去sub_path找对应的名字
patient_info = pd.read_excel(xlsx_path, sheet_name=None)
patient_names=[]
for sheet_name, sheet in patient_info.items():
is_find_name = False
for col_nm1 in sheet.columns:
if '姓名' in col_nm1:
is_find_name=True
cur_clo_names=sheet[col_nm1].values
# print('-->在[{0}].[{1}].[{2}]找到[{3}]人'.format(xlsx_path, sheet_name, col_nm1, len(cur_clo_names)))
patient_names.extend(cur_clo_names)
if not is_find_name:
print('WARNING, 在[{0}].[{1}].[{2}]没找到包含[姓名]字眼的属性'.format(xlsx_path, sheet_name, col_nm1))
patient_names=list(set(patient_names))
print('-->在[{0}]找到[{1}]人'.format(xlsx_path, len(patient_names)))
find_name_cnt = 0
for name in patient_names:
if isinstance(name, float):
# print('error name:', name)
continue
for img in img_files:
if str(name) in img:
find_name_cnt += 1
break
print('-->根据excel姓名找图找到{0}/{1}人'.format(find_name_cnt, len(patient_names)))
excel_name_cnt += len(patient_names)
excel_2_file_name_cnt += find_name_cnt
print()
print('-->汇总:根据excel姓名找图找到{0}/{1}人'.format(excel_2_file_name_cnt, excel_name_cnt))
# if __name__ == "__main__":
# # preprocess_data()
# # crop_img()
# stats()
...@@ -8,7 +8,8 @@ import numpy as np ...@@ -8,7 +8,8 @@ import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from img_dataset import Image_DataSet from img_dataset import Image_DataSet
from models import wideresnet, resnet # from torch.autograd import Variable
# from graphviz import Digraph
from torchvision import models from torchvision import models
...@@ -22,9 +23,9 @@ parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learn ...@@ -22,9 +23,9 @@ parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learn
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum')
parser.add_argument('--weight_decay', '--wd', default=0, type=float, metavar='W') parser.add_argument('--weight_decay', '--wd', default=0, type=float, metavar='W')
parser.add_argument('--epochs', type=int, default=40, metavar='N', help='number of epochs to train') parser.add_argument('--epochs', type=int, default=40, metavar='N', help='number of epochs to train')
parser.add_argument('--schedule', type=int, nargs='+', default=[25, 30, 35], parser.add_argument('--schedule', type=int, nargs='+', default=[10, 20, 30],
help='Decrease learning rate at these epochs.') help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.2, help='LR is multiplied by gamma on schedule.') parser.add_argument('--gamma', type=float, default=0.5, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--model_dir', default='dnn_models/tongue_modes/shetai/', help='directory of model for saving checkpoint') parser.add_argument('--model_dir', default='dnn_models/tongue_modes/shetai/', help='directory of model for saving checkpoint')
parser.add_argument('--training_data_dir', default='./processed-data-shetai', help='directory of training data') parser.add_argument('--training_data_dir', default='./processed-data-shetai', help='directory of training data')
parser.add_argument('--test_data_dir', default='./shetai-test', help='directory of test data') parser.add_argument('--test_data_dir', default='./shetai-test', help='directory of test data')
...@@ -32,6 +33,8 @@ parser.add_argument('--save_freq', '-s', default=1, type=int, metavar='N', help= ...@@ -32,6 +33,8 @@ parser.add_argument('--save_freq', '-s', default=1, type=int, metavar='N', help=
parser.add_argument('--log_interval', type=int, default=10, metavar='N', parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status') help='how many batches to wait before logging training status')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--pretrain', action='store_true', default=True,
help='whether to use pretrain')
args = parser.parse_args() args = parser.parse_args()
...@@ -83,15 +86,15 @@ tongue_test_loader = torch.utils.data.DataLoader(tongue_dataset_test, batch_size ...@@ -83,15 +86,15 @@ tongue_test_loader = torch.utils.data.DataLoader(tongue_dataset_test, batch_size
def get_model(model_name): def get_model(model_name):
if model_name == 'densenet121': if model_name == 'densenet121':
model = models.densenet121(pretrained=True) model = models.densenet121(pretrained=args.pretrain)
model.classifier = torch.nn.Linear(in_features=1024, out_features=NUM_CLASSES, bias=True) model.classifier = torch.nn.Linear(in_features=1024, out_features=NUM_CLASSES, bias=True)
return model.to(device) return model.to(device)
elif model_name == 'resnet-50': elif model_name == 'resnet-50':
model = models.resnet50(pretrained=True) model = models.resnet50(pretrained=args.pretrain)
model.fc = torch.nn.Linear(in_features=2048, out_features=NUM_CLASSES, bias=True) model.fc = torch.nn.Linear(in_features=2048, out_features=NUM_CLASSES, bias=True)
return model.to(device) return model.to(device)
elif model_name == 'vgg19_bn': elif model_name == 'vgg19_bn':
model = models.vgg19_bn(pretrained=True) model = models.vgg19_bn(pretrained=args.pretrain)
model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=NUM_CLASSES, bias=True) model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=NUM_CLASSES, bias=True)
return model.to(device) return model.to(device)
else: else:
...@@ -171,6 +174,77 @@ def eval_test(model, test_loader): ...@@ -171,6 +174,77 @@ def eval_test(model, test_loader):
return test_loss, test_accuracy return test_loss, test_accuracy
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot
def print_model_arch(model):
x = torch.randn(1, 3, 298, 224).to(device)
y = model(x)
g = make_dot(y)
g.view()
# from torchsummary import summary
# summary(model, (3, 224, 224))
params = list(model.parameters())
k = 0
for i in params:
l = 1
print("该层的结构:" + str(list(i.size())))
for j in i.size():
l *= j
print("该层参数和:" + str(l))
k = k + l
print("总参数数量和:" + str(k))
def main(): def main():
print(args) print(args)
model = get_model(args.model_name).to(device) model = get_model(args.model_name).to(device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment