Skip to content
Snippets Groups Projects
metricsmanager.py 6.8 KiB
Newer Older
Ulysse Darmet's avatar
Ulysse Darmet committed
#!/usr/bin/python3

import os
import re
import json
from pprint import *
from argparse import *
from itertools import *

import numpy as np
import matplotlib.pyplot as plt

import bjontegaard

#///////////////////////////////////////////////////////////////////////////////

parser = ArgumentParser(description =
	# TODO: Add a big description on what the script is aimed to do.
	'Compute and plot Bjontegaard metrics',
	allow_abbrev	= False,
	formatter_class	= RawTextHelpFormatter)

parser.add_argument('-H', '--HM',
	type	= str,
	help	= 'the JSON HM metrics file')

parser.add_argument('-S', '--SHM',
	type	= str,
	help	= 'the JSON SHM metrics file')

parser.add_argument('-A', '--ARC',
	type	= str,
	help	= 'the JSON ARC metrics file')
	
parser.add_argument('-s', '--save',
	nargs	= '?',
	const	= os.getcwd(),
	metavar	= 'DIR',
	help	= 'save the plots')

#///////////////////////////////////////////////////////////////////////////////

def set_as(obj, key, T):
	'''
	Return <obj>[<key>] if it is already an object of type <T>
	If not, set <obj>[<key>] to an instance of <T> and return it.
	'''
	try:
		item = obj[key]
		if not isinstance(item, T):
			item = obj[key] = T()
	except IndexError:
		while len(obj) <= key:
			obj.append(None)
		item = obj[key] = T()
	except KeyError:
		item = obj[key] = T()
	return item

#///////////////////////////////////////////////////////////////////////////////

def load_metrics(filename):
	'''
	Load the computed metrics from a JSON file and reorganize them into a short,
	easy-to-use dictionnary of lists.
	'''
	def split_label(label):
		'''
		Split <label> into a tuple: (value, ratio)
		'''
		match = re.match('(\d+)_(\d+)', label)
		if match is not None:
			return tuple(map(float, match.groups()))
		else:
			return float(label), None
	
	# Main function
	try:
		metrics_file = open(filename)
	except (TypeError, OSError):
		return {}
	else:
		with metrics_file:
			metrics_list = list(json.load(metrics_file).items())
		metrics_list.sort(key = lambda x: split_label(x[0])[0])
		
		metrics_dict = {}
		for label, metrics in metrics_list:

			# Determine if there is a ratio in the profile name or not
			label, ratio = split_label(label)
			if ratio is not None:
				
				# Create a dictionnary of different ratios and assign
				# values_o to one of them
				current_dict = set_as(metrics_dict, ratio, dict)
			else:
				
				# Assign values_o to the metrics output dictionnary
				current_dict = metrics_dict
			
			# At this point it doesn't matter if there is a ratio in the
			# current profile -> we simply use current_dict as a reference
			
			# Append the current label
			set_as(current_dict, 'Labels', list).append(label)
			
			# Iterate over the metrics parsed at encoding...
			try:
				for name, value in metrics['SUMMARY'].items():
					if name == 'Bitrate':
				
						# Create a list of bitrates for every layer
						bitrates = set_as(current_dict, 'Bitrates', list)
						for layer, bitrate in enumerate(value):
							set_as(bitrates, layer, list).append(bitrate)
						
					else:
						match = re.match('(.*?)[-_\s]*PSNR', name)
						if match is not None:
						
							# Convert the PSNR metrics name into the HDRTools
							# style (i.e. reverse the 'PSNR' part and its
							# component name)
							name = 'PSNR-{}'.format(match.group(1))
					
						set_as(current_dict, name, list).append(value[-1])
						
			# Or get the bitrate from the files sizes themselves
			except KeyError:
				bitrates = set_as(current_dict, 'Bitrates', list)
				for layer, bitrate in enumerate(metrics['Bitrate']):
					set_as(bitrates, layer, list).append(bitrate)
					
			# Iterate over the metrics parsed in HDRMetrics and HDRVQM...
			try:
				for name, value in chain(metrics['D_Avg'].items(),
										 metrics['HDR-VQM'].items()):
					set_as(current_dict, name, list).append(value)
			
			# Or ignore them so that the script can be launched after encoding
			except KeyError:
				pass
			
		return metrics_dict

#///////////////////////////////////////////////////////////////////////////////

class profile:
	def __init__(self, metrics_dict):
		self.dict = metrics_dict
	
	def get_metrics_set(self, metric):
		bitrates_list	= np.sum(self.dict['Bitrates'], axis = 0)
		
		if metric == 'PSNR':
			y_psnr			= np.array(self.dict['PSNR-Y'])
			u_psnr			= np.array(self.dict['PSNR-U'])
			v_psnr			= np.array(self.dict['PSNR-V'])
		
		elif metric == 'mPSNR':
Ulysse Darmet's avatar
Ulysse Darmet committed
			y_psnr			= np.array(self.dict['mPSNRY'])
			u_psnr			= np.array(self.dict['mPSNRU'])
			v_psnr			= np.array(self.dict['mPSNRV'])
		
		elif metric == 'tPSNR':
			y_psnr			= np.array(self.dict['tPSNR-Y'])
			u_psnr			= np.array(self.dict['tPSNR-U'])
			v_psnr			= np.array(self.dict['tPSNR-V'])
		
		try:
			metrics_list	= (6 * y_psnr + u_psnr + v_psnr) / 8
		except NameError:
			metrics_list	= self.dict[metric]
		
		return list(zip(bitrates_list, metrics_list))

#///////////////////////////////////////////////////////////////////////////////

if __name__ == '__main__':
	args = parser.parse_args()

	# Load metrics from JSON files
	simulcast	= profile(load_metrics(args.HM))
	adaptive	= profile(load_metrics(args.ARC))
	scalables	= {}
	for ratio, metrics_dict in load_metrics(args.SHM).items():
		scalables[ratio / 100] = profile(metrics_dict)
	
	# Plot a Bjontegaard curve
	def plot_bjontegaard(title, profiles1, profiles2, mode, *metrics):
		
		fig = plt.figure()
		
		# Get Bjontegaard function and ylabel
		if mode.upper() in ('BD-BR', 'BD-RATE', 'BITRATE'):
			func	= bjontegaard.bdrate
			ylabel	= '{} (%)'.format(mode)
		elif mode.upper() in ('BD-PSNR', 'PSNR'):
			func	= bjontegaard.bdsnr
			ylabel	= '{} (dB)'.format(mode)
		
		# Convert constant profiles into dictionnaries
		if isinstance(profiles1, profile) and isinstance(profiles2, dict):
			profiles1 = {ratio: profiles1 for ratio in profiles2.keys()}
		elif isinstance(profiles2, profile) and isinstance(profiles1, dict):
			profiles2 = {ratio: profiles2 for ratio in profiles1.keys()}
		else:
			raise Exception('...')
		
		# Plot
		for metric in metrics:
			x = []
			y = []
			try:
				for ratio in sorted(profiles1.keys()):
Ulysse Darmet's avatar
Ulysse Darmet committed
					metrics_set1 = profiles1[ratio].get_metrics_set(metric)
					metrics_set2 = profiles2[ratio].get_metrics_set(metric)
					x.append(ratio)
					y.append(func(metrics_set1, metrics_set2))
			except KeyError:
				continue
			else:
				plt.plot(x, y, label = metric)
Ulysse Darmet's avatar
Ulysse Darmet committed

		# Show
		plt.title(title)
		plt.xlabel(r'$\tau = \frac{EL}{EL + BL}$')
Ulysse Darmet's avatar
Ulysse Darmet committed
		plt.ylabel(ylabel)
		plt.legend()
		return fig
	
	# Main function
	metrics = ['PSNR', 'mPSNR', 'tPSNR', 'PSNR_DE0100', 'HDR-VQM']
Ulysse Darmet's avatar
Ulysse Darmet committed
	
	rate_figure = plot_bjontegaard(r'BD(SHM, HM) = $f(\tau)$',
		scalables, simulcast,
		'BD-RATE', *metrics)
	psnr_figure = plot_bjontegaard(r'BD(SHM, HM) = $f(\tau)$',
		scalables, simulcast,
		'BD-PSNR', *metrics)
	plt.show()
	
	if args.save:
		rate_figure.savefig(os.path.join(args.save, 'bdrate.png'))
		psnr_figure.savefig(os.path.join(args.save, 'bdpsnr.png'))