Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Hash with SHA256 for PyTorch checkpoints.""" 

2 

3import hashlib 

4import os 

5import shutil 

6from argparse import ArgumentParser 

7 

8 

9def _hash_ckpt(ckpt_file: str, jitted: bool = False, output_path: str = ""): 

10 with open(ckpt_file, "rb") as file: 

11 sha_hash = hashlib.sha256(file.read()).hexdigest() 

12 

13 path_to_file_list = ckpt_file.split(os.sep) 

14 ckpt_file_name = os.path.splitext(path_to_file_list[-1])[0] 

15 if jitted: 

16 filename = "-".join((ckpt_file_name, sha_hash[:8])) + ".ptc" 

17 else: 

18 filename = "-".join((ckpt_file_name, sha_hash[:8])) + ".pt" 

19 

20 shutil.move(ckpt_file, os.path.join(output_path, filename)) 

21 print(f"==> Saved state dict into {filename} | SHA256: {sha_hash}") 

22 

23 return filename, sha_hash 

24 

25 

26if __name__ == "__main__": 

27 parser = ArgumentParser(description="Hash checkpointed PyTorch file.") 

28 parser.add_argument( 

29 "--ckpt_file", 

30 type=str, 

31 required=True, 

32 help="Path to the checkpoint file including filename", 

33 ) 

34 parser.add_argument( 

35 "--jitted", 

36 type=bool, 

37 required=False, 

38 default=False, 

39 help="Is the checkpoint file jitted? (.ptc for jitted else .pt)", 

40 ) 

41 parser.add_argument( 

42 "--output_path", 

43 type=str, 

44 required=False, 

45 default=os.getcwd(), 

46 help="Path to the hashed file", 

47 ) 

48 args = parser.parse_args() 

49 _hash_ckpt(args.ckpt_file, args.jitted, args.output_path)