repos / ops

infra for pico services
git clone https://github.com/picosh/ops.git

ops / scripts / nsfw_detector
Eric Bower · 27 Sep 24

detector.py

 1import sys
 2import glob
 3from PIL import Image
 4from transformers import pipeline
 5import torch
 6
 7CGREEN = '\033[92m'
 8CYELLOW = '\033[93m'
 9CRED = '\033[91m'
10CEND = '\033[0m'
11
12def images(root_dir):
13    count = 0
14    for filename in glob.iglob(root_dir + '**/*.jpg', recursive=True):
15        if count == 10:
16            return
17        try:
18            img = Image.open(filename)
19            yield img, filename
20        except Exception as err:
21            print("failed to open file", err)
22        count += 1
23
24if __name__ == '__main__':
25    if len(sys.argv) < 2:
26        raise Exception(f"{CRED}error!: please provide root image folder{CEND}")
27    root_dir = sys.argv[1]
28    print(f"root_dir {root_dir}")
29    threshold = 0.3
30
31    print(f"failure threshold is set to {threshold:.4f}")
32
33    print("loading model")
34    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
35    classify = pipeline(
36        "image-classification",
37        model="Falconsai/nsfw_image_detection",
38        device=device,
39    )
40
41    print("scanning images")
42    for img, filename in images(root_dir):
43        result = None
44        try:
45            result = classify(img)
46        except Exception as err:
47            print(f"{CYELLOW}err{CEND} (score:n/a) {filename} {err}")
48            continue
49
50        nsfw_score = result[1]["score"]
51        score_read = '%.4f' % nsfw_score
52        if nsfw_score > threshold:
53            print(f"{CRED}failed{CEND} (score:{score_read}) {filename}")
54        else:
55            print(f"{CGREEN}passed{CEND} (score:{score_read}) {filename}")