Skip to content

Commit

Permalink
feat(portforwarding): VPN_PORT_FORWARDING_DOWN_COMMAND option
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Nov 10, 2024
1 parent a035a15 commit 0374c14
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 7 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
VPN_PORT_FORWARDING_USERNAME= \
VPN_PORT_FORWARDING_PASSWORD= \
VPN_PORT_FORWARDING_UP_COMMAND= \
VPN_PORT_FORWARDING_DOWN_COMMAND= \
# # Cyberghost only:
OPENVPN_CERT= \
OPENVPN_KEY= \
Expand Down
18 changes: 15 additions & 3 deletions internal/configuration/settings/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ type PortForwarding struct {
// It can be the empty string to indicate not to run a command.
// It cannot be nil in the internal state.
UpCommand *string `json:"up_command"`
// DownCommand is the command to use after the port forwarding goes down.
// It can be the empty string to indicate to NOT run a command.
// It cannot be nil in the internal state.
DownCommand *string `json:"down_command"`
// ListeningPort is the port traffic would be redirected to from the
// forwarded port. The redirection is disabled if it is set to 0, which
// is its default as well.
Expand Down Expand Up @@ -89,6 +93,7 @@ func (p *PortForwarding) Copy() (copied PortForwarding) {
Provider: gosettings.CopyPointer(p.Provider),
Filepath: gosettings.CopyPointer(p.Filepath),
UpCommand: gosettings.CopyPointer(p.UpCommand),
DownCommand: gosettings.CopyPointer(p.DownCommand),
ListeningPort: gosettings.CopyPointer(p.ListeningPort),
Username: p.Username,
Password: p.Password,
Expand All @@ -100,6 +105,7 @@ func (p *PortForwarding) OverrideWith(other PortForwarding) {
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)
p.UpCommand = gosettings.OverrideWithPointer(p.UpCommand, other.UpCommand)
p.DownCommand = gosettings.OverrideWithPointer(p.DownCommand, other.DownCommand)
p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort)
p.Username = gosettings.OverrideWithComparable(p.Username, other.Username)
p.Password = gosettings.OverrideWithComparable(p.Password, other.Password)
Expand All @@ -110,6 +116,7 @@ func (p *PortForwarding) setDefaults() {
p.Provider = gosettings.DefaultPointer(p.Provider, "")
p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port")
p.UpCommand = gosettings.DefaultPointer(p.UpCommand, "")
p.DownCommand = gosettings.DefaultPointer(p.DownCommand, "")
p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0)
}

Expand Down Expand Up @@ -142,9 +149,11 @@ func (p PortForwarding) toLinesNode() (node *gotree.Node) {
}
node.Appendf("Forwarded port file path: %s", filepath)

command := *p.UpCommand
if command != "" {
node.Appendf("Forwarded port command: %s", command)
if *p.UpCommand != "" {
node.Appendf("Forwarded port up command: %s", *p.UpCommand)
}
if *p.DownCommand != "" {
node.Appendf("Forwarded port down command: %s", *p.DownCommand)
}

if p.Username != "" {
Expand Down Expand Up @@ -178,6 +187,9 @@ func (p *PortForwarding) read(r *reader.Reader) (err error) {
p.UpCommand = r.Get("VPN_PORT_FORWARDING_UP_COMMAND",
reader.ForceLowercase(false))

p.DownCommand = r.Get("VPN_PORT_FORWARDING_DOWN_COMMAND",
reader.ForceLowercase(false))

p.ListeningPort, err = r.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT")
if err != nil {
return err
Expand Down
1 change: 1 addition & 0 deletions internal/portforward/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
Enabled: settings.Enabled,
Filepath: *settings.Filepath,
UpCommand: *settings.UpCommand,
DownCommand: *settings.DownCommand,
ListeningPort: *settings.ListeningPort,
},
},
Expand Down
2 changes: 1 addition & 1 deletion internal/portforward/service/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/qdm12/gluetun/internal/command"
)

func runUpCommand(ctx context.Context, cmder Cmder, logger Logger,
func runCommand(ctx context.Context, cmder Cmder, logger Logger,
commandTemplate string, ports []uint16,
) (err error) {
portStrings := make([]string, len(ports))
Expand Down
4 changes: 2 additions & 2 deletions internal/portforward/service/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/require"
)

func Test_Service_runUpCommand(t *testing.T) {
func Test_Service_runCommand(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)

Expand All @@ -22,7 +22,7 @@ func Test_Service_runUpCommand(t *testing.T) {
logger := NewMockLogger(ctrl)
logger.EXPECT().Info("1234,5678")

err := runUpCommand(ctx, cmder, logger, commandTemplate, ports)
err := runCommand(ctx, cmder, logger, commandTemplate, ports)

require.NoError(t, err)
}
3 changes: 3 additions & 0 deletions internal/portforward/service/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Settings struct {
PortForwarder PortForwarder
Filepath string
UpCommand string
DownCommand string
Interface string // needed for PIA, PrivateVPN and ProtonVPN, tun0 for example
ServerName string // needed for PIA
CanPortForward bool // needed for PIA
Expand All @@ -26,6 +27,7 @@ func (s Settings) Copy() (copied Settings) {
copied.PortForwarder = s.PortForwarder
copied.Filepath = s.Filepath
copied.UpCommand = s.UpCommand
copied.DownCommand = s.DownCommand
copied.Interface = s.Interface
copied.ServerName = s.ServerName
copied.CanPortForward = s.CanPortForward
Expand All @@ -40,6 +42,7 @@ func (s *Settings) OverrideWith(update Settings) {
s.PortForwarder = gosettings.OverrideWithComparable(s.PortForwarder, update.PortForwarder)
s.Filepath = gosettings.OverrideWithComparable(s.Filepath, update.Filepath)
s.UpCommand = gosettings.OverrideWithComparable(s.UpCommand, update.UpCommand)
s.DownCommand = gosettings.OverrideWithComparable(s.DownCommand, update.DownCommand)
s.Interface = gosettings.OverrideWithComparable(s.Interface, update.Interface)
s.ServerName = gosettings.OverrideWithComparable(s.ServerName, update.ServerName)
s.CanPortForward = gosettings.OverrideWithComparable(s.CanPortForward, update.CanPortForward)
Expand Down
2 changes: 1 addition & 1 deletion internal/portforward/service/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
s.portMutex.Unlock()

if s.settings.UpCommand != "" {
err = runUpCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
err = runCommand(ctx, s.cmder, s.logger, s.settings.UpCommand, ports)
if err != nil {
err = fmt.Errorf("running up command: %w", err)
s.logger.Error(err.Error())
Expand Down
12 changes: 12 additions & 0 deletions internal/portforward/service/stop.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"time"
)

func (s *Service) Stop() (err error) {
Expand All @@ -30,6 +31,17 @@ func (s *Service) cleanup() (err error) {
s.portMutex.Lock()
defer s.portMutex.Unlock()

if s.settings.DownCommand != "" {
const downTimeout = 60 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), downTimeout)
defer cancel()
err = runCommand(ctx, s.cmder, s.logger, s.settings.DownCommand, s.ports)
if err != nil {
err = fmt.Errorf("running down command: %w", err)
s.logger.Error(err.Error())
}
}

for _, port := range s.ports {
err = s.portAllower.RemoveAllowedPort(context.Background(), port)
if err != nil {
Expand Down

0 comments on commit 0374c14

Please sign in to comment.