Skip to content

Commit 09be781

Browse files
committed
add weights_only=False for the latest torch
1 parent bfb2a89 commit 09be781

10 files changed

Lines changed: 42 additions & 42 deletions

File tree

yolov10/gen_wts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def parse_args():
3838
device = 'cpu'
3939

4040
# Load model
41-
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
41+
model = torch.load(pt_file, map_location=device, weights_only=False)['model'].float() # load to FP32
4242
# If the training is not finished, the model will be interrupted.
43-
# model = torch.load(pt_file, map_location=device)['ema'].float() # load to FP32
43+
# model = torch.load(pt_file, map_location=device, weights_only=False)['ema'].float() # load to FP32
4444

4545
model.to(device).eval()
4646

yolov3-spp/gen_wts.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import struct
22
import sys
3-
from models import *
4-
from utils.utils import *
3+
import torch
4+
from models import * # noqa: F403
5+
from utils.utils import * # noqa: F403
56

6-
model = Darknet('cfg/yolov3-spp.cfg', (416, 416))
7+
model = Darknet('cfg/yolov3-spp.cfg', (416, 416)) # noqa: F405
78
weights = sys.argv[1]
89
dev = '0'
9-
device = torch_utils.select_device(dev)
10-
model.load_state_dict(torch.load(weights, map_location=device)['model'])
10+
device = torch_utils.select_device(dev) # noqa: F405
11+
model.load_state_dict(torch.load(weights, map_location=device, weights_only=False)['model'])
1112

1213

1314
with open('yolov3-spp_ultralytics68.wts', 'w') as f:
@@ -17,6 +18,5 @@
1718
f.write('{} {} '.format(k, len(vr)))
1819
for vv in vr:
1920
f.write(' ')
20-
f.write(struct.pack('>f',float(vv)).hex())
21+
f.write(struct.pack('>f', float(vv)).hex())
2122
f.write('\n')
22-

yolov3-tiny/gen_wts.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import struct
2+
import torch
23
import sys
3-
from models import *
4-
from utils.utils import *
4+
from models import * # noqa: F403
5+
from utils.utils import * # noqa: F403
56

6-
model = Darknet('cfg/yolov3-tiny.cfg', (608, 608))
7+
model = Darknet('cfg/yolov3-tiny.cfg', (608, 608)) # noqa: F405
78
weights = sys.argv[1]
8-
device = torch_utils.select_device('0')
9+
device = torch_utils.select_device('0') # noqa: F405
910
if weights.endswith('.pt'): # pytorch format
10-
model.load_state_dict(torch.load(weights, map_location=device)['model'])
11+
model.load_state_dict(torch.load(weights, map_location=device, weights_only=False)['model'])
1112
else: # darknet format
12-
load_darknet_weights(model, weights)
13+
load_darknet_weights(model, weights) # noqa: F405
1314
model = model.eval()
1415

1516
with open('yolov3-tiny.wts', 'w') as f:
@@ -19,6 +20,5 @@
1920
f.write('{} {} '.format(k, len(vr)))
2021
for vv in vr:
2122
f.write(' ')
22-
f.write(struct.pack('>f',float(vv)).hex())
23+
f.write(struct.pack('>f', float(vv)).hex())
2324
f.write('\n')
24-

yolov3/gen_wts.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import struct
22
import sys
3-
from models import *
4-
from utils.utils import *
3+
import torch
4+
from models import * # noqa: F403
5+
from utils.utils import * # noqa: F403
56

6-
model = Darknet('cfg/yolov3.cfg', (608, 608))
7+
model = Darknet('cfg/yolov3.cfg', (608, 608)) # noqa: F405
78
weights = sys.argv[1]
8-
device = torch_utils.select_device('0')
9+
device = torch_utils.select_device('0') # noqa: F405
910
if weights.endswith('.pt'): # pytorch format
10-
model.load_state_dict(torch.load(weights, map_location=device)['model'])
11+
model.load_state_dict(torch.load(weights, map_location=device, weights_only=False)['model'])
1112
else: # darknet format
12-
load_darknet_weights(model, weights)
13+
load_darknet_weights(model, weights) # noqa: F405
1314
model = model.eval()
1415

1516
with open('yolov3.wts', 'w') as f:
@@ -19,6 +20,5 @@
1920
f.write('{} {} '.format(k, len(vr)))
2021
for vv in vr:
2122
f.write(' ')
22-
f.write(struct.pack('>f',float(vv)).hex())
23+
f.write(struct.pack('>f', float(vv)).hex())
2324
f.write('\n')
24-

yolov4/gen_wts.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import struct
22
import sys
3-
from models import *
4-
from utils.utils import *
3+
import torch
4+
from models import * # noqa: F403
5+
from utils.utils import * # noqa: F403
56

6-
model = Darknet('cfg/yolov4.cfg', (608, 608))
7+
model = Darknet('cfg/yolov4.cfg', (608, 608)) # noqa: F405
78
weights = sys.argv[1]
8-
device = torch_utils.select_device('0')
9+
device = torch_utils.select_device('0') # noqa: F405
910
if weights.endswith('.pt'): # pytorch format
10-
model.load_state_dict(torch.load(weights, map_location=device)['model'])
11+
model.load_state_dict(torch.load(weights, map_location=device, weights_only=False)['model'])
1112
else: # darknet format
12-
load_darknet_weights(model, weights)
13+
load_darknet_weights(model, weights) # noqa: F405
1314

1415
with open('yolov4.wts', 'w') as f:
1516
f.write('{}\n'.format(len(model.state_dict().keys())))
@@ -18,6 +19,5 @@
1819
f.write('{} {} '.format(k, len(vr)))
1920
for vv in vr:
2021
f.write(' ')
21-
f.write(struct.pack('>f',float(vv)).hex())
22+
f.write(struct.pack('>f', float(vv)).hex())
2223
f.write('\n')
23-

yolov5/gen_wts.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import sys
21
import argparse
32
import os
43
import struct
@@ -33,7 +32,7 @@ def parse_args():
3332
# Load model
3433
print(f'Loading {pt_file}')
3534
device = select_device('cpu')
36-
model = torch.load(pt_file, map_location=device) # Load FP32 weights
35+
model = torch.load(pt_file, map_location=device, weights_only=False) # Load FP32 weights
3736
model = model['ema' if model.get('ema') else 'model'].float()
3837

3938
if m_type in ['detect', 'seg']:

yolov7/gen_wts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import sys
21
import argparse
32
import os
43
import struct
@@ -27,13 +26,14 @@ def parse_args():
2726
# Initialize
2827
device = select_device('cpu')
2928
# Load model
30-
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
29+
model = torch.load(pt_file, map_location=device, weights_only=False)['model'].float() # load to FP32
3130

3231
# update anchor_grid info
3332
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
3433
# model.model[-1].anchor_grid = anchor_grid
3534
delattr(model.model[-1], 'anchor_grid') # model.model[-1] is detect layer
36-
model.model[-1].register_buffer("anchor_grid", anchor_grid) # The parameters are saved in the OrderDict through the "register_buffer" method, and then saved to the weight.
35+
# The parameters are saved in the OrderDict through the "register_buffer" method, and then saved to the weight.
36+
model.model[-1].register_buffer("anchor_grid", anchor_grid)
3737

3838
model.to(device).eval()
3939

yolov8/gen_wts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def parse_args():
3737
device = 'cpu'
3838

3939
# Load model
40-
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
40+
model = torch.load(pt_file, map_location=device, weights_only=False)['model'].float() # load to FP32
4141

4242
if m_type in ['detect', 'seg', 'pose']:
4343
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]

yolov8/yolov8_trt10/gen_wts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def parse_args():
3737
device = 'cpu'
3838

3939
# Load model
40-
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
40+
model = torch.load(pt_file, map_location=device, weights_only=False)['model'].float() # load to FP32
4141

4242
if m_type in ['detect', 'seg', 'pose']:
4343
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]

yolov9/gen_wts.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import sys
21
import argparse
32
import os
43
import struct
54
import torch
65
from utils.torch_utils import select_device
76

7+
88
def parse_args():
99
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
1010
parser.add_argument('-w', '--weights', default='yolov9-e.pt',
@@ -25,13 +25,14 @@ def parse_args():
2525
os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
2626
return args.weights, args.output, args.type
2727

28+
2829
pt_file, wts_file, m_type = parse_args()
2930
print(f'Generating .wts for {m_type} model')
3031

3132
# Load model
3233
print(f'Loading {pt_file}')
3334
device = select_device('cpu')
34-
model = torch.load(pt_file, map_location=device) # Load FP32 weights
35+
model = torch.load(pt_file, map_location=device, weights_only=False) # Load FP32 weights
3536
model = model["model"].float()
3637

3738
if m_type in ['detect', 'seg']:

0 commit comments

Comments
 (0)