Skip to content

Commit e42039e

Browse files
committed
Add Fix command
-Add logging to Fix command -Format the copied directives in project.scala -Add support for test. directives in main scope -Add support for sources with shebang -Move removing directives as the last operation in Fix command -Remove ordering via hardcoded names -Apply review fixes
1 parent ffd6765 commit e42039e

File tree

10 files changed

+901
-23
lines changed

10 files changed

+901
-23
lines changed

modules/build/src/main/scala/scala/build/preprocessing/SheBang.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@ object SheBang {
77

88
def isShebangScript(content: String): Boolean = sheBangRegex.unanchored.matches(content)
99

10+
/** Returns the shebang section and the content without the shebang section */
11+
def partitionOnShebangSection(content: String): (String, String) =
12+
if (content.startsWith("#!")) {
13+
val regexMatch = sheBangRegex.findFirstMatchIn(content)
14+
regexMatch match {
15+
case Some(firstMatch) =>
16+
(firstMatch.toString(), content.replaceFirst(firstMatch.toString(), ""))
17+
case None => ("", content)
18+
}
19+
}
20+
else
21+
("", content)
22+
1023
def ignoreSheBangLines(content: String): (String, Boolean) =
1124
if (content.startsWith("#!")) {
1225
val regexMatch = sheBangRegex.findFirstMatchIn(content)

modules/cli/src/main/scala/scala/cli/ScalaCliCommands.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class ScalaCliCommands(
3838
directories.Directories,
3939
doc.Doc,
4040
export0.Export,
41+
fix.Fix,
4142
fmt.Fmt,
4243
new HelpCmd(help),
4344
installcompletions.InstallCompletions,
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
package scala.cli.commands.fix
2+
3+
import caseapp.core.RemainingArgs
4+
5+
import scala.build.Ops.EitherMap2
6+
import scala.build.errors.{BuildException, CompositeBuildException}
7+
import scala.build.input.*
8+
import scala.build.internal.Constants
9+
import scala.build.options.{BuildOptions, Scope, SuppressWarningOptions}
10+
import scala.build.preprocessing.directives.*
11+
import scala.build.preprocessing.{ExtractedDirectives, SheBang}
12+
import scala.build.{CrossSources, Logger, Position, Sources}
13+
import scala.cli.commands.shared.SharedOptions
14+
import scala.cli.commands.{ScalaCommand, SpecificationLevel}
15+
import scala.collection.immutable.HashMap
16+
import scala.util.chaining.scalaUtilChainingOps
17+
18+
object Fix extends ScalaCommand[FixOptions] {
19+
override def group = "Main"
20+
override def scalaSpecificationLevel = SpecificationLevel.EXPERIMENTAL
21+
override def sharedOptions(options: FixOptions): Option[SharedOptions] = Some(options.shared)
22+
23+
lazy val targetDirectivesKeysSet = DirectivesPreprocessingUtils.requireDirectiveHandlers
24+
.flatMap(_.keys.flatMap(_.nameAliases)).toSet
25+
lazy val usingDirectivesKeysGrouped = DirectivesPreprocessingUtils.usingDirectiveHandlers
26+
.flatMap(_.keys)
27+
lazy val usingDirectivesWithTestPrefixKeysGrouped =
28+
DirectivesPreprocessingUtils.usingDirectiveWithReqsHandlers
29+
.flatMap(_.keys)
30+
31+
val newLine = System.lineSeparator()
32+
33+
override def runCommand(options: FixOptions, args: RemainingArgs, logger: Logger): Unit = {
34+
val inputs = options.shared.inputs(args.remaining, () => Inputs.default()).orExit(logger)
35+
36+
val (mainSources, testSources) = getProjectSources(inputs)
37+
.left.map(CompositeBuildException(_))
38+
.orExit(logger)
39+
40+
// Only initial inputs are used, new inputs discovered during processing of
41+
// CrossSources.forInput may be shared between projects
42+
val writableInputs: Seq[OnDisk] = inputs.flattened()
43+
.collect { case onDisk: OnDisk => onDisk }
44+
45+
def isExtractedFromWritableInput(position: Option[Position.File]): Boolean = {
46+
val originOrPathOpt = position.map(_.path)
47+
originOrPathOpt match {
48+
case Some(Right(path)) => writableInputs.exists(_.path == path)
49+
case _ => false
50+
}
51+
}
52+
53+
val projectFileContents = new StringBuilder()
54+
55+
given LoggingUtilities(logger, inputs.workspace)
56+
57+
// Deal with directives from the Main scope
58+
val (directivesFromWritableMainInputs, testDirectivesFromMain) = {
59+
val originalMainDirectives = getExtractedDirectives(mainSources)
60+
.filterNot(hasTargetDirectives)
61+
62+
val transformedMainDirectives = unifyCorrespondingNameAliases(originalMainDirectives)
63+
64+
val allDirectives = for {
65+
transformedMainDirective <- transformedMainDirectives
66+
directive <- transformedMainDirective.directives
67+
} yield directive
68+
69+
val (testScopeDirectives, allMainDirectives) =
70+
allDirectives.partition(_.key.startsWith("test"))
71+
72+
createFormattedLinesAndAppend(allMainDirectives, projectFileContents, isTest = false)
73+
74+
(
75+
transformedMainDirectives.filter(d => isExtractedFromWritableInput(d.positions)),
76+
testScopeDirectives
77+
)
78+
}
79+
80+
// Deal with directives from the Test scope
81+
val directivesFromWritableTestInputs: Seq[TransformedTestDirectives] =
82+
if (
83+
testSources.paths.nonEmpty || testSources.inMemory.nonEmpty || testDirectivesFromMain.nonEmpty
84+
) {
85+
val originalTestDirectives = getExtractedDirectives(testSources)
86+
.filterNot(hasTargetDirectives)
87+
88+
val transformedTestDirectives = unifyCorrespondingNameAliases(originalTestDirectives)
89+
.pipe(maybeTransformIntoTestEquivalent)
90+
91+
val allDirectives = for {
92+
directivesWithTestPrefix <- transformedTestDirectives.map(_.withTestPrefix)
93+
directive <- directivesWithTestPrefix ++ testDirectivesFromMain
94+
} yield directive
95+
96+
createFormattedLinesAndAppend(allDirectives, projectFileContents, isTest = true)
97+
98+
transformedTestDirectives
99+
.filter(ttd => isExtractedFromWritableInput(ttd.positions))
100+
}
101+
else Seq(TransformedTestDirectives(Nil, Nil, None))
102+
103+
projectFileContents.append(newLine)
104+
105+
// Write extracted directives to project.scala
106+
logger.message(s"Writing ${Constants.projectFileName}")
107+
os.write.over(inputs.workspace / Constants.projectFileName, projectFileContents.toString)
108+
109+
def isProjectFile(position: Option[Position.File]): Boolean =
110+
position.exists(_.path.contains(inputs.workspace / Constants.projectFileName))
111+
112+
// Remove directives from their original files, skip the project.scala file
113+
directivesFromWritableMainInputs
114+
.filterNot(e => isProjectFile(e.positions))
115+
.foreach(d => removeDirectivesFrom(d.positions))
116+
directivesFromWritableTestInputs
117+
.filterNot(ttd => isProjectFile(ttd.positions))
118+
.foreach(ttd => removeDirectivesFrom(ttd.positions, toKeep = ttd.noTestPrefixAvailable))
119+
}
120+
121+
def getProjectSources(inputs: Inputs): Either[::[BuildException], (Sources, Sources)] = {
122+
val buildOptions = BuildOptions()
123+
124+
val (crossSources, _) = CrossSources.forInputs(
125+
inputs,
126+
preprocessors = Sources.defaultPreprocessors(
127+
buildOptions.archiveCache,
128+
buildOptions.internal.javaClassNameVersionOpt,
129+
() => buildOptions.javaHome().value.javaCommand
130+
),
131+
logger = logger,
132+
suppressWarningOptions = SuppressWarningOptions.suppressAll,
133+
exclude = buildOptions.internal.exclude
134+
).orExit(logger)
135+
136+
val sharedOptions = crossSources.sharedOptions(buildOptions)
137+
val scopedSources = crossSources.scopedSources(sharedOptions).orExit(logger)
138+
139+
val mainSources = scopedSources.sources(Scope.Main, sharedOptions, inputs.workspace)
140+
val testSources = scopedSources.sources(Scope.Test, sharedOptions, inputs.workspace)
141+
142+
(mainSources, testSources).traverseN
143+
}
144+
145+
def getExtractedDirectives(sources: Sources)(
146+
using loggingUtilities: LoggingUtilities
147+
): Seq[ExtractedDirectives] = {
148+
val logger = loggingUtilities.logger
149+
150+
val fromPaths = sources.paths.map { (path, _) =>
151+
val (_, content) = SheBang.partitionOnShebangSection(os.read(path))
152+
logger.debug(s"Extracting directives from ${loggingUtilities.relativePath(path)}")
153+
ExtractedDirectives.from(content.toCharArray, Right(path), logger, _ => None).orExit(logger)
154+
}
155+
156+
val fromInMemory = sources.inMemory.map { inMem =>
157+
val originOrPath = inMem.originalPath.map((_, path) => path)
158+
val content = originOrPath match {
159+
case Right(path) =>
160+
logger.debug(s"Extracting directives from ${loggingUtilities.relativePath(path)}")
161+
os.read(path)
162+
case Left(origin) =>
163+
logger.debug(s"Extracting directives from $origin")
164+
inMem.wrapperParamsOpt match {
165+
// In case of script snippets, we need to drop the top wrapper lines
166+
case Some(wrapperParams) => String(inMem.content)
167+
.linesWithSeparators
168+
.drop(wrapperParams.topWrapperLineCount)
169+
.mkString
170+
case None => String(inMem.content)
171+
}
172+
}
173+
174+
val (_, contentWithNoShebang) = SheBang.partitionOnShebangSection(content)
175+
176+
ExtractedDirectives.from(
177+
contentWithNoShebang.toCharArray,
178+
originOrPath,
179+
logger,
180+
_ => None
181+
).orExit(logger)
182+
}
183+
184+
fromPaths ++ fromInMemory
185+
}
186+
187+
def hasTargetDirectives(extractedDirectives: ExtractedDirectives): Boolean = {
188+
// Filter out all elements that contain using target directives
189+
val directivesInElement = extractedDirectives.directives.map(_.key)
190+
directivesInElement.exists(key => targetDirectivesKeysSet.contains(key))
191+
}
192+
193+
def unifyCorrespondingNameAliases(extractedDirectives: Seq[ExtractedDirectives]) =
194+
extractedDirectives.map { extracted =>
195+
// All keys that we migrate, not all in general
196+
val allKeysGrouped = usingDirectivesKeysGrouped ++ usingDirectivesWithTestPrefixKeysGrouped
197+
val strictDirectives = extracted.directives
198+
199+
val strictDirectivesWithNewKeys = strictDirectives.flatMap { strictDir =>
200+
val newKeyOpt = allKeysGrouped.find(_.nameAliases.contains(strictDir.key))
201+
.flatMap(_.nameAliases.headOption)
202+
.map { key =>
203+
if (key.startsWith("test"))
204+
val withTestStripped = key.stripPrefix("test").stripPrefix(".")
205+
"test." + withTestStripped.take(1).toLowerCase + withTestStripped.drop(1)
206+
else
207+
key
208+
}
209+
210+
newKeyOpt.map(newKey => strictDir.copy(key = newKey))
211+
}
212+
213+
extracted.copy(directives = strictDirectivesWithNewKeys)
214+
}
215+
216+
/** Transforms directives into their 'test.' equivalent if it exists
217+
*
218+
* @param extractedDirectives
219+
* @return
220+
* an instance of TransformedTestDirectives containing transformed directives and those that
221+
* could not be transformed since they have no 'test.' equivalent
222+
*/
223+
def maybeTransformIntoTestEquivalent(extractedDirectives: Seq[ExtractedDirectives])
224+
: Seq[TransformedTestDirectives] =
225+
for {
226+
extractedFromSingleElement <- extractedDirectives
227+
directives = extractedFromSingleElement.directives
228+
} yield {
229+
val (withTestEquivalent, noTestEquivalent) = directives.partition { directive =>
230+
usingDirectivesWithTestPrefixKeysGrouped.exists(
231+
_.nameAliases.contains("test." + directive.key)
232+
)
233+
}
234+
235+
val transformedToTestEquivalents = withTestEquivalent.map {
236+
case StrictDirective(key, values) => StrictDirective("test." + key, values)
237+
}
238+
239+
TransformedTestDirectives(
240+
withTestPrefix = transformedToTestEquivalents,
241+
noTestPrefixAvailable = noTestEquivalent,
242+
positions = extractedFromSingleElement.positions
243+
)
244+
}
245+
246+
def removeDirectivesFrom(
247+
position: Option[Position.File],
248+
toKeep: Seq[StrictDirective] = Nil
249+
)(
250+
using loggingUtilities: LoggingUtilities
251+
): Unit = {
252+
position match {
253+
case Some(Position.File(Right(path), _, _, offset)) =>
254+
val (shebangSection, strippedContent) = SheBang.partitionOnShebangSection(os.read(path))
255+
256+
def ignoreOrAddNewLine(str: String) = if str.isBlank then "" else str + newLine
257+
258+
val keepLines = ignoreOrAddNewLine(shebangSection) + ignoreOrAddNewLine(toKeep.mkString(
259+
"",
260+
newLine,
261+
newLine
262+
))
263+
val newContents = keepLines + strippedContent.drop(offset).stripLeading()
264+
val relativePath = loggingUtilities.relativePath(path)
265+
266+
loggingUtilities.logger.message(s"Removing directives from $relativePath")
267+
if (toKeep.nonEmpty) {
268+
loggingUtilities.logger.message(" Keeping:")
269+
toKeep.foreach(d => loggingUtilities.logger.message(s" $d"))
270+
}
271+
272+
os.write.over(path, newContents.stripLeading())
273+
case _ => ()
274+
}
275+
}
276+
277+
def createFormattedLinesAndAppend(
278+
strictDirectives: Seq[StrictDirective],
279+
projectFileContents: StringBuilder,
280+
isTest: Boolean
281+
): Unit = {
282+
if (strictDirectives.nonEmpty) {
283+
projectFileContents
284+
.append(if (projectFileContents.nonEmpty) newLine else "")
285+
.append(if isTest then "// Test" else "// Main")
286+
.append(newLine)
287+
288+
strictDirectives
289+
// group by key to merge values
290+
.groupBy(_.key)
291+
.map { (key, directives) =>
292+
StrictDirective(key, directives.flatMap(_.values))
293+
}
294+
// group by key prefixes to create splits between groups
295+
.groupBy(dir => (if (isTest) dir.key.stripPrefix("test.") else dir.key).takeWhile(_ != '.'))
296+
.map { (_, directives) =>
297+
directives.flatMap(_.explodeToStringsWithColLimit()).toSeq.sorted
298+
}
299+
.toSeq
300+
.filter(_.nonEmpty)
301+
.sortBy(_.head)(using directivesOrdering)
302+
// append groups to the StringBuilder, add new lines between groups that are bigger than one line
303+
.foldLeft(0) { (lastSize, directiveLines) =>
304+
val newSize = directiveLines.size
305+
if (lastSize > 1 || (lastSize != 0 && newSize > 1)) projectFileContents.append(newLine)
306+
307+
directiveLines.foreach(projectFileContents.append(_).append(newLine))
308+
309+
newSize
310+
}
311+
}
312+
}
313+
314+
case class TransformedTestDirectives(
315+
withTestPrefix: Seq[StrictDirective],
316+
noTestPrefixAvailable: Seq[StrictDirective],
317+
positions: Option[Position.File]
318+
)
319+
320+
case class LoggingUtilities(
321+
logger: Logger,
322+
workspacePath: os.Path
323+
) {
324+
def relativePath(path: os.Path) =
325+
if (path.startsWith(workspacePath)) path.relativeTo(workspacePath)
326+
else path
327+
}
328+
329+
private val directivesOrdering: Ordering[String] = {
330+
def directivesOrder(key: String): Int = {
331+
val handlersOrder = Seq(
332+
ScalaVersion.handler.keys,
333+
Platform.handler.keys,
334+
Jvm.handler.keys,
335+
JavaHome.handler.keys,
336+
ScalaNative.handler.keys,
337+
ScalaJs.handler.keys,
338+
ScalacOptions.handler.keys,
339+
JavaOptions.handler.keys,
340+
JavacOptions.handler.keys,
341+
JavaProps.handler.keys,
342+
MainClass.handler.keys,
343+
scala.build.preprocessing.directives.Sources.handler.keys,
344+
ObjectWrapper.handler.keys,
345+
Toolkit.handler.keys,
346+
Dependency.handler.keys
347+
)
348+
349+
handlersOrder.zipWithIndex
350+
.find(_._1.flatMap(_.nameAliases).contains(key))
351+
.map(_._2)
352+
.getOrElse(if key.startsWith("publish") then 20 else 15)
353+
}
354+
355+
Ordering.by { directiveLine =>
356+
val key = directiveLine
357+
.stripPrefix("//> using")
358+
.stripLeading()
359+
.stripPrefix("test.")
360+
// separate key from value
361+
.takeWhile(!_.isWhitespace)
362+
363+
directivesOrder(key)
364+
}
365+
}
366+
}

0 commit comments

Comments
 (0)