|
@@ -0,0 +1,255 @@
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+package main
|
|
|
+
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "crypto/rand"
|
|
|
+ "crypto/subtle"
|
|
|
+ "errors"
|
|
|
+ "fmt"
|
|
|
+ "sync"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "golang.org/x/crypto/ssh"
|
|
|
+ "golang.org/x/crypto/ssh/agent"
|
|
|
+)
|
|
|
+
|
|
|
+type privKey struct {
|
|
|
+ signer ssh.Signer
|
|
|
+ comment string
|
|
|
+ confirm bool
|
|
|
+ expire *time.Time
|
|
|
+}
|
|
|
+
|
|
|
+type ConfirmFunction func(comment string) bool
|
|
|
+
|
|
|
+type keyring struct {
|
|
|
+ confirmFunction ConfirmFunction
|
|
|
+
|
|
|
+ mu sync.Mutex
|
|
|
+ keys []privKey
|
|
|
+
|
|
|
+ locked bool
|
|
|
+ passphrase []byte
|
|
|
+}
|
|
|
+
|
|
|
+var errLocked = errors.New("agent: locked")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+func NewKeyring(confirmFunction ConfirmFunction) agent.ExtendedAgent {
|
|
|
+ return &keyring{
|
|
|
+ confirmFunction: confirmFunction,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (r *keyring) RemoveAll() error {
|
|
|
+ r.mu.Lock()
|
|
|
+ defer r.mu.Unlock()
|
|
|
+ if r.locked {
|
|
|
+ return errLocked
|
|
|
+ }
|
|
|
+
|
|
|
+ r.keys = nil
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+func (r *keyring) removeLocked(want []byte) error {
|
|
|
+ found := false
|
|
|
+ for i := 0; i < len(r.keys); {
|
|
|
+ if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) {
|
|
|
+ found = true
|
|
|
+ r.keys[i] = r.keys[len(r.keys)-1]
|
|
|
+ r.keys = r.keys[:len(r.keys)-1]
|
|
|
+ continue
|
|
|
+ } else {
|
|
|
+ i++
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if !found {
|
|
|
+ return errors.New("agent: key not found")
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (r *keyring) Remove(key ssh.PublicKey) error {
|
|
|
+ r.mu.Lock()
|
|
|
+ defer r.mu.Unlock()
|
|
|
+ if r.locked {
|
|
|
+ return errLocked
|
|
|
+ }
|
|
|
+
|
|
|
+ return r.removeLocked(key.Marshal())
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (r *keyring) Lock(passphrase []byte) error {
|
|
|
+ r.mu.Lock()
|
|
|
+ defer r.mu.Unlock()
|
|
|
+ if r.locked {
|
|
|
+ return errLocked
|
|
|
+ }
|
|
|
+
|
|
|
+ r.locked = true
|
|
|
+ r.passphrase = passphrase
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (r *keyring) Unlock(passphrase []byte) error {
|
|
|
+ r.mu.Lock()
|
|
|
+ defer r.mu.Unlock()
|
|
|
+ if !r.locked {
|
|
|
+ return errors.New("agent: not locked")
|
|
|
+ }
|
|
|
+ if 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) {
|
|
|
+ return fmt.Errorf("agent: incorrect passphrase")
|
|
|
+ }
|
|
|
+
|
|
|
+ r.locked = false
|
|
|
+ r.passphrase = nil
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+func (r *keyring) expireKeysLocked() {
|
|
|
+ for _, k := range r.keys {
|
|
|
+ if k.expire != nil && time.Now().After(*k.expire) {
|
|
|
+ r.removeLocked(k.signer.PublicKey().Marshal())
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (r *keyring) List() ([]*agent.Key, error) {
|
|
|
+ r.mu.Lock()
|
|
|
+ defer r.mu.Unlock()
|
|
|
+ if r.locked {
|
|
|
+
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ r.expireKeysLocked()
|
|
|
+ var ids []*agent.Key
|
|
|
+ for _, k := range r.keys {
|
|
|
+ pub := k.signer.PublicKey()
|
|
|
+ ids = append(ids, &agent.Key{
|
|
|
+ Format: pub.Type(),
|
|
|
+ Blob: pub.Marshal(),
|
|
|
+ Comment: k.comment})
|
|
|
+ }
|
|
|
+ return ids, nil
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+func (r *keyring) Add(key agent.AddedKey) error {
|
|
|
+ r.mu.Lock()
|
|
|
+ defer r.mu.Unlock()
|
|
|
+ if r.locked {
|
|
|
+ return errLocked
|
|
|
+ }
|
|
|
+ signer, err := ssh.NewSignerFromKey(key.PrivateKey)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ if cert := key.Certificate; cert != nil {
|
|
|
+ signer, err = ssh.NewCertSigner(cert, signer)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ p := privKey{
|
|
|
+ signer: signer,
|
|
|
+ comment: key.Comment,
|
|
|
+ confirm: key.ConfirmBeforeUse,
|
|
|
+ }
|
|
|
+
|
|
|
+ if key.LifetimeSecs > 0 {
|
|
|
+ t := time.Now().Add(time.Duration(key.LifetimeSecs) * time.Second)
|
|
|
+ p.expire = &t
|
|
|
+ }
|
|
|
+
|
|
|
+ r.keys = append(r.keys, p)
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
|
|
|
+ return r.SignWithFlags(key, data, 0)
|
|
|
+}
|
|
|
+
|
|
|
+func (r *keyring) SignWithFlags(key ssh.PublicKey, data []byte, flags agent.SignatureFlags) (*ssh.Signature, error) {
|
|
|
+ r.mu.Lock()
|
|
|
+ defer r.mu.Unlock()
|
|
|
+ if r.locked {
|
|
|
+ return nil, errLocked
|
|
|
+ }
|
|
|
+
|
|
|
+ r.expireKeysLocked()
|
|
|
+ wanted := key.Marshal()
|
|
|
+ for _, k := range r.keys {
|
|
|
+ if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) {
|
|
|
+ if k.confirm {
|
|
|
+ if !r.confirmFunction(k.comment) {
|
|
|
+ return nil, fmt.Errorf("agent: confirmation failed")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if flags == 0 {
|
|
|
+ return k.signer.Sign(rand.Reader, data)
|
|
|
+ } else {
|
|
|
+ if algorithmSigner, ok := k.signer.(ssh.AlgorithmSigner); !ok {
|
|
|
+ return nil, fmt.Errorf("agent: signature does not support non-default signature algorithm: %T", k.signer)
|
|
|
+ } else {
|
|
|
+ var algorithm string
|
|
|
+ switch flags {
|
|
|
+ case agent.SignatureFlagRsaSha256:
|
|
|
+ algorithm = ssh.KeyAlgoRSASHA256
|
|
|
+ case agent.SignatureFlagRsaSha512:
|
|
|
+ algorithm = ssh.KeyAlgoRSASHA512
|
|
|
+ default:
|
|
|
+ return nil, fmt.Errorf("agent: unsupported signature flags: %d", flags)
|
|
|
+ }
|
|
|
+ return algorithmSigner.SignWithAlgorithm(rand.Reader, data, algorithm)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil, errors.New("not found")
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (r *keyring) Signers() ([]ssh.Signer, error) {
|
|
|
+ r.mu.Lock()
|
|
|
+ defer r.mu.Unlock()
|
|
|
+ if r.locked {
|
|
|
+ return nil, errLocked
|
|
|
+ }
|
|
|
+
|
|
|
+ r.expireKeysLocked()
|
|
|
+ s := make([]ssh.Signer, 0, len(r.keys))
|
|
|
+ for _, k := range r.keys {
|
|
|
+ s = append(s, k.signer)
|
|
|
+ }
|
|
|
+ return s, nil
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (r *keyring) Extension(extensionType string, contents []byte) ([]byte, error) {
|
|
|
+ return nil, agent.ErrExtensionUnsupported
|
|
|
+}
|