Skip to content

Commit 4c90f4b

Browse files
committed
feat: port ssh cmd from python, refactored ssh key gen
1 parent 930fc3c commit 4c90f4b

File tree

6 files changed

+320
-120
lines changed

6 files changed

+320
-120
lines changed

api/user.go

+68-28
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,21 @@ package api
22

33
import (
44
"encoding/json"
5-
"errors"
65
"fmt"
76
"io"
87
"strings"
8+
9+
"golang.org/x/crypto/ssh"
910
)
1011

11-
func GetPublicSSHKeys() (keys string, err error) {
12+
type SSHKey struct {
13+
Name string `json:"name"`
14+
Type string `json:"type"`
15+
Key string `json:"key"`
16+
Fingerprint string `json:"fingerprint"`
17+
}
18+
19+
func GetPublicSSHKeys() (string, []SSHKey, error) {
1220
input := Input{
1321
Query: `
1422
query myself {
@@ -19,48 +27,79 @@ func GetPublicSSHKeys() (keys string, err error) {
1927
}
2028
`,
2129
}
30+
2231
res, err := Query(input)
2332
if err != nil {
24-
return "", err
33+
return "", nil, err
2534
}
35+
defer res.Body.Close()
36+
2637
if res.StatusCode != 200 {
27-
err = fmt.Errorf("statuscode %d", res.StatusCode)
28-
return
38+
return "", nil, fmt.Errorf("unexpected status code: %d", res.StatusCode)
2939
}
30-
defer res.Body.Close()
40+
3141
rawData, err := io.ReadAll(res.Body)
3242
if err != nil {
33-
return "", err
43+
return "", nil, fmt.Errorf("failed to read response body: %w", err)
3444
}
35-
data := &UserOut{}
36-
if err = json.Unmarshal(rawData, data); err != nil {
37-
return "", err
45+
46+
var data UserOut
47+
if err := json.Unmarshal(rawData, &data); err != nil {
48+
return "", nil, fmt.Errorf("JSON unmarshal error: %w", err)
3849
}
50+
3951
if len(data.Errors) > 0 {
40-
err = errors.New(data.Errors[0].Message)
41-
return "", err
52+
return "", nil, fmt.Errorf("API error: %s", data.Errors[0].Message)
4253
}
43-
if data == nil || data.Data == nil || data.Data.Myself == nil {
44-
err = fmt.Errorf("data is nil: %s", string(rawData))
45-
return "", err
54+
55+
if data.Data == nil || data.Data.Myself == nil {
56+
return "", nil, fmt.Errorf("nil data received: %s", string(rawData))
57+
}
58+
59+
// Parse the public key string into a list of SSHKey structs
60+
var keys []SSHKey
61+
keyStrings := strings.Split(data.Data.Myself.PubKey, "\n")
62+
for _, keyString := range keyStrings {
63+
if keyString == "" {
64+
continue
65+
}
66+
67+
pubKey, name, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
68+
if err != nil {
69+
continue // Skip keys that can't be parsed
70+
}
71+
72+
keys = append(keys, SSHKey{
73+
Name: name,
74+
Type: pubKey.Type(),
75+
Key: string(ssh.MarshalAuthorizedKey(pubKey)),
76+
Fingerprint: ssh.FingerprintSHA256(pubKey),
77+
})
4678
}
47-
return data.Data.Myself.PubKey, nil
79+
80+
return data.Data.Myself.PubKey, keys, nil
4881
}
4982

5083
func AddPublicSSHKey(key []byte) error {
51-
//pull existing pubKey
52-
existingKeys, err := GetPublicSSHKeys()
84+
rawKeys, existingKeys, err := GetPublicSSHKeys()
5385
if err != nil {
54-
return err
86+
return fmt.Errorf("failed to get existing SSH keys: %w", err)
5587
}
88+
5689
keyStr := string(key)
57-
//check for key present
58-
if strings.Contains(existingKeys, keyStr) {
59-
return nil
90+
for _, k := range existingKeys {
91+
if strings.TrimSpace(k.Key) == strings.TrimSpace(keyStr) {
92+
return nil
93+
}
6094
}
61-
// concat key onto pubKey
62-
newKeys := existingKeys + "\n\n" + keyStr
63-
// set new pubKey
95+
96+
// Concatenate the new key onto the existing keys, separated by a newline
97+
newKeys := strings.TrimSpace(rawKeys)
98+
if newKeys != "" {
99+
newKeys += "\n\n"
100+
}
101+
newKeys += strings.TrimSpace(keyStr)
102+
64103
input := Input{
65104
Query: `
66105
mutation Mutation($input: UpdateUserSettingsInput) {
@@ -71,9 +110,10 @@ func AddPublicSSHKey(key []byte) error {
71110
`,
72111
Variables: map[string]interface{}{"input": map[string]interface{}{"pubKey": newKeys}},
73112
}
74-
_, err = Query(input)
75-
if err != nil {
76-
return err
113+
114+
if _, err = Query(input); err != nil {
115+
return fmt.Errorf("failed to update SSH keys: %w", err)
77116
}
117+
78118
return nil
79119
}

cmd/config/config.go

+22-70
Original file line numberDiff line numberDiff line change
@@ -2,100 +2,52 @@ package config
22

33
import (
44
"cli/api"
5-
"crypto/rand"
6-
"crypto/rsa"
7-
"crypto/x509"
8-
"encoding/pem"
9-
"errors"
5+
"cli/cmd/ssh"
106
"fmt"
117
"os"
12-
"path/filepath"
138

149
"github.com/spf13/cobra"
1510
"github.com/spf13/viper"
16-
"golang.org/x/crypto/ssh"
1711
)
1812

19-
var ConfigFile string
20-
var apiKey string
21-
var apiUrl string
13+
var (
14+
ConfigFile string
15+
apiKey string
16+
apiUrl string
17+
)
2218

2319
var ConfigCmd = &cobra.Command{
2420
Use: "config",
2521
Short: "CLI Config",
2622
Long: "RunPod CLI Config Settings",
2723
Run: func(c *cobra.Command, args []string) {
28-
err := viper.WriteConfig()
29-
cobra.CheckErr(err)
30-
fmt.Println("saved apiKey into config file: " + ConfigFile)
31-
home, err := os.UserHomeDir()
24+
if err := viper.WriteConfig(); err != nil {
25+
fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
26+
return
27+
}
28+
fmt.Println("Configuration saved to file:", viper.ConfigFileUsed())
29+
30+
publicKey, err := ssh.GenerateSSHKeyPair("RunPod-Key-Go")
3231
if err != nil {
33-
fmt.Println("couldn't get user home dir path")
32+
fmt.Fprintf(os.Stderr, "Failed to generate SSH key: %v\n", err)
3433
return
3534
}
36-
sshFolderPath := filepath.Join(home, ".runpod", "ssh")
37-
os.MkdirAll(sshFolderPath, os.ModePerm)
38-
privateSshPath := filepath.Join(sshFolderPath, "RunPod-Key-Go")
39-
publicSshPath := filepath.Join(sshFolderPath, "RunPod-Key-Go.pub")
40-
publicKey, _ := os.ReadFile(publicSshPath)
41-
if _, err := os.Stat(privateSshPath); errors.Is(err, os.ErrNotExist) {
42-
publicKey = makeRSAKey(privateSshPath)
35+
36+
if err := api.AddPublicSSHKey(publicKey); err != nil {
37+
fmt.Fprintf(os.Stderr, "Failed to add the SSH key: %v\n", err)
38+
return
4339
}
44-
api.AddPublicSSHKey(publicKey)
40+
fmt.Println("SSH key added successfully.")
4541
},
4642
}
4743

4844
func init() {
49-
ConfigCmd.Flags().StringVar(&apiKey, "apiKey", "", "runpod api key")
50-
ConfigCmd.MarkFlagRequired("apiKey")
45+
ConfigCmd.Flags().StringVar(&apiKey, "apiKey", "", "RunPod API key")
5146
viper.BindPFlag("apiKey", ConfigCmd.Flags().Lookup("apiKey")) //nolint
5247
viper.SetDefault("apiKey", "")
5348

54-
ConfigCmd.Flags().StringVar(&apiUrl, "apiUrl", "", "runpod api url")
49+
ConfigCmd.Flags().StringVar(&apiUrl, "apiUrl", "https://api.runpod.io/graphql", "RunPod API URL")
5550
viper.BindPFlag("apiUrl", ConfigCmd.Flags().Lookup("apiUrl")) //nolint
56-
viper.SetDefault("apiUrl", "https://api.runpod.io/graphql")
57-
}
5851

59-
func makeRSAKey(filename string) []byte {
60-
bitSize := 2048
61-
62-
// Generate RSA key.
63-
key, err := rsa.GenerateKey(rand.Reader, bitSize)
64-
if err != nil {
65-
panic(err)
66-
}
67-
68-
// Extract public component.
69-
pub := key.PublicKey
70-
71-
// Encode private key to PKCS#1 ASN.1 PEM.
72-
keyPEM := pem.EncodeToMemory(
73-
&pem.Block{
74-
Type: "RSA PRIVATE KEY",
75-
Bytes: x509.MarshalPKCS1PrivateKey(key),
76-
},
77-
)
78-
79-
// generate and write public key
80-
publicKey, err := ssh.NewPublicKey(&pub)
81-
if err != nil {
82-
fmt.Println("err in NewPublicKey")
83-
fmt.Println(err)
84-
}
85-
pubBytes := ssh.MarshalAuthorizedKey(publicKey)
86-
pubBytes = append(pubBytes, []byte(" "+filename)...)
87-
88-
// Write private key to file.
89-
if err := os.WriteFile(filename, keyPEM, 0600); err != nil {
90-
fmt.Println("err writing priv")
91-
panic(err)
92-
}
93-
94-
// Write public key to file.
95-
if err := os.WriteFile(filename+".pub", pubBytes, 0600); err != nil {
96-
fmt.Println("err writing pub")
97-
panic(err)
98-
}
99-
fmt.Println("saved new SSH public key into", filename+".pub")
100-
return pubBytes
52+
ConfigCmd.MarkFlagRequired("apiKey")
10153
}

cmd/root.go

+32-22
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,51 @@ import (
1414

1515
var version string
1616

17-
// rootCmd represents the base command when called without any subcommands
18-
var RootCmd = &cobra.Command{
17+
// Entrypoint for the CLI
18+
var rootCmd = &cobra.Command{
1919
Use: "runpodctl",
2020
Aliases: []string{"runpod"},
2121
Short: "CLI for runpod.io",
2222
Long: "CLI tool to manage your pods for runpod.io",
2323
}
2424

25-
// Execute adds all child commands to the root command and sets flags appropriately.
26-
// This is called by main.main(). It only needs to happen once to the rootCmd.
27-
func Execute(ver string) {
28-
version = ver
29-
api.Version = ver
30-
err := RootCmd.Execute()
31-
if err != nil {
32-
os.Exit(1)
33-
}
25+
func GetRootCmd() *cobra.Command {
26+
return rootCmd
3427
}
3528

3629
func init() {
3730
cobra.OnInitialize(initConfig)
38-
RootCmd.AddCommand(config.ConfigCmd)
31+
registerCommands()
32+
}
33+
34+
func registerCommands() {
35+
rootCmd.AddCommand(config.ConfigCmd)
3936
// RootCmd.AddCommand(connectCmd)
4037
// RootCmd.AddCommand(copyCmd)
41-
RootCmd.AddCommand(createCmd)
42-
RootCmd.AddCommand(getCmd)
43-
RootCmd.AddCommand(removeCmd)
44-
RootCmd.AddCommand(startCmd)
45-
RootCmd.AddCommand(stopCmd)
46-
RootCmd.AddCommand(versionCmd)
47-
RootCmd.AddCommand(projectCmd)
48-
RootCmd.AddCommand(updateCmd)
38+
rootCmd.AddCommand(createCmd)
39+
rootCmd.AddCommand(getCmd)
40+
rootCmd.AddCommand(removeCmd)
41+
rootCmd.AddCommand(startCmd)
42+
rootCmd.AddCommand(stopCmd)
43+
rootCmd.AddCommand(versionCmd)
44+
rootCmd.AddCommand(projectCmd)
45+
rootCmd.AddCommand(updateCmd)
46+
rootCmd.AddCommand(sshCmd)
4947

50-
RootCmd.AddCommand(croc.ReceiveCmd)
51-
RootCmd.AddCommand(croc.SendCmd)
48+
// file transfer via croc
49+
rootCmd.AddCommand(croc.ReceiveCmd)
50+
rootCmd.AddCommand(croc.SendCmd)
51+
}
52+
53+
// Execute adds all child commands to the root command and sets flags appropriately.
54+
// This is called by main.main(). It only needs to happen once to the rootCmd.
55+
func Execute(ver string) {
56+
version = ver
57+
api.Version = ver
58+
if err := rootCmd.Execute(); err != nil {
59+
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
60+
os.Exit(1)
61+
}
5262
}
5363

5464
// initConfig reads in config file and ENV variables if set.

cmd/ssh.go

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package cmd
2+
3+
import (
4+
"cli/cmd/ssh"
5+
6+
"github.com/spf13/cobra"
7+
)
8+
9+
var sshCmd = &cobra.Command{
10+
Use: "ssh",
11+
Short: "SSH keys and commands",
12+
Long: "SSH key management and connection to pods",
13+
}
14+
15+
func init() {
16+
sshCmd.AddCommand(ssh.ListKeysCmd)
17+
sshCmd.AddCommand(ssh.AddKeyCmd)
18+
}

0 commit comments

Comments
 (0)