#!/usr/bin/python3

# Author: Paul Wise <pabs@debian.org>
# License: MIT/Expat

import os
import sys
import ldap

def filter(line):
	if line[0] == '#': return line
	host = line.split(',', 1)[0]
	if host not in hosts:
		error = 'Unknown host'
	elif not hosts[host]:
		error = 'No groups allowed'
	elif set(hosts[host]).isdisjoint(set(groups)):
		error = 'Not in allowed groups'
	elif host in status:
		error = 'Host is down %s' % status[host]
	else:
		error = None
	prefix = '# %s # ' % error if error else ''
	return "%s%s" % (prefix, line)

ldap_server = 'ldaps://db.debian.org'
ldap_hosts_search = 'ou=hosts,dc=debian,dc=org'
ldap_search = '(objectClass=debianServer)'
ldap_attributes = ['hostname', 'allowedGroups', 'status']
l = ldap.initialize(ldap_server)
r = l.search_s(ldap_hosts_search, ldap.SCOPE_SUBTREE, ldap_search, ldap_attributes)
hosts = dict([(e['hostname'][0], e.get('allowedGroups')) for dn, e in r])
status = dict([(e['hostname'][0], e['status'][0]) for dn, e in r if 'status' in e])

groups_file = os.path.expanduser('~/.cache/projects-remote-debian')
groups_f = open(groups_file)
groups = groups_f.read().splitlines()
groups_f.close()

ssh_hosts_file = os.path.expanduser('~/.ssh/hosts/debian')
ssh_hosts_f = open(ssh_hosts_file, 'rw+')
ssh_hosts_lines = ssh_hosts_f.read().splitlines()
ssh_hosts_lines = [filter(line) for line in ssh_hosts_lines]
ssh_hosts_f.seek(0, 0)
ssh_hosts_f.write('\n'.join(ssh_hosts_lines))
ssh_hosts_f.close()
