diff --git a/README.md b/README.md index c2d755a..8fcc07f 100644 --- a/README.md +++ b/README.md @@ -60,15 +60,21 @@ import ( ) func main() { - ret := nvml.Init() + ret, err := nvml.Init() + if err != nil { + log.Fatalf("Unable to open NVML library: %v", err) + } if ret != nvml.SUCCESS { log.Fatalf("Unable to initialize NVML: %v", nvml.ErrorString(ret)) } defer func() { - ret := nvml.Shutdown() + ret, err := nvml.Shutdown() if ret != nvml.SUCCESS { log.Fatalf("Unable to shutdown NVML: %v", nvml.ErrorString(ret)) } + if err != nil { + log.Fatalf("Unable to close NVML library: %v", err) + } }() count, ret := nvml.DeviceGetCount() diff --git a/examples/compute-processes/main.go b/examples/compute-processes/main.go index 5adeb5a..5d34877 100644 --- a/examples/compute-processes/main.go +++ b/examples/compute-processes/main.go @@ -24,15 +24,21 @@ import ( ) func main() { - ret := nvml.Init() + ret, err := nvml.Init() + if err != nil { + log.Fatalf("Unable to open NVML library: %v", err) + } if ret != nvml.SUCCESS { log.Fatalf("Unable to initialize NVML: %v", nvml.ErrorString(ret)) } defer func() { - ret := nvml.Shutdown() + ret, err := nvml.Shutdown() if ret != nvml.SUCCESS { log.Fatalf("Unable to shutdown NVML: %v", nvml.ErrorString(ret)) } + if err != nil { + log.Fatalf("Unable to close NVML library: %v", err) + } }() count, ret := nvml.DeviceGetCount() diff --git a/examples/devices/main.go b/examples/devices/main.go index 649f371..c6d20eb 100644 --- a/examples/devices/main.go +++ b/examples/devices/main.go @@ -24,15 +24,21 @@ import ( ) func main() { - ret := nvml.Init() + ret, err := nvml.Init() + if err != nil { + log.Fatalf("Unable to open NVML library: %v", err) + } if ret != nvml.SUCCESS { log.Fatalf("Unable to initialize NVML: %v", nvml.ErrorString(ret)) } defer func() { - ret := nvml.Shutdown() + ret, err := nvml.Shutdown() if ret != nvml.SUCCESS { log.Fatalf("Unable to shutdown NVML: %v", nvml.ErrorString(ret)) } + if err != nil { + log.Fatalf("Unable to close NVML library: %v", err) + } }() count, ret := nvml.DeviceGetCount() diff --git a/gen/nvml/init.go b/gen/nvml/init.go index 1572f81..c8cb0d0 100644 --- a/gen/nvml/init.go +++ b/gen/nvml/init.go @@ -15,8 +15,6 @@ package nvml import ( - "fmt" - "github.com/NVIDIA/go-nvml/pkg/dl" ) @@ -30,45 +28,45 @@ const ( var nvml *dl.DynamicLibrary // nvml.Init() -func Init() Return { +func Init() (Return, error) { lib := dl.New(nvmlLibraryName, nvmlLibraryLoadFlags) err := lib.Open() if err != nil { - return ERROR_LIBRARY_NOT_FOUND + return ERROR_LIBRARY_NOT_FOUND, err } nvml = lib updateVersionedSymbols() - return nvmlInit() + return nvmlInit(), nil } // nvml.InitWithFlags() -func InitWithFlags(Flags uint32) Return { +func InitWithFlags(Flags uint32) (Return, error) { lib := dl.New(nvmlLibraryName, nvmlLibraryLoadFlags) err := lib.Open() if err != nil { - return ERROR_LIBRARY_NOT_FOUND + return ERROR_LIBRARY_NOT_FOUND, err } nvml = lib - return nvmlInitWithFlags(Flags) + return nvmlInitWithFlags(Flags), nil } // nvml.Shutdown() -func Shutdown() Return { +func Shutdown() (Return, error) { ret := nvmlShutdown() if ret != SUCCESS { - return ret + return ret, nil } err := nvml.Close() if err != nil { - panic(fmt.Sprintf("error closing %s: %v", nvmlLibraryName, err)) + return ret, err } - return ret + return ret, nil } // Default all versioned APIs to v1 (to infer the types) diff --git a/gen/nvml/nvml_test.go b/gen/nvml/nvml_test.go index 9d8824f..42355fa 100644 --- a/gen/nvml/nvml_test.go +++ b/gen/nvml/nvml_test.go @@ -19,19 +19,25 @@ import ( ) func TestInit(t *testing.T) { - ret := Init() + ret, err := Init() + if err != nil { + t.Errorf("NVML open: %v", err) + } if ret != SUCCESS { t.Errorf("Init: %v", ret) } else { t.Logf("Init: %v", ret) } - ret = Shutdown() + ret, err = Shutdown() if ret != SUCCESS { t.Errorf("Shutdown: %v", ret) } else { t.Logf("Shutdown: %v", ret) } + if err != nil { + t.Errorf("NVML close: %v", err) + } } func TestSystem(t *testing.T) { diff --git a/pkg/nvml/init.go b/pkg/nvml/init.go index 1572f81..c8cb0d0 100644 --- a/pkg/nvml/init.go +++ b/pkg/nvml/init.go @@ -15,8 +15,6 @@ package nvml import ( - "fmt" - "github.com/NVIDIA/go-nvml/pkg/dl" ) @@ -30,45 +28,45 @@ const ( var nvml *dl.DynamicLibrary // nvml.Init() -func Init() Return { +func Init() (Return, error) { lib := dl.New(nvmlLibraryName, nvmlLibraryLoadFlags) err := lib.Open() if err != nil { - return ERROR_LIBRARY_NOT_FOUND + return ERROR_LIBRARY_NOT_FOUND, err } nvml = lib updateVersionedSymbols() - return nvmlInit() + return nvmlInit(), nil } // nvml.InitWithFlags() -func InitWithFlags(Flags uint32) Return { +func InitWithFlags(Flags uint32) (Return, error) { lib := dl.New(nvmlLibraryName, nvmlLibraryLoadFlags) err := lib.Open() if err != nil { - return ERROR_LIBRARY_NOT_FOUND + return ERROR_LIBRARY_NOT_FOUND, err } nvml = lib - return nvmlInitWithFlags(Flags) + return nvmlInitWithFlags(Flags), nil } // nvml.Shutdown() -func Shutdown() Return { +func Shutdown() (Return, error) { ret := nvmlShutdown() if ret != SUCCESS { - return ret + return ret, nil } err := nvml.Close() if err != nil { - panic(fmt.Sprintf("error closing %s: %v", nvmlLibraryName, err)) + return ret, err } - return ret + return ret, nil } // Default all versioned APIs to v1 (to infer the types) diff --git a/pkg/nvml/nvml_test.go b/pkg/nvml/nvml_test.go index 9d8824f..42355fa 100644 --- a/pkg/nvml/nvml_test.go +++ b/pkg/nvml/nvml_test.go @@ -19,19 +19,25 @@ import ( ) func TestInit(t *testing.T) { - ret := Init() + ret, err := Init() + if err != nil { + t.Errorf("NVML open: %v", err) + } if ret != SUCCESS { t.Errorf("Init: %v", ret) } else { t.Logf("Init: %v", ret) } - ret = Shutdown() + ret, err = Shutdown() if ret != SUCCESS { t.Errorf("Shutdown: %v", ret) } else { t.Logf("Shutdown: %v", ret) } + if err != nil { + t.Errorf("NVML close: %v", err) + } } func TestSystem(t *testing.T) {