-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
security (dialFunc): adding a dialFunc to estabilishing network conne…
…ctions
- Loading branch information
1 parent
93cd19b
commit 2c7a67a
Showing
1 changed file
with
65 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
package security | ||
|
||
import ( | ||
"crypto/tls" | ||
"errors" | ||
"net" | ||
"time" | ||
) | ||
|
||
var errIpNotAllowed error = errors.New("ip adress is not allowed") | ||
|
||
// IsDisallowedIP checks if the provided host is a disallowed IP address. | ||
// It parses the given host into an IP address and returns true if the IP is multicast, | ||
// unspecified, a loopback address, or a private address. | ||
func IsDisallowedIP(host string) bool { | ||
ip := net.ParseIP(host) | ||
|
||
return ip.IsMulticast() || ip.IsUnspecified() || ip.IsLoopback() || ip.IsPrivate() | ||
} | ||
|
||
// checkDisallowedIP checks if the IP address of the incoming connection is disallowed. | ||
func checkDisallowedIP(conn net.Conn) error { | ||
ip, _, _ := net.SplitHostPort(conn.RemoteAddr().String()) | ||
|
||
if IsDisallowedIP(ip) { | ||
conn.Close() | ||
return errIpNotAllowed | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// dialFunc establishes a network connection to a specified address with optional TLS configuration and timeout. | ||
// It first checks if a TLS configuration is provided, if so, it dials with TLS using the provided TLS configuration. | ||
// If not, it dials without TLS. After establishing the connection, it checks if the remote IP is disallowed. | ||
// Returns the connection and any error encountered during the process. | ||
func dialFunc(network string, addr string, timeout time.Duration, tlsConfig *tls.Config) (net.Conn, error) { | ||
dialer := &net.Dialer{ | ||
Timeout: timeout, | ||
} | ||
|
||
if tlsConfig != nil { | ||
conn, err := tls.DialWithDialer(dialer, network, addr, tlsConfig) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if err := checkDisallowedIP(conn); err != nil { | ||
return nil, err | ||
} | ||
|
||
return conn, err | ||
} | ||
|
||
conn, err := dialer.Dial(network, addr) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if err := checkDisallowedIP(conn); err != nil { | ||
return nil, err | ||
} | ||
|
||
return conn, err | ||
} |