Skip to content

Commit b5deb41

Browse files
authored
[Jinja] Everything is an identifier... & other improvements (#1418)
- Aligned usage of identifiers w/ official jinja implementation (anything can now be an identifier) - We now support comment parsing and formatting (meaning this won't be preprocessed away when formatting). Rendering remains unchanged of course (ignore comments). - Differentiate between Integer and Float types, which is something Python does, but JavaScript doesn't, so we needed to take special care here. - Add support for new statement types: - filter: `{% filter %}...{% endfilter %}` - call: `{% call %}...{% endcall %}` - (custom) generation: `{% generation %}{% endgeneration %}` - Add support for new expression types: - spread: `{{ fn(*args) }}` - ternary (previously we just used the if expression): `{{ 1 if true else 2 }}` - Add support for new functions: - `replace` - `strftime_now` (get current time according to a narrow set of time templates, found in the wild) - and many others - Reduce number of redundant brackets when formatting membership and property accesses. - Improved binary operator precedence rules and general formatting rules. - Fixed edge-case lexing issues This now means we support (at least) parsing, formatting, and rendering of the top 100,000 transformers-compatible chat templates on the Hugging Face Hub 🥳
1 parent 88989d2 commit b5deb41

File tree

11 files changed

+2552
-1054
lines changed

11 files changed

+2552
-1054
lines changed

packages/jinja/src/ast.ts

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ export class Macro extends Statement {
7777
}
7878
}
7979

80+
export class Comment extends Statement {
81+
override type = "Comment";
82+
constructor(public value: string) {
83+
super();
84+
}
85+
}
86+
8087
/**
8188
* Expressions will result in a value at runtime (unlike statements).
8289
*/
@@ -133,11 +140,12 @@ abstract class Literal<T> extends Expression {
133140
}
134141
}
135142

136-
/**
137-
* Represents a numeric constant in the template.
138-
*/
139-
export class NumericLiteral extends Literal<number> {
140-
override type = "NumericLiteral";
143+
export class IntegerLiteral extends Literal<number> {
144+
override type = "IntegerLiteral";
145+
}
146+
147+
export class FloatLiteral extends Literal<number> {
148+
override type = "FloatLiteral";
141149
}
142150

143151
/**
@@ -147,20 +155,6 @@ export class StringLiteral extends Literal<string> {
147155
override type = "StringLiteral";
148156
}
149157

150-
/**
151-
* Represents a boolean constant in the template.
152-
*/
153-
export class BooleanLiteral extends Literal<boolean> {
154-
override type = "BooleanLiteral";
155-
}
156-
157-
/**
158-
* Represents null (none) in the template.
159-
*/
160-
export class NullLiteral extends Literal<null> {
161-
override type = "NullLiteral";
162-
}
163-
164158
/**
165159
* Represents an array literal in the template.
166160
*/
@@ -214,15 +208,28 @@ export class FilterExpression extends Expression {
214208
}
215209
}
216210

211+
export class FilterStatement extends Statement {
212+
override type = "FilterStatement";
213+
214+
constructor(
215+
public filter: Identifier | CallExpression,
216+
public body: Statement[]
217+
) {
218+
super();
219+
}
220+
}
221+
217222
/**
218223
* An operation which filters a sequence of objects by applying a test to each object,
219224
* and only selecting the objects with the test succeeding.
225+
*
226+
* It may also be used as a shortcut for a ternary operator.
220227
*/
221228
export class SelectExpression extends Expression {
222229
override type = "SelectExpression";
223230

224231
constructor(
225-
public iterable: Expression,
232+
public lhs: Expression,
226233
public test: Expression
227234
) {
228235
super();
@@ -258,17 +265,6 @@ export class UnaryExpression extends Expression {
258265
}
259266
}
260267

261-
/**
262-
* Logical negation of an expression.
263-
*/
264-
export class LogicalNegationExpression extends Expression {
265-
override type = "LogicalNegationExpression";
266-
267-
constructor(public argument: Expression) {
268-
super();
269-
}
270-
}
271-
272268
export class SliceExpression extends Expression {
273269
override type = "SliceExpression";
274270

@@ -291,3 +287,34 @@ export class KeywordArgumentExpression extends Expression {
291287
super();
292288
}
293289
}
290+
291+
export class SpreadExpression extends Expression {
292+
override type = "SpreadExpression";
293+
294+
constructor(public argument: Expression) {
295+
super();
296+
}
297+
}
298+
299+
export class CallStatement extends Statement {
300+
override type = "CallStatement";
301+
302+
constructor(
303+
public call: CallExpression,
304+
public callerArgs: Expression[] | null,
305+
public body: Statement[]
306+
) {
307+
super();
308+
}
309+
}
310+
311+
export class Ternary extends Expression {
312+
override type = "Ternary";
313+
constructor(
314+
public condition: Expression,
315+
public trueExpr: Expression,
316+
public falseExpr: Expression
317+
) {
318+
super();
319+
}
320+
}

packages/jinja/src/format.ts

Lines changed: 96 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import type {
22
Program,
33
Statement,
4+
Comment,
45
If,
56
For,
67
SetStatement,
@@ -9,9 +10,9 @@ import type {
910
MemberExpression,
1011
CallExpression,
1112
Identifier,
12-
NumericLiteral,
13+
FloatLiteral,
14+
IntegerLiteral,
1315
StringLiteral,
14-
BooleanLiteral,
1516
ArrayLiteral,
1617
TupleLiteral,
1718
ObjectLiteral,
@@ -20,20 +21,33 @@ import type {
2021
SelectExpression,
2122
TestExpression,
2223
UnaryExpression,
23-
LogicalNegationExpression,
2424
SliceExpression,
2525
KeywordArgumentExpression,
26+
CallStatement,
27+
FilterStatement,
28+
SpreadExpression,
29+
Ternary,
2630
} from "./ast";
2731

2832
const NEWLINE = "\n";
2933
const OPEN_STATEMENT = "{%- ";
3034
const CLOSE_STATEMENT = " -%}";
3135

32-
const OPERATOR_PRECEDENCE: Record<string, number> = {
33-
MultiplicativeBinaryOperator: 2,
34-
AdditiveBinaryOperator: 1,
35-
ComparisonBinaryOperator: 0,
36-
};
36+
function getBinaryOperatorPrecedence(expr: BinaryExpression): number {
37+
switch (expr.operator.type) {
38+
case "MultiplicativeBinaryOperator":
39+
return 4;
40+
case "AdditiveBinaryOperator":
41+
return 3;
42+
case "ComparisonBinaryOperator":
43+
return 2;
44+
case "Identifier":
45+
if (expr.operator.value === "and") return 1;
46+
if (expr.operator.value === "in" || expr.operator.value === "not in") return 2;
47+
return 0;
48+
}
49+
return 0;
50+
}
3751

3852
export function format(program: Program, indent: string | number = "\t"): string {
3953
const indentStr = typeof indent === "number" ? " ".repeat(indent) : indent;
@@ -66,6 +80,12 @@ function formatStatement(node: Statement, depth: number, indentStr: string): str
6680
return pad + createStatement("break");
6781
case "Continue":
6882
return pad + createStatement("continue");
83+
case "CallStatement":
84+
return formatCallStatement(node as CallStatement, depth, indentStr);
85+
case "FilterStatement":
86+
return formatFilterStatement(node as FilterStatement, depth, indentStr);
87+
case "Comment":
88+
return pad + "{# " + (node as Comment).value + " #}";
6989
default:
7090
return pad + "{{- " + formatExpression(node as Expression) + " -}}";
7191
}
@@ -93,7 +113,7 @@ function formatIf(node: If, depth: number, indentStr: string): string {
93113
formatStatements(clauses[0].body, depth + 1, indentStr);
94114

95115
// ELIF(s)
96-
for (let i = 1; i < clauses.length; i++) {
116+
for (let i = 1; i < clauses.length; ++i) {
97117
out +=
98118
NEWLINE +
99119
pad +
@@ -119,7 +139,7 @@ function formatFor(node: For, depth: number, indentStr: string): string {
119139
if (node.iterable.type === "SelectExpression") {
120140
// Handle special case: e.g., `for x in [1, 2, 3] if x > 2`
121141
const n = node.iterable as SelectExpression;
122-
formattedIterable = `${formatExpression(n.iterable)} if ${formatExpression(n.test)}`;
142+
formattedIterable = `${formatExpression(n.lhs)} if ${formatExpression(n.test)}`;
123143
} else {
124144
formattedIterable = formatExpression(node.iterable);
125145
}
@@ -166,20 +186,46 @@ function formatMacro(node: Macro, depth: number, indentStr: string): string {
166186
);
167187
}
168188

189+
function formatCallStatement(node: CallStatement, depth: number, indentStr: string): string {
190+
const pad = indentStr.repeat(depth);
191+
const params =
192+
node.callerArgs && node.callerArgs.length > 0 ? `(${node.callerArgs.map(formatExpression).join(", ")})` : "";
193+
const callExpr = formatExpression(node.call);
194+
let out = pad + createStatement(`call${params}`, callExpr) + NEWLINE;
195+
out += formatStatements(node.body, depth + 1, indentStr) + NEWLINE;
196+
out += pad + createStatement("endcall");
197+
return out;
198+
}
199+
200+
function formatFilterStatement(node: FilterStatement, depth: number, indentStr: string): string {
201+
const pad = indentStr.repeat(depth);
202+
const spec =
203+
node.filter.type === "Identifier"
204+
? (node.filter as Identifier).value
205+
: formatExpression(node.filter as CallExpression);
206+
let out = pad + createStatement("filter", spec) + NEWLINE;
207+
out += formatStatements(node.body, depth + 1, indentStr) + NEWLINE;
208+
out += pad + createStatement("endfilter");
209+
return out;
210+
}
211+
169212
function formatExpression(node: Expression, parentPrec: number = -1): string {
170213
switch (node.type) {
214+
case "SpreadExpression": {
215+
const n = node as SpreadExpression;
216+
return `*${formatExpression(n.argument)}`;
217+
}
171218
case "Identifier":
172219
return (node as Identifier).value;
173-
case "NullLiteral":
174-
return "none";
175-
case "NumericLiteral":
176-
case "BooleanLiteral":
177-
return `${(node as NumericLiteral | BooleanLiteral).value}`;
220+
case "IntegerLiteral":
221+
return `${(node as IntegerLiteral).value}`;
222+
case "FloatLiteral":
223+
return `${(node as FloatLiteral).value}`;
178224
case "StringLiteral":
179225
return JSON.stringify((node as StringLiteral).value);
180226
case "BinaryExpression": {
181227
const n = node as BinaryExpression;
182-
const thisPrecedence = OPERATOR_PRECEDENCE[n.operator.type] ?? 0;
228+
const thisPrecedence = getBinaryOperatorPrecedence(n);
183229
const left = formatExpression(n.left, thisPrecedence);
184230
const right = formatExpression(n.right, thisPrecedence + 1);
185231
const expr = `${left} ${n.operator.value} ${right}`;
@@ -190,20 +236,31 @@ function formatExpression(node: Expression, parentPrec: number = -1): string {
190236
const val = n.operator.value + (n.operator.value === "not" ? " " : "") + formatExpression(n.argument, Infinity);
191237
return val;
192238
}
193-
case "LogicalNegationExpression":
194-
return `not ${formatExpression((node as LogicalNegationExpression).argument, Infinity)}`;
195239
case "CallExpression": {
196240
const n = node as CallExpression;
197-
const args = n.args.map((a) => formatExpression(a, -1)).join(", ");
198-
return `${formatExpression(n.callee, -1)}(${args})`;
241+
const args = n.args.map(formatExpression).join(", ");
242+
return `${formatExpression(n.callee)}(${args})`;
199243
}
200244
case "MemberExpression": {
201245
const n = node as MemberExpression;
202-
let obj = formatExpression(n.object, -1);
203-
if (n.object.type !== "Identifier") {
246+
let obj = formatExpression(n.object);
247+
// only wrap if it's not a simple or chained access/call
248+
if (
249+
![
250+
"Identifier",
251+
"MemberExpression",
252+
"CallExpression",
253+
"StringLiteral",
254+
"IntegerLiteral",
255+
"FloatLiteral",
256+
"ArrayLiteral",
257+
"TupleLiteral",
258+
"ObjectLiteral",
259+
].includes(n.object.type)
260+
) {
204261
obj = `(${obj})`;
205262
}
206-
let prop = formatExpression(n.property, -1);
263+
let prop = formatExpression(n.property);
207264
if (!n.computed && n.property.type !== "Identifier") {
208265
prop = `(${prop})`;
209266
}
@@ -213,48 +270,47 @@ function formatExpression(node: Expression, parentPrec: number = -1): string {
213270
const n = node as FilterExpression;
214271
const operand = formatExpression(n.operand, Infinity);
215272
if (n.filter.type === "CallExpression") {
216-
return `${operand} | ${formatExpression(n.filter, -1)}`;
273+
return `${operand} | ${formatExpression(n.filter)}`;
217274
}
218275
return `${operand} | ${(n.filter as Identifier).value}`;
219276
}
220277
case "SelectExpression": {
221278
const n = node as SelectExpression;
222-
return `${formatExpression(n.iterable, -1)} | select(${formatExpression(n.test, -1)})`;
279+
return `${formatExpression(n.lhs)} if ${formatExpression(n.test)}`;
223280
}
224281
case "TestExpression": {
225282
const n = node as TestExpression;
226-
return `${formatExpression(n.operand, -1)} is${n.negate ? " not" : ""} ${n.test.value}`;
283+
return `${formatExpression(n.operand)} is${n.negate ? " not" : ""} ${n.test.value}`;
227284
}
228285
case "ArrayLiteral":
229286
case "TupleLiteral": {
230-
const elems = ((node as ArrayLiteral | TupleLiteral).value as Expression[]).map((e) => formatExpression(e, -1));
287+
const elems = ((node as ArrayLiteral | TupleLiteral).value as Expression[]).map(formatExpression);
231288
const brackets = node.type === "ArrayLiteral" ? "[]" : "()";
232289
return `${brackets[0]}${elems.join(", ")}${brackets[1]}`;
233290
}
234291
case "ObjectLiteral": {
235292
const entries = Array.from((node as ObjectLiteral).value.entries()).map(
236-
([k, v]) => `${formatExpression(k, -1)}: ${formatExpression(v, -1)}`
293+
([k, v]) => `${formatExpression(k)}: ${formatExpression(v)}`
237294
);
238-
return `{ ${entries.join(", ")} }`;
295+
return `{${entries.join(", ")}}`;
239296
}
240297
case "SliceExpression": {
241298
const n = node as SliceExpression;
242-
const s = n.start ? formatExpression(n.start, -1) : "";
243-
const t = n.stop ? formatExpression(n.stop, -1) : "";
244-
const st = n.step ? `:${formatExpression(n.step, -1)}` : "";
299+
const s = n.start ? formatExpression(n.start) : "";
300+
const t = n.stop ? formatExpression(n.stop) : "";
301+
const st = n.step ? `:${formatExpression(n.step)}` : "";
245302
return `${s}:${t}${st}`;
246303
}
247304
case "KeywordArgumentExpression": {
248305
const n = node as KeywordArgumentExpression;
249-
return `${n.key.value}=${formatExpression(n.value, -1)}`;
306+
return `${n.key.value}=${formatExpression(n.value)}`;
250307
}
251-
case "If": {
252-
// Special case for ternary operator (If as an expression, not a statement)
253-
const n = node as If;
254-
const test = formatExpression(n.test, -1);
255-
const body = formatExpression(n.body[0], 0); // Ternary operators have a single body and alternate
256-
const alternate = formatExpression(n.alternate[0], -1);
257-
return `${body} if ${test} else ${alternate}`;
308+
case "Ternary": {
309+
const n = node as Ternary;
310+
const expr = `${formatExpression(n.trueExpr)} if ${formatExpression(n.condition, 0)} else ${formatExpression(
311+
n.falseExpr
312+
)}`;
313+
return parentPrec > -1 ? `(${expr})` : expr;
258314
}
259315
default:
260316
throw new Error(`Unknown expression type: ${node.type}`);

0 commit comments

Comments
 (0)