diff --git a/common/types.go b/common/types.go index d1ffd2cb..8abf41c1 100644 --- a/common/types.go +++ b/common/types.go @@ -103,6 +103,10 @@ type BaseSession struct { SAM SAM } +func (bs *BaseSession) Conn() net.Conn { + return bs.conn +} + func (bs *BaseSession) ID() string { return bs.id } func (bs *BaseSession) Keys() i2pkeys.I2PKeys { return bs.keys } func (bs *BaseSession) Read(b []byte) (int, error) { return bs.conn.Read(b) } diff --git a/datagram/read.go b/datagram/read.go index a7f43d01..ee84e118 100644 --- a/datagram/read.go +++ b/datagram/read.go @@ -82,6 +82,7 @@ func (r *DatagramReader) safeCloseChannel() { close(r.recvChan) close(r.errorChan) } + func (r *DatagramReader) receiveLoop() { logger := log.WithField("session_id", r.session.ID()) logger.Debug("Starting receive loop") diff --git a/raw/dial.go b/raw/dial.go index 7fae116c..0547f125 100644 --- a/raw/dial.go +++ b/raw/dial.go @@ -7,6 +7,7 @@ import ( "time" "github.com/go-i2p/i2pkeys" + "github.com/samber/oops" "github.com/sirupsen/logrus" ) @@ -24,6 +25,19 @@ func (rs *RawSession) DialTimeout(destination string, timeout time.Duration) (ne // DialContext establishes a raw connection with context support func (rs *RawSession) DialContext(ctx context.Context, destination string) (net.PacketConn, error) { + // Validate session state first + rs.mu.RLock() + if rs.closed { + rs.mu.RUnlock() + return nil, oops.Errorf("session is closed") + } + rs.mu.RUnlock() + + // Validate destination + if destination == "" { + return nil, oops.Errorf("destination cannot be empty") + } + logger := log.WithFields(logrus.Fields{ "destination": destination, }) @@ -36,7 +50,7 @@ func (rs *RawSession) DialContext(ctx context.Context, destination string) (net. writer: rs.NewWriter(), } - // Start the reader loop + // Start the reader loop only if session is valid go conn.reader.receiveLoop() logger.WithField("session_id", rs.ID()).Debug("Successfully created raw connection") @@ -57,6 +71,14 @@ func (rs *RawSession) DialI2PTimeout(addr i2pkeys.I2PAddr, timeout time.Duration // DialI2PContext establishes a raw connection to an I2P address with context support func (rs *RawSession) DialI2PContext(ctx context.Context, addr i2pkeys.I2PAddr) (net.PacketConn, error) { + // Validate session state first + rs.mu.RLock() + if rs.closed { + rs.mu.RUnlock() + return nil, oops.Errorf("session is closed") + } + rs.mu.RUnlock() + logger := log.WithFields(logrus.Fields{ "destination": addr.Base32(), }) @@ -69,7 +91,7 @@ func (rs *RawSession) DialI2PContext(ctx context.Context, addr i2pkeys.I2PAddr) writer: rs.NewWriter(), } - // Start the reader loop + // Start the reader loop only if session is valid go conn.reader.receiveLoop() logger.WithField("session_id", rs.ID()).Debug("Successfully created I2P raw connection") diff --git a/raw/dial_test.go b/raw/dial_test.go index c83dc30b..aab5537d 100644 --- a/raw/dial_test.go +++ b/raw/dial_test.go @@ -10,66 +10,63 @@ import ( "github.com/go-i2p/i2pkeys" ) +func setupTestSession(t *testing.T) *RawSession { + t.Helper() + + // Skip actual I2P connection for unit tests + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + sam, err := common.NewSAM(testSAMAddr) + if err != nil { + t.Fatalf("Failed to create SAM connection: %v", err) + } + + keys, err := sam.NewKeys() + if err != nil { + sam.Close() + t.Fatalf("Failed to generate keys: %v", err) + } + + session, err := NewRawSession(sam, "test_dial_session", keys, nil) + if err != nil { + sam.Close() + t.Fatalf("Failed to create session: %v", err) + } + + return session +} + +// Update the test to use proper session setup func TestRawSession_Dial(t *testing.T) { tests := []struct { - name string - destination string - setupSession func() *RawSession - wantErr bool - errContains string + name string + destination string + wantErr bool + errContains string }{ { name: "valid_b32_destination", destination: "example.b32.i2p", - setupSession: func() *RawSession { - sam := &common.SAM{} - baseSession := &common.BaseSession{} - return &RawSession{ - BaseSession: baseSession, - sam: sam, - options: []string{}, - closed: false, - } - }, - wantErr: false, + wantErr: false, }, { name: "empty_destination", destination: "", - setupSession: func() *RawSession { - sam := &common.SAM{} - baseSession := &common.BaseSession{} - return &RawSession{ - BaseSession: baseSession, - sam: sam, - options: []string{}, - closed: false, - } - }, wantErr: true, errContains: "destination", }, - { - name: "dial_on_closed_session", - destination: "example.b32.i2p", - setupSession: func() *RawSession { - sam := &common.SAM{} - baseSession := &common.BaseSession{} - return &RawSession{ - BaseSession: baseSession, - sam: sam, - options: []string{}, - closed: true, - } - }, - wantErr: true, - errContains: "closed", - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - session := tt.setupSession() + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + session := setupTestSession(t) + defer session.Close() conn, err := session.Dial(tt.destination) @@ -94,11 +91,6 @@ func TestRawSession_Dial(t *testing.T) { return } - // Verify conn implements net.PacketConn - if _, ok := conn.(net.PacketConn); !ok { - t.Error("Dial() returned connection that doesn't implement net.PacketConn") - } - // Clean up if conn != nil { _ = conn.Close() diff --git a/raw/read.go b/raw/read.go index e1a8e92e..779962ba 100644 --- a/raw/read.go +++ b/raw/read.go @@ -56,7 +56,10 @@ func (r *RawReader) Close() error { logger.Warn("Timeout waiting for receive loop to stop") } - // Now safe to close the receiver channels since receiveLoop has stopped + // Fix: Close doneChan here to prevent multiple closes + close(r.doneChan) + + // Fix: Close receiver channels here under mutex protection close(r.recvChan) close(r.errorChan) @@ -71,11 +74,23 @@ func (r *RawReader) receiveLoop() { // Signal completion when this loop exits defer func() { - if r.doneChan != nil { - close(r.doneChan) + select { + case r.doneChan <- struct{}{}: + // Successfully signaled completion + default: + // Channel may be closed or blocked - that's okay } }() + // Check session state before starting loop + r.session.mu.RLock() + if r.session.closed || r.session.BaseSession == nil { + r.session.mu.RUnlock() + logger.Debug("Raw receive loop terminated - session invalid") + return + } + r.session.mu.RUnlock() + for { // Check for closure in a non-blocking way first select { @@ -114,6 +129,25 @@ func (r *RawReader) receiveLoop() { func (r *RawReader) receiveDatagram() (*RawDatagram, error) { logger := log.WithField("session_id", r.session.ID()) + // Check if session is valid and not closed + r.session.mu.RLock() + if r.session.closed { + r.session.mu.RUnlock() + return nil, oops.Errorf("session is closed") + } + + // Check if BaseSession is properly initialized + if r.session.BaseSession == nil { + r.session.mu.RUnlock() + return nil, oops.Errorf("session is not properly initialized") + } + + if r.session.BaseSession.Conn() == nil { + r.session.mu.RUnlock() + return nil, oops.Errorf("session connection is not available") + } + r.session.mu.RUnlock() + // Read from the session connection for incoming raw datagrams buf := make([]byte, 4096) n, err := r.session.Read(buf) diff --git a/raw/read_test.go b/raw/read_test.go new file mode 100644 index 00000000..51035210 --- /dev/null +++ b/raw/read_test.go @@ -0,0 +1,61 @@ +package raw + +import ( + "sync" + "testing" + + "github.com/go-i2p/go-sam-go/common" +) + +func TestRawReader_ConcurrentClose(t *testing.T) { + // Test concurrent Close() calls don't panic + session := &RawSession{ + BaseSession: &common.BaseSession{}, + closed: false, + } + + reader := &RawReader{ + session: session, + recvChan: make(chan *RawDatagram, 10), + errorChan: make(chan error, 1), + closeChan: make(chan struct{}), + doneChan: make(chan struct{}), + closed: false, + mu: sync.RWMutex{}, + } + + // Start receive loop + go reader.receiveLoop() + + // Simulate concurrent close attempts + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = reader.Close() // Should not panic + }() + } + + wg.Wait() + + // Verify reader is properly closed + if !reader.closed { + t.Error("Reader should be marked as closed") + } +} + +func TestRawReader_CloseRaceCondition(t *testing.T) { + // Test that rapid close after start doesn't cause channel panic + for i := 0; i < 100; i++ { + session := &RawSession{closed: false} + reader := session.NewReader() + + go reader.receiveLoop() + + // Close immediately to trigger race condition + if err := reader.Close(); err != nil { + t.Errorf("Close() failed: %v", err) + } + } +}