repos / ops

infra for pico services
git clone https://github.com/picosh/ops.git

ops / scripts / nsfw_detector
Eric Bower · 04 Oct 24

detector.py

 1import sys
 2import glob
 3from PIL import Image
 4from transformers import pipeline
 5import torch
 6
 7CGREEN = '\033[92m'
 8CYELLOW = '\033[93m'
 9CRED = '\033[91m'
10CEND = '\033[0m'
11
12def images(root_dir):
13    count = 0
14    for filename in glob.iglob(root_dir + '**/*.jpg', recursive=True):
15        #if count == 10:
16        #    return
17        try:
18            img = Image.open(filename)
19            yield img, filename
20        except Exception as err:
21            print("failed to open file", err)
22        count += 1
23    print(f"scanned {count} images")
24
25if __name__ == '__main__':
26    if len(sys.argv) < 2:
27        raise Exception(f"{CRED}error!: please provide root image folder{CEND}")
28    root_dir = sys.argv[1]
29    print(f"root_dir {root_dir}")
30    threshold = 0.6
31
32    print(f"failure threshold is set to {threshold:.4f}")
33
34    print("loading model")
35    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
36    classify = pipeline(
37        "image-classification",
38        model="Falconsai/nsfw_image_detection",
39        device=device,
40    )
41
42    print("scanning images")
43    for img, filename in images(root_dir):
44        result = None
45        try:
46            result = classify(img)
47        except Exception as err:
48            # print(f"{CYELLOW}err{CEND} (score:n/a) {filename} {err}")
49            continue
50
51        nsfw_score = result[1]["score"]
52        score_read = '%.4f' % nsfw_score
53        if nsfw_score > threshold:
54            print(f"{CRED}failed{CEND} (score:{score_read}) {filename}")
55        else:
56            # print(f"{CGREEN}passed{CEND} (score:{score_read}) {filename}")
57            pass