Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""Hash with SHA256 for PyTorch checkpoints."""
3import hashlib
4import os
5import shutil
6from argparse import ArgumentParser
9def _hash_ckpt(ckpt_file: str, jitted: bool = False, output_path: str = ""):
10 with open(ckpt_file, "rb") as file:
11 sha_hash = hashlib.sha256(file.read()).hexdigest()
13 path_to_file_list = ckpt_file.split(os.sep)
14 ckpt_file_name = os.path.splitext(path_to_file_list[-1])[0]
15 if jitted:
16 filename = "-".join((ckpt_file_name, sha_hash[:8])) + ".ptc"
17 else:
18 filename = "-".join((ckpt_file_name, sha_hash[:8])) + ".pt"
20 shutil.move(ckpt_file, os.path.join(output_path, filename))
21 print(f"==> Saved state dict into {filename} | SHA256: {sha_hash}")
23 return filename, sha_hash
26if __name__ == "__main__":
27 parser = ArgumentParser(description="Hash checkpointed PyTorch file.")
28 parser.add_argument(
29 "--ckpt_file",
30 type=str,
31 required=True,
32 help="Path to the checkpoint file including filename",
33 )
34 parser.add_argument(
35 "--jitted",
36 type=bool,
37 required=False,
38 default=False,
39 help="Is the checkpoint file jitted? (.ptc for jitted else .pt)",
40 )
41 parser.add_argument(
42 "--output_path",
43 type=str,
44 required=False,
45 default=os.getcwd(),
46 help="Path to the hashed file",
47 )
48 args = parser.parse_args()
49 _hash_ckpt(args.ckpt_file, args.jitted, args.output_path)