Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
T
tongue-diagnosis
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
zhiyang.zhou
tongue-diagnosis
Commits
ca355dee
Commit
ca355dee
authored
Jul 08, 2021
by
zhiyang.zhou
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use pretrained models
parent
aa8f9999
Pipeline
#191
canceled with stages
Changes
2
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
12 deletions
+14
-12
img_dataset.py
img_dataset.py
+1
-1
train_tongue_diagnosis.py
train_tongue_diagnosis.py
+13
-11
No files found.
img_dataset.py
View file @
ca355dee
...
@@ -20,7 +20,7 @@ class Image_DataSet(torch.utils.data.Dataset):
...
@@ -20,7 +20,7 @@ class Image_DataSet(torch.utils.data.Dataset):
index
=
index
%
(
len
(
imgs
)
-
1
)
index
=
index
%
(
len
(
imgs
)
-
1
)
img_file
=
os
.
path
.
join
(
self
.
processed_dir
,
imgs
[
index
])
img_file
=
os
.
path
.
join
(
self
.
processed_dir
,
imgs
[
index
])
source_image
=
PIL
.
Image
.
open
(
img_file
)
source_image
=
PIL
.
Image
.
open
(
img_file
)
.
convert
(
'RGB'
)
# source_image=np.array(source_image)
# source_image=np.array(source_image)
# image = np.expand_dims(source_image.transpose((2, 0, 1)), 0)
# image = np.expand_dims(source_image.transpose((2, 0, 1)), 0)
if
self
.
transform
is
not
None
:
if
self
.
transform
is
not
None
:
...
...
train_tongue_diagnosis.py
View file @
ca355dee
...
@@ -10,10 +10,11 @@ import torch.nn.functional as F
...
@@ -10,10 +10,11 @@ import torch.nn.functional as F
from
img_dataset
import
Image_DataSet
from
img_dataset
import
Image_DataSet
from
models
import
wideresnet
,
resnet
from
models
import
wideresnet
,
resnet
from
torchvision
import
models
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch DenseNet Training'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch DenseNet Training'
)
parser
.
add_argument
(
'--model_name'
,
default
=
'resnet-50'
,
parser
.
add_argument
(
'--model_name'
,
default
=
'resnet-50'
,
help
=
'model name, cifar-10 models: wide_resnet-34x10, cifar-100 models: resnet-50, resnet-101,'
help
=
'models: densenet121, resnet-50, resnet-101, inception_v, vgg19_bn'
)
'resnet-152'
)
parser
.
add_argument
(
'--gpuid'
,
default
=
'0'
,
type
=
str
,
help
=
'which gpu to use'
)
parser
.
add_argument
(
'--gpuid'
,
default
=
'0'
,
type
=
str
,
help
=
'which gpu to use'
)
parser
.
add_argument
(
'--no_cuda'
,
action
=
'store_true'
,
default
=
False
,
help
=
'disables CUDA training'
)
parser
.
add_argument
(
'--no_cuda'
,
action
=
'store_true'
,
default
=
False
,
help
=
'disables CUDA training'
)
parser
.
add_argument
(
'-b'
,
'--batch_size'
,
default
=
32
,
type
=
int
,
help
=
'mini-batch size (default: 32)'
)
parser
.
add_argument
(
'-b'
,
'--batch_size'
,
default
=
32
,
type
=
int
,
help
=
'mini-batch size (default: 32)'
)
...
@@ -58,15 +59,15 @@ tongue_train_loader = torch.utils.data.DataLoader(tongue_dataset_train, batch_si
...
@@ -58,15 +59,15 @@ tongue_train_loader = torch.utils.data.DataLoader(tongue_dataset_train, batch_si
**
kwargs
)
**
kwargs
)
def
get_model
(
model_name
,
num_classes
=
NUM_CLASSES
):
def
get_model
(
model_name
):
if
model_name
==
'
wide_resnet-34x10
'
:
if
model_name
==
'
densenet121
'
:
return
wideresnet
.
WideResNet
(
num_classes
=
num_classes
)
return
models
.
densenet121
(
pretrained
=
True
,
num_classes
=
NUM_CLASSES
)
.
to
(
device
)
elif
model_name
==
'resnet-50'
:
elif
model_name
==
'resnet-50'
:
return
resnet
.
ResNet50
(
num_classes
=
num_classes
)
return
models
.
resnet50
(
pretrained
=
True
,
num_classes
=
NUM_CLASSES
)
.
to
(
device
)
elif
model_name
==
'
resnet-101
'
:
elif
model_name
==
'
vgg19_bn
'
:
return
resnet
.
ResNet101
(
num_classes
=
num_classes
)
return
models
.
vgg19_bn
(
pretrained
=
True
,
num_classes
=
NUM_CLASSES
)
.
to
(
device
)
elif
model_name
==
'
resnet-152
'
:
elif
model_name
==
'
inception_v3
'
:
return
resnet
.
ResNet152
(
num_classes
=
num_classes
)
return
models
.
inception_v3
(
pretrained
=
True
,
num_classes
=
NUM_CLASSES
)
.
to
(
device
)
else
:
else
:
raise
ValueError
(
'Unsupport model: {0}'
,
model_name
)
raise
ValueError
(
'Unsupport model: {0}'
,
model_name
)
...
@@ -146,7 +147,8 @@ def eval_test(model, test_loader):
...
@@ -146,7 +147,8 @@ def eval_test(model, test_loader):
def
main
():
def
main
():
print
(
args
)
print
(
args
)
model
=
get_model
(
args
.
model_name
,
num_classes
=
NUM_CLASSES
)
.
to
(
device
)
# model = get_model(args.model_name, num_classes=NUM_CLASSES).to(device)
model
=
models
.
resnet50
(
pretrained
=
False
,
num_classes
=
NUM_CLASSES
)
.
to
(
device
)
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
args
.
lr
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
args
.
lr
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment