-
Notifications
You must be signed in to change notification settings - Fork 160
/
Copy pathmain.py
78 lines (67 loc) · 2.53 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import subprocess
from threading import Thread
import time
import requests
from gradio_ui import app
from util import download_weights
import torch
import socket
def run():
try:
print("cuda is_available: ", torch.cuda.is_available()) # 应该返回True
print("MPS is_available: ", torch.backends.mps.is_available())
print("cuda device_count", torch.cuda.device_count()) # 应该至少返回1
print("cuda device_name", torch.cuda.get_device_name(0)) # 应该显示您的GPU名称
except Exception:
print("显卡驱动不适配,请根据readme安装合适版本的 torch!")
server_process = subprocess.Popen(
["python", "./omniserver.py"],
stdout=subprocess.PIPE, # 捕获标准输出
stderr=subprocess.PIPE,
text=True
)
stdout_thread = Thread(
target=stream_reader,
args=(server_process.stdout, "SERVER-OUT")
)
stderr_thread = Thread(
target=stream_reader,
args=(server_process.stderr, "SERVER-ERR")
)
stdout_thread.daemon = True
stderr_thread.daemon = True
stdout_thread.start()
stderr_thread.start()
try:
# 下载权重文件
download_weights.download()
print("启动Omniserver服务中,因为加载模型真的超级慢,请耐心等待!")
while True:
try:
res = requests.get("http://127.0.0.1:8000/probe/", timeout=5)
if res.status_code == 200:
print("Omniparser服务启动成功...")
break
except (requests.ConnectionError, requests.Timeout):
pass
if server_process.poll() is not None:
raise RuntimeError(f"服务器进程报错退出:{server_process.returncode}")
print("等待服务启动...")
time.sleep(10)
app.run()
finally:
if server_process.poll() is None: # 如果进程还在运行
server_process.terminate() # 发送终止信号
server_process.wait(timeout=8) # 等待进程结束
def stream_reader(pipe, prefix):
for line in pipe:
print(f"[{prefix}]", line, end="", flush=True)
def is_port_occupied(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
if __name__ == '__main__':
# 检测8000端口是否被占用
if is_port_occupied(8000):
print("8000端口被占用,请先关闭占用该端口的进程")
exit()
run()