diff --git a/Libraries/LibC/netdb.cpp b/Libraries/LibC/netdb.cpp index f58d45f02c..bc8aebc400 100644 --- a/Libraries/LibC/netdb.cpp +++ b/Libraries/LibC/netdb.cpp @@ -49,6 +49,7 @@ static hostent __gethostbyaddr_buffer; static char __gethostbyaddr_name_buffer[512]; static in_addr_t* __gethostbyaddr_address_list_buffer[2]; +//Get service entry buffers and file information for the getservent() family of functions static FILE* services_file = nullptr; static const char* services_path = "/etc/services"; @@ -62,6 +63,19 @@ static Vector __getserv_alias_list; static bool keep_service_file_open = false; static ssize_t service_file_offset = 0; +//Get protocol entry buffers and file information for the getprotent() family of functions +static FILE* protocols_file = nullptr; +static const char* protocols_path = "/etc/protocols"; + +static bool fill_getproto_buffers(char* line, ssize_t read); +static protoent __getproto_buffer; +static char __getproto_name_buffer[512]; +static Vector __getproto_alias_list_buffer; +static Vector __getproto_alias_list; +static int __getproto_protocol_buffer; +static bool keep_protocols_file_open = false; +static ssize_t protocol_file_offset = 0; + static int connect_to_lookup_server() { int fd = socket(AF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC, 0); @@ -263,9 +277,11 @@ struct servent* getservent() __getserv_buffer.s_port = __getserv_port_buffer; __getserv_buffer.s_proto = __getserv_protocol_buffer; + __getserv_alias_list.clear(); for (auto& alias : __getserv_alias_list_buffer) { __getserv_alias_list.append((char*)alias.data()); } + __getserv_buffer.s_aliases = __getserv_alias_list.data(); service_entry = &__getserv_buffer; @@ -395,7 +411,185 @@ static bool fill_getserv_buffers(char* line, ssize_t read) if (split_line[i].starts_with('#')) { break; } - __getserv_alias_list_buffer.append(split_line[i].to_byte_buffer()); + auto alias = split_line[i].to_byte_buffer(); + alias.append("\0", sizeof(char)); + __getserv_alias_list_buffer.append(alias); + } + } + + return true; +} + +struct protoent* getprotoent() +{ + //If protocols file isn't open, attempt to open and return null on failure. + if (!protocols_file) { + protocols_file = fopen(protocols_path, "r"); + + if (!protocols_file) { + perror("error opening protocols file"); + return nullptr; + } + } + + if (fseek(protocols_file, protocol_file_offset, SEEK_SET) != 0) { + perror("error seeking protocols file"); + fclose(protocols_file); + return nullptr; + } + + char* line = nullptr; + size_t len = 0; + ssize_t read; + + auto free_line_on_exit = ScopeGuard([line] { + if (line) { + free(line); + } + }); + + do { + read = getline(&line, &len, protocols_file); + protocol_file_offset += read; + if (read > 0 && (line[0] >= 65 && line[0] <= 122)) { + break; + } + } while (read != -1); + + if (read == -1) { + fclose(protocols_file); + protocols_file = nullptr; + protocol_file_offset = 0; + return nullptr; + } + + struct protoent* protocol_entry = nullptr; + if (!fill_getproto_buffers(line, read)) + return nullptr; + + __getproto_buffer.p_name = __getproto_name_buffer; + __getproto_buffer.p_proto = __getproto_protocol_buffer; + + __getproto_alias_list.clear(); + + for (auto& alias : __getproto_alias_list_buffer) { + __getproto_alias_list.append((char*)alias.data()); + } + + __getproto_buffer.p_aliases = __getproto_alias_list.data(); + protocol_entry = &__getproto_buffer; + + if (!keep_protocols_file_open) + endprotoent(); + + return protocol_entry; +} + +struct protoent* getprotobyname(const char* name) +{ + bool previous_file_open_setting = keep_protocols_file_open; + setprotoent(1); + struct protoent* current_protocol = nullptr; + auto protocol_file_handler = ScopeGuard([previous_file_open_setting] { + if (!previous_file_open_setting) { + endprotoent(); + } + }); + + while (true) { + current_protocol = getprotoent(); + if (current_protocol == nullptr) + break; + else if (strcmp(current_protocol->p_name, name) == 0) + break; + } + + return current_protocol; +} + +struct protoent* getprotobynumber(int proto) +{ + bool previous_file_open_setting = keep_protocols_file_open; + setprotoent(1); + struct protoent* current_protocol = nullptr; + auto protocol_file_handler = ScopeGuard([previous_file_open_setting] { + if (!previous_file_open_setting) { + endprotoent(); + } + }); + + while (true) { + current_protocol = getprotoent(); + if (current_protocol == nullptr) + break; + else if (current_protocol->p_proto == proto) + break; + } + + return current_protocol; +} + +void setprotoent(int stay_open) +{ + if (!protocols_file) { + protocols_file = fopen(protocols_path, "r"); + + if (!protocols_file) { + perror("error opening protocols file"); + return; + } + } + rewind(protocols_file); + keep_protocols_file_open = stay_open; + protocol_file_offset = 0; +} + +void endprotoent() +{ + if (!protocols_file) { + return; + } + fclose(protocols_file); + protocols_file = nullptr; +} + +static bool fill_getproto_buffers(char* line, ssize_t read) +{ + String string_line = String(line, read); + string_line.replace(" ", "\t", true); + auto split_line = string_line.split('\t'); + + //This indicates an incorrect file format. Protocols file entries should always have at least a name and a protocol. + if (split_line.size() < 2) { + perror("malformed protocols file: entry"); + return false; + } + if (sizeof(__getproto_name_buffer) >= split_line[0].length() + 1) { + strncpy(__getproto_name_buffer, split_line[0].characters(), split_line[0].length() + 1); + } else { + perror("invalid buffer length: protocol name"); + return false; + } + + bool conversion_checker; + __getproto_protocol_buffer = split_line[1].to_int(conversion_checker); + + if (!conversion_checker) { + return false; + } + + __getproto_alias_list_buffer.clear(); + + //If there are aliases for the protocol, we will fill the alias list buffer. + if (split_line.size() > 2 && !split_line[2].starts_with('#')) { + + for (size_t i = 2; i < split_line.size(); i++) { + if (split_line[i].starts_with('#')) { + break; + } + auto alias = split_line[i].to_byte_buffer(); + alias.append("\0", sizeof(char)); + __getproto_alias_list_buffer.append(alias); } } diff --git a/Libraries/LibC/netdb.h b/Libraries/LibC/netdb.h index 802c89a39d..c2bd5a130f 100644 --- a/Libraries/LibC/netdb.h +++ b/Libraries/LibC/netdb.h @@ -55,6 +55,19 @@ struct servent* getservbyname(const char* name, const char* protocol); struct servent* getservbyport(int port, const char* protocol); void setservent(int stay_open); void endservent(); + +struct protoent { + char* p_name; + char** p_aliases; + int p_proto; +}; + +void endprotoent(); +struct protoent* getprotobyname(const char* name); +struct protoent* getprotobynumber(int proto); +struct protoent* getprotoent(); +void setprotoent(int stay_open); + extern int h_errno; #define HOST_NOT_FOUND 101