diff --git a/README.md b/README.md index e601019..7a16981 100644 --- a/README.md +++ b/README.md @@ -37,23 +37,42 @@ SSH settings configure the integrated SSH server in `sshmux`. They are grouped u Auth settings configures the authentication and authorization API used by `sshmux`. They are grouped under `auth` in the TOML file. -| Key | Type | Description | Required | Example | -| -------------------------- | ---------- | -------------------------------------------------------------------------- | -------- | ----------------------------- | -| `endpoint` | `string` | Endpoint URL that `sshmux` will use for authentication and authorization. | Yes | `"http://127.0.0.1:5000/ssh"` | -| `token` | `string` | Token used to authenticate with the API endpoint. | Yes | `"long-and-random-token"` | -| `all-username-nopassword` | `bool` | If set to `true`, no users will be asked for UNIX password. | No | `true` | -| `usernames-nopassword` | `[]string` | Usernames that won't be asked for UNIX password. | No | `["vlab", "ubuntu", "root"]` | -| `invalid-usernames` | `[]string` | Usernames that are known to be invalid. | No | `["user"]` | -| `invalid-username-message` | `string` | Message to display when the requested username is invalid. | No | `"Invalid username %s."` | +| Key | Type | Description | Required | Example | +| ---------- | -------------- | -------------------------------------------------------------------------- | -------- | -------------------------------------------------- | +| `endpoint` | `string` | Endpoint URL that `sshmux` will use for authentication and authorization. | Yes | `"http://127.0.0.1:5000/ssh"` | +| `version` | `string` | Auth endpoint API version (`"legacy"`, `"v1"`). Defaults to `"legacy"`. | No | `"v1"` | +| `headers` | `[]HTTPHeader` | Extra HTTP headers to send to API server. | No | See [`fixtures/config.toml`](fixtures/config.toml) | + +#### Legacy Auth Settings + +The following settings are only used by `legacy` auth APIs. They are also grouped under `auth` in the TOML file. + +| Key | Type | Description | Required | Example | +| -------------------------- | ---------- | ----------------------------------------------------------- | ------------------------------- | ---------------------------- | +| `token` | `string` | Token used to authenticate with the API endpoint. | If `auth.version` is `"legacy"` | `"long-and-random-token"` | +| `all-username-nopassword` | `bool` | If set to `true`, no users will be asked for UNIX password. | No | `true` | +| `usernames-nopassword` | `[]string` | Usernames that won't be asked for UNIX password. | No | `["vlab", "ubuntu", "root"]` | +| `invalid-usernames` | `[]string` | Usernames that are known to be invalid. | No | `["user"]` | +| `invalid-username-message` | `string` | Message to display when the requested username is invalid. | No | `"Invalid username %s."` | + +#### Recovery Settings + +Recovery settings configures Vlab recovery service support of `sshmux` for `legacy` auth APIs. They are grouped under `recovery` in the TOML file. + +| Key | Type | Description | Required | Example | +| ----------- | ---------- | ----------------------------------------------------- | -------- | ------------------------- | +| `address` | `string` | SSH host and port of the recovery server. | No | `"172.30.0.101:2222"` | +| `usernames` | `[]string` | Usernames dedicated to the recovery server. | No | `["recovery", "console"]` | +| `token` | `string` | Token used to authenticate with the recovery backend. | No | `"long-and-random-token"` | ### Logger Settings Logger settings configures the logger behavior of `sshmux`. They are grouped under `logger` in the TOML file. -| Key | Type | Description | Required | Example | -| ---------- | -------- | ----------------------------------------------------------------------------- | ---------------------- | ------------------------ | -| `enabled` | `bool` | Whether the logger is enabled. Defaults to `false`. | No | `true` | -| `endpoint` | `string` | Endpoint URL that `sshmux` will log onto. Only `udp` scheme is supported now. | If `enabled` is `true` | `"udp://127.0.0.1:5556"` | +| Key | Type | Description | Required | Example | +| ---------- | -------- | ----------------------------------------------------------------------------- | ----------------------------- | ------------------------ | +| `enabled` | `bool` | Whether the logger is enabled. Defaults to `false`. | No | `true` | +| `endpoint` | `string` | Endpoint URL that `sshmux` will log onto. Only `udp` scheme is supported now. | If `logger.enabled` is `true` | `"udp://127.0.0.1:5556"` | ### PROXY Protocol Settings @@ -65,41 +84,77 @@ PROXY protocol settings configures [PROXY protocol](https://www.haproxy.com/blog | `hosts` | `[]string` | Host names from which PROXY protocol is allowed. | No | `["nginx.local", "127.0.0.22"]` | | `networks` | `[]string` | Network CIDRs from which PROXY protocol is allowed. | No | `["10.10.0.0/24"]` | -### Recovery Settings +## Auth API -Recovery settings configures Vlab recovery service support of `sshmux`. They are grouped under `recovery` in the TOML file. +`sshmux` uses a RESTful API to perform authentication and authorization for a user. -| Key | Type | Description | Required | Example | -| ----------- | ---------- | ----------------------------------------------------- | -------- | ------------------------- | -| `address` | `string` | SSH host and port of the recovery server. | No | `"172.30.0.101:2222"` | -| `usernames` | `[]string` | Usernames dedicated to the recovery server. | No | `["recovery", "console"]` | -| `token` | `string` | Token used to authenticate with the recovery backend. | No | `"long-and-random-token"` | +### `POST /v1/auth/:username` + +#### Input + +| Key | Type | Description | Position | Required | +| ----------------- | --------------------- | ---------------------------------------------------------------------------------------------- | -------- | -------- | +| `username` | `string` | SSH user name. Usually the one for logging into the target server. | Path | Yes | +| `method` | `string` | SSH authentication method. Usually one of `"none"`, `"publickey"` or `"keyboard-interactive"`. | Body | Yes | +| `public_key` | `string` | User public key, serialized in OpenSSH format. | Body | No | +| `payload` | `Map` | Authentication payload constructed from interactive input. | Body | No | + +#### Output: `200 OK` + +| Key | Type | Description | Required | +| ---------------- | ----------------------- | ----------------------------- | -------- | +| `upstream` | [`Upstream`](#upstream) | SSH upstream information. | Yes | +| `proxy` | [`Proxy`](#proxy) | PROXY protocol configuration. | No | + +##### `Upstream` + +| Key | Type | Description | Required | +| ------------- | -------- | --------------------------------------------------------------------------- | -------- | +| `host` | `string` | Host name or IP of upstream SSH server. | Yes | +| `port` | `uint` | Port number of upstream SSH server. Defaults to `22`. | No | +| `private_key` | `string` | Private key for authenticating with upstream, serialized in OpenSSH format. | No | +| `certificate` | `string` | Certificate for authenticating with upstream, serialized in OpenSSH format. | No | +| `password` | `string` | Password for authenticating with upstream. | No | + +##### `Proxy` + +| Key | Type | Description | Required | +| ------------- | -------- | ----------------------------------------------------------------------------------- | -------- | +| `host` | `string` | Host name or IP of the proxy server. Defaults to `upstream.host`. | No | +| `port` | `uint` | Port number of the proxy server. Defaults to `upstream.port`. | No | +| `protocol` | `string` | PROXY protocol version to use. Must be one of `"v1"` or `"v2"`. Defaults to `"v2"`. | No | + +#### Output: `401 Not Authorized` + +| Key | Type | Description | Required | +| ------------ | --------------------------- | ------------------------------------------------------------------------------------------------ | -------- | +| `challenges` | [`[]Challenge`](#challenge) | Challenges for extra inputs from user. Only applicable to `keyboard-interactive` authentication. | Yes | + +##### `Challenge` -## API server +| Key | Type | Description | Required | +| ------------- | ------------------------------------- | ---------------------------------- | -------- | +| `instruction` | `string` | Instruction for the challenge. | Yes | +| `fields` | [`[]ChallengeField`](#challengefield) | Requested fields by the challenge. | No | -`sshmux` requires an API server to perform authentication and authorization for a user. +##### `ChallengeField` -The API accepts JSON input with the following keys: +| Key | Type | Description | Required | +| -------- | -------- | ---------------------------------------------------------- | -------- | +| `key` | `string` | Key to set the user input on. | Yes | +| `prompt` | `string` | Prompt for the input field. | Yes | +| `secret` | `bool` | Whether to treat the input as secret. Defaults to `false`. | No | -| Key | Type | Description | -| ----------------- | -------- | -------------------------------------------------------------------------------------------------------- | -| `auth_type` | `string` | The authentication type. Always set to `"key"` at the moment. | -| `username` | `string` | Vlab username. Omitted if the user is authenticating with public key. | -| `password` | `string` | Vlab password. Omitted if the user is authenticating with public key. | -| `public_key_type` | `string` | SSH public key type. Omitted if the user is authenticating with username and password. | -| `public_key_data` | `string` | Base64-encoded SSH public key payload. Omitted if the user is authenticating with username and password. | -| `unix_username` | `string` | UNIX username the user is requesting access to. | -| `token` | `string` | Token used to authenticate the `sshmux` instance. | +#### Output: `403 Forbidden` -The API responds with JSON output with the following keys: +| Key | Type | Description | Required | +| --------- | --------------------- | ------------------------- | -------- | +| `failure` | [`Failure`](#failure) | Auth failure information. | No | -| Key | Type | Description | -| ---------------- | --------- | ---------------------------------------------------------------------------------------------------------------- | -| `status` | `string` | The authentication status. Should be `"ok"` if the user is authorized. | -| `address` | `string` | TCP host and port of the downstream SSH server the user is requesting for. | -| `private_key` | `string` | SSH private key to authenticate for the downstream. | -| `cert` | `string` | The certificate associated with the SSH private key. | -| `vmid` | `integer` | ID of the requested VM. Only used for recovery access. | -| `proxy_protocol` | `integer` | PROXY protocol version to use for the downstream. Should be `1`, `2` or omitted (which disables PROXY protocol). | +##### `Failure` -Note that if the user is not authorized, the API server should return a `status` other than `"ok"`, and other keys can be safely ommitted. +| Key | Type | Description | Required | +| ------------ | -------- | --------------------------------------------------------------------------- | -------- | +| `message` | `string` | Message from the server to describe the failure. | Yes | +| `disconnect` | `string` | Whether to disconnect the downstream user. Defaults to `false`. | No | +| `reason` | `uint` | SSH disconnect reason code. Defaults to `11` (`DISCONNECT_BY_APPLICATION`). | No | diff --git a/auth.go b/auth.go index c50c6dd..5ffd496 100644 --- a/auth.go +++ b/auth.go @@ -2,143 +2,124 @@ package main import ( "bytes" - "encoding/base64" "encoding/json" "fmt" "io" "net/http" - "slices" + "net/url" "golang.org/x/crypto/ssh" ) -type AuthRequestPublicKey struct { - AuthType string `json:"auth_type"` - UnixUsername string `json:"unix_username"` - PublicKeyType string `json:"public_key_type"` - PublicKeyData string `json:"public_key_data"` - Token string `json:"token"` +type AuthRequest struct { + Method string `json:"method"` + PublicKey string `json:"public_key,omitempty"` + Payload map[string]string `json:"payload"` } -type AuthRequestPassword struct { - AuthType string `json:"auth_type"` - Username string `json:"username"` - Password string `json:"password"` - UnixUsername string `json:"unix_username"` - Token string `json:"token"` +type AuthResponse struct { + Challenges []AuthChallenge `json:"challenges,omitempty"` + Failure *AuthFailure `json:"failure,omitempty"` + Upstream *AuthUpstream `json:"upstream,omitempty"` + Proxy *AuthProxy `json:"proxy,omitempty"` } -type AuthResponse struct { - Status string `json:"status"` - Address string `json:"address"` - PrivateKey string `json:"private_key"` - Cert string `json:"cert"` - Id int `json:"vmid"` - ProxyProtocol byte `json:"proxy_protocol,omitempty"` +type AuthChallenge struct { + Instruction string `json:"instruction"` + Fields []AuthChallengeField `json:"fields"` } -type UpstreamInformation struct { - Host string - Signer ssh.Signer - Password *string - ProxyProtocol byte +type AuthChallengeField struct { + Key string `json:"key"` + Prompt string `json:"prompt"` + Secret bool `json:"secret"` } -type Authenticator struct { - Endpoint string - Token string - Recovery RecoveryConfig +type AuthFailure struct { + Message string `json:"message"` + Disconnect bool `json:"disconnect,omitempty"` + Reason uint32 `json:"reason,omitempty"` } -func makeAuthenticator(auth AuthConfig, recovery RecoveryConfig) Authenticator { - return Authenticator{ - Endpoint: auth.Endpoint, - Token: auth.Token, - Recovery: recovery, - } +type AuthUpstream struct { + Host string `json:"host"` + Port uint16 `json:"port,omitempty"` + PrivateKey string `json:"private_key,omitempty"` + Certificate string `json:"certificate,omitempty"` + Password *string `json:"password,omitempty"` } -func parsePrivateKey(key string, cert string) ssh.Signer { - if key == "" { - return nil - } - signer, err := ssh.ParsePrivateKey([]byte(key)) - if err != nil { - return nil +type AuthProxy struct { + Host string `json:"host,omitempty"` + Port uint16 `json:"port,omitempty"` + Protocol *string `json:"protocol,omitempty"` +} + +type Authenticator interface { + Auth(request AuthRequest, username string) (int, *AuthResponse, error) +} + +func makeAuthenticator(auth AuthConfig) (Authenticator, error) { + if auth.Version == "" { + auth.Version = "v1" } - if cert == "" { - return signer + headers := http.Header{} + for _, header := range auth.Headers { + headers.Add(header.Name, header.Value) } - pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(cert)) + auth_url, err := url.Parse(auth.Endpoint) if err != nil { - return signer + return nil, err } - certSigner, err := ssh.NewCertSigner(pk.(*ssh.Certificate), signer) - if err != nil { - return signer + authenticator := RESTfulAuthenticator{ + Endpoint: auth_url, + Version: auth.Version, + Headers: headers, } - return certSigner + return &authenticator, nil +} + +type RESTfulAuthenticator struct { + Endpoint *url.URL + Version string + Headers http.Header } -func (auth Authenticator) AuthUser(request any, username string) (*UpstreamInformation, error) { +func (auth *RESTfulAuthenticator) Auth(request AuthRequest, username string) (int, *AuthResponse, error) { + if auth.Version != "v1" { + return 500, nil, fmt.Errorf("unsupported API version: %s", auth.Version) + } + auth_url := auth.Endpoint.JoinPath("v1", "auth", username).String() + payload := new(bytes.Buffer) if err := json.NewEncoder(payload).Encode(request); err != nil { - return nil, err + return 0, nil, err + } + + req, err := http.NewRequest("POST", auth_url, payload) + if err != nil { + return 0, nil, err } - res, err := http.Post(auth.Endpoint, "application/json", payload) + req.Header = auth.Headers.Clone() + req.Header.Set("accept", "application/json") + req.Header.Set("content-type", "application/json") + + res, err := http.DefaultClient.Do(req) if err != nil { - return nil, err + return 0, nil, err } defer res.Body.Close() body, err := io.ReadAll(res.Body) if err != nil { - return nil, err + return res.StatusCode, nil, err } + var response AuthResponse err = json.Unmarshal(body, &response) if err != nil { - return nil, err + return res.StatusCode, nil, err } - if response.Status != "ok" { - return nil, nil - } - - var upstream UpstreamInformation - // FIXME: Can this be handled in API server? - if slices.Contains(auth.Recovery.Usernames, username) { - upstream.Host = auth.Recovery.Address - password := fmt.Sprintf("%d %s", response.Id, auth.Recovery.Token) - upstream.Password = &password - } else { - upstream.Host = response.Address - } - upstream.Signer = parsePrivateKey(response.PrivateKey, response.Cert) - upstream.ProxyProtocol = response.ProxyProtocol - return &upstream, nil -} - -func (auth Authenticator) AuthUserWithPublicKey(key ssh.PublicKey, unixUsername string) (*UpstreamInformation, error) { - keyType := key.Type() - keyData := base64.StdEncoding.EncodeToString(key.Marshal()) - request := &AuthRequestPublicKey{ - AuthType: "key", - UnixUsername: unixUsername, - PublicKeyType: keyType, - PublicKeyData: keyData, - Token: auth.Token, - } - return auth.AuthUser(request, unixUsername) -} - -func (auth Authenticator) AuthUserWithUserPass(username string, password string, unixUsername string) (*UpstreamInformation, error) { - request := &AuthRequestPassword{ - AuthType: "key", - Username: username, - Password: password, - UnixUsername: unixUsername, - Token: auth.Token, - } - return auth.AuthUser(request, unixUsername) + return res.StatusCode, &response, nil } func removePublicKeyMethod(methods []string) []string { @@ -150,3 +131,25 @@ func removePublicKeyMethod(methods []string) []string { } return res } + +func parsePrivateKey(key string, cert string) ssh.Signer { + if key == "" { + return nil + } + signer, err := ssh.ParsePrivateKey([]byte(key)) + if err != nil { + return nil + } + if cert == "" { + return signer + } + pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(cert)) + if err != nil { + return signer + } + certSigner, err := ssh.NewCertSigner(pk.(*ssh.Certificate), signer) + if err != nil { + return signer + } + return certSigner +} diff --git a/config.go b/config.go index 4636e9d..ace70f4 100644 --- a/config.go +++ b/config.go @@ -17,15 +17,22 @@ type SSHKeyConfig struct { } type AuthConfig struct { - Endpoint string `toml:"endpoint"` - Token string `toml:"token"` - // The following should be moved into API server + Endpoint string `toml:"endpoint"` + Version string `toml:"version,omitempty"` + Headers []AuthHTTPHeaderConfig `toml:"headers,omitempty"` + // The following settings are for legacy API only + Token string `toml:"token,omitempty"` InvalidUsernames []string `toml:"invalid-usernames,omitempty"` InvalidUsernameMessage string `toml:"invalid-username-message,omitempty"` AllUsernameNoPassword bool `toml:"all-username-nopassword,omitempty"` UsernamesNoPassword []string `toml:"usernames-nopassword,omitempty"` } +type AuthHTTPHeaderConfig struct { + Name string `toml:"name"` + Value string `toml:"value"` +} + type LoggerConfig struct { Enabled bool `toml:"enabled"` Endpoint string `toml:"endpoint,omitempty"` @@ -38,9 +45,9 @@ type ProxyProtocolConfig struct { } type RecoveryConfig struct { - Address string `toml:"address"` - Usernames []string `toml:"usernames"` - Token string `toml:"token"` + Address string `toml:"address,omitempty"` + Usernames []string `toml:"usernames,omitempty"` + Token string `toml:"token,omitempty"` } type Config struct { @@ -101,6 +108,7 @@ func convertLegacyConfig(config LegacyConfig) Config { }, Auth: AuthConfig{ Endpoint: config.API, + Version: "legacy", Token: config.Token, InvalidUsernames: config.InvalidUsername, InvalidUsernameMessage: config.InvalidUsernameMessage, diff --git a/etc/config.example.toml b/etc/config.example.toml index acfaa6f..dd4f3cb 100644 --- a/etc/config.example.toml +++ b/etc/config.example.toml @@ -10,8 +10,9 @@ host-keys = [ [auth] endpoint = "http://127.0.0.1:5000/ssh" -token = "token" +version = "legacy" # Legacy settings +token = "token" all-username-nopassword = true usernames-nopassword = ["vlab", "ubuntu", "root"] invalid-usernames = ["用户名"] diff --git a/fixtures/config.toml b/fixtures/config.toml index 172d7e8..341c06f 100644 --- a/fixtures/config.toml +++ b/fixtures/config.toml @@ -19,13 +19,11 @@ base64 = "LS0tLS1CRUdJTiBPUEVOU1NIIFBSSVZBVEUgS0VZLS0tLS0KYjNCbGJuTnphQzFyWlhrdG path = "fixtures/ssh_host_rsa_key" [auth] -endpoint = "http://127.0.0.1:5000/ssh" -token = "token" -# Legacy settings -all-username-nopassword = true -usernames-nopassword = ["vlab", "ubuntu", "root"] -invalid-usernames = ["用户名"] -invalid-username-message = "Invalid username %s. Please check https://vlab.ustc.edu.cn/docs/login/ssh/#username for more information." +endpoint = "http://127.0.0.1:5000" +version = "v1" +headers = [ + { name = "Authorization", value = "ApiKey 12345678" }, +] [logger] enabled = false diff --git a/fixtures/legacy.toml b/fixtures/legacy.toml new file mode 100644 index 0000000..e0791e3 --- /dev/null +++ b/fixtures/legacy.toml @@ -0,0 +1,42 @@ +address = "0.0.0.0:8022" + +[ssh] +banner = "Welcome to Vlab\n" +[[ssh.host-keys]] +content = """ +-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAaAAAABNlY2RzYS +1zaGEyLW5pc3RwMjU2AAAACG5pc3RwMjU2AAAAQQQNVE69PKFYERMMmQVUDdmz6cP6i44e +6LhN5091KWPVToekpMKvPYxMgfQWPFkmRSB1t2eMCrI9Vr9vfEZCaM/tAAAAmCtjMwcrYz +MHAAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBA1UTr08oVgREwyZ +BVQN2bPpw/qLjh7ouE3nT3UpY9VOh6Skwq89jEyB9BY8WSZFIHW3Z4wKsj1Wv298RkJoz+ +0AAAAgHkhPmtcUZwSkQAjy8QtHjdJ7AM4eGXhJWBp9icCRvWUAAAAA +-----END OPENSSH PRIVATE KEY----- +""" +[[ssh.host-keys]] +base64 = "LS0tLS1CRUdJTiBPUEVOU1NIIFBSSVZBVEUgS0VZLS0tLS0KYjNCbGJuTnphQzFyWlhrdGRqRUFBQUFBQkc1dmJtVUFBQUFFYm05dVpRQUFBQUFBQUFBQkFBQUFhQUFBQUJObFkyUnpZUwoxemFHRXlMVzVwYzNSd01qVTJBQUFBQ0c1cGMzUndNalUyQUFBQVFRUU5WRTY5UEtGWUVSTU1tUVZVRGRtejZjUDZpNDRlCjZMaE41MDkxS1dQVlRvZWtwTUt2UFl4TWdmUVdQRmttUlNCMXQyZU1Dckk5VnI5dmZFWkNhTS90QUFBQW1DdGpNd2NyWXoKTUhBQUFBRTJWalpITmhMWE5vWVRJdGJtbHpkSEF5TlRZQUFBQUlibWx6ZEhBeU5UWUFBQUJCQkExVVRyMDhvVmdSRXd5WgpCVlFOMmJQcHcvcUxqaDdvdUUzblQzVXBZOVZPaDZTa3dxODlqRXlCOUJZOFdTWkZJSFczWjR3S3NqMVd2Mjk4UmtKb3orCjBBQUFBZ0hraFBtdGNVWndTa1FBank4UXRIamRKN0FNNGVHWGhKV0JwOWljQ1J2V1VBQUFBQQotLS0tLUVORCBPUEVOU1NIIFBSSVZBVEUgS0VZLS0tLS0K" +[[ssh.host-keys]] +path = "fixtures/ssh_host_rsa_key" + +[auth] +endpoint = "http://127.0.0.1:5000/ssh" +version = "legacy" +# Legacy settings +token = "token" +all-username-nopassword = true +usernames-nopassword = ["vlab", "ubuntu", "root"] +invalid-usernames = ["用户名"] +invalid-username-message = "Invalid username %s. Please check https://vlab.ustc.edu.cn/docs/login/ssh/#username for more information." + +[logger] +enabled = true +endpoint = "udp://127.0.0.1:5556" + +[proxy-protocol] +enabled = true +hosts = ["127.0.0.22"] + +[recovery] +address = "172.30.0.101:2222" +usernames = ["recovery", "console", "serial"] +token = "token" diff --git a/go.mod b/go.mod index db54b8d..b72f20b 100644 --- a/go.mod +++ b/go.mod @@ -10,8 +10,8 @@ require github.com/pires/go-proxyproto v0.7.0 require github.com/pelletier/go-toml/v2 v2.2.2 -require ( - golang.org/x/sys v0.21.0 // indirect -) +require github.com/julienschmidt/httprouter v1.3.0 + +require golang.org/x/sys v0.21.0 // indirect replace golang.org/x/crypto => ./crypto diff --git a/go.sum b/go.sum index e537255..ad0f9a7 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= diff --git a/legacy_auth.go b/legacy_auth.go new file mode 100644 index 0000000..be8c736 --- /dev/null +++ b/legacy_auth.go @@ -0,0 +1,223 @@ +package main + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/netip" + "slices" + + "golang.org/x/crypto/ssh" +) + +type LegacyAuthRequestPublicKey struct { + AuthType string `json:"auth_type"` + UnixUsername string `json:"unix_username"` + PublicKeyType string `json:"public_key_type"` + PublicKeyData string `json:"public_key_data"` + Token string `json:"token"` +} + +type LegacyAuthRequestPassword struct { + AuthType string `json:"auth_type"` + Username string `json:"username"` + Password string `json:"password"` + UnixUsername string `json:"unix_username"` + Token string `json:"token"` +} + +type LegacyAuthResponse struct { + Status string `json:"status"` + Address string `json:"address"` + PrivateKey string `json:"private_key"` + Cert string `json:"cert"` + Id int `json:"vmid"` + ProxyProtocol byte `json:"proxy_protocol,omitempty"` +} + +type LegacyAuthUpstream struct { + Host string + PrivateKey string + Certificate string + Password *string + ProxyProtocol byte +} + +type LegacyAuthenticator struct { + Endpoint string + Token string + Recovery RecoveryConfig + UsernamePolicy UsernamePolicyConfig + PasswordPolicy PasswordPolicyConfig + Headers http.Header +} + +func makeLegacyAuthenticator(auth AuthConfig, recovery RecoveryConfig) LegacyAuthenticator { + headers := http.Header{} + for _, header := range auth.Headers { + headers.Add(header.Name, header.Value) + } + return LegacyAuthenticator{ + Endpoint: auth.Endpoint, + Token: auth.Token, + Recovery: recovery, + UsernamePolicy: UsernamePolicyConfig{ + InvalidUsernames: auth.InvalidUsernames, + InvalidUsernameMessage: auth.InvalidUsernameMessage, + }, + PasswordPolicy: PasswordPolicyConfig{ + AllUsernameNoPassword: auth.AllUsernameNoPassword, + UsernamesNoPassword: auth.UsernamesNoPassword, + }, + Headers: headers, + } +} + +func (auth *LegacyAuthenticator) Auth(request AuthRequest, username string) (int, *AuthResponse, error) { + var upstream *LegacyAuthUpstream + var err error + if slices.Contains(auth.UsernamePolicy.InvalidUsernames, username) { + // 15: SSH_DISCONNECT_ILLEGAL_USER_NAME + msg := fmt.Sprintf(auth.UsernamePolicy.InvalidUsernameMessage, username) + failure := AuthFailure{Message: msg, Reason: 15, Disconnect: true} + return 403, &AuthResponse{Failure: &failure}, nil + } + if request.Method == "publickey" { + publicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(request.PublicKey)) + if err != nil { + return 500, nil, err + } + upstream, err = auth.AuthUserWithPublicKey(publicKey, username) + if err != nil { + return 500, nil, err + } + } + if request.Method == "keyboard-interactive" { + requireUnixPassword := !auth.PasswordPolicy.AllUsernameNoPassword && + !slices.Contains(auth.Recovery.Usernames, username) && + !slices.Contains(auth.PasswordPolicy.UsernamesNoPassword, username) + username, has_username := request.Payload["username"] + password, has_password := request.Payload["password"] + if !has_username || !has_password { + challenge := AuthChallenge{ + Instruction: "Please enter Vlab username & password.", + Fields: []AuthChallengeField{ + {Key: "username", Prompt: "Vlab username (Student ID): "}, + {Key: "password", Prompt: "Vlab password: ", Secret: true}, + }, + } + resp := AuthResponse{Challenges: []AuthChallenge{challenge}} + return 401, &resp, nil + } + _, has_unix_password := request.Payload["unix_password"] + if requireUnixPassword && !has_unix_password { + challenge := AuthChallenge{ + Instruction: "Please enter UNIX password.", + Fields: []AuthChallengeField{ + {Key: "unix_password", Prompt: "UNIX password: ", Secret: true}, + }, + } + resp := AuthResponse{Challenges: []AuthChallenge{challenge}} + return 401, &resp, nil + } + upstream, err = auth.AuthUserWithUserPass(username, password, username) + if err != nil { + return 500, nil, err + } + } + if upstream != nil { + address, err := netip.ParseAddrPort(upstream.Host) + if err != nil { + return 500, nil, err + } + resp := AuthResponse{ + Upstream: &AuthUpstream{ + Host: address.Addr().String(), + Port: address.Port(), + PrivateKey: upstream.PrivateKey, + Certificate: upstream.Certificate, + Password: upstream.Password, + }, + } + unix_password, has_unix_password := request.Payload["unix_password"] + if has_unix_password { + resp.Upstream.Password = &unix_password + } + if upstream.ProxyProtocol > 0 { + protocolVersion := fmt.Sprintf("v%d", upstream.ProxyProtocol) + resp.Proxy = &AuthProxy{Protocol: &protocolVersion} + } + return 200, &resp, nil + } + return 403, &AuthResponse{}, nil +} + +func (auth LegacyAuthenticator) AuthUser(request any, username string) (*LegacyAuthUpstream, error) { + payload := new(bytes.Buffer) + if err := json.NewEncoder(payload).Encode(request); err != nil { + return nil, err + } + req, err := http.NewRequest("POST", auth.Endpoint, payload) + if err != nil { + return nil, err + } + req.Header = auth.Headers.Clone() + req.Header.Set("accept", "application/json") + req.Header.Set("content-type", "application/json") + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + var response LegacyAuthResponse + err = json.Unmarshal(body, &response) + if err != nil { + return nil, err + } + if response.Status != "ok" { + return nil, nil + } + var upstream LegacyAuthUpstream + if slices.Contains(auth.Recovery.Usernames, username) { + upstream.Host = auth.Recovery.Address + password := fmt.Sprintf("%d %s", response.Id, auth.Recovery.Token) + upstream.Password = &password + } else { + upstream.Host = response.Address + } + upstream.PrivateKey = response.PrivateKey + upstream.Certificate = response.Cert + upstream.ProxyProtocol = response.ProxyProtocol + return &upstream, nil +} + +func (auth LegacyAuthenticator) AuthUserWithPublicKey(key ssh.PublicKey, unixUsername string) (*LegacyAuthUpstream, error) { + keyType := key.Type() + keyData := base64.StdEncoding.EncodeToString(key.Marshal()) + request := &LegacyAuthRequestPublicKey{ + AuthType: "key", + UnixUsername: unixUsername, + PublicKeyType: keyType, + PublicKeyData: keyData, + Token: auth.Token, + } + return auth.AuthUser(request, unixUsername) +} + +func (auth LegacyAuthenticator) AuthUserWithUserPass(username string, password string, unixUsername string) (*LegacyAuthUpstream, error) { + request := &LegacyAuthRequestPassword{ + AuthType: "key", + Username: username, + Password: password, + UnixUsername: unixUsername, + Token: auth.Token, + } + return auth.AuthUser(request, unixUsername) +} diff --git a/sshmux.go b/sshmux.go index 48caad3..7fa0589 100644 --- a/sshmux.go +++ b/sshmux.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/base64" + "errors" "fmt" "io" "log" @@ -11,7 +12,7 @@ import ( "net/netip" "net/url" "os" - "slices" + "strconv" "sync" "time" @@ -21,26 +22,32 @@ import ( ) type Server struct { - listener net.Listener - wg sync.WaitGroup - ctx context.Context - cancel context.CancelFunc - Address string - Banner string - SSHConfig *ssh.ServerConfig - Authenticator Authenticator - LogWriter io.Writer - ProxyPolicy ProxyPolicyConfig - UsernamePolicy UsernamePolicyConfig - PasswordPolicy PasswordPolicyConfig + listener net.Listener + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + Address string + Banner string + SSHConfig *ssh.ServerConfig + Authenticator Authenticator + LogWriter io.Writer + ProxyPolicy ProxyPolicyConfig +} + +type upstreamInformation struct { + Address string + Signer ssh.Signer + Password *string + ProxyProtocol *byte + ProxyDestination string } func validateKey(config SSHKeyConfig) (ssh.Signer, error) { if config.Path == "" && config.Base64 == "" && config.Content == "" { - return nil, fmt.Errorf("one of path, base64 or content of the SSH key must be set") + return nil, errors.New("one of path, base64 or content of the SSH key must be set") } if (config.Path != "" && config.Base64 != "") || (config.Path != "" && config.Content != "") || (config.Base64 != "" && config.Content != "") { - return nil, fmt.Errorf("only one of path, base64 or content of the SSH key can be set") + return nil, errors.New("only one of path, base64 or content of the SSH key can be set") } var pemFile []byte if config.Path != "" { @@ -88,30 +95,33 @@ func makeServer(config Config) (*Server, error) { if loggerURL.Scheme == "udp" { conn, err := net.Dial("udp", loggerURL.Host) if err != nil { - log.Fatalf("Logger Dial failed: %s\n", err) + return nil, fmt.Errorf("logger dial failed: %w", err) } logWriter = conn } else { - log.Fatalf("unsupported logger endpoint: %s\n", config.Logger.Endpoint) + return nil, fmt.Errorf("unsupported logger endpoint: %s", config.Logger.Endpoint) } } else { logWriter = io.Discard } + var authenticator Authenticator + if config.Auth.Version == "" || config.Auth.Version == "legacy" { + legacyAuthenticator := makeLegacyAuthenticator(config.Auth, config.Recovery) + authenticator = &legacyAuthenticator + } else { + var err error + authenticator, err = makeAuthenticator(config.Auth) + if err != nil { + return nil, err + } + } sshmux := &Server{ Address: config.Address, Banner: config.SSH.Banner, SSHConfig: sshConfig, - Authenticator: makeAuthenticator(config.Auth, config.Recovery), + Authenticator: authenticator, LogWriter: logWriter, ProxyPolicy: proxyPolicyConfig, - UsernamePolicy: UsernamePolicyConfig{ - InvalidUsernames: config.Auth.InvalidUsernames, - InvalidUsernameMessage: config.Auth.InvalidUsernameMessage, - }, - PasswordPolicy: PasswordPolicyConfig{ - AllUsernameNoPassword: config.Auth.AllUsernameNoPassword, - UsernamesNoPassword: config.Auth.UsernamesNoPassword, - }, } return sshmux, nil } @@ -173,91 +183,132 @@ func (s *Server) handler(conn net.Conn) { func (s *Server) Handshake(session *ssh.PipeSession) error { hasSetUser := false var user string - var upstream *UpstreamInformation + var upstream *upstreamInformation if s.Banner != "" { err := session.Downstream.SendBanner(s.Banner) if err != nil { return err } } - // Stage 1: Get publickey or keyboard-interactive answers, and authenticate the user with with API + // Stage 1: Authenticate the user with API +auth_requests: for { - req, err := session.Downstream.ReadAuthRequest(true) + authReq, err := session.Downstream.ReadAuthRequest(true) if err != nil { return err } if !hasSetUser { - user = req.User + user = authReq.User session.Downstream.SetUser(user) hasSetUser = true } - if slices.Contains(s.UsernamePolicy.InvalidUsernames, user) { - // 15: SSH_DISCONNECT_ILLEGAL_USER_NAME - msg := fmt.Sprintf(s.UsernamePolicy.InvalidUsernameMessage, user) - session.Downstream.WriteDisconnectMsg(15, msg) - return fmt.Errorf("ssh: invalid username") + req := AuthRequest{Method: authReq.Method} + if authReq.Method == "publickey" && !authReq.IsPublicKeyQuery { + req.PublicKey = string(ssh.MarshalAuthorizedKey(*authReq.PublicKey)) } - if req.Method == "none" { - session.Downstream.WriteAuthFailure([]string{"publickey", "keyboard-interactive"}, false) - } else if req.Method == "publickey" && !req.IsPublicKeyQuery { - upstream, err = s.Authenticator.AuthUserWithPublicKey(*req.PublicKey, user) - if err != nil { - return err - } - if upstream != nil { - break - } - session.Downstream.WriteAuthFailure([]string{"publickey", "keyboard-interactive"}, false) - } else if req.Method == "keyboard-interactive" { - // FIXME: Can this be handled by API server? - requireUnixPassword := !s.PasswordPolicy.AllUsernameNoPassword && - !slices.Contains(s.Authenticator.Recovery.Usernames, user) && - !slices.Contains(s.PasswordPolicy.UsernamesNoPassword, user) - interactiveQuestions := []string{"Vlab username (Student ID): ", "Vlab password: "} - interactiveEcho := []bool{true, false} - - answers, err := session.Downstream.InteractiveChallenge("", - "Please enter Vlab username & password.", - interactiveQuestions, interactiveEcho) + for { + status, resp, err := s.Authenticator.Auth(req, user) if err != nil { return err } - if len(answers) != len(interactiveQuestions) { - return fmt.Errorf("ssh: numbers of answers and questions do not match") - } - username := answers[0] - password := answers[1] - upstream, err = s.Authenticator.AuthUserWithUserPass(username, password, user) - if err != nil { - return err - } - if upstream != nil { - if requireUnixPassword { - answers, err := session.Downstream.InteractiveChallenge("", - "Please enter UNIX password.", - []string{"UNIX password: "}, []bool{false}) + switch status { + case 200: + upstreamResp := *resp.Upstream + if upstreamResp.Port == 0 { + upstreamResp.Port = 22 + } + upstream = &upstreamInformation{ + Signer: parsePrivateKey(upstreamResp.PrivateKey, upstreamResp.Certificate), + Password: upstreamResp.Password, + } + upstream.Address = net.JoinHostPort(upstreamResp.Host, strconv.Itoa(int(upstreamResp.Port))) + if resp.Proxy != nil { + proxyConfig := *resp.Proxy + // parse protocol version + var protocolVersion byte + if proxyConfig.Protocol != nil { + switch *proxyConfig.Protocol { + case "v1": + protocolVersion = 1 + case "v2": + protocolVersion = 2 + default: + return fmt.Errorf("unknown PROXY protocol version: %s", *proxyConfig.Protocol) + } + } + upstream.ProxyProtocol = &protocolVersion + // parse protocol destination + upstream.ProxyDestination = upstream.Address + if proxyConfig.Host == "" { + proxyConfig.Host = upstreamResp.Host + } + if proxyConfig.Port == 0 { + proxyConfig.Port = upstreamResp.Port + } + upstream.Address = net.JoinHostPort(proxyConfig.Host, strconv.Itoa(int(proxyConfig.Port))) + } + break auth_requests + case 401: + if len(resp.Challenges) == 0 { + // The API server is requesting no challenges, which is abnormal and will + // likely lead to an infinite loop + session.Downstream.WriteAuthFailure([]string{"publickey", "keyboard-interactive"}, false) + continue auth_requests + } + for _, challenge := range resp.Challenges { + questions := make([]string, 0, len(challenge.Fields)) + withEcho := make([]bool, 0, len(challenge.Fields)) + for _, field := range challenge.Fields { + questions = append(questions, field.Prompt) + withEcho = append(withEcho, !field.Secret) + } + answers, err := session.Downstream.InteractiveChallenge("", challenge.Instruction, questions, withEcho) if err != nil { return err } - if len(answers) != 1 { - return fmt.Errorf("ssh: expected UNIX password") + if len(answers) != len(questions) { + return errors.New("ssh: numbers of answers and questions do not match") + } + if req.Payload == nil { + req.Payload = make(map[string]string, len(challenge.Fields)) + } + for i, answer := range answers { + req.Payload[challenge.Fields[i].Key] = answer + } + } + continue + case 403: + if resp.Failure != nil { + failure := *resp.Failure + if failure.Disconnect { + if failure.Reason == 0 { + // 11: SSH_DISCONNECT_BY_APPLICATION + failure.Reason = 11 + } + session.Downstream.WriteDisconnectMsg(failure.Reason, failure.Message) + return fmt.Errorf("ssh(%d): %s", failure.Reason, failure.Message) } - upstream.Password = &answers[0] } - break + fallthrough + default: + session.Downstream.WriteAuthFailure([]string{"publickey", "keyboard-interactive"}, false) + continue auth_requests } - session.Downstream.WriteAuthFailure([]string{"publickey", "keyboard-interactive"}, false) - } else { - session.Downstream.WriteAuthFailure([]string{"publickey", "keyboard-interactive"}, false) } } // Stage 2: connect to upstream - conn, err := net.Dial("tcp", upstream.Host) + conn, err := net.Dial("tcp", upstream.Address) if err != nil { return err } - if upstream.ProxyProtocol > 0 { - header := proxyproto.HeaderProxyFromAddrs(upstream.ProxyProtocol, session.Downstream.RemoteAddr(), conn.RemoteAddr()) + if upstream.ProxyProtocol != nil { + dest := conn.RemoteAddr() + if upstream.ProxyDestination != upstream.Address { + if addr, err := net.ResolveTCPAddr("tcp", upstream.ProxyDestination); err == nil { + dest = addr + } + } + header := proxyproto.HeaderProxyFromAddrs(*upstream.ProxyProtocol, session.Downstream.RemoteAddr(), dest) _, err := header.WriteTo(conn) if err != nil { return err @@ -267,7 +318,7 @@ func (s *Server) Handshake(session *ssh.PipeSession) error { User: user, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - err = session.InitUpstream(conn, upstream.Host, sshConfig) + err = session.InitUpstream(conn, upstream.Address, sshConfig) if err != nil { return err } @@ -287,7 +338,7 @@ func (s *Server) Handshake(session *ssh.PipeSession) error { if err != nil { return err } - // For the first auth fail, we mark it partial succss + // For the first auth fail, we mark it as partial success if !res.Success { err = session.Downstream.WriteAuthFailure(removePublicKeyMethod(res.Methods), true) } else { @@ -299,7 +350,7 @@ func (s *Server) Handshake(session *ssh.PipeSession) error { if res.Success { return nil } - // Finally, pipe downstream and upstream's auth request and result + // Finally, pipe downstream and upstream's auth requests and results // Note that publickey auth cannot be used anymore after this point for { req, err := session.Downstream.ReadAuthRequest(true) diff --git a/sshmux_test.go b/sshmux_test.go index b9e2aeb..6e9937a 100644 --- a/sshmux_test.go +++ b/sshmux_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/julienschmidt/httprouter" "github.com/pires/go-proxyproto" ) @@ -32,7 +33,7 @@ func localhostTCPAddr(port int) *net.TCPAddr { var enableProxy bool func initHttp(sshPrivateKey []byte) { - sshAPIHandler := func(w http.ResponseWriter, r *http.Request) { + sshAPIHandler := func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Cannot read body", http.StatusBadRequest) @@ -44,16 +45,16 @@ func initHttp(sshPrivateKey []byte) { return } - res := &AuthResponse{ - Status: "ok", - Id: 1141919, - PrivateKey: string(sshPrivateKey), + res := map[string]any{ + "status": "ok", + "vmid": 1141919, + "private_key": string(sshPrivateKey), } if enableProxy { - res.Address = sshdProxiedAddr.String() - res.ProxyProtocol = 2 + res["address"] = sshdProxiedAddr.String() + res["proxy_protocol"] = 2 } else { - res.Address = sshdServerAddr.String() + res["address"] = sshdServerAddr.String() } jsonRes, err := json.Marshal(res) @@ -65,9 +66,46 @@ func initHttp(sshPrivateKey []byte) { w.Write(jsonRes) } - http.HandleFunc("/ssh", sshAPIHandler) + authAPIHandler := func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Cannot read body", http.StatusBadRequest) + return + } + var dat map[string]interface{} + if err := json.Unmarshal(body, &dat); err != nil { + http.Error(w, "Not JSON", http.StatusBadRequest) + return + } + + res := map[string]any{ + "upstream": map[string]any{ + "host": sshdServerAddr.IP.String(), + "port": sshdServerAddr.Port, + "private_key": string(sshPrivateKey), + }, + } + if enableProxy { + res["proxy"] = map[string]any{ + "host": sshdProxiedAddr.IP.String(), + "port": sshdProxiedAddr.Port, + } + } + + jsonRes, err := json.Marshal(res) + if err != nil { + http.Error(w, "Cannot encode JSON", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(jsonRes) + } + + router := httprouter.New() + router.POST("/ssh", sshAPIHandler) + router.POST("/v1/auth/:name", authAPIHandler) - if err := http.ListenAndServe(apiServerAddr.String(), nil); err != nil { + if err := http.ListenAndServe(apiServerAddr.String(), router); err != nil { log.Fatal(err) } } @@ -229,7 +267,7 @@ func testWithSSHClient(t *testing.T, address *net.TCPAddr, description string, p func TestSSHClientConnection(t *testing.T) { initEnv(t) - configFiles := []string{"config.toml", "config.json"} + configFiles := []string{"config.toml", "legacy.toml", "config.json"} for _, configFile := range configFiles { // start sshmux server