diff --git a/packages/common/infra/package.json b/packages/common/infra/package.json index 093f2cf22f149..03486b0badcc2 100644 --- a/packages/common/infra/package.json +++ b/packages/common/infra/package.json @@ -7,6 +7,7 @@ "./command": "./src/command/index.ts", "./atom": "./src/atom/index.ts", "./app-config-storage": "./src/app-config-storage.ts", + "./di": "./src/di/index.ts", "./livedata": "./src/livedata/index.ts", ".": "./src/index.ts" }, diff --git a/packages/common/infra/src/di/__tests__/di.spec.ts b/packages/common/infra/src/di/__tests__/di.spec.ts new file mode 100644 index 0000000000000..6828d5ba1c472 --- /dev/null +++ b/packages/common/infra/src/di/__tests__/di.spec.ts @@ -0,0 +1,357 @@ +import { describe, expect, test } from 'vitest'; + +import { + CircularDependencyError, + createIdentifier, + createScope, + DuplicateServiceDefinitionError, + MissingDependencyError, + RecursionLimitError, + ServiceCollection, + ServiceNotFoundError, + ServiceProvider, +} from '../'; + +describe('di', () => { + test('basic', () => { + const serviceCollection = new ServiceCollection(); + class TestService { + a = 'b'; + } + + serviceCollection.add(TestService); + + const provider = serviceCollection.provider(); + expect(provider.get(TestService)).toEqual({ a: 'b' }); + }); + + test('size', () => { + const serviceCollection = new ServiceCollection(); + class TestService { + a = 'b'; + } + + serviceCollection.add(TestService); + + expect(serviceCollection.size).toEqual(1); + }); + + test('dependency', () => { + const serviceCollection = new ServiceCollection(); + + class A { + value = 'hello world'; + } + + class B { + constructor(public a: A) {} + } + + class C { + constructor(public b: B) {} + } + + serviceCollection.add(A).add(B, [A]).add(C, [B]); + + const provider = serviceCollection.provider(); + + expect(provider.get(C).b.a.value).toEqual('hello world'); + }); + + test('identifier', () => { + interface Animal { + name: string; + } + const Animal = createIdentifier('Animal'); + + class Cat { + constructor() {} + name = 'cat'; + } + + class Zoo { + constructor(public animal: Animal) {} + } + + const serviceCollection = new ServiceCollection(); + serviceCollection.addImpl(Animal, Cat).add(Zoo, [Animal]); + + const provider = serviceCollection.provider(); + expect(provider.get(Zoo).animal.name).toEqual('cat'); + }); + + test('variant', () => { + const serviceCollection = new ServiceCollection(); + + interface USB { + speed: number; + } + + const USB = createIdentifier('USB'); + + class TypeA implements USB { + speed = 100; + } + class TypeC implements USB { + speed = 300; + } + + class PC { + constructor( + public typeA: USB, + public ports: USB[] + ) {} + } + + serviceCollection + .addImpl(USB('A'), TypeA) + .addImpl(USB('C'), TypeC) + .add(PC, [USB('A'), [USB]]); + + const provider = serviceCollection.provider(); + expect(provider.get(USB('A')).speed).toEqual(100); + expect(provider.get(USB('C')).speed).toEqual(300); + expect(provider.get(PC).typeA.speed).toEqual(100); + expect(provider.get(PC).ports.length).toEqual(2); + }); + + test('lazy initialization', () => { + const serviceCollection = new ServiceCollection(); + interface Command { + shortcut: string; + callback: () => void; + } + const Command = createIdentifier('command'); + + let pageSystemInitialized = false; + + class PageSystem { + mode = 'page'; + name = 'helloworld'; + + constructor() { + pageSystemInitialized = true; + } + + switchToEdgeless() { + this.mode = 'edgeless'; + } + + rename() { + this.name = 'foobar'; + } + } + + class CommandSystem { + constructor(public commands: Command[]) {} + + execute(shortcut: string) { + const command = this.commands.find(c => c.shortcut === shortcut); + if (command) { + command.callback(); + } + } + } + + serviceCollection.add(PageSystem); + serviceCollection.add(CommandSystem, [[Command]]); + serviceCollection.addImpl(Command('switch'), p => ({ + shortcut: 'option+s', + callback: () => p.get(PageSystem).switchToEdgeless(), + })); + serviceCollection.addImpl(Command('rename'), p => ({ + shortcut: 'f2', + callback: () => p.get(PageSystem).rename(), + })); + + const provider = serviceCollection.provider(); + const commandSystem = provider.get(CommandSystem); + + expect( + pageSystemInitialized, + "PageSystem won't be initialized until command executed" + ).toEqual(false); + + commandSystem.execute('option+s'); + expect(pageSystemInitialized).toEqual(true); + expect(provider.get(PageSystem).mode).toEqual('edgeless'); + + expect(provider.get(PageSystem).name).toEqual('helloworld'); + expect(commandSystem.commands.length).toEqual(2); + commandSystem.execute('f2'); + expect(provider.get(PageSystem).name).toEqual('foobar'); + }); + + test('duplicate, override', () => { + const serviceCollection = new ServiceCollection(); + + const something = createIdentifier('USB'); + + class A { + a = 'i am A'; + } + + class B { + b = 'i am B'; + } + + serviceCollection.addImpl(something, A).override(something, B); + + const provider = serviceCollection.provider(); + expect(provider.get(something)).toEqual({ b: 'i am B' }); + }); + + test('scope', () => { + const services = new ServiceCollection(); + + const workspaceScope = createScope('workspace'); + const pageScope = createScope('page', workspaceScope); + const editorScope = createScope('editor', pageScope); + + class System { + appName = 'affine'; + } + + services.add(System); + + class Workspace { + name = 'workspace'; + constructor(public system: System) {} + } + + services.scope(workspaceScope).add(Workspace, [System]); + class Page { + name = 'page'; + constructor( + public system: System, + public workspace: Workspace + ) {} + } + + services.scope(pageScope).add(Page, [System, Workspace]); + + class Editor { + name = 'editor'; + constructor(public page: Page) {} + } + + services.scope(editorScope).add(Editor, [Page]); + + const root = services.provider(); + expect(root.get(System).appName).toEqual('affine'); + expect(() => root.get(Workspace)).toThrowError(ServiceNotFoundError); + + const workspace = services.provider(workspaceScope, root); + expect(workspace.get(Workspace).name).toEqual('workspace'); + expect(workspace.get(System).appName).toEqual('affine'); + expect(() => root.get(Page)).toThrowError(ServiceNotFoundError); + + const page = services.provider(pageScope, workspace); + expect(page.get(Page).name).toEqual('page'); + expect(page.get(Workspace).name).toEqual('workspace'); + expect(page.get(System).appName).toEqual('affine'); + + const editor = services.provider(editorScope, page); + expect(editor.get(Editor).name).toEqual('editor'); + }); + + test('service not found', () => { + const serviceCollection = new ServiceCollection(); + + const provider = serviceCollection.provider(); + expect(() => provider.get(createIdentifier('SomeService'))).toThrowError( + ServiceNotFoundError + ); + }); + + test('missing dependency', () => { + const serviceCollection = new ServiceCollection(); + + class A { + value = 'hello world'; + } + + class B { + constructor(public a: A) {} + } + + serviceCollection.add(B, [A]); + + const provider = serviceCollection.provider(); + expect(() => provider.get(B)).toThrowError(MissingDependencyError); + }); + + test('circular dependency', () => { + const serviceCollection = new ServiceCollection(); + + class A { + constructor(public c: C) {} + } + + class B { + constructor(public a: A) {} + } + + class C { + constructor(public b: B) {} + } + + serviceCollection.add(A, [C]).add(B, [A]).add(C, [B]); + + const provider = serviceCollection.provider(); + expect(() => provider.get(A)).toThrowError(CircularDependencyError); + expect(() => provider.get(B)).toThrowError(CircularDependencyError); + expect(() => provider.get(C)).toThrowError(CircularDependencyError); + }); + + test('duplicate service definition', () => { + const serviceCollection = new ServiceCollection(); + + class A {} + + serviceCollection.add(A); + expect(() => serviceCollection.add(A)).toThrowError( + DuplicateServiceDefinitionError + ); + + class B {} + const Something = createIdentifier('something'); + serviceCollection.addImpl(Something, A); + expect(() => serviceCollection.addImpl(Something, B)).toThrowError( + DuplicateServiceDefinitionError + ); + }); + + test('recursion limit', () => { + // maxmium resolve depth is 100 + const serviceCollection = new ServiceCollection(); + const Something = createIdentifier('something'); + let i = 0; + for (; i < 100; i++) { + const next = i + 1; + + class Test { + constructor(_next: any) {} + } + + serviceCollection.addImpl(Something(i.toString()), Test, [ + Something(next.toString()), + ]); + } + + class Final { + a = 'b'; + } + serviceCollection.addImpl(Something(i.toString()), Final); + const provider = serviceCollection.provider(); + expect(() => provider.get(Something('0'))).toThrowError( + RecursionLimitError + ); + }); + + test('self resolve', () => { + const serviceCollection = new ServiceCollection(); + const provider = serviceCollection.provider(); + expect(provider.get(ServiceProvider)).toEqual(provider); + }); +}); diff --git a/packages/common/infra/src/di/core/collection.ts b/packages/common/infra/src/di/core/collection.ts new file mode 100644 index 0000000000000..5fbd5b6840049 --- /dev/null +++ b/packages/common/infra/src/di/core/collection.ts @@ -0,0 +1,459 @@ +import { DEFAULT_SERVICE_VARIANT, ROOT_SCOPE } from './consts'; +import { DuplicateServiceDefinitionError } from './error'; +import { parseIdentifier } from './identifier'; +import type { ServiceProvider } from './provider'; +import { BasicServiceProvider } from './provider'; +import { stringifyScope } from './scope'; +import type { + GeneralServiceIdentifier, + ServiceFactory, + ServiceIdentifier, + ServiceIdentifierType, + ServiceIdentifierValue, + ServiceScope, + ServiceVariant, + Type, + TypesToDeps, +} from './types'; + +/** + * A collection of services. + * + * ServiceCollection basically is a tuple of `[scope, identifier, variant, factory]` with some helper methods. + * It just stores the definitions of services. It never holds any instances of services. + * + * # Usage + * + * ```ts + * const services = new ServiceCollection(); + * class ServiceA { + * // ... + * } + * // add a service + * services.add(ServiceA); + * + * class ServiceB { + * constructor(serviceA: ServiceA) {} + * } + * // add a service with dependency + * services.add(ServiceB, [ServiceA]); + * ^ dependency class/identifier, match ServiceB's constructor + * + * const FeatureA = createIdentifier('Config'); + * + * // add a implementation for a service identifier + * services.addImpl(FeatureA, ServiceA); + * + * // override a service + * services.override(ServiceA, NewServiceA); + * + * // create a service provider + * const provider = services.provider(); + * ``` + * + * # The data structure + * + * The data structure of ServiceCollection is a three-layer nested Map, used to represent the tuple of + * `[scope, identifier, variant, factory]`. + * Such a data structure ensures that a service factory can be uniquely determined by `[scope, identifier, variant]`. + * + * When a service added: + * + * ```ts + * services.add(ServiceClass) + * ``` + * + * The data structure will be: + * + * ```ts + * Map { + * '': Map { // scope + * 'ServiceClass': Map { // identifier + * 'default': // variant + * () => new ServiceClass() // factory + * } + * } + * ``` + * + * # Dependency relationship + * + * The dependency relationships of services are not actually stored in the ServiceCollection, + * but are transformed into a factory function when the service is added. + * + * For example: + * + * ```ts + * services.add(ServiceB, [ServiceA]); + * + * // is equivalent to + * services.addFactory(ServiceB, (provider) => new ServiceB(provider.get(ServiceA))); + * ``` + * + * For multiple implementations of the same service identifier, can be defined as: + * + * ```ts + * services.add(ServiceB, [[FeatureA]]); + * + * // is equivalent to + * services.addFactory(ServiceB, (provider) => new ServiceB(provider.getAll(FeatureA))); + * ``` + */ +export class ServiceCollection { + private readonly services: Map< + string, + Map> + > = new Map(); + + /** + * Create an empty service collection. + * + * same as `new ServiceCollection()` + */ + static get EMPTY() { + return new ServiceCollection(); + } + + /** + * The number of services in the collection. + */ + get size() { + let size = 0; + for (const [, identifiers] of this.services) { + for (const [, variants] of identifiers) { + size += variants.size; + } + } + return size; + } + + /** + * @see {@link ServiceCollectionEditor.add} + */ + get add() { + return new ServiceCollectionEditor(this).add; + } + + /** + * @see {@link ServiceCollectionEditor.addImpl} + */ + get addImpl() { + return new ServiceCollectionEditor(this).addImpl; + } + + /** + * @see {@link ServiceCollectionEditor.scope} + */ + get scope() { + return new ServiceCollectionEditor(this).scope; + } + + /** + * @see {@link ServiceCollectionEditor.scope} + */ + get override() { + return new ServiceCollectionEditor(this).override; + } + + /** + * @internal Use {@link addImpl} instead. + */ + addValue( + identifier: GeneralServiceIdentifier, + value: T, + { scope, override }: { scope?: ServiceScope; override?: boolean } = {} + ) { + this.addFactory( + parseIdentifier(identifier) as ServiceIdentifier, + () => value, + { + scope, + override, + } + ); + } + + /** + * @internal Use {@link addImpl} instead. + */ + addFactory( + identifier: GeneralServiceIdentifier, + factory: ServiceFactory, + { scope, override }: { scope?: ServiceScope; override?: boolean } = {} + ) { + // convert scope to string + const normalizedScope = stringifyScope(scope ?? ROOT_SCOPE); + const normalizedIdentifier = parseIdentifier(identifier); + const normalizedVariant = + normalizedIdentifier.variant ?? DEFAULT_SERVICE_VARIANT; + + const services = + this.services.get(normalizedScope) ?? + new Map>(); + + const variants = + services.get(normalizedIdentifier.identifierName) ?? + new Map(); + + // throw if service already exists, unless it is an override + if (variants.has(normalizedVariant) && !override) { + throw new DuplicateServiceDefinitionError(normalizedIdentifier); + } + variants.set(normalizedVariant, factory); + services.set(normalizedIdentifier.identifierName, variants); + this.services.set(normalizedScope, services); + } + + /** + * Create a service provider from the collection. + * + * @example + * ```ts + * provider() // create a service provider for root scope + * provider(ScopeA, parentProvider) // create a service provider for scope A + * ``` + * + * @param scope The scope of the service provider, default to the root scope. + * @param parent The parent service provider, it is required if the scope is not the root scope. + */ + provider( + scope: ServiceScope = ROOT_SCOPE, + parent: ServiceProvider | null = null + ): ServiceProvider { + return new BasicServiceProvider(this, scope, parent); + } + + /** + * @internal + */ + getFactory( + identifier: ServiceIdentifierValue, + scope: ServiceScope = ROOT_SCOPE + ): ServiceFactory | undefined { + return this.services + .get(stringifyScope(scope)) + ?.get(identifier.identifierName) + ?.get(identifier.variant ?? DEFAULT_SERVICE_VARIANT); + } + + /** + * @internal + */ + getFactoryAll( + identifier: ServiceIdentifierValue, + scope: ServiceScope = ROOT_SCOPE + ): Map { + return new Map( + this.services.get(stringifyScope(scope))?.get(identifier.identifierName) + ); + } + + /** + * Clone the entire service collection. + * + * This method is quite cheap as it only clones the references. + * + * @returns A new service collection with the same services. + */ + clone(): ServiceCollection { + const di = new ServiceCollection(); + for (const [scope, identifiers] of this.services) { + const s = new Map(); + for (const [identifier, variants] of identifiers) { + s.set(identifier, new Map(variants)); + } + di.services.set(scope, s); + } + return di; + } +} + +/** + * A helper class to edit a service collection. + */ +class ServiceCollectionEditor { + private currentScope: ServiceScope = ROOT_SCOPE; + + constructor(private readonly collection: ServiceCollection) {} + + /** + * Add a service to the collection. + * + * @see {@link ServiceCollection} + * + * @example + * ```ts + * add(ServiceClass, [dependencies, ...]) + * ``` + */ + add = < + T extends new (...args: any) => any, + const Deps extends TypesToDeps> = TypesToDeps< + ConstructorParameters + >, + >( + cls: T, + ...[deps]: Deps extends [] ? [] : [Deps] + ): this => { + this.collection.addFactory( + cls as any, + dependenciesToFactory(cls, deps as any), + { scope: this.currentScope } + ); + + return this; + }; + + /** + * Add an implementation for identifier to the collection. + * + * @see {@link ServiceCollection} + * + * @example + * ```ts + * addImpl(ServiceIdentifier, ServiceClass, [dependencies, ...]) + * or + * addImpl(ServiceIdentifier, Instance) + * or + * addImpl(ServiceIdentifier, Factory) + * ``` + */ + addImpl = < + Arg1 extends ServiceIdentifier, + Arg2 extends Type | ServiceFactory | Trait, + Trait = ServiceIdentifierType, + Deps extends Arg2 extends Type + ? TypesToDeps> + : [] = Arg2 extends Type + ? TypesToDeps> + : [], + Arg3 extends Deps = Deps, + >( + identifier: Arg1, + arg2: Arg2, + ...[arg3]: Arg3 extends [] ? [] : [Arg3] + ): this => { + if (arg2 instanceof Function) { + this.collection.addFactory( + identifier, + dependenciesToFactory(arg2, arg3 as any[]), + { scope: this.currentScope } + ); + } else { + this.collection.addValue(identifier, arg2 as any, { + scope: this.currentScope, + }); + } + + return this; + }; + + /** + * same as {@link addImpl} but this method will override the service if it exists. + * + * @see {@link ServiceCollection} + * + * @example + * ```ts + * override(OriginServiceClass, NewServiceClass, [dependencies, ...]) + * or + * override(ServiceIdentifier, ServiceClass, [dependencies, ...]) + * or + * override(ServiceIdentifier, Instance) + * or + * override(ServiceIdentifier, Factory) + * ``` + */ + override = < + Arg1 extends ServiceIdentifier, + Arg2 extends Type | ServiceFactory | Trait, + Trait = ServiceIdentifierType, + Deps extends Arg2 extends Type + ? TypesToDeps> + : [] = Arg2 extends Type + ? TypesToDeps> + : [], + Arg3 extends Deps = Deps, + >( + identifier: Arg1, + arg2: Arg2, + ...[arg3]: Arg3 extends [] ? [] : [Arg3] + ): this => { + if (arg2 instanceof Function) { + this.collection.addFactory( + identifier, + dependenciesToFactory(arg2, arg3 as any[]), + { scope: this.currentScope, override: true } + ); + } else { + this.collection.addValue(identifier, arg2 as any, { + scope: this.currentScope, + override: true, + }); + } + + return this; + }; + + /** + * Set the scope for the service registered subsequently + * + * @example + * + * ```ts + * const ScopeA = createScope('a'); + * + * services.scope(ScopeA).add(XXXService, ...); + * ``` + */ + scope = (scope: ServiceScope): ServiceCollectionEditor => { + this.currentScope = scope; + return this; + }; +} + +/** + * Convert dependencies definition to a factory function. + */ +function dependenciesToFactory( + cls: any, + deps: any[] = [] +): ServiceFactory { + return (provider: ServiceProvider) => { + const args = []; + for (const dep of deps) { + let isAll; + let identifier; + if (Array.isArray(dep)) { + if (dep.length !== 1) { + throw new Error('Invalid dependency'); + } + isAll = true; + identifier = dep[0]; + } else { + isAll = false; + identifier = dep; + } + if (isAll) { + args.push(Array.from(provider.getAll(identifier).values())); + } else { + args.push(provider.get(identifier)); + } + } + if (isConstructor(cls)) { + return new cls(...args, provider); + } else { + return cls(...args, provider); + } + }; +} + +// a hack to check if a function is a constructor +// https://github.com/zloirock/core-js/blob/232c8462c26c75864b4397b7f643a4f57c6981d5/packages/core-js/internals/is-constructor.js#L15 +function isConstructor(cls: any) { + try { + Reflect.construct(function () {}, [], cls); + return true; + } catch (error) { + return false; + } +} diff --git a/packages/common/infra/src/di/core/consts.ts b/packages/common/infra/src/di/core/consts.ts new file mode 100644 index 0000000000000..dc43ed89530a7 --- /dev/null +++ b/packages/common/infra/src/di/core/consts.ts @@ -0,0 +1,4 @@ +import type { ServiceVariant } from './types'; + +export const DEFAULT_SERVICE_VARIANT: ServiceVariant = 'default'; +export const ROOT_SCOPE = []; diff --git a/packages/common/infra/src/di/core/error.ts b/packages/common/infra/src/di/core/error.ts new file mode 100644 index 0000000000000..90fab9c35c71d --- /dev/null +++ b/packages/common/infra/src/di/core/error.ts @@ -0,0 +1,59 @@ +import { DEFAULT_SERVICE_VARIANT } from './consts'; +import type { ServiceIdentifierValue } from './types'; + +export class RecursionLimitError extends Error { + constructor() { + super('Dynamic resolve recursion limit reached'); + } +} + +export class CircularDependencyError extends Error { + constructor(public readonly dependencyStack: ServiceIdentifierValue[]) { + super( + `A circular dependency was detected.\n` + + stringifyDependencyStack(dependencyStack) + ); + } +} + +export class ServiceNotFoundError extends Error { + constructor(public readonly identifier: ServiceIdentifierValue) { + super(`Service ${stringifyIdentifier(identifier)} not found in container`); + } +} + +export class MissingDependencyError extends Error { + constructor( + public readonly from: ServiceIdentifierValue, + public readonly target: ServiceIdentifierValue, + public readonly dependencyStack: ServiceIdentifierValue[] + ) { + super( + `Missing dependency ${stringifyIdentifier( + target + )} in creating service ${stringifyIdentifier( + from + )}.\n${stringifyDependencyStack(dependencyStack)}` + ); + } +} + +export class DuplicateServiceDefinitionError extends Error { + constructor(public readonly identifier: ServiceIdentifierValue) { + super(`Service ${stringifyIdentifier(identifier)} already exists`); + } +} + +function stringifyIdentifier(identifier: ServiceIdentifierValue) { + return `[${identifier.identifierName}]${ + identifier.variant !== DEFAULT_SERVICE_VARIANT + ? `(${identifier.variant})` + : '' + }`; +} + +function stringifyDependencyStack(dependencyStack: ServiceIdentifierValue[]) { + return dependencyStack + .map(identifier => `${stringifyIdentifier(identifier)}`) + .join(' -> '); +} diff --git a/packages/common/infra/src/di/core/identifier.ts b/packages/common/infra/src/di/core/identifier.ts new file mode 100644 index 0000000000000..5812207e2d57e --- /dev/null +++ b/packages/common/infra/src/di/core/identifier.ts @@ -0,0 +1,113 @@ +import { stableHash } from '../../utils/stable-hash'; +import { DEFAULT_SERVICE_VARIANT } from './consts'; +import type { + ServiceIdentifier, + ServiceIdentifierValue, + ServiceVariant, + Type, +} from './types'; + +/** + * create a ServiceIdentifier. + * + * ServiceIdentifier is used to identify a certain type of service. With the identifier, you can reference one or more services + * without knowing the specific implementation, thereby achieving + * [inversion of control](https://en.wikipedia.org/wiki/Inversion_of_control). + * + * @example + * ```ts + * // define a interface + * interface Storage { + * get(key: string): string | null; + * set(key: string, value: string): void; + * } + * + * // create a identifier + * // NOTICE: Highly recommend to use the interface name as the identifier name, + * // so that it is easy to understand. and it is legal to do so in TypeScript. + * const Storage = createIdentifier('Storage'); + * + * // create a implementation + * class LocalStorage implements Storage { + * get(key: string): string | null { + * return localStorage.getItem(key); + * } + * set(key: string, value: string): void { + * localStorage.setItem(key, value); + * } + * } + * + * // register the implementation to the identifier + * services.addImpl(Storage, LocalStorage); + * + * // get the implementation from the identifier + * const storage = services.provider().get(Storage); + * storage.set('foo', 'bar'); + * ``` + * + * With identifier: + * + * * You can easily replace the implementation of a `Storage` without changing the code that uses it. + * * You can easily mock a `Storage` for testing. + * + * # Variant + * + * Sometimes, you may want to register multiple implementations for the same interface. + * For example, you may want have both `LocalStorage` and `SessionStorage` for `Storage`, + * and use them in same time. + * + * In this case, you can use `variant` to distinguish them. + * + * ```ts + * const Storage = createIdentifier('Storage'); + * const LocalStorage = Storage('local'); + * const SessionStorage = Storage('session'); + * + * services.addImpl(LocalStorage, LocalStorageImpl); + * services.addImpl(SessionStorage, SessionStorageImpl); + * + * // get the implementation from the identifier + * const localStorage = services.provider().get(LocalStorage); + * const sessionStorage = services.provider().get(SessionStorage); + * const storage = services.provider().getAll(Storage); // { local: LocalStorageImpl, session: SessionStorageImpl } + * ``` + * + * @param name unique name of the identifier. + * @param variant The default variant name of the identifier, can be overridden by `identifier("variant")`. + */ +export function createIdentifier( + name: string, + variant: ServiceVariant = DEFAULT_SERVICE_VARIANT +): ServiceIdentifier & ((variant: ServiceVariant) => ServiceIdentifier) { + return Object.assign( + (variant: ServiceVariant) => { + return createIdentifier(name, variant); + }, + { + identifierName: name, + variant, + } + ) as any; +} + +/** + * Convert the constructor into a ServiceIdentifier. + * As we always deal with ServiceIdentifier in the DI container. + * + * @internal + */ +export function createIdentifierFromConstructor( + target: Type +): ServiceIdentifier { + return createIdentifier(`${target.name}${stableHash(target)}`); +} + +export function parseIdentifier(input: any): ServiceIdentifierValue { + if (input.identifierName) { + return input as ServiceIdentifierValue; + } else if (typeof input === 'function' && input.name) { + return createIdentifierFromConstructor(input); + } else { + throw new Error('Input is not a service identifier.'); + } +} diff --git a/packages/common/infra/src/di/core/index.ts b/packages/common/infra/src/di/core/index.ts new file mode 100644 index 0000000000000..f86d0240c01cd --- /dev/null +++ b/packages/common/infra/src/di/core/index.ts @@ -0,0 +1,7 @@ +export * from './collection'; +export * from './consts'; +export * from './error'; +export * from './identifier'; +export * from './provider'; +export * from './scope'; +export * from './types'; diff --git a/packages/common/infra/src/di/core/provider.ts b/packages/common/infra/src/di/core/provider.ts new file mode 100644 index 0000000000000..5026e53259599 --- /dev/null +++ b/packages/common/infra/src/di/core/provider.ts @@ -0,0 +1,216 @@ +import type { ServiceCollection } from './collection'; +import { + CircularDependencyError, + MissingDependencyError, + RecursionLimitError, + ServiceNotFoundError, +} from './error'; +import { parseIdentifier } from './identifier'; +import { + type GeneralServiceIdentifier, + type ServiceIdentifierValue, + type ServiceVariant, +} from './types'; + +export interface ResolveOptions { + sameScope?: boolean; + optional?: boolean; +} + +export abstract class ServiceProvider { + abstract collection: ServiceCollection; + abstract getRaw( + identifier: ServiceIdentifierValue, + options?: ResolveOptions + ): any; + abstract getAllRaw( + identifier: ServiceIdentifierValue, + options?: ResolveOptions + ): Map; + + get(identifier: GeneralServiceIdentifier, options?: ResolveOptions): T { + return this.getRaw(parseIdentifier(identifier), { + ...options, + optional: false, + }); + } + + getAll( + identifier: GeneralServiceIdentifier, + options?: ResolveOptions + ): Map { + return this.getAllRaw(parseIdentifier(identifier), { + ...options, + }); + } + + getOptional( + identifier: GeneralServiceIdentifier, + options?: ResolveOptions + ): T | null { + return this.getRaw(parseIdentifier(identifier), { + ...options, + optional: true, + }); + } +} + +export class ServiceCachePool { + cache: Map> = new Map(); + + getOrInsert(identifier: ServiceIdentifierValue, insert: () => any) { + const cache = this.cache.get(identifier.identifierName) ?? new Map(); + if (!cache.has(identifier.variant)) { + cache.set(identifier.variant, insert()); + } + const cached = cache.get(identifier.variant); + this.cache.set(identifier.identifierName, cache); + return cached; + } +} + +export class ServiceResolver extends ServiceProvider { + constructor( + public readonly provider: BasicServiceProvider, + public readonly depth = 0, + public readonly stack: ServiceIdentifierValue[] = [] + ) { + super(); + } + + collection = this.provider.collection; + + getRaw( + identifier: ServiceIdentifierValue, + { sameScope = false, optional = false }: ResolveOptions = {} + ) { + const factory = this.provider.collection.getFactory( + identifier, + this.provider.scope + ); + if (!factory) { + if (this.provider.parent && !sameScope) { + return this.provider.parent.getRaw(identifier, { + sameScope, + optional, + }); + } + + if (optional) { + return undefined; + } + throw new ServiceNotFoundError(identifier); + } + + return this.provider.cache.getOrInsert(identifier, () => { + const nextResolver = this.track(identifier); + try { + return factory(nextResolver); + } catch (err) { + if (err instanceof ServiceNotFoundError) { + throw new MissingDependencyError( + identifier, + err.identifier, + this.stack + ); + } + throw err; + } + }); + } + + getAllRaw( + identifier: ServiceIdentifierValue, + { sameScope = false }: ResolveOptions = {} + ): Map { + const vars = this.provider.collection.getFactoryAll( + identifier, + this.provider.scope + ); + + if (vars === undefined) { + if (this.provider.parent && !sameScope) { + return this.provider.parent.getAllRaw(identifier); + } + + return new Map(); + } + + const result = new Map(); + + for (const [variant, factory] of vars) { + const service = this.provider.cache.getOrInsert( + { identifierName: identifier.identifierName, variant }, + () => { + const nextResolver = this.track(identifier); + try { + return factory(nextResolver); + } catch (err) { + if (err instanceof ServiceNotFoundError) { + throw new MissingDependencyError( + identifier, + err.identifier, + this.stack + ); + } + throw err; + } + } + ); + result.set(variant, service); + } + + return result; + } + + track(identifier: ServiceIdentifierValue): ServiceResolver { + const depth = this.depth + 1; + if (depth >= 100) { + throw new RecursionLimitError(); + } + const circular = this.stack.find( + i => + i.identifierName === identifier.identifierName && + i.variant === identifier.variant + ); + if (circular) { + throw new CircularDependencyError([...this.stack, identifier]); + } + + return new ServiceResolver(this.provider, depth, [ + ...this.stack, + identifier, + ]); + } +} + +export class BasicServiceProvider extends ServiceProvider { + public readonly cache = new ServiceCachePool(); + public readonly collection: ServiceCollection; + + constructor( + collection: ServiceCollection, + public readonly scope: string[], + public readonly parent: ServiceProvider | null + ) { + super(); + this.collection = collection.clone(); + this.collection.addValue(ServiceProvider, this, { + scope: scope, + override: true, + }); + } + + getRaw(identifier: ServiceIdentifierValue, options?: ResolveOptions) { + const resolver = new ServiceResolver(this); + return resolver.getRaw(identifier, options); + } + + getAllRaw( + identifier: ServiceIdentifierValue, + options?: ResolveOptions + ): Map { + const resolver = new ServiceResolver(this); + return resolver.getAllRaw(identifier, options); + } +} diff --git a/packages/common/infra/src/di/core/scope.ts b/packages/common/infra/src/di/core/scope.ts new file mode 100644 index 0000000000000..190bbd7d8d864 --- /dev/null +++ b/packages/common/infra/src/di/core/scope.ts @@ -0,0 +1,13 @@ +import { ROOT_SCOPE } from './consts'; +import type { ServiceScope } from './types'; + +export function createScope( + name: string, + base: ServiceScope = ROOT_SCOPE +): ServiceScope { + return [...base, name]; +} + +export function stringifyScope(scope: ServiceScope): string { + return scope.join('/'); +} diff --git a/packages/common/infra/src/di/core/types.ts b/packages/common/infra/src/di/core/types.ts new file mode 100644 index 0000000000000..e81e7ef2ee4f8 --- /dev/null +++ b/packages/common/infra/src/di/core/types.ts @@ -0,0 +1,37 @@ +import type { ServiceProvider } from './provider'; + +// eslint-disable-next-line @typescript-eslint/ban-types +export type Type = abstract new (...args: any) => T; + +export type ServiceFactory = (provider: ServiceProvider) => T; +export type ServiceVariant = string; + +/** + * + */ +export type ServiceScope = string[]; + +export type ServiceIdentifierValue = { + identifierName: string; + variant: ServiceVariant; +}; + +export type GeneralServiceIdentifier = ServiceIdentifier | Type; + +export type ServiceIdentifier = { + identifierName: string; + variant: ServiceVariant; + __TYPE__: T; +}; + +export type ServiceIdentifierType = T extends ServiceIdentifier + ? R + : T extends Type + ? R + : never; + +export type TypesToDeps = { + [index in keyof T]: + | GeneralServiceIdentifier + | (T[index] extends (infer I)[] ? [GeneralServiceIdentifier] : never); +}; diff --git a/packages/common/infra/src/di/index.ts b/packages/common/infra/src/di/index.ts new file mode 100644 index 0000000000000..78a29b8b5a3fa --- /dev/null +++ b/packages/common/infra/src/di/index.ts @@ -0,0 +1,2 @@ +export * from './core'; +export * from './react'; diff --git a/packages/common/infra/src/di/react/index.ts b/packages/common/infra/src/di/react/index.ts new file mode 100644 index 0000000000000..3f5c43fffda52 --- /dev/null +++ b/packages/common/infra/src/di/react/index.ts @@ -0,0 +1,30 @@ +import React, { useContext } from 'react'; + +import type { ServiceProvider } from '../core'; +import { type GeneralServiceIdentifier, ServiceCollection } from '../core'; + +export const ServiceProviderContext = React.createContext( + ServiceCollection.EMPTY.provider() +); + +export function useService( + identifier: GeneralServiceIdentifier, + { provider }: { provider?: ServiceProvider } = {} +): T { + const contextServiceProvider = useContext(ServiceProviderContext); + + const serviceProvider = provider ?? contextServiceProvider; + + return serviceProvider.get(identifier); +} + +export function useServiceOptional( + identifier: GeneralServiceIdentifier, + { provider }: { provider?: ServiceProvider } = {} +): T | null { + const contextServiceProvider = useContext(ServiceProviderContext); + + const serviceProvider = provider ?? contextServiceProvider; + + return serviceProvider.getOptional(identifier); +} diff --git a/packages/common/infra/src/utils/stable-hash.ts b/packages/common/infra/src/utils/stable-hash.ts new file mode 100644 index 0000000000000..406a944211452 --- /dev/null +++ b/packages/common/infra/src/utils/stable-hash.ts @@ -0,0 +1,59 @@ +// copied from https://github.com/shuding/stable-hash + +// Use WeakMap to store the object-key mapping so the objects can still be +// garbage collected. WeakMap uses a hashtable under the hood, so the lookup +// complexity is almost O(1). +const table = new WeakMap(); + +// A counter of the key. +let counter = 0; + +// A stable hash implementation that supports: +// - Fast and ensures unique hash properties +// - Handles unserializable values +// - Handles object key ordering +// - Generates short results +// +// This is not a serialization function, and the result is not guaranteed to be +// parsable. +export function stableHash(arg: any): string { + const type = typeof arg; + const constructor = arg && arg.constructor; + const isDate = constructor === Date; + + if (Object(arg) === arg && !isDate && constructor !== RegExp) { + // Object/function, not null/date/regexp. Use WeakMap to store the id first. + // If it's already hashed, directly return the result. + let result = table.get(arg); + if (result) return result; + // Store the hash first for circular reference detection before entering the + // recursive `stableHash` calls. + // For other objects like set and map, we use this id directly as the hash. + result = ++counter + '~'; + table.set(arg, result); + let index: any; + + if (constructor === Array) { + // Array. + result = '@'; + for (index = 0; index < arg.length; index++) { + result += stableHash(arg[index]) + ','; + } + table.set(arg, result); + } else if (constructor === Object) { + // Object, sort keys. + result = '#'; + const keys = Object.keys(arg).sort(); + while ((index = keys.pop() as string) !== undefined) { + if (arg[index] !== undefined) { + result += index + ':' + stableHash(arg[index]) + ','; + } + } + table.set(arg, result); + } + return result; + } + if (isDate) return arg.toJSON(); + if (type === 'symbol') return arg.toString(); + return type === 'string' ? JSON.stringify(arg) : '' + arg; +}