Skip to content
Snippets Groups Projects
interactive.py 7.49 KiB
Newer Older
Ulysse Darmet's avatar
Ulysse Darmet committed
#!/usr/bin/python3

import os
import re
import json
from pprint import *
from argparse import *
from itertools import *

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, RadioButtons

from metricsmanager import load_metrics
from bjontegaard import bdsnr, bdrate

#///////////////////////////////////////////////////////////////////////////////

parser = ArgumentParser(description =
	# TODO: Add a big description on what the script is aimed to do.
	'Description here...',
	allow_abbrev	= False,
	formatter_class	= RawTextHelpFormatter)

parser.add_argument('--hm',
	type	= str,
	help	= 'the JSON HM metrics')

parser.add_argument('--shm',
	type	= str,
	help	= 'the JSON SHM metrics')

parser.add_argument('--arc',
	type	= str,
	help	= 'the JSON ARC metrics')

parser.add_argument('-v', '--verbose',
	action	= 'store_true',
	help	= 'be verbose')

#///////////////////////////////////////////////////////////////////////////////

if __name__ == '__main__':
	args = parser.parse_args()

	# Load metrics from JSON files
	simulcast	= load_metrics(args.hm)
	scalable	= load_metrics(args.shm)
	adaptive	= load_metrics(args.arc)

	# What to compute?
	def compute_bjontegaard(metric, ratio):
		'''
		Compute the Bjontegaard metric between simulcast and scalable curves
		'''
		metrics_set1 = zip(
			np.sum(scalable[round(ratio, 1) * 100]['Bitrates'], axis=0),
			scalable[round(ratio, 1) * 100][metric])
		metrics_set2 = zip(
			np.sum(simulcast['Bitrates'], axis=0),
			simulcast[metric])
		return bdsnr(metrics_set1, metrics_set2)

	# Available metrics list
	metrics_list = [m for m in simulcast.keys() if m not in ('Labels', 'Bitrates')]
	metrics_list.sort()
	
	# Create window and widgets placeholders
	axcurve		= plt.subplot2grid((20, 20), ( 0,  7), rowspan = 18, colspan = 13)
	axratio 	= plt.subplot2grid((20, 20), (19,  7), rowspan =  1, colspan = 13, axisbg = 'yellow')
	axchoice	= plt.subplot2grid((20, 20), ( 0,  0), rowspan = 18, colspan =  5)
	axbjont		= plt.subplot2grid((20, 20), (19,  0), rowspan =  1, colspan =  5)
	plt.subplots_adjust(left = 0.05, right = 0.9)
	
	# Create slider
	ratio = Slider(axratio, 'Ratio', 0.1, 0.9)

	# Create bjontegaard metric label
	bjont = axbjont.text(0, 0.5, '',
		horizontalalignment	= 'left',
		verticalalignment	= 'center')
	axbjont.set_axis_off()
	
	# Create metrics choice menu
	choice = RadioButtons(axchoice, metrics_list, active = metrics_list.index('tOSNR-XYZ'))
	
	def update_circles(widget):
		'''
		Adjust the RadioButtons circles so that they look circular in any
		circumstances
		'''
		width, height = widget.canvas.figure.get_size_inches()
		width *= widget.ax.figbox.width
		height *= widget.ax.figbox.height
		for each in choice.circles:
			each.height = each.width * width / height
	
	def enable_dist_curve(curve, val = True):
		'''
		Enable or not a single rate-distorsion curve
		'''
		curve[0].set_visible(val)
		curve[1].set_visible(val)
	
	# Plot rate-distorsion curves
	def update_dist_curve(curve, autoscale = False,
		profile = None, layers = None, metric = None,
		rate = None, dist = None):
		'''
		Update a rate-distorsion curve, previously generated by new_dist_curve.
		If <profile> or <metric> aren't specified, it will use <rate> and <dist>
		as X and Y values
		<layer> is a list of all the layers id that need to be taken into
		account while summing the bitrates
		'''
		# Compute rate and dist from profile, layers and metric
		if profile is not None and metric is not None:
			if layers is not None:
				rate = np.sum([profile['Bitrates'][i] for i in layers], axis=0)
			else:
				rate = np.sum(profile['Bitrates'], axis=0)
			dist = np.log(profile[metric])
		
		# Generate the fitting polynom
		x = np.linspace(min(rate) * 0.9, max(rate) * 1.1)
		y = np.polyval(np.polyfit(rate, dist, 3), x)
		
		# Update data
		curve[0].set_data(x, y)
		curve[1].set_data(rate, dist)
		
		# Rescale axe
		axe = curve[0]._axes
		if autoscale:
			axe.relim(visible_only = True)
			axe.autoscale_view()
		
		# Redraw the curves
		fig = curve[0].figure
		fig.canvas.draw_idle()
		
	def new_dist_curve(axe, **kwargs):
		'''
		Add a new empty rate-distorstion curve on <axe>
		'''
		poly = axe.plot(0, 0, scalex = False, scaley = False, **kwargs)[0]
		dots = axe.plot(0, 0, scalex = False, scaley = False,
			color		= poly._color,
			marker		= 'o',
			linestyle	= ' ')[0]
		return [poly, dots]
	
	# Create rate-distorsion curves
	p1 = new_dist_curve(axcurve, color = 'blue', label = 'Simulcast')
	p2 = new_dist_curve(axcurve, color = 'red', label = 'Scalable')
	axcurve.set_xlabel('Bitrate (in kb/s)')
	axcurve.set_ylabel('log(PSNR)', rotation = 0)
	axcurve.xaxis.set_label_coords(0.25, 0.05)
	axcurve.yaxis.set_label_coords(0, 1.01)
	axcurve.legend(handles = [p1[0], p2[0]], loc = 4)
	axcurve.margins(0.1, 0.1)
	
	# Update bjontegaard metric label
	def update_bjontegaard():
		'''
		Update the bjontegaard metric label
		'''
		bjont._text = 'Bjontegaard:\n{0:.2f}\n'.format(compute_bjontegaard(choice.value_selected, ratio.val))
	
	# Update simulcast curve p1 when clicking a button
	def update_simulcast_curve():
		update_dist_curve(p1, profile = simulcast, metric = choice.value_selected, autoscale = True)
	
	# Update scalable curve p2 when moving the slider
	def update_scalable_curve():
		'''
		Update <p2> by finding the closest ratio in <scalable>
		'''
		profile = scalable[round(ratio.val, 1) * 100]
		update_dist_curve(p2, profile = profile, metric = choice.value_selected, autoscale = True)
	
	def update_scalable_curve_smooth():
		'''
		Update <p2> by interpolating all the <scalable> ratios
		'''
		rate = []
		dist = []
		
		# Get a list of lists of bitrates and metrics
		# The rows in those lists must be compared in parallel (i.e.
		# <bitrates_list>[0][0] and <bitrates_list>[1][0] are bitrates computed
		# using the same profile but with different ratios)
		bitrates_list = [np.sum(profile['Bitrates'], axis=0) for profile in scalable.values()]
		metrics_list = [profile[choice.value_selected] for profile in scalable.values()]
		
		for bitrates_samples, metrics_samples in zip(zip(*bitrates_list), zip(*metrics_list)):
			
			# x = [0.1, 0.2, ..., 0.9]
			x = np.divide(list(scalable.keys()), 100)
			
			# Interpolate between the different ratios by using a 3 order
			# polynom
			rate.append(np.polyval(np.polyfit(x, list(bitrates_samples), 3), ratio.val))
			dist.append(np.log(np.polyval(np.polyfit(x, list(metrics_samples), 3), ratio.val)))
		update_dist_curve(p2, rate = rate, dist = dist, autoscale = True)
	
	# Update all
	def slider_update_callback(val = None):
		update_bjontegaard()
		update_scalable_curve_smooth() # Change this line for raw or smooth update
	
	def choice_update_callback(label = None):
		update_bjontegaard()
		update_simulcast_curve()
		update_scalable_curve_smooth() # Change this line for raw or smooth update
		
		x = np.array(sorted(scalable.keys())) / 100
		y = np.array([compute_bjontegaard(choice.value_selected, v) for v in x])
		ymin, ymax = min(y), max(y)
		y = (y - ymin) / (ymax - ymin)
		try:
			ratio.preview.remove()
		except AttributeError:
			pass
		ratio.preview = axratio.fill_between(x, y, color = 'black', alpha = 0.3, zorder = 2)
		ratio.valinit = min(zip(x, y), key = lambda v: v[1])[0]
		ratio.vline.set_xdata([ratio.valinit, ratio.valinit])
		ratio.reset()
	
	def window_resize_callback(event = None):
		update_circles(choice)
	
	# Connect callback
	ratio.on_changed(slider_update_callback)
	choice.on_clicked(choice_update_callback)
	plt.gcf().canvas.mpl_connect('resize_event', window_resize_callback)
	
	# Show window
	choice_update_callback()
	window_resize_callback()
	plt.show()