Skip to content
This repository was archived by the owner on Jan 17, 2021. It is now read-only.

Commit 2693c3f

Browse files
authored
Merge pull request #116 from cdr/reuse-ssh-connection
Add SSH master connection feature
2 parents c637d40 + bbd94c5 commit 2693c3f

File tree

2 files changed

+158
-19
lines changed

2 files changed

+158
-19
lines changed

main.go

+13-10
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ var _ interface {
3434
} = new(rootCmd)
3535

3636
type rootCmd struct {
37-
skipSync bool
38-
syncBack bool
39-
printVersion bool
40-
bindAddr string
41-
sshFlags string
37+
skipSync bool
38+
syncBack bool
39+
printVersion bool
40+
noReuseConnection bool
41+
bindAddr string
42+
sshFlags string
4243
}
4344

4445
func (c *rootCmd) Spec() cli.CommandSpec {
@@ -53,6 +54,7 @@ func (c *rootCmd) RegisterFlags(fl *flag.FlagSet) {
5354
fl.BoolVar(&c.skipSync, "skipsync", false, "skip syncing local settings and extensions to remote host")
5455
fl.BoolVar(&c.syncBack, "b", false, "sync extensions back on termination")
5556
fl.BoolVar(&c.printVersion, "version", false, "print version information and exit")
57+
fl.BoolVar(&c.noReuseConnection, "no-reuse-connection", false, "do not reuse SSH connection via control socket")
5658
fl.StringVar(&c.bindAddr, "bind", "", "local bind address for SSH tunnel, in [HOST][:PORT] syntax (default: 127.0.0.1)")
5759
fl.StringVar(&c.sshFlags, "ssh-flags", "", "custom SSH flags")
5860
}
@@ -76,10 +78,11 @@ func (c *rootCmd) Run(fl *flag.FlagSet) {
7678
}
7779

7880
err := sshCode(host, dir, options{
79-
skipSync: c.skipSync,
80-
sshFlags: c.sshFlags,
81-
bindAddr: c.bindAddr,
82-
syncBack: c.syncBack,
81+
skipSync: c.skipSync,
82+
sshFlags: c.sshFlags,
83+
bindAddr: c.bindAddr,
84+
syncBack: c.syncBack,
85+
reuseConnection: !c.noReuseConnection,
8386
})
8487

8588
if err != nil {
@@ -101,7 +104,7 @@ Environment variables:
101104
More info: https://github.com/cdr/sshcode
102105
103106
Arguments:
104-
%vHOST is passed into the ssh command. Valid formats are '<ip-address>' or 'gcp:<instance-name>'.
107+
%vHOST is passed into the ssh command. Valid formats are '<ip-address>' or 'gcp:<instance-name>'.
105108
%vDIR is optional.`,
106109
helpTab, vsCodeConfigDirEnv,
107110
helpTab, vsCodeExtensionsDirEnv,

sshcode.go

+145-9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"path/filepath"
1313
"strconv"
1414
"strings"
15+
"syscall"
1516
"time"
1617

1718
"github.com/pkg/browser"
@@ -21,18 +22,23 @@ import (
2122

2223
const codeServerPath = "~/.cache/sshcode/sshcode-server"
2324

25+
const (
26+
sshDirectory = "~/.ssh"
27+
sshDirectoryUnsafeModeMask = 0022
28+
sshControlPath = sshDirectory + "/control-%h-%p-%r"
29+
)
30+
2431
type options struct {
25-
skipSync bool
26-
syncBack bool
27-
noOpen bool
28-
bindAddr string
29-
remotePort string
30-
sshFlags string
32+
skipSync bool
33+
syncBack bool
34+
noOpen bool
35+
reuseConnection bool
36+
bindAddr string
37+
remotePort string
38+
sshFlags string
3139
}
3240

3341
func sshCode(host, dir string, o options) error {
34-
flog.Info("ensuring code-server is updated...")
35-
3642
host, extraSSHFlags, err := parseHost(host)
3743
if err != nil {
3844
return xerrors.Errorf("failed to parse host IP: %w", err)
@@ -53,6 +59,24 @@ func sshCode(host, dir string, o options) error {
5359
return xerrors.Errorf("failed to find available remote port: %w", err)
5460
}
5561

62+
// Check the SSH directory's permissions and warn the user if it is not safe.
63+
o.reuseConnection = checkSSHDirectory(sshDirectory, o.reuseConnection)
64+
65+
// Start SSH master connection socket. This prevents multiple password prompts from appearing as authentication
66+
// only happens on the initial connection.
67+
if o.reuseConnection {
68+
flog.Info("starting SSH master connection...")
69+
newSSHFlags, cancel, err := startSSHMaster(o.sshFlags, sshControlPath, host)
70+
defer cancel()
71+
if err != nil {
72+
flog.Error("failed to start SSH master connection: %v", err)
73+
o.reuseConnection = false
74+
} else {
75+
o.sshFlags = newSSHFlags
76+
}
77+
}
78+
79+
flog.Info("ensuring code-server is updated...")
5680
dlScript := downloadScript(codeServerPath)
5781

5882
// Downloads the latest code-server and allows it to be executed.
@@ -147,8 +171,8 @@ func sshCode(host, dir string, o options) error {
147171
case <-c:
148172
}
149173

174+
flog.Info("shutting down")
150175
if !o.syncBack || o.skipSync {
151-
flog.Info("shutting down")
152176
return nil
153177
}
154178

@@ -167,6 +191,24 @@ func sshCode(host, dir string, o options) error {
167191
return nil
168192
}
169193

194+
// expandPath returns an expanded version of path.
195+
func expandPath(path string) string {
196+
path = filepath.Clean(os.ExpandEnv(path))
197+
198+
// Replace tilde notation in path with the home directory. You can't replace the first instance of `~` in the
199+
// string with the homedir as having a tilde in the middle of a filename is valid.
200+
homedir := os.Getenv("HOME")
201+
if homedir != "" {
202+
if path == "~" {
203+
path = homedir
204+
} else if strings.HasPrefix(path, "~/") {
205+
path = filepath.Join(homedir, path[2:])
206+
}
207+
}
208+
209+
return filepath.Clean(path)
210+
}
211+
170212
func parseBindAddr(bindAddr string) (string, error) {
171213
if !strings.Contains(bindAddr, ":") {
172214
bindAddr += ":"
@@ -263,6 +305,100 @@ func randomPort() (string, error) {
263305
return "", xerrors.Errorf("max number of tries exceeded: %d", maxTries)
264306
}
265307

308+
// checkSSHDirectory performs sanity and safety checks on sshDirectory, and
309+
// returns a new value for o.reuseConnection depending on the checks.
310+
func checkSSHDirectory(sshDirectory string, reuseConnection bool) bool {
311+
sshDirectoryMode, err := os.Lstat(expandPath(sshDirectory))
312+
if err != nil {
313+
if reuseConnection {
314+
flog.Info("failed to stat %v directory, disabling connection reuse feature: %v", sshDirectory, err)
315+
}
316+
reuseConnection = false
317+
} else {
318+
if !sshDirectoryMode.IsDir() {
319+
if reuseConnection {
320+
flog.Info("%v is not a directory, disabling connection reuse feature", sshDirectory)
321+
} else {
322+
flog.Info("warning: %v is not a directory", sshDirectory)
323+
}
324+
reuseConnection = false
325+
}
326+
if sshDirectoryMode.Mode().Perm()&sshDirectoryUnsafeModeMask != 0 {
327+
flog.Info("warning: the %v directory has unsafe permissions, they should only be writable by "+
328+
"the owner (and files inside should be set to 0600)", sshDirectory)
329+
}
330+
}
331+
return reuseConnection
332+
}
333+
334+
// startSSHMaster starts an SSH master connection and waits for it to be ready.
335+
// It returns a new set of SSH flags for child SSH processes to use.
336+
func startSSHMaster(sshFlags string, sshControlPath string, host string) (string, func(), error) {
337+
ctx, cancel := context.WithCancel(context.Background())
338+
339+
newSSHFlags := fmt.Sprintf(`%v -o "ControlPath=%v"`, sshFlags, sshControlPath)
340+
341+
// -MN means "start a master socket and don't open a session, just connect".
342+
sshCmdStr := fmt.Sprintf(`exec ssh %v -MNq %v`, newSSHFlags, host)
343+
sshMasterCmd := exec.CommandContext(ctx, "sh", "-c", sshCmdStr)
344+
sshMasterCmd.Stdin = os.Stdin
345+
sshMasterCmd.Stderr = os.Stderr
346+
347+
// Gracefully stop the SSH master.
348+
stopSSHMaster := func() {
349+
if sshMasterCmd.Process != nil {
350+
if sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited() {
351+
return
352+
}
353+
err := sshMasterCmd.Process.Signal(syscall.SIGTERM)
354+
if err != nil {
355+
flog.Error("failed to send SIGTERM to SSH master process: %v", err)
356+
}
357+
}
358+
cancel()
359+
}
360+
361+
// Start ssh master and wait. Waiting prevents the process from becoming a zombie process if it dies before
362+
// sshcode does, and allows sshMasterCmd.ProcessState to be populated.
363+
err := sshMasterCmd.Start()
364+
go sshMasterCmd.Wait()
365+
if err != nil {
366+
return "", stopSSHMaster, err
367+
}
368+
err = checkSSHMaster(sshMasterCmd, newSSHFlags, host)
369+
if err != nil {
370+
stopSSHMaster()
371+
return "", stopSSHMaster, xerrors.Errorf("SSH master wasn't ready on time: %w", err)
372+
}
373+
return newSSHFlags, stopSSHMaster, nil
374+
}
375+
376+
// checkSSHMaster polls every second for 30 seconds to check if the SSH master
377+
// is ready.
378+
func checkSSHMaster(sshMasterCmd *exec.Cmd, sshFlags string, host string) error {
379+
var (
380+
maxTries = 30
381+
sleepDur = time.Second
382+
err error
383+
)
384+
for i := 0; i < maxTries; i++ {
385+
// Check if the master is running.
386+
if sshMasterCmd.Process == nil || (sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited()) {
387+
return xerrors.Errorf("SSH master process is not running")
388+
}
389+
390+
// Check if it's ready.
391+
sshCmdStr := fmt.Sprintf(`ssh %v -O check %v`, sshFlags, host)
392+
sshCmd := exec.Command("sh", "-c", sshCmdStr)
393+
err = sshCmd.Run()
394+
if err == nil {
395+
return nil
396+
}
397+
time.Sleep(sleepDur)
398+
}
399+
return xerrors.Errorf("max number of tries exceeded: %d", maxTries)
400+
}
401+
266402
func syncUserSettings(sshFlags string, host string, back bool) error {
267403
localConfDir, err := configDir()
268404
if err != nil {

0 commit comments

Comments
 (0)