#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import os
import re
import shutil
import socket
import sys

HOME = os.environ['HOME']
KNOWN_HOSTS = os.path.join(HOME, '.ssh', 'known_hosts')


def parse_args():
    parser = argparse.ArgumentParser(description='Cleanup known_hosts file')
    parser.add_argument('-s', '--strip-ips', action='store_true', dest='strip')
    return parser.parse_args()


def backup():
    # Backup known hosts file
    shutil.copyfile(KNOWN_HOSTS, KNOWN_HOSTS+".old")


def is_ip(ip):
    try:
        socket.inet_aton(ip)
        return True
    except socket.error:
        return False


def read_and_reduce():
    # Read known hosts to memory
    knownhosts = {}
    with open(KNOWN_HOSTS) as f:
        for line in f:
            if line.strip() == "" or line.strip().startswith("#"):
                continue
            hosts, keytype, fingerprint = line.strip().split(" ")
            dictkey = keytype + fingerprint
            hosts = hosts.split(",")
            if knownhosts.get(dictkey) == None:
                knownhosts[dictkey] = {
                    'hosts': set(),
                    'keytype': keytype,
                    'fingerprint': fingerprint,
                }
            knownhosts[dictkey]['hosts'].update(hosts)

    return knownhosts


def write_hosts_file(knownhosts):
    # Replace known hosts with a cleaned version
    with open(KNOWN_HOSTS, 'w') as f:
        for key, host in knownhosts.items():
            if len(host['hosts']) == 0:
                continue

            host['hosts_joined'] = ",".join(sorted(host['hosts'], reverse=True))
            f.write("{hosts_joined} {keytype} {fingerprint}\n".format(**host))


def main():
    args = parse_args()

    backup()
    known_hosts = read_and_reduce()

    if args.strip:
        for k in known_hosts:
            known_hosts[k]['hosts'] = [h for h in known_hosts[k]['hosts']
                                       if not is_ip(h)]

    write_hosts_file(known_hosts)
    print("OK. Cleaned up", KNOWN_HOSTS)


if __name__ == "__main__":
    main()