Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions __tests__/datatypes.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import { DataType } from "@polars/datatypes";

describe("DataType variants", () => {
describe("Simple types", () => {
it("Null should have variant 'Null'", () => {
const dt = DataType.Null;
expect(dt.variant).toBe("Null");
});

it("Bool should have variant 'Bool'", () => {
const dt = DataType.Bool;
expect(dt.variant).toBe("Bool");
});

it("Int8 should have variant 'Int8'", () => {
const dt = DataType.Int8;
expect(dt.variant).toBe("Int8");
});

it("Int16 should have variant 'Int16'", () => {
const dt = DataType.Int16;
expect(dt.variant).toBe("Int16");
});

it("Int32 should have variant 'Int32'", () => {
const dt = DataType.Int32;
expect(dt.variant).toBe("Int32");
});

it("Int64 should have variant 'Int64'", () => {
const dt = DataType.Int64;
expect(dt.variant).toBe("Int64");
});

it("UInt8 should have variant 'UInt8'", () => {
const dt = DataType.UInt8;
expect(dt.variant).toBe("UInt8");
});

it("UInt16 should have variant 'UInt16'", () => {
const dt = DataType.UInt16;
expect(dt.variant).toBe("UInt16");
});

it("UInt32 should have variant 'UInt32'", () => {
const dt = DataType.UInt32;
expect(dt.variant).toBe("UInt32");
});

it("UInt64 should have variant 'UInt64'", () => {
const dt = DataType.UInt64;
expect(dt.variant).toBe("UInt64");
});

it("Float32 should have variant 'Float32'", () => {
const dt = DataType.Float32;
expect(dt.variant).toBe("Float32");
});

it("Float64 should have variant 'Float64'", () => {
const dt = DataType.Float64;
expect(dt.variant).toBe("Float64");
});

it("Date should have variant 'Date'", () => {
const dt = DataType.Date;
expect(dt.variant).toBe("Date");
});

it("Time should have variant 'Time'", () => {
const dt = DataType.Time;
expect(dt.variant).toBe("Time");
});

it("Object should have variant 'Object'", () => {
const dt = DataType.Object;
expect(dt.variant).toBe("Object");
});

it("Utf8 should have variant 'Utf8'", () => {
const dt = DataType.Utf8;
expect(dt.variant).toBe("Utf8");
});

it("String should have variant 'String'", () => {
const dt = DataType.String;
expect(dt.variant).toBe("String");
});

it("Categorical should have variant 'Categorical'", () => {
const dt = DataType.Categorical;
expect(dt.variant).toBe("Categorical");
});
});

describe("Complex types", () => {
it("Decimal should have variant 'Decimal'", () => {
const dt = DataType.Decimal();
expect(dt.variant).toBe("Decimal");
});

it("Decimal with precision and scale should have variant 'Decimal'", () => {
const dt = DataType.Decimal(10, 2);
expect(dt.variant).toBe("Decimal");
});

it("Datetime should have variant 'Datetime'", () => {
const dt = DataType.Datetime();
expect(dt.variant).toBe("Datetime");
});

it("Datetime with timeUnit should have variant 'Datetime'", () => {
const dt = DataType.Datetime("ms");
expect(dt.variant).toBe("Datetime");
});

it("Datetime with timeUnit and timeZone should have variant 'Datetime'", () => {
const dt = DataType.Datetime("ms", "America/New_York");
expect(dt.variant).toBe("Datetime");
});

it("List should have variant 'List'", () => {
const dt = DataType.List(DataType.Int32);
expect(dt.variant).toBe("List");
});

it("FixedSizeList should have variant 'FixedSizeList'", () => {
const dt = DataType.FixedSizeList(DataType.Float64, 5);
expect(dt.variant).toBe("FixedSizeList");
});

it("Struct should have variant 'Struct'", () => {
const dt = DataType.Struct({
a: DataType.Int32,
b: DataType.Utf8,
});
expect(dt.variant).toBe("Struct");
});
});

describe("Variant usage in equals", () => {
it("should use variant for type comparison", () => {
const dt1 = DataType.Int32;
const dt2 = DataType.Int32;
const dt3 = DataType.Int64;

expect(dt1.equals(dt2)).toBe(true);
expect(dt1.equals(dt3)).toBe(false);
});

it("should use variant for complex type comparison", () => {
const dt1 = DataType.List(DataType.Int32);
const dt2 = DataType.List(DataType.Int32);
const dt3 = DataType.List(DataType.Float64);

expect(dt1.equals(dt2)).toBe(true);
expect(dt1.equals(dt3)).toBe(false);
});
});
});
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@
"lint": "yarn lint:ts && yarn format:rs",
"prepublishOnly": "napi prepublish -t npm",
"test": "jest",
"type": "tsc --noEmit -p tsconfig.test.json",
"version": "napi version",
"precommit": "yarn lint && yarn test"
"precommit": "yarn lint && yarn type && yarn test"
},
"devDependencies": {
"@biomejs/biome": "=2.2.4",
Expand Down
27 changes: 24 additions & 3 deletions polars/datatypes/datatype.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ import { Field } from "./field";

export abstract class DataType<Dtype extends DataTypeName = any> {
declare readonly __dtype: Dtype;
get variant() {
return this.constructor.name as DataTypeName;
}
abstract readonly variant: DataTypeName;
protected identity = "DataType";
protected get inner(): null | any[] {
return null;
Expand Down Expand Up @@ -166,66 +164,85 @@ export abstract class DataType<Dtype extends DataTypeName = any> {

export class Null extends DataType<"Null"> {
declare __dtype: "Null";
readonly variant = "Null";
}

export class Bool extends DataType<"Bool"> {
declare __dtype: "Bool";
readonly variant = "Bool";
}
export class Int8 extends DataType<"Int8"> {
declare __dtype: "Int8";
readonly variant = "Int8";
}
export class Int16 extends DataType<"Int16"> {
declare __dtype: "Int16";
readonly variant = "Int16";
}
export class Int32 extends DataType<"Int32"> {
declare __dtype: "Int32";
readonly variant = "Int32";
}
export class Int64 extends DataType<"Int64"> {
declare __dtype: "Int64";
readonly variant = "Int64";
}
export class UInt8 extends DataType<"UInt8"> {
declare __dtype: "UInt8";
readonly variant = "UInt8";
}
export class UInt16 extends DataType<"UInt16"> {
declare __dtype: "UInt16";
readonly variant = "UInt16";
}
export class UInt32 extends DataType<"UInt32"> {
declare __dtype: "UInt32";
readonly variant = "UInt32";
}
export class UInt64 extends DataType<"UInt64"> {
declare __dtype: "UInt64";
readonly variant = "UInt64";
}
export class Float32 extends DataType<"Float32"> {
declare __dtype: "Float32";
readonly variant = "Float32";
}
export class Float64 extends DataType<"Float64"> {
declare __dtype: "Float64";
readonly variant = "Float64";
}

// biome-ignore lint/suspicious/noShadowRestrictedNames: Using Polars Date
export class Date extends DataType<"Date"> {
declare __dtype: "Date";
readonly variant = "Date";
}
export class Time extends DataType<"Time"> {
declare __dtype: "Time";
readonly variant = "Time";
}
export class Object_ extends DataType<"Object"> {
declare __dtype: "Object";
readonly variant = "Object";
}
export class Utf8 extends DataType<"Utf8"> {
declare __dtype: "Utf8";
readonly variant = "Utf8";
}
// biome-ignore lint/suspicious/noShadowRestrictedNames: Using Polars String
export class String extends DataType<"String"> {
declare __dtype: "String";
readonly variant = "String";
}

export class Categorical extends DataType<"Categorical"> {
declare __dtype: "Categorical";
readonly variant = "Categorical";
}

export class Decimal extends DataType<"Decimal"> {
declare __dtype: "Decimal";
readonly variant = "Decimal";
private precision: number | null;
private scale: number | null;
constructor(precision?: number, scale?: number) {
Expand Down Expand Up @@ -263,6 +280,7 @@ export class Decimal extends DataType<"Decimal"> {
*/
export class Datetime extends DataType<"Datetime"> {
declare __dtype: "Datetime";
readonly variant = "Datetime";
constructor(
private timeUnit: TimeUnit | "ms" | "ns" | "us" = "ms",
private timeZone?: string | null,
Expand All @@ -286,6 +304,7 @@ export class Datetime extends DataType<"Datetime"> {

export class List extends DataType<"List"> {
declare __dtype: "List";
readonly variant = "List";
constructor(protected __inner: DataType) {
super();
}
Expand All @@ -302,6 +321,7 @@ export class List extends DataType<"List"> {

export class FixedSizeList extends DataType<"FixedSizeList"> {
declare __dtype: "FixedSizeList";
readonly variant = "FixedSizeList";
constructor(
protected __inner: DataType,
protected listSize: number,
Expand Down Expand Up @@ -336,6 +356,7 @@ export class FixedSizeList extends DataType<"FixedSizeList"> {

export class Struct extends DataType<"Struct"> {
declare __dtype: "Struct";
readonly variant = "Struct";
private fields: Field[];

constructor(
Expand Down
Loading