Skip to content

Instantly share code, notes, and snippets.

@dsevero
Last active October 13, 2023 19:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dsevero/6140d918c0dfcdf6af3dca3cc8a261b2 to your computer and use it in GitHub Desktop.
Save dsevero/6140d918c0dfcdf6af3dca3cc8a261b2 to your computer and use it in GitHub Desktop.
Helper scripts to compute bits-per-dimension of images in a directory.
"""
Helper script to compute BPD of images in a directory.
The script will never modify any files.
All conversions are done in memory on a copy.
Usage: python compute_bpd.py your_glob_pattern [extension] [colorspace] [psnr-check]
Args:
- extension: any valid PIL image extension (e.g., PNG, WebP, JPEG).
- colorspace: any valid PIL colorspace plus the lossless YCoCg.
- psnr-check (flag): if set, will compute PSNR of saved image with respect ot the original colorspace.
Examples:
# Compute BPD of all images in the directory
python compute_bpd.py your_glob
# Convert to PNG and compute bpd
python compute_bpd.py your_glob png
# Convert to webp and compute bpd
python compute_bpd.py your_glob webp
# Convert to lossless (with respect to RGB) YCoCg and save as PNG
python compute_bpd.py your_glob png YCoCg
# Convert to lossless (with respect to RGB) YCoCg, save as PNG, and check PSNR
python compute_bpd.py your_glob png YCoCg psnr-check
"""
import glob
import os
import sys
import math
import numpy as np
import io
from pathlib import Path
from multiprocessing import Pool
from PIL import Image
from functools import partial
from typing import Optional
# Script will ignore all extensions NOT listed here.
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".bmp"]
def compute_psnr(img1, img2, max_value=255):
mse = np.mean((img1 - img2) ** 2)
if mse == 0:
return float("inf")
psnr = 20 * math.log10(max_value / math.sqrt(mse))
return psnr
def YCoCg_from_RGB(img_array):
R, G, B = img_array[..., 0], img_array[..., 1], img_array[..., 2]
diff = B - R
average = R + (diff >> 1)
temp = average
Co = diff
diff = temp - G
average = G + (diff >> 1)
Y = average
Cg = diff
return np.stack([Y, Co, Cg], axis=-1)
def RGB_from_YCoCg(img_array):
Y, Co, Cg = img_array[..., 0], img_array[..., 1], img_array[..., 2]
x = Y - (Cg >> 1)
yy = x + Cg
G = x
temp = yy
x = temp - (Co >> 1)
B = x + Co
R = x
return np.stack([R, G, B], axis=-1)
def is_image_file(filename):
# Check if a file has an image extension (you can add more extensions if needed)
return any(filename.suffix == ext for ext in IMAGE_EXTENSIONS)
def compute_bpd(
file_path: Path,
extension: Optional[str],
colorspace: Optional[str],
psnr_check: bool,
per_channel: bool,
):
use_file_ext = extension is None or str(file_path.suffix).lower() == "." + extension
if use_file_ext and colorspace is None:
assert per_channel == False
return compute_bpd_from_file(file_path)
else:
assert extension is not None
return convert_and_compute_bpd_with_pil(
file_path, extension, colorspace, psnr_check, per_channel
)
def convert_to_colorspace(img: Image.Image, colorspace: str):
if colorspace == "YCoCg":
# PNG will think this is RGB
return Image.fromarray(YCoCg_from_RGB(np.array(img.convert("RGB"))))
else:
return img.convert(colorspace)
def convert_and_compute_bpd_with_pil(
file_path: Path,
extension: str,
colorspace: Optional[str],
psnr_check: bool,
per_channel: bool,
):
def f(img, channel=None):
with io.BytesIO() as byte_stream:
if colorspace is not None and colorspace != img.mode:
img_conv = convert_to_colorspace(img, colorspace)
colorspace_msg = f"{colorspace} (converted)"
else:
img_conv = img
colorspace_msg = f"{img.mode} (from file)"
if channel is not None:
img_conv = img_conv.getchannel(channel)
img_conv.save(byte_stream, format=extension, lossless=True, optimize=True)
bytes = byte_stream.getvalue()
bits = len(bytes) * 8
channels = len(img_conv.getbands())
dims = math.prod(img.size) * channels
if per_channel:
assert channels == 1
bpd = bits / dims
print(
f"{bpd: .2f} bpd ({file_path}) -> .{extension} w/ {colorspace_msg}",
end="",
)
if psnr_check:
assert img.size == img_conv.size
arr = np.array(img)
img_conv_dec = Image.open(byte_stream)
# TODO dsevero: need to make YCoCg a proper PIL plugin.
if colorspace == "YCoCg" and img.mode == "RGB":
arr_conv = RGB_from_YCoCg(np.array(img_conv_dec))
else:
arr_conv = np.array(img_conv_dec.convert(img.mode))
psnr = compute_psnr(arr, arr_conv)
print(f" PSNR={psnr:.2f}")
else:
print("")
return bpd
img = Image.open(file_path)
if per_channel:
bpd = dict()
for channel in range(len(img.getbands())):
print(f"Channel {channel}: ", end="")
bpd[channel] = f(img, channel)
else:
bpd = f(img)
return bpd
def compute_bpd_from_file(file_path: Path):
img = Image.open(file_path)
bits = os.path.getsize(file_path) * 8
channels = len(img.getbands())
dims = math.prod(img.size) * channels
bpd = bits / dims
print(f"{bpd: .2f} bpd ({file_path}) ")
return bpd
def main():
# Check if the directory path is provided as a command line argument
directory_path = sys.argv[1]
extension = None if len(sys.argv) < 3 else sys.argv[2].lower()
colorspace = None if len(sys.argv) < 4 else sys.argv[3]
psnr_check = False if len(sys.argv) < 5 else ("psnr-check" in sys.argv[4:])
per_channel = False if len(sys.argv) < 5 else ("per-channel" in sys.argv[4:])
# List all files in the directory
files = [Path(file_path) for file_path in glob.glob(directory_path, recursive=True)]
files = [
file_path
for file_path in files
if file_path.is_file() and is_image_file(file_path)
]
# Loop over the files and process images
with Pool(processes=os.cpu_count()) as pool:
f = partial(
compute_bpd,
extension=extension,
colorspace=colorspace,
psnr_check=psnr_check,
per_channel=per_channel,
)
bpds = pool.map(f, files)
print("-----------------------------------------------------")
if per_channel:
# TODO will break if some images have different extensions
channels = list(bpds[0].keys())
print(f"Found {len(bpds)} images")
avg_bpd = 0
for c in channels:
avg_bpd_channel = 0
for bpd in bpds:
avg_bpd_channel += bpd[c]
avg_bpd_channel /= len(bpds)
avg_bpd += avg_bpd_channel
print(f"Channel {c}: {avg_bpd_channel:.2f} bpd (Average)")
avg_bpd /= len(channels)
print(f"BPD all channels: {avg_bpd}")
else:
avg_bpd = sum(bpds) / len(bpds)
print(f"Found {len(bpds)} images: {avg_bpd:.2f} bpd (Average)")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment