Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
T
tongue-diagnosis
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhiyang.zhou
tongue-diagnosis
Commits
3c9c5251
Commit
3c9c5251
authored
Jul 14, 2021
by
zhiyang.zhou
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
sehtai
parent
de5d8ac1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
15 deletions
+22
-15
img_dataset.py
img_dataset.py
+17
-10
train_tongue_diagnosis.py
train_tongue_diagnosis.py
+5
-5
No files found.
img_dataset.py
View file @
3c9c5251
...
@@ -8,6 +8,8 @@ import torch
...
@@ -8,6 +8,8 @@ import torch
property_idx_to_label
=
{
0
:
'淡红'
,
1
:
'淡白'
,
2
:
'红'
,
3
:
'淡紫'
,
4
:
'紫暗'
,
5
:
'紫红'
,
6
:
'青紫'
,
7
:
'其他'
}
property_idx_to_label
=
{
0
:
'淡红'
,
1
:
'淡白'
,
2
:
'红'
,
3
:
'淡紫'
,
4
:
'紫暗'
,
5
:
'紫红'
,
6
:
'青紫'
,
7
:
'其他'
}
property_label_to_idx
=
{
'淡红'
:
0
,
'淡白'
:
1
,
'红'
:
2
,
'淡紫'
:
3
,
'紫暗'
:
4
,
'紫红'
:
5
,
'青紫'
:
6
,
'其他'
:
7
}
property_label_to_idx
=
{
'淡红'
:
0
,
'淡白'
:
1
,
'红'
:
2
,
'淡紫'
:
3
,
'紫暗'
:
4
,
'紫红'
:
5
,
'青紫'
:
6
,
'其他'
:
7
}
# property_idx_to_shetai = {0: '润', 1: '欠润', 2: '滑', 3: '燥', 4: '糙'}
property_shetai_to_idx
=
{
'润'
:
1
,
'欠润'
:
2
,
'滑'
:
3
,
'燥'
:
4
,
'糙'
:
5
}
#需要再次减去1
class
Image_DataSet
(
torch
.
utils
.
data
.
Dataset
):
class
Image_DataSet
(
torch
.
utils
.
data
.
Dataset
):
...
@@ -21,7 +23,7 @@ class Image_DataSet(torch.utils.data.Dataset):
...
@@ -21,7 +23,7 @@ class Image_DataSet(torch.utils.data.Dataset):
img_file
=
os
.
path
.
join
(
self
.
processed_dir
,
imgs
[
index
])
img_file
=
os
.
path
.
join
(
self
.
processed_dir
,
imgs
[
index
])
source_image
=
PIL
.
Image
.
open
(
img_file
)
.
convert
(
'RGB'
)
source_image
=
PIL
.
Image
.
open
(
img_file
)
.
convert
(
'RGB'
)
#
source_image
=np.array(source_image)
#
tmp
=np.array(source_image)
# image = np.expand_dims(source_image.transpose((2, 0, 1)), 0)
# image = np.expand_dims(source_image.transpose((2, 0, 1)), 0)
if
self
.
transform
is
not
None
:
if
self
.
transform
is
not
None
:
image
=
self
.
transform
(
source_image
)
image
=
self
.
transform
(
source_image
)
...
@@ -69,12 +71,12 @@ def check_label(xlsx_path, img_file, property='舌色'):
...
@@ -69,12 +71,12 @@ def check_label(xlsx_path, img_file, property='舌色'):
if
is_digit
(
str
(
label
)):
if
is_digit
(
str
(
label
)):
label
=
int
(
label
)
label
=
int
(
label
)
else
:
else
:
if
label
in
property_
label
_to_idx
:
if
label
in
property_
shetai
_to_idx
:
label
=
property_
label
_to_idx
[
label
]
label
=
property_
shetai
_to_idx
[
label
]
else
:
else
:
print
(
'error label {}, which not found in property_
label
_to_idx'
.
format
(
label
))
print
(
'error label {}, which not found in property_
shetai
_to_idx'
.
format
(
label
))
# raise ValueError(
# raise ValueError(
# 'error label {}, which not found in property_
label
_to_idx'.format(label))
# 'error label {}, which not found in property_
shetai
_to_idx'.format(label))
return
False
,
tmp_name
,
None
return
False
,
tmp_name
,
None
print
(
'I find [{0} → {1}] in [{2}].[{3}]'
.
format
(
tmp_name
,
label
,
sheet_name
,
col_nm2
))
print
(
'I find [{0} → {1}] in [{2}].[{3}]'
.
format
(
tmp_name
,
label
,
sheet_name
,
col_nm2
))
return
True
,
tmp_name
,
label
return
True
,
tmp_name
,
label
...
@@ -82,12 +84,12 @@ def check_label(xlsx_path, img_file, property='舌色'):
...
@@ -82,12 +84,12 @@ def check_label(xlsx_path, img_file, property='舌色'):
def
preprocess_data
():
def
preprocess_data
():
processed_dir
=
'./processed-data'
processed_dir
=
'./processed-data
-shetai
'
if
not
os
.
path
.
exists
(
processed_dir
):
if
not
os
.
path
.
exists
(
processed_dir
):
os
.
makedirs
(
processed_dir
)
os
.
makedirs
(
processed_dir
)
rootdir
=
'
./tongue-tops
'
rootdir
=
'
E:
\\
workspace
\\
zhongyi
'
property
=
'舌
色
'
property
=
'舌
苔润燥
'
list
=
os
.
listdir
(
rootdir
)
# 列出文件夹下所有的目录与文件
list
=
os
.
listdir
(
rootdir
)
# 列出文件夹下所有的目录与文件
for
i
in
range
(
0
,
len
(
list
)):
for
i
in
range
(
0
,
len
(
list
)):
path
=
os
.
path
.
join
(
rootdir
,
list
[
i
])
path
=
os
.
path
.
join
(
rootdir
,
list
[
i
])
...
@@ -111,10 +113,15 @@ def preprocess_data():
...
@@ -111,10 +113,15 @@ def preprocess_data():
img_file_only
=
img_files
[
k
]
img_file_only
=
img_files
[
k
]
islabeled
,
name
,
label
=
check_label
(
xlsx_path
,
img_file_only
,
property
=
property
)
islabeled
,
name
,
label
=
check_label
(
xlsx_path
,
img_file_only
,
property
=
property
)
if
islabeled
and
label
!=
None
:
if
islabeled
and
label
!=
None
:
if
label
>=
8
:
if
label
>=
6
:
continue
continue
if
list
[
i
]
==
'软件所-第一批数据'
:
new_label
=
label
else
:
new_label
=
label
-
1
print
(
'原始label: {0} 重写为 {1}'
.
format
(
label
,
new_label
))
tmpstrs
=
os
.
path
.
splitext
(
img_file_only
)
tmpstrs
=
os
.
path
.
splitext
(
img_file_only
)
dst_img_file_only
=
tmpstrs
[
0
]
+
'-'
+
property
+
'_'
+
str
(
label
)
+
tmpstrs
[
1
]
dst_img_file_only
=
tmpstrs
[
0
]
+
'-'
+
property
+
'_'
+
str
(
new_
label
)
+
tmpstrs
[
1
]
dst_img_file
=
os
.
path
.
join
(
processed_dir
,
dst_img_file_only
)
dst_img_file
=
os
.
path
.
join
(
processed_dir
,
dst_img_file_only
)
copyfile
(
img_file
,
dst_img_file
)
copyfile
(
img_file
,
dst_img_file
)
...
...
train_tongue_diagnosis.py
View file @
3c9c5251
...
@@ -20,14 +20,14 @@ parser.add_argument('--no_cuda', action='store_true', default=False, help='disab
...
@@ -20,14 +20,14 @@ parser.add_argument('--no_cuda', action='store_true', default=False, help='disab
parser
.
add_argument
(
'-b'
,
'--batch_size'
,
default
=
64
,
type
=
int
,
help
=
'mini-batch size (default: 64)'
)
parser
.
add_argument
(
'-b'
,
'--batch_size'
,
default
=
64
,
type
=
int
,
help
=
'mini-batch size (default: 64)'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.001
,
metavar
=
'LR'
,
help
=
'learning rate'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.001
,
metavar
=
'LR'
,
help
=
'learning rate'
)
parser
.
add_argument
(
'--momentum'
,
type
=
float
,
default
=
0.9
,
metavar
=
'M'
,
help
=
'SGD momentum'
)
parser
.
add_argument
(
'--momentum'
,
type
=
float
,
default
=
0.9
,
metavar
=
'M'
,
help
=
'SGD momentum'
)
parser
.
add_argument
(
'--weight_decay'
,
'--wd'
,
default
=
2e-4
,
type
=
float
,
metavar
=
'W'
)
parser
.
add_argument
(
'--weight_decay'
,
'--wd'
,
default
=
0
,
type
=
float
,
metavar
=
'W'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
40
,
metavar
=
'N'
,
help
=
'number of epochs to train'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
40
,
metavar
=
'N'
,
help
=
'number of epochs to train'
)
parser
.
add_argument
(
'--schedule'
,
type
=
int
,
nargs
=
'+'
,
default
=
[
25
,
30
,
35
],
parser
.
add_argument
(
'--schedule'
,
type
=
int
,
nargs
=
'+'
,
default
=
[
25
,
30
,
35
],
help
=
'Decrease learning rate at these epochs.'
)
help
=
'Decrease learning rate at these epochs.'
)
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
0.2
,
help
=
'LR is multiplied by gamma on schedule.'
)
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
0.2
,
help
=
'LR is multiplied by gamma on schedule.'
)
parser
.
add_argument
(
'--model_dir'
,
default
=
'dnn_models/tongue_modes/'
,
help
=
'directory of model for saving checkpoint'
)
parser
.
add_argument
(
'--model_dir'
,
default
=
'dnn_models/tongue_modes/
shetai/
'
,
help
=
'directory of model for saving checkpoint'
)
parser
.
add_argument
(
'--training_data_dir'
,
default
=
'./processed-data'
,
help
=
'directory of training data'
)
parser
.
add_argument
(
'--training_data_dir'
,
default
=
'./processed-data
-shetai
'
,
help
=
'directory of training data'
)
parser
.
add_argument
(
'--test_data_dir'
,
default
=
'./
processed-testdata
'
,
help
=
'directory of test data'
)
parser
.
add_argument
(
'--test_data_dir'
,
default
=
'./
shetai-test
'
,
help
=
'directory of test data'
)
parser
.
add_argument
(
'--save_freq'
,
'-s'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'save frequency'
)
parser
.
add_argument
(
'--save_freq'
,
'-s'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'save frequency'
)
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
10
,
metavar
=
'N'
,
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
10
,
metavar
=
'N'
,
help
=
'how many batches to wait before logging training status'
)
help
=
'how many batches to wait before logging training status'
)
...
@@ -61,7 +61,7 @@ transform_train = transforms.Compose([
...
@@ -61,7 +61,7 @@ transform_train = transforms.Compose([
transforms
.
Resize
((
224
)),
transforms
.
Resize
((
224
)),
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
ToTensor
(),
normalizer
#
normalizer
])
])
training_data_dir
=
args
.
training_data_dir
training_data_dir
=
args
.
training_data_dir
tongue_dataset_train
=
Image_DataSet
(
processed_dir
=
training_data_dir
,
transform
=
transform_train
)
tongue_dataset_train
=
Image_DataSet
(
processed_dir
=
training_data_dir
,
transform
=
transform_train
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment