diff --git a/AK/SIMDExtras.h b/AK/SIMDExtras.h index eba20f2b97..78116ef53a 100644 --- a/AK/SIMDExtras.h +++ b/AK/SIMDExtras.h @@ -6,6 +6,7 @@ #pragma once +#include #include // Functions returning vectors or accepting vector arguments have different calling conventions @@ -166,6 +167,31 @@ ALWAYS_INLINE static void store4_masked(VectorType v, UnderlyingType* a, Underly *d = v[3]; } +// Shuffle + +template T> +ALWAYS_INLINE static T shuffle(T a, T control) +{ + // FIXME: This is probably not the fastest way to do this. + return T { + a[control[0] & 0xf], + a[control[1] & 0xf], + a[control[2] & 0xf], + a[control[3] & 0xf], + a[control[4] & 0xf], + a[control[5] & 0xf], + a[control[6] & 0xf], + a[control[7] & 0xf], + a[control[8] & 0xf], + a[control[9] & 0xf], + a[control[10] & 0xf], + a[control[11] & 0xf], + a[control[12] & 0xf], + a[control[13] & 0xf], + a[control[14] & 0xf], + a[control[15] & 0xf], + }; +} } #pragma GCC diagnostic pop diff --git a/Meta/CMake/wasm_spec_tests.cmake b/Meta/CMake/wasm_spec_tests.cmake index 1a4bf480b6..38d79eb99f 100644 --- a/Meta/CMake/wasm_spec_tests.cmake +++ b/Meta/CMake/wasm_spec_tests.cmake @@ -16,7 +16,7 @@ if(INCLUDE_WASM_SPEC_TESTS) find_program(WAT2WASM wat2wasm REQUIRED) find_program(PRETTIER prettier OPTIONAL) - if (NOT SKIP_PRETTIER AND NOT PRETTIER_FOUND) + if (NOT SKIP_PRETTIER AND PRETTIER EQUAL "PRETTIER-NOTFOUND") message(FATAL_ERROR "Prettier required to format Wasm spec tests! Install prettier or set WASM_SPEC_TEST_SKIP_FORMATTING to ON") endif() diff --git a/Meta/generate-libwasm-spec-test.py b/Meta/generate-libwasm-spec-test.py index 117ceba6b0..3f4bbecf72 100644 --- a/Meta/generate-libwasm-spec-test.py +++ b/Meta/generate-libwasm-spec-test.py @@ -140,6 +140,7 @@ def parse_typed_value(ast): value.frombytes(parse_v128_chunk(num[0], ast[1][0])) assert len(value) - s == size, f'Expected {size} bytes, got {len(value) - s} bytes' + assert len(value) == 16, f'Expected 16 bytes, got {len(value)} bytes' return { 'type': types[ast[0][0]], 'value': value.tobytes().hex() diff --git a/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.h b/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.h index ed73d6fb8f..bc1166bd42 100644 --- a/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.h +++ b/Userland/Libraries/LibWasm/AbstractMachine/AbstractMachine.h @@ -140,10 +140,17 @@ public: if constexpr (IsSame || (!IsFloatingPoint && IsSame>)) { result = static_cast(value); } else if constexpr (!IsFloatingPoint && IsConvertible) { - if (AK::is_within_range(value)) - result = static_cast(value); + // NOTE: No implicit vector <-> scalar conversion. + if constexpr (!IsSame) { + if (AK::is_within_range(value)) + result = static_cast(value); + } } }, + [&](u128 value) { + if constexpr (IsSame) + result = value; + }, [&](Reference const& value) { if constexpr (IsSame) { result = value; diff --git a/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp b/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp index 362446be36..085735de75 100644 --- a/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp +++ b/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.cpp @@ -5,9 +5,11 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include #include #include #include +#include #include #include #include @@ -15,6 +17,8 @@ #include #include +using namespace AK::SIMD; + namespace Wasm { #define TRAP_IF_NOT(x) \ @@ -110,6 +114,148 @@ void BytecodeInterpreter::load_and_push(Configuration& configuration, Instructio configuration.stack().peek() = Value(static_cast(read_value(slice))); } +template +ALWAYS_INLINE static TDst convert_vector(TSrc v) +{ + return __builtin_convertvector(v, TDst); +} + +template typename SetSign> +void BytecodeInterpreter::load_and_push_mxn(Configuration& configuration, Instruction const& instruction) +{ + auto& address = configuration.frame().module().memories().first(); + auto memory = configuration.store().get(address); + if (!memory) { + m_trap = Trap { "Nonexistent memory" }; + return; + } + auto& arg = instruction.arguments().get(); + auto& entry = configuration.stack().peek(); + auto base = entry.get().to(); + if (!base.has_value()) { + m_trap = Trap { "Memory access out of bounds" }; + return; + } + u64 instance_address = static_cast(bit_cast(base.value())) + arg.offset; + Checked addition { instance_address }; + addition += M * N / 8; + if (addition.has_overflow() || addition.value() > memory->size()) { + m_trap = Trap { "Memory access out of bounds" }; + dbgln("LibWasm: Memory access out of bounds (expected {} to be less than or equal to {})", instance_address + M * N / 8, memory->size()); + return; + } + dbgln_if(WASM_TRACE_DEBUG, "vec-load({} : {}) -> stack", instance_address, M * N / 8); + auto slice = memory->data().bytes().slice(instance_address, M * N / 8); + using V64 = NativeVectorType; + using V128 = NativeVectorType; + + V64 bytes { 0 }; + if (bit_cast(slice.data()) % sizeof(V64) == 0) + bytes = *bit_cast(slice.data()); + else + ByteReader::load(slice.data(), bytes); + + configuration.stack().peek() = Value(bit_cast(convert_vector(bytes))); +} + +template +void BytecodeInterpreter::load_and_push_m_splat(Configuration& configuration, Instruction const& instruction) +{ + auto& address = configuration.frame().module().memories().first(); + auto memory = configuration.store().get(address); + if (!memory) { + m_trap = Trap { "Nonexistent memory" }; + return; + } + auto& arg = instruction.arguments().get(); + auto& entry = configuration.stack().peek(); + auto base = entry.get().to(); + if (!base.has_value()) { + m_trap = Trap { "Memory access out of bounds" }; + return; + } + u64 instance_address = static_cast(bit_cast(base.value())) + arg.offset; + Checked addition { instance_address }; + addition += M / 8; + if (addition.has_overflow() || addition.value() > memory->size()) { + m_trap = Trap { "Memory access out of bounds" }; + dbgln("LibWasm: Memory access out of bounds (expected {} to be less than or equal to {})", instance_address + M / 8, memory->size()); + return; + } + dbgln_if(WASM_TRACE_DEBUG, "vec-splat({} : {}) -> stack", instance_address, M / 8); + auto slice = memory->data().bytes().slice(instance_address, M / 8); + auto value = read_value>(slice); + set_top_m_splat(configuration, value); +} + +template typename NativeType> +void BytecodeInterpreter::set_top_m_splat(Wasm::Configuration& configuration, NativeType value) +{ + auto push = [&](auto result) { + configuration.stack().peek() = Value(bit_cast(result)); + }; + + if constexpr (IsFloatingPoint>) { + if constexpr (M == 32) // 32 -> 32x4 + push(expand4(value)); + else if constexpr (M == 64) // 64 -> 64x2 + push(f64x2 { value, value }); + else + static_assert(DependentFalse>, "Invalid vector size"); + } else { + if constexpr (M == 8) // 8 -> 8x4 -> 32x4 + push(expand4(bit_cast(u8x4 { value, value, value, value }))); + else if constexpr (M == 16) // 16 -> 16x2 -> 32x4 + push(expand4(bit_cast(u16x2 { value, value }))); + else if constexpr (M == 32) // 32 -> 32x4 + push(expand4(value)); + else if constexpr (M == 64) // 64 -> 64x2 + push(u64x2 { value, value }); + else + static_assert(DependentFalse>, "Invalid vector size"); + } +} + +template typename NativeType> +void BytecodeInterpreter::pop_and_push_m_splat(Wasm::Configuration& configuration, Instruction const&) +{ + using PopT = Conditional, NativeType<64>>; + using ReadT = NativeType; + auto entry = configuration.stack().peek(); + auto value = static_cast(*entry.get().to()); + dbgln_if(WASM_TRACE_DEBUG, "stack({}) -> splat({})", value, M); + set_top_m_splat(configuration, value); +} + +template typename SetSign, typename VectorType> +Optional BytecodeInterpreter::pop_vector(Configuration& configuration) +{ + auto value = peek_vector(configuration); + if (value.has_value()) + configuration.stack().pop(); + return value; +} + +template typename SetSign, typename VectorType> +Optional BytecodeInterpreter::peek_vector(Configuration& configuration) +{ + auto& entry = configuration.stack().peek(); + auto value = entry.get().value().get_pointer(); + if (!value) + return {}; + auto vector = bit_cast(*value); + dbgln_if(WASM_TRACE_DEBUG, "stack({}) peek-> vector({:x})", *value, bit_cast(vector)); + return vector; +} + +template +static u128 shuffle_vector(VectorType values, VectorType indices) +{ + auto vector = bit_cast(values); + auto indices_vector = bit_cast(indices); + return bit_cast(shuffle(vector, indices_vector)); +} + void BytecodeInterpreter::call_address(Configuration& configuration, FunctionAddress address) { TRAP_IF_NOT(m_stack_info.size_free() >= Constants::minimum_stack_space_to_keep_free); @@ -150,15 +296,15 @@ void BytecodeInterpreter::call_address(Configuration& configuration, FunctionAdd configuration.stack().entries().unchecked_append(move(entry)); } -template +template void BytecodeInterpreter::binary_numeric_operation(Configuration& configuration) { auto rhs_entry = configuration.stack().pop(); auto& lhs_entry = configuration.stack().peek(); auto rhs_ptr = rhs_entry.get_pointer(); auto lhs_ptr = lhs_entry.get_pointer(); - auto rhs = rhs_ptr->to(); - auto lhs = lhs_ptr->to(); + auto rhs = rhs_ptr->to(); + auto lhs = lhs_ptr->to(); PushType result; auto call_result = Operator {}(lhs.value(), rhs.value()); if constexpr (IsSpecializationOf) { @@ -1016,6 +1162,78 @@ void BytecodeInterpreter::interpret(Configuration& configuration, InstructionPoi return unary_operation>(configuration); case Instructions::i64_trunc_sat_f64_u.value(): return unary_operation>(configuration); + case Instructions::v128_const.value(): + configuration.stack().push(Value(instruction.arguments().get())); + return; + case Instructions::v128_load.value(): + return load_and_push(configuration, instruction); + case Instructions::v128_load8x8_s.value(): + return load_and_push_mxn<8, 8, MakeSigned>(configuration, instruction); + case Instructions::v128_load8x8_u.value(): + return load_and_push_mxn<8, 8, MakeUnsigned>(configuration, instruction); + case Instructions::v128_load16x4_s.value(): + return load_and_push_mxn<16, 4, MakeSigned>(configuration, instruction); + case Instructions::v128_load16x4_u.value(): + return load_and_push_mxn<16, 4, MakeUnsigned>(configuration, instruction); + case Instructions::v128_load32x2_s.value(): + return load_and_push_mxn<32, 2, MakeSigned>(configuration, instruction); + case Instructions::v128_load32x2_u.value(): + return load_and_push_mxn<32, 2, MakeUnsigned>(configuration, instruction); + case Instructions::v128_load8_splat.value(): + return load_and_push_m_splat<8>(configuration, instruction); + case Instructions::v128_load16_splat.value(): + return load_and_push_m_splat<16>(configuration, instruction); + case Instructions::v128_load32_splat.value(): + return load_and_push_m_splat<32>(configuration, instruction); + case Instructions::v128_load64_splat.value(): + return load_and_push_m_splat<64>(configuration, instruction); + case Instructions::i8x16_splat.value(): + return pop_and_push_m_splat<8, NativeIntegralType>(configuration, instruction); + case Instructions::i16x8_splat.value(): + return pop_and_push_m_splat<16, NativeIntegralType>(configuration, instruction); + case Instructions::i32x4_splat.value(): + return pop_and_push_m_splat<32, NativeIntegralType>(configuration, instruction); + case Instructions::i64x2_splat.value(): + return pop_and_push_m_splat<64, NativeIntegralType>(configuration, instruction); + case Instructions::f32x4_splat.value(): + return pop_and_push_m_splat<32, NativeFloatingType>(configuration, instruction); + case Instructions::f64x2_splat.value(): + return pop_and_push_m_splat<64, NativeFloatingType>(configuration, instruction); + case Instructions::i8x16_shuffle.value(): { + auto indices = pop_vector(configuration); + TRAP_IF_NOT(indices.has_value()); + auto vector = peek_vector(configuration); + TRAP_IF_NOT(vector.has_value()); + auto result = shuffle_vector(vector.value(), indices.value()); + configuration.stack().peek() = Value(result); + return; + } + case Instructions::v128_store.value(): + return pop_and_store(configuration, instruction); + case Instructions::i8x16_shl.value(): + return binary_numeric_operation, i32>(configuration); + case Instructions::i8x16_shr_u.value(): + return binary_numeric_operation, i32>(configuration); + case Instructions::i8x16_shr_s.value(): + return binary_numeric_operation, i32>(configuration); + case Instructions::i16x8_shl.value(): + return binary_numeric_operation, i32>(configuration); + case Instructions::i16x8_shr_u.value(): + return binary_numeric_operation, i32>(configuration); + case Instructions::i16x8_shr_s.value(): + return binary_numeric_operation, i32>(configuration); + case Instructions::i32x4_shl.value(): + return binary_numeric_operation, i32>(configuration); + case Instructions::i32x4_shr_u.value(): + return binary_numeric_operation, i32>(configuration); + case Instructions::i32x4_shr_s.value(): + return binary_numeric_operation, i32>(configuration); + case Instructions::i64x2_shl.value(): + return binary_numeric_operation, i32>(configuration); + case Instructions::i64x2_shr_u.value(): + return binary_numeric_operation, i32>(configuration); + case Instructions::i64x2_shr_s.value(): + return binary_numeric_operation, i32>(configuration); case Instructions::table_init.value(): case Instructions::elem_drop.value(): case Instructions::table_copy.value(): @@ -1024,7 +1242,7 @@ void BytecodeInterpreter::interpret(Configuration& configuration, InstructionPoi case Instructions::table_fill.value(): default: unimplemented:; - dbgln("Instruction '{}' not implemented", instruction_name(instruction.opcode())); + dbgln_if(WASM_TRACE_DEBUG, "Instruction '{}' not implemented", instruction_name(instruction.opcode())); m_trap = Trap { DeprecatedString::formatted("Unimplemented instruction {}", instruction_name(instruction.opcode())) }; return; } diff --git a/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.h b/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.h index 9b2d96e660..db5ab5467b 100644 --- a/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.h +++ b/Userland/Libraries/LibWasm/AbstractMachine/BytecodeInterpreter.h @@ -50,10 +50,22 @@ protected: void load_and_push(Configuration&, Instruction const&); template void pop_and_store(Configuration&, Instruction const&); + template typename SetSign> + void load_and_push_mxn(Configuration&, Instruction const&); + template + void load_and_push_m_splat(Configuration&, Instruction const&); + template typename NativeType> + void set_top_m_splat(Configuration&, NativeType); + template typename NativeType> + void pop_and_push_m_splat(Configuration&, Instruction const&); + template typename SetSign, typename VectorType = Native128ByteVectorOf> + Optional pop_vector(Configuration&); + template typename SetSign, typename VectorType = Native128ByteVectorOf> + Optional peek_vector(Configuration&); void store_to_memory(Configuration&, Instruction const&, ReadonlyBytes data, i32 base); void call_address(Configuration&, FunctionAddress); - template + template void binary_numeric_operation(Configuration&); template diff --git a/Userland/Libraries/LibWasm/AbstractMachine/Operators.h b/Userland/Libraries/LibWasm/AbstractMachine/Operators.h index a3a61589d1..c869a36a52 100644 --- a/Userland/Libraries/LibWasm/AbstractMachine/Operators.h +++ b/Userland/Libraries/LibWasm/AbstractMachine/Operators.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, Ali Mohammad Pur + * Copyright (c) 2021-2023, Ali Mohammad Pur * * SPDX-License-Identifier: BSD-2-Clause */ @@ -9,12 +9,15 @@ #include #include #include +#include #include #include #include #include -namespace Operators { +namespace Wasm::Operators { + +using namespace AK::SIMD; #define DEFINE_BINARY_OPERATOR(Name, operation) \ struct Name { \ @@ -62,6 +65,7 @@ struct Divide { static StringView name() { return "/"sv; } }; + struct Modulo { template auto operator()(Lhs lhs, Rhs rhs) const @@ -77,18 +81,21 @@ struct Modulo { static StringView name() { return "%"sv; } }; + struct BitShiftLeft { template auto operator()(Lhs lhs, Rhs rhs) const { return lhs << (rhs % (sizeof(lhs) * 8)); } static StringView name() { return "<<"sv; } }; + struct BitShiftRight { template auto operator()(Lhs lhs, Rhs rhs) const { return lhs >> (rhs % (sizeof(lhs) * 8)); } static StringView name() { return ">>"sv; } }; + struct BitRotateLeft { template auto operator()(Lhs lhs, Rhs rhs) const @@ -102,6 +109,7 @@ struct BitRotateLeft { static StringView name() { return "rotate_left"sv; } }; + struct BitRotateRight { template auto operator()(Lhs lhs, Rhs rhs) const @@ -115,6 +123,55 @@ struct BitRotateRight { static StringView name() { return "rotate_right"sv; } }; + +template +struct VectorShiftLeft { + auto operator()(u128 lhs, i32 rhs) const + { + auto shift_value = rhs % (sizeof(lhs) * 8 / VectorSize); + return bit_cast(bit_cast, MakeUnsigned>>(lhs) << shift_value); + } + static StringView name() + { + switch (VectorSize) { + case 16: + return "vec(8x16)<<"sv; + case 8: + return "vec(16x8)<<"sv; + case 4: + return "vec(32x4)<<"sv; + case 2: + return "vec(64x2)<<"sv; + default: + VERIFY_NOT_REACHED(); + } + } +}; + +template typename SetSign> +struct VectorShiftRight { + auto operator()(u128 lhs, i32 rhs) const + { + auto shift_value = rhs % (sizeof(lhs) * 8 / VectorSize); + return bit_cast(bit_cast, SetSign>>(lhs) >> shift_value); + } + static StringView name() + { + switch (VectorSize) { + case 16: + return "vec(8x16)>>"sv; + case 8: + return "vec(16x8)>>"sv; + case 4: + return "vec(32x4)>>"sv; + case 2: + return "vec(64x2)>>"sv; + default: + VERIFY_NOT_REACHED(); + } + } +}; + struct Minimum { template auto operator()(Lhs lhs, Rhs rhs) const @@ -134,6 +191,7 @@ struct Minimum { static StringView name() { return "minimum"sv; } }; + struct Maximum { template auto operator()(Lhs lhs, Rhs rhs) const @@ -153,6 +211,7 @@ struct Maximum { static StringView name() { return "maximum"sv; } }; + struct CopySign { template auto operator()(Lhs lhs, Rhs rhs) const @@ -176,6 +235,7 @@ struct EqualsZero { static StringView name() { return "== 0"sv; } }; + struct CountLeadingZeros { template i32 operator()(Lhs lhs) const @@ -191,6 +251,7 @@ struct CountLeadingZeros { static StringView name() { return "clz"sv; } }; + struct CountTrailingZeros { template i32 operator()(Lhs lhs) const @@ -206,6 +267,7 @@ struct CountTrailingZeros { static StringView name() { return "ctz"sv; } }; + struct PopCount { template auto operator()(Lhs lhs) const @@ -218,18 +280,21 @@ struct PopCount { static StringView name() { return "popcnt"sv; } }; + struct Absolute { template auto operator()(Lhs lhs) const { return AK::abs(lhs); } static StringView name() { return "abs"sv; } }; + struct Negate { template auto operator()(Lhs lhs) const { return -lhs; } static StringView name() { return "== 0"sv; } }; + struct Ceil { template auto operator()(Lhs lhs) const @@ -244,6 +309,7 @@ struct Ceil { static StringView name() { return "ceil"sv; } }; + struct Floor { template auto operator()(Lhs lhs) const @@ -258,9 +324,10 @@ struct Floor { static StringView name() { return "floor"sv; } }; + struct Truncate { template - Result operator()(Lhs lhs) const + AK::Result operator()(Lhs lhs) const { if constexpr (IsSame) return truncf(lhs); @@ -272,6 +339,7 @@ struct Truncate { static StringView name() { return "truncate"sv; } }; + struct NearbyIntegral { template auto operator()(Lhs lhs) const @@ -286,6 +354,7 @@ struct NearbyIntegral { static StringView name() { return "round"sv; } }; + struct SquareRoot { template auto operator()(Lhs lhs) const diff --git a/Userland/Libraries/LibWasm/Types.h b/Userland/Libraries/LibWasm/Types.h index ffba0ffb2c..a2769c5979 100644 --- a/Userland/Libraries/LibWasm/Types.h +++ b/Userland/Libraries/LibWasm/Types.h @@ -20,6 +20,18 @@ namespace Wasm { +template +using NativeIntegralType = Conditional>>>; + +template +using NativeFloatingType = Conditional>; + +template typename SetSign, typename ElementType = SetSign>> +using NativeVectorType __attribute__((vector_size(N * sizeof(ElementType)))) = ElementType; + +template typename SetSign> +using Native128ByteVectorOf = NativeVectorType; + enum class ParseError { UnexpectedEof, UnknownInstruction,