Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get 'The engine PyTorch was not able to initialize' when call TransformersEmbeddingModel's embed method concurrently #836

Closed
nichozhan opened this issue Jun 7, 2024 · 2 comments

Comments

@nichozhan
Copy link
Contributor

Bug description
I got an error message when my program calls TransformersEmbeddingModel's embed() method concurrently to compute many embedings, even when I called afterPropertiesSet() before the concurrent calls. Here is the stacktrace:

java.lang.IllegalStateException: The engine PyTorch was not able to initialize
	at ai.djl.engine.Engine.getEngine(Engine.java:218)
	at ai.djl.engine.Engine.getInstance(Engine.java:149)
	at ai.djl.ndarray.NDManager.newBaseManager(NDManager.java:120)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.call(TransformersEmbeddingModel.java:280)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:232)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:212)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:217)
	at org.springframework.ai.transformers.TransformersEmbeddingModelTests.lambda$parallelEmbedDocument$0(TransformersEmbeddingModelTests.java:71)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:840)
java.lang.IllegalStateException: The engine PyTorch was not able to initialize
	at ai.djl.engine.Engine.getEngine(Engine.java:218)
	at ai.djl.engine.Engine.getInstance(Engine.java:149)
	at ai.djl.ndarray.NDManager.newBaseManager(NDManager.java:120)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.call(TransformersEmbeddingModel.java:280)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:232)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:212)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.embed(TransformersEmbeddingModel.java:217)
	at org.springframework.ai.transformers.TransformersEmbeddingModelTests.lambda$parallelEmbedDocument$0(TransformersEmbeddingModelTests.java:71)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:840)
...

I did some toubleshooting and investigation, my conclusion is that the djl 0.26.0 used by spring-ai has a known issue about what I encountered. And this isssue is fixed in djl 0.27.0. See: deepjavalibrary/djl#3005.

I verified in djl 0.28.0 and the errors were disappeared. Thus, I suggest we upgrade djl to 0.28.0.

Environment
Sping AI version: 0.8.1, v1.0.0-M1
Java: 17
vectore store: none

Steps to reproduce
You can use the test code I modified from TransformerEmbeddingModelTests. Here the code to reproduce:

@Test
	void parallelEmbedDocument() throws InterruptedException {
		TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel();
    try {
      embeddingModel.afterPropertiesSet();
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
    ExecutorService executorService = Executors.newFixedThreadPool(10);
		for (int i = 0; i < 10; i++) {
			executorService.execute(() -> {
        try {
					List<Double> embed = embeddingModel.embed(new Document("Hello world"));
					assertThat(embed).hasSize(384);
					assertThat(DF.format(embed.get(0))).isEqualTo(DF.format(-0.19744634628295898));
					assertThat(DF.format(embed.get(383))).isEqualTo(DF.format(0.17298996448516846));
        } catch (Exception e) {
          e.printStackTrace();
        }
      });
		}
		executorService.shutdown();
		executorService.awaitTermination(30, TimeUnit.SECONDS);
	}

Expected behavior
No errors about "The engine PyTorch was not able to initialize".

Minimal Complete Reproducible example
See my comment above.

@ThomasVitale
Copy link
Contributor

@nichozhan Did #837 fix this issue?

@nichozhan
Copy link
Contributor Author

Hi @ThomasVitale , yes, I think so.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants