Skip to content

Commit 53ce856

Browse files
committed
feat: remote exec
1 parent f5cde85 commit 53ce856

File tree

7 files changed

+192
-84
lines changed

7 files changed

+192
-84
lines changed

api/query.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package api
33
import (
44
"bytes"
55
"encoding/json"
6+
"errors"
7+
"fmt"
68
"net/http"
79
"os"
810
"runtime"
@@ -19,17 +21,25 @@ type Input struct {
1921
func Query(input Input) (res *http.Response, err error) {
2022
jsonValue, err := json.Marshal(input)
2123
if err != nil {
22-
return
24+
return nil, err
2325
}
2426

2527
apiUrl := os.Getenv("RUNPOD_API_URL")
2628
if apiUrl == "" {
2729
apiUrl = viper.GetString("apiUrl")
2830
}
31+
2932
apiKey := os.Getenv("RUNPOD_API_KEY")
3033
if apiKey == "" {
3134
apiKey = viper.GetString("apiKey")
3235
}
36+
37+
// Check if the API key is present
38+
if apiKey == "" {
39+
fmt.Println("API key not found")
40+
return nil, errors.New("API key not found")
41+
}
42+
3343
req, err := http.NewRequest("POST", apiUrl+"?api_key="+apiKey, bytes.NewBuffer(jsonValue))
3444
if err != nil {
3545
return

cmd/exec.go

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package cmd
2+
3+
import (
4+
"cli/cmd/exec"
5+
6+
"github.com/spf13/cobra"
7+
)
8+
9+
// execCmd represents the base command for executing commands in a pod
10+
var execCmd = &cobra.Command{
11+
Use: "exec",
12+
Short: "Execute commands in a pod",
13+
Long: `Execute a local file remotely in a pod.`,
14+
}
15+
16+
func init() {
17+
execCmd.AddCommand(exec.RemotePythonCmd)
18+
}

cmd/exec/commands.go

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package exec
2+
3+
import (
4+
"fmt"
5+
"os"
6+
7+
"github.com/spf13/cobra"
8+
)
9+
10+
var RemotePythonCmd = &cobra.Command{
11+
Use: "python [file]",
12+
Short: "Runs a remote Python shell",
13+
Long: `Runs a remote Python shell with a local script file.`,
14+
Args: cobra.ExactArgs(1),
15+
Run: func(cmd *cobra.Command, args []string) {
16+
podID, _ := cmd.Flags().GetString("pod_id")
17+
file := args[0]
18+
19+
// Default to the session pod if no pod_id is provided
20+
// if podID == "" {
21+
// var err error
22+
// podID, err = api.GetSessionPod()
23+
// if err != nil {
24+
// fmt.Fprintf(os.Stderr, "Error retrieving session pod: %v\n", err)
25+
// return
26+
// }
27+
// }
28+
29+
fmt.Println("Running remote Python shell...")
30+
if err := PythonOverSSH(podID, file); err != nil {
31+
fmt.Fprintf(os.Stderr, "Error executing Python over SSH: %v\n", err)
32+
}
33+
},
34+
}
35+
36+
func init() {
37+
RemotePythonCmd.Flags().String("pod_id", "", "The ID of the pod to run the command on.")
38+
RemotePythonCmd.MarkFlagRequired("file")
39+
}

cmd/exec/functions.go

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package exec
2+
3+
import (
4+
"cli/cmd/project"
5+
"fmt"
6+
)
7+
8+
func PythonOverSSH(podID string, file string) error {
9+
sshConn, err := project.PodSSHConnection(podID)
10+
if err != nil {
11+
return fmt.Errorf("getting SSH connection: %w", err)
12+
}
13+
14+
// Copy the file to the pod using Rsync
15+
if err := sshConn.Rsync(file, "/tmp/"+file, false); err != nil {
16+
return fmt.Errorf("copying file to pod: %w", err)
17+
}
18+
19+
// Run the file on the pod
20+
if err := sshConn.RunCommand("python3.11 /tmp/" + file); err != nil {
21+
return fmt.Errorf("running Python command: %w", err)
22+
}
23+
24+
return nil
25+
}

cmd/project/functions.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,31 @@ func copyFiles(files fs.FS, source string, dest string) error {
4242
if path == source {
4343
return nil
4444
}
45+
46+
relPath, err := filepath.Rel(source, path)
47+
if err != nil {
48+
return err
49+
}
50+
4551
// Generate the corresponding path in the new project folder
46-
newPath := filepath.Join(dest, path[len(source):])
52+
newPath := filepath.Join(dest, relPath)
4753
if d.IsDir() {
48-
return os.MkdirAll(newPath, os.ModePerm)
54+
if err := os.MkdirAll(newPath, os.ModePerm); err != nil {
55+
return err
56+
}
4957
} else {
5058
content, err := fs.ReadFile(files, path)
5159
if err != nil {
5260
return err
5361
}
54-
return os.WriteFile(newPath, content, 0644)
62+
if err := os.WriteFile(newPath, content, 0644); err != nil {
63+
return err
64+
}
5565
}
66+
return nil
5667
})
5768
}
69+
5870
func createNewProject(projectName string, cudaVersion string,
5971
pythonVersion string, modelType string, modelName string, initCurrentDir bool) {
6072
projectFolder, _ := os.Getwd()

0 commit comments

Comments
 (0)