From d5101ca34bce8d90210a824de413bd46fc6fef18 Mon Sep 17 00:00:00 2001 From: Evgeny Date: Mon, 27 May 2024 16:40:17 +0400 Subject: [PATCH] Update utils.lua as of AIO Launcher 5.3.1 --- README.md | 3 + lib/utils.lua | 122 +++++++++++++++----- samples/utils-tests.lua | 242 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 342 insertions(+), 25 deletions(-) create mode 100644 samples/utils-tests.lua diff --git a/README.md b/README.md index e245ae8..e8d285c 100644 --- a/README.md +++ b/README.md @@ -775,6 +775,9 @@ The standard Lua API is extended with the following features: * `string:split(delimeter)` - splits the string using the specified delimiter and returns a table; * `string:replace(regexp, string)` - replaces the text found by the regular expression with another text; +* `string:trim()` - removes leading and trailing spaces from the string; +* `string:starts_with(substring)` - returns true if the string starts with the specified substring; +* `string:ends_with(substring)` - returns true if the string ends with the specified substring; * `slice(table, start, end)` - returns the part of the table starting with the `start` index and ending with `end` index; * `index(table, value)` - returns the index of the table element; * `key(table, value)` - returns the key of the table element; diff --git a/lib/utils.lua b/lib/utils.lua index 0918600..2141282 100644 --- a/lib/utils.lua +++ b/lib/utils.lua @@ -1,5 +1,25 @@ -- Standard AIO Launcher library +function string:trim() + if #self == 0 then return self end + + return self:match("^%s*(.-)%s*$") +end + +function string:starts_with(start_str) + if #self == 0 or #self < #start_str then return false end + if start_str == nil or start_str == "" then return true end + + return self:sub(1, #start_str) == start_str +end + +function string:ends_with(end_str) + if #self == 0 or #self < #end_str then return false end + if end_str == nil or end_str == "" then return true end + + return self:sub(-#end_str) == end_str +end + function string:split(sep) if sep == nil then sep = "%s" @@ -39,11 +59,6 @@ function index(tab, val) return 0 end --- Deprecated -function get_index(tab, val) - return index(tab, val) -end - function key(tab, val) for index, value in pairs(tab) do if value == val then @@ -54,20 +69,19 @@ function key(tab, val) return 0 end --- Deprecated -function get_key(tab, val) - return key(tab, val) -end - function concat(t1, t2) for _,v in ipairs(t2) do table.insert(t1, v) end end --- Deprecated -function concat_tables(t1, t2) - concat(t1, t2) +function contains(table, val) + for i=1, #table do + if table[i] == val then + return true + end + end + return false end function reverse(tab) @@ -78,20 +92,20 @@ function reverse(tab) end function serialize(tab, ind) - ind = ind and (ind .. " ") or " " + ind = ind and (ind .. " ") or " " local nl = "\n" local str = "{" .. nl for k, v in pairs(tab) do local pr = (type(k)=="string") and ("[\"" .. k .. "\"] = ") or "" str = str .. ind .. pr if type(v) == "table" then - str = str .. serialize(v, ind) .. "," + str = str .. serialize(v, ind) .. ",\n" elseif type(v) == "string" then - str = str .. "\"" .. tostring(v) .. "\"," + str = str .. "\"" .. tostring(v) .. "\",\n" elseif type(v) == "number" or type(v) == "boolean" then - str = str .. tostring(v) .. "," + str = str .. tostring(v) .. ",\n" else - str = str .. "[[" .. tostring(v) .. "]]," + str = str .. "[[" .. tostring(v) .. "]],\n" end end str = str:gsub(".$","") @@ -99,6 +113,21 @@ function serialize(tab, ind) return str end +function deep_copy(orig) + local orig_type = type(orig) + local copy + if orig_type == 'table' then + copy = {} + for orig_key, orig_value in next, orig, nil do + copy[deep_copy(orig_key)] = deep_copy(orig_value) + end + setmetatable(copy, deep_copy(getmetatable(orig))) + else + copy = orig + end + return copy +end + function round(x, n) local n = math.pow(10, n or 0) local x = x * n @@ -109,13 +138,34 @@ end function use(module, ...) for k,v in pairs(module) do if _G[k] then - io.stderr:write("use: skipping duplicate symbol ", k, "\n") + print("use: skipping duplicate symbol ", k, "\n") else _G[k] = module[k] end end end +function for_each(tbl, callback) + for index, value in ipairs(tbl) do + callback(value, index, tbl) + end +end + +-- Deprecated +function get_index(tab, val) + return index(tab, val) +end + +-- Deprecated +function get_key(tab, val) + return key(tab, val) +end + +-- Deprecated +function concat_tables(t1, t2) + concat(t1, t2) +end + -- Functional Library -- -- @file functional.lua @@ -146,6 +196,29 @@ function filter(func, tbl) return newtbl end +-- skip(table, N) +-- e.g: skip({1,2,3,4}, 2) -> {3,4} +function skip(tbl, N) + local result = {} + for i = N+1, #tbl do + table.insert(result, tbl[i]) + end + return result +end + +-- take(table, N) +-- e.g: take({1,2,3,4}, 2) -> {1,2} +function take(tbl, N) + local result = {} + for i = 1, N do + if tbl[i] == nil then + break + end + table.insert(result, tbl[i]) + end + return result +end + -- head(table) -- e.g: head({1,2,3}) -> 1 function head(tbl) @@ -158,13 +231,12 @@ end -- XXX This is a BAD and ugly implementation. -- should return the address to next porinter, like in C (arr+1) function tail(tbl) - if table.getn(tbl) < 1 then + if #tbl < 1 then return nil else local newtbl = {} - local tblsize = table.getn(tbl) local i = 2 - while (i <= tblsize) do + while (i <= #tbl) do table.insert(newtbl, i-1, tbl[i]) i = i + 1 end @@ -190,9 +262,9 @@ end -- curry(f,g) -- e.g: printf = curry(io.write, string.format) -- -> function(...) return io.write(string.format(unpack(arg))) end -function curry(f,g) +function curry(f, g) return function (...) - return f(g(unpack(arg))) + return f(g(table.unpack({...}))) end end @@ -225,7 +297,7 @@ end -- local is_odd = is(bind2(math.mod, 2), 0) is = function(check, expected) return function (...) - if (check(unpack(arg)) == expected) then + if (check(table.unpack({...})) == expected) then return true else return false diff --git a/samples/utils-tests.lua b/samples/utils-tests.lua new file mode 100644 index 0000000..0df6ea9 --- /dev/null +++ b/samples/utils-tests.lua @@ -0,0 +1,242 @@ +local print_tab = {} + +function print(str) + table.insert(print_tab, str) + ui:show_lines(print_tab) +end + +function test_trim() + assert((" hello "):trim() == "hello") + assert(("no_spaces"):trim() == "no_spaces") + assert((""):trim() == "") + assert((" "):trim() == "") + print("test_trim passed") +end + +function test_starts_with() + assert(("hello"):starts_with("he") == true) + assert(("hello"):starts_with("hello") == true) + assert(("hello"):starts_with("hi") == false) + assert((""):starts_with("he") == false) + assert(("hello"):starts_with("") == true) + print("test_starts_with passed") +end + +function test_ends_with() + assert(("hello"):ends_with("lo") == true) + assert(("hello"):ends_with("hello") == true) + assert(("hello"):ends_with("world") == false) + assert((""):ends_with("lo") == false) + assert(("hello"):ends_with("") == true) + print("test_ends_with passed") +end + +function test_split() + local result = ("a,b,c"):split(",") + assert(#result == 3 and result[1] == "a" and result[2] == "b" and result[3] == "c") + result = ("a b c"):split(" ") + assert(#result == 3 and result[1] == "a" and result[2] == "b" and result[3] == "c") + result = ("a b c"):split("%s+") + assert(#result == 3 and result[1] == "a" and result[2] == "b" and result[3] == "c") + print("test_split passed") +end + +function test_replace() + assert(("hello world"):replace("world", "Lua") == "hello Lua") + assert(("hello world world"):replace("world", "Lua") == "hello Lua Lua") + assert((""):replace("world", "Lua") == "") + print("test_replace passed") +end + +function test_slice() + local result = slice({1, 2, 3, 4, 5}, 2, 4) + assert(#result == 3 and result[1] == 2 and result[2] == 3 and result[3] == 4) + print("test_slice passed") +end + +function test_index() + assert(index({1, 2, 3}, 2) == 2) + assert(index({1, 2, 3}, 4) == 0) + print("test_index passed") +end + +function test_key() + assert(key({a = 1, b = 2, c = 3}, 2) == "b") + assert(key({a = 1, b = 2, c = 3}, 4) == 0) + print("test_key passed") +end + +function test_concat() + local t1 = {1, 2} + local t2 = {3, 4} + concat(t1, t2) + assert(#t1 == 4 and t1[3] == 3 and t1[4] == 4) + print("test_concat passed") +end + +function test_contains() + assert(contains({1, 2, 3}, 2) == true) + assert(contains({1, 2, 3}, 4) == false) + print("test_contains passed") +end + +function test_reverse() + local result = reverse({1, 2, 3}) + assert(#result == 3 and result[1] == 3 and result[2] == 2 and result[3] == 1) + print("test_reverse passed") +end + +function test_serialize() + local result = serialize({a = 1, b = {c = 2}}) + assert(result:replace("\n", "") == '{ ["a"] = 1, ["b"] = { ["c"] = 2, }, }') + print("test_serialize passed") +end + +function test_deep_copy() + local original = {a = 1, b = {c = 2}} + local copy = deep_copy(original) + assert(copy.b.c == 2) + copy.b.c = 3 + assert(original.b.c == 2) + print("test_deep_copy passed") +end + +function test_round() + assert(round(1.2345, 2) == 1.23) + assert(round(1.2345, 0) == 1) + assert(round(-1.2345, 2) == -1.23) + print("test_round passed") +end + +function test_use() + local module = {test_var = 42} + use(module) + assert(test_var == 42) + print("test_use passed") +end + +function test_for_each() + local sum = 0 + for_each({1, 2, 3}, function(value) sum = sum + value end) + assert(sum == 6) + print("test_for_each passed") +end + +function test_map() + local result = map(function(x) return x * 2 end, {1, 2, 3}) + assert(#result == 3 and result[1] == 2 and result[2] == 4 and result[3] == 6) + print("test_map passed") +end + +function test_filter() + local result = filter(function(x) return x % 2 == 0 end, {1, 2, 3, 4}) + debug:log(#result) + assert(#result == 2 and result[2] == 2 and result[4] == 4) + print("test_filter passed") +end + +function test_skip() + local result = skip({1, 2, 3, 4}, 2) + assert(#result == 2 and result[1] == 3 and result[2] == 4) + print("test_skip passed") +end + +function test_take() + local result = take({1, 2, 3, 4}, 2) + assert(#result == 2 and result[1] == 1 and result[2] == 2) + print("test_take passed") +end + +function test_head() + assert(head({1, 2, 3}) == 1) + print("test_head passed") +end + +function test_tail() + local result = tail({1, 2, 3}) + assert(#result == 2 and result[1] == 2 and result[2] == 3) + print("test_tail passed") +end + +function test_foldr() + local result = foldr(operator.mul, 1, {1, 2, 3, 4, 5}) + assert(result == 120) + print("test_foldr passed") +end + +function test_reduce() + local result = reduce(operator.add, {1, 2, 3, 4}) + assert(result == 10) + print("test_reduce passed") +end + +function test_curry() + local function add(x, y) + return x + y + end + + local function multiply(x, y) + return x * y + end + + local curried_add = curry(add, function(...) return ... end) + local curried_multiply = curry(multiply, function(...) return ... end) + + assert(curried_add(3, 4) == 7) -- 3 + 4 = 7 + assert(curried_add(10, 20) == 30) -- 10 + 20 = 30 + + assert(curried_multiply(3, 4) == 12) -- 3 * 4 = 12 + assert(curried_multiply(10, 20) == 200) -- 10 * 20 = 200 + + print("test_curry passed") +end + +function test_bind1() + local mul5 = bind1(operator.mul, 5) + assert(mul5(10) == 50) + print("test_bind1 passed") +end + +function test_bind2() + local sub2 = bind2(operator.sub, 2) + assert(sub2(5) == 3) + print("test_bind2 passed") +end + +function test_is() + local is_table = is(type, "table") + assert(is_table({}) == true) + assert(is_table(42) == false) + print("test_is passed") +end + +function on_resume() + test_trim() + test_starts_with() + test_ends_with() + test_split() + test_replace() + test_slice() + test_index() + test_key() + test_concat() + test_contains() + test_reverse() + test_serialize() + test_deep_copy() + test_round() + test_use() + test_for_each() + test_map() + test_filter() + test_skip() + test_take() + test_head() + test_tail() + test_foldr() + test_reduce() + test_bind1() + test_bind2() + test_curry() + test_is() +end