diff --git a/apps/server/src/schemas/entities.ts b/apps/server/src/schemas/entities.ts index 78607590..5a640032 100644 --- a/apps/server/src/schemas/entities.ts +++ b/apps/server/src/schemas/entities.ts @@ -22,6 +22,13 @@ export const EntitySchema = z.object({ region: RegionSchema.nullable().default(null), address: z.string().trim().nullish().default(null), location: PointSchema.nullable().default(null), + parent: z + .object({ + id: z.number(), + name: z.string(), + }) + .nullable() + .default(null), totalTastings: z.number().readonly(), totalBottles: z.number().readonly(), @@ -39,6 +46,13 @@ export const EntityInputSchema = EntitySchema.omit({ }).extend({ country: z.number().nullish().default(null), region: z.number().nullish().default(null), + parent: z + .number() + .nullish() + .default(null) + .refine((val) => val === null || val === undefined || val > 0, { + message: "Parent entity ID must be a positive number", + }), }); export const EntityMergeSchema = z.object({ diff --git a/apps/server/src/serializers/entity.ts b/apps/server/src/serializers/entity.ts index f3078fce..cbff53da 100644 --- a/apps/server/src/serializers/entity.ts +++ b/apps/server/src/serializers/entity.ts @@ -1,8 +1,14 @@ -import { inArray } from "drizzle-orm"; +import { eq, inArray } from "drizzle-orm"; import { type z } from "zod"; import { serialize, serializer } from "."; import { db } from "../db"; -import { countries, regions, type Entity, type User } from "../db/schema"; +import { + countries, + entities, + regions, + type Entity, + type User, +} from "../db/schema"; import { notEmpty } from "../lib/filter"; import { type EntitySchema } from "../schemas"; import { CountrySerializer } from "./country"; @@ -39,6 +45,23 @@ export const EntitySerializer = serializer({ ) : {}; + const parentIds = itemList.map((i) => i.parentId).filter(notEmpty); + const parentList = parentIds.length + ? await db.select().from(entities).where(inArray(entities.id, parentIds)) + : []; + + const parentsById = parentList.length + ? Object.fromEntries( + parentList.map((parent) => [ + parent.id, + { + id: parent.id, + name: parent.name, + }, + ]), + ) + : {}; + return Object.fromEntries( itemList.map((item) => { return [ @@ -46,6 +69,7 @@ export const EntitySerializer = serializer({ { country: item.countryId ? countriesById[item.countryId] : null, region: item.regionId ? regionsById[item.regionId] : null, + parent: item.parentId ? parentsById[item.parentId] : null, }, ]; }), @@ -66,6 +90,7 @@ export const EntitySerializer = serializer({ website: item.website, country: attrs.country, region: attrs.region, + parent: attrs.parent, address: item.address, location: item.location, createdAt: item.createdAt.toISOString(), diff --git a/apps/server/src/trpc/routes/entityCreate.test.ts b/apps/server/src/trpc/routes/entityCreate.test.ts index d410a8f6..d6fad520 100644 --- a/apps/server/src/trpc/routes/entityCreate.test.ts +++ b/apps/server/src/trpc/routes/entityCreate.test.ts @@ -68,3 +68,41 @@ test("updates existing entity with new type", async ({ expect(brand.id).toEqual(entity.id); expect(brand.type).toEqual(["distiller", "brand"]); }); + +test("creates a new entity with parent", async ({ fixtures }) => { + const parentEntity = await fixtures.Entity(); + const caller = createCaller({ user: await fixtures.User({ mod: true }) }); + + const data = await caller.entityCreate({ + name: "Child Entity", + parent: parentEntity.id, + }); + + expect(data.id).toBeDefined(); + expect(data.parent).toBeDefined(); + expect(data.parent?.id).toEqual(parentEntity.id); + expect(data.parent?.name).toEqual(parentEntity.name); + + const [childEntity] = await db + .select() + .from(entities) + .where(eq(entities.id, data.id)); + + expect(childEntity.parentId).toEqual(parentEntity.id); +}); + +test("fails with invalid parent entity ID", async ({ fixtures }) => { + const caller = createCaller({ user: await fixtures.User({ mod: true }) }); + + const nonExistentParentId = 999999; // A parent ID that doesn't exist + + const err = await waitError( + caller.entityCreate({ + name: "Child Entity", + parent: nonExistentParentId, + }), + ); + + expect(err).toMatchInlineSnapshot(`[TRPCError: NOT_FOUND]`); + expect(err.message).toContain("Parent entity not found"); +}); diff --git a/apps/server/src/trpc/routes/entityCreate.ts b/apps/server/src/trpc/routes/entityCreate.ts index 7b760ec6..a4403590 100644 --- a/apps/server/src/trpc/routes/entityCreate.ts +++ b/apps/server/src/trpc/routes/entityCreate.ts @@ -26,6 +26,7 @@ export default verifiedProcedure name: normalizeEntityName(input.name), type: input.type || [], createdById: ctx.user.id, + parentId: input.parent || null, }; if (input.country) { diff --git a/apps/server/src/trpc/routes/entityUpdate.test.ts b/apps/server/src/trpc/routes/entityUpdate.test.ts index e3c71cbb..5d1d3820 100644 --- a/apps/server/src/trpc/routes/entityUpdate.test.ts +++ b/apps/server/src/trpc/routes/entityUpdate.test.ts @@ -417,3 +417,133 @@ test("updates existing conflicting alias", async ({ fixtures }) => { expect(newAlias.name).toEqual("Cool Cats Single Barrel Bourbon"); expect(newAlias.bottleId).toEqual(newBottle.id); }); + +test("can change parent", async ({ fixtures }) => { + const entity = await fixtures.Entity(); + const parentEntity = await fixtures.Entity(); + + const caller = createCaller({ + user: await fixtures.User({ mod: true }), + }); + + const data = await caller.entityUpdate({ + entity: entity.id, + parent: parentEntity.id, + }); + + expect(data.id).toBeDefined(); + expect(data.parent).toBeDefined(); + expect(data.parent?.id).toEqual(parentEntity.id); + expect(data.parent?.name).toEqual(parentEntity.name); + + const [newEntity] = await db + .select() + .from(entities) + .where(eq(entities.id, data.id)); + + expect(omit(entity, "parentId", "searchVector", "updatedAt")).toEqual( + omit(newEntity, "parentId", "searchVector", "updatedAt"), + ); + expect(newEntity.parentId).toBe(parentEntity.id); + + // Verify that the change is recorded in the changes table + const [change] = await db + .select() + .from(changes) + .where(eq(changes.objectId, newEntity.id)) + .orderBy(desc(changes.id)) + .limit(1); + + expect(change).toBeDefined(); + expect(change.data).toHaveProperty("parentId", parentEntity.id); +}); + +test("can remove parent", async ({ fixtures }) => { + const parentEntity = await fixtures.Entity(); + const entity = await fixtures.Entity({ + parentId: parentEntity.id, + }); + + const caller = createCaller({ + user: await fixtures.User({ mod: true }), + }); + + const data = await caller.entityUpdate({ + entity: entity.id, + parent: null, + }); + + expect(data.id).toBeDefined(); + expect(data.parent).toBeNull(); + + const [newEntity] = await db + .select() + .from(entities) + .where(eq(entities.id, data.id)); + + expect(newEntity.parentId).toBeNull(); +}); + +test("prevents circular parent references", async ({ fixtures }) => { + const entity = await fixtures.Entity(); + const childEntity = await fixtures.Entity({ + parentId: entity.id, + }); + + const caller = createCaller({ + user: await fixtures.User({ mod: true }), + }); + + const err = await waitError( + caller.entityUpdate({ + entity: entity.id, + parent: childEntity.id, + }), + ); + + expect(err).toMatchInlineSnapshot(`[TRPCError: BAD_REQUEST]`); + expect(err.message).toContain("circular reference"); +}); + +test("prevents deep circular parent references", async ({ fixtures }) => { + const rootEntity = await fixtures.Entity(); + const midEntity = await fixtures.Entity({ + parentId: rootEntity.id, + }); + const leafEntity = await fixtures.Entity({ + parentId: midEntity.id, + }); + + const caller = createCaller({ + user: await fixtures.User({ mod: true }), + }); + + const err = await waitError( + caller.entityUpdate({ + entity: rootEntity.id, + parent: leafEntity.id, + }), + ); + + expect(err).toMatchInlineSnapshot(`[TRPCError: BAD_REQUEST]`); + expect(err.message).toContain("circular reference"); +}); + +test("fails with invalid parent entity ID", async ({ fixtures }) => { + const entity = await fixtures.Entity(); + const nonExistentParentId = 999999; // A parent ID that doesn't exist + + const caller = createCaller({ + user: await fixtures.User({ mod: true }), + }); + + const err = await waitError( + caller.entityUpdate({ + entity: entity.id, + parent: nonExistentParentId, + }), + ); + + expect(err).toMatchInlineSnapshot(`[TRPCError: NOT_FOUND]`); + expect(err.message).toContain("Parent entity not found"); +}); diff --git a/apps/server/src/trpc/routes/entityUpdate.ts b/apps/server/src/trpc/routes/entityUpdate.ts index c76826a2..677291e4 100644 --- a/apps/server/src/trpc/routes/entityUpdate.ts +++ b/apps/server/src/trpc/routes/entityUpdate.ts @@ -49,23 +49,46 @@ export default modProcedure data.shortName = input.shortName; } - if (input.country) { - if (input.country) { - const [country] = await db + if (input.parent !== undefined && input.parent !== entity.parentId) { + // Check for circular reference + let parentId = input.parent; + while (parentId) { + const [parent] = await db .select() - .from(countries) - .where(eq(countries.id, input.country)) - .limit(1); - if (!country) { + .from(entities) + .where(eq(entities.id, parentId)); + if (!parent) { throw new TRPCError({ - message: "Country not found.", + message: "Parent entity not found.", code: "NOT_FOUND", }); } - if (country.id !== entity.countryId) { - data.countryId = country.id; - data.regionId = null; + if (parent.id === entity.id) { + throw new TRPCError({ + message: "Cannot create circular reference in entity hierarchy.", + code: "BAD_REQUEST", + }); } + parentId = parent.parentId; + } + data.parentId = input.parent; + } + + if (input.country) { + const [country] = await db + .select() + .from(countries) + .where(eq(countries.id, input.country)) + .limit(1); + if (!country) { + throw new TRPCError({ + message: "Country not found.", + code: "NOT_FOUND", + }); + } + if (country.id !== entity.countryId) { + data.countryId = country.id; + data.regionId = null; } } else if (input.country === null) { if (entity.countryId) { diff --git a/apps/server/src/worker/jobs/mergeEntity.ts b/apps/server/src/worker/jobs/mergeEntity.ts index 0aff5dbb..27d840d2 100644 --- a/apps/server/src/worker/jobs/mergeEntity.ts +++ b/apps/server/src/worker/jobs/mergeEntity.ts @@ -105,7 +105,6 @@ export default async function mergeEntity({ newEntityId: toEntity.id, }); } - await tx.delete(entities).where(inArray(entities.id, fromEntityIds)); }); diff --git a/apps/web/src/components/entityForm.tsx b/apps/web/src/components/entityForm.tsx index 933f47eb..41834911 100644 --- a/apps/web/src/components/entityForm.tsx +++ b/apps/web/src/components/entityForm.tsx @@ -4,6 +4,7 @@ import { toTitleCase } from "@peated/server/lib/strings"; import { EntityInputSchema } from "@peated/server/schemas"; import { type Entity } from "@peated/server/types"; import CountryField from "@peated/web/components/countryField"; +import EntityField from "@peated/web/components/entityField"; import Fieldset from "@peated/web/components/fieldset"; import Form from "@peated/web/components/form"; import FormError from "@peated/web/components/formError"; @@ -54,6 +55,7 @@ export default function EntityForm({ ...initialData, country: initialData.country ? initialData.country.id : null, region: initialData.region ? initialData.region.id : null, + parent: initialData.parent ? initialData.parent.id : null, }, }); @@ -70,6 +72,15 @@ export default function EntityForm({ : undefined, ); + const [parentValue, setParentValue] = useState