Eric Bower
·
04 Oct 24
detector.py
1import sys
2import glob
3from PIL import Image
4from transformers import pipeline
5import torch
6
7CGREEN = '\033[92m'
8CYELLOW = '\033[93m'
9CRED = '\033[91m'
10CEND = '\033[0m'
11
12def images(root_dir):
13 count = 0
14 for filename in glob.iglob(root_dir + '**/*.jpg', recursive=True):
15 #if count == 10:
16 # return
17 try:
18 img = Image.open(filename)
19 yield img, filename
20 except Exception as err:
21 print("failed to open file", err)
22 count += 1
23 print(f"scanned {count} images")
24
25if __name__ == '__main__':
26 if len(sys.argv) < 2:
27 raise Exception(f"{CRED}error!: please provide root image folder{CEND}")
28 root_dir = sys.argv[1]
29 print(f"root_dir {root_dir}")
30 threshold = 0.6
31
32 print(f"failure threshold is set to {threshold:.4f}")
33
34 print("loading model")
35 device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
36 classify = pipeline(
37 "image-classification",
38 model="Falconsai/nsfw_image_detection",
39 device=device,
40 )
41
42 print("scanning images")
43 for img, filename in images(root_dir):
44 result = None
45 try:
46 result = classify(img)
47 except Exception as err:
48 # print(f"{CYELLOW}err{CEND} (score:n/a) {filename} {err}")
49 continue
50
51 nsfw_score = result[1]["score"]
52 score_read = '%.4f' % nsfw_score
53 if nsfw_score > threshold:
54 print(f"{CRED}failed{CEND} (score:{score_read}) {filename}")
55 else:
56 # print(f"{CGREEN}passed{CEND} (score:{score_read}) {filename}")
57 pass