Skip to content

Commit 9305313

Browse files
committed
Default to old checkpoint format for now, still want compatibility with older torch ver for released models
1 parent a4d8fea commit 9305313

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

avg_checkpoints.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ def main():
103103
v = v.clamp(float32_info.min, float32_info.max)
104104
final_state_dict[k] = v.to(dtype=torch.float32)
105105

106-
torch.save(final_state_dict, args.output)
106+
try:
107+
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
108+
except:
109+
torch.save(final_state_dict, args.output)
110+
107111
with open(args.output, 'rb') as f:
108112
sha_hash = hashlib.sha256(f.read()).hexdigest()
109113
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))

clean_checkpoint.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ def main():
5757
new_state_dict[name] = v
5858
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
5959

60-
torch.save(new_state_dict, _TEMP_NAME)
60+
try:
61+
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
62+
except:
63+
torch.save(new_state_dict, _TEMP_NAME)
64+
6165
with open(_TEMP_NAME, 'rb') as f:
6266
sha_hash = hashlib.sha256(f.read()).hexdigest()
6367

0 commit comments

Comments
 (0)