diff --git a/docs/compiler_options.md b/docs/compiler_options.md index f09f45227..15b92698c 100644 --- a/docs/compiler_options.md +++ b/docs/compiler_options.md @@ -26,6 +26,7 @@ return { | `--gen-target` | `gen_target` | `string` | `build` `gen` `run` | Minimum targeted Lua version for generated code. Options are `5.1`, `5.3` and `5.4`. See [below](#generated-code) for details. | | `include` | `{string}` | `build` | The set of files to compile/check. See below for details on patterns. | | `exclude` | `{string}` | `build` | The set of files to exclude. See below for details on patterns. +| `--keep-hashbang` | | | `gen` | Preserve hashbang line (`#!`) at the top of file if present. | `-s --source-dir` | `source_dir` | `string` | `build` | Set the directory to be searched for files. `build` will compile every .tl file in every subdirectory by default. | `-b --build-dir` | `build_dir` | `string` | `build` | Set the directory for generated files, mimicking the file structure of the source files. | | `files` | `{string}` | `build` | The names of files to be compiled. Does not accept patterns like `include`. diff --git a/spec/cli/gen_spec.lua b/spec/cli/gen_spec.lua index 2d79e03ad..d3132871d 100644 --- a/spec/cli/gen_spec.lua +++ b/spec/cli/gen_spec.lua @@ -72,6 +72,16 @@ end local c = 100 ]] +local script_with_hashbang = [[ +#!/usr/bin/env lua +print("hello world") +]] + +local script_without_hashbang = [[ + +print("hello world") +]] + local function tl_to_lua(name) return (name:gsub("%.tl$", ".lua"):gsub("^" .. util.os_tmp .. util.os_sep, "")) end @@ -185,6 +195,26 @@ describe("tl gen", function() end) end) + it("preserves hashbang with --keep-hashbang", function() + local name = util.write_tmp_file(finally, script_with_hashbang) + local pd = io.popen(util.tl_cmd("gen", "--keep-hashbang", name), "r") + local output = pd:read("*a") + util.assert_popen_close(0, pd:close()) + local lua_name = tl_to_lua(name) + assert.match("Wrote: " .. lua_name, output, 1, true) + util.assert_line_by_line(script_with_hashbang, util.read_file(lua_name)) + end) + + it("drops hashbang when not using --keep-hashbang", function() + local name = util.write_tmp_file(finally, script_with_hashbang) + local pd = io.popen(util.tl_cmd("gen", name), "r") + local output = pd:read("*a") + util.assert_popen_close(0, pd:close()) + local lua_name = tl_to_lua(name) + assert.match("Wrote: " .. lua_name, output, 1, true) + util.assert_line_by_line(script_without_hashbang, util.read_file(lua_name)) + end) + describe("with --gen-target=5.1", function() it("targets generated code to Lua 5.1+", function() local name = util.write_tmp_file(finally, [[ diff --git a/spec/lexer/hashbang_spec.lua b/spec/lexer/hashbang_spec.lua index e249722ca..7d69f1ccb 100644 --- a/spec/lexer/hashbang_spec.lua +++ b/spec/lexer/hashbang_spec.lua @@ -12,9 +12,9 @@ describe("lexer", function() it("skips hashbang at the beginning of a file", function() local syntax_errors = {} local tokens = tl.lex("#!/usr/bin/env lua\nlocal x = 1") + assert.same({"#!/usr/bin/env lua\n", "local", "x", "=", "1", "$EOF$"}, map(function(x) return x.tk end, tokens)) + tl.parse_program(tokens, syntax_errors) assert.same({}, syntax_errors) - assert.same(5, #tokens) - assert.same({"local", "x", "=", "1", "$EOF$"}, map(function(x) return x.tk end, tokens)) end) end) diff --git a/tl b/tl index c8c49339d..a7f214784 100755 --- a/tl +++ b/tl @@ -230,7 +230,7 @@ local function type_check_and_load(tlconfig, filename) return chunk end -local function write_out(tlconfig, result, output_file) +local function write_out(tlconfig, result, output_file, pp_opts) if tlconfig["pretend"] then print("Would Write: " .. output_file) return @@ -243,7 +243,7 @@ local function write_out(tlconfig, result, output_file) end local _ - _, err = ofd:write(tl.pretty_print_ast(result.ast, tlconfig.gen_target) .. "\n") + _, err = ofd:write(tl.pretty_print_ast(result.ast, tlconfig.gen_target, pp_opts) .. "\n") if err then die("error writing " .. output_file .. ": " .. err) end @@ -863,6 +863,7 @@ local function get_args_parser() local gen_command = parser:command("gen", "Generate a Lua file for one or more Teal files.") gen_command:argument("file", "The Teal source file."):args("+") gen_command:flag("-c --check", "Type check and fail on type errors.") + gen_command:flag("--keep-hashbang", "Preserve hashbang line (#!) at the top of file if present.") gen_command:option("-o --output", "Write to instead.") :argname("") @@ -1227,9 +1228,15 @@ commands["gen"] = function(tlconfig, args) local results = {} local err local env + local pp_opts for i, input_file in ipairs(args["file"]) do if not env then env = setup_env(tlconfig, input_file) + pp_opts = { + preserve_indent = true, + preserve_newlines = true, + preserve_hashbang = args["keep_hashbang"] + } end local res = { @@ -1248,7 +1255,7 @@ commands["gen"] = function(tlconfig, args) for _, res in ipairs(results) do if #res.tl_result.syntax_errors == 0 then - write_out(tlconfig, res.tl_result, args["output"] or res.output_file) + write_out(tlconfig, res.tl_result, args["output"] or res.output_file, pp_opts) end end diff --git a/tl.lua b/tl.lua index 057639120..20e41c077 100644 --- a/tl.lua +++ b/tl.lua @@ -1,7 +1,13 @@ local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local debug = _tl_compat and _tl_compat.debug or debug; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local _tl_math_maxinteger = math.maxinteger or math.pow(2, 53); local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack local VERSION = "0.15.3+dev" -local tl = {TypeCheckOptions = {}, Env = {}, Symbol = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, TypeReportEnv = {}, } +local tl = {PrettyPrintOptions = {}, TypeCheckOptions = {}, Env = {}, Symbol = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, TypeReportEnv = {}, } + + + + + + @@ -217,6 +223,7 @@ tl.typecodes = { + local TL_DEBUG = os.getenv("TL_DEBUG") local TL_DEBUG_MAXLINE = _tl_math_maxinteger @@ -279,6 +286,7 @@ end + do @@ -592,10 +600,12 @@ do local len = #input if input:sub(1, 2) == "#!" then + begin_token() i = input:find("\n") if not i then i = len + 1 end + end_token_here("hashbang") y = 2 x = 0 end @@ -1327,6 +1337,7 @@ local is_attribute = attributes + local function is_array_type(t) @@ -3164,7 +3175,16 @@ function tl.parse_program(tokens, errs, filename) filename = filename or "", required_modules = {}, } - local _, node = parse_statements(ps, 1, true) + local i = 1 + local hashbang + if ps.tokens[i].kind == "hashbang" then + hashbang = ps.tokens[i].tk + i = i + 1 + end + local _, node = parse_statements(ps, i, true) + if hashbang then + node.hashbang = hashbang + end clear_redundant_errors(errs) return node, ps.required_modules @@ -3689,18 +3709,16 @@ local spaced_op = { } - - - - local default_pretty_print_ast_opts = { preserve_indent = true, preserve_newlines = true, + preserve_hashbang = false, } local fast_pretty_print_ast_opts = { preserve_indent = false, preserve_newlines = true, + preserve_hashbang = false, } local primitive = { @@ -3837,6 +3855,9 @@ function tl.pretty_print_ast(ast, gen_target, mode) ["statements"] = { after = function(node, children) local out = { y = node.y, h = 0 } + if opts.preserve_hashbang and node.hashbang then + table.insert(out, node.hashbang) + end local space for i, child in ipairs(children) do add_child(out, child, space, indent) @@ -10854,7 +10875,7 @@ function tl.process_string(input, is_lua, env, filename, module_name) return result end -tl.gen = function(input, env) +tl.gen = function(input, env, pp) env = env or assert(tl.init_env(), "Default environment initialization failed") local result = tl.process_string(input, false, env) @@ -10863,7 +10884,7 @@ tl.gen = function(input, env) end local code - code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target) + code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target, pp) return code, result end diff --git a/tl.tl b/tl.tl index 80d230066..7b0ba54c1 100644 --- a/tl.tl +++ b/tl.tl @@ -24,6 +24,12 @@ local record tl "5.4" end + record PrettyPrintOptions + preserve_indent: boolean + preserve_newlines: boolean + preserve_hashbang: boolean + end + record TypeCheckOptions lax: boolean filename: string @@ -125,7 +131,7 @@ local record tl load: function(string, string, LoadMode, {any:any}): LoadFunction, string process: function(string, Env, string, FILE): (Result, string) process_string: function(string, boolean, Env, string, string): Result - gen: function(string, Env): string, Result + gen: function(string, Env, PrettyPrintOptions): string, Result type_check: function(Node, TypeCheckOptions): Result, string init_env: function(boolean, boolean | CompatMode, TargetMode, {string}): Env, string version: function(): string @@ -204,6 +210,7 @@ local type Result = tl.Result local type Env = tl.Env local type Error = tl.Error local type CompatMode = tl.CompatMode +local type PrettyPrintOptions = tl.PrettyPrintOptions local type TypeCheckOptions = tl.TypeCheckOptions local type LoadMode = tl.LoadMode local type LoadFunction = tl.LoadFunction @@ -258,6 +265,7 @@ end -------------------------------------------------------------------------------- local enum TokenKind + "hashbang" "keyword" "op" "string" @@ -592,10 +600,12 @@ do local len = #input if input:sub(1,2) == "#!" then + begin_token() i = input:find("\n") if not i then i = len + 1 end + end_token_here("hashbang") y = 2 x = 0 end @@ -1250,6 +1260,7 @@ local record Node kind: NodeKind symbol_list_slot: integer semicolon: boolean + hashbang: string is_longstring: boolean @@ -3164,7 +3175,16 @@ function tl.parse_program(tokens: {Token}, errs: {Error}, filename: string): Nod filename = filename or "", required_modules = {}, } - local _, node = parse_statements(ps, 1, true) + local i = 1 + local hashbang: string + if ps.tokens[i].kind == "hashbang" then + hashbang = ps.tokens[i].tk + i = i + 1 + end + local _, node = parse_statements(ps, i, true) + if hashbang then + node.hashbang = hashbang + end clear_redundant_errors(errs) return node, ps.required_modules @@ -3688,19 +3708,17 @@ local spaced_op: {integer:{string:boolean}} = { }, } -local record PrettyPrintOpts - preserve_indent: boolean - preserve_newlines: boolean -end -local default_pretty_print_ast_opts: PrettyPrintOpts = { +local default_pretty_print_ast_opts: PrettyPrintOptions = { preserve_indent = true, preserve_newlines = true, + preserve_hashbang = false, } -local fast_pretty_print_ast_opts: PrettyPrintOpts = { +local fast_pretty_print_ast_opts: PrettyPrintOptions = { preserve_indent = false, preserve_newlines = true, + preserve_hashbang = false, } local primitive: {TypeName:string} = { @@ -3714,12 +3732,12 @@ local primitive: {TypeName:string} = { ["thread"] = "thread", } -function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | PrettyPrintOpts): string, string +function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | PrettyPrintOptions): string, string local err: string local indent = 0 - local opts: PrettyPrintOpts - if mode is PrettyPrintOpts then + local opts: PrettyPrintOptions + if mode is PrettyPrintOptions then opts = mode elseif mode == true then opts = fast_pretty_print_ast_opts @@ -3837,6 +3855,9 @@ function tl.pretty_print_ast(ast: Node, gen_target: TargetMode, mode: boolean | ["statements"] = { after = function(node: Node, children: {Output}): Output local out: Output = { y = node.y, h = 0 } + if opts.preserve_hashbang and node.hashbang then + table.insert(out, node.hashbang) + end local space: string for i, child in ipairs(children) do add_child(out, child, space, indent) @@ -10854,7 +10875,7 @@ function tl.process_string(input: string, is_lua: boolean, env: Env, filename: s return result end -tl.gen = function(input: string, env: Env): string, Result +tl.gen = function(input: string, env: Env, pp: PrettyPrintOptions): string, Result env = env or assert(tl.init_env(), "Default environment initialization failed") local result = tl.process_string(input, false, env) @@ -10863,7 +10884,7 @@ tl.gen = function(input: string, env: Env): string, Result end local code: string - code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target) + code, result.gen_error = tl.pretty_print_ast(result.ast, env.gen_target, pp) return code, result end