#!/usr/bin/python3
import numpy as np
import matplotlib.pyplot as plt

medicare = 0.98
brackets = [
		(0, 18200, lambda min, max, x: 0),
		(18200, 45000, lambda min, max, x: x-((x-min)*0.19)),
		(45000, 120000, lambda min, max, x: x-(5092+(x-min)*0.325)),
		(120000, 180000, lambda min, max, x: x-(29467+(x-min)*0.37)),
		(180000, float("inf"), lambda min, max, x: x-(51667+(x-min)*0.45)),
]

def salary_after_tax(x):
	for min, max, func in brackets:
		if x>= min and x <= max:
			return func(min, max, x)*medicare

def salary(x):
	return x

MAX_X = 200000 # A maximum value for x goes here
x = range(MAX_X)

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
y1, y2 = zip(*[ (salary(val), salary_after_tax(val)) for val in x ])
plt.plot(x, y1)
plt.plot(x, y2)
plt.ylabel('Tax brackets')
ax.set_xticks([min for min, max, func in brackets])
ax.set_xticks(np.arange(0, MAX_X, 10000), minor=True)
ax.set_yticks([min for min, max, func in brackets])
ax.set_yticks(np.arange(0, MAX_X, 10000), minor=True)
ax.grid(which='minor', alpha=0.2)
ax.grid(which='major', alpha=0.5)
plt.show()
