Try to shutdown the UnixSocketServer when our program exits
Also add logging #1 - Forward using iptables, pr0xy and custom DNS
This commit is contained in:
@@ -72,6 +72,7 @@ class UnixSocketServer(socketserver.ThreadingMixIn, socketserver.UnixStreamServe
|
|||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
self.server_close()
|
||||||
|
|
||||||
def isAlive(self):
|
def isAlive(self):
|
||||||
return self.thread.isAlive()
|
return self.thread.isAlive()
|
||||||
@@ -182,17 +183,12 @@ class RandomResolver(BaseResolver):
|
|||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
def handle_sig(signum, frame):
|
|
||||||
logger.info('pid=%d, got signal: %s, stopping...', os.getpid(), signal.Signals(signum).name)
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
"""
|
"""
|
||||||
:param args:
|
:param args:
|
||||||
:type args: argparse.Namespace
|
:type args: argparse.Namespace
|
||||||
"""
|
"""
|
||||||
signal.signal(signal.SIGTERM, handle_sig)
|
log = logging.getLogger("fake-dns.main")
|
||||||
|
|
||||||
port = args.port
|
port = args.port
|
||||||
resolve_dir = Path(args.resolve_dir)
|
resolve_dir = Path(args.resolve_dir)
|
||||||
@@ -204,7 +200,20 @@ def main(args):
|
|||||||
servers = [udp_server, tcp_server]
|
servers = [udp_server, tcp_server]
|
||||||
if args.socket_path:
|
if args.socket_path:
|
||||||
socket_path = Path(args.socket_path)
|
socket_path = Path(args.socket_path)
|
||||||
servers.append(UnixSocketServer(socket_path, UnixSocketHandler))
|
socket_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
log.info("Creating unix socket to %s", socket_path)
|
||||||
|
servers.append(UnixSocketServer(str(socket_path), UnixSocketHandler))
|
||||||
|
|
||||||
|
def stop_servers():
|
||||||
|
for _server in servers:
|
||||||
|
_server.stop()
|
||||||
|
|
||||||
|
def handle_sig(signum, frame):
|
||||||
|
logger.info('pid=%d, got signal: %s, stopping...', os.getpid(), signal.Signals(signum).name)
|
||||||
|
stop_servers()
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM | signal.SIGINT, handle_sig)
|
||||||
|
|
||||||
logger.info('starting DNS server on port %d', port)
|
logger.info('starting DNS server on port %d', port)
|
||||||
for server in servers:
|
for server in servers:
|
||||||
@@ -215,11 +224,13 @@ def main(args):
|
|||||||
sleep(1)
|
sleep(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
for server in servers:
|
finally:
|
||||||
server.stop()
|
stop_servers()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="A DNS server that returns fake IPs for requests"
|
description="A DNS server that returns fake IPs for requests"
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user