Use asyncio.start_server instead of loop.create_server
The protocol_factory doesn't support `Protocol`s with async handlers e.g `async def connection_made` won't actually be awaited and thus nothing ever happens. Now there's a pretty ugly solution with one long-ass method, but maybe that can be trimmed or a callable can be used. #4 - Investigate extending pr0xy to use SAM
This commit is contained in:
@@ -14,26 +14,24 @@
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import typing
|
||||
from time import sleep
|
||||
|
||||
|
||||
from i2plib import sam
|
||||
|
||||
from trans_proxy import fake_dns
|
||||
from trans_proxy.process import AsyncProcess
|
||||
from trans_proxy.servers import ClientTcpTunnel
|
||||
from trans_proxy.servers import start_client_tcp_tunnel
|
||||
|
||||
ENV_PORT = "PROXY_PORT"
|
||||
ENV_SAM_HOST = "PROXY_SAM_HOST"
|
||||
ENV_SAM_PORT = "PROXY_SAM_PORT"
|
||||
ENV_DNS_PORT = "PROXY_DNS_PORT"
|
||||
|
||||
logger = logging.getLogger("trans_proxy")
|
||||
logger = logging.getLogger("trans_proxy.cli")
|
||||
|
||||
|
||||
def main():
|
||||
@@ -111,23 +109,7 @@ def exec_processes(processes: typing.List[multiprocessing.Process]):
|
||||
finally:
|
||||
for process in processes:
|
||||
if process.is_alive():
|
||||
process.close()
|
||||
|
||||
|
||||
async def start_client_tcp_tunnel(
|
||||
sam_host,
|
||||
sam_port,
|
||||
ip_dict,
|
||||
host="127.0.0.1", port=1234,
|
||||
**kwargs
|
||||
):
|
||||
loop = asyncio.get_running_loop()
|
||||
server = await loop.create_server(lambda: ClientTcpTunnel(
|
||||
sam_host=sam_host,
|
||||
sam_port=sam_port,
|
||||
ip_dict=ip_dict,
|
||||
), host=host, port=port)
|
||||
await server.serve_forever()
|
||||
process.kill()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -14,67 +14,91 @@
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
import asyncio
|
||||
import typing
|
||||
from asyncio import transports
|
||||
import logging
|
||||
from asyncio.streams import StreamReader, StreamWriter
|
||||
|
||||
import i2plib as i2plib
|
||||
import i2plib
|
||||
|
||||
from trans_proxy.utils import get_original_ip
|
||||
|
||||
|
||||
class ClientTcpTunnel(asyncio.Protocol):
|
||||
SESSION_NAME = "trans-proxy-sessions"
|
||||
|
||||
async def start_client_tcp_tunnel(
|
||||
sam_host,
|
||||
sam_port,
|
||||
ip_dict,
|
||||
host="127.0.0.1", port=1234,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
The transparent proxy that forwards clients to their I2P destinations
|
||||
"""
|
||||
logger = logging.getLogger("servers.client_tcp_tunnel")
|
||||
logger.info("Starting client tcp tunnel at %s:%s", host, port)
|
||||
server = await asyncio.start_server(ClientTcpTunnel.make(
|
||||
sam_host=sam_host,
|
||||
sam_port=sam_port,
|
||||
ip_dict=ip_dict,
|
||||
logger=logger,
|
||||
), host=host, port=port)
|
||||
await server.serve_forever()
|
||||
|
||||
def __init__(self, sam_host: str, sam_port: int, ip_dict: dict):
|
||||
self.sam_port = sam_port
|
||||
self.sam_host = sam_host
|
||||
self.ip_dict = ip_dict
|
||||
|
||||
self.reader: typing.Optional[StreamReader] = None
|
||||
self.writer: typing.Optional[StreamWriter] = None
|
||||
self.transport: typing.Optional[transports.Transport] = None
|
||||
# TODO generate a unique session name
|
||||
self.session_name = "test-connect"
|
||||
self._destination = None
|
||||
class ClientTcpTunnel:
|
||||
session_made = False
|
||||
|
||||
@property
|
||||
def destination(self) -> str:
|
||||
"""
|
||||
The I2P destination that this tunnel is pointing to
|
||||
"""
|
||||
if self._destination is None:
|
||||
self._update_destination()
|
||||
return self._destination
|
||||
@classmethod
|
||||
def make(
|
||||
cls,
|
||||
sam_host: str,
|
||||
sam_port: int,
|
||||
ip_dict: dict,
|
||||
logger: logging.Logger,
|
||||
):
|
||||
# TODO: this seems really dirty. Maybe functools.partial would be cleaner
|
||||
async def handle_connection(reader: StreamReader, writer: StreamWriter):
|
||||
log = logging.getLogger("ClientTcpTunnel")
|
||||
logger.debug("Created")
|
||||
logger.debug("ip_dict: %s", ip_dict)
|
||||
# The I2P destination that this tunnel is pointing to
|
||||
try:
|
||||
destination = ip_dict[
|
||||
get_original_ip(writer.get_extra_info('socket'))
|
||||
]
|
||||
logger.debug("Set destination to %s", destination)
|
||||
|
||||
def _update_destination(self):
|
||||
# TODO: handle non-existent destination
|
||||
self._destination = self.ip_dict[
|
||||
get_original_ip(self.transport.get_extra_info('socket'))
|
||||
]
|
||||
# create a SAM stream session
|
||||
sam_address = (sam_host, sam_port)
|
||||
# TODO: find a cleaner way of doing this
|
||||
if not cls.session_made:
|
||||
try:
|
||||
await i2plib.create_session(SESSION_NAME, sam_address)
|
||||
cls.session_made = True
|
||||
except i2plib.exceptions.DuplicatedId:
|
||||
cls.session_made = True
|
||||
|
||||
async def connection_made(self, transport: transports.Transport) -> None:
|
||||
self.transport = transport
|
||||
self._update_destination()
|
||||
session_name = self.session_name
|
||||
# connect to a destination
|
||||
i2p_reader, i2p_writer = await i2plib.stream_connect(
|
||||
SESSION_NAME,
|
||||
destination,
|
||||
sam_address=sam_address,
|
||||
)
|
||||
|
||||
# create a SAM stream session
|
||||
await i2plib.create_session(session_name, (self.sam_host, self.sam_port))
|
||||
while data := await reader.read(4098):
|
||||
log.debug("data_received")
|
||||
# write data to a socket
|
||||
i2p_writer.write(data)
|
||||
|
||||
# connect to a destination
|
||||
# TODO: Add destination_lookup to FakeResolver
|
||||
self.reader, self.writer = await i2plib.stream_connect(session_name, self.destination)
|
||||
# asynchronously receive data
|
||||
while i2p_response := await i2p_reader.read(4096):
|
||||
writer.write(i2p_response)
|
||||
|
||||
async def data_received(self, data: bytes) -> None:
|
||||
# write data to a socket
|
||||
self.writer.write(data)
|
||||
# close the connection
|
||||
i2p_writer.close()
|
||||
except:
|
||||
logger.exception("Error while proxying")
|
||||
finally:
|
||||
writer.close()
|
||||
|
||||
# asynchronously receive data
|
||||
while i2p_response := await self.reader.read(4096):
|
||||
self.transport.write(i2p_response)
|
||||
|
||||
def eof_received(self):
|
||||
# close the connection
|
||||
self.writer.close()
|
||||
return handle_connection
|
||||
|
Reference in New Issue
Block a user