Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): Support transpose with explicit permutation #256

Merged
merged 2 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/late-tables-peel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@zarrita/core": patch
---

Support transpose wiht explicit permutation
15 changes: 5 additions & 10 deletions packages/core/src/codecs/bytes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@ import type {
DataType,
TypedArrayConstructor,
} from "../metadata.js";
import {
byteswap_inplace,
get_array_order,
get_ctr,
get_strides,
} from "../util.js";
import { byteswap_inplace, get_ctr, get_strides } from "../util.js";

const LITTLE_ENDIAN_OS = system_is_little_endian();

Expand All @@ -31,10 +26,10 @@ function bytes_per_element<D extends DataType>(

export class BytesCodec<D extends Exclude<DataType, "v2:object">> {
kind = "array_to_bytes";
#strides: number[];
#stride: Array<number>;
#TypedArray: TypedArrayConstructor<D>;
#BYTES_PER_ELEMENT: number;
#shape: number[];
#shape: Array<number>;
#endian?: "little" | "big";

constructor(
Expand All @@ -44,7 +39,7 @@ export class BytesCodec<D extends Exclude<DataType, "v2:object">> {
this.#endian = configuration?.endian;
this.#TypedArray = get_ctr(meta.data_type);
this.#shape = meta.shape;
this.#strides = get_strides(meta.shape, get_array_order(meta.codecs));
this.#stride = get_strides(meta.shape, "C");
// TODO: fix me.
// hack to get bytes per element since it's dynamic for string types.
const sample = new this.#TypedArray(0);
Expand Down Expand Up @@ -77,7 +72,7 @@ export class BytesCodec<D extends Exclude<DataType, "v2:object">> {
bytes.byteLength / this.#BYTES_PER_ELEMENT,
),
shape: this.#shape,
stride: this.#strides,
stride: this.#stride,
};
}
}
80 changes: 67 additions & 13 deletions packages/core/src/codecs/transpose.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import assert from "node:assert";
import type {
Chunk,
DataType,
Expand Down Expand Up @@ -41,7 +42,7 @@ function proxy<D extends DataType>(arr: TypedArray<D>): TypedArrayProxy<D> {

function empty_like<D extends DataType>(
chunk: Chunk<D>,
order: "C" | "F",
order: Order,
): Chunk<D> {
let data: TypedArray<D>;
if (
Expand All @@ -67,7 +68,7 @@ function empty_like<D extends DataType>(

function convert_array_order<D extends DataType>(
src: Chunk<D>,
target: "C" | "F",
target: Order,
): Chunk<D> {
let out = empty_like(src, target);
let n_dims = src.shape.length;
Expand Down Expand Up @@ -99,30 +100,83 @@ function convert_array_order<D extends DataType>(
return out;
}

function get_order(arr: Chunk<DataType>): "C" | "F" {
// Assume C order if no stride is given
if (!arr.stride) return "C";
let row_major_strides = get_strides(arr.shape, "C");
return arr.stride.every((s, i) => s === row_major_strides[i]) ? "C" : "F";
/** Determine the memory order (axis permutation) for a chunk */
function get_order(chunk: Chunk<DataType>): number[] {
let rank = chunk.shape.length;
assert(
rank === chunk.stride.length,
"Shape and stride must have the same length.",
);
return chunk.stride
.map((s, i) => ({ stride: s, index: i }))
.sort((a, b) => b.stride - a.stride)
.map((entry) => entry.index);
}

function matches_order(chunk: Chunk<DataType>, target: Order) {
let source = get_order(chunk);
assert(source.length === target.length, "Orders must match");
return source.every((dim, i) => dim === target[i]);
}

type Order = "C" | "F" | Array<number>;

export class TransposeCodec {
kind = "array_to_array";
#order: Array<number>;
#inverseOrder: Array<number>;

constructor(configuration: { order?: Order }, meta: { shape: number[] }) {
let value = configuration.order ?? "C";
let rank = meta.shape.length;
let order = new Array<number>(rank);
let inverseOrder = new Array<number>(rank);

if (value === "C") {
for (let i = 0; i < rank; ++i) {
order[i] = i;
inverseOrder[i] = i;
}
} else if (value === "F") {
for (let i = 0; i < rank; ++i) {
order[i] = rank - i - 1;
inverseOrder[i] = rank - i - 1;
}
} else {
order = value;
order.forEach((x, i) => {
assert(
inverseOrder[x] === undefined,
`Invalid permutation: ${JSON.stringify(value)}`,
);
inverseOrder[x] = i;
});
}

constructor(public configuration?: { order: "C" | "F" }) {}
this.#order = order;
this.#inverseOrder = inverseOrder;
}

static fromConfig(configuration: { order: "C" | "F" }) {
return new TransposeCodec(configuration);
static fromConfig(
configuration: { order: Order },
meta: { shape: number[] },
) {
return new TransposeCodec(configuration, meta);
}

encode<D extends DataType>(arr: Chunk<D>): Chunk<D> {
if (get_order(arr) === this.configuration?.order) {
if (matches_order(arr, this.#inverseOrder)) {
// can skip making a copy
return arr;
}
return convert_array_order(arr, this.configuration?.order ?? "C");
return convert_array_order(arr, this.#inverseOrder);
}

decode<D extends DataType>(arr: Chunk<D>): Chunk<D> {
return arr;
return {
data: arr.data,
shape: arr.shape,
stride: get_strides(arr.shape, this.#order),
};
}
}
10 changes: 9 additions & 1 deletion packages/core/src/hierarchy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type {
ArrayMetadata,
Attributes,
Chunk,
CodecMetadata,
DataType,
GroupMetadata,
Scalar,
Expand All @@ -19,7 +20,6 @@ import {
import {
create_chunk_key_encoder,
ensure_correct_scalar,
get_array_order,
get_ctr,
get_strides,
} from "./util.js";
Expand Down Expand Up @@ -63,6 +63,14 @@ export class Group<Store extends Readable> extends Location<Store> {
}
}

function get_array_order(
codecs: CodecMetadata[],
): "C" | "F" | globalThis.Array<number> {
const maybe_transpose_codec = codecs.find((c) => c.name === "transpose");
// @ts-expect-error - TODO: Should validate?
return maybe_transpose_codec?.configuration?.order ?? "C";
}

const CONTEXT_MARKER = Symbol("zarrita.context");

export function get_context<T>(obj: { [CONTEXT_MARKER]: T }): T {
Expand Down
47 changes: 21 additions & 26 deletions packages/core/src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,29 @@ export function get_ctr<D extends DataType>(
}

/** Compute strides for 'C' or 'F' ordered array from shape */
export function get_strides(shape: readonly number[], order: "C" | "F") {
return (order === "C" ? row_major_stride : col_major_stride)(shape);
}

function row_major_stride(shape: readonly number[]) {
const ndim = shape.length;
const stride: number[] = globalThis.Array(ndim);
for (let i = ndim - 1, step = 1; i >= 0; i--) {
stride[i] = step;
step *= shape[i];
export function get_strides(
shape: readonly number[],
order: "C" | "F" | Array<number>,
) {
const rank = shape.length;
if (typeof order === "string") {
order =
order === "C"
? Array.from({ length: rank }, (_, i) => i) // Row-major (identity order)
: Array.from({ length: rank }, (_, i) => rank - 1 - i); // Column-major (reverse order)
}
return stride;
}
assert(
rank === order.length,
"Order length must match the number of dimensions.",
);

function col_major_stride(shape: readonly number[]) {
const ndim = shape.length;
const stride: number[] = globalThis.Array(ndim);
for (let i = 0, step = 1; i < ndim; i++) {
stride[i] = step;
step *= shape[i];
let step = 1;
let stride = new Array(rank);
for (let i = order.length - 1; i >= 0; i--) {
stride[order[i]] = step;
step *= shape[order[i]];
}

return stride;
}

Expand All @@ -119,21 +121,14 @@ export function create_chunk_key_encoder({
throw new Error(`Unknown chunk key encoding: ${name}`);
}

export function get_array_order(codecs: CodecMetadata[]): "C" | "F" {
const maybe_transpose_codec = codecs.find((c) => c.name === "transpose");
return maybe_transpose_codec?.configuration?.order === "F" ? "F" : "C";
}

const endian_regex = /^([<|>])(.*)$/;

function coerce_dtype(
dtype: string,
): { data_type: DataType } | { data_type: DataType; endian: "little" | "big" } {
if (dtype === "|O") {
return { data_type: "v2:object" };
}

let match = dtype.match(endian_regex);
let match = dtype.match(/^([<|>])(.*)$/);
assert(match, `Invalid dtype: ${dtype}`);

let [, endian, rest] = match;
Expand Down