Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
T
tongue-diagnosis
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhiyang.zhou
tongue-diagnosis
Commits
103c8ddd
Commit
103c8ddd
authored
Jul 09, 2021
by
zhiyang.zhou
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
支持指定训练集和测试集目录
parent
445d587e
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
17 deletions
+28
-17
img_dataset.py
img_dataset.py
+8
-7
train_tongue_diagnosis.py
train_tongue_diagnosis.py
+20
-10
No files found.
img_dataset.py
View file @
103c8ddd
...
@@ -48,7 +48,7 @@ def check_label(xlsx_path, img_file, property='舌色'):
...
@@ -48,7 +48,7 @@ def check_label(xlsx_path, img_file, property='舌色'):
patient_info
=
pd
.
read_excel
(
xlsx_path
,
sheet_name
=
None
)
patient_info
=
pd
.
read_excel
(
xlsx_path
,
sheet_name
=
None
)
patient_name
=
img_file
.
split
(
'-'
)[
0
]
patient_name
=
img_file
.
split
(
'-'
)[
0
]
print
(
'searching [{}] in [{}]'
.
format
(
patient_nam
e
,
xlsx_path
))
print
(
'searching [{}] in [{}]'
.
format
(
img_fil
e
,
xlsx_path
))
for
sheet_name
,
sheet
in
patient_info
.
items
():
for
sheet_name
,
sheet
in
patient_info
.
items
():
# print(v)
# print(v)
for
col_nm1
in
sheet
.
columns
:
for
col_nm1
in
sheet
.
columns
:
...
@@ -64,6 +64,8 @@ def check_label(xlsx_path, img_file, property='舌色'):
...
@@ -64,6 +64,8 @@ def check_label(xlsx_path, img_file, property='舌色'):
label
=
label
[
0
]
label
=
label
[
0
]
# if type(label) == float64:
# if type(label) == float64:
# label=int(label)
# label=int(label)
if
pd
.
isna
(
label
):
return
False
,
patient_name
,
None
if
is_digit
(
str
(
label
)):
if
is_digit
(
str
(
label
)):
label
=
int
(
label
)
label
=
int
(
label
)
else
:
else
:
...
@@ -73,14 +75,17 @@ def check_label(xlsx_path, img_file, property='舌色'):
...
@@ -73,14 +75,17 @@ def check_label(xlsx_path, img_file, property='舌色'):
print
(
'error label {}, which not found in property_label_to_idx'
.
format
(
label
))
print
(
'error label {}, which not found in property_label_to_idx'
.
format
(
label
))
# raise ValueError(
# raise ValueError(
# 'error label {}, which not found in property_label_to_idx'.format(label))
# 'error label {}, which not found in property_label_to_idx'.format(label))
label
=
None
return
False
,
tmp_name
,
None
return
False
,
tmp_name
,
label
print
(
'I find [{0} → {1}] in [{2}].[{3}]'
.
format
(
tmp_name
,
label
,
sheet_name
,
col_nm2
))
print
(
'I find [{0} → {1}] in [{2}].[{3}]'
.
format
(
tmp_name
,
label
,
sheet_name
,
col_nm2
))
return
True
,
tmp_name
,
label
return
True
,
tmp_name
,
label
return
False
,
patient_name
,
None
return
False
,
patient_name
,
None
def
preprocess_data
():
def
preprocess_data
():
processed_dir
=
'./processed-data'
if
not
os
.
path
.
exists
(
processed_dir
):
os
.
makedirs
(
processed_dir
)
rootdir
=
'./tongue-tops'
rootdir
=
'./tongue-tops'
property
=
'舌色'
property
=
'舌色'
list
=
os
.
listdir
(
rootdir
)
# 列出文件夹下所有的目录与文件
list
=
os
.
listdir
(
rootdir
)
# 列出文件夹下所有的目录与文件
...
@@ -100,11 +105,7 @@ def preprocess_data():
...
@@ -100,11 +105,7 @@ def preprocess_data():
# take images
# take images
for
img_path
in
img_paths
:
for
img_path
in
img_paths
:
processed_dir
=
'./processed-data'
if
not
os
.
path
.
exists
(
processed_dir
):
os
.
makedirs
(
processed_dir
)
img_files
=
os
.
listdir
(
img_path
)
img_files
=
os
.
listdir
(
img_path
)
for
k
in
range
(
0
,
len
(
img_files
)):
for
k
in
range
(
0
,
len
(
img_files
)):
img_file
=
os
.
path
.
join
(
img_path
,
img_files
[
k
])
img_file
=
os
.
path
.
join
(
img_path
,
img_files
[
k
])
img_file_only
=
img_files
[
k
]
img_file_only
=
img_files
[
k
]
...
...
train_tongue_diagnosis.py
View file @
103c8ddd
...
@@ -12,9 +12,9 @@ from models import wideresnet, resnet
...
@@ -12,9 +12,9 @@ from models import wideresnet, resnet
from
torchvision
import
models
from
torchvision
import
models
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch
DenseNet
Training'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch
Tongue-diagnosis
Training'
)
parser
.
add_argument
(
'--model_name'
,
default
=
'resnet-50'
,
parser
.
add_argument
(
'--model_name'
,
default
=
'resnet-50'
,
help
=
'models: densenet121, resnet-50,
resnet-101, inception_v3,
vgg19_bn'
)
help
=
'models: densenet121, resnet-50, vgg19_bn'
)
parser
.
add_argument
(
'--gpuid'
,
default
=
'0'
,
type
=
str
,
help
=
'which gpu to use'
)
parser
.
add_argument
(
'--gpuid'
,
default
=
'0'
,
type
=
str
,
help
=
'which gpu to use'
)
parser
.
add_argument
(
'--no_cuda'
,
action
=
'store_true'
,
default
=
False
,
help
=
'disables CUDA training'
)
parser
.
add_argument
(
'--no_cuda'
,
action
=
'store_true'
,
default
=
False
,
help
=
'disables CUDA training'
)
parser
.
add_argument
(
'-b'
,
'--batch_size'
,
default
=
64
,
type
=
int
,
help
=
'mini-batch size (default: 64)'
)
parser
.
add_argument
(
'-b'
,
'--batch_size'
,
default
=
64
,
type
=
int
,
help
=
'mini-batch size (default: 64)'
)
...
@@ -26,6 +26,8 @@ parser.add_argument('--schedule', type=int, nargs='+', default=[25, 30, 35],
...
@@ -26,6 +26,8 @@ parser.add_argument('--schedule', type=int, nargs='+', default=[25, 30, 35],
help
=
'Decrease learning rate at these epochs.'
)
help
=
'Decrease learning rate at these epochs.'
)
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
0.2
,
help
=
'LR is multiplied by gamma on schedule.'
)
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
0.2
,
help
=
'LR is multiplied by gamma on schedule.'
)
parser
.
add_argument
(
'--model_dir'
,
default
=
'dnn_models/tongue_modes/'
,
help
=
'directory of model for saving checkpoint'
)
parser
.
add_argument
(
'--model_dir'
,
default
=
'dnn_models/tongue_modes/'
,
help
=
'directory of model for saving checkpoint'
)
parser
.
add_argument
(
'--training_data_dir'
,
default
=
'./processed-data'
,
help
=
'directory of training data'
)
parser
.
add_argument
(
'--test_data_dir'
,
default
=
'./processed-testdata'
,
help
=
'directory of test data'
)
parser
.
add_argument
(
'--save_freq'
,
'-s'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'save frequency'
)
parser
.
add_argument
(
'--save_freq'
,
'-s'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'save frequency'
)
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
10
,
metavar
=
'N'
,
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
10
,
metavar
=
'N'
,
help
=
'how many batches to wait before logging training status'
)
help
=
'how many batches to wait before logging training status'
)
...
@@ -52,19 +54,31 @@ torch.backends.cudnn.enabled = False
...
@@ -52,19 +54,31 @@ torch.backends.cudnn.enabled = False
kwargs
=
{
'num_workers'
:
4
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
kwargs
=
{
'num_workers'
:
4
,
'pin_memory'
:
True
}
if
use_cuda
else
{}
normalizer
=
transforms
.
Normalize
(
mean
=
[
x
/
255.0
for
x
in
[
125.3
,
123.0
,
113.9
]],
normalizer
=
transforms
.
Normalize
(
mean
=
[
x
/
255.0
for
x
in
[
125.3
,
123.0
,
113.9
]],
std
=
[
x
/
255.0
for
x
in
[
63.0
,
62.1
,
66.7
]])
std
=
[
x
/
255.0
for
x
in
[
63.0
,
62.1
,
66.7
]])
#trining data
transform_train
=
transforms
.
Compose
([
transform_train
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
transforms
.
ToPILImage
(),
transforms
.
ToPILImage
(),
transforms
.
Resize
((
224
)),
transforms
.
Resize
((
224
)),
# transforms.RandomCrop((224, 224), padding=14),
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
normalizer
normalizer
])
])
processed_dir
=
'./processed-data'
training_data_dir
=
args
.
training_data_dir
tongue_dataset_train
=
Image_DataSet
(
processed_dir
=
processed
_dir
,
transform
=
transform_train
)
tongue_dataset_train
=
Image_DataSet
(
processed_dir
=
training_data
_dir
,
transform
=
transform_train
)
tongue_train_loader
=
torch
.
utils
.
data
.
DataLoader
(
tongue_dataset_train
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
tongue_train_loader
=
torch
.
utils
.
data
.
DataLoader
(
tongue_dataset_train
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
**
kwargs
)
**
kwargs
)
# test data
transform_test
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
ToPILImage
(),
transforms
.
Resize
((
298
,
224
)),
transforms
.
ToTensor
(),
normalizer
])
test_data_dir
=
args
.
test_data_dir
tongue_dataset_test
=
Image_DataSet
(
processed_dir
=
test_data_dir
,
transform
=
transform_test
)
tongue_test_loader
=
torch
.
utils
.
data
.
DataLoader
(
tongue_dataset_test
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
**
kwargs
)
def
get_model
(
model_name
):
def
get_model
(
model_name
):
...
@@ -78,10 +92,6 @@ def get_model(model_name):
...
@@ -78,10 +92,6 @@ def get_model(model_name):
model
=
models
.
vgg19_bn
(
pretrained
=
True
)
model
=
models
.
vgg19_bn
(
pretrained
=
True
)
model
.
classifier
[
6
]
=
torch
.
nn
.
Linear
(
in_features
=
4096
,
out_features
=
NUM_CLASSES
,
bias
=
True
)
model
.
classifier
[
6
]
=
torch
.
nn
.
Linear
(
in_features
=
4096
,
out_features
=
NUM_CLASSES
,
bias
=
True
)
return
model
.
to
(
device
)
return
model
.
to
(
device
)
elif
model_name
==
'inception_v3'
:
model
=
models
.
inception_v3
(
pretrained
=
True
)
model
.
fc
=
torch
.
nn
.
Linear
(
in_features
=
2048
,
out_features
=
1000
,
bias
=
True
)
return
model
.
to
(
device
)
else
:
else
:
raise
ValueError
(
'Unsupport model: {0}'
,
model_name
)
raise
ValueError
(
'Unsupport model: {0}'
,
model_name
)
...
@@ -176,7 +186,7 @@ def main():
...
@@ -176,7 +186,7 @@ def main():
# evaluation on natural examples
# evaluation on natural examples
print
(
'================================================================'
)
print
(
'================================================================'
)
eval_train
(
model
,
tongue_train_loader
)
eval_train
(
model
,
tongue_train_loader
)
# eval_test(model,
test_loader)
eval_test
(
model
,
tongue_
test_loader
)
print
(
'================================================================'
)
print
(
'================================================================'
)
# save checkpoint
# save checkpoint
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment