Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
T
tongue-diagnosis
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhiyang.zhou
tongue-diagnosis
Commits
098dc40b
Commit
098dc40b
authored
Jul 26, 2021
by
ZhouZhiyang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
支持是否使用pretrain
parent
3c9c5251
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
139 additions
and
10 deletions
+139
-10
img_dataset.py
img_dataset.py
+59
-4
train_tongue_diagnosis.py
train_tongue_diagnosis.py
+80
-6
No files found.
img_dataset.py
View file @
098dc40b
...
...
@@ -163,7 +163,62 @@ def crop_img():
region
.
save
(
croped_img_file
)
def
stats
():
excel_name_cnt
=
0
excel_2_file_name_cnt
=
0
rootdir
=
'E:
\\
workspace
\\
zhongyi'
dirs
=
os
.
listdir
(
rootdir
)
# 列出文件夹下所有的目录与文件
for
i
in
range
(
0
,
len
(
dirs
)):
path
=
os
.
path
.
join
(
rootdir
,
dirs
[
i
])
if
os
.
path
.
isdir
(
path
):
# take image dir and xlsx file
sub_paths
=
os
.
listdir
(
path
)
img_paths
=
[]
xlsx_path
=
None
for
j
in
range
(
0
,
len
(
sub_paths
)):
sub_path
=
os
.
path
.
join
(
path
,
sub_paths
[
j
])
if
os
.
path
.
isdir
(
sub_path
):
img_paths
.
append
(
sub_path
)
elif
os
.
path
.
isfile
(
sub_path
):
xlsx_path
=
sub_path
img_files
=
[]
for
img_path
in
img_paths
:
cur_dir_img_files
=
os
.
listdir
(
img_path
)
img_files
.
extend
(
cur_dir_img_files
)
print
(
'在[{0}]下总共有[{1}]张图片'
.
format
(
path
,
len
(
img_files
)))
# 根据excel的病人名字去sub_path找对应的名字
patient_info
=
pd
.
read_excel
(
xlsx_path
,
sheet_name
=
None
)
patient_names
=
[]
if
__name__
==
"__main__"
:
preprocess_data
()
# crop_img()
for
sheet_name
,
sheet
in
patient_info
.
items
():
is_find_name
=
False
for
col_nm1
in
sheet
.
columns
:
if
'姓名'
in
col_nm1
:
is_find_name
=
True
cur_clo_names
=
sheet
[
col_nm1
]
.
values
# print('-->在[{0}].[{1}].[{2}]找到[{3}]人'.format(xlsx_path, sheet_name, col_nm1, len(cur_clo_names)))
patient_names
.
extend
(
cur_clo_names
)
if
not
is_find_name
:
print
(
'WARNING, 在[{0}].[{1}].[{2}]没找到包含[姓名]字眼的属性'
.
format
(
xlsx_path
,
sheet_name
,
col_nm1
))
patient_names
=
list
(
set
(
patient_names
))
print
(
'-->在[{0}]找到[{1}]人'
.
format
(
xlsx_path
,
len
(
patient_names
)))
find_name_cnt
=
0
for
name
in
patient_names
:
if
isinstance
(
name
,
float
):
# print('error name:', name)
continue
for
img
in
img_files
:
if
str
(
name
)
in
img
:
find_name_cnt
+=
1
break
print
(
'-->根据excel姓名找图找到{0}/{1}人'
.
format
(
find_name_cnt
,
len
(
patient_names
)))
excel_name_cnt
+=
len
(
patient_names
)
excel_2_file_name_cnt
+=
find_name_cnt
print
()
print
(
'-->汇总:根据excel姓名找图找到{0}/{1}人'
.
format
(
excel_2_file_name_cnt
,
excel_name_cnt
))
# if __name__ == "__main__":
# # preprocess_data()
# # crop_img()
# stats()
train_tongue_diagnosis.py
View file @
098dc40b
...
...
@@ -8,7 +8,8 @@ import numpy as np
import
torch.nn.functional
as
F
from
img_dataset
import
Image_DataSet
from
models
import
wideresnet
,
resnet
# from torch.autograd import Variable
# from graphviz import Digraph
from
torchvision
import
models
...
...
@@ -22,9 +23,9 @@ parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learn
parser
.
add_argument
(
'--momentum'
,
type
=
float
,
default
=
0.9
,
metavar
=
'M'
,
help
=
'SGD momentum'
)
parser
.
add_argument
(
'--weight_decay'
,
'--wd'
,
default
=
0
,
type
=
float
,
metavar
=
'W'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
40
,
metavar
=
'N'
,
help
=
'number of epochs to train'
)
parser
.
add_argument
(
'--schedule'
,
type
=
int
,
nargs
=
'+'
,
default
=
[
25
,
30
,
35
],
parser
.
add_argument
(
'--schedule'
,
type
=
int
,
nargs
=
'+'
,
default
=
[
10
,
20
,
30
],
help
=
'Decrease learning rate at these epochs.'
)
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
0.
2
,
help
=
'LR is multiplied by gamma on schedule.'
)
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
0.
5
,
help
=
'LR is multiplied by gamma on schedule.'
)
parser
.
add_argument
(
'--model_dir'
,
default
=
'dnn_models/tongue_modes/shetai/'
,
help
=
'directory of model for saving checkpoint'
)
parser
.
add_argument
(
'--training_data_dir'
,
default
=
'./processed-data-shetai'
,
help
=
'directory of training data'
)
parser
.
add_argument
(
'--test_data_dir'
,
default
=
'./shetai-test'
,
help
=
'directory of test data'
)
...
...
@@ -32,6 +33,8 @@ parser.add_argument('--save_freq', '-s', default=1, type=int, metavar='N', help=
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
10
,
metavar
=
'N'
,
help
=
'how many batches to wait before logging training status'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1
,
metavar
=
'S'
,
help
=
'random seed (default: 1)'
)
parser
.
add_argument
(
'--pretrain'
,
action
=
'store_true'
,
default
=
True
,
help
=
'whether to use pretrain'
)
args
=
parser
.
parse_args
()
...
...
@@ -83,15 +86,15 @@ tongue_test_loader = torch.utils.data.DataLoader(tongue_dataset_test, batch_size
def
get_model
(
model_name
):
if
model_name
==
'densenet121'
:
model
=
models
.
densenet121
(
pretrained
=
True
)
model
=
models
.
densenet121
(
pretrained
=
args
.
pretrain
)
model
.
classifier
=
torch
.
nn
.
Linear
(
in_features
=
1024
,
out_features
=
NUM_CLASSES
,
bias
=
True
)
return
model
.
to
(
device
)
elif
model_name
==
'resnet-50'
:
model
=
models
.
resnet50
(
pretrained
=
True
)
model
=
models
.
resnet50
(
pretrained
=
args
.
pretrain
)
model
.
fc
=
torch
.
nn
.
Linear
(
in_features
=
2048
,
out_features
=
NUM_CLASSES
,
bias
=
True
)
return
model
.
to
(
device
)
elif
model_name
==
'vgg19_bn'
:
model
=
models
.
vgg19_bn
(
pretrained
=
True
)
model
=
models
.
vgg19_bn
(
pretrained
=
args
.
pretrain
)
model
.
classifier
[
6
]
=
torch
.
nn
.
Linear
(
in_features
=
4096
,
out_features
=
NUM_CLASSES
,
bias
=
True
)
return
model
.
to
(
device
)
else
:
...
...
@@ -171,6 +174,77 @@ def eval_test(model, test_loader):
return
test_loss
,
test_accuracy
def
make_dot
(
var
,
params
=
None
):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if
params
is
not
None
:
assert
isinstance
(
params
.
values
()[
0
],
Variable
)
param_map
=
{
id
(
v
):
k
for
k
,
v
in
params
.
items
()}
node_attr
=
dict
(
style
=
'filled'
,
shape
=
'box'
,
align
=
'left'
,
fontsize
=
'12'
,
ranksep
=
'0.1'
,
height
=
'0.2'
)
dot
=
Digraph
(
node_attr
=
node_attr
,
graph_attr
=
dict
(
size
=
"12,12"
))
seen
=
set
()
def
size_to_str
(
size
):
return
'('
+
(
', '
)
.
join
([
'
%
d'
%
v
for
v
in
size
])
+
')'
def
add_nodes
(
var
):
if
var
not
in
seen
:
if
torch
.
is_tensor
(
var
):
dot
.
node
(
str
(
id
(
var
)),
size_to_str
(
var
.
size
()),
fillcolor
=
'orange'
)
elif
hasattr
(
var
,
'variable'
):
u
=
var
.
variable
name
=
param_map
[
id
(
u
)]
if
params
is
not
None
else
''
node_name
=
'
%
s
\n
%
s'
%
(
name
,
size_to_str
(
u
.
size
()))
dot
.
node
(
str
(
id
(
var
)),
node_name
,
fillcolor
=
'lightblue'
)
else
:
dot
.
node
(
str
(
id
(
var
)),
str
(
type
(
var
)
.
__name__
))
seen
.
add
(
var
)
if
hasattr
(
var
,
'next_functions'
):
for
u
in
var
.
next_functions
:
if
u
[
0
]
is
not
None
:
dot
.
edge
(
str
(
id
(
u
[
0
])),
str
(
id
(
var
)))
add_nodes
(
u
[
0
])
if
hasattr
(
var
,
'saved_tensors'
):
for
t
in
var
.
saved_tensors
:
dot
.
edge
(
str
(
id
(
t
)),
str
(
id
(
var
)))
add_nodes
(
t
)
add_nodes
(
var
.
grad_fn
)
return
dot
def
print_model_arch
(
model
):
x
=
torch
.
randn
(
1
,
3
,
298
,
224
)
.
to
(
device
)
y
=
model
(
x
)
g
=
make_dot
(
y
)
g
.
view
()
# from torchsummary import summary
# summary(model, (3, 224, 224))
params
=
list
(
model
.
parameters
())
k
=
0
for
i
in
params
:
l
=
1
print
(
"该层的结构:"
+
str
(
list
(
i
.
size
())))
for
j
in
i
.
size
():
l
*=
j
print
(
"该层参数和:"
+
str
(
l
))
k
=
k
+
l
print
(
"总参数数量和:"
+
str
(
k
))
def
main
():
print
(
args
)
model
=
get_model
(
args
.
model_name
)
.
to
(
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment