mirror of
https://github.com/certimate-go/certimate.git
synced 2026-06-27 21:03:11 +08:00
232 lines
5.4 KiB
Go
232 lines
5.4 KiB
Go
package ssh
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/samber/lo"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type Client struct {
|
|
conns []net.Conn
|
|
clis []*ssh.Client
|
|
}
|
|
|
|
func NewClient(config *Config) (*Client, error) {
|
|
if config == nil {
|
|
return nil, fmt.Errorf("the configuration of SSH client is nil")
|
|
}
|
|
|
|
conns, clis, err := createConnsAndSshClients(config)
|
|
if err != nil {
|
|
for i := len(clis) - 1; i >= 0; i-- {
|
|
clis[i].Close()
|
|
}
|
|
|
|
for i := len(conns) - 1; i >= 0; i-- {
|
|
conns[i].Close()
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
return &Client{conns: conns, clis: clis}, nil
|
|
}
|
|
|
|
func (c *Client) Close() error {
|
|
errs := make([]error, 0)
|
|
|
|
for i := len(c.clis) - 1; i >= 0; i-- {
|
|
cli := c.clis[i]
|
|
if err := cli.Close(); err != nil {
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
|
|
for i := len(c.conns) - 1; i >= 0; i-- {
|
|
conn := c.conns[i]
|
|
if err := conn.Close(); err != nil {
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
|
|
if len(errs) == 0 {
|
|
return nil
|
|
} else if len(errs) == 1 {
|
|
return errs[0]
|
|
} else {
|
|
return errors.Join(errs...)
|
|
}
|
|
}
|
|
|
|
func (c *Client) GetClient() *ssh.Client {
|
|
if len(c.clis) == 0 {
|
|
return nil
|
|
}
|
|
|
|
return c.clis[len(c.clis)-1]
|
|
}
|
|
|
|
func createConnsAndSshClients(config *Config) (conns []net.Conn, clis []*ssh.Client, err error) {
|
|
conns = make([]net.Conn, 0)
|
|
clis = make([]*ssh.Client, 0)
|
|
|
|
var targetConn net.Conn
|
|
if len(config.JumpServers) > 0 {
|
|
var jumpCli *ssh.Client
|
|
|
|
for i, jumpConfig := range config.JumpServers {
|
|
var jumpConn net.Conn
|
|
if jumpCli == nil {
|
|
jumpConn, err = net.Dial("tcp", resolveAddr(jumpConfig.Host, jumpConfig.Port))
|
|
} else {
|
|
jumpConn, err = jumpCli.Dial("tcp", resolveAddr(jumpConfig.Host, jumpConfig.Port))
|
|
}
|
|
if err != nil {
|
|
err = fmt.Errorf("ssh: failed to connect to jump server [%d]: %w", i+1, err)
|
|
return
|
|
}
|
|
|
|
conns = append(conns, jumpConn)
|
|
|
|
jumpCli, err = createSshClientWithConn(&jumpConfig, jumpConn)
|
|
if err != nil {
|
|
err = fmt.Errorf("ssh: failed to create jump server SSH client[%d]: %w", i+1, err)
|
|
return
|
|
}
|
|
|
|
clis = append(clis, jumpCli)
|
|
}
|
|
|
|
// 通过跳板机发起 TCP 连接到目标服务器
|
|
targetConn, err = jumpCli.Dial("tcp", resolveAddr(config.Host, config.Port))
|
|
if err != nil {
|
|
err = fmt.Errorf("ssh: failed to connect to target server: %w", err)
|
|
return
|
|
}
|
|
|
|
conns = append(conns, targetConn)
|
|
} else {
|
|
// 直接发起 TCP 连接到目标服务器
|
|
targetConn, err = net.Dial("tcp", resolveAddr(config.Host, config.Port))
|
|
if err != nil {
|
|
err = fmt.Errorf("ssh: failed to connect to target server: %w", err)
|
|
return
|
|
}
|
|
|
|
conns = append(conns, targetConn)
|
|
}
|
|
|
|
// 创建 SSH 客户端
|
|
targetCli, err := createSshClientWithConn(&config.ServerConfig, targetConn)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("ssh: failed to create SSH client: %w", err)
|
|
}
|
|
|
|
clis = append(clis, targetCli)
|
|
|
|
return conns, clis, nil
|
|
}
|
|
|
|
func createSshClientWithConn(config *ServerConfig, conn net.Conn) (*ssh.Client, error) {
|
|
if conn == nil {
|
|
return nil, fmt.Errorf("ssh: nil conn")
|
|
}
|
|
|
|
authMethodType := lo.
|
|
If(string(config.AuthMethod) != "", config.AuthMethod).
|
|
ElseIf(config.Key != "", AuthMethodTypeKey).
|
|
ElseIf(config.Password != "", AuthMethodTypePassword).
|
|
Else(AuthMethodTypeNone)
|
|
authMethods := make([]ssh.AuthMethod, 0)
|
|
switch authMethodType {
|
|
case AuthMethodTypeNone:
|
|
{
|
|
if config.Username == "" {
|
|
return nil, fmt.Errorf("ssh: unset username")
|
|
}
|
|
}
|
|
|
|
case AuthMethodTypePassword:
|
|
{
|
|
if config.Username == "" {
|
|
return nil, fmt.Errorf("ssh: unset username")
|
|
}
|
|
if config.Password == "" {
|
|
return nil, fmt.Errorf("ssh: unset password")
|
|
}
|
|
|
|
password := config.Password
|
|
authMethods = append(authMethods, ssh.Password(password))
|
|
authMethods = append(authMethods, ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) ([]string, error) {
|
|
answers := make([]string, len(questions))
|
|
if len(answers) == 0 {
|
|
return answers, nil
|
|
}
|
|
|
|
for i, question := range questions {
|
|
question = strings.TrimSpace(strings.TrimSuffix(strings.TrimSpace(question), ":"))
|
|
if strings.EqualFold(question, "Password") {
|
|
answers[i] = password
|
|
return answers, nil
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("unexpected keyboard interactive question '%s'", strings.Join(questions, ", "))
|
|
}))
|
|
}
|
|
|
|
case AuthMethodTypeKey:
|
|
{
|
|
if config.Username == "" {
|
|
return nil, fmt.Errorf("ssh: unset username")
|
|
}
|
|
if config.Key == "" {
|
|
return nil, fmt.Errorf("ssh: unset key")
|
|
}
|
|
|
|
key := config.Key
|
|
keyPassphrase := config.KeyPassphrase
|
|
|
|
var signer ssh.Signer
|
|
var err error
|
|
if keyPassphrase != "" {
|
|
signer, err = ssh.ParsePrivateKeyWithPassphrase([]byte(key), []byte(keyPassphrase))
|
|
} else {
|
|
signer, err = ssh.ParsePrivateKey([]byte(key))
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ssh: %w", err)
|
|
}
|
|
|
|
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
|
}
|
|
|
|
default:
|
|
return nil, fmt.Errorf("ssh: unsupported auth method '%s'", authMethodType)
|
|
}
|
|
|
|
addr := resolveAddr(config.Host, config.Port)
|
|
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, &ssh.ClientConfig{
|
|
User: config.Username,
|
|
Auth: authMethods,
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ssh: %w", err)
|
|
}
|
|
|
|
return ssh.NewClient(sshConn, chans, reqs), nil
|
|
}
|
|
|
|
func resolveAddr(host string, port int) string {
|
|
if port == 0 {
|
|
port = defaultPort
|
|
}
|
|
return net.JoinHostPort(host, strconv.Itoa(port))
|
|
}
|