Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,23 @@ By default, `compile(...)` returns a public-safe result surface: emitted SQL on
`voight` currently ships with these policy helpers:

- `tenantScopingPolicy(...)` to inject and enforce tenant filters on configured tables
- `maxLimitPolicy(...)` to cap `LIMIT`, optionally cap `OFFSET`, and optionally add a default `LIMIT`
- `maxLimitPolicy(...)` to cap the outer `LIMIT`, optionally cap outer `OFFSET`, and optionally add a default outer `LIMIT`
- `allowedFunctionsPolicy(...)` to restrict callable SQL functions and `CURRENT_*` keywords
- `supportedOperatorsPolicy()` to reject operators outside the supported policy surface

`maxLimitPolicy(...)` constrains the final result set by default, so nested selects are not
limited unless `recursive: true` is configured.

`tenantScopingPolicy(...)` requires `scopeValueType` on every scope rule and enforces that type
at runtime. Use `scopeValueType: "string"` for string scope columns, or configure the matching
numeric or boolean type, for example `scopeValueType: "bigint"` for a `BIGINT project_id`.

String tenant scopes require careful database configuration. MySQL in particular has many
string comparison gotchas around implicit casts, collations, charsets, padding, and
case/accent equivalence. Avoid string tenant identifiers unless those semantics are deliberate;
if you use them, choose binary or otherwise case-sensitive comparison semantics so values such
as `project-alpha` and `PROJECT-ALPHA` cannot share a scope unintentionally.

## Example

```ts
Expand Down Expand Up @@ -112,6 +125,7 @@ const result = compile(
tables: ["tracking.time_series_stats"],
scopeColumn: "tenant_id",
contextKey: "tenantId",
scopeValueType: "string",
}),
],
policyContext: {
Expand Down
16 changes: 15 additions & 1 deletion packages/voight/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const result = compile(
tables: ["tracking.time_series_stats"],
scopeColumn: "tenant_id",
contextKey: "tenantId",
scopeValueType: "string",
}),
],
policyContext: {
Expand Down Expand Up @@ -133,10 +134,23 @@ Use `AliasCatalog` and `createCatalogAlias(...)` if your public logical table na
## Built-in Policies

- `tenantScopingPolicy(...)` injects tenant predicates during rewrite and verifies them during enforcement.
- `maxLimitPolicy(...)` caps `LIMIT`, can cap `OFFSET`, and can add a default `LIMIT`.
- `maxLimitPolicy(...)` caps the outer `LIMIT`, can cap the outer `OFFSET`, and can add a default outer `LIMIT`.
- `allowedFunctionsPolicy(...)` allowlists function calls and `CURRENT_*` keywords.
- `supportedOperatorsPolicy()` rejects operators outside the supported policy surface.

`maxLimitPolicy(...)` constrains the final result set by default, so nested selects are not
limited unless `recursive: true` is configured.

`tenantScopingPolicy(...)` requires `scopeValueType` on every scope rule and enforces that type
at runtime. Use `scopeValueType: "string"` for string scope columns, or configure the matching
numeric or boolean type, for example `scopeValueType: "bigint"` for a `BIGINT project_id`.

String tenant scopes require careful database configuration. MySQL in particular has many string
comparison gotchas around implicit casts, collations, charsets, padding, and case/accent
equivalence. Avoid string tenant identifiers unless those semantics are deliberate; if you use
them, choose binary or otherwise case-sensitive comparison semantics so values such as
`project-alpha` and `PROJECT-ALPHA` cannot share a scope unintentionally.

## Repository

The source repository lives at [github.com/lukaskratzel/voight](https://github.com/lukaskratzel/voight). The workspace README has more detail on the parser stack, development workflow, and release process.
44 changes: 44 additions & 0 deletions packages/voight/src/binder/expression.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ import type {
BoundWindowSpecification,
CaseExpressionNode,
CastExpressionNode,
CastTypeNode,
ExistsExpressionNode,
ExpressionNode,
GroupingExpressionNode,
IdentifierNode,
IdentifierExpressionNode,
InListExpressionNode,
InSubqueryExpressionNode,
Expand All @@ -54,6 +56,8 @@ import {
import { stageFailure, stageSuccess, type StageResult } from "../core/result";
import type { SourceSpan } from "../core/source";

const RAW_SQL_IDENTIFIER_PATTERN = /^[a-z_][a-z0-9_$]*$/;

export type BindResult<T> = StageResult<T, CompilerStage.Binder, { scopeSize: number }>;

export interface BinderExpressionContext {
Expand Down Expand Up @@ -239,6 +243,14 @@ function bindFunction(
node: BoundFunctionCall["ast"],
): BindResult<BoundFunctionCall> {
const callee = normalizeIdentifier(node.callee.name);
if (!isSafeRawSqlIdentifier(node.callee, callee)) {
return context.fail(
DiagnosticCode.UnsupportedConstruct,
"Function names must be unquoted simple identifiers.",
node.callee.span,
);
}

const args: BoundExpression[] = [];
for (const arg of node.arguments) {
const bound = context.bindExpression(arg);
Expand Down Expand Up @@ -350,6 +362,15 @@ function bindCast(
context: BinderExpressionContext,
node: CastExpressionNode,
): BindResult<BoundCastExpression> {
const unsafeTypeIdentifier = findUnsafeCastTypeIdentifier(node.targetType);
if (unsafeTypeIdentifier) {
return context.fail(
DiagnosticCode.UnsupportedConstruct,
"CAST target type names must be unquoted simple identifiers.",
unsafeTypeIdentifier.span,
);
}

const expression = context.bindExpression(node.expression);
if (!expression.ok) {
return expression;
Expand All @@ -368,6 +389,29 @@ function bindCast(
);
}

function findUnsafeCastTypeIdentifier(type: CastTypeNode): IdentifierNode | undefined {
for (const part of type.name.parts) {
if (!isSafeRawSqlIdentifier(part, normalizeIdentifier(part.name))) {
return part;
}
}

for (const argument of type.arguments) {
if (argument.kind === "CastType") {
const unsafe = findUnsafeCastTypeIdentifier(argument);
if (unsafe) {
return unsafe;
}
}
}

return undefined;
}

function isSafeRawSqlIdentifier(identifier: IdentifierNode, normalized: string): boolean {
return !identifier.quoted && RAW_SQL_IDENTIFIER_PATTERN.test(normalized);
}

function bindCase(
context: BinderExpressionContext,
node: CaseExpressionNode,
Expand Down
15 changes: 12 additions & 3 deletions packages/voight/src/catalog/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export class InMemoryCatalog implements Catalog {

export function createIdentifierPath(...parts: string[]): IdentifierPath {
return {
parts: parts.map(normalizeIdentifier),
parts: parts.map((part) => normalizeIdentifierPathSegment(part)),
};
}

Expand All @@ -57,7 +57,7 @@ export function createTableSchema(input: {
}
)[];
}): TableSchema {
const normalizedPath = input.path.map(normalizeIdentifier);
const normalizedPath = input.path.map((part) => normalizeIdentifierPathSegment(part));
if (normalizedPath.length === 0 || normalizedPath.some((part) => part.length === 0)) {
throw new Error("createTableSchema requires a non-empty path.");
}
Expand Down Expand Up @@ -95,8 +95,17 @@ export function normalizeIdentifier(value: string): string {
return value.toLowerCase();
}

function normalizeIdentifierPathSegment(value: string): string {
const normalized = normalizeIdentifier(value);
if (!normalized || normalized.includes(".")) {
throw new Error("Catalog identifier path segments cannot be empty or contain dots.");
}

return normalized;
}

function normalizeIdentifierPath(path: IdentifierPath): string {
return path.parts.map(normalizeIdentifier).join(".");
return JSON.stringify(path.parts.map(normalizeIdentifier));
}

export class AliasCatalog implements Catalog {
Expand Down
3 changes: 3 additions & 0 deletions packages/voight/src/compiler/enforcer.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import type { BoundQuery } from "../ast";
import type { Catalog } from "../catalog";
import { CompilerStage, type Diagnostic } from "../core/diagnostics";
import { type CompilerPolicy, type PolicyContext, resolvePolicies } from "../policies";
import { stageFailure, stageSuccess, type StageResult } from "../core/result";

export interface EnforcementOptions {
readonly policies?: readonly CompilerPolicy[];
readonly policyContext?: PolicyContext;
readonly catalog?: Catalog;
}

export type EnforcementResult = StageResult<
Expand All @@ -19,6 +21,7 @@ export function enforce(bound: BoundQuery, options: EnforcementOptions = {}): En
const policies = resolvePolicies(options);
const context = {
context: options.policyContext ?? {},
catalog: options.catalog,
};

policies.forEach((policy) => {
Expand Down
1 change: 1 addition & 0 deletions packages/voight/src/compiler/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ function compileInternal(source: string, options: CompileOptions): CompileResult
const enforced = enforce(bound.value, {
policies: options.policies,
policyContext: options.policyContext,
catalog: options.catalog,
});
if (!enforced.ok) {
return {
Expand Down
49 changes: 45 additions & 4 deletions packages/voight/src/emitter/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,19 @@ function emitBoundExpression(expression: BoundExpression, parameterIndices: numb
: `-${emitBoundExpression(expression.operand, parameterIndices)}`;
case "BoundBinaryExpression":
return emitBinary(
emitBoundBinaryOperand(expression.left, expression.operator, parameterIndices),
emitBoundBinaryOperand(
expression.left,
expression.operator,
parameterIndices,
"left",
),
expression.operator,
emitBoundBinaryOperand(expression.right, expression.operator, parameterIndices),
emitBoundBinaryOperand(
expression.right,
expression.operator,
parameterIndices,
"right",
),
);
case "BoundFunctionCall":
return `${expression.callee}(${expression.distinct ? "DISTINCT " : ""}${expression.arguments.map((arg) => emitBoundExpression(arg, parameterIndices)).join(", ")})${expression.over ? ` ${emitWindowSpecification(expression.over, parameterIndices)}` : ""}`;
Expand Down Expand Up @@ -333,19 +343,50 @@ function emitBoundBinaryOperand(
expression: BoundExpression,
parentOperator: BinaryExpressionNode["operator"],
parameterIndices: number[],
side: "left" | "right",
): string {
const emitted = emitBoundExpression(expression, parameterIndices);
return expression.kind === "BoundBinaryExpression" &&
shouldParenthesizeBinary(expression.operator, parentOperator)
shouldParenthesizeBinary(expression.operator, parentOperator, side)
? `(${emitted})`
: emitted;
}

function shouldParenthesizeBinary(
childOperator: BinaryExpressionNode["operator"],
parentOperator: BinaryExpressionNode["operator"],
side: "left" | "right",
): boolean {
return binaryPrecedence(childOperator) < binaryPrecedence(parentOperator);
const childPrecedence = binaryPrecedence(childOperator);
const parentPrecedence = binaryPrecedence(parentOperator);
if (childPrecedence < parentPrecedence) {
return true;
}

return (
side === "right" &&
childPrecedence === parentPrecedence &&
!canFlattenRightBinaryOperand(parentOperator, childOperator)
);
}

function canFlattenRightBinaryOperand(
parentOperator: BinaryExpressionNode["operator"],
childOperator: BinaryExpressionNode["operator"],
): boolean {
if (parentOperator === "AND" || parentOperator === "OR") {
return childOperator === parentOperator;
}

if (parentOperator === "+") {
return childOperator === "+";
}

if (parentOperator === "*") {
return childOperator === "*";
}

return false;
}

function binaryPrecedence(operator: BinaryExpressionNode["operator"]): number {
Expand Down
4 changes: 3 additions & 1 deletion packages/voight/src/policies/allowed-functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ export function allowedFunctionsPolicy(options: AllowedFunctionsPolicyOptions):
return new AllowedFunctionsPolicy(options);
}

export const ALLOWED_FUNCTIONS_POLICY_NAME = "allowed-functions";

class AllowedFunctionsPolicy implements CompilerPolicy {
readonly name = "allowed-functions";
readonly name = ALLOWED_FUNCTIONS_POLICY_NAME;
readonly #allowedFunctions: ReadonlySet<string>;

constructor(options: AllowedFunctionsPolicyOptions) {
Expand Down
25 changes: 21 additions & 4 deletions packages/voight/src/policies/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import { allowedFunctionsPolicy, type AllowedFunctionsPolicyOptions } from "./allowed-functions";
import { maxLimitPolicy, type MaxLimitPolicyOptions } from "./max-limit";
import { supportedOperatorsPolicy } from "./supported-operators";
import {
ALLOWED_FUNCTIONS_POLICY_NAME,
allowedFunctionsPolicy,
type AllowedFunctionsPolicyOptions,
} from "./allowed-functions";
import { MAX_LIMIT_POLICY_NAME, maxLimitPolicy, type MaxLimitPolicyOptions } from "./max-limit";
import { SUPPORTED_OPERATORS_POLICY_NAME, supportedOperatorsPolicy } from "./supported-operators";
import {
PolicyConflictError,
PolicyConfigurationError,
Expand All @@ -11,9 +15,11 @@ import {
type PolicySelectionOptions,
} from "./shared";
import {
TENANT_SCOPING_POLICY_NAME,
tenantScopingPolicy,
type TenantScopingPolicyOptions,
type TenantScopingScopeOptions,
type TenantScopeValueType,
} from "./tenant-scoping";

export type {
Expand All @@ -29,20 +35,31 @@ export type {
MaxLimitPolicyOptions,
TenantScopingPolicyOptions,
TenantScopingScopeOptions,
TenantScopeValueType,
};

export {
ALLOWED_FUNCTIONS_POLICY_NAME,
MAX_LIMIT_POLICY_NAME,
allowedFunctionsPolicy,
maxLimitPolicy,
PolicyConflictError,
PolicyConfigurationError,
PolicyDiagnosticError,
PolicyError,
PolicyUsageError,
SUPPORTED_OPERATORS_POLICY_NAME,
TENANT_SCOPING_POLICY_NAME,
supportedOperatorsPolicy,
tenantScopingPolicy,
};

export function resolvePolicies(options: PolicySelectionOptions = {}) {
return dedupePoliciesByName(options.policies ?? []);
const policies = dedupePoliciesByName(options.policies ?? []);
if (policies.some((policy) => policy.name === ALLOWED_FUNCTIONS_POLICY_NAME)) {
return policies;
}

// If no allowed functions policy is provided, add a default one that disallows all functions.
return [allowedFunctionsPolicy({ allowedFunctions: new Set() }), ...policies];
}
Loading
Loading