From ebd6d7466d047c0b05ff16d7b6f72b3bff0f31d3 Mon Sep 17 00:00:00 2001 From: wheelchairy Date: Tue, 28 Jan 2025 21:22:41 +0300 Subject: [PATCH] =?UTF-8?q?=D0=9D=D0=B0=D1=87=D0=B0=D0=BB=D0=BE.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Makefile | 56 + bfsk.cpp | 78 + bfsk.hpp | 11 + cli.cpp | 56 + cli.hpp | 5 + commands.cpp | 113 + commands.hpp | 6 + config.hpp | 30 + history.txt | 6 + libs/httplib.h | 10351 +++++++++++++++++++++++++++ libs/linenoise.c | 1349 ++++ libs/linenoise.h | 113 + libs/monocypher.c | 2956 ++++++++ libs/monocypher.h | 321 + libs/optional/monocypher-ed25519.c | 500 ++ libs/optional/monocypher-ed25519.h | 140 + main.cpp | 8 + sound.cpp | 180 + sound.hpp | 6 + webserver.cpp | 66 + webserver.hpp | 7 + x25519_handshake.cpp | 34 + x25519_handshake.hpp | 9 + 23 files changed, 16401 insertions(+) create mode 100644 Makefile create mode 100644 bfsk.cpp create mode 100644 bfsk.hpp create mode 100644 cli.cpp create mode 100644 cli.hpp create mode 100644 commands.cpp create mode 100644 commands.hpp create mode 100644 config.hpp create mode 100644 history.txt create mode 100644 libs/httplib.h create mode 100644 libs/linenoise.c create mode 100644 libs/linenoise.h create mode 100644 libs/monocypher.c create mode 100644 libs/monocypher.h create mode 100644 libs/optional/monocypher-ed25519.c create mode 100644 libs/optional/monocypher-ed25519.h create mode 100644 main.cpp create mode 100644 sound.cpp create mode 100644 sound.hpp create mode 100644 webserver.cpp create mode 100644 webserver.hpp create mode 100644 x25519_handshake.cpp create mode 100644 x25519_handshake.hpp diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e5f6885 --- /dev/null +++ b/Makefile @@ -0,0 +1,56 @@ +CC = clang +CXX = clang++ + +CFLAGS = -O2 -Wall -I./libs +CXXFLAGS = -std=c++17 -O2 -Wall -I./libs + +UNAME_S := $(shell uname -s) +BREW_PREFIX := $(shell brew --prefix portaudio 2>/dev/null) + +ifeq ($(UNAME_S),Darwin) + ifneq ($(BREW_PREFIX),) + CFLAGS += -I$(BREW_PREFIX)/include + CXXFLAGS += -I$(BREW_PREFIX)/include + LIBS += -L$(BREW_PREFIX)/lib -lportaudio + else + $(warning [Makefile] PortAudio не найдена через Homebrew. Установите её: brew install portaudio) + endif +else ifeq ($(UNAME_S),Linux) + HAVE_PKG := $(shell pkg-config --exists portaudio-2.0 && echo yes || echo no) + ifeq ($(HAVE_PKG),yes) + PORTAUDIO_CFLAGS := $(shell pkg-config --cflags portaudio-2.0) + PORTAUDIO_LIBS := $(shell pkg-config --libs portaudio-2.0) + CFLAGS += $(PORTAUDIO_CFLAGS) + CXXFLAGS += $(PORTAUDIO_CFLAGS) + LIBS += $(PORTAUDIO_LIBS) + else + $(warning [Makefile] portaudio-2.0 не найдена через pkg-config. Использую -lportaudio) + LIBS += -lportaudio + endif +else + $(warning [Makefile] Unsupported OS: $(UNAME_S)) +endif + +SRCS_CPP = main.cpp commands.cpp cli.cpp webserver.cpp sound.cpp bfsk.cpp x25519_handshake.cpp +SRCS_C = libs/monocypher.c libs/linenoise.c + +OBJS_CPP = $(SRCS_CPP:.cpp=.o) +OBJS_C = $(SRCS_C:.c=.o) + +TARGET = cerberus + +.PHONY: all clean + +all: $(TARGET) + +$(TARGET): $(OBJS_CPP) $(OBJS_C) + $(CXX) $(CXXFLAGS) -o $@ $^ $(LIBS) + +%.o: %.cpp + $(CXX) $(CXXFLAGS) -c $< -o $@ + +%.o: %.c + $(CC) $(CFLAGS) -c $< -o $@ + +clean: + rm -f $(TARGET) *.o libs/*.o diff --git a/bfsk.cpp b/bfsk.cpp new file mode 100644 index 0000000..d887881 --- /dev/null +++ b/bfsk.cpp @@ -0,0 +1,78 @@ +#include "bfsk.hpp" +#include +#include +#include + +std::vector bfskModulate(const std::vector &data) { + int samplesPerBit = (int)((double)SAMPLE_RATE / BFSK_BAUD); + size_t totalBits = data.size() * 8; + size_t totalSamples = totalBits * samplesPerBit; + std::vector out(totalSamples * 2, 0.0f); + + double phase0 = 0.0; + double phase1 = 0.0; + double inc0 = 2.0 * M_PI * BFSK_FREQ0 / (double)SAMPLE_RATE; + double inc1 = 2.0 * M_PI * BFSK_FREQ1 / (double)SAMPLE_RATE; + + size_t sampleIndex = 0; + + for (auto byteVal : data) { + for (int b = 0; b < 8; b++) { + int bit = (byteVal >> b) & 1; + for (int s = 0; s < samplesPerBit; s++) { + float val; + if (bit == 0) { + val = sinf((float)phase0); + phase0 += inc0; + } else { + val = sinf((float)phase1); + phase1 += inc1; + } + out[sampleIndex*2 + 0] = val * 0.3f; + out[sampleIndex*2 + 1] = val * 0.3f; + sampleIndex++; + } + } + } + return out; +} + +std::vector bfskDemodulate(const std::vector &monoData) { + int samplesPerBit = (int)((double)SAMPLE_RATE / BFSK_BAUD); + size_t totalBits = monoData.size() / samplesPerBit; + size_t totalBytes = totalBits / 8; + + double inc0 = 2.0 * M_PI * BFSK_FREQ0 / (double)SAMPLE_RATE; + double inc1 = 2.0 * M_PI * BFSK_FREQ1 / (double)SAMPLE_RATE; + + std::vector result(totalBytes, 0); + + size_t bitIndex = 0; + for (size_t b = 0; b < totalBits; b++) { + double sum0 = 0.0; + double sum1 = 0.0; + double phase0 = 0.0; + double phase1 = 0.0; + + size_t startSample = b * samplesPerBit; + for (int s = 0; s < samplesPerBit; s++) { + float sample = monoData[startSample + s]; + float ref0 = sinf((float)phase0); + float ref1 = sinf((float)phase1); + sum0 += sample * ref0; + sum1 += sample * ref1; + phase0 += inc0; + phase1 += inc1; + } + + int bit = (std::fabs(sum1) > std::fabs(sum0)) ? 1 : 0; + size_t bytePos = bitIndex / 8; + int bitPos = bitIndex % 8; + if (bytePos < result.size()) { + result[bytePos] |= (bit << bitPos); + } + bitIndex++; + } + + return result; +} diff --git a/bfsk.hpp b/bfsk.hpp new file mode 100644 index 0000000..f85cff1 --- /dev/null +++ b/bfsk.hpp @@ -0,0 +1,11 @@ +#pragma once +#include + +constexpr double BFSK_FREQ0 = 1000.0; +constexpr double BFSK_FREQ1 = 2000.0; +constexpr double BFSK_BAUD = 100.0; +constexpr int SAMPLE_RATE = 44100; + +std::vector bfskModulate(const std::vector &data); + +std::vector bfskDemodulate(const std::vector &monoData); diff --git a/cli.cpp b/cli.cpp new file mode 100644 index 0000000..81e577c --- /dev/null +++ b/cli.cpp @@ -0,0 +1,56 @@ +#include "cli.hpp" +#include "commands.hpp" +#include "config.hpp" + +#include +#include +#include + +#include "libs/linenoise.h" + +static void completionCallback(const char *input, linenoiseCompletions *completions) { + if (strncmp(input, "ni", 2) == 0) { + linenoiseAddCompletion(completions, "nick set "); + linenoiseAddCompletion(completions, "nick generatekey"); + } + else if (strncmp(input, "web", 3) == 0) { + linenoiseAddCompletion(completions, "web start"); + linenoiseAddCompletion(completions, "web connect pm 127.0.0.1"); + linenoiseAddCompletion(completions, "web stop"); + } + else if (strncmp(input, "sound", 5) == 0) { + linenoiseAddCompletion(completions, "sound find"); + linenoiseAddCompletion(completions, "sound lose"); + } +} + +void runCLI(AppConfig &config) { + linenoiseHistoryLoad("history.txt"); + + linenoiseSetCompletionCallback(completionCallback); + + std::cout << CLR_CYAN "Cerberus BFSK Demo (linenoise + color)\n" + << "Доступные команды:\n" + << " nick set \n" + << " nick generatekey\n" + << " web start / web connect pm|server / web stop\n" + << " sound find / sound lose\n" + << " exit\n\n" CLR_RESET; + + while (true) { + char *line = linenoise(CLR_BOLD ">_ " CLR_RESET); + if (!line) { + std::cout << "\n[cli] EOF/Ctrl+C - выходим.\n"; + break; + } + std::string input(line); + free(line); + + if (!input.empty()) { + linenoiseHistoryAdd(input.c_str()); + linenoiseHistorySave("history.txt"); + + processCommand(input, config); + } + } +} diff --git a/cli.hpp b/cli.hpp new file mode 100644 index 0000000..0a99e81 --- /dev/null +++ b/cli.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "config.hpp" + +void runCLI(AppConfig &config); diff --git a/commands.cpp b/commands.cpp new file mode 100644 index 0000000..bdc1f5c --- /dev/null +++ b/commands.cpp @@ -0,0 +1,113 @@ +#include "commands.hpp" +#include "config.hpp" +#include "webserver.hpp" +#include "sound.hpp" + +#include +#include +#include +#include +#include + +static std::vector splitTokens(const std::string &line) { + std::istringstream iss(line); + return { std::istream_iterator(iss), + std::istream_iterator() }; +} + +static void generateKey(AppConfig &config) { + config.key.resize(32); + FILE* f = fopen("/dev/urandom", "rb"); + if (!f) { + std::cerr << CLR_RED "[nick] Не удалось открыть /dev/urandom\n" CLR_RESET; + return; + } + fread(config.key.data(), 1, 32, f); + fclose(f); + + std::cout << CLR_GREEN "[nick] 256-битный ключ сгенерирован!\n" CLR_RESET; +} + +void processCommand(const std::string &input, AppConfig &config) { + if (input.empty()) return; + + auto tokens = splitTokens(input); + if (tokens.empty()) return; + + std::string cmd = tokens[0]; + + if (cmd == "nick") { + if (tokens.size() < 2) { + std::cout << CLR_YELLOW "[nick] Доступно: set , generatekey\n" CLR_RESET; + return; + } + std::string sub = tokens[1]; + if (sub == "set") { + if (tokens.size() < 3) { + std::cout << CLR_YELLOW "[nick] Использование: nick set \n" CLR_RESET; + return; + } + std::string newNick; + for (size_t i = 2; i < tokens.size(); i++) { + if (i > 2) newNick += " "; + newNick += tokens[i]; + } + config.nickname = newNick; + std::cout << CLR_GREEN "[nick] Установлен ник: " << newNick << "\n" CLR_RESET; + } + else if (sub == "generatekey") { + generateKey(config); + } + else { + std::cout << CLR_YELLOW "[nick] Неизвестная подкоманда: " << sub << "\n" CLR_RESET; + } + } + else if (cmd == "web") { + if (tokens.size() < 2) { + std::cout << CLR_YELLOW "[web] Доступно: start, connect pm|server , stop\n" CLR_RESET; + return; + } + std::string sub = tokens[1]; + if (sub == "start") { + webServerStart(config); + } + else if (sub == "connect") { + if (tokens.size() < 4) { + std::cout << CLR_YELLOW "[web] Использование: web connect \n" CLR_RESET; + return; + } + std::string ctype = tokens[2]; + std::string ip = tokens[3]; + webServerConnect(config, ctype, ip); + } + else if (sub == "stop") { + webServerStop(config); + } + else { + std::cout << CLR_YELLOW "[web] Неизвестная подкоманда: " << sub << "\n" CLR_RESET; + } + } + else if (cmd == "sound") { + if (tokens.size() < 2) { + std::cout << CLR_YELLOW "[sound] Доступно: find, lose\n" CLR_RESET; + return; + } + std::string sub = tokens[1]; + if (sub == "find") { + soundFind(config); + } + else if (sub == "lose") { + soundLose(config); + } + else { + std::cout << CLR_YELLOW "[sound] Неизвестная подкоманда: " << sub << "\n" CLR_RESET; + } + } + else if (cmd == "exit") { + std::cout << CLR_CYAN "[cli] Завершаем работу по команде 'exit'\n" CLR_RESET; + exit(0); + } + else { + std::cout << CLR_RED "Неизвестная команда: " << cmd << CLR_RESET << "\n"; + } +} diff --git a/commands.hpp b/commands.hpp new file mode 100644 index 0000000..c141c34 --- /dev/null +++ b/commands.hpp @@ -0,0 +1,6 @@ +#pragma once + +#include +#include "config.hpp" + +void processCommand(const std::string &input, AppConfig &config); diff --git a/config.hpp b/config.hpp new file mode 100644 index 0000000..6317f3e --- /dev/null +++ b/config.hpp @@ -0,0 +1,30 @@ +#pragma once +#include +#include +#include + +#define CLR_RESET "\x1b[0m" +#define CLR_BOLD "\x1b[1m" +#define CLR_RED "\x1b[31m" +#define CLR_GREEN "\x1b[32m" +#define CLR_YELLOW "\x1b[33m" +#define CLR_BLUE "\x1b[34m" +#define CLR_MAGENTA "\x1b[35m" +#define CLR_CYAN "\x1b[36m" +#define CLR_WHITE "\x1b[37m" + +struct AppConfig { + std::string nickname = "noname"; + + std::vector key; + + bool webServerRunning = false; + bool soundExchangeActive = false; + + + uint8_t ephemeralSec[32]; + uint8_t ephemeralPub[32]; + + uint8_t sharedSecret[32]; + bool haveSharedSecret = false; +}; diff --git a/history.txt b/history.txt new file mode 100644 index 0000000..bc1f6e7 --- /dev/null +++ b/history.txt @@ -0,0 +1,6 @@ +sound find +ls +nick set platon +nick generatekey +web start +web connect pm localhost diff --git a/libs/httplib.h b/libs/httplib.h new file mode 100644 index 0000000..7813cd4 --- /dev/null +++ b/libs/httplib.h @@ -0,0 +1,10351 @@ +// +// httplib.h +// +// Copyright (c) 2025 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +#define CPPHTTPLIB_VERSION "0.18.5" + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND 10000 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 100 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND +#define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_USECOND +#ifdef _WIN32 +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 10000 +#else +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 0 +#endif +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_HEADER_MAX_LENGTH +#define CPPHTTPLIB_HEADER_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT +#define CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#endif + +#ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_RANGE_MAX_COUNT +#define CPPHTTPLIB_RANGE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_TCP_NODELAY +#define CPPHTTPLIB_TCP_NODELAY false +#endif + +#ifndef CPPHTTPLIB_IPV6_V6ONLY +#define CPPHTTPLIB_IPV6_V6ONLY false +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_COMPRESSION_BUFSIZ +#define CPPHTTPLIB_COMPRESSION_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() > 0 \ + ? std::thread::hardware_concurrency() - 1 \ + : 0)) +#endif + +#ifndef CPPHTTPLIB_RECV_FLAGS +#define CPPHTTPLIB_RECV_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_SEND_FLAGS +#define CPPHTTPLIB_SEND_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_LISTEN_BACKLOG +#define CPPHTTPLIB_LISTEN_BACKLOG 5 +#endif + +/* + * Headers + */ + +#ifdef _WIN32 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#if _MSC_VER < 1900 +#error Sorry, Visual Studio versions prior to 2015 are not supported +#endif + +#pragma comment(lib, "ws2_32.lib") + +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = long; +#endif +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m) & S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m) & S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include +#include +#include + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +using socket_t = SOCKET; +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) +#endif + +#else // not _WIN32 + +#include +#if !defined(_AIX) && !defined(__MVS__) +#include +#endif +#ifdef __MVS__ +#include +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif +#endif +#include +#include +#include +#ifdef __linux__ +#include +#endif +#include +#ifdef CPPHTTPLIB_USE_POLL +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +using socket_t = int; +#ifndef INVALID_SOCKET +#define INVALID_SOCKET (-1) +#endif +#endif //_WIN32 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +#include + +// these are defined in wincrypt.h and it breaks compilation if BoringSSL is +// used +#undef X509_NAME +#undef X509_CERT_PAIR +#undef X509_EXTENSIONS +#undef PKCS7_SIGNER_INFO + +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#include +#if TARGET_OS_OSX +#include +#include +#endif // TARGET_OS_OSX +#endif // _WIN32 + +#include +#include +#include +#include + +#if defined(_WIN32) && defined(OPENSSL_USE_APPLINK) +#include +#endif + +#include +#include + +#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER) +#if OPENSSL_VERSION_NUMBER < 0x1010107f +#error Please use OpenSSL or a current version of BoringSSL +#endif +#define SSL_get1_peer_certificate SSL_get_peer_certificate +#elif OPENSSL_VERSION_NUMBER < 0x30000000L +#error Sorry, OpenSSL versions prior to 3.0.0 are not supported +#endif + +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +#include +#include +#endif + +/* + * Declaration + */ +namespace httplib { + +namespace detail { + +/* + * Backport std::make_unique from C++14. + * + * NOTE: This code came up with the following stackoverflow post: + * https://stackoverflow.com/questions/10149840/c-arrays-and-make-unique + * + */ + +template +typename std::enable_if::value, std::unique_ptr>::type +make_unique(Args &&...args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +template +typename std::enable_if::value, std::unique_ptr>::type +make_unique(std::size_t n) { + typedef typename std::remove_extent::type RT; + return std::unique_ptr(new RT[n]); +} + +namespace case_ignore { + +inline unsigned char to_lower(int c) { + const static unsigned char table[256] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, + 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, + 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, + 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, + 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, + 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 224, 225, 226, + 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, + 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, 224, + 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, + 255, + }; + return table[(unsigned char)(char)c]; +} + +inline bool equal(const std::string &a, const std::string &b) { + return a.size() == b.size() && + std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) { + return to_lower(ca) == to_lower(cb); + }); +} + +struct equal_to { + bool operator()(const std::string &a, const std::string &b) const { + return equal(a, b); + } +}; + +struct hash { + size_t operator()(const std::string &key) const { + return hash_core(key.data(), key.size(), 0); + } + + size_t hash_core(const char *s, size_t l, size_t h) const { + return (l == 0) ? h + : hash_core(s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no + // overflow happens + (((std::numeric_limits::max)() >> 6) & + h * 33) ^ + static_cast(to_lower(*s))); + } +}; + +} // namespace case_ignore + +// This is based on +// "http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4189". + +struct scope_exit { + explicit scope_exit(std::function &&f) + : exit_function(std::move(f)), execute_on_destruction{true} {} + + scope_exit(scope_exit &&rhs) noexcept + : exit_function(std::move(rhs.exit_function)), + execute_on_destruction{rhs.execute_on_destruction} { + rhs.release(); + } + + ~scope_exit() { + if (execute_on_destruction) { this->exit_function(); } + } + + void release() { this->execute_on_destruction = false; } + +private: + scope_exit(const scope_exit &) = delete; + void operator=(const scope_exit &) = delete; + scope_exit &operator=(scope_exit &&) = delete; + + std::function exit_function; + bool execute_on_destruction; +}; + +} // namespace detail + +enum StatusCode { + // Information responses + Continue_100 = 100, + SwitchingProtocol_101 = 101, + Processing_102 = 102, + EarlyHints_103 = 103, + + // Successful responses + OK_200 = 200, + Created_201 = 201, + Accepted_202 = 202, + NonAuthoritativeInformation_203 = 203, + NoContent_204 = 204, + ResetContent_205 = 205, + PartialContent_206 = 206, + MultiStatus_207 = 207, + AlreadyReported_208 = 208, + IMUsed_226 = 226, + + // Redirection messages + MultipleChoices_300 = 300, + MovedPermanently_301 = 301, + Found_302 = 302, + SeeOther_303 = 303, + NotModified_304 = 304, + UseProxy_305 = 305, + unused_306 = 306, + TemporaryRedirect_307 = 307, + PermanentRedirect_308 = 308, + + // Client error responses + BadRequest_400 = 400, + Unauthorized_401 = 401, + PaymentRequired_402 = 402, + Forbidden_403 = 403, + NotFound_404 = 404, + MethodNotAllowed_405 = 405, + NotAcceptable_406 = 406, + ProxyAuthenticationRequired_407 = 407, + RequestTimeout_408 = 408, + Conflict_409 = 409, + Gone_410 = 410, + LengthRequired_411 = 411, + PreconditionFailed_412 = 412, + PayloadTooLarge_413 = 413, + UriTooLong_414 = 414, + UnsupportedMediaType_415 = 415, + RangeNotSatisfiable_416 = 416, + ExpectationFailed_417 = 417, + ImATeapot_418 = 418, + MisdirectedRequest_421 = 421, + UnprocessableContent_422 = 422, + Locked_423 = 423, + FailedDependency_424 = 424, + TooEarly_425 = 425, + UpgradeRequired_426 = 426, + PreconditionRequired_428 = 428, + TooManyRequests_429 = 429, + RequestHeaderFieldsTooLarge_431 = 431, + UnavailableForLegalReasons_451 = 451, + + // Server error responses + InternalServerError_500 = 500, + NotImplemented_501 = 501, + BadGateway_502 = 502, + ServiceUnavailable_503 = 503, + GatewayTimeout_504 = 504, + HttpVersionNotSupported_505 = 505, + VariantAlsoNegotiates_506 = 506, + InsufficientStorage_507 = 507, + LoopDetected_508 = 508, + NotExtended_510 = 510, + NetworkAuthenticationRequired_511 = 511, +}; + +using Headers = + std::unordered_multimap; + +using Params = std::multimap; +using Match = std::smatch; + +using Progress = std::function; + +struct Response; +using ResponseHandler = std::function; + +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +using MultipartFormDataItems = std::vector; +using MultipartFormDataMap = std::multimap; + +class DataSink { +public: + DataSink() : os(&sb_), sb_(*this) {} + + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function write; + std::function is_writable; + std::function done; + std::function done_with_trailer; + std::ostream os; + +private: + class data_sink_streambuf final : public std::streambuf { + public: + explicit data_sink_streambuf(DataSink &sink) : sink_(sink) {} + + protected: + std::streamsize xsputn(const char *s, std::streamsize n) override { + sink_.write(s, static_cast(n)); + return n; + } + + private: + DataSink &sink_; + }; + + data_sink_streambuf sb_; +}; + +using ContentProvider = + std::function; + +using ContentProviderWithoutLength = + std::function; + +using ContentProviderResourceReleaser = std::function; + +struct MultipartFormDataProvider { + std::string name; + ContentProviderWithoutLength provider; + std::string filename; + std::string content_type; +}; +using MultipartFormDataProviderItems = std::vector; + +using ContentReceiverWithProgress = + std::function; + +using ContentReceiver = + std::function; + +using MultipartContentHeader = + std::function; + +class ContentReader { +public: + using Reader = std::function; + using MultipartReader = std::function; + + ContentReader(Reader reader, MultipartReader multipart_reader) + : reader_(std::move(reader)), + multipart_reader_(std::move(multipart_reader)) {} + + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return multipart_reader_(std::move(header), std::move(receiver)); + } + + bool operator()(ContentReceiver receiver) const { + return reader_(std::move(receiver)); + } + + Reader reader_; + MultipartReader multipart_reader_; +}; + +using Range = std::pair; +using Ranges = std::vector; + +struct Request { + std::string method; + std::string path; + Params params; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + std::string local_addr; + int local_port = -1; + + // for server + std::string version; + std::string target; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + std::unordered_map path_params; + std::function is_connection_closed = []() { return true; }; + + // for client + ResponseHandler response_handler; + ContentReceiverWithProgress content_receiver; + Progress progress; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl = nullptr; +#endif + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, + size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + bool has_param(const std::string &key) const; + std::string get_param_value(const std::string &key, size_t id = 0) const; + size_t get_param_value_count(const std::string &key) const; + + bool is_multipart_form_data() const; + + bool has_file(const std::string &key) const; + MultipartFormData get_file_value(const std::string &key) const; + std::vector get_file_values(const std::string &key) const; + + // private members... + size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT; + size_t content_length_ = 0; + ContentProvider content_provider_; + bool is_chunked_content_provider_ = false; + size_t authorization_count_ = 0; +}; + +struct Response { + std::string version; + int status = -1; + std::string reason; + Headers headers; + std::string body; + std::string location; // Redirect location + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, + size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + void set_redirect(const std::string &url, int status = StatusCode::Found_302); + void set_content(const char *s, size_t n, const std::string &content_type); + void set_content(const std::string &s, const std::string &content_type); + void set_content(std::string &&s, const std::string &content_type); + + void set_content_provider( + size_t length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_chunked_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_file_content(const std::string &path, + const std::string &content_type); + void set_file_content(const std::string &path); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser_) { + content_provider_resource_releaser_(content_provider_success_); + } + } + + // private members... + size_t content_length_ = 0; + ContentProvider content_provider_; + ContentProviderResourceReleaser content_provider_resource_releaser_; + bool is_chunked_content_provider_ = false; + bool content_provider_success_ = false; + std::string file_content_path_; + std::string file_content_content_type_; +}; + +class Stream { +public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; + + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; + virtual socket_t socket() const = 0; + + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); +}; + +class TaskQueue { +public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + + virtual bool enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + + virtual void on_idle() {} +}; + +class ThreadPool final : public TaskQueue { +public: + explicit ThreadPool(size_t n, size_t mqr = 0) + : shutdown_(false), max_queued_requests_(mqr) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + bool enqueue(std::function fn) override { + { + std::unique_lock lock(mutex_); + if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { + return false; + } + jobs_.push_back(std::move(fn)); + } + + cond_.notify_one(); + return true; + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + +private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ + !defined(LIBRESSL_VERSION_NUMBER) + OPENSSL_thread_stop(); +#endif + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + size_t max_queued_requests_ = 0; + + std::condition_variable cond_; + std::mutex mutex_; +}; + +using Logger = std::function; + +using SocketOptions = std::function; + +void default_socket_options(socket_t sock); + +const char *status_message(int status); + +std::string get_bearer_token_auth(const Request &req); + +namespace detail { + +class MatcherBase { +public: + virtual ~MatcherBase() = default; + + // Match request path and populate its matches and + virtual bool match(Request &request) const = 0; +}; + +/** + * Captures parameters in request path and stores them in Request::path_params + * + * Capture name is a substring of a pattern from : to /. + * The rest of the pattern is matched agains the request path directly + * Parameters are captured starting from the next character after + * the end of the last matched static pattern fragment until the next /. + * + * Example pattern: + * "/path/fragments/:capture/more/fragments/:second_capture" + * Static fragments: + * "/path/fragments/", "more/fragments/" + * + * Given the following request path: + * "/path/fragments/:1/more/fragments/:2" + * the resulting capture will be + * {{"capture", "1"}, {"second_capture", "2"}} + */ +class PathParamsMatcher final : public MatcherBase { +public: + PathParamsMatcher(const std::string &pattern); + + bool match(Request &request) const override; + +private: + // Treat segment separators as the end of path parameter capture + // Does not need to handle query parameters as they are parsed before path + // matching + static constexpr char separator = '/'; + + // Contains static path fragments to match against, excluding the '/' after + // path params + // Fragments are separated by path params + std::vector static_fragments_; + // Stores the names of the path parameters to be used as keys in the + // Request::path_params map + std::vector param_names_; +}; + +/** + * Performs std::regex_match on request path + * and stores the result in Request::matches + * + * Note that regex match is performed directly on the whole request. + * This means that wildcard patterns may match multiple path segments with /: + * "/begin/(.*)/end" will match both "/begin/middle/end" and "/begin/1/2/end". + */ +class RegexMatcher final : public MatcherBase { +public: + RegexMatcher(const std::string &pattern) : regex_(pattern) {} + + bool match(Request &request) const override; + +private: + std::regex regex_; +}; + +ssize_t write_headers(Stream &strm, const Headers &headers); + +} // namespace detail + +class Server { +public: + using Handler = std::function; + + using ExceptionHandler = + std::function; + + enum class HandlerResponse { + Handled, + Unhandled, + }; + using HandlerWithResponse = + std::function; + + using HandlerWithContentReader = std::function; + + using Expect100ContinueHandler = + std::function; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server &Get(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, HandlerWithContentReader handler); + Server &Put(const std::string &pattern, Handler handler); + Server &Put(const std::string &pattern, HandlerWithContentReader handler); + Server &Patch(const std::string &pattern, Handler handler); + Server &Patch(const std::string &pattern, HandlerWithContentReader handler); + Server &Delete(const std::string &pattern, Handler handler); + Server &Delete(const std::string &pattern, HandlerWithContentReader handler); + Server &Options(const std::string &pattern, Handler handler); + + bool set_base_dir(const std::string &dir, + const std::string &mount_point = std::string()); + bool set_mount_point(const std::string &mount_point, const std::string &dir, + Headers headers = Headers()); + bool remove_mount_point(const std::string &mount_point); + Server &set_file_extension_and_mimetype_mapping(const std::string &ext, + const std::string &mime); + Server &set_default_file_mimetype(const std::string &mime); + Server &set_file_request_handler(Handler handler); + + template + Server &set_error_handler(ErrorHandlerFunc &&handler) { + return set_error_handler_core( + std::forward(handler), + std::is_convertible{}); + } + + Server &set_exception_handler(ExceptionHandler handler); + Server &set_pre_routing_handler(HandlerWithResponse handler); + Server &set_post_routing_handler(Handler handler); + + Server &set_expect_100_continue_handler(Expect100ContinueHandler handler); + Server &set_logger(Logger logger); + + Server &set_address_family(int family); + Server &set_tcp_nodelay(bool on); + Server &set_ipv6_v6only(bool on); + Server &set_socket_options(SocketOptions socket_options); + + Server &set_default_headers(Headers headers); + Server & + set_header_writer(std::function const &writer); + + Server &set_keep_alive_max_count(size_t count); + Server &set_keep_alive_timeout(time_t sec); + + Server &set_read_timeout(time_t sec, time_t usec = 0); + template + Server &set_read_timeout(const std::chrono::duration &duration); + + Server &set_write_timeout(time_t sec, time_t usec = 0); + template + Server &set_write_timeout(const std::chrono::duration &duration); + + Server &set_idle_interval(time_t sec, time_t usec = 0); + template + Server &set_idle_interval(const std::chrono::duration &duration); + + Server &set_payload_max_length(size_t length); + + bool bind_to_port(const std::string &host, int port, int socket_flags = 0); + int bind_to_any_port(const std::string &host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const std::string &host, int port, int socket_flags = 0); + + bool is_running() const; + void wait_until_ready() const; + void stop(); + void decommission(); + + std::function new_task_queue; + +protected: + bool process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, + bool &connection_closed, + const std::function &setup_request); + + std::atomic svr_sock_{INVALID_SOCKET}; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND; + time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; + time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + +private: + using Handlers = + std::vector, Handler>>; + using HandlersForContentReader = + std::vector, + HandlerWithContentReader>>; + + static std::unique_ptr + make_matcher(const std::string &pattern); + + Server &set_error_handler_core(HandlerWithResponse handler, std::true_type); + Server &set_error_handler_core(Handler handler, std::false_type); + + socket_t create_server_socket(const std::string &host, int port, + int socket_flags, + SocketOptions socket_options) const; + int bind_internal(const std::string &host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(const Request &req, Response &res, + bool head = false); + bool dispatch_request(Request &req, Response &res, + const Handlers &handlers) const; + bool dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const; + + bool parse_request_line(const char *s, Request &req) const; + void apply_ranges(const Request &req, Response &res, + std::string &content_type, std::string &boundary) const; + bool write_response(Stream &strm, bool close_connection, Request &req, + Response &res); + bool write_response_with_content(Stream &strm, bool close_connection, + const Request &req, Response &res); + bool write_response_core(Stream &strm, bool close_connection, + const Request &req, Response &res, + bool need_apply_ranges); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool + read_content_with_content_receiver(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) const; + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_{false}; + std::atomic is_decommisioned{false}; + + struct MountPointEntry { + std::string mount_point; + std::string base_dir; + Headers headers; + }; + std::vector base_dirs_; + std::map file_extension_and_mimetype_map_; + std::string default_file_mimetype_ = "application/octet-stream"; + Handler file_request_handler_; + + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + + HandlerWithResponse error_handler_; + ExceptionHandler exception_handler_; + HandlerWithResponse pre_routing_handler_; + Handler post_routing_handler_; + Expect100ContinueHandler expect_100_continue_handler_; + + Logger logger_; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = default_socket_options; + + Headers default_headers_; + std::function header_writer_ = + detail::write_headers; +}; + +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + SSLServerHostnameVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + ProxyConnection, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + +class Result { +public: + Result() = default; + Result(std::unique_ptr &&res, Error err, + Headers &&request_headers = Headers{}) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)) {} + // Response + operator bool() const { return res_ != nullptr; } + bool operator==(std::nullptr_t) const { return res_ == nullptr; } + bool operator!=(std::nullptr_t) const { return res_ != nullptr; } + const Response &value() const { return *res_; } + Response &value() { return *res_; } + const Response &operator*() const { return *res_; } + Response &operator*() { return *res_; } + const Response *operator->() const { return res_.get(); } + Response *operator->() { return res_.get(); } + + // Error + Error error() const { return err_; } + + // Request Headers + bool has_request_header(const std::string &key) const; + std::string get_request_header_value(const std::string &key, + const char *def = "", + size_t id = 0) const; + uint64_t get_request_header_value_u64(const std::string &key, + uint64_t def = 0, size_t id = 0) const; + size_t get_request_header_value_count(const std::string &key) const; + +private: + std::unique_ptr res_; + Error err_ = Error::Unknown; + Headers request_headers_; +}; + +class ClientImpl { +public: + explicit ClientImpl(const std::string &host); + + explicit ClientImpl(const std::string &host, int port); + + explicit ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + virtual ~ClientImpl(); + + virtual bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, + Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, size_t content_length, + ContentProvider content_provider, const std::string &content_type); + Result Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void + set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_ipv6_v6only(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template + void + set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template + void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template + void set_write_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, + const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, + const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, + const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + void set_ca_cert_store(X509_STORE *ca_cert_store); + X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier(std::function verifier); +#endif + + void set_logger(Logger logger); + +protected: + struct Socket { + socket_t sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSL *ssl = nullptr; +#endif + + bool is_open() const { return sock != INVALID_SOCKET; } + }; + + virtual bool create_and_connect_socket(Socket &socket, Error &error); + + // All of: + // shutdown_ssl + // shutdown_socket + // close_socket + // should ONLY be called when socket_mutex_ is locked. + // Also, shutdown_ssl and close_socket should also NOT be called concurrently + // with a DIFFERENT thread sending requests using that socket. + virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully); + void shutdown_socket(Socket &socket) const; + void close_socket(Socket &socket); + + bool process_request(Stream &strm, Request &req, Response &res, + bool close_connection, Error &error); + + bool write_content_with_provider(Stream &strm, const Request &req, + Error &error) const; + + void copy_settings(const ClientImpl &rhs); + + // Socket endpoint information + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; + + // These are all protected under socket_mutex + size_t socket_requests_in_flight_ = 0; + std::thread::id socket_requests_are_from_thread_ = std::thread::id(); + bool socket_should_be_closed_when_request_is_done_ = false; + + // Hostname-IP map + std::map addr_map_; + + // Default headers + Headers default_headers_; + + // Header writer + std::function header_writer_ = + detail::write_headers; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; + + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool keep_alive_ = false; + bool follow_location_ = false; + + bool url_encode_ = true; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = nullptr; + + bool compress_ = false; + bool decompress_ = true; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_ = -1; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + + X509_STORE *ca_cert_store_ = nullptr; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool server_certificate_verification_ = true; + bool server_hostname_verification_ = true; + std::function server_certificate_verifier_; +#endif + + Logger logger_; + +private: + bool send_(Request &req, Response &res, Error &error); + Result send_(Request &&req); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_peer_could_be_closed(SSL *ssl) const; +#endif + socket_t create_client_socket(Error &error) const; + bool read_response_line(Stream &strm, const Request &req, + Response &res) const; + bool write_request(Stream &strm, Request &req, bool close_connection, + Error &error); + bool redirect(Request &req, Response &res, Error &error); + bool handle_request(Stream &strm, Request &req, Response &res, + bool close_connection, Error &error); + std::unique_ptr send_with_content_provider( + Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error); + Result send_with_content_provider( + const std::string &method, const std::string &path, + const Headers &headers, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Progress progress); + ContentProviderWithoutLength get_multipart_content_provider( + const std::string &boundary, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) const; + + std::string adjust_host_string(const std::string &host) const; + + virtual bool process_socket(const Socket &socket, + std::function callback); + virtual bool is_ssl() const; +}; + +class Client { +public: + // Universal interface + explicit Client(const std::string &scheme_host_port); + + explicit Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path); + + // HTTP only interface + explicit Client(const std::string &host, int port); + + explicit Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + Client(Client &&) = default; + Client &operator=(Client &&) = default; + + ~Client(); + + bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, + Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, size_t content_length, + ContentProvider content_provider, const std::string &content_type); + Result Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + Result Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void + set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template + void + set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template + void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template + void set_write_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, + const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, + const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, + const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier(std::function verifier); +#endif + + void set_logger(Logger logger); + + // SSL +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; +#endif + +private: + std::unique_ptr cli_; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_ = false; +#endif +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLServer : public Server { +public: + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr, + const char *private_key_password = nullptr); + + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + + SSLServer( + const std::function &setup_ssl_ctx_callback); + + ~SSLServer() override; + + bool is_valid() const override; + + SSL_CTX *ssl_context() const; + + void update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + +private: + bool process_and_close_socket(socket_t sock) override; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; +}; + +class SSLClient final : public ClientImpl { +public: + explicit SSLClient(const std::string &host); + + explicit SSLClient(const std::string &host, int port); + + explicit SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password = std::string()); + + explicit SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key, + const std::string &private_key_password = std::string()); + + ~SSLClient() override; + + bool is_valid() const override; + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; + +private: + bool create_and_connect_socket(Socket &socket, Error &error) override; + void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; + void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); + + bool process_socket(const Socket &socket, + std::function callback) override; + bool is_ssl() const override; + + bool connect_with_proxy(Socket &sock, Response &res, bool &success, + Error &error); + bool initialize_ssl(Socket &socket, Error &error); + + bool load_certs(); + + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::once_flag initialize_cert_; + + std::vector host_components_; + + long verify_result_ = 0; + + friend class ClientImpl; +}; +#endif + +/* + * Implementation of template methods. + */ + +namespace detail { + +template +inline void duration_to_sec_and_usec(const T &duration, U callback) { + auto sec = std::chrono::duration_cast(duration).count(); + auto usec = std::chrono::duration_cast( + duration - std::chrono::seconds(sec)) + .count(); + callback(static_cast(sec), static_cast(usec)); +} + +inline bool is_numeric(const std::string &str) { + return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit); +} + +inline uint64_t get_header_value_u64(const Headers &headers, + const std::string &key, uint64_t def, + size_t id, bool &is_invalid_value) { + is_invalid_value = false; + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + if (is_numeric(it->second)) { + return std::strtoull(it->second.data(), nullptr, 10); + } else { + is_invalid_value = true; + } + } + return def; +} + +inline uint64_t get_header_value_u64(const Headers &headers, + const std::string &key, uint64_t def, + size_t id) { + bool dummy = false; + return get_header_value_u64(headers, key, def, id, dummy); +} + +} // namespace detail + +inline uint64_t Request::get_header_value_u64(const std::string &key, + uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +inline uint64_t Response::get_header_value_u64(const std::string &key, + uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +inline void default_socket_options(socket_t sock) { + int opt = 1; +#ifdef _WIN32 + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, + reinterpret_cast(&opt), sizeof(opt)); +#else +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, + reinterpret_cast(&opt), sizeof(opt)); +#else + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, + reinterpret_cast(&opt), sizeof(opt)); +#endif +#endif +} + +inline const char *status_message(int status) { + switch (status) { + case StatusCode::Continue_100: return "Continue"; + case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; + case StatusCode::Processing_102: return "Processing"; + case StatusCode::EarlyHints_103: return "Early Hints"; + case StatusCode::OK_200: return "OK"; + case StatusCode::Created_201: return "Created"; + case StatusCode::Accepted_202: return "Accepted"; + case StatusCode::NonAuthoritativeInformation_203: + return "Non-Authoritative Information"; + case StatusCode::NoContent_204: return "No Content"; + case StatusCode::ResetContent_205: return "Reset Content"; + case StatusCode::PartialContent_206: return "Partial Content"; + case StatusCode::MultiStatus_207: return "Multi-Status"; + case StatusCode::AlreadyReported_208: return "Already Reported"; + case StatusCode::IMUsed_226: return "IM Used"; + case StatusCode::MultipleChoices_300: return "Multiple Choices"; + case StatusCode::MovedPermanently_301: return "Moved Permanently"; + case StatusCode::Found_302: return "Found"; + case StatusCode::SeeOther_303: return "See Other"; + case StatusCode::NotModified_304: return "Not Modified"; + case StatusCode::UseProxy_305: return "Use Proxy"; + case StatusCode::unused_306: return "unused"; + case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; + case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; + case StatusCode::BadRequest_400: return "Bad Request"; + case StatusCode::Unauthorized_401: return "Unauthorized"; + case StatusCode::PaymentRequired_402: return "Payment Required"; + case StatusCode::Forbidden_403: return "Forbidden"; + case StatusCode::NotFound_404: return "Not Found"; + case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; + case StatusCode::NotAcceptable_406: return "Not Acceptable"; + case StatusCode::ProxyAuthenticationRequired_407: + return "Proxy Authentication Required"; + case StatusCode::RequestTimeout_408: return "Request Timeout"; + case StatusCode::Conflict_409: return "Conflict"; + case StatusCode::Gone_410: return "Gone"; + case StatusCode::LengthRequired_411: return "Length Required"; + case StatusCode::PreconditionFailed_412: return "Precondition Failed"; + case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; + case StatusCode::UriTooLong_414: return "URI Too Long"; + case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; + case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; + case StatusCode::ExpectationFailed_417: return "Expectation Failed"; + case StatusCode::ImATeapot_418: return "I'm a teapot"; + case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; + case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; + case StatusCode::Locked_423: return "Locked"; + case StatusCode::FailedDependency_424: return "Failed Dependency"; + case StatusCode::TooEarly_425: return "Too Early"; + case StatusCode::UpgradeRequired_426: return "Upgrade Required"; + case StatusCode::PreconditionRequired_428: return "Precondition Required"; + case StatusCode::TooManyRequests_429: return "Too Many Requests"; + case StatusCode::RequestHeaderFieldsTooLarge_431: + return "Request Header Fields Too Large"; + case StatusCode::UnavailableForLegalReasons_451: + return "Unavailable For Legal Reasons"; + case StatusCode::NotImplemented_501: return "Not Implemented"; + case StatusCode::BadGateway_502: return "Bad Gateway"; + case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; + case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; + case StatusCode::HttpVersionNotSupported_505: + return "HTTP Version Not Supported"; + case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; + case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; + case StatusCode::LoopDetected_508: return "Loop Detected"; + case StatusCode::NotExtended_510: return "Not Extended"; + case StatusCode::NetworkAuthenticationRequired_511: + return "Network Authentication Required"; + + default: + case StatusCode::InternalServerError_500: return "Internal Server Error"; + } +} + +inline std::string get_bearer_token_auth(const Request &req) { + if (req.has_header("Authorization")) { + static std::string BearerHeaderPrefix = "Bearer "; + return req.get_header_value("Authorization") + .substr(BearerHeaderPrefix.length()); + } + return ""; +} + +template +inline Server & +Server::set_read_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); + return *this; +} + +template +inline Server & +Server::set_write_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); + return *this; +} + +template +inline Server & +Server::set_idle_interval(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_idle_interval(sec, usec); }); + return *this; +} + +inline std::string to_string(const Error error) { + switch (error) { + case Error::Success: return "Success (no error)"; + case Error::Connection: return "Could not establish connection"; + case Error::BindIPAddress: return "Failed to bind IP address"; + case Error::Read: return "Failed to read connection"; + case Error::Write: return "Failed to write connection"; + case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; + case Error::Canceled: return "Connection handling canceled"; + case Error::SSLConnection: return "SSL connection failed"; + case Error::SSLLoadingCerts: return "SSL certificate loading failed"; + case Error::SSLServerVerification: return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: return "Compression failed"; + case Error::ConnectionTimeout: return "Connection timed out"; + case Error::ProxyConnection: return "Proxy connection failed"; + case Error::Unknown: return "Unknown"; + default: break; + } + + return "Invalid"; +} + +inline std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + +inline uint64_t Result::get_request_header_value_u64(const std::string &key, + uint64_t def, + size_t id) const { + return detail::get_header_value_u64(request_headers_, key, def, id); +} + +template +inline void ClientImpl::set_connection_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { + set_connection_timeout(sec, usec); + }); +} + +template +inline void ClientImpl::set_read_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_write_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); +} + +template +inline void Client::set_connection_timeout( + const std::chrono::duration &duration) { + cli_->set_connection_timeout(duration); +} + +template +inline void +Client::set_read_timeout(const std::chrono::duration &duration) { + cli_->set_read_timeout(duration); +} + +template +inline void +Client::set_write_timeout(const std::chrono::duration &duration) { + cli_->set_write_timeout(duration); +} + +/* + * Forward declarations and types that will be part of the .h file if split into + * .h + .cc. + */ + +std::string hosted_at(const std::string &hostname); + +void hosted_at(const std::string &hostname, std::vector &addrs); + +std::string append_query_params(const std::string &path, const Params ¶ms); + +std::pair make_range_header(const Ranges &ranges); + +std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false); + +namespace detail { + +#if defined(_WIN32) +inline std::wstring u8string_to_wstring(const char *s) { + std::wstring ws; + auto len = static_cast(strlen(s)); + auto wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, nullptr, 0); + if (wlen > 0) { + ws.resize(wlen); + wlen = ::MultiByteToWideChar( + CP_UTF8, 0, s, len, + const_cast(reinterpret_cast(ws.data())), wlen); + if (wlen != static_cast(ws.size())) { ws.clear(); } + } + return ws; +} +#endif + +struct FileStat { + FileStat(const std::string &path); + bool is_file() const; + bool is_dir() const; + +private: +#if defined(_WIN32) + struct _stat st_; +#else + struct stat st_; +#endif + int ret_ = -1; +}; + +std::string encode_query_param(const std::string &value); + +std::string decode_url(const std::string &s, bool convert_plus_to_space); + +void read_file(const std::string &path, std::string &out); + +std::string trim_copy(const std::string &s); + +void divide( + const char *data, std::size_t size, char d, + std::function + fn); + +void divide( + const std::string &str, char d, + std::function + fn); + +void split(const char *b, const char *e, char d, + std::function fn); + +void split(const char *b, const char *e, char d, size_t m, + std::function fn); + +bool process_client_socket(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, + std::function callback); + +socket_t create_client_socket(const std::string &host, const std::string &ip, + int port, int address_family, bool tcp_nodelay, + bool ipv6_v6only, SocketOptions socket_options, + time_t connection_timeout_sec, + time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + const std::string &intf, Error &error); + +const char *get_header_value(const Headers &headers, const std::string &key, + const char *def, size_t id); + +std::string params_to_query_str(const Params ¶ms); + +void parse_query_text(const char *data, std::size_t size, Params ¶ms); + +void parse_query_text(const std::string &s, Params ¶ms); + +bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary); + +bool parse_range_header(const std::string &s, Ranges &ranges); + +int close_socket(socket_t sock); + +ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); + +ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); + +enum class EncodingType { None = 0, Gzip, Brotli }; + +EncodingType encoding_type(const Request &req, const Response &res); + +class BufferStream final : public Stream { +public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + + const std::string &get_buffer() const; + +private: + std::string buffer; + size_t position = 0; +}; + +class compressor { +public: + virtual ~compressor() = default; + + typedef std::function Callback; + virtual bool compress(const char *data, size_t data_length, bool last, + Callback callback) = 0; +}; + +class decompressor { +public: + virtual ~decompressor() = default; + + virtual bool is_valid() const = 0; + + typedef std::function Callback; + virtual bool decompress(const char *data, size_t data_length, + Callback callback) = 0; +}; + +class nocompressor final : public compressor { +public: + ~nocompressor() override = default; + + bool compress(const char *data, size_t data_length, bool /*last*/, + Callback callback) override; +}; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +class gzip_compressor final : public compressor { +public: + gzip_compressor(); + ~gzip_compressor() override; + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; + +class gzip_decompressor final : public decompressor { +public: + gzip_decompressor(); + ~gzip_decompressor() override; + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +class brotli_compressor final : public compressor { +public: + brotli_compressor(); + ~brotli_compressor(); + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + BrotliEncoderState *state_ = nullptr; +}; + +class brotli_decompressor final : public decompressor { +public: + brotli_decompressor(); + ~brotli_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + BrotliDecoderResult decoder_r; + BrotliDecoderState *decoder_s = nullptr; +}; +#endif + +// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` +// to store data. The call can set memory on stack for performance. +class stream_line_reader { +public: + stream_line_reader(Stream &strm, char *fixed_buffer, + size_t fixed_buffer_size); + const char *ptr() const; + size_t size() const; + bool end_with_crlf() const; + bool getline(); + +private: + void append(char c); + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; +}; + +class mmap { +public: + mmap(const char *path); + ~mmap(); + + bool open(const char *path); + void close(); + + bool is_open() const; + size_t size() const; + const char *data() const; + +private: +#if defined(_WIN32) + HANDLE hFile_ = NULL; + HANDLE hMapping_ = NULL; +#else + int fd_ = -1; +#endif + size_t size_ = 0; + void *addr_ = nullptr; + bool is_open_empty_file = false; +}; + +// NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 +namespace fields { + +inline bool is_token_char(char c) { + return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || + c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; +} + +inline bool is_token(const std::string &s) { + if (s.empty()) { return false; } + for (auto c : s) { + if (!is_token_char(c)) { return false; } + } + return true; +} + +inline bool is_field_name(const std::string &s) { return is_token(s); } + +inline bool is_vchar(char c) { return c >= 33 && c <= 126; } + +inline bool is_obs_text(char c) { return 128 <= static_cast(c); } + +inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } + +inline bool is_field_content(const std::string &s) { + if (s.empty()) { return false; } + + if (s.size() == 1) { + return is_field_vchar(s[0]); + } else if (s.size() == 2) { + return is_field_vchar(s[0]) && is_field_vchar(s[1]); + } else { + size_t i = 0; + + if (!is_field_vchar(s[i])) { return false; } + i++; + + while (i < s.size() - 1) { + auto c = s[i++]; + if (c == ' ' || c == '\t' || is_field_vchar(c)) { + } else { + return false; + } + } + + return is_field_vchar(s[i]); + } +} + +inline bool is_field_value(const std::string &s) { return is_field_content(s); } + +} // namespace fields + +} // namespace detail + +// ---------------------------------------------------------------------------- + +/* + * Implementation that will be part of the .cc file if split into .h + .cc. + */ + +namespace detail { + +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} + +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, + int &val) { + if (i >= s.size()) { return false; } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + auto v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + static const auto charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = static_cast(code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + auto val = 0; + auto valb = -6; + + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + if (path[i] == '\0') { + return false; + } else if (path[i] == '\\') { + return false; + } + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline FileStat::FileStat(const std::string &path) { +#if defined(_WIN32) + auto wpath = u8string_to_wstring(path.c_str()); + ret_ = _wstat(wpath.c_str(), &st_); +#else + ret_ = stat(path.c_str(), &st_); +#endif +} +inline bool FileStat::is_file() const { + return ret_ >= 0 && S_ISREG(st_.st_mode); +} +inline bool FileStat::is_dir() const { + return ret_ >= 0 && S_ISDIR(st_.st_mode); +} + +inline std::string encode_query_param(const std::string &value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + + for (auto c : value) { + if (std::isalnum(static_cast(c)) || c == '-' || c == '_' || + c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || + c == ')') { + escaped << c; + } else { + escaped << std::uppercase; + escaped << '%' << std::setw(2) + << static_cast(static_cast(c)); + escaped << std::nouppercase; + } + } + + return escaped.str(); +} + +inline std::string encode_url(const std::string &s) { + std::string result; + result.reserve(s.size()); + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': result += "%3B"; break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s, + bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + auto val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + auto val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline void read_file(const std::string &path, std::string &out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { return m[1].str(); } + return std::string(); +} + +inline bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } + +inline std::pair trim(const char *b, const char *e, size_t left, + size_t right) { + while (b + left < e && is_space_or_tab(b[left])) { + left++; + } + while (right > 0 && is_space_or_tab(b[right - 1])) { + right--; + } + return std::make_pair(left, right); +} + +inline std::string trim_copy(const std::string &s) { + auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); + return s.substr(r.first, r.second - r.first); +} + +inline std::string trim_double_quotes_copy(const std::string &s) { + if (s.length() >= 2 && s.front() == '"' && s.back() == '"') { + return s.substr(1, s.size() - 2); + } + return s; +} + +inline void +divide(const char *data, std::size_t size, char d, + std::function + fn) { + const auto it = std::find(data, data + size, d); + const auto found = static_cast(it != data + size); + const auto lhs_data = data; + const auto lhs_size = static_cast(it - data); + const auto rhs_data = it + found; + const auto rhs_size = size - lhs_size - found; + + fn(lhs_data, lhs_size, rhs_data, rhs_size); +} + +inline void +divide(const std::string &str, char d, + std::function + fn) { + divide(str.data(), str.size(), d, std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, + std::function fn) { + return split(b, e, d, (std::numeric_limits::max)(), std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, size_t m, + std::function fn) { + size_t i = 0; + size_t beg = 0; + size_t count = 1; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d && count < m) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } + beg = i + 1; + count++; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } + } +} + +inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, + size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} + +inline const char *stream_line_reader::ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); + } +} + +inline size_t stream_line_reader::size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } +} + +inline bool stream_line_reader::end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; +} + +inline bool stream_line_reader::getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + +#ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + char prev_byte = 0; +#endif + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + if (byte == '\n') { break; } +#else + if (prev_byte == '\r' && byte == '\n') { break; } + prev_byte = byte; +#endif + } + + return true; +} + +inline void stream_line_reader::append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } +} + +inline mmap::mmap(const char *path) { open(path); } + +inline mmap::~mmap() { close(); } + +inline bool mmap::open(const char *path) { + close(); + +#if defined(_WIN32) + auto wpath = u8string_to_wstring(path); + if (wpath.empty()) { return false; } + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, + OPEN_EXISTING, NULL); +#else + hFile_ = ::CreateFileW(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, + OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); +#endif + + if (hFile_ == INVALID_HANDLE_VALUE) { return false; } + + LARGE_INTEGER size{}; + if (!::GetFileSizeEx(hFile_, &size)) { return false; } + // If the following line doesn't compile due to QuadPart, update Windows SDK. + // See: + // https://github.com/yhirose/cpp-httplib/issues/1903#issuecomment-2316520721 + if (static_cast(size.QuadPart) > + (std::numeric_limits::max)()) { + // `size_t` might be 32-bits, on 32-bits Windows. + return false; + } + size_ = static_cast(size.QuadPart); + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + hMapping_ = + ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL); +#else + hMapping_ = ::CreateFileMappingW(hFile_, NULL, PAGE_READONLY, 0, 0, NULL); +#endif + + // Special treatment for an empty file... + if (hMapping_ == NULL && size_ == 0) { + close(); + is_open_empty_file = true; + return true; + } + + if (hMapping_ == NULL) { + close(); + return false; + } + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + addr_ = ::MapViewOfFileFromApp(hMapping_, FILE_MAP_READ, 0, 0); +#else + addr_ = ::MapViewOfFile(hMapping_, FILE_MAP_READ, 0, 0, 0); +#endif + + if (addr_ == nullptr) { + close(); + return false; + } +#else + fd_ = ::open(path, O_RDONLY); + if (fd_ == -1) { return false; } + + struct stat sb; + if (fstat(fd_, &sb) == -1) { + close(); + return false; + } + size_ = static_cast(sb.st_size); + + addr_ = ::mmap(NULL, size_, PROT_READ, MAP_PRIVATE, fd_, 0); + + // Special treatment for an empty file... + if (addr_ == MAP_FAILED && size_ == 0) { + close(); + is_open_empty_file = true; + return false; + } +#endif + + return true; +} + +inline bool mmap::is_open() const { + return is_open_empty_file ? true : addr_ != nullptr; +} + +inline size_t mmap::size() const { return size_; } + +inline const char *mmap::data() const { + return is_open_empty_file ? "" : static_cast(addr_); +} + +inline void mmap::close() { +#if defined(_WIN32) + if (addr_) { + ::UnmapViewOfFile(addr_); + addr_ = nullptr; + } + + if (hMapping_) { + ::CloseHandle(hMapping_); + hMapping_ = NULL; + } + + if (hFile_ != INVALID_HANDLE_VALUE) { + ::CloseHandle(hFile_); + hFile_ = INVALID_HANDLE_VALUE; + } + + is_open_empty_file = false; +#else + if (addr_ != nullptr) { + munmap(addr_, size_); + addr_ = nullptr; + } + + if (fd_ != -1) { + ::close(fd_); + fd_ = -1; + } +#endif + size_ = 0; +} +inline int close_socket(socket_t sock) { +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif +} + +template inline ssize_t handle_EINTR(T fn) { + ssize_t res = 0; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } + break; + } + return res; +} + +inline ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return recv(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, + int flags) { + return handle_EINTR([&]() { + return send(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { return -1; } +#endif + + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { + return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); + }); +#endif +} + +inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { return -1; } +#endif + + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { + return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); + }); +#endif +} + +inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, + time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + + if (poll_res == 0) { return Error::ConnectionTimeout; } + + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + + return Error::Connection; +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { return Error::Connection; } +#endif + + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + auto ret = handle_EINTR([&]() { + return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); + }); + + if (ret == 0) { return Error::ConnectionTimeout; } + + if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + return Error::Connection; +#endif +} + +inline bool is_socket_alive(socket_t sock) { + const auto val = detail::select_read(sock, 0, 0); + if (val == 0) { + return true; + } else if (val < 0 && errno == EBADF) { + return false; + } + char buf[1]; + return detail::read_socket(sock, &buf[0], sizeof(buf), MSG_PEEK) > 0; +} + +class SocketStream final : public Stream { +public: + SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec); + ~SocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + +private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + + std::vector read_buff_; + size_t read_buff_off_ = 0; + size_t read_buff_content_size_ = 0; + + static const size_t read_buff_size_ = 1024l * 4; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream final : public Stream { +public: + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + +private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; +}; +#endif + +inline bool keep_alive(const std::atomic &svr_sock, socket_t sock, + time_t keep_alive_timeout_sec) { + using namespace std::chrono; + + const auto interval_usec = + CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND; + + // Avoid expensive `steady_clock::now()` call for the first time + if (select_read(sock, 0, interval_usec) > 0) { return true; } + + const auto start = steady_clock::now() - microseconds{interval_usec}; + const auto timeout = seconds{keep_alive_timeout_sec}; + + while (true) { + if (svr_sock == INVALID_SOCKET) { + break; // Server socket is closed + } + + auto val = select_read(sock, 0, interval_usec); + if (val < 0) { + break; // Ssocket error + } else if (val == 0) { + if (steady_clock::now() - start > timeout) { + break; // Timeout + } + } else { + return true; // Ready for read + } + } + + return false; +} + +template +inline bool +process_server_socket_core(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, T callback) { + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (count > 0 && keep_alive(svr_sock, sock, keep_alive_timeout_sec)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { break; } + count--; + } + return ret; +} + +template +inline bool +process_server_socket(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +inline bool process_client_socket(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + std::function callback) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); +} + +inline int shutdown_socket(socket_t sock) { +#ifdef _WIN32 + return shutdown(sock, SD_BOTH); +#else + return shutdown(sock, SHUT_RDWR); +#endif +} + +inline std::string escape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '\0') { + auto ret = s; + ret[0] = '@'; + return ret; + } + return s; +} + +inline std::string +unescape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '@') { + auto ret = s; + ret[0] = '\0'; + return ret; + } + return s; +} + +template +socket_t create_socket(const std::string &host, const std::string &ip, int port, + int address_family, int socket_flags, bool tcp_nodelay, + bool ipv6_v6only, SocketOptions socket_options, + BindOrConnect bind_or_connect) { + // Get address info + const char *node = nullptr; + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_IP; + + if (!ip.empty()) { + node = ip.c_str(); + // Ask getaddrinfo to convert IP in c-string to address + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + } else { + if (!host.empty()) { node = host.c_str(); } + hints.ai_family = address_family; + hints.ai_flags = socket_flags; + } + +#ifndef _WIN32 + if (hints.ai_family == AF_UNIX) { + const auto addrlen = host.length(); + if (addrlen > sizeof(sockaddr_un::sun_path)) { return INVALID_SOCKET; } + +#ifdef SOCK_CLOEXEC + auto sock = socket(hints.ai_family, hints.ai_socktype | SOCK_CLOEXEC, + hints.ai_protocol); +#else + auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol); +#endif + + if (sock != INVALID_SOCKET) { + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + + auto unescaped_host = unescape_abstract_namespace_unix_domain(host); + std::copy(unescaped_host.begin(), unescaped_host.end(), addr.sun_path); + + hints.ai_addr = reinterpret_cast(&addr); + hints.ai_addrlen = static_cast( + sizeof(addr) - sizeof(addr.sun_path) + addrlen); + +#ifndef SOCK_CLOEXEC + fcntl(sock, F_SETFD, FD_CLOEXEC); +#endif + + if (socket_options) { socket_options(sock); } + + bool dummy; + if (!bind_or_connect(sock, hints, dummy)) { + close_socket(sock); + sock = INVALID_SOCKET; + } + } + return sock; + } +#endif + + auto service = std::to_string(port); + + if (getaddrinfo(node, service.c_str(), &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return INVALID_SOCKET; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = + WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, nullptr, 0, + WSA_FLAG_NO_HANDLE_INHERIT | WSA_FLAG_OVERLAPPED); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + +#ifdef SOCK_CLOEXEC + auto sock = + socket(rp->ai_family, rp->ai_socktype | SOCK_CLOEXEC, rp->ai_protocol); +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + +#endif + if (sock == INVALID_SOCKET) { continue; } + +#if !defined _WIN32 && !defined SOCK_CLOEXEC + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { + close_socket(sock); + continue; + } +#endif + + if (tcp_nodelay) { + auto opt = 1; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, + reinterpret_cast(&opt), sizeof(opt)); +#else + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, + reinterpret_cast(&opt), sizeof(opt)); +#endif + } + + if (rp->ai_family == AF_INET6) { + auto opt = ipv6_v6only ? 1 : 0; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + reinterpret_cast(&opt), sizeof(opt)); +#else + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + reinterpret_cast(&opt), sizeof(opt)); +#endif + } + + if (socket_options) { socket_options(sock); } + + // bind or connect + auto quit = false; + if (bind_or_connect(sock, *rp, quit)) { return sock; } + + close_socket(sock); + + if (quit) { break; } + } + + return INVALID_SOCKET; +} + +inline void set_nonblocking(socket_t sock, bool nonblocking) { +#ifdef _WIN32 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif +} + +inline bool is_connection_error() { +#ifdef _WIN32 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif +} + +inline bool bind_ip_address(socket_t sock, const std::string &host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host.c_str(), "0", &hints, &result)) { return false; } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + return ret; +} + +#if !defined _WIN32 && !defined ANDROID && !defined _AIX && !defined __MVS__ +#define USE_IF2IP +#endif + +#ifdef USE_IF2IP +inline std::string if2ip(int address_family, const std::string &ifn) { + struct ifaddrs *ifap; + getifaddrs(&ifap); + auto se = detail::scope_exit([&] { freeifaddrs(ifap); }); + + std::string addr_candidate; + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name && + (AF_UNSPEC == address_family || + ifa->ifa_addr->sa_family == address_family)) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + return std::string(buf, INET_ADDRSTRLEN); + } + } else if (ifa->ifa_addr->sa_family == AF_INET6) { + auto sa = reinterpret_cast(ifa->ifa_addr); + if (!IN6_IS_ADDR_LINKLOCAL(&sa->sin6_addr)) { + char buf[INET6_ADDRSTRLEN] = {}; + if (inet_ntop(AF_INET6, &sa->sin6_addr, buf, INET6_ADDRSTRLEN)) { + // equivalent to mac's IN6_IS_ADDR_UNIQUE_LOCAL + auto s6_addr_head = sa->sin6_addr.s6_addr[0]; + if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) { + addr_candidate = std::string(buf, INET6_ADDRSTRLEN); + } else { + return std::string(buf, INET6_ADDRSTRLEN); + } + } + } + } + } + } + return addr_candidate; +} +#endif + +inline socket_t create_client_socket( + const std::string &host, const std::string &ip, int port, + int address_family, bool tcp_nodelay, bool ipv6_v6only, + SocketOptions socket_options, time_t connection_timeout_sec, + time_t connection_timeout_usec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, const std::string &intf, Error &error) { + auto sock = create_socket( + host, ip, port, address_family, 0, tcp_nodelay, ipv6_v6only, + std::move(socket_options), + [&](socket_t sock2, struct addrinfo &ai, bool &quit) -> bool { + if (!intf.empty()) { +#ifdef USE_IF2IP + auto ip_from_if = if2ip(address_family, intf); + if (ip_from_if.empty()) { ip_from_if = intf; } + if (!bind_ip_address(sock2, ip_from_if)) { + error = Error::BindIPAddress; + return false; + } +#endif + } + + set_nonblocking(sock2, true); + + auto ret = + ::connect(sock2, ai.ai_addr, static_cast(ai.ai_addrlen)); + + if (ret < 0) { + if (is_connection_error()) { + error = Error::Connection; + return false; + } + error = wait_until_socket_is_ready(sock2, connection_timeout_sec, + connection_timeout_usec); + if (error != Error::Success) { + if (error == Error::ConnectionTimeout) { quit = true; } + return false; + } + } + + set_nonblocking(sock2, false); + + { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec * 1000 + + read_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec); + tv.tv_usec = static_cast(read_timeout_usec); + setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&tv), sizeof(tv)); +#endif + } + { + +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec * 1000 + + write_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, + reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec); + tv.tv_usec = static_cast(write_timeout_usec); + setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, + reinterpret_cast(&tv), sizeof(tv)); +#endif + } + + error = Error::Success; + return true; + }); + + if (sock != INVALID_SOCKET) { + error = Error::Success; + } else { + if (error == Error::Success) { error = Error::Connection; } + } + + return sock; +} + +inline bool get_ip_and_port(const struct sockaddr_storage &addr, + socklen_t addr_len, std::string &ip, int &port) { + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = + ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return false; + } + + std::array ipstr{}; + if (getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + return false; + } + + ip = ipstr.data(); + return true; +} + +inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (!getsockname(sock, reinterpret_cast(&addr), + &addr_len)) { + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), + &addr_len)) { +#ifndef _WIN32 + if (addr.ss_family == AF_UNIX) { +#if defined(__linux__) + struct ucred ucred; + socklen_t len = sizeof(ucred); + if (getsockopt(sock, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == 0) { + port = ucred.pid; + } +#elif defined(SOL_LOCAL) && defined(SO_PEERPID) // __APPLE__ + pid_t pid; + socklen_t len = sizeof(pid); + if (getsockopt(sock, SOL_LOCAL, SO_PEERPID, &pid, &len) == 0) { + port = pid; + } +#endif + return; + } +#endif + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline constexpr unsigned int str2tag_core(const char *s, size_t l, + unsigned int h) { + return (l == 0) + ? h + : str2tag_core( + s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no overflow happens + (((std::numeric_limits::max)() >> 6) & + h * 33) ^ + static_cast(*s)); +} + +inline unsigned int str2tag(const std::string &s) { + return str2tag_core(s.data(), s.size(), 0); +} + +namespace udl { + +inline constexpr unsigned int operator""_t(const char *s, size_t l) { + return str2tag_core(s, l, 0); +} + +} // namespace udl + +inline std::string +find_content_type(const std::string &path, + const std::map &user_data, + const std::string &default_content_type) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { return it->second; } + + using udl::operator""_t; + + switch (str2tag(ext)) { + default: return default_content_type; + + case "css"_t: return "text/css"; + case "csv"_t: return "text/csv"; + case "htm"_t: + case "html"_t: return "text/html"; + case "js"_t: + case "mjs"_t: return "text/javascript"; + case "txt"_t: return "text/plain"; + case "vtt"_t: return "text/vtt"; + + case "apng"_t: return "image/apng"; + case "avif"_t: return "image/avif"; + case "bmp"_t: return "image/bmp"; + case "gif"_t: return "image/gif"; + case "png"_t: return "image/png"; + case "svg"_t: return "image/svg+xml"; + case "webp"_t: return "image/webp"; + case "ico"_t: return "image/x-icon"; + case "tif"_t: return "image/tiff"; + case "tiff"_t: return "image/tiff"; + case "jpg"_t: + case "jpeg"_t: return "image/jpeg"; + + case "mp4"_t: return "video/mp4"; + case "mpeg"_t: return "video/mpeg"; + case "webm"_t: return "video/webm"; + + case "mp3"_t: return "audio/mp3"; + case "mpga"_t: return "audio/mpeg"; + case "weba"_t: return "audio/webm"; + case "wav"_t: return "audio/wave"; + + case "otf"_t: return "font/otf"; + case "ttf"_t: return "font/ttf"; + case "woff"_t: return "font/woff"; + case "woff2"_t: return "font/woff2"; + + case "7z"_t: return "application/x-7z-compressed"; + case "atom"_t: return "application/atom+xml"; + case "pdf"_t: return "application/pdf"; + case "json"_t: return "application/json"; + case "rss"_t: return "application/rss+xml"; + case "tar"_t: return "application/x-tar"; + case "xht"_t: + case "xhtml"_t: return "application/xhtml+xml"; + case "xslt"_t: return "application/xslt+xml"; + case "xml"_t: return "application/xml"; + case "gz"_t: return "application/gzip"; + case "zip"_t: return "application/zip"; + case "wasm"_t: return "application/wasm"; + } +} + +inline bool can_compress_content_type(const std::string &content_type) { + using udl::operator""_t; + + auto tag = str2tag(content_type); + + switch (tag) { + case "image/svg+xml"_t: + case "application/javascript"_t: + case "application/json"_t: + case "application/xml"_t: + case "application/protobuf"_t: + case "application/xhtml+xml"_t: return true; + + case "text/event-stream"_t: return false; + + default: return !content_type.rfind("text/", 0); + } +} + +inline EncodingType encoding_type(const Request &req, const Response &res) { + auto ret = + detail::can_compress_content_type(res.get_header_value("Content-Type")); + if (!ret) { return EncodingType::None; } + + const auto &s = req.get_header_value("Accept-Encoding"); + (void)(s); + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + // TODO: 'Accept-Encoding' has br, not br;q=0 + ret = s.find("br") != std::string::npos; + if (ret) { return EncodingType::Brotli; } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + ret = s.find("gzip") != std::string::npos; + if (ret) { return EncodingType::Gzip; } +#endif + + return EncodingType::None; +} + +inline bool nocompressor::compress(const char *data, size_t data_length, + bool /*last*/, Callback callback) { + if (!data_length) { return true; } + return callback(data, data_length); +} + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline gzip_compressor::gzip_compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY) == Z_OK; +} + +inline gzip_compressor::~gzip_compressor() { deflateEnd(&strm_); } + +inline bool gzip_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + assert(is_valid_); + + do { + constexpr size_t max_avail_in = + (std::numeric_limits::max)(); + + strm_.avail_in = static_cast( + (std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + auto flush = (last && data_length == 0) ? Z_FINISH : Z_NO_FLUSH; + auto ret = Z_OK; + + std::array buff{}; + do { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = deflate(&strm_, flush); + if (ret == Z_STREAM_ERROR) { return false; } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); + + assert((flush == Z_FINISH && ret == Z_STREAM_END) || + (flush == Z_NO_FLUSH && ret == Z_OK)); + assert(strm_.avail_in == 0); + } while (data_length > 0); + + return true; +} + +inline gzip_decompressor::gzip_decompressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; +} + +inline gzip_decompressor::~gzip_decompressor() { inflateEnd(&strm_); } + +inline bool gzip_decompressor::is_valid() const { return is_valid_; } + +inline bool gzip_decompressor::decompress(const char *data, size_t data_length, + Callback callback) { + assert(is_valid_); + + auto ret = Z_OK; + + do { + constexpr size_t max_avail_in = + (std::numeric_limits::max)(); + + strm_.avail_in = static_cast( + (std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + std::array buff{}; + while (strm_.avail_in > 0 && ret == Z_OK) { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = inflate(&strm_, Z_NO_FLUSH); + + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm_); return false; + } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } + + if (ret != Z_OK && ret != Z_STREAM_END) { return false; } + + } while (data_length > 0); + + return true; +} +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +inline brotli_compressor::brotli_compressor() { + state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); +} + +inline brotli_compressor::~brotli_compressor() { + BrotliEncoderDestroyInstance(state_); +} + +inline bool brotli_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + std::array buff{}; + + auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; + auto available_in = data_length; + auto next_in = reinterpret_cast(data); + + for (;;) { + if (last) { + if (BrotliEncoderIsFinished(state_)) { break; } + } else { + if (!available_in) { break; } + } + + auto available_out = buff.size(); + auto next_out = buff.data(); + + if (!BrotliEncoderCompressStream(state_, operation, &available_in, &next_in, + &available_out, &next_out, nullptr)) { + return false; + } + + auto output_bytes = buff.size() - available_out; + if (output_bytes) { + callback(reinterpret_cast(buff.data()), output_bytes); + } + } + + return true; +} + +inline brotli_decompressor::brotli_decompressor() { + decoder_s = BrotliDecoderCreateInstance(0, 0, 0); + decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT + : BROTLI_DECODER_RESULT_ERROR; +} + +inline brotli_decompressor::~brotli_decompressor() { + if (decoder_s) { BrotliDecoderDestroyInstance(decoder_s); } +} + +inline bool brotli_decompressor::is_valid() const { return decoder_s; } + +inline bool brotli_decompressor::decompress(const char *data, + size_t data_length, + Callback callback) { + if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return 0; + } + + auto next_in = reinterpret_cast(data); + size_t avail_in = data_length; + size_t total_out; + + decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; + + std::array buff{}; + while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + char *next_out = buff.data(); + size_t avail_out = buff.size(); + + decoder_r = BrotliDecoderDecompressStream( + decoder_s, &avail_in, &next_in, &avail_out, + reinterpret_cast(&next_out), &total_out); + + if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { return false; } + + if (!callback(buff.data(), buff.size() - avail_out)) { return false; } + } + + return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; +} +#endif + +inline bool has_header(const Headers &headers, const std::string &key) { + return headers.find(key) != headers.end(); +} + +inline const char *get_header_value(const Headers &headers, + const std::string &key, const char *def, + size_t id) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second.c_str(); } + return def; +} + +template +inline bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + if (p == end) { return false; } + + auto key_end = p; + + if (*p++ != ':') { return false; } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p <= end) { + auto key_len = key_end - beg; + if (!key_len) { return false; } + + auto key = std::string(beg, key_end); + auto val = case_ignore::equal(key, "Location") + ? std::string(p, end) + : decode_url(std::string(p, end), false); + + // NOTE: From RFC 9110: + // Field values containing CR, LF, or NUL characters are + // invalid and dangerous, due to the varying ways that + // implementations might parse and interpret those + // characters; a recipient of CR, LF, or NUL within a field + // value MUST either reject the message or replace each of + // those characters with SP before further processing or + // forwarding of that message. + static const std::string CR_LF_NUL("\r\n\0", 3); + if (val.find_first_of(CR_LF_NUL) != std::string::npos) { return false; } + + fn(key, val); + return true; + } + + return false; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { return false; } + + // Check if the line ends with CRLF. + auto line_terminator_len = 2; + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { break; } + } else { +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + // Blank line indicates end of headers. + if (line_reader.size() == 1) { break; } + line_terminator_len = 1; +#else + continue; // Skip invalid line. +#endif + } + + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + + // Exclude line terminator + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + if (!parse_header(line_reader.ptr(), end, + [&](const std::string &key, std::string &val) { + headers.emplace(key, val); + })) { + return false; + } + } + + return true; +} + +inline bool read_content_with_length(Stream &strm, uint64_t len, + Progress progress, + ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } + + if (!out(buf, static_cast(n), r, len)) { return false; } + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { return false; } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += static_cast(n); + } +} + +inline bool read_content_without_length(Stream &strm, + ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n <= 0) { return true; } + + if (!out(buf, static_cast(n), r, 0)) { return false; } + r += static_cast(n); + } + + return true; +} + +template +inline bool read_content_chunked(Stream &strm, T &x, + ContentReceiverWithProgress out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { return false; } + + unsigned long chunk_len; + while (true) { + char *end_ptr; + + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + + if (end_ptr == line_reader.ptr()) { return false; } + if (chunk_len == ULONG_MAX) { return false; } + + if (chunk_len == 0) { break; } + + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } + + if (!line_reader.getline()) { return false; } + + if (strcmp(line_reader.ptr(), "\r\n") != 0) { return false; } + + if (!line_reader.getline()) { return false; } + } + + assert(chunk_len == 0); + + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentiones "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // does't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-htpplib now allows + // chuncked transfer coding data without the final CRLF. + if (!line_reader.getline()) { return true; } + + while (strcmp(line_reader.ptr(), "\r\n") != 0) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + + // Exclude line terminator + constexpr auto line_terminator_len = 2; + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + parse_header(line_reader.ptr(), end, + [&](const std::string &key, const std::string &val) { + x.headers.emplace(key, val); + }); + + if (!line_reader.getline()) { return false; } + } + + return true; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return case_ignore::equal( + get_header_value(headers, "Transfer-Encoding", "", 0), "chunked"); +} + +template +bool prepare_content_receiver(T &x, int &status, + ContentReceiverWithProgress receiver, + bool decompress, U callback) { + if (decompress) { + std::string encoding = x.get_header_value("Content-Encoding"); + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } + + if (decompressor) { + if (decompressor->is_valid()) { + ContentReceiverWithProgress out = [&](const char *buf, size_t n, + uint64_t off, uint64_t len) { + return decompressor->decompress(buf, n, + [&](const char *buf2, size_t n2) { + return receiver(buf2, n2, off, len); + }); + }; + return callback(std::move(out)); + } else { + status = StatusCode::InternalServerError_500; + return false; + } + } + } + + ContentReceiverWithProgress out = [&](const char *buf, size_t n, uint64_t off, + uint64_t len) { + return receiver(buf, n, off, len); + }; + return callback(std::move(out)); +} + +template +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, + Progress progress, ContentReceiverWithProgress receiver, + bool decompress) { + return prepare_content_receiver( + x, status, std::move(receiver), decompress, + [&](const ContentReceiverWithProgress &out) { + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, x, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto is_invalid_value = false; + auto len = get_header_value_u64(x.headers, "Content-Length", + (std::numeric_limits::max)(), + 0, is_invalid_value); + + if (is_invalid_value) { + ret = false; + } else if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, std::move(progress), out); + } + } + + if (!ret) { + status = exceed_payload_max_length ? StatusCode::PayloadTooLarge_413 + : StatusCode::BadRequest_400; + } + return ret; + }); +} + +inline ssize_t write_request_line(Stream &strm, const std::string &method, + const std::string &path) { + std::string s = method; + s += " "; + s += path; + s += " HTTP/1.1\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_response_line(Stream &strm, int status) { + std::string s = "HTTP/1.1 "; + s += std::to_string(status); + s += " "; + s += httplib::status_message(status); + s += "\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_headers(Stream &strm, const Headers &headers) { + ssize_t write_len = 0; + for (const auto &x : headers) { + std::string s; + s = x.first; + s += ": "; + s += x.second; + s += "\r\n"; + + auto len = strm.write(s.data(), s.size()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; +} + +inline bool write_data(Stream &strm, const char *d, size_t l) { + size_t offset = 0; + while (offset < l) { + auto length = strm.write(d + offset, l - offset); + if (length < 0) { return false; } + offset += static_cast(length); + } + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, + size_t offset, size_t length, T is_shutting_down, + Error &error) { + size_t end_offset = offset + length; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + if (strm.is_writable() && write_data(strm, d, l)) { + offset += l; + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + while (offset < end_offset && !is_shutting_down()) { + if (!strm.is_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, end_offset - offset, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, + size_t offset, size_t length, + const T &is_shutting_down) { + auto error = Error::Success; + return write_content(strm, content_provider, offset, length, is_shutting_down, + error); +} + +template +inline bool +write_content_without_length(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + offset += l; + if (!strm.is_writable() || !write_data(strm, d, l)) { ok = false; } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + data_sink.done = [&](void) { data_available = false; }; + + while (data_available && !is_shutting_down()) { + if (!strm.is_writable()) { + return false; + } else if (!content_provider(offset, 0, data_sink)) { + return false; + } else if (!ok) { + return false; + } + } + return true; +} + +template +inline bool +write_content_chunked(Stream &strm, const ContentProvider &content_provider, + const T &is_shutting_down, U &compressor, Error &error) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + data_available = l > 0; + offset += l; + + std::string payload; + if (compressor.compress(d, l, false, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = + from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!strm.is_writable() || + !write_data(strm, chunk.data(), chunk.size())) { + ok = false; + } + } + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + auto done_with_trailer = [&](const Headers *trailer) { + if (!ok) { return; } + + data_available = false; + + std::string payload; + if (!compressor.compress(nullptr, 0, true, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!strm.is_writable() || + !write_data(strm, chunk.data(), chunk.size())) { + ok = false; + return; + } + } + + static const std::string done_marker("0\r\n"); + if (!write_data(strm, done_marker.data(), done_marker.size())) { + ok = false; + } + + // Trailer + if (trailer) { + for (const auto &kv : *trailer) { + std::string field_line = kv.first + ": " + kv.second + "\r\n"; + if (!write_data(strm, field_line.data(), field_line.size())) { + ok = false; + } + } + } + + static const std::string crlf("\r\n"); + if (!write_data(strm, crlf.data(), crlf.size())) { ok = false; } + }; + + data_sink.done = [&](void) { done_with_trailer(nullptr); }; + + data_sink.done_with_trailer = [&](const Headers &trailer) { + done_with_trailer(&trailer); + }; + + while (data_available && !is_shutting_down()) { + if (!strm.is_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, 0, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content_chunked(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down, U &compressor) { + auto error = Error::Success; + return write_content_chunked(strm, content_provider, is_shutting_down, + compressor, error); +} + +template +inline bool redirect(T &cli, Request &req, Response &res, + const std::string &path, const std::string &location, + Error &error) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count_ -= 1; + + if (res.status == StatusCode::SeeOther_303 && + (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } + + Response new_res; + + auto ret = cli.send(new_req, new_res, error); + if (ret) { + req = new_req; + res = new_res; + + if (res.location.empty()) { res.location = location; } + } + return ret; +} + +inline std::string params_to_query_str(const Params ¶ms) { + std::string query; + + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += encode_query_param(it->second); + } + return query; +} + +inline void parse_query_text(const char *data, std::size_t size, + Params ¶ms) { + std::set cache; + split(data, data + size, '&', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { return; } + cache.insert(std::move(kv)); + + std::string key; + std::string val; + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, const char *rhs_data, + std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); + + if (!key.empty()) { + params.emplace(decode_url(key, true), decode_url(val, true)); + } + }); +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { + parse_query_text(s.data(), s.size(), params); +} + +inline bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary) { + auto boundary_keyword = "boundary="; + auto pos = content_type.find(boundary_keyword); + if (pos == std::string::npos) { return false; } + auto end = content_type.find(';', pos); + auto beg = pos + strlen(boundary_keyword); + boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg)); + return !boundary.empty(); +} + +inline void parse_disposition_params(const std::string &s, Params ¶ms) { + std::set cache; + split(s.data(), s.data() + s.size(), ';', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { return; } + cache.insert(kv); + + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + + if (!key.empty()) { + params.emplace(trim_double_quotes_copy((key)), + trim_double_quotes_copy((val))); + } + }); +} + +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +inline bool parse_range_header(const std::string &s, Ranges &ranges) { +#else +inline bool parse_range_header(const std::string &s, Ranges &ranges) try { +#endif + auto is_valid = [](const std::string &str) { + return std::all_of(str.cbegin(), str.cend(), + [](unsigned char c) { return std::isdigit(c); }); + }; + + if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) { + const auto pos = static_cast(6); + const auto len = static_cast(s.size() - 6); + auto all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) { return; } + + const auto it = std::find(b, e, '-'); + if (it == e) { + all_valid_ranges = false; + return; + } + + const auto lhs = std::string(b, it); + const auto rhs = std::string(it + 1, e); + if (!is_valid(lhs) || !is_valid(rhs)) { + all_valid_ranges = false; + return; + } + + const auto first = + static_cast(lhs.empty() ? -1 : std::stoll(lhs)); + const auto last = + static_cast(rhs.empty() ? -1 : std::stoll(rhs)); + if ((first == -1 && last == -1) || + (first != -1 && last != -1 && first > last)) { + all_valid_ranges = false; + return; + } + + ranges.emplace_back(first, last); + }); + return all_valid_ranges && !ranges.empty(); + } + return false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +} +#else +} catch (...) { return false; } +#endif + +class MultipartFormDataParser { +public: + MultipartFormDataParser() = default; + + void set_boundary(std::string &&boundary) { + boundary_ = boundary; + dash_boundary_crlf_ = dash_ + boundary_ + crlf_; + crlf_dash_boundary_ = crlf_ + dash_ + boundary_; + } + + bool is_valid() const { return is_valid_; } + + bool parse(const char *buf, size_t n, const ContentReceiver &content_callback, + const MultipartContentHeader &header_callback) { + + buf_append(buf, n); + + while (buf_size() > 0) { + switch (state_) { + case 0: { // Initial boundary + buf_erase(buf_find(dash_boundary_crlf_)); + if (dash_boundary_crlf_.size() > buf_size()) { return true; } + if (!buf_start_with(dash_boundary_crlf_)) { return false; } + buf_erase(dash_boundary_crlf_.size()); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_find(crlf_); + if (pos > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + while (pos < buf_size()) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + return false; + } + buf_erase(crlf_.size()); + state_ = 3; + break; + } + + const auto header = buf_head(pos); + + if (!parse_header(header.data(), header.data() + header.size(), + [&](const std::string &, const std::string &) {})) { + is_valid_ = false; + return false; + } + + static const std::string header_content_type = "Content-Type:"; + + if (start_with_case_ignore(header, header_content_type)) { + file_.content_type = + trim_copy(header.substr(header_content_type.size())); + } else { + static const std::regex re_content_disposition( + R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", + std::regex_constants::icase); + + std::smatch m; + if (std::regex_match(header, m, re_content_disposition)) { + Params params; + parse_disposition_params(m[1], params); + + auto it = params.find("name"); + if (it != params.end()) { + file_.name = it->second; + } else { + is_valid_ = false; + return false; + } + + it = params.find("filename"); + if (it != params.end()) { file_.filename = it->second; } + + it = params.find("filename*"); + if (it != params.end()) { + // Only allow UTF-8 enconnding... + static const std::regex re_rfc5987_encoding( + R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); + + std::smatch m2; + if (std::regex_match(it->second, m2, re_rfc5987_encoding)) { + file_.filename = decode_url(m2[1], false); // override... + } else { + is_valid_ = false; + return false; + } + } + } + } + buf_erase(pos + crlf_.size()); + pos = buf_find(crlf_); + } + if (state_ != 3) { return true; } + break; + } + case 3: { // Body + if (crlf_dash_boundary_.size() > buf_size()) { return true; } + auto pos = buf_find(crlf_dash_boundary_); + if (pos < buf_size()) { + if (!content_callback(buf_data(), pos)) { + is_valid_ = false; + return false; + } + buf_erase(pos + crlf_dash_boundary_.size()); + state_ = 4; + } else { + auto len = buf_size() - crlf_dash_boundary_.size(); + if (len > 0) { + if (!content_callback(buf_data(), len)) { + is_valid_ = false; + return false; + } + buf_erase(len); + } + return true; + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_size()) { return true; } + if (buf_start_with(crlf_)) { + buf_erase(crlf_.size()); + state_ = 1; + } else { + if (dash_.size() > buf_size()) { return true; } + if (buf_start_with(dash_)) { + buf_erase(dash_.size()); + is_valid_ = true; + buf_erase(buf_size()); // Remove epilogue + } else { + return true; + } + } + break; + } + } + } + + return true; + } + +private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + bool start_with_case_ignore(const std::string &a, + const std::string &b) const { + if (a.size() < b.size()) { return false; } + for (size_t i = 0; i < b.size(); i++) { + if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { + return false; + } + } + return true; + } + + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + std::string boundary_; + std::string dash_boundary_crlf_; + std::string crlf_dash_boundary_; + + size_t state_ = 0; + bool is_valid_ = false; + MultipartFormData file_; + + // Buffer + bool start_with(const std::string &a, size_t spos, size_t epos, + const std::string &b) const { + if (epos - spos < b.size()) { return false; } + for (size_t i = 0; i < b.size(); i++) { + if (a[i + spos] != b[i]) { return false; } + } + return true; + } + + size_t buf_size() const { return buf_epos_ - buf_spos_; } + + const char *buf_data() const { return &buf_[buf_spos_]; } + + std::string buf_head(size_t l) const { return buf_.substr(buf_spos_, l); } + + bool buf_start_with(const std::string &s) const { + return start_with(buf_, buf_spos_, buf_epos_, s); + } + + size_t buf_find(const std::string &s) const { + auto c = s.front(); + + size_t off = buf_spos_; + while (off < buf_epos_) { + auto pos = off; + while (true) { + if (pos == buf_epos_) { return buf_size(); } + if (buf_[pos] == c) { break; } + pos++; + } + + auto remaining_size = buf_epos_ - pos; + if (s.size() > remaining_size) { return buf_size(); } + + if (start_with(buf_, pos, buf_epos_, s)) { return pos - buf_spos_; } + + off = pos + 1; + } + + return buf_size(); + } + + void buf_append(const char *data, size_t n) { + auto remaining_size = buf_size(); + if (remaining_size > 0 && buf_spos_ > 0) { + for (size_t i = 0; i < remaining_size; i++) { + buf_[i] = buf_[buf_spos_ + i]; + } + } + buf_spos_ = 0; + buf_epos_ = remaining_size; + + if (remaining_size + n > buf_.size()) { buf_.resize(remaining_size + n); } + + for (size_t i = 0; i < n; i++) { + buf_[buf_epos_ + i] = data[i]; + } + buf_epos_ += n; + } + + void buf_erase(size_t size) { buf_spos_ += size; } + + std::string buf_; + size_t buf_spos_ = 0; + size_t buf_epos_ = 0; +}; + +inline std::string random_string(size_t length) { + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. + static std::random_device seed_gen; + + // Request 128 bits of entropy for initialization + static std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), + seed_gen()}; + + static std::mt19937 engine(seed_sequence); + + std::string result; + for (size_t i = 0; i < length; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + return result; +} + +inline std::string make_multipart_data_boundary() { + return "--cpp-httplib-multipart-data-" + detail::random_string(16); +} + +inline bool is_multipart_boundary_chars_valid(const std::string &boundary) { + auto valid = true; + for (size_t i = 0; i < boundary.size(); i++) { + auto c = boundary[i]; + if (!std::isalnum(c) && c != '-' && c != '_') { + valid = false; + break; + } + } + return valid; +} + +template +inline std::string +serialize_multipart_formdata_item_begin(const T &item, + const std::string &boundary) { + std::string body = "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + + return body; +} + +inline std::string serialize_multipart_formdata_item_end() { return "\r\n"; } + +inline std::string +serialize_multipart_formdata_finish(const std::string &boundary) { + return "--" + boundary + "--\r\n"; +} + +inline std::string +serialize_multipart_formdata_get_content_type(const std::string &boundary) { + return "multipart/form-data; boundary=" + boundary; +} + +inline std::string +serialize_multipart_formdata(const MultipartFormDataItems &items, + const std::string &boundary, bool finish = true) { + std::string body; + + for (const auto &item : items) { + body += serialize_multipart_formdata_item_begin(item, boundary); + body += item.content + serialize_multipart_formdata_item_end(); + } + + if (finish) { body += serialize_multipart_formdata_finish(boundary); } + + return body; +} + +inline bool range_error(Request &req, Response &res) { + if (!req.ranges.empty() && 200 <= res.status && res.status < 300) { + ssize_t contant_len = static_cast( + res.content_length_ ? res.content_length_ : res.body.size()); + + ssize_t prev_first_pos = -1; + ssize_t prev_last_pos = -1; + size_t overwrapping_count = 0; + + // NOTE: The following Range check is based on '14.2. Range' in RFC 9110 + // 'HTTP Semantics' to avoid potential denial-of-service attacks. + // https://www.rfc-editor.org/rfc/rfc9110#section-14.2 + + // Too many ranges + if (req.ranges.size() > CPPHTTPLIB_RANGE_MAX_COUNT) { return true; } + + for (auto &r : req.ranges) { + auto &first_pos = r.first; + auto &last_pos = r.second; + + if (first_pos == -1 && last_pos == -1) { + first_pos = 0; + last_pos = contant_len; + } + + if (first_pos == -1) { + first_pos = contant_len - last_pos; + last_pos = contant_len - 1; + } + + // NOTE: RFC-9110 '14.1.2. Byte Ranges': + // A client can limit the number of bytes requested without knowing the + // size of the selected representation. If the last-pos value is absent, + // or if the value is greater than or equal to the current length of the + // representation data, the byte range is interpreted as the remainder of + // the representation (i.e., the server replaces the value of last-pos + // with a value that is one less than the current length of the selected + // representation). + // https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2-6 + if (last_pos == -1 || last_pos >= contant_len) { + last_pos = contant_len - 1; + } + + // Range must be within content length + if (!(0 <= first_pos && first_pos <= last_pos && + last_pos <= contant_len - 1)) { + return true; + } + + // Ranges must be in ascending order + if (first_pos <= prev_first_pos) { return true; } + + // Request must not have more than two overlapping ranges + if (first_pos <= prev_last_pos) { + overwrapping_count++; + if (overwrapping_count > 2) { return true; } + } + + prev_first_pos = (std::max)(prev_first_pos, first_pos); + prev_last_pos = (std::max)(prev_last_pos, last_pos); + } + } + + return false; +} + +inline std::pair +get_range_offset_and_length(Range r, size_t content_length) { + assert(r.first != -1 && r.second != -1); + assert(0 <= r.first && r.first < static_cast(content_length)); + assert(r.first <= r.second && + r.second < static_cast(content_length)); + (void)(content_length); + return std::make_pair(r.first, static_cast(r.second - r.first) + 1); +} + +inline std::string make_content_range_header_field( + const std::pair &offset_and_length, size_t content_length) { + auto st = offset_and_length.first; + auto ed = st + offset_and_length.second - 1; + + std::string field = "bytes "; + field += std::to_string(st); + field += "-"; + field += std::to_string(ed); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, + const std::string &boundary, + const std::string &content_type, + size_t content_length, SToken stoken, + CToken ctoken, Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offset_and_length = + get_range_offset_and_length(req.ranges[i], content_length); + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset_and_length, content_length)); + ctoken("\r\n"); + ctoken("\r\n"); + + if (!content(offset_and_length.first, offset_and_length.second)) { + return false; + } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--"); + + return true; +} + +inline void make_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + size_t content_length, + std::string &data) { + process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { data += token; }, + [&](const std::string &token) { data += token; }, + [&](size_t offset, size_t length) { + assert(offset + length <= content_length); + data += res.body.substr(offset, length); + return true; + }); +} + +inline size_t get_multipart_ranges_data_length(const Request &req, + const std::string &boundary, + const std::string &content_type, + size_t content_length) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { data_length += token.size(); }, + [&](const std::string &token) { data_length += token.size(); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +template +inline bool +write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + size_t content_length, const T &is_shutting_down) { + return process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { strm.write(token); }, + [&](const std::string &token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider_, offset, length, + is_shutting_down); + }); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "PRI" || req.method == "DELETE") { + return true; + } + // TODO: check if Content-Length is set + return false; +} + +inline bool has_crlf(const std::string &s) { + auto p = s.c_str(); + while (*p) { + if (*p == '\r' || *p == '\n') { return true; } + p++; + } + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::string message_digest(const std::string &s, const EVP_MD *algo) { + auto context = std::unique_ptr( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + + unsigned int hash_length = 0; + unsigned char hash[EVP_MAX_MD_SIZE]; + + EVP_DigestInit_ex(context.get(), algo, nullptr); + EVP_DigestUpdate(context.get(), s.c_str(), s.size()); + EVP_DigestFinal_ex(context.get(), hash, &hash_length); + + std::stringstream ss; + for (auto i = 0u; i < hash_length; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + + return ss.str(); +} + +inline std::string MD5(const std::string &s) { + return message_digest(s, EVP_md5()); +} + +inline std::string SHA_256(const std::string &s) { + return message_digest(s, EVP_sha256()); +} + +inline std::string SHA_512(const std::string &s) { + return message_digest(s, EVP_sha512()); +} +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store +inline bool load_system_certs_on_windows(X509_STORE *store) { + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); + if (!hStore) { return false; } + + auto result = false; + PCCERT_CONTEXT pContext = NULL; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + auto encoded_cert = + static_cast(pContext->pbCertEncoded); + + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); + + return result; +} +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX +template +using CFObjectPtr = + std::unique_ptr::type, void (*)(CFTypeRef)>; + +inline void cf_object_ptr_deleter(CFTypeRef obj) { + if (obj) { CFRelease(obj); } +} + +inline bool retrieve_certs_from_keychain(CFObjectPtr &certs) { + CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef}; + CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll, + kCFBooleanTrue}; + + CFObjectPtr query( + CFDictionaryCreate(nullptr, reinterpret_cast(keys), values, + sizeof(keys) / sizeof(keys[0]), + &kCFTypeDictionaryKeyCallBacks, + &kCFTypeDictionaryValueCallBacks), + cf_object_ptr_deleter); + + if (!query) { return false; } + + CFTypeRef security_items = nullptr; + if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess || + CFArrayGetTypeID() != CFGetTypeID(security_items)) { + return false; + } + + certs.reset(reinterpret_cast(security_items)); + return true; +} + +inline bool retrieve_root_certs_from_keychain(CFObjectPtr &certs) { + CFArrayRef root_security_items = nullptr; + if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) { + return false; + } + + certs.reset(root_security_items); + return true; +} + +inline bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) { + auto result = false; + for (auto i = 0; i < CFArrayGetCount(certs); ++i) { + const auto cert = reinterpret_cast( + CFArrayGetValueAtIndex(certs, i)); + + if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { continue; } + + CFDataRef cert_data = nullptr; + if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) != + errSecSuccess) { + continue; + } + + CFObjectPtr cert_data_ptr(cert_data, cf_object_ptr_deleter); + + auto encoded_cert = static_cast( + CFDataGetBytePtr(cert_data_ptr.get())); + + auto x509 = + d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get())); + + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + return result; +} + +inline bool load_system_certs_on_macos(X509_STORE *store) { + auto result = false; + CFObjectPtr certs(nullptr, cf_object_ptr_deleter); + if (retrieve_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store); + } + + if (retrieve_root_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store) || result; + } + + return result; +} +#endif // TARGET_OS_OSX +#endif // _WIN32 +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef _WIN32 +class WSInit { +public: + WSInit() { + WSADATA wsaData; + if (WSAStartup(0x0002, &wsaData) == 0) is_valid_ = true; + } + + ~WSInit() { + if (is_valid_) WSACleanup(); + } + + bool is_valid_ = false; +}; + +static WSInit wsinit_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 + : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" + : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + + cnonce + "\", response=\"") + + response + "\"" + + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} +#endif + +inline bool parse_www_authenticate(const Response &res, + std::map &auth, + bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + const auto &m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 + ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +class ContentProviderAdapter { +public: + explicit ContentProviderAdapter( + ContentProviderWithoutLength &&content_provider) + : content_provider_(content_provider) {} + + bool operator()(size_t offset, size_t, DataSink &sink) { + return content_provider_(offset, sink); + } + +private: + ContentProviderWithoutLength content_provider_; +}; + +} // namespace detail + +inline std::string hosted_at(const std::string &hostname) { + std::vector addrs; + hosted_at(hostname, addrs); + if (addrs.empty()) { return std::string(); } + return addrs[0]; +} + +inline void hosted_at(const std::string &hostname, + std::vector &addrs) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(hostname.c_str(), nullptr, &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &addr = + *reinterpret_cast(rp->ai_addr); + std::string ip; + auto dummy = -1; + if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, + dummy)) { + addrs.push_back(ip); + } + } +} + +inline std::string append_query_params(const std::string &path, + const Params ¶ms) { + std::string path_with_query = path; + const static std::regex re("[^?]+\\?.*"); + auto delm = std::regex_match(path, re) ? '&' : '?'; + path_with_query += delm + detail::params_to_query_str(params); + return path_with_query; +} + +// Header utilities +inline std::pair +make_range_header(const Ranges &ranges) { + std::string field = "bytes="; + auto i = 0; + for (const auto &r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", std::move(field)); +} + +inline std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, bool is_proxy) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +inline std::pair +make_bearer_token_authentication_header(const std::string &token, + bool is_proxy = false) { + auto field = "Bearer " + token; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +// Request implementation +inline bool Request::has_header(const std::string &key) const { + return detail::has_header(headers, key); +} + +inline std::string Request::get_header_value(const std::string &key, + const char *def, size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Request::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Request::set_header(const std::string &key, + const std::string &val) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { + headers.emplace(key, val); + } +} + +inline bool Request::has_param(const std::string &key) const { + return params.find(key) != params.end(); +} + +inline std::string Request::get_param_value(const std::string &key, + size_t id) const { + auto rng = params.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return std::string(); +} + +inline size_t Request::get_param_value_count(const std::string &key) const { + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.rfind("multipart/form-data", 0); +} + +inline bool Request::has_file(const std::string &key) const { + return files.find(key) != files.end(); +} + +inline MultipartFormData Request::get_file_value(const std::string &key) const { + auto it = files.find(key); + if (it != files.end()) { return it->second; } + return MultipartFormData(); +} + +inline std::vector +Request::get_file_values(const std::string &key) const { + std::vector values; + auto rng = files.equal_range(key); + for (auto it = rng.first; it != rng.second; it++) { + values.push_back(it->second); + } + return values; +} + +// Response implementation +inline bool Response::has_header(const std::string &key) const { + return headers.find(key) != headers.end(); +} + +inline std::string Response::get_header_value(const std::string &key, + const char *def, + size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Response::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_header(const std::string &key, + const std::string &val) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { + headers.emplace(key, val); + } +} + +inline void Response::set_redirect(const std::string &url, int stat) { + if (detail::fields::is_field_value(url)) { + set_header("Location", url); + if (300 <= stat && stat < 400) { + this->status = stat; + } else { + this->status = StatusCode::Found_302; + } + } +} + +inline void Response::set_content(const char *s, size_t n, + const std::string &content_type) { + body.assign(s, n); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(const std::string &s, + const std::string &content_type) { + set_content(s.data(), s.size(), content_type); +} + +inline void Response::set_content(std::string &&s, + const std::string &content_type) { + body = std::move(s); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider( + size_t in_length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = in_length; + if (in_length > 0) { content_provider_ = std::move(provider); } + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_chunked_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = true; +} + +inline void Response::set_file_content(const std::string &path, + const std::string &content_type) { + file_content_path_ = path; + file_content_content_type_ = content_type; +} + +inline void Response::set_file_content(const std::string &path) { + file_content_path_ = path; +} + +// Result implementation +inline bool Result::has_request_header(const std::string &key) const { + return request_headers_.find(key) != request_headers_.end(); +} + +inline std::string Result::get_request_header_value(const std::string &key, + const char *def, + size_t id) const { + return detail::get_header_value(request_headers_, key, def, id); +} + +inline size_t +Result::get_request_header_value_count(const std::string &key) const { + auto r = request_headers_.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +// Stream implementation +inline ssize_t Stream::write(const char *ptr) { + return write(ptr, strlen(ptr)); +} + +inline ssize_t Stream::write(const std::string &s) { + return write(s.data(), s.size()); +} + +namespace detail { + +// Socket stream implementation +inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec) + : sock_(sock), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), read_buff_(read_buff_size_, 0) {} + +inline SocketStream::~SocketStream() = default; + +inline bool SocketStream::is_readable() const { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SocketStream::is_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_); +} + +inline ssize_t SocketStream::read(char *ptr, size_t size) { +#ifdef _WIN32 + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); +#else + size = (std::min)(size, + static_cast((std::numeric_limits::max)())); +#endif + + if (read_buff_off_ < read_buff_content_size_) { + auto remaining_size = read_buff_content_size_ - read_buff_off_; + if (size <= remaining_size) { + memcpy(ptr, read_buff_.data() + read_buff_off_, size); + read_buff_off_ += size; + return static_cast(size); + } else { + memcpy(ptr, read_buff_.data() + read_buff_off_, remaining_size); + read_buff_off_ += remaining_size; + return static_cast(remaining_size); + } + } + + if (!is_readable()) { return -1; } + + read_buff_off_ = 0; + read_buff_content_size_ = 0; + + if (size < read_buff_size_) { + auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, + CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + return n; + } else if (n <= static_cast(size)) { + memcpy(ptr, read_buff_.data(), static_cast(n)); + return n; + } else { + memcpy(ptr, read_buff_.data(), size); + read_buff_off_ = size; + read_buff_content_size_ = static_cast(n); + return static_cast(size); + } + } else { + return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + } +} + +inline ssize_t SocketStream::write(const char *ptr, size_t size) { + if (!is_writable()) { return -1; } + +#if defined(_WIN32) && !defined(_WIN64) + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); +#endif + + return send_socket(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); +} + +inline void SocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + return detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + return detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SocketStream::socket() const { return sock_; } + +// Buffer stream implementation +inline bool BufferStream::is_readable() const { return true; } + +inline bool BufferStream::is_writable() const { return true; } + +inline ssize_t BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1910 + auto len_read = buffer._Copy_s(ptr, size, size, position); +#else + auto len_read = buffer.copy(ptr, size, position); +#endif + position += static_cast(len_read); + return static_cast(len_read); +} + +inline ssize_t BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); +} + +inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline socket_t BufferStream::socket() const { return 0; } + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + +inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { + static constexpr char marker[] = "/:"; + + // One past the last ending position of a path param substring + std::size_t last_param_end = 0; + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + // Needed to ensure that parameter names are unique during matcher + // construction + // If exceptions are disabled, only last duplicate path + // parameter will be set + std::unordered_set param_name_set; +#endif + + while (true) { + const auto marker_pos = pattern.find( + marker, last_param_end == 0 ? last_param_end : last_param_end - 1); + if (marker_pos == std::string::npos) { break; } + + static_fragments_.push_back( + pattern.substr(last_param_end, marker_pos - last_param_end + 1)); + + const auto param_name_start = marker_pos + 2; + + auto sep_pos = pattern.find(separator, param_name_start); + if (sep_pos == std::string::npos) { sep_pos = pattern.length(); } + + auto param_name = + pattern.substr(param_name_start, sep_pos - param_name_start); + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (param_name_set.find(param_name) != param_name_set.cend()) { + std::string msg = "Encountered path parameter '" + param_name + + "' multiple times in route pattern '" + pattern + "'."; + throw std::invalid_argument(msg); + } +#endif + + param_names_.push_back(std::move(param_name)); + + last_param_end = sep_pos + 1; + } + + if (last_param_end < pattern.length()) { + static_fragments_.push_back(pattern.substr(last_param_end)); + } +} + +inline bool PathParamsMatcher::match(Request &request) const { + request.matches = std::smatch(); + request.path_params.clear(); + request.path_params.reserve(param_names_.size()); + + // One past the position at which the path matched the pattern last time + std::size_t starting_pos = 0; + for (size_t i = 0; i < static_fragments_.size(); ++i) { + const auto &fragment = static_fragments_[i]; + + if (starting_pos + fragment.length() > request.path.length()) { + return false; + } + + // Avoid unnecessary allocation by using strncmp instead of substr + + // comparison + if (std::strncmp(request.path.c_str() + starting_pos, fragment.c_str(), + fragment.length()) != 0) { + return false; + } + + starting_pos += fragment.length(); + + // Should only happen when we have a static fragment after a param + // Example: '/users/:id/subscriptions' + // The 'subscriptions' fragment here does not have a corresponding param + if (i >= param_names_.size()) { continue; } + + auto sep_pos = request.path.find(separator, starting_pos); + if (sep_pos == std::string::npos) { sep_pos = request.path.length(); } + + const auto ¶m_name = param_names_[i]; + + request.path_params.emplace( + param_name, request.path.substr(starting_pos, sep_pos - starting_pos)); + + // Mark everything up to '/' as matched + starting_pos = sep_pos + 1; + } + // Returns false if the path is longer than the pattern + return starting_pos >= request.path.length(); +} + +inline bool RegexMatcher::match(Request &request) const { + request.path_params.clear(); + return std::regex_match(request.path, request.matches, regex_); +} + +} // namespace detail + +// HTTP server implementation +inline Server::Server() + : new_task_queue( + [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) { +#ifndef _WIN32 + signal(SIGPIPE, SIG_IGN); +#endif +} + +inline Server::~Server() = default; + +inline std::unique_ptr +Server::make_matcher(const std::string &pattern) { + if (pattern.find("/:") != std::string::npos) { + return detail::make_unique(pattern); + } else { + return detail::make_unique(pattern); + } +} + +inline Server &Server::Get(const std::string &pattern, Handler handler) { + get_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, Handler handler) { + post_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, Handler handler) { + put_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, Handler handler) { + patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, Handler handler) { + delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, + HandlerWithContentReader handler) { + delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Options(const std::string &pattern, Handler handler) { + options_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline bool Server::set_base_dir(const std::string &dir, + const std::string &mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const std::string &mount_point, + const std::string &dir, Headers headers) { + detail::FileStat stat(dir); + if (stat.is_dir()) { + std::string mnt = !mount_point.empty() ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.push_back({mnt, dir, std::move(headers)}); + return true; + } + } + return false; +} + +inline bool Server::remove_mount_point(const std::string &mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->mount_point == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline Server & +Server::set_file_extension_and_mimetype_mapping(const std::string &ext, + const std::string &mime) { + file_extension_and_mimetype_map_[ext] = mime; + return *this; +} + +inline Server &Server::set_default_file_mimetype(const std::string &mime) { + default_file_mimetype_ = mime; + return *this; +} + +inline Server &Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(HandlerWithResponse handler, + std::true_type) { + error_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(Handler handler, + std::false_type) { + error_handler_ = [handler](const Request &req, Response &res) { + handler(req, res); + return HandlerResponse::Handled; + }; + return *this; +} + +inline Server &Server::set_exception_handler(ExceptionHandler handler) { + exception_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_pre_routing_handler(HandlerWithResponse handler) { + pre_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_post_routing_handler(Handler handler) { + post_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_logger(Logger logger) { + logger_ = std::move(logger); + return *this; +} + +inline Server & +Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { + expect_100_continue_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_address_family(int family) { + address_family_ = family; + return *this; +} + +inline Server &Server::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; + return *this; +} + +inline Server &Server::set_ipv6_v6only(bool on) { + ipv6_v6only_ = on; + return *this; +} + +inline Server &Server::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); + return *this; +} + +inline Server &Server::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); + return *this; +} + +inline Server &Server::set_header_writer( + std::function const &writer) { + header_writer_ = writer; + return *this; +} + +inline Server &Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; + return *this; +} + +inline Server &Server::set_keep_alive_timeout(time_t sec) { + keep_alive_timeout_sec_ = sec; + return *this; +} + +inline Server &Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_idle_interval(time_t sec, time_t usec) { + idle_interval_sec_ = sec; + idle_interval_usec_ = usec; + return *this; +} + +inline Server &Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; + return *this; +} + +inline bool Server::bind_to_port(const std::string &host, int port, + int socket_flags) { + auto ret = bind_internal(host, port, socket_flags); + if (ret == -1) { is_decommisioned = true; } + return ret >= 0; +} +inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { + auto ret = bind_internal(host, 0, socket_flags); + if (ret == -1) { is_decommisioned = true; } + return ret; +} + +inline bool Server::listen_after_bind() { return listen_internal(); } + +inline bool Server::listen(const std::string &host, int port, + int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { return is_running_; } + +inline void Server::wait_until_ready() const { + while (!is_running_ && !is_decommisioned) { + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } +} + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + is_decommisioned = false; +} + +inline void Server::decommission() { is_decommisioned = true; } + +inline bool Server::parse_request_line(const char *s, Request &req) const { + auto len = strlen(s); + if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { return false; } + len -= 2; + + { + size_t count = 0; + + detail::split(s, s + len, ' ', [&](const char *b, const char *e) { + switch (count) { + case 0: req.method = std::string(b, e); break; + case 1: req.target = std::string(b, e); break; + case 2: req.version = std::string(b, e); break; + default: break; + } + count++; + }); + + if (count != 3) { return false; } + } + + static const std::set methods{ + "GET", "HEAD", "POST", "PUT", "DELETE", + "CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"}; + + if (methods.find(req.method) == methods.end()) { return false; } + + if (req.version != "HTTP/1.1" && req.version != "HTTP/1.0") { return false; } + + { + // Skip URL fragment + for (size_t i = 0; i < req.target.size(); i++) { + if (req.target[i] == '#') { + req.target.erase(i); + break; + } + } + + detail::divide(req.target, '?', + [&](const char *lhs_data, std::size_t lhs_size, + const char *rhs_data, std::size_t rhs_size) { + req.path = detail::decode_url( + std::string(lhs_data, lhs_size), false); + detail::parse_query_text(rhs_data, rhs_size, req.params); + }); + } + + return true; +} + +inline bool Server::write_response(Stream &strm, bool close_connection, + Request &req, Response &res) { + // NOTE: `req.ranges` should be empty, otherwise it will be applied + // incorrectly to the error content. + req.ranges.clear(); + return write_response_core(strm, close_connection, req, res, false); +} + +inline bool Server::write_response_with_content(Stream &strm, + bool close_connection, + const Request &req, + Response &res) { + return write_response_core(strm, close_connection, req, res, true); +} + +inline bool Server::write_response_core(Stream &strm, bool close_connection, + const Request &req, Response &res, + bool need_apply_ranges) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_ && + error_handler_(req, res) == HandlerResponse::Handled) { + need_apply_ranges = true; + } + + std::string content_type; + std::string boundary; + if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } + + // Prepare additional headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::string s = "timeout="; + s += std::to_string(keep_alive_timeout_sec_); + s += ", max="; + s += std::to_string(keep_alive_max_count_); + res.set_header("Keep-Alive", s); + } + + if ((!res.body.empty() || res.content_length_ > 0 || res.content_provider_) && + !res.has_header("Content-Type")) { + res.set_header("Content-Type", "text/plain"); + } + + if (res.body.empty() && !res.content_length_ && !res.content_provider_ && + !res.has_header("Content-Length")) { + res.set_header("Content-Length", "0"); + } + + if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) { + res.set_header("Accept-Ranges", "bytes"); + } + + if (post_routing_handler_) { post_routing_handler_(req, res); } + + // Response line and headers + { + detail::BufferStream bstrm; + if (!detail::write_response_line(bstrm, res.status)) { return false; } + if (!header_writer_(bstrm, res.headers)) { return false; } + + // Flush buffer + auto &data = bstrm.get_buffer(); + detail::write_data(strm, data.data(), data.size()); + } + + // Body + auto ret = true; + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!detail::write_data(strm, res.body.data(), res.body.size())) { + ret = false; + } + } else if (res.content_provider_) { + if (write_content_with_provider(strm, req, res, boundary, content_type)) { + res.content_provider_success_ = true; + } else { + ret = false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return ret; +} + +inline bool +Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + + if (res.content_length_ > 0) { + if (req.ranges.empty()) { + return detail::write_content(strm, res.content_provider_, 0, + res.content_length_, is_shutting_down); + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length( + req.ranges[0], res.content_length_); + + return detail::write_content(strm, res.content_provider_, + offset_and_length.first, + offset_and_length.second, is_shutting_down); + } else { + return detail::write_multipart_ranges_data( + strm, req, res, boundary, content_type, res.content_length_, + is_shutting_down); + } + } else { + if (res.is_chunked_content_provider_) { + auto type = detail::encoding_type(req, res); + + std::unique_ptr compressor; + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); +#endif + } else { + compressor = detail::make_unique(); + } + assert(compressor != nullptr); + + return detail::write_content_chunked(strm, res.content_provider_, + is_shutting_down, *compressor); + } else { + return detail::write_content_without_length(strm, res.content_provider_, + is_shutting_down); + } + } +} + +inline bool Server::read_content(Stream &strm, Request &req, Response &res) { + MultipartFormDataMap::iterator cur; + auto file_count = 0; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { return false; } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + if (file_count++ == CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT) { + return false; + } + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) { + res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414? + return false; + } + detail::parse_query_text(req.body, req.params); + } + return true; + } + return false; +} + +inline bool Server::read_content_with_content_receiver( + Stream &strm, Request &req, Response &res, ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + return read_content_core(strm, req, res, std::move(receiver), + std::move(multipart_header), + std::move(multipart_receiver)); +} + +inline bool +Server::read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) const { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiverWithProgress out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = StatusCode::BadRequest_400; + return false; + } + + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { + /* For debug + size_t pos = 0; + while (pos < n) { + auto read_size = (std::min)(1, n - pos); + auto ret = multipart_form_data_parser.parse( + buf + pos, read_size, multipart_receiver, multipart_header); + if (!ret) { return false; } + pos += read_size; + } + return true; + */ + return multipart_form_data_parser.parse(buf, n, multipart_receiver, + multipart_header); + }; + } else { + out = [receiver](const char *buf, size_t n, uint64_t /*off*/, + uint64_t /*len*/) { return receiver(buf, n); }; + } + + if (req.method == "DELETE" && !req.has_header("Content-Length")) { + return true; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, + out, true)) { + return false; + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = StatusCode::BadRequest_400; + return false; + } + } + + return true; +} + +inline bool Server::handle_file_request(const Request &req, Response &res, + bool head) { + for (const auto &entry : base_dirs_) { + // Prefix match + if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { + std::string sub_path = "/" + req.path.substr(entry.mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = entry.base_dir + sub_path; + if (path.back() == '/') { path += "index.html"; } + + detail::FileStat stat(path); + + if (stat.is_dir()) { + res.set_redirect(sub_path + "/", StatusCode::MovedPermanently_301); + return true; + } + + if (stat.is_file()) { + for (const auto &kv : entry.headers) { + res.set_header(kv.first, kv.second); + } + + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { return false; } + + res.set_content_provider( + mm->size(), + detail::find_content_type(path, file_extension_and_mimetype_map_, + default_file_mimetype_), + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + + return true; + } + } + } + } + return false; +} + +inline socket_t +Server::create_server_socket(const std::string &host, int port, + int socket_flags, + SocketOptions socket_options) const { + return detail::create_socket( + host, std::string(), port, address_family_, socket_flags, tcp_nodelay_, + ipv6_v6only_, std::move(socket_options), + [](socket_t sock, struct addrinfo &ai, bool & /*quit*/) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, CPPHTTPLIB_LISTEN_BACKLOG)) { return false; } + return true; + }); +} + +inline int Server::bind_internal(const std::string &host, int port, + int socket_flags) { + if (is_decommisioned) { return -1; } + + if (!is_valid()) { return -1; } + + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); + if (svr_sock_ == INVALID_SOCKET) { return -1; } + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), + &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + if (is_decommisioned) { return false; } + + auto ret = true; + is_running_ = true; + auto se = detail::scope_exit([&]() { is_running_ = false; }); + + { + std::unique_ptr task_queue(new_task_queue()); + + while (svr_sock_ != INVALID_SOCKET) { +#ifndef _WIN32 + if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { +#endif + auto val = detail::select_read(svr_sock_, idle_interval_sec_, + idle_interval_usec_); + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } +#ifndef _WIN32 + } +#endif + +#if defined _WIN32 + // sockets conneced via WASAccept inherit flags NO_HANDLE_INHERIT, + // OVERLAPPED + socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0); +#elif defined SOCK_CLOEXEC + socket_t sock = accept4(svr_sock_, nullptr, nullptr, SOCK_CLOEXEC); +#else + socket_t sock = accept(svr_sock_, nullptr, nullptr); +#endif + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } else if (errno == EINTR || errno == EAGAIN) { + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } + + { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec_ * 1000 + + read_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec_); + tv.tv_usec = static_cast(read_timeout_usec_); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&tv), sizeof(tv)); +#endif + } + { + +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec_ * 1000 + + write_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, + reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec_); + tv.tv_usec = static_cast(write_timeout_usec_); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, + reinterpret_cast(&tv), sizeof(tv)); +#endif + } + + if (!task_queue->enqueue( + [this, sock]() { process_and_close_socket(sock); })) { + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + } + + task_queue->shutdown(); + } + + is_decommisioned = !ret; + return ret; +} + +inline bool Server::routing(Request &req, Response &res, Stream &strm) { + if (pre_routing_handler_ && + pre_routing_handler_(req, res) == HandlerResponse::Handled) { + return true; + } + + // File handler + auto is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && + handle_file_request(req, res, is_head_request)) { + return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver( + strm, req, res, std::move(receiver), nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, + std::move(header), + std::move(receiver)); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { return false; } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = StatusCode::BadRequest_400; + return false; +} + +inline bool Server::dispatch_request(Request &req, Response &res, + const Handlers &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res); + return true; + } + } + return false; +} + +inline void Server::apply_ranges(const Request &req, Response &res, + std::string &content_type, + std::string &boundary) const { + if (req.ranges.size() > 1 && res.status == StatusCode::PartialContent_206) { + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + boundary = detail::make_multipart_data_boundary(); + + res.set_header("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } + + auto type = detail::encoding_type(req, res); + + if (res.body.empty()) { + if (res.content_length_ > 0) { + size_t length = 0; + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + length = res.content_length_; + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length( + req.ranges[0], res.content_length_); + + length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field( + offset_and_length, res.content_length_); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length( + req, boundary, content_type, res.content_length_); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider_) { + if (res.is_chunked_content_provider_) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } + } + } + } + } else { + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + ; + } else if (req.ranges.size() == 1) { + auto offset_and_length = + detail::get_range_offset_and_length(req.ranges[0], res.body.size()); + auto offset = offset_and_length.first; + auto length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field( + offset_and_length, res.body.size()); + res.set_header("Content-Range", content_range); + + assert(offset + length <= res.body.size()); + res.body = res.body.substr(offset, length); + } else { + std::string data; + detail::make_multipart_ranges_data(req, res, boundary, content_type, + res.body.size(), data); + res.body.swap(data); + } + + if (type != detail::EncodingType::None) { + std::unique_ptr compressor; + std::string content_encoding; + + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); + content_encoding = "gzip"; +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); + content_encoding = "br"; +#endif + } + + if (compressor) { + std::string compressed; + if (compressor->compress(res.body.data(), res.body.size(), true, + [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + res.body.swap(compressed); + res.set_header("Content-Encoding", content_encoding); + } + } + } + + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } +} + +inline bool Server::dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res, content_reader); + return true; + } + } + return false; +} + +inline bool +Server::process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, + bool &connection_closed, + const std::function &setup_request) { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { return false; } + + Request req; + + Response res; + res.version = "HTTP/1.1"; + res.headers = default_headers_; + +#ifdef _WIN32 + // TODO: Increase FD_SETSIZE statically (libzmq), dynamically (MySQL). +#else +#ifndef CPPHTTPLIB_USE_POLL + // Socket file descriptor exceeded FD_SETSIZE... + if (strm.socket() >= FD_SETSIZE) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::InternalServerError_500; + return write_response(strm, close_connection, req, res); + } +#endif +#endif + + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::UriTooLong_414; + return write_response(strm, close_connection, req, res); + } + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_closed = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_closed = true; + } + + req.remote_addr = remote_addr; + req.remote_port = remote_port; + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + + req.local_addr = local_addr; + req.local_port = local_port; + req.set_header("LOCAL_ADDR", req.local_addr); + req.set_header("LOCAL_PORT", std::to_string(req.local_port)); + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + } + + if (setup_request) { setup_request(req); } + + if (req.get_header_value("Expect") == "100-continue") { + int status = StatusCode::Continue_100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case StatusCode::Continue_100: + case StatusCode::ExpectationFailed_417: + detail::write_response_line(strm, status); + strm.write("\r\n"); + break; + default: + connection_closed = true; + return write_response(strm, true, req, res); + } + } + + // Setup `is_connection_closed` method + req.is_connection_closed = [&]() { + return !detail::is_socket_alive(strm.socket()); + }; + + // Routing + auto routed = false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS + routed = routing(req, res, strm); +#else + try { + routed = routing(req, res, strm); + } catch (std::exception &e) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + std::string val; + auto s = e.what(); + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case '\r': val += "\\r"; break; + case '\n': val += "\\n"; break; + default: val += s[i]; break; + } + } + res.set_header("EXCEPTION_WHAT", val); + } + } catch (...) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); + } + } +#endif + if (routed) { + if (res.status == -1) { + res.status = req.ranges.empty() ? StatusCode::OK_200 + : StatusCode::PartialContent_206; + } + + // Serve file content by using a content provider + if (!res.file_content_path_.empty()) { + const auto &path = res.file_content_path_; + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::NotFound_404; + return write_response(strm, close_connection, req, res); + } + + auto content_type = res.file_content_content_type_; + if (content_type.empty()) { + content_type = detail::find_content_type( + path, file_extension_and_mimetype_map_, default_file_mimetype_); + } + + res.set_content_provider( + mm->size(), content_type, + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + } + + if (detail::range_error(req, res)) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + + return write_response_with_content(strm, close_connection, req, res); + } else { + if (res.status == -1) { res.status = StatusCode::NotFound_404; } + + return write_response(strm, close_connection, req, res); + } +} + +inline bool Server::is_valid() const { return true; } + +inline bool Server::process_and_close_socket(socket_t sock) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + auto ret = detail::process_server_socket( + svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, connection_closed, + nullptr); + }); + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// HTTP client implementation +inline ClientImpl::ClientImpl(const std::string &host) + : ClientImpl(host, 80, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port) + : ClientImpl(host, port, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port), + host_and_port_(adjust_host_string(host_) + ":" + std::to_string(port)), + client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + +inline ClientImpl::~ClientImpl() { + std::lock_guard guard(socket_mutex_); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline bool ClientImpl::is_valid() const { return true; } + +inline void ClientImpl::copy_settings(const ClientImpl &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + connection_timeout_sec_ = rhs.connection_timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + bearer_token_auth_token_ = rhs.bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + keep_alive_ = rhs.keep_alive_; + follow_location_ = rhs.follow_location_; + url_encode_ = rhs.url_encode_; + address_family_ = rhs.address_family_; + tcp_nodelay_ = rhs.tcp_nodelay_; + ipv6_v6only_ = rhs.ipv6_v6only_; + socket_options_ = rhs.socket_options_; + compress_ = rhs.compress_; + decompress_ = rhs.decompress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + ca_cert_file_path_ = rhs.ca_cert_file_path_; + ca_cert_dir_path_ = rhs.ca_cert_dir_path_; + ca_cert_store_ = rhs.ca_cert_store_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + server_certificate_verification_ = rhs.server_certificate_verification_; + server_hostname_verification_ = rhs.server_hostname_verification_; + server_certificate_verifier_ = rhs.server_certificate_verifier_; +#endif + logger_ = rhs.logger_; +} + +inline socket_t ClientImpl::create_client_socket(Error &error) const { + if (!proxy_host_.empty() && proxy_port_ != -1) { + return detail::create_client_socket( + proxy_host_, std::string(), proxy_port_, address_family_, tcp_nodelay_, + ipv6_v6only_, socket_options_, connection_timeout_sec_, + connection_timeout_usec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, interface_, error); + } + + // Check is custom IP specified for host_ + std::string ip; + auto it = addr_map_.find(host_); + if (it != addr_map_.end()) { ip = it->second; } + + return detail::create_client_socket( + host_, ip, port_, address_family_, tcp_nodelay_, ipv6_v6only_, + socket_options_, connection_timeout_sec_, connection_timeout_usec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, interface_, error); +} + +inline bool ClientImpl::create_and_connect_socket(Socket &socket, + Error &error) { + auto sock = create_client_socket(error); + if (sock == INVALID_SOCKET) { return false; } + socket.sock = sock; + return true; +} + +inline void ClientImpl::shutdown_ssl(Socket & /*socket*/, + bool /*shutdown_gracefully*/) { + // If there are any requests in flight from threads other than us, then it's + // a thread-unsafe race because individual ssl* objects are not thread-safe. + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); +} + +inline void ClientImpl::shutdown_socket(Socket &socket) const { + if (socket.sock == INVALID_SOCKET) { return; } + detail::shutdown_socket(socket.sock); +} + +inline void ClientImpl::close_socket(Socket &socket) { + // If there are requests in flight in another thread, usually closing + // the socket will be fine and they will simply receive an error when + // using the closed socket, but it is still a bug since rarely the OS + // may reassign the socket id to be used for a new socket, and then + // suddenly they will be operating on a live socket that is different + // than the one they intended! + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); + + // It is also a bug if this happens while SSL is still active +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + assert(socket.ssl == nullptr); +#endif + if (socket.sock == INVALID_SOCKET) { return; } + detail::close_socket(socket.sock); + socket.sock = INVALID_SOCKET; +} + +inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, + Response &res) const { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { return false; } + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); +#else + const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); +#endif + + std::cmatch m; + if (!std::regex_match(line_reader.ptr(), m, re)) { + return req.method == "CONNECT"; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + + // Ignore '100 Continue' + while (res.status == StatusCode::Continue_100) { + if (!line_reader.getline()) { return false; } // CRLF + if (!line_reader.getline()) { return false; } // next response line + + if (!std::regex_match(line_reader.ptr(), m, re)) { return false; } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + } + + return true; +} + +inline bool ClientImpl::send(Request &req, Response &res, Error &error) { + std::lock_guard request_mutex_guard(request_mutex_); + auto ret = send_(req, res, error); + if (error == Error::SSLPeerCouldBeClosed_) { + assert(!ret); + ret = send_(req, res, error); + } + return ret; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline bool ClientImpl::is_ssl_peer_could_be_closed(SSL *ssl) const { + detail::set_nonblocking(socket_.sock, true); + auto se = detail::scope_exit( + [&]() { detail::set_nonblocking(socket_.sock, false); }); + + char buf[1]; + return !SSL_peek(ssl, buf, 1) && + SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; +} +#endif + +inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { + { + std::lock_guard guard(socket_mutex_); + + // Set this to false immediately - if it ever gets set to true by the end of + // the request, we know another thread instructed us to close the socket. + socket_should_be_closed_when_request_is_done_ = false; + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_alive && is_ssl()) { + if (is_ssl_peer_could_be_closed(socket_.ssl)) { is_alive = false; } + } +#endif + + if (!is_alive) { + // Attempt to avoid sigpipe by shutting down nongracefully if it seems + // like the other side has already closed the connection Also, there + // cannot be any requests in flight from other threads since we locked + // request_mutex_, so safe to close everything immediately + const bool shutdown_gracefully = false; + shutdown_ssl(socket_, shutdown_gracefully); + shutdown_socket(socket_); + close_socket(socket_); + } + } + + if (!is_alive) { + if (!create_and_connect_socket(socket_, error)) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + auto success = false; + if (!scli.connect_with_proxy(socket_, res, success, error)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_, error)) { return false; } + } +#endif + } + + // Mark the current socket as being in use so that it cannot be closed by + // anyone else while this request is ongoing, even though we will be + // releasing the mutex. + if (socket_requests_in_flight_ > 1) { + assert(socket_requests_are_from_thread_ == std::this_thread::get_id()); + } + socket_requests_in_flight_ += 1; + socket_requests_are_from_thread_ = std::this_thread::get_id(); + } + + for (const auto &header : default_headers_) { + if (req.headers.find(header.first) == req.headers.end()) { + req.headers.insert(header); + } + } + + auto ret = false; + auto close_connection = !keep_alive_; + + auto se = detail::scope_exit([&]() { + // Briefly lock mutex in order to mark that a request is no longer ongoing + std::lock_guard guard(socket_mutex_); + socket_requests_in_flight_ -= 1; + if (socket_requests_in_flight_ <= 0) { + assert(socket_requests_in_flight_ == 0); + socket_requests_are_from_thread_ = std::thread::id(); + } + + if (socket_should_be_closed_when_request_is_done_ || close_connection || + !ret) { + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + }); + + ret = process_socket(socket_, [&](Stream &strm) { + return handle_request(strm, req, res, close_connection, error); + }); + + if (!ret) { + if (error == Error::Success) { error = Error::Unknown; } + } + + return ret; +} + +inline Result ClientImpl::send(const Request &req) { + auto req2 = req; + return send_(std::move(req2)); +} + +inline Result ClientImpl::send_(Request &&req) { + auto res = detail::make_unique(); + auto error = Error::Success; + auto ret = send(req, *res, error); + return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)}; +} + +inline bool ClientImpl::handle_request(Stream &strm, Request &req, + Response &res, bool close_connection, + Error &error) { + if (req.path.empty()) { + error = Error::Connection; + return false; + } + + auto req_save = req; + + bool ret; + + if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, close_connection, error); + req = req2; + req.path = req_save.path; + } else { + ret = process_request(strm, req, res, close_connection, error); + } + + if (!ret) { return false; } + + if (res.get_header_value("Connection") == "close" || + (res.version == "HTTP/1.0" && res.reason != "Connection established")) { + // TODO this requires a not-entirely-obvious chain of calls to be correct + // for this to be safe. + + // This is safe to call because handle_request is only called by send_ + // which locks the request mutex during the process. It would be a bug + // to call it from a different thread since it's a thread-safety issue + // to do these things to the socket if another thread is using the socket. + std::lock_guard guard(socket_mutex_); + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + + if (300 < res.status && res.status < 400 && follow_location_) { + req = req_save; + ret = redirect(req, res, error); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if ((res.status == StatusCode::Unauthorized_401 || + res.status == StatusCode::ProxyAuthenticationRequired_407) && + req.authorization_count_ < 5) { + auto is_proxy = res.status == StatusCode::ProxyAuthenticationRequired_407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.authorization_count_ += 1; + new_req.headers.erase(is_proxy ? "Proxy-Authorization" + : "Authorization"); + new_req.headers.insert(detail::make_digest_authentication_header( + req, auth, new_req.authorization_count_, detail::random_string(10), + username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res, error); + if (ret) { res = new_res; } + } + } + } +#endif + + return ret; +} + +inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { + if (req.redirect_count_ == 0) { + error = Error::ExceedRedirectCount; + return false; + } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + const static std::regex re( + R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + + std::smatch m; + if (!std::regex_match(location, m, re)) { return false; } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + if (next_host.empty()) { next_host = m[3].str(); } + auto port_str = m[4].str(); + auto next_path = m[5].str(); + auto next_query = m[6].str(); + + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } + + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } + + auto path = detail::decode_url(next_path, true) + next_query; + + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, path, location, error); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host, next_port); + cli.copy_settings(*this); + if (ca_cert_store_) { cli.set_ca_cert_store(ca_cert_store_); } + return detail::redirect(cli, req, res, path, location, error); +#else + return false; +#endif + } else { + ClientImpl cli(next_host, next_port); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, path, location, error); + } + } +} + +inline bool ClientImpl::write_content_with_provider(Stream &strm, + const Request &req, + Error &error) const { + auto is_shutting_down = []() { return false; }; + + if (req.is_chunked_content_provider_) { + // TODO: Brotli support + std::unique_ptr compressor; +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + compressor = detail::make_unique(); + } else +#endif + { + compressor = detail::make_unique(); + } + + return detail::write_content_chunked(strm, req.content_provider_, + is_shutting_down, *compressor, error); + } else { + return detail::write_content(strm, req.content_provider_, 0, + req.content_length_, is_shutting_down, error); + } +} + +inline bool ClientImpl::write_request(Stream &strm, Request &req, + bool close_connection, Error &error) { + // Prepare additional headers + if (close_connection) { + if (!req.has_header("Connection")) { + req.set_header("Connection", "close"); + } + } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } else { + if (port_ == 80) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } + + if (!req.content_receiver) { + if (!req.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; +#endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "gzip, deflate"; +#endif + req.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!req.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + req.set_header("User-Agent", agent); + } +#endif + }; + + if (req.body.empty()) { + if (req.content_provider_) { + if (!req.is_chunked_content_provider_) { + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.content_length_); + req.set_header("Content-Length", length); + } + } + } else { + if (req.method == "POST" || req.method == "PUT" || + req.method == "PATCH") { + req.set_header("Content-Length", "0"); + } + } + } else { + if (!req.has_header("Content-Type")) { + req.set_header("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + req.set_header("Content-Length", length); + } + } + + if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } + } + + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + } + + if (!bearer_token_auth_token_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_bearer_token_authentication_header( + bearer_token_auth_token_, false)); + } + } + + if (!proxy_bearer_token_auth_token_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_bearer_token_authentication_header( + proxy_bearer_token_auth_token_, true)); + } + } + + // Request line and headers + { + detail::BufferStream bstrm; + + const auto &path_with_query = + req.params.empty() ? req.path + : append_query_params(req.path, req.params); + + const auto &path = + url_encode_ ? detail::encode_url(path_with_query) : path_with_query; + + detail::write_request_line(bstrm, req.method, path); + + header_writer_(bstrm, req.headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { + error = Error::Write; + return false; + } + } + + // Body + if (req.body.empty()) { + return write_content_with_provider(strm, req, error); + } + + if (!detail::write_data(strm, req.body.data(), req.body.size())) { + error = Error::Write; + return false; + } + + return true; +} + +inline std::unique_ptr ClientImpl::send_with_content_provider( + Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error) { + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { req.set_header("Content-Encoding", "gzip"); } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_ && !content_provider_without_length) { + // TODO: Brotli support + detail::gzip_compressor compressor; + + if (content_provider) { + auto ok = true; + size_t offset = 0; + DataSink data_sink; + + data_sink.write = [&](const char *data, size_t data_len) -> bool { + if (ok) { + auto last = offset + data_len == content_length; + + auto ret = compressor.compress( + data, data_len, last, + [&](const char *compressed_data, size_t compressed_data_len) { + req.body.append(compressed_data, compressed_data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + return ok; + }; + + while (ok && offset < content_length) { + if (!content_provider(offset, content_length - offset, data_sink)) { + error = Error::Canceled; + return nullptr; + } + } + } else { + if (!compressor.compress(body, content_length, true, + [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + error = Error::Compression; + return nullptr; + } + } + } else +#endif + { + if (content_provider) { + req.content_length_ = content_length; + req.content_provider_ = std::move(content_provider); + req.is_chunked_content_provider_ = false; + } else if (content_provider_without_length) { + req.content_length_ = 0; + req.content_provider_ = detail::ContentProviderAdapter( + std::move(content_provider_without_length)); + req.is_chunked_content_provider_ = true; + req.set_header("Transfer-Encoding", "chunked"); + } else { + req.body.assign(body, content_length); + } + } + + auto res = detail::make_unique(); + return send(req, *res, error) ? std::move(res) : nullptr; +} + +inline Result ClientImpl::send_with_content_provider( + const std::string &method, const std::string &path, const Headers &headers, + const char *body, size_t content_length, ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Progress progress) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + req.progress = progress; + + auto error = Error::Success; + + auto res = send_with_content_provider( + req, body, content_length, std::move(content_provider), + std::move(content_provider_without_length), content_type, error); + + return Result{std::move(res), error, std::move(req.headers)}; +} + +inline std::string +ClientImpl::adjust_host_string(const std::string &host) const { + if (host.find(':') != std::string::npos) { return "[" + host + "]"; } + return host; +} + +inline bool ClientImpl::process_request(Stream &strm, Request &req, + Response &res, bool close_connection, + Error &error) { + // Send request + if (!write_request(strm, req, close_connection, error)) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl()) { + auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; + if (!is_proxy_enabled) { + if (is_ssl_peer_could_be_closed(socket_.ssl)) { + error = Error::SSLPeerCouldBeClosed_; + return false; + } + } + } +#endif + + // Receive response and headers + if (!read_response_line(strm, req, res) || + !detail::read_headers(strm, res.headers)) { + error = Error::Read; + return false; + } + + // Body + if ((res.status != StatusCode::NoContent_204) && req.method != "HEAD" && + req.method != "CONNECT") { + auto redirect = 300 < res.status && res.status < 400 && + res.status != StatusCode::NotModified_304 && + follow_location_; + + if (req.response_handler && !redirect) { + if (!req.response_handler(res)) { + error = Error::Canceled; + return false; + } + } + + auto out = + req.content_receiver + ? static_cast( + [&](const char *buf, size_t n, uint64_t off, uint64_t len) { + if (redirect) { return true; } + auto ret = req.content_receiver(buf, n, off, len); + if (!ret) { error = Error::Canceled; } + return ret; + }) + : static_cast( + [&](const char *buf, size_t n, uint64_t /*off*/, + uint64_t /*len*/) { + assert(res.body.size() + n <= res.body.max_size()); + res.body.append(buf, n); + return true; + }); + + auto progress = [&](uint64_t current, uint64_t total) { + if (!req.progress || redirect) { return true; } + auto ret = req.progress(current, total); + if (!ret) { error = Error::Canceled; } + return ret; + }; + + if (res.has_header("Content-Length")) { + if (!req.content_receiver) { + auto len = res.get_header_value_u64("Content-Length"); + if (len > res.body.max_size()) { + error = Error::Read; + return false; + } + res.body.reserve(static_cast(len)); + } + } + + if (res.status != StatusCode::NotModified_304) { + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, std::move(progress), + std::move(out), decompress_)) { + if (error != Error::Canceled) { error = Error::Read; } + return false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; +} + +inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider( + const std::string &boundary, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) const { + size_t cur_item = 0; + size_t cur_start = 0; + // cur_item and cur_start are copied to within the std::function and maintain + // state between successive calls + return [&, cur_item, cur_start](size_t offset, + DataSink &sink) mutable -> bool { + if (!offset && !items.empty()) { + sink.os << detail::serialize_multipart_formdata(items, boundary, false); + return true; + } else if (cur_item < provider_items.size()) { + if (!cur_start) { + const auto &begin = detail::serialize_multipart_formdata_item_begin( + provider_items[cur_item], boundary); + offset += begin.size(); + cur_start = offset; + sink.os << begin; + } + + DataSink cur_sink; + auto has_data = true; + cur_sink.write = sink.write; + cur_sink.done = [&]() { has_data = false; }; + + if (!provider_items[cur_item].provider(offset - cur_start, cur_sink)) { + return false; + } + + if (!has_data) { + sink.os << detail::serialize_multipart_formdata_item_end(); + cur_item++; + cur_start = 0; + } + return true; + } else { + sink.os << detail::serialize_multipart_formdata_finish(boundary); + sink.done(); + return true; + } + }; +} + +inline bool +ClientImpl::process_socket(const Socket &socket, + std::function callback) { + return detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, std::move(callback)); +} + +inline bool ClientImpl::is_ssl() const { return false; } + +inline Result ClientImpl::Get(const std::string &path) { + return Get(path, Headers(), Progress()); +} + +inline Result ClientImpl::Get(const std::string &path, Progress progress) { + return Get(path, Headers(), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers) { + return Get(path, headers, Progress()); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, + ContentReceiver content_receiver) { + return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver) { + return Get(path, headers, nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, Headers(), std::move(response_handler), + std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, headers, std::move(response_handler), + std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + uint64_t /*offset*/, uint64_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.progress = std::move(progress); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress) { + if (params.empty()) { return Get(path, headers); } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, params, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + if (params.empty()) { + return Get(path, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); + } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Head(const std::string &path) { + return Head(path, Headers()); +} + +inline Result ClientImpl::Head(const std::string &path, + const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + + return send_(std::move(req)); +} + +inline Result ClientImpl::Post(const std::string &path) { + return Post(path, std::string(), std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, + const Headers &headers) { + return Post(path, headers, nullptr, 0, std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Post(path, Headers(), body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body, content_length, + nullptr, nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("POST", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, + const std::string &content_type) { + return Post(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return Post(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("POST", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Params ¶ms) { + return Post(path, Headers(), params); +} + +inline Result ClientImpl::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Post(path, Headers(), content_length, std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Post(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded", + progress); +} + +inline Result ClientImpl::Post(const std::string &path, + const MultipartFormDataItems &items) { + return Post(path, Headers(), items); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type); +} + +inline Result +ClientImpl::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "POST", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path) { + return Put(path, std::string(), std::string()); +} + +inline Result ClientImpl::Put(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Put(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body, content_length, + nullptr, nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PUT", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, + const std::string &content_type) { + return Put(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return Put(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PUT", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Put(path, Headers(), content_length, std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Put(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Params ¶ms) { + return Put(path, Headers(), params); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded", + progress); +} + +inline Result ClientImpl::Put(const std::string &path, + const MultipartFormDataItems &items) { + return Put(path, Headers(), items); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result +ClientImpl::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "PUT", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type, nullptr); +} +inline Result ClientImpl::Patch(const std::string &path) { + return Patch(path, std::string(), std::string()); +} + +inline Result ClientImpl::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Patch(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return Patch(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return Patch(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PATCH", path, headers, body, + content_length, nullptr, nullptr, + content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, + const std::string &body, + const std::string &content_type) { + return Patch(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Patch(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return Patch(path, headers, body, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PATCH", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Patch(path, Headers(), content_length, std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Patch(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type, + nullptr); +} + +inline Result ClientImpl::Delete(const std::string &path) { + return Delete(path, Headers(), std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers) { + return Delete(path, headers, std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Delete(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return Delete(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, const char *body, + size_t content_length, + const std::string &content_type) { + return Delete(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + req.progress = progress; + + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + req.body.assign(body, content_length); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Delete(const std::string &path, + const std::string &body, + const std::string &content_type) { + return Delete(path, Headers(), body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Delete(path, Headers(), body.data(), body.size(), content_type, + progress); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type) { + return Delete(path, headers, body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Delete(path, headers, body.data(), body.size(), content_type, + progress); +} + +inline Result ClientImpl::Options(const std::string &path) { + return Options(path, Headers()); +} + +inline Result ClientImpl::Options(const std::string &path, + const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.headers = headers; + req.path = path; + + return send_(std::move(req)); +} + +inline void ClientImpl::stop() { + std::lock_guard guard(socket_mutex_); + + // If there is anything ongoing right now, the ONLY thread-safe thing we can + // do is to shutdown_socket, so that threads using this socket suddenly + // discover they can't read/write any more and error out. Everything else + // (closing the socket, shutting ssl down) is unsafe because these actions are + // not thread-safe. + if (socket_requests_in_flight_ > 0) { + shutdown_socket(socket_); + + // Aside from that, we set a flag for the socket to be closed when we're + // done. + socket_should_be_closed_when_request_is_done_ = true; + return; + } + + // Otherwise, still holding the mutex, we can shut everything down ourselves + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline std::string ClientImpl::host() const { return host_; } + +inline int ClientImpl::port() const { return port_; } + +inline size_t ClientImpl::is_socket_open() const { + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); +} + +inline socket_t ClientImpl::socket() const { return socket_.sock; } + +inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; +} + +inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + +inline void ClientImpl::set_basic_auth(const std::string &username, + const std::string &password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +inline void ClientImpl::set_bearer_token_auth(const std::string &token) { + bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_digest_auth(const std::string &username, + const std::string &password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } + +inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } + +inline void ClientImpl::set_url_encode(bool on) { url_encode_ = on; } + +inline void +ClientImpl::set_hostname_addr_map(std::map addr_map) { + addr_map_ = std::move(addr_map); +} + +inline void ClientImpl::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); +} + +inline void ClientImpl::set_header_writer( + std::function const &writer) { + header_writer_ = writer; +} + +inline void ClientImpl::set_address_family(int family) { + address_family_ = family; +} + +inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } + +inline void ClientImpl::set_ipv6_v6only(bool on) { ipv6_v6only_ = on; } + +inline void ClientImpl::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); +} + +inline void ClientImpl::set_compress(bool on) { compress_ = on; } + +inline void ClientImpl::set_decompress(bool on) { decompress_ = on; } + +inline void ClientImpl::set_interface(const std::string &intf) { + interface_ = intf; +} + +inline void ClientImpl::set_proxy(const std::string &host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void ClientImpl::set_proxy_basic_auth(const std::string &username, + const std::string &password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +inline void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) { + proxy_bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} + +inline void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path) { + ca_cert_file_path_ = ca_cert_file_path; + ca_cert_dir_path_ = ca_cert_dir_path; +} + +inline void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store && ca_cert_store != ca_cert_store_) { + ca_cert_store_ = ca_cert_store; + } +} + +inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, + std::size_t size) const { + auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); + auto se = detail::scope_exit([&] { BIO_free_all(mem); }); + if (!mem) { return nullptr; } + + auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); + if (!inf) { return nullptr; } + + auto cts = X509_STORE_new(); + if (cts) { + for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { + auto itmp = sk_X509_INFO_value(inf, i); + if (!itmp) { continue; } + + if (itmp->x509) { X509_STORE_add_cert(cts, itmp->x509); } + if (itmp->crl) { X509_STORE_add_crl(cts, itmp->crl); } + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + return cts; +} + +inline void ClientImpl::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +inline void ClientImpl::enable_server_hostname_verification(bool enabled) { + server_hostname_verification_ = enabled; +} + +inline void ClientImpl::set_server_certificate_verifier( + std::function verifier) { + server_certificate_verifier_ = verifier; +} +#endif + +inline void ClientImpl::set_logger(Logger logger) { + logger_ = std::move(logger); +} + +/* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace detail { + +template +inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, + U SSL_connect_or_accept, V setup) { + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (ssl) { + set_nonblocking(sock, true); + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + BIO_set_nbio(bio, 1); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + set_nonblocking(sock, false); + return nullptr; + } + BIO_set_nbio(bio, 0); + set_nonblocking(sock, false); + } + + return ssl; +} + +inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, + bool shutdown_gracefully) { + // sometimes we may want to skip this to try to avoid SIGPIPE if we know + // the remote has closed the network connection + // Note that it is not always possible to avoid SIGPIPE, this is merely a + // best-efforts. + if (shutdown_gracefully) { +#ifdef _WIN32 + (void)(sock); + SSL_shutdown(ssl); +#else + timeval tv; + tv.tv_sec = 1; + tv.tv_usec = 0; + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&tv), sizeof(tv)); + + auto ret = SSL_shutdown(ssl); + while (ret == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds{100}); + ret = SSL_shutdown(ssl); + } +#endif + } + + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); +} + +template +bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl, + U ssl_connect_or_accept, + time_t timeout_sec, + time_t timeout_usec) { + auto res = 0; + while ((res = ssl_connect_or_accept(ssl)) != 1) { + auto err = SSL_get_error(ssl, res); + switch (err) { + case SSL_ERROR_WANT_READ: + if (select_read(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + case SSL_ERROR_WANT_WRITE: + if (select_write(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + default: break; + } + return false; + } + return true; +} + +template +inline bool process_server_socket_ssl( + const std::atomic &svr_sock, SSL *ssl, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +template +inline bool +process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); +} + +class SSLInit { +public: + SSLInit() { + OPENSSL_init_ssl( + OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); + } +}; + +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec) + : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec) { + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); +} + +inline SSLSocketStream::~SSLSocketStream() = default; + +inline bool SSLSocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SSLSocketStream::is_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_); +} + +inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err == SSL_ERROR_WANT_READ || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_READ) { +#endif + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } + return -1; +} + +inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (is_writable()) { + auto handle_size = static_cast( + std::min(size, (std::numeric_limits::max)())); + + auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { +#endif + if (is_writable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } + return -1; +} + +inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SSLSocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SSLSocketStream::socket() const { return sock_; } + +static SSLInit sslinit_; + +} // namespace detail + +// SSL HTTP server implementation +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path, + const char *private_key_password) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (private_key_password != nullptr && (private_key_password[0] != '\0')) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, + reinterpret_cast(const_cast(private_key_password))); + } + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1 || + SSL_CTX_check_private_key(ctx_) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify( + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer( + const std::function &setup_ssl_ctx_callback) { + ctx_ = SSL_CTX_new(TLS_method()); + if (ctx_) { + if (!setup_ssl_ctx_callback(*ctx_)) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } +} + +inline bool SSLServer::is_valid() const { return ctx_; } + +inline SSL_CTX *SSLServer::ssl_context() const { return ctx_; } + +inline void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + + std::lock_guard guard(ctx_mutex_); + + SSL_CTX_use_certificate(ctx_, cert); + SSL_CTX_use_PrivateKey(ctx_, private_key); + + if (client_ca_cert_store != nullptr) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + } +} + +inline bool SSLServer::process_and_close_socket(socket_t sock) { + auto ssl = detail::ssl_new( + sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + return detail::ssl_connect_or_accept_nonblocking( + sock, ssl2, SSL_accept, read_timeout_sec_, read_timeout_usec_); + }, + [](SSL * /*ssl2*/) { return true; }); + + auto ret = false; + if (ssl) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + ret = detail::process_server_socket_ssl( + svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, + connection_closed, + [&](Request &req) { req.ssl = ssl; }); + }); + + // Shutdown gracefully if the result seemed successful, non-gracefully if + // the connection appeared to be closed. + const bool shutdown_gracefully = ret; + detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); + } + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// SSL HTTP client implementation +inline SSLClient::SSLClient(const std::string &host) + : SSLClient(host, 443, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port) + : SSLClient(host, port, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password) + : ClientImpl(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(b, e); + }); + + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::SSLClient(const std::string &host, int port, + X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password) + : ClientImpl(host, port) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(b, e); + }); + + if (client_cert != nullptr && client_key != nullptr) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + shutdown_ssl_impl(socket_, true); +} + +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { + if (ctx_) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { + // Free memory allocated for old cert and use new store `ca_cert_store` + SSL_CTX_set_cert_store(ctx_, ca_cert_store); + } + } else { + X509_STORE_free(ca_cert_store); + } + } +} + +inline void SSLClient::load_ca_cert_store(const char *ca_cert, + std::size_t size) { + set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size)); +} + +inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; +} + +inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } + +inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { + return is_valid() && ClientImpl::create_and_connect_socket(socket, error); +} + +// Assumes that socket_mutex_ is locked and that there are no requests in flight +inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, + bool &success, Error &error) { + success = true; + Response proxy_res; + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are no + // requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + + if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(proxy_res, auth, true)) { + proxy_res = Response(); + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req3, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + } + } + } + + // If status code is not 200, proxy request is failed. + // Set error to ProxyConnection and return proxy response + // as the response of the request + if (proxy_res.status != StatusCode::OK_200) { + error = Error::ProxyConnection; + res = std::move(proxy_res); + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +inline bool SSLClient::load_certs() { + auto ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), + nullptr)) { + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, nullptr, + ca_cert_dir_path_.c_str())) { + ret = false; + } + } else { + auto loaded = false; +#ifdef _WIN32 + loaded = + detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX + loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); +#endif // TARGET_OS_OSX +#endif // _WIN32 + if (!loaded) { SSL_CTX_set_default_verify_paths(ctx_); } + } + }); + + return ret; +} + +inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { + auto ssl = detail::ssl_new( + socket.sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + return false; + } + SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); + } + + if (!detail::ssl_connect_or_accept_nonblocking( + socket.sock, ssl2, SSL_connect, connection_timeout_sec_, + connection_timeout_usec_)) { + error = Error::SSLConnection; + return false; + } + + if (server_certificate_verification_) { + if (server_certificate_verifier_) { + if (!server_certificate_verifier_(ssl2)) { + error = Error::SSLServerVerification; + return false; + } + } else { + verify_result_ = SSL_get_verify_result(ssl2); + + if (verify_result_ != X509_V_OK) { + error = Error::SSLServerVerification; + return false; + } + + auto server_cert = SSL_get1_peer_certificate(ssl2); + auto se = detail::scope_exit([&] { X509_free(server_cert); }); + + if (server_cert == nullptr) { + error = Error::SSLServerVerification; + return false; + } + + if (server_hostname_verification_) { + if (!verify_host(server_cert)) { + error = Error::SSLServerHostnameVerification; + return false; + } + } + } + } + + return true; + }, + [&](SSL *ssl2) { +#if defined(OPENSSL_IS_BORINGSSL) + SSL_set_tlsext_host_name(ssl2, host_.c_str()); +#else + // NOTE: Direct call instead of using the OpenSSL macro to suppress + // -Wold-style-cast warning + SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, + static_cast(const_cast(host_.c_str()))); +#endif + return true; + }); + + if (ssl) { + socket.ssl = ssl; + return true; + } + + shutdown_socket(socket); + close_socket(socket); + return false; +} + +inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + shutdown_ssl_impl(socket, shutdown_gracefully); +} + +inline void SSLClient::shutdown_ssl_impl(Socket &socket, + bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, + shutdown_gracefully); + socket.ssl = nullptr; + } + assert(socket.ssl == nullptr); +} + +inline bool +SSLClient::process_socket(const Socket &socket, + std::function callback) { + assert(socket.ssl); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, std::move(callback)); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +inline bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6{}; + struct in_addr addr{}; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_matched = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = + reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_matched = true; + } + break; + } + } + } + + if (dsn_matched || ip_matched) { ret = true; } + } + + GENERAL_NAMES_free(const_cast( + reinterpret_cast(alt_names))); + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(b, e); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; +} +#endif + +// Universal client implementation +inline Client::Client(const std::string &scheme_host_port) + : Client(scheme_host_port, std::string(), std::string()) {} + +inline Client::Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path) { + const static std::regex re( + R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); + + std::smatch m; + if (std::regex_match(scheme_host_port, m, re)) { + auto scheme = m[1].str(); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!scheme.empty() && (scheme != "http" && scheme != "https")) { +#else + if (!scheme.empty() && scheme != "http") { +#endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); +#endif + return; + } + + auto is_ssl = scheme == "https"; + + auto host = m[2].str(); + if (host.empty()) { host = m[3].str(); } + + auto port_str = m[4].str(); + auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + + if (is_ssl) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + cli_ = detail::make_unique(host, port, client_cert_path, + client_key_path); + is_ssl_ = is_ssl; +#endif + } else { + cli_ = detail::make_unique(host, port, client_cert_path, + client_key_path); + } + } else { + // NOTE: Update TEST(UniversalClientImplTest, Ipv6LiteralAddress) + // if port param below changes. + cli_ = detail::make_unique(scheme_host_port, 80, + client_cert_path, client_key_path); + } +} // namespace detail + +inline Client::Client(const std::string &host, int port) + : cli_(detail::make_unique(host, port)) {} + +inline Client::Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : cli_(detail::make_unique(host, port, client_cert_path, + client_key_path)) {} + +inline Client::~Client() = default; + +inline bool Client::is_valid() const { + return cli_ != nullptr && cli_->is_valid(); +} + +inline Result Client::Get(const std::string &path) { return cli_->Get(path); } +inline Result Client::Get(const std::string &path, const Headers &headers) { + return cli_->Get(path, headers); +} +inline Result Client::Get(const std::string &path, Progress progress) { + return cli_->Get(path, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + Progress progress) { + return cli_->Get(path, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, + ContentReceiver content_receiver) { + return cli_->Get(path, std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, headers, std::move(content_receiver), + std::move(progress)); +} +inline Result Client::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress) { + return cli_->Get(path, params, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, params, headers, std::move(content_receiver), + std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, params, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result Client::Head(const std::string &path) { return cli_->Head(path); } +inline Result Client::Head(const std::string &path, const Headers &headers) { + return cli_->Head(path, headers); +} + +inline Result Client::Post(const std::string &path) { return cli_->Post(path); } +inline Result Client::Post(const std::string &path, const Headers &headers) { + return cli_->Post(path, headers); +} +inline Result Client::Post(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Post(path, body, content_length, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_length, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Post(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Post(path, body, content_type); +} +inline Result Client::Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Post(path, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Post(path, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Post(path, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Post(path, headers, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Post(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Params ¶ms) { + return cli_->Post(path, params); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Post(path, headers, params); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + return cli_->Post(path, headers, params, progress); +} +inline Result Client::Post(const std::string &path, + const MultipartFormDataItems &items) { + return cli_->Post(path, items); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + return cli_->Post(path, headers, items); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Post(path, headers, items, boundary); +} +inline Result +Client::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Post(path, headers, items, provider_items); +} +inline Result Client::Put(const std::string &path) { return cli_->Put(path); } +inline Result Client::Put(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Put(path, body, content_length, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_length, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Put(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Put(path, body, content_type); +} +inline Result Client::Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Put(path, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Put(path, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Put(path, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Put(path, headers, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Put(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Params ¶ms) { + return cli_->Put(path, params); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Put(path, headers, params); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + return cli_->Put(path, headers, params, progress); +} +inline Result Client::Put(const std::string &path, + const MultipartFormDataItems &items) { + return cli_->Put(path, items); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + return cli_->Put(path, headers, items); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Put(path, headers, items, boundary); +} +inline Result +Client::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Put(path, headers, items, provider_items); +} +inline Result Client::Patch(const std::string &path) { + return cli_->Patch(path); +} +inline Result Client::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, body, content_length, content_type); +} +inline Result Client::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, body, content_length, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_length, content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Patch(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Patch(path, body, content_type); +} +inline Result Client::Patch(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, headers, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Patch(path, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Patch(path, std::move(content_provider), content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Patch(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Delete(const std::string &path) { + return cli_->Delete(path); +} +inline Result Client::Delete(const std::string &path, const Headers &headers) { + return cli_->Delete(path, headers); +} +inline Result Client::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, body, content_length, content_type); +} +inline Result Client::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, body, content_length, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_length, content_type); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, headers, body, content_length, content_type, + progress); +} +inline Result Client::Delete(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Delete(path, body, content_type); +} +inline Result Client::Delete(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, body, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_type); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, headers, body, content_type, progress); +} +inline Result Client::Options(const std::string &path) { + return cli_->Options(path); +} +inline Result Client::Options(const std::string &path, const Headers &headers) { + return cli_->Options(path, headers); +} + +inline bool Client::send(Request &req, Response &res, Error &error) { + return cli_->send(req, res, error); +} + +inline Result Client::send(const Request &req) { return cli_->send(req); } + +inline void Client::stop() { cli_->stop(); } + +inline std::string Client::host() const { return cli_->host(); } + +inline int Client::port() const { return cli_->port(); } + +inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); } + +inline socket_t Client::socket() const { return cli_->socket(); } + +inline void +Client::set_hostname_addr_map(std::map addr_map) { + cli_->set_hostname_addr_map(std::move(addr_map)); +} + +inline void Client::set_default_headers(Headers headers) { + cli_->set_default_headers(std::move(headers)); +} + +inline void Client::set_header_writer( + std::function const &writer) { + cli_->set_header_writer(writer); +} + +inline void Client::set_address_family(int family) { + cli_->set_address_family(family); +} + +inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); } + +inline void Client::set_socket_options(SocketOptions socket_options) { + cli_->set_socket_options(std::move(socket_options)); +} + +inline void Client::set_connection_timeout(time_t sec, time_t usec) { + cli_->set_connection_timeout(sec, usec); +} + +inline void Client::set_read_timeout(time_t sec, time_t usec) { + cli_->set_read_timeout(sec, usec); +} + +inline void Client::set_write_timeout(time_t sec, time_t usec) { + cli_->set_write_timeout(sec, usec); +} + +inline void Client::set_basic_auth(const std::string &username, + const std::string &password) { + cli_->set_basic_auth(username, password); +} +inline void Client::set_bearer_token_auth(const std::string &token) { + cli_->set_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_digest_auth(username, password); +} +#endif + +inline void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } +inline void Client::set_follow_location(bool on) { + cli_->set_follow_location(on); +} + +inline void Client::set_url_encode(bool on) { cli_->set_url_encode(on); } + +inline void Client::set_compress(bool on) { cli_->set_compress(on); } + +inline void Client::set_decompress(bool on) { cli_->set_decompress(on); } + +inline void Client::set_interface(const std::string &intf) { + cli_->set_interface(intf); +} + +inline void Client::set_proxy(const std::string &host, int port) { + cli_->set_proxy(host, port); +} +inline void Client::set_proxy_basic_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_basic_auth(username, password); +} +inline void Client::set_proxy_bearer_token_auth(const std::string &token) { + cli_->set_proxy_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_digest_auth(username, password); +} +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::enable_server_certificate_verification(bool enabled) { + cli_->enable_server_certificate_verification(enabled); +} + +inline void Client::enable_server_hostname_verification(bool enabled) { + cli_->enable_server_hostname_verification(enabled); +} + +inline void Client::set_server_certificate_verifier( + std::function verifier) { + cli_->set_server_certificate_verifier(verifier); +} +#endif + +inline void Client::set_logger(Logger logger) { + cli_->set_logger(std::move(logger)); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path) { + cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path); +} + +inline void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_store(ca_cert_store); + } else { + cli_->set_ca_cert_store(ca_cert_store); + } +} + +inline void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { + set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size)); +} + +inline long Client::get_openssl_verify_result() const { + if (is_ssl_) { + return static_cast(*cli_).get_openssl_verify_result(); + } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? +} + +inline SSL_CTX *Client::ssl_context() const { + if (is_ssl_) { return static_cast(*cli_).ssl_context(); } + return nullptr; +} +#endif + +// ---------------------------------------------------------------------------- + +} // namespace httplib + +#if defined(_WIN32) && defined(CPPHTTPLIB_USE_POLL) +#undef poll +#endif + +#endif // CPPHTTPLIB_HTTPLIB_H diff --git a/libs/linenoise.c b/libs/linenoise.c new file mode 100644 index 0000000..574ab1f --- /dev/null +++ b/libs/linenoise.c @@ -0,0 +1,1349 @@ +/* linenoise.c -- guerrilla line editing library against the idea that a + * line editing lib needs to be 20,000 lines of C code. + * + * You can find the latest source code at: + * + * http://github.com/antirez/linenoise + * + * Does a number of crazy assumptions that happen to be true in 99.9999% of + * the 2010 UNIX computers around. + * + * ------------------------------------------------------------------------ + * + * Copyright (c) 2010-2023, Salvatore Sanfilippo + * Copyright (c) 2010-2013, Pieter Noordhuis + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * ------------------------------------------------------------------------ + * + * References: + * - http://invisible-island.net/xterm/ctlseqs/ctlseqs.html + * - http://www.3waylabs.com/nw/WWW/products/wizcon/vt220.html + * + * Todo list: + * - Filter bogus Ctrl+ combinations. + * - Win32 support + * + * Bloat: + * - History search like Ctrl+r in readline? + * + * List of escape sequences used by this program, we do everything just + * with three sequences. In order to be so cheap we may have some + * flickering effect with some slow terminal, but the lesser sequences + * the more compatible. + * + * EL (Erase Line) + * Sequence: ESC [ n K + * Effect: if n is 0 or missing, clear from cursor to end of line + * Effect: if n is 1, clear from beginning of line to cursor + * Effect: if n is 2, clear entire line + * + * CUF (CUrsor Forward) + * Sequence: ESC [ n C + * Effect: moves cursor forward n chars + * + * CUB (CUrsor Backward) + * Sequence: ESC [ n D + * Effect: moves cursor backward n chars + * + * The following is used to get the terminal width if getting + * the width with the TIOCGWINSZ ioctl fails + * + * DSR (Device Status Report) + * Sequence: ESC [ 6 n + * Effect: reports the current cusor position as ESC [ n ; m R + * where n is the row and m is the column + * + * When multi line mode is enabled, we also use an additional escape + * sequence. However multi line editing is disabled by default. + * + * CUU (Cursor Up) + * Sequence: ESC [ n A + * Effect: moves cursor up of n chars. + * + * CUD (Cursor Down) + * Sequence: ESC [ n B + * Effect: moves cursor down of n chars. + * + * When linenoiseClearScreen() is called, two additional escape sequences + * are used in order to clear the screen and position the cursor at home + * position. + * + * CUP (Cursor position) + * Sequence: ESC [ H + * Effect: moves the cursor to upper left corner + * + * ED (Erase display) + * Sequence: ESC [ 2 J + * Effect: clear the whole screen + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "linenoise.h" + +#define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100 +#define LINENOISE_MAX_LINE 4096 +static char *unsupported_term[] = {"dumb","cons25","emacs",NULL}; +static linenoiseCompletionCallback *completionCallback = NULL; +static linenoiseHintsCallback *hintsCallback = NULL; +static linenoiseFreeHintsCallback *freeHintsCallback = NULL; +static char *linenoiseNoTTY(void); +static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags); +static void refreshLineWithFlags(struct linenoiseState *l, int flags); + +static struct termios orig_termios; /* In order to restore at exit.*/ +static int maskmode = 0; /* Show "***" instead of input. For passwords. */ +static int rawmode = 0; /* For atexit() function to check if restore is needed*/ +static int mlmode = 0; /* Multi line mode. Default is single line. */ +static int atexit_registered = 0; /* Register atexit just 1 time. */ +static int history_max_len = LINENOISE_DEFAULT_HISTORY_MAX_LEN; +static int history_len = 0; +static char **history = NULL; + +enum KEY_ACTION{ + KEY_NULL = 0, /* NULL */ + CTRL_A = 1, /* Ctrl+a */ + CTRL_B = 2, /* Ctrl-b */ + CTRL_C = 3, /* Ctrl-c */ + CTRL_D = 4, /* Ctrl-d */ + CTRL_E = 5, /* Ctrl-e */ + CTRL_F = 6, /* Ctrl-f */ + CTRL_H = 8, /* Ctrl-h */ + TAB = 9, /* Tab */ + CTRL_K = 11, /* Ctrl+k */ + CTRL_L = 12, /* Ctrl+l */ + ENTER = 13, /* Enter */ + CTRL_N = 14, /* Ctrl-n */ + CTRL_P = 16, /* Ctrl-p */ + CTRL_T = 20, /* Ctrl-t */ + CTRL_U = 21, /* Ctrl+u */ + CTRL_W = 23, /* Ctrl+w */ + ESC = 27, /* Escape */ + BACKSPACE = 127 /* Backspace */ +}; + +static void linenoiseAtExit(void); +int linenoiseHistoryAdd(const char *line); +#define REFRESH_CLEAN (1<<0) // Clean the old prompt from the screen +#define REFRESH_WRITE (1<<1) // Rewrite the prompt on the screen. +#define REFRESH_ALL (REFRESH_CLEAN|REFRESH_WRITE) // Do both. +static void refreshLine(struct linenoiseState *l); + +/* Debugging macro. */ +#if 0 +FILE *lndebug_fp = NULL; +#define lndebug(...) \ + do { \ + if (lndebug_fp == NULL) { \ + lndebug_fp = fopen("/tmp/lndebug.txt","a"); \ + fprintf(lndebug_fp, \ + "[%d %d %d] p: %d, rows: %d, rpos: %d, max: %d, oldmax: %d\n", \ + (int)l->len,(int)l->pos,(int)l->oldpos,plen,rows,rpos, \ + (int)l->oldrows,old_rows); \ + } \ + fprintf(lndebug_fp, ", " __VA_ARGS__); \ + fflush(lndebug_fp); \ + } while (0) +#else +#define lndebug(fmt, ...) +#endif + +/* ======================= Low level terminal handling ====================== */ + +/* Enable "mask mode". When it is enabled, instead of the input that + * the user is typing, the terminal will just display a corresponding + * number of asterisks, like "****". This is useful for passwords and other + * secrets that should not be displayed. */ +void linenoiseMaskModeEnable(void) { + maskmode = 1; +} + +/* Disable mask mode. */ +void linenoiseMaskModeDisable(void) { + maskmode = 0; +} + +/* Set if to use or not the multi line mode. */ +void linenoiseSetMultiLine(int ml) { + mlmode = ml; +} + +/* Return true if the terminal name is in the list of terminals we know are + * not able to understand basic escape sequences. */ +static int isUnsupportedTerm(void) { + char *term = getenv("TERM"); + int j; + + if (term == NULL) return 0; + for (j = 0; unsupported_term[j]; j++) + if (!strcasecmp(term,unsupported_term[j])) return 1; + return 0; +} + +/* Raw mode: 1960 magic shit. */ +static int enableRawMode(int fd) { + struct termios raw; + + if (!isatty(STDIN_FILENO)) goto fatal; + if (!atexit_registered) { + atexit(linenoiseAtExit); + atexit_registered = 1; + } + if (tcgetattr(fd,&orig_termios) == -1) goto fatal; + + raw = orig_termios; /* modify the original mode */ + /* input modes: no break, no CR to NL, no parity check, no strip char, + * no start/stop output control. */ + raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + /* output modes - disable post processing */ + raw.c_oflag &= ~(OPOST); + /* control modes - set 8 bit chars */ + raw.c_cflag |= (CS8); + /* local modes - choing off, canonical off, no extended functions, + * no signal chars (^Z,^C) */ + raw.c_lflag &= ~(ECHO | ICANON | IEXTEN | ISIG); + /* control chars - set return condition: min number of bytes and timer. + * We want read to return every single byte, without timeout. */ + raw.c_cc[VMIN] = 1; raw.c_cc[VTIME] = 0; /* 1 byte, no timer */ + + /* put terminal in raw mode after flushing */ + if (tcsetattr(fd,TCSAFLUSH,&raw) < 0) goto fatal; + rawmode = 1; + return 0; + +fatal: + errno = ENOTTY; + return -1; +} + +static void disableRawMode(int fd) { + /* Don't even check the return value as it's too late. */ + if (rawmode && tcsetattr(fd,TCSAFLUSH,&orig_termios) != -1) + rawmode = 0; +} + +/* Use the ESC [6n escape sequence to query the horizontal cursor position + * and return it. On error -1 is returned, on success the position of the + * cursor. */ +static int getCursorPosition(int ifd, int ofd) { + char buf[32]; + int cols, rows; + unsigned int i = 0; + + /* Report cursor location */ + if (write(ofd, "\x1b[6n", 4) != 4) return -1; + + /* Read the response: ESC [ rows ; cols R */ + while (i < sizeof(buf)-1) { + if (read(ifd,buf+i,1) != 1) break; + if (buf[i] == 'R') break; + i++; + } + buf[i] = '\0'; + + /* Parse it. */ + if (buf[0] != ESC || buf[1] != '[') return -1; + if (sscanf(buf+2,"%d;%d",&rows,&cols) != 2) return -1; + return cols; +} + +/* Try to get the number of columns in the current terminal, or assume 80 + * if it fails. */ +static int getColumns(int ifd, int ofd) { + struct winsize ws; + + if (ioctl(1, TIOCGWINSZ, &ws) == -1 || ws.ws_col == 0) { + /* ioctl() failed. Try to query the terminal itself. */ + int start, cols; + + /* Get the initial position so we can restore it later. */ + start = getCursorPosition(ifd,ofd); + if (start == -1) goto failed; + + /* Go to right margin and get position. */ + if (write(ofd,"\x1b[999C",6) != 6) goto failed; + cols = getCursorPosition(ifd,ofd); + if (cols == -1) goto failed; + + /* Restore position. */ + if (cols > start) { + char seq[32]; + snprintf(seq,32,"\x1b[%dD",cols-start); + if (write(ofd,seq,strlen(seq)) == -1) { + /* Can't recover... */ + } + } + return cols; + } else { + return ws.ws_col; + } + +failed: + return 80; +} + +/* Clear the screen. Used to handle ctrl+l */ +void linenoiseClearScreen(void) { + if (write(STDOUT_FILENO,"\x1b[H\x1b[2J",7) <= 0) { + /* nothing to do, just to avoid warning. */ + } +} + +/* Beep, used for completion when there is nothing to complete or when all + * the choices were already shown. */ +static void linenoiseBeep(void) { + fprintf(stderr, "\x7"); + fflush(stderr); +} + +/* ============================== Completion ================================ */ + +/* Free a list of completion option populated by linenoiseAddCompletion(). */ +static void freeCompletions(linenoiseCompletions *lc) { + size_t i; + for (i = 0; i < lc->len; i++) + free(lc->cvec[i]); + if (lc->cvec != NULL) + free(lc->cvec); +} + +/* Called by completeLine() and linenoiseShow() to render the current + * edited line with the proposed completion. If the current completion table + * is already available, it is passed as second argument, otherwise the + * function will use the callback to obtain it. + * + * Flags are the same as refreshLine*(), that is REFRESH_* macros. */ +static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags) { + /* Obtain the table of completions if the caller didn't provide one. */ + linenoiseCompletions ctable = { 0, NULL }; + if (lc == NULL) { + completionCallback(ls->buf,&ctable); + lc = &ctable; + } + + /* Show the edited line with completion if possible, or just refresh. */ + if (ls->completion_idx < lc->len) { + struct linenoiseState saved = *ls; + ls->len = ls->pos = strlen(lc->cvec[ls->completion_idx]); + ls->buf = lc->cvec[ls->completion_idx]; + refreshLineWithFlags(ls,flags); + ls->len = saved.len; + ls->pos = saved.pos; + ls->buf = saved.buf; + } else { + refreshLineWithFlags(ls,flags); + } + + /* Free the completions table if needed. */ + if (lc != &ctable) freeCompletions(&ctable); +} + +/* This is an helper function for linenoiseEdit*() and is called when the + * user types the key in order to complete the string currently in the + * input. + * + * The state of the editing is encapsulated into the pointed linenoiseState + * structure as described in the structure definition. + * + * If the function returns non-zero, the caller should handle the + * returned value as a byte read from the standard input, and process + * it as usually: this basically means that the function may return a byte + * read from the termianl but not processed. Otherwise, if zero is returned, + * the input was consumed by the completeLine() function to navigate the + * possible completions, and the caller should read for the next characters + * from stdin. */ +static int completeLine(struct linenoiseState *ls, int keypressed) { + linenoiseCompletions lc = { 0, NULL }; + int nwritten; + char c = keypressed; + + completionCallback(ls->buf,&lc); + if (lc.len == 0) { + linenoiseBeep(); + ls->in_completion = 0; + } else { + switch(c) { + case 9: /* tab */ + if (ls->in_completion == 0) { + ls->in_completion = 1; + ls->completion_idx = 0; + } else { + ls->completion_idx = (ls->completion_idx+1) % (lc.len+1); + if (ls->completion_idx == lc.len) linenoiseBeep(); + } + c = 0; + break; + case 27: /* escape */ + /* Re-show original buffer */ + if (ls->completion_idx < lc.len) refreshLine(ls); + ls->in_completion = 0; + c = 0; + break; + default: + /* Update buffer and return */ + if (ls->completion_idx < lc.len) { + nwritten = snprintf(ls->buf,ls->buflen,"%s", + lc.cvec[ls->completion_idx]); + ls->len = ls->pos = nwritten; + } + ls->in_completion = 0; + break; + } + + /* Show completion or original buffer */ + if (ls->in_completion && ls->completion_idx < lc.len) { + refreshLineWithCompletion(ls,&lc,REFRESH_ALL); + } else { + refreshLine(ls); + } + } + + freeCompletions(&lc); + return c; /* Return last read character */ +} + +/* Register a callback function to be called for tab-completion. */ +void linenoiseSetCompletionCallback(linenoiseCompletionCallback *fn) { + completionCallback = fn; +} + +/* Register a hits function to be called to show hits to the user at the + * right of the prompt. */ +void linenoiseSetHintsCallback(linenoiseHintsCallback *fn) { + hintsCallback = fn; +} + +/* Register a function to free the hints returned by the hints callback + * registered with linenoiseSetHintsCallback(). */ +void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *fn) { + freeHintsCallback = fn; +} + +/* This function is used by the callback function registered by the user + * in order to add completion options given the input string when the + * user typed . See the example.c source code for a very easy to + * understand example. */ +void linenoiseAddCompletion(linenoiseCompletions *lc, const char *str) { + size_t len = strlen(str); + char *copy, **cvec; + + copy = malloc(len+1); + if (copy == NULL) return; + memcpy(copy,str,len+1); + cvec = realloc(lc->cvec,sizeof(char*)*(lc->len+1)); + if (cvec == NULL) { + free(copy); + return; + } + lc->cvec = cvec; + lc->cvec[lc->len++] = copy; +} + +/* =========================== Line editing ================================= */ + +/* We define a very simple "append buffer" structure, that is an heap + * allocated string where we can append to. This is useful in order to + * write all the escape sequences in a buffer and flush them to the standard + * output in a single call, to avoid flickering effects. */ +struct abuf { + char *b; + int len; +}; + +static void abInit(struct abuf *ab) { + ab->b = NULL; + ab->len = 0; +} + +static void abAppend(struct abuf *ab, const char *s, int len) { + char *new = realloc(ab->b,ab->len+len); + + if (new == NULL) return; + memcpy(new+ab->len,s,len); + ab->b = new; + ab->len += len; +} + +static void abFree(struct abuf *ab) { + free(ab->b); +} + +/* Helper of refreshSingleLine() and refreshMultiLine() to show hints + * to the right of the prompt. */ +void refreshShowHints(struct abuf *ab, struct linenoiseState *l, int plen) { + char seq[64]; + if (hintsCallback && plen+l->len < l->cols) { + int color = -1, bold = 0; + char *hint = hintsCallback(l->buf,&color,&bold); + if (hint) { + int hintlen = strlen(hint); + int hintmaxlen = l->cols-(plen+l->len); + if (hintlen > hintmaxlen) hintlen = hintmaxlen; + if (bold == 1 && color == -1) color = 37; + if (color != -1 || bold != 0) + snprintf(seq,64,"\033[%d;%d;49m",bold,color); + else + seq[0] = '\0'; + abAppend(ab,seq,strlen(seq)); + abAppend(ab,hint,hintlen); + if (color != -1 || bold != 0) + abAppend(ab,"\033[0m",4); + /* Call the function to free the hint returned. */ + if (freeHintsCallback) freeHintsCallback(hint); + } + } +} + +/* Single line low level line refresh. + * + * Rewrite the currently edited line accordingly to the buffer content, + * cursor position, and number of columns of the terminal. + * + * Flags is REFRESH_* macros. The function can just remove the old + * prompt, just write it, or both. */ +static void refreshSingleLine(struct linenoiseState *l, int flags) { + char seq[64]; + size_t plen = strlen(l->prompt); + int fd = l->ofd; + char *buf = l->buf; + size_t len = l->len; + size_t pos = l->pos; + struct abuf ab; + + while((plen+pos) >= l->cols) { + buf++; + len--; + pos--; + } + while (plen+len > l->cols) { + len--; + } + + abInit(&ab); + /* Cursor to left edge */ + snprintf(seq,sizeof(seq),"\r"); + abAppend(&ab,seq,strlen(seq)); + + if (flags & REFRESH_WRITE) { + /* Write the prompt and the current buffer content */ + abAppend(&ab,l->prompt,strlen(l->prompt)); + if (maskmode == 1) { + while (len--) abAppend(&ab,"*",1); + } else { + abAppend(&ab,buf,len); + } + /* Show hits if any. */ + refreshShowHints(&ab,l,plen); + } + + /* Erase to right */ + snprintf(seq,sizeof(seq),"\x1b[0K"); + abAppend(&ab,seq,strlen(seq)); + + if (flags & REFRESH_WRITE) { + /* Move cursor to original position. */ + snprintf(seq,sizeof(seq),"\r\x1b[%dC", (int)(pos+plen)); + abAppend(&ab,seq,strlen(seq)); + } + + if (write(fd,ab.b,ab.len) == -1) {} /* Can't recover from write error. */ + abFree(&ab); +} + +/* Multi line low level line refresh. + * + * Rewrite the currently edited line accordingly to the buffer content, + * cursor position, and number of columns of the terminal. + * + * Flags is REFRESH_* macros. The function can just remove the old + * prompt, just write it, or both. */ +static void refreshMultiLine(struct linenoiseState *l, int flags) { + char seq[64]; + int plen = strlen(l->prompt); + int rows = (plen+l->len+l->cols-1)/l->cols; /* rows used by current buf. */ + int rpos = (plen+l->oldpos+l->cols)/l->cols; /* cursor relative row. */ + int rpos2; /* rpos after refresh. */ + int col; /* colum position, zero-based. */ + int old_rows = l->oldrows; + int fd = l->ofd, j; + struct abuf ab; + + l->oldrows = rows; + + /* First step: clear all the lines used before. To do so start by + * going to the last row. */ + abInit(&ab); + + if (flags & REFRESH_CLEAN) { + if (old_rows-rpos > 0) { + lndebug("go down %d", old_rows-rpos); + snprintf(seq,64,"\x1b[%dB", old_rows-rpos); + abAppend(&ab,seq,strlen(seq)); + } + + /* Now for every row clear it, go up. */ + for (j = 0; j < old_rows-1; j++) { + lndebug("clear+up"); + snprintf(seq,64,"\r\x1b[0K\x1b[1A"); + abAppend(&ab,seq,strlen(seq)); + } + } + + if (flags & REFRESH_ALL) { + /* Clean the top line. */ + lndebug("clear"); + snprintf(seq,64,"\r\x1b[0K"); + abAppend(&ab,seq,strlen(seq)); + } + + if (flags & REFRESH_WRITE) { + /* Write the prompt and the current buffer content */ + abAppend(&ab,l->prompt,strlen(l->prompt)); + if (maskmode == 1) { + unsigned int i; + for (i = 0; i < l->len; i++) abAppend(&ab,"*",1); + } else { + abAppend(&ab,l->buf,l->len); + } + + /* Show hits if any. */ + refreshShowHints(&ab,l,plen); + + /* If we are at the very end of the screen with our prompt, we need to + * emit a newline and move the prompt to the first column. */ + if (l->pos && + l->pos == l->len && + (l->pos+plen) % l->cols == 0) + { + lndebug(""); + abAppend(&ab,"\n",1); + snprintf(seq,64,"\r"); + abAppend(&ab,seq,strlen(seq)); + rows++; + if (rows > (int)l->oldrows) l->oldrows = rows; + } + + /* Move cursor to right position. */ + rpos2 = (plen+l->pos+l->cols)/l->cols; /* Current cursor relative row */ + lndebug("rpos2 %d", rpos2); + + /* Go up till we reach the expected positon. */ + if (rows-rpos2 > 0) { + lndebug("go-up %d", rows-rpos2); + snprintf(seq,64,"\x1b[%dA", rows-rpos2); + abAppend(&ab,seq,strlen(seq)); + } + + /* Set column. */ + col = (plen+(int)l->pos) % (int)l->cols; + lndebug("set col %d", 1+col); + if (col) + snprintf(seq,64,"\r\x1b[%dC", col); + else + snprintf(seq,64,"\r"); + abAppend(&ab,seq,strlen(seq)); + } + + lndebug("\n"); + l->oldpos = l->pos; + + if (write(fd,ab.b,ab.len) == -1) {} /* Can't recover from write error. */ + abFree(&ab); +} + +/* Calls the two low level functions refreshSingleLine() or + * refreshMultiLine() according to the selected mode. */ +static void refreshLineWithFlags(struct linenoiseState *l, int flags) { + if (mlmode) + refreshMultiLine(l,flags); + else + refreshSingleLine(l,flags); +} + +/* Utility function to avoid specifying REFRESH_ALL all the times. */ +static void refreshLine(struct linenoiseState *l) { + refreshLineWithFlags(l,REFRESH_ALL); +} + +/* Hide the current line, when using the multiplexing API. */ +void linenoiseHide(struct linenoiseState *l) { + if (mlmode) + refreshMultiLine(l,REFRESH_CLEAN); + else + refreshSingleLine(l,REFRESH_CLEAN); +} + +/* Show the current line, when using the multiplexing API. */ +void linenoiseShow(struct linenoiseState *l) { + if (l->in_completion) { + refreshLineWithCompletion(l,NULL,REFRESH_WRITE); + } else { + refreshLineWithFlags(l,REFRESH_WRITE); + } +} + +/* Insert the character 'c' at cursor current position. + * + * On error writing to the terminal -1 is returned, otherwise 0. */ +int linenoiseEditInsert(struct linenoiseState *l, char c) { + if (l->len < l->buflen) { + if (l->len == l->pos) { + l->buf[l->pos] = c; + l->pos++; + l->len++; + l->buf[l->len] = '\0'; + if ((!mlmode && l->plen+l->len < l->cols && !hintsCallback)) { + /* Avoid a full update of the line in the + * trivial case. */ + char d = (maskmode==1) ? '*' : c; + if (write(l->ofd,&d,1) == -1) return -1; + } else { + refreshLine(l); + } + } else { + memmove(l->buf+l->pos+1,l->buf+l->pos,l->len-l->pos); + l->buf[l->pos] = c; + l->len++; + l->pos++; + l->buf[l->len] = '\0'; + refreshLine(l); + } + } + return 0; +} + +/* Move cursor on the left. */ +void linenoiseEditMoveLeft(struct linenoiseState *l) { + if (l->pos > 0) { + l->pos--; + refreshLine(l); + } +} + +/* Move cursor on the right. */ +void linenoiseEditMoveRight(struct linenoiseState *l) { + if (l->pos != l->len) { + l->pos++; + refreshLine(l); + } +} + +/* Move cursor to the start of the line. */ +void linenoiseEditMoveHome(struct linenoiseState *l) { + if (l->pos != 0) { + l->pos = 0; + refreshLine(l); + } +} + +/* Move cursor to the end of the line. */ +void linenoiseEditMoveEnd(struct linenoiseState *l) { + if (l->pos != l->len) { + l->pos = l->len; + refreshLine(l); + } +} + +/* Substitute the currently edited line with the next or previous history + * entry as specified by 'dir'. */ +#define LINENOISE_HISTORY_NEXT 0 +#define LINENOISE_HISTORY_PREV 1 +void linenoiseEditHistoryNext(struct linenoiseState *l, int dir) { + if (history_len > 1) { + /* Update the current history entry before to + * overwrite it with the next one. */ + free(history[history_len - 1 - l->history_index]); + history[history_len - 1 - l->history_index] = strdup(l->buf); + /* Show the new entry */ + l->history_index += (dir == LINENOISE_HISTORY_PREV) ? 1 : -1; + if (l->history_index < 0) { + l->history_index = 0; + return; + } else if (l->history_index >= history_len) { + l->history_index = history_len-1; + return; + } + strncpy(l->buf,history[history_len - 1 - l->history_index],l->buflen); + l->buf[l->buflen-1] = '\0'; + l->len = l->pos = strlen(l->buf); + refreshLine(l); + } +} + +/* Delete the character at the right of the cursor without altering the cursor + * position. Basically this is what happens with the "Delete" keyboard key. */ +void linenoiseEditDelete(struct linenoiseState *l) { + if (l->len > 0 && l->pos < l->len) { + memmove(l->buf+l->pos,l->buf+l->pos+1,l->len-l->pos-1); + l->len--; + l->buf[l->len] = '\0'; + refreshLine(l); + } +} + +/* Backspace implementation. */ +void linenoiseEditBackspace(struct linenoiseState *l) { + if (l->pos > 0 && l->len > 0) { + memmove(l->buf+l->pos-1,l->buf+l->pos,l->len-l->pos); + l->pos--; + l->len--; + l->buf[l->len] = '\0'; + refreshLine(l); + } +} + +/* Delete the previosu word, maintaining the cursor at the start of the + * current word. */ +void linenoiseEditDeletePrevWord(struct linenoiseState *l) { + size_t old_pos = l->pos; + size_t diff; + + while (l->pos > 0 && l->buf[l->pos-1] == ' ') + l->pos--; + while (l->pos > 0 && l->buf[l->pos-1] != ' ') + l->pos--; + diff = old_pos - l->pos; + memmove(l->buf+l->pos,l->buf+old_pos,l->len-old_pos+1); + l->len -= diff; + refreshLine(l); +} + +/* This function is part of the multiplexed API of Linenoise, that is used + * in order to implement the blocking variant of the API but can also be + * called by the user directly in an event driven program. It will: + * + * 1. Initialize the linenoise state passed by the user. + * 2. Put the terminal in RAW mode. + * 3. Show the prompt. + * 4. Return control to the user, that will have to call linenoiseEditFeed() + * each time there is some data arriving in the standard input. + * + * The user can also call linenoiseEditHide() and linenoiseEditShow() if it + * is required to show some input arriving asyncronously, without mixing + * it with the currently edited line. + * + * When linenoiseEditFeed() returns non-NULL, the user finished with the + * line editing session (pressed enter CTRL-D/C): in this case the caller + * needs to call linenoiseEditStop() to put back the terminal in normal + * mode. This will not destroy the buffer, as long as the linenoiseState + * is still valid in the context of the caller. + * + * The function returns 0 on success, or -1 if writing to standard output + * fails. If stdin_fd or stdout_fd are set to -1, the default is to use + * STDIN_FILENO and STDOUT_FILENO. + */ +int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) { + /* Populate the linenoise state that we pass to functions implementing + * specific editing functionalities. */ + l->in_completion = 0; + l->ifd = stdin_fd != -1 ? stdin_fd : STDIN_FILENO; + l->ofd = stdout_fd != -1 ? stdout_fd : STDOUT_FILENO; + l->buf = buf; + l->buflen = buflen; + l->prompt = prompt; + l->plen = strlen(prompt); + l->oldpos = l->pos = 0; + l->len = 0; + + /* Enter raw mode. */ + if (enableRawMode(l->ifd) == -1) return -1; + + l->cols = getColumns(stdin_fd, stdout_fd); + l->oldrows = 0; + l->history_index = 0; + + /* Buffer starts empty. */ + l->buf[0] = '\0'; + l->buflen--; /* Make sure there is always space for the nulterm */ + + /* If stdin is not a tty, stop here with the initialization. We + * will actually just read a line from standard input in blocking + * mode later, in linenoiseEditFeed(). */ + if (!isatty(l->ifd)) return 0; + + /* The latest history entry is always our current buffer, that + * initially is just an empty string. */ + linenoiseHistoryAdd(""); + + if (write(l->ofd,prompt,l->plen) == -1) return -1; + return 0; +} + +char *linenoiseEditMore = "If you see this, you are misusing the API: when linenoiseEditFeed() is called, if it returns linenoiseEditMore the user is yet editing the line. See the README file for more information."; + +/* This function is part of the multiplexed API of linenoise, see the top + * comment on linenoiseEditStart() for more information. Call this function + * each time there is some data to read from the standard input file + * descriptor. In the case of blocking operations, this function can just be + * called in a loop, and block. + * + * The function returns linenoiseEditMore to signal that line editing is still + * in progress, that is, the user didn't yet pressed enter / CTRL-D. Otherwise + * the function returns the pointer to the heap-allocated buffer with the + * edited line, that the user should free with linenoiseFree(). + * + * On special conditions, NULL is returned and errno is populated: + * + * EAGAIN if the user pressed Ctrl-C + * ENOENT if the user pressed Ctrl-D + * + * Some other errno: I/O error. + */ +char *linenoiseEditFeed(struct linenoiseState *l) { + /* Not a TTY, pass control to line reading without character + * count limits. */ + if (!isatty(l->ifd)) return linenoiseNoTTY(); + + char c; + int nread; + char seq[3]; + + nread = read(l->ifd,&c,1); + if (nread <= 0) return NULL; + + /* Only autocomplete when the callback is set. It returns < 0 when + * there was an error reading from fd. Otherwise it will return the + * character that should be handled next. */ + if ((l->in_completion || c == 9) && completionCallback != NULL) { + c = completeLine(l,c); + /* Return on errors */ + if (c < 0) return NULL; + /* Read next character when 0 */ + if (c == 0) return linenoiseEditMore; + } + + switch(c) { + case ENTER: /* enter */ + history_len--; + free(history[history_len]); + if (mlmode) linenoiseEditMoveEnd(l); + if (hintsCallback) { + /* Force a refresh without hints to leave the previous + * line as the user typed it after a newline. */ + linenoiseHintsCallback *hc = hintsCallback; + hintsCallback = NULL; + refreshLine(l); + hintsCallback = hc; + } + return strdup(l->buf); + case CTRL_C: /* ctrl-c */ + errno = EAGAIN; + return NULL; + case BACKSPACE: /* backspace */ + case 8: /* ctrl-h */ + linenoiseEditBackspace(l); + break; + case CTRL_D: /* ctrl-d, remove char at right of cursor, or if the + line is empty, act as end-of-file. */ + if (l->len > 0) { + linenoiseEditDelete(l); + } else { + history_len--; + free(history[history_len]); + errno = ENOENT; + return NULL; + } + break; + case CTRL_T: /* ctrl-t, swaps current character with previous. */ + if (l->pos > 0 && l->pos < l->len) { + int aux = l->buf[l->pos-1]; + l->buf[l->pos-1] = l->buf[l->pos]; + l->buf[l->pos] = aux; + if (l->pos != l->len-1) l->pos++; + refreshLine(l); + } + break; + case CTRL_B: /* ctrl-b */ + linenoiseEditMoveLeft(l); + break; + case CTRL_F: /* ctrl-f */ + linenoiseEditMoveRight(l); + break; + case CTRL_P: /* ctrl-p */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); + break; + case CTRL_N: /* ctrl-n */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); + break; + case ESC: /* escape sequence */ + /* Read the next two bytes representing the escape sequence. + * Use two calls to handle slow terminals returning the two + * chars at different times. */ + if (read(l->ifd,seq,1) == -1) break; + if (read(l->ifd,seq+1,1) == -1) break; + + /* ESC [ sequences. */ + if (seq[0] == '[') { + if (seq[1] >= '0' && seq[1] <= '9') { + /* Extended escape, read additional byte. */ + if (read(l->ifd,seq+2,1) == -1) break; + if (seq[2] == '~') { + switch(seq[1]) { + case '3': /* Delete key. */ + linenoiseEditDelete(l); + break; + } + } + } else { + switch(seq[1]) { + case 'A': /* Up */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); + break; + case 'B': /* Down */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); + break; + case 'C': /* Right */ + linenoiseEditMoveRight(l); + break; + case 'D': /* Left */ + linenoiseEditMoveLeft(l); + break; + case 'H': /* Home */ + linenoiseEditMoveHome(l); + break; + case 'F': /* End*/ + linenoiseEditMoveEnd(l); + break; + } + } + } + + /* ESC O sequences. */ + else if (seq[0] == 'O') { + switch(seq[1]) { + case 'H': /* Home */ + linenoiseEditMoveHome(l); + break; + case 'F': /* End*/ + linenoiseEditMoveEnd(l); + break; + } + } + break; + default: + if (linenoiseEditInsert(l,c)) return NULL; + break; + case CTRL_U: /* Ctrl+u, delete the whole line. */ + l->buf[0] = '\0'; + l->pos = l->len = 0; + refreshLine(l); + break; + case CTRL_K: /* Ctrl+k, delete from current to end of line. */ + l->buf[l->pos] = '\0'; + l->len = l->pos; + refreshLine(l); + break; + case CTRL_A: /* Ctrl+a, go to the start of the line */ + linenoiseEditMoveHome(l); + break; + case CTRL_E: /* ctrl+e, go to the end of the line */ + linenoiseEditMoveEnd(l); + break; + case CTRL_L: /* ctrl+l, clear screen */ + linenoiseClearScreen(); + refreshLine(l); + break; + case CTRL_W: /* ctrl+w, delete previous word */ + linenoiseEditDeletePrevWord(l); + break; + } + return linenoiseEditMore; +} + +/* This is part of the multiplexed linenoise API. See linenoiseEditStart() + * for more information. This function is called when linenoiseEditFeed() + * returns something different than NULL. At this point the user input + * is in the buffer, and we can restore the terminal in normal mode. */ +void linenoiseEditStop(struct linenoiseState *l) { + if (!isatty(l->ifd)) return; + disableRawMode(l->ifd); + printf("\n"); +} + +/* This just implements a blocking loop for the multiplexed API. + * In many applications that are not event-drivern, we can just call + * the blocking linenoise API, wait for the user to complete the editing + * and return the buffer. */ +static char *linenoiseBlockingEdit(int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) +{ + struct linenoiseState l; + + /* Editing without a buffer is invalid. */ + if (buflen == 0) { + errno = EINVAL; + return NULL; + } + + linenoiseEditStart(&l,stdin_fd,stdout_fd,buf,buflen,prompt); + char *res; + while((res = linenoiseEditFeed(&l)) == linenoiseEditMore); + linenoiseEditStop(&l); + return res; +} + +/* This special mode is used by linenoise in order to print scan codes + * on screen for debugging / development purposes. It is implemented + * by the linenoise_example program using the --keycodes option. */ +void linenoisePrintKeyCodes(void) { + char quit[4]; + + printf("Linenoise key codes debugging mode.\n" + "Press keys to see scan codes. Type 'quit' at any time to exit.\n"); + if (enableRawMode(STDIN_FILENO) == -1) return; + memset(quit,' ',4); + while(1) { + char c; + int nread; + + nread = read(STDIN_FILENO,&c,1); + if (nread <= 0) continue; + memmove(quit,quit+1,sizeof(quit)-1); /* shift string to left. */ + quit[sizeof(quit)-1] = c; /* Insert current char on the right. */ + if (memcmp(quit,"quit",sizeof(quit)) == 0) break; + + printf("'%c' %02x (%d) (type quit to exit)\n", + isprint(c) ? c : '?', (int)c, (int)c); + printf("\r"); /* Go left edge manually, we are in raw mode. */ + fflush(stdout); + } + disableRawMode(STDIN_FILENO); +} + +/* This function is called when linenoise() is called with the standard + * input file descriptor not attached to a TTY. So for example when the + * program using linenoise is called in pipe or with a file redirected + * to its standard input. In this case, we want to be able to return the + * line regardless of its length (by default we are limited to 4k). */ +static char *linenoiseNoTTY(void) { + char *line = NULL; + size_t len = 0, maxlen = 0; + + while(1) { + if (len == maxlen) { + if (maxlen == 0) maxlen = 16; + maxlen *= 2; + char *oldval = line; + line = realloc(line,maxlen); + if (line == NULL) { + if (oldval) free(oldval); + return NULL; + } + } + int c = fgetc(stdin); + if (c == EOF || c == '\n') { + if (c == EOF && len == 0) { + free(line); + return NULL; + } else { + line[len] = '\0'; + return line; + } + } else { + line[len] = c; + len++; + } + } +} + +/* The high level function that is the main API of the linenoise library. + * This function checks if the terminal has basic capabilities, just checking + * for a blacklist of stupid terminals, and later either calls the line + * editing function or uses dummy fgets() so that you will be able to type + * something even in the most desperate of the conditions. */ +char *linenoise(const char *prompt) { + char buf[LINENOISE_MAX_LINE]; + + if (!isatty(STDIN_FILENO)) { + /* Not a tty: read from file / pipe. In this mode we don't want any + * limit to the line size, so we call a function to handle that. */ + return linenoiseNoTTY(); + } else if (isUnsupportedTerm()) { + size_t len; + + printf("%s",prompt); + fflush(stdout); + if (fgets(buf,LINENOISE_MAX_LINE,stdin) == NULL) return NULL; + len = strlen(buf); + while(len && (buf[len-1] == '\n' || buf[len-1] == '\r')) { + len--; + buf[len] = '\0'; + } + return strdup(buf); + } else { + char *retval = linenoiseBlockingEdit(STDIN_FILENO,STDOUT_FILENO,buf,LINENOISE_MAX_LINE,prompt); + return retval; + } +} + +/* This is just a wrapper the user may want to call in order to make sure + * the linenoise returned buffer is freed with the same allocator it was + * created with. Useful when the main program is using an alternative + * allocator. */ +void linenoiseFree(void *ptr) { + if (ptr == linenoiseEditMore) return; // Protect from API misuse. + free(ptr); +} + +/* ================================ History ================================= */ + +/* Free the history, but does not reset it. Only used when we have to + * exit() to avoid memory leaks are reported by valgrind & co. */ +static void freeHistory(void) { + if (history) { + int j; + + for (j = 0; j < history_len; j++) + free(history[j]); + free(history); + } +} + +/* At exit we'll try to fix the terminal to the initial conditions. */ +static void linenoiseAtExit(void) { + disableRawMode(STDIN_FILENO); + freeHistory(); +} + +/* This is the API call to add a new entry in the linenoise history. + * It uses a fixed array of char pointers that are shifted (memmoved) + * when the history max length is reached in order to remove the older + * entry and make room for the new one, so it is not exactly suitable for huge + * histories, but will work well for a few hundred of entries. + * + * Using a circular buffer is smarter, but a bit more complex to handle. */ +int linenoiseHistoryAdd(const char *line) { + char *linecopy; + + if (history_max_len == 0) return 0; + + /* Initialization on first call. */ + if (history == NULL) { + history = malloc(sizeof(char*)*history_max_len); + if (history == NULL) return 0; + memset(history,0,(sizeof(char*)*history_max_len)); + } + + /* Don't add duplicated lines. */ + if (history_len && !strcmp(history[history_len-1], line)) return 0; + + /* Add an heap allocated copy of the line in the history. + * If we reached the max length, remove the older line. */ + linecopy = strdup(line); + if (!linecopy) return 0; + if (history_len == history_max_len) { + free(history[0]); + memmove(history,history+1,sizeof(char*)*(history_max_len-1)); + history_len--; + } + history[history_len] = linecopy; + history_len++; + return 1; +} + +/* Set the maximum length for the history. This function can be called even + * if there is already some history, the function will make sure to retain + * just the latest 'len' elements if the new history length value is smaller + * than the amount of items already inside the history. */ +int linenoiseHistorySetMaxLen(int len) { + char **new; + + if (len < 1) return 0; + if (history) { + int tocopy = history_len; + + new = malloc(sizeof(char*)*len); + if (new == NULL) return 0; + + /* If we can't copy everything, free the elements we'll not use. */ + if (len < tocopy) { + int j; + + for (j = 0; j < tocopy-len; j++) free(history[j]); + tocopy = len; + } + memset(new,0,sizeof(char*)*len); + memcpy(new,history+(history_len-tocopy), sizeof(char*)*tocopy); + free(history); + history = new; + } + history_max_len = len; + if (history_len > history_max_len) + history_len = history_max_len; + return 1; +} + +/* Save the history in the specified file. On success 0 is returned + * otherwise -1 is returned. */ +int linenoiseHistorySave(const char *filename) { + mode_t old_umask = umask(S_IXUSR|S_IRWXG|S_IRWXO); + FILE *fp; + int j; + + fp = fopen(filename,"w"); + umask(old_umask); + if (fp == NULL) return -1; + chmod(filename,S_IRUSR|S_IWUSR); + for (j = 0; j < history_len; j++) + fprintf(fp,"%s\n",history[j]); + fclose(fp); + return 0; +} + +/* Load the history from the specified file. If the file does not exist + * zero is returned and no operation is performed. + * + * If the file exists and the operation succeeded 0 is returned, otherwise + * on error -1 is returned. */ +int linenoiseHistoryLoad(const char *filename) { + FILE *fp = fopen(filename,"r"); + char buf[LINENOISE_MAX_LINE]; + + if (fp == NULL) return -1; + + while (fgets(buf,LINENOISE_MAX_LINE,fp) != NULL) { + char *p; + + p = strchr(buf,'\r'); + if (!p) p = strchr(buf,'\n'); + if (p) *p = '\0'; + linenoiseHistoryAdd(buf); + } + fclose(fp); + return 0; +} diff --git a/libs/linenoise.h b/libs/linenoise.h new file mode 100644 index 0000000..3f0270e --- /dev/null +++ b/libs/linenoise.h @@ -0,0 +1,113 @@ +/* linenoise.h -- VERSION 1.0 + * + * Guerrilla line editing library against the idea that a line editing lib + * needs to be 20,000 lines of C code. + * + * See linenoise.c for more information. + * + * ------------------------------------------------------------------------ + * + * Copyright (c) 2010-2023, Salvatore Sanfilippo + * Copyright (c) 2010-2013, Pieter Noordhuis + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef __LINENOISE_H +#define __LINENOISE_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include /* For size_t. */ + +extern char *linenoiseEditMore; + +/* The linenoiseState structure represents the state during line editing. + * We pass this state to functions implementing specific editing + * functionalities. */ +struct linenoiseState { + int in_completion; /* The user pressed TAB and we are now in completion + * mode, so input is handled by completeLine(). */ + size_t completion_idx; /* Index of next completion to propose. */ + int ifd; /* Terminal stdin file descriptor. */ + int ofd; /* Terminal stdout file descriptor. */ + char *buf; /* Edited line buffer. */ + size_t buflen; /* Edited line buffer size. */ + const char *prompt; /* Prompt to display. */ + size_t plen; /* Prompt length. */ + size_t pos; /* Current cursor position. */ + size_t oldpos; /* Previous refresh cursor position. */ + size_t len; /* Current edited line length. */ + size_t cols; /* Number of columns in terminal. */ + size_t oldrows; /* Rows used by last refrehsed line (multiline mode) */ + int history_index; /* The history index we are currently editing. */ +}; + +typedef struct linenoiseCompletions { + size_t len; + char **cvec; +} linenoiseCompletions; + +/* Non blocking API. */ +int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt); +char *linenoiseEditFeed(struct linenoiseState *l); +void linenoiseEditStop(struct linenoiseState *l); +void linenoiseHide(struct linenoiseState *l); +void linenoiseShow(struct linenoiseState *l); + +/* Blocking API. */ +char *linenoise(const char *prompt); +void linenoiseFree(void *ptr); + +/* Completion API. */ +typedef void(linenoiseCompletionCallback)(const char *, linenoiseCompletions *); +typedef char*(linenoiseHintsCallback)(const char *, int *color, int *bold); +typedef void(linenoiseFreeHintsCallback)(void *); +void linenoiseSetCompletionCallback(linenoiseCompletionCallback *); +void linenoiseSetHintsCallback(linenoiseHintsCallback *); +void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *); +void linenoiseAddCompletion(linenoiseCompletions *, const char *); + +/* History API. */ +int linenoiseHistoryAdd(const char *line); +int linenoiseHistorySetMaxLen(int len); +int linenoiseHistorySave(const char *filename); +int linenoiseHistoryLoad(const char *filename); + +/* Other utilities. */ +void linenoiseClearScreen(void); +void linenoiseSetMultiLine(int ml); +void linenoisePrintKeyCodes(void); +void linenoiseMaskModeEnable(void); +void linenoiseMaskModeDisable(void); + +#ifdef __cplusplus +} +#endif + +#endif /* __LINENOISE_H */ diff --git a/libs/monocypher.c b/libs/monocypher.c new file mode 100644 index 0000000..d3930fb --- /dev/null +++ b/libs/monocypher.c @@ -0,0 +1,2956 @@ +// Monocypher version 4.0.2 +// +// This file is dual-licensed. Choose whichever licence you want from +// the two licences listed below. +// +// The first licence is a regular 2-clause BSD licence. The second licence +// is the CC-0 from Creative Commons. It is intended to release Monocypher +// to the public domain. The BSD licence serves as a fallback option. +// +// SPDX-License-Identifier: BSD-2-Clause OR CC0-1.0 +// +// ------------------------------------------------------------------------ +// +// Copyright (c) 2017-2020, Loup Vaillant +// All rights reserved. +// +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ------------------------------------------------------------------------ +// +// Written in 2017-2020 by Loup Vaillant +// +// To the extent possible under law, the author(s) have dedicated all copyright +// and related neighboring rights to this software to the public domain +// worldwide. This software is distributed without any warranty. +// +// You should have received a copy of the CC0 Public Domain Dedication along +// with this software. If not, see +// + +#include "monocypher.h" + +#ifdef MONOCYPHER_CPP_NAMESPACE +namespace MONOCYPHER_CPP_NAMESPACE { +#endif + +///////////////// +/// Utilities /// +///////////////// +#define FOR_T(type, i, start, end) for (type i = (start); i < (end); i++) +#define FOR(i, start, end) FOR_T(size_t, i, start, end) +#define COPY(dst, src, size) FOR(_i_, 0, size) (dst)[_i_] = (src)[_i_] +#define ZERO(buf, size) FOR(_i_, 0, size) (buf)[_i_] = 0 +#define WIPE_CTX(ctx) crypto_wipe(ctx , sizeof(*(ctx))) +#define WIPE_BUFFER(buffer) crypto_wipe(buffer, sizeof(buffer)) +#define MIN(a, b) ((a) <= (b) ? (a) : (b)) +#define MAX(a, b) ((a) >= (b) ? (a) : (b)) + +typedef int8_t i8; +typedef uint8_t u8; +typedef int16_t i16; +typedef uint32_t u32; +typedef int32_t i32; +typedef int64_t i64; +typedef uint64_t u64; + +static const u8 zero[128] = {0}; + +// returns the smallest positive integer y such that +// (x + y) % pow_2 == 0 +// Basically, y is the "gap" missing to align x. +// Only works when pow_2 is a power of 2. +// Note: we use ~x+1 instead of -x to avoid compiler warnings +static size_t gap(size_t x, size_t pow_2) +{ + return (~x + 1) & (pow_2 - 1); +} + +static u32 load24_le(const u8 s[3]) +{ + return + ((u32)s[0] << 0) | + ((u32)s[1] << 8) | + ((u32)s[2] << 16); +} + +static u32 load32_le(const u8 s[4]) +{ + return + ((u32)s[0] << 0) | + ((u32)s[1] << 8) | + ((u32)s[2] << 16) | + ((u32)s[3] << 24); +} + +static u64 load64_le(const u8 s[8]) +{ + return load32_le(s) | ((u64)load32_le(s+4) << 32); +} + +static void store32_le(u8 out[4], u32 in) +{ + out[0] = in & 0xff; + out[1] = (in >> 8) & 0xff; + out[2] = (in >> 16) & 0xff; + out[3] = (in >> 24) & 0xff; +} + +static void store64_le(u8 out[8], u64 in) +{ + store32_le(out , (u32)in ); + store32_le(out + 4, in >> 32); +} + +static void load32_le_buf (u32 *dst, const u8 *src, size_t size) { + FOR(i, 0, size) { dst[i] = load32_le(src + i*4); } +} +static void load64_le_buf (u64 *dst, const u8 *src, size_t size) { + FOR(i, 0, size) { dst[i] = load64_le(src + i*8); } +} +static void store32_le_buf(u8 *dst, const u32 *src, size_t size) { + FOR(i, 0, size) { store32_le(dst + i*4, src[i]); } +} +static void store64_le_buf(u8 *dst, const u64 *src, size_t size) { + FOR(i, 0, size) { store64_le(dst + i*8, src[i]); } +} + +static u64 rotr64(u64 x, u64 n) { return (x >> n) ^ (x << (64 - n)); } +static u32 rotl32(u32 x, u32 n) { return (x << n) ^ (x >> (32 - n)); } + +static int neq0(u64 diff) +{ + // constant time comparison to zero + // return diff != 0 ? -1 : 0 + u64 half = (diff >> 32) | ((u32)diff); + return (1 & ((half - 1) >> 32)) - 1; +} + +static u64 x16(const u8 a[16], const u8 b[16]) +{ + return (load64_le(a + 0) ^ load64_le(b + 0)) + | (load64_le(a + 8) ^ load64_le(b + 8)); +} +static u64 x32(const u8 a[32],const u8 b[32]){return x16(a,b)| x16(a+16, b+16);} +static u64 x64(const u8 a[64],const u8 b[64]){return x32(a,b)| x32(a+32, b+32);} +int crypto_verify16(const u8 a[16], const u8 b[16]){ return neq0(x16(a, b)); } +int crypto_verify32(const u8 a[32], const u8 b[32]){ return neq0(x32(a, b)); } +int crypto_verify64(const u8 a[64], const u8 b[64]){ return neq0(x64(a, b)); } + +void crypto_wipe(void *secret, size_t size) +{ + volatile u8 *v_secret = (u8*)secret; + ZERO(v_secret, size); +} + +///////////////// +/// Chacha 20 /// +///////////////// +#define QUARTERROUND(a, b, c, d) \ + a += b; d = rotl32(d ^ a, 16); \ + c += d; b = rotl32(b ^ c, 12); \ + a += b; d = rotl32(d ^ a, 8); \ + c += d; b = rotl32(b ^ c, 7) + +static void chacha20_rounds(u32 out[16], const u32 in[16]) +{ + // The temporary variables make Chacha20 10% faster. + u32 t0 = in[ 0]; u32 t1 = in[ 1]; u32 t2 = in[ 2]; u32 t3 = in[ 3]; + u32 t4 = in[ 4]; u32 t5 = in[ 5]; u32 t6 = in[ 6]; u32 t7 = in[ 7]; + u32 t8 = in[ 8]; u32 t9 = in[ 9]; u32 t10 = in[10]; u32 t11 = in[11]; + u32 t12 = in[12]; u32 t13 = in[13]; u32 t14 = in[14]; u32 t15 = in[15]; + + FOR (i, 0, 10) { // 20 rounds, 2 rounds per loop. + QUARTERROUND(t0, t4, t8 , t12); // column 0 + QUARTERROUND(t1, t5, t9 , t13); // column 1 + QUARTERROUND(t2, t6, t10, t14); // column 2 + QUARTERROUND(t3, t7, t11, t15); // column 3 + QUARTERROUND(t0, t5, t10, t15); // diagonal 0 + QUARTERROUND(t1, t6, t11, t12); // diagonal 1 + QUARTERROUND(t2, t7, t8 , t13); // diagonal 2 + QUARTERROUND(t3, t4, t9 , t14); // diagonal 3 + } + out[ 0] = t0; out[ 1] = t1; out[ 2] = t2; out[ 3] = t3; + out[ 4] = t4; out[ 5] = t5; out[ 6] = t6; out[ 7] = t7; + out[ 8] = t8; out[ 9] = t9; out[10] = t10; out[11] = t11; + out[12] = t12; out[13] = t13; out[14] = t14; out[15] = t15; +} + +static const u8 *chacha20_constant = (const u8*)"expand 32-byte k"; // 16 bytes + +void crypto_chacha20_h(u8 out[32], const u8 key[32], const u8 in [16]) +{ + u32 block[16]; + load32_le_buf(block , chacha20_constant, 4); + load32_le_buf(block + 4, key , 8); + load32_le_buf(block + 12, in , 4); + + chacha20_rounds(block, block); + + // prevent reversal of the rounds by revealing only half of the buffer. + store32_le_buf(out , block , 4); // constant + store32_le_buf(out+16, block+12, 4); // counter and nonce + WIPE_BUFFER(block); +} + +u64 crypto_chacha20_djb(u8 *cipher_text, const u8 *plain_text, + size_t text_size, const u8 key[32], const u8 nonce[8], + u64 ctr) +{ + u32 input[16]; + load32_le_buf(input , chacha20_constant, 4); + load32_le_buf(input + 4, key , 8); + load32_le_buf(input + 14, nonce , 2); + input[12] = (u32) ctr; + input[13] = (u32)(ctr >> 32); + + // Whole blocks + u32 pool[16]; + size_t nb_blocks = text_size >> 6; + FOR (i, 0, nb_blocks) { + chacha20_rounds(pool, input); + if (plain_text != 0) { + FOR (j, 0, 16) { + u32 p = pool[j] + input[j]; + store32_le(cipher_text, p ^ load32_le(plain_text)); + cipher_text += 4; + plain_text += 4; + } + } else { + FOR (j, 0, 16) { + u32 p = pool[j] + input[j]; + store32_le(cipher_text, p); + cipher_text += 4; + } + } + input[12]++; + if (input[12] == 0) { + input[13]++; + } + } + text_size &= 63; + + // Last (incomplete) block + if (text_size > 0) { + if (plain_text == 0) { + plain_text = zero; + } + chacha20_rounds(pool, input); + u8 tmp[64]; + FOR (i, 0, 16) { + store32_le(tmp + i*4, pool[i] + input[i]); + } + FOR (i, 0, text_size) { + cipher_text[i] = tmp[i] ^ plain_text[i]; + } + WIPE_BUFFER(tmp); + } + ctr = input[12] + ((u64)input[13] << 32) + (text_size > 0); + + WIPE_BUFFER(pool); + WIPE_BUFFER(input); + return ctr; +} + +u32 crypto_chacha20_ietf(u8 *cipher_text, const u8 *plain_text, + size_t text_size, + const u8 key[32], const u8 nonce[12], u32 ctr) +{ + u64 big_ctr = ctr + ((u64)load32_le(nonce) << 32); + return (u32)crypto_chacha20_djb(cipher_text, plain_text, text_size, + key, nonce + 4, big_ctr); +} + +u64 crypto_chacha20_x(u8 *cipher_text, const u8 *plain_text, + size_t text_size, + const u8 key[32], const u8 nonce[24], u64 ctr) +{ + u8 sub_key[32]; + crypto_chacha20_h(sub_key, key, nonce); + ctr = crypto_chacha20_djb(cipher_text, plain_text, text_size, + sub_key, nonce + 16, ctr); + WIPE_BUFFER(sub_key); + return ctr; +} + +///////////////// +/// Poly 1305 /// +///////////////// + +// h = (h + c) * r +// preconditions: +// ctx->h <= 4_ffffffff_ffffffff_ffffffff_ffffffff +// ctx->r <= 0ffffffc_0ffffffc_0ffffffc_0fffffff +// end <= 1 +// Postcondition: +// ctx->h <= 4_ffffffff_ffffffff_ffffffff_ffffffff +static void poly_blocks(crypto_poly1305_ctx *ctx, const u8 *in, + size_t nb_blocks, unsigned end) +{ + // Local all the things! + const u32 r0 = ctx->r[0]; + const u32 r1 = ctx->r[1]; + const u32 r2 = ctx->r[2]; + const u32 r3 = ctx->r[3]; + const u32 rr0 = (r0 >> 2) * 5; // lose 2 bits... + const u32 rr1 = (r1 >> 2) + r1; // rr1 == (r1 >> 2) * 5 + const u32 rr2 = (r2 >> 2) + r2; // rr1 == (r2 >> 2) * 5 + const u32 rr3 = (r3 >> 2) + r3; // rr1 == (r3 >> 2) * 5 + const u32 rr4 = r0 & 3; // ...recover 2 bits + u32 h0 = ctx->h[0]; + u32 h1 = ctx->h[1]; + u32 h2 = ctx->h[2]; + u32 h3 = ctx->h[3]; + u32 h4 = ctx->h[4]; + + FOR (i, 0, nb_blocks) { + // h + c, without carry propagation + const u64 s0 = (u64)h0 + load32_le(in); in += 4; + const u64 s1 = (u64)h1 + load32_le(in); in += 4; + const u64 s2 = (u64)h2 + load32_le(in); in += 4; + const u64 s3 = (u64)h3 + load32_le(in); in += 4; + const u32 s4 = h4 + end; + + // (h + c) * r, without carry propagation + const u64 x0 = s0*r0+ s1*rr3+ s2*rr2+ s3*rr1+ s4*rr0; + const u64 x1 = s0*r1+ s1*r0 + s2*rr3+ s3*rr2+ s4*rr1; + const u64 x2 = s0*r2+ s1*r1 + s2*r0 + s3*rr3+ s4*rr2; + const u64 x3 = s0*r3+ s1*r2 + s2*r1 + s3*r0 + s4*rr3; + const u32 x4 = s4*rr4; + + // partial reduction modulo 2^130 - 5 + const u32 u5 = x4 + (x3 >> 32); // u5 <= 7ffffff5 + const u64 u0 = (u5 >> 2) * 5 + (x0 & 0xffffffff); + const u64 u1 = (u0 >> 32) + (x1 & 0xffffffff) + (x0 >> 32); + const u64 u2 = (u1 >> 32) + (x2 & 0xffffffff) + (x1 >> 32); + const u64 u3 = (u2 >> 32) + (x3 & 0xffffffff) + (x2 >> 32); + const u32 u4 = (u3 >> 32) + (u5 & 3); // u4 <= 4 + + // Update the hash + h0 = u0 & 0xffffffff; + h1 = u1 & 0xffffffff; + h2 = u2 & 0xffffffff; + h3 = u3 & 0xffffffff; + h4 = u4; + } + ctx->h[0] = h0; + ctx->h[1] = h1; + ctx->h[2] = h2; + ctx->h[3] = h3; + ctx->h[4] = h4; +} + +void crypto_poly1305_init(crypto_poly1305_ctx *ctx, const u8 key[32]) +{ + ZERO(ctx->h, 5); // Initial hash is zero + ctx->c_idx = 0; + // load r and pad (r has some of its bits cleared) + load32_le_buf(ctx->r , key , 4); + load32_le_buf(ctx->pad, key+16, 4); + FOR (i, 0, 1) { ctx->r[i] &= 0x0fffffff; } + FOR (i, 1, 4) { ctx->r[i] &= 0x0ffffffc; } +} + +void crypto_poly1305_update(crypto_poly1305_ctx *ctx, + const u8 *message, size_t message_size) +{ + // Avoid undefined NULL pointer increments with empty messages + if (message_size == 0) { + return; + } + + // Align ourselves with block boundaries + size_t aligned = MIN(gap(ctx->c_idx, 16), message_size); + FOR (i, 0, aligned) { + ctx->c[ctx->c_idx] = *message; + ctx->c_idx++; + message++; + message_size--; + } + + // If block is complete, process it + if (ctx->c_idx == 16) { + poly_blocks(ctx, ctx->c, 1, 1); + ctx->c_idx = 0; + } + + // Process the message block by block + size_t nb_blocks = message_size >> 4; + poly_blocks(ctx, message, nb_blocks, 1); + message += nb_blocks << 4; + message_size &= 15; + + // remaining bytes (we never complete a block here) + FOR (i, 0, message_size) { + ctx->c[ctx->c_idx] = message[i]; + ctx->c_idx++; + } +} + +void crypto_poly1305_final(crypto_poly1305_ctx *ctx, u8 mac[16]) +{ + // Process the last block (if any) + // We move the final 1 according to remaining input length + // (this will add less than 2^130 to the last input block) + if (ctx->c_idx != 0) { + ZERO(ctx->c + ctx->c_idx, 16 - ctx->c_idx); + ctx->c[ctx->c_idx] = 1; + poly_blocks(ctx, ctx->c, 1, 0); + } + + // check if we should subtract 2^130-5 by performing the + // corresponding carry propagation. + u64 c = 5; + FOR (i, 0, 4) { + c += ctx->h[i]; + c >>= 32; + } + c += ctx->h[4]; + c = (c >> 2) * 5; // shift the carry back to the beginning + // c now indicates how many times we should subtract 2^130-5 (0 or 1) + FOR (i, 0, 4) { + c += (u64)ctx->h[i] + ctx->pad[i]; + store32_le(mac + i*4, (u32)c); + c = c >> 32; + } + WIPE_CTX(ctx); +} + +void crypto_poly1305(u8 mac[16], const u8 *message, + size_t message_size, const u8 key[32]) +{ + crypto_poly1305_ctx ctx; + crypto_poly1305_init (&ctx, key); + crypto_poly1305_update(&ctx, message, message_size); + crypto_poly1305_final (&ctx, mac); +} + +//////////////// +/// BLAKE2 b /// +//////////////// +static const u64 iv[8] = { + 0x6a09e667f3bcc908, 0xbb67ae8584caa73b, + 0x3c6ef372fe94f82b, 0xa54ff53a5f1d36f1, + 0x510e527fade682d1, 0x9b05688c2b3e6c1f, + 0x1f83d9abfb41bd6b, 0x5be0cd19137e2179, +}; + +static void blake2b_compress(crypto_blake2b_ctx *ctx, int is_last_block) +{ + static const u8 sigma[12][16] = { + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, + { 14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3 }, + { 11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4 }, + { 7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8 }, + { 9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13 }, + { 2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9 }, + { 12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11 }, + { 13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10 }, + { 6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5 }, + { 10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0 }, + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, + { 14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3 }, + }; + + // increment input offset + u64 *x = ctx->input_offset; + size_t y = ctx->input_idx; + x[0] += y; + if (x[0] < y) { + x[1]++; + } + + // init work vector + u64 v0 = ctx->hash[0]; u64 v8 = iv[0]; + u64 v1 = ctx->hash[1]; u64 v9 = iv[1]; + u64 v2 = ctx->hash[2]; u64 v10 = iv[2]; + u64 v3 = ctx->hash[3]; u64 v11 = iv[3]; + u64 v4 = ctx->hash[4]; u64 v12 = iv[4] ^ ctx->input_offset[0]; + u64 v5 = ctx->hash[5]; u64 v13 = iv[5] ^ ctx->input_offset[1]; + u64 v6 = ctx->hash[6]; u64 v14 = iv[6] ^ (u64)~(is_last_block - 1); + u64 v7 = ctx->hash[7]; u64 v15 = iv[7]; + + // mangle work vector + u64 *input = ctx->input; +#define BLAKE2_G(a, b, c, d, x, y) \ + a += b + x; d = rotr64(d ^ a, 32); \ + c += d; b = rotr64(b ^ c, 24); \ + a += b + y; d = rotr64(d ^ a, 16); \ + c += d; b = rotr64(b ^ c, 63) +#define BLAKE2_ROUND(i) \ + BLAKE2_G(v0, v4, v8 , v12, input[sigma[i][ 0]], input[sigma[i][ 1]]); \ + BLAKE2_G(v1, v5, v9 , v13, input[sigma[i][ 2]], input[sigma[i][ 3]]); \ + BLAKE2_G(v2, v6, v10, v14, input[sigma[i][ 4]], input[sigma[i][ 5]]); \ + BLAKE2_G(v3, v7, v11, v15, input[sigma[i][ 6]], input[sigma[i][ 7]]); \ + BLAKE2_G(v0, v5, v10, v15, input[sigma[i][ 8]], input[sigma[i][ 9]]); \ + BLAKE2_G(v1, v6, v11, v12, input[sigma[i][10]], input[sigma[i][11]]); \ + BLAKE2_G(v2, v7, v8 , v13, input[sigma[i][12]], input[sigma[i][13]]); \ + BLAKE2_G(v3, v4, v9 , v14, input[sigma[i][14]], input[sigma[i][15]]) + +#ifdef BLAKE2_NO_UNROLLING + FOR (i, 0, 12) { + BLAKE2_ROUND(i); + } +#else + BLAKE2_ROUND(0); BLAKE2_ROUND(1); BLAKE2_ROUND(2); BLAKE2_ROUND(3); + BLAKE2_ROUND(4); BLAKE2_ROUND(5); BLAKE2_ROUND(6); BLAKE2_ROUND(7); + BLAKE2_ROUND(8); BLAKE2_ROUND(9); BLAKE2_ROUND(10); BLAKE2_ROUND(11); +#endif + + // update hash + ctx->hash[0] ^= v0 ^ v8; ctx->hash[1] ^= v1 ^ v9; + ctx->hash[2] ^= v2 ^ v10; ctx->hash[3] ^= v3 ^ v11; + ctx->hash[4] ^= v4 ^ v12; ctx->hash[5] ^= v5 ^ v13; + ctx->hash[6] ^= v6 ^ v14; ctx->hash[7] ^= v7 ^ v15; +} + +void crypto_blake2b_keyed_init(crypto_blake2b_ctx *ctx, size_t hash_size, + const u8 *key, size_t key_size) +{ + // initial hash + COPY(ctx->hash, iv, 8); + ctx->hash[0] ^= 0x01010000 ^ (key_size << 8) ^ hash_size; + + ctx->input_offset[0] = 0; // beginning of the input, no offset + ctx->input_offset[1] = 0; // beginning of the input, no offset + ctx->hash_size = hash_size; + ctx->input_idx = 0; + ZERO(ctx->input, 16); + + // if there is a key, the first block is that key (padded with zeroes) + if (key_size > 0) { + u8 key_block[128] = {0}; + COPY(key_block, key, key_size); + // same as calling crypto_blake2b_update(ctx, key_block , 128) + load64_le_buf(ctx->input, key_block, 16); + ctx->input_idx = 128; + } +} + +void crypto_blake2b_init(crypto_blake2b_ctx *ctx, size_t hash_size) +{ + crypto_blake2b_keyed_init(ctx, hash_size, 0, 0); +} + +void crypto_blake2b_update(crypto_blake2b_ctx *ctx, + const u8 *message, size_t message_size) +{ + // Avoid undefined NULL pointer increments with empty messages + if (message_size == 0) { + return; + } + + // Align with word boundaries + if ((ctx->input_idx & 7) != 0) { + size_t nb_bytes = MIN(gap(ctx->input_idx, 8), message_size); + size_t word = ctx->input_idx >> 3; + size_t byte = ctx->input_idx & 7; + FOR (i, 0, nb_bytes) { + ctx->input[word] |= (u64)message[i] << ((byte + i) << 3); + } + ctx->input_idx += nb_bytes; + message += nb_bytes; + message_size -= nb_bytes; + } + + // Align with block boundaries (faster than byte by byte) + if ((ctx->input_idx & 127) != 0) { + size_t nb_words = MIN(gap(ctx->input_idx, 128), message_size) >> 3; + load64_le_buf(ctx->input + (ctx->input_idx >> 3), message, nb_words); + ctx->input_idx += nb_words << 3; + message += nb_words << 3; + message_size -= nb_words << 3; + } + + // Process block by block + size_t nb_blocks = message_size >> 7; + FOR (i, 0, nb_blocks) { + if (ctx->input_idx == 128) { + blake2b_compress(ctx, 0); + } + load64_le_buf(ctx->input, message, 16); + message += 128; + ctx->input_idx = 128; + } + message_size &= 127; + + if (message_size != 0) { + // Compress block & flush input buffer as needed + if (ctx->input_idx == 128) { + blake2b_compress(ctx, 0); + ctx->input_idx = 0; + } + if (ctx->input_idx == 0) { + ZERO(ctx->input, 16); + } + // Fill remaining words (faster than byte by byte) + size_t nb_words = message_size >> 3; + load64_le_buf(ctx->input, message, nb_words); + ctx->input_idx += nb_words << 3; + message += nb_words << 3; + message_size -= nb_words << 3; + + // Fill remaining bytes + FOR (i, 0, message_size) { + size_t word = ctx->input_idx >> 3; + size_t byte = ctx->input_idx & 7; + ctx->input[word] |= (u64)message[i] << (byte << 3); + ctx->input_idx++; + } + } +} + +void crypto_blake2b_final(crypto_blake2b_ctx *ctx, u8 *hash) +{ + blake2b_compress(ctx, 1); // compress the last block + size_t hash_size = MIN(ctx->hash_size, 64); + size_t nb_words = hash_size >> 3; + store64_le_buf(hash, ctx->hash, nb_words); + FOR (i, nb_words << 3, hash_size) { + hash[i] = (ctx->hash[i >> 3] >> (8 * (i & 7))) & 0xff; + } + WIPE_CTX(ctx); +} + +void crypto_blake2b_keyed(u8 *hash, size_t hash_size, + const u8 *key, size_t key_size, + const u8 *message, size_t message_size) +{ + crypto_blake2b_ctx ctx; + crypto_blake2b_keyed_init(&ctx, hash_size, key, key_size); + crypto_blake2b_update (&ctx, message, message_size); + crypto_blake2b_final (&ctx, hash); +} + +void crypto_blake2b(u8 *hash, size_t hash_size, const u8 *msg, size_t msg_size) +{ + crypto_blake2b_keyed(hash, hash_size, 0, 0, msg, msg_size); +} + +////////////// +/// Argon2 /// +////////////// +// references to R, Z, Q etc. come from the spec + +// Argon2 operates on 1024 byte blocks. +typedef struct { u64 a[128]; } blk; + +// updates a BLAKE2 hash with a 32 bit word, little endian. +static void blake_update_32(crypto_blake2b_ctx *ctx, u32 input) +{ + u8 buf[4]; + store32_le(buf, input); + crypto_blake2b_update(ctx, buf, 4); + WIPE_BUFFER(buf); +} + +static void blake_update_32_buf(crypto_blake2b_ctx *ctx, + const u8 *buf, u32 size) +{ + blake_update_32(ctx, size); + crypto_blake2b_update(ctx, buf, size); +} + + +static void copy_block(blk *o,const blk*in){FOR(i, 0, 128) o->a[i] = in->a[i];} +static void xor_block(blk *o,const blk*in){FOR(i, 0, 128) o->a[i] ^= in->a[i];} + +// Hash with a virtually unlimited digest size. +// Doesn't extract more entropy than the base hash function. +// Mainly used for filling a whole kilobyte block with pseudo-random bytes. +// (One could use a stream cipher with a seed hash as the key, but +// this would introduce another dependency —and point of failure.) +static void extended_hash(u8 *digest, u32 digest_size, + const u8 *input , u32 input_size) +{ + crypto_blake2b_ctx ctx; + crypto_blake2b_init (&ctx, MIN(digest_size, 64)); + blake_update_32 (&ctx, digest_size); + crypto_blake2b_update(&ctx, input, input_size); + crypto_blake2b_final (&ctx, digest); + + if (digest_size > 64) { + // the conversion to u64 avoids integer overflow on + // ludicrously big hash sizes. + u32 r = (u32)(((u64)digest_size + 31) >> 5) - 2; + u32 i = 1; + u32 in = 0; + u32 out = 32; + while (i < r) { + // Input and output overlap. This is intentional + crypto_blake2b(digest + out, 64, digest + in, 64); + i += 1; + in += 32; + out += 32; + } + crypto_blake2b(digest + out, digest_size - (32 * r), digest + in , 64); + } +} + +#define LSB(x) ((u64)(u32)x) +#define G(a, b, c, d) \ + a += b + ((LSB(a) * LSB(b)) << 1); d ^= a; d = rotr64(d, 32); \ + c += d + ((LSB(c) * LSB(d)) << 1); b ^= c; b = rotr64(b, 24); \ + a += b + ((LSB(a) * LSB(b)) << 1); d ^= a; d = rotr64(d, 16); \ + c += d + ((LSB(c) * LSB(d)) << 1); b ^= c; b = rotr64(b, 63) +#define ROUND(v0, v1, v2, v3, v4, v5, v6, v7, \ + v8, v9, v10, v11, v12, v13, v14, v15) \ + G(v0, v4, v8, v12); G(v1, v5, v9, v13); \ + G(v2, v6, v10, v14); G(v3, v7, v11, v15); \ + G(v0, v5, v10, v15); G(v1, v6, v11, v12); \ + G(v2, v7, v8, v13); G(v3, v4, v9, v14) + +// Core of the compression function G. Computes Z from R in place. +static void g_rounds(blk *b) +{ + // column rounds (work_block = Q) + for (int i = 0; i < 128; i += 16) { + ROUND(b->a[i ], b->a[i+ 1], b->a[i+ 2], b->a[i+ 3], + b->a[i+ 4], b->a[i+ 5], b->a[i+ 6], b->a[i+ 7], + b->a[i+ 8], b->a[i+ 9], b->a[i+10], b->a[i+11], + b->a[i+12], b->a[i+13], b->a[i+14], b->a[i+15]); + } + // row rounds (b = Z) + for (int i = 0; i < 16; i += 2) { + ROUND(b->a[i ], b->a[i+ 1], b->a[i+ 16], b->a[i+ 17], + b->a[i+32], b->a[i+33], b->a[i+ 48], b->a[i+ 49], + b->a[i+64], b->a[i+65], b->a[i+ 80], b->a[i+ 81], + b->a[i+96], b->a[i+97], b->a[i+112], b->a[i+113]); + } +} + +const crypto_argon2_extras crypto_argon2_no_extras = { 0, 0, 0, 0 }; + +void crypto_argon2(u8 *hash, u32 hash_size, void *work_area, + crypto_argon2_config config, + crypto_argon2_inputs inputs, + crypto_argon2_extras extras) +{ + const u32 segment_size = config.nb_blocks / config.nb_lanes / 4; + const u32 lane_size = segment_size * 4; + const u32 nb_blocks = lane_size * config.nb_lanes; // rounding down + + // work area seen as blocks (must be suitably aligned) + blk *blocks = (blk*)work_area; + { + u8 initial_hash[72]; // 64 bytes plus 2 words for future hashes + crypto_blake2b_ctx ctx; + crypto_blake2b_init (&ctx, 64); + blake_update_32 (&ctx, config.nb_lanes ); // p: number of "threads" + blake_update_32 (&ctx, hash_size); + blake_update_32 (&ctx, config.nb_blocks); + blake_update_32 (&ctx, config.nb_passes); + blake_update_32 (&ctx, 0x13); // v: version number + blake_update_32 (&ctx, config.algorithm); // y: Argon2i, Argon2d... + blake_update_32_buf (&ctx, inputs.pass, inputs.pass_size); + blake_update_32_buf (&ctx, inputs.salt, inputs.salt_size); + blake_update_32_buf (&ctx, extras.key, extras.key_size); + blake_update_32_buf (&ctx, extras.ad, extras.ad_size); + crypto_blake2b_final(&ctx, initial_hash); // fill 64 first bytes only + + // fill first 2 blocks of each lane + u8 hash_area[1024]; + FOR_T(u32, l, 0, config.nb_lanes) { + FOR_T(u32, i, 0, 2) { + store32_le(initial_hash + 64, i); // first additional word + store32_le(initial_hash + 68, l); // second additional word + extended_hash(hash_area, 1024, initial_hash, 72); + load64_le_buf(blocks[l * lane_size + i].a, hash_area, 128); + } + } + + WIPE_BUFFER(initial_hash); + WIPE_BUFFER(hash_area); + } + + // Argon2i and Argon2id start with constant time indexing + int constant_time = config.algorithm != CRYPTO_ARGON2_D; + + // Fill (and re-fill) the rest of the blocks + // + // Note: even though each segment within the same slice can be + // computed in parallel, (one thread per lane), we are computing + // them sequentially, because Monocypher doesn't support threads. + // + // Yet optimal performance (and therefore security) requires one + // thread per lane. The only reason Monocypher supports multiple + // lanes is compatibility. + blk tmp; + FOR_T(u32, pass, 0, config.nb_passes) { + FOR_T(u32, slice, 0, 4) { + // On the first slice of the first pass, + // blocks 0 and 1 are already filled, hence pass_offset. + u32 pass_offset = pass == 0 && slice == 0 ? 2 : 0; + u32 slice_offset = slice * segment_size; + + // Argon2id switches back to non-constant time indexing + // after the first two slices of the first pass + if (slice == 2 && config.algorithm == CRYPTO_ARGON2_ID) { + constant_time = 0; + } + + // Each iteration of the following loop may be performed in + // a separate thread. All segments must be fully completed + // before we start filling the next slice. + FOR_T(u32, segment, 0, config.nb_lanes) { + blk index_block; + u32 index_ctr = 1; + FOR_T (u32, block, pass_offset, segment_size) { + // Current and previous blocks + u32 lane_offset = segment * lane_size; + blk *segment_start = blocks + lane_offset + slice_offset; + blk *current = segment_start + block; + blk *previous = + block == 0 && slice_offset == 0 + ? segment_start + lane_size - 1 + : segment_start + block - 1; + + u64 index_seed; + if (constant_time) { + if (block == pass_offset || (block % 128) == 0) { + // Fill or refresh deterministic indices block + + // seed the beginning of the block... + ZERO(index_block.a, 128); + index_block.a[0] = pass; + index_block.a[1] = segment; + index_block.a[2] = slice; + index_block.a[3] = nb_blocks; + index_block.a[4] = config.nb_passes; + index_block.a[5] = config.algorithm; + index_block.a[6] = index_ctr; + index_ctr++; + + // ... then shuffle it + copy_block(&tmp, &index_block); + g_rounds (&index_block); + xor_block (&index_block, &tmp); + copy_block(&tmp, &index_block); + g_rounds (&index_block); + xor_block (&index_block, &tmp); + } + index_seed = index_block.a[block % 128]; + } else { + index_seed = previous->a[0]; + } + + // Establish the reference set. *Approximately* comprises: + // - The last 3 slices (if they exist yet) + // - The already constructed blocks in the current segment + u32 next_slice = ((slice + 1) % 4) * segment_size; + u32 window_start = pass == 0 ? 0 : next_slice; + u32 nb_segments = pass == 0 ? slice : 3; + u64 lane = + pass == 0 && slice == 0 + ? segment + : (index_seed >> 32) % config.nb_lanes; + u32 window_size = + nb_segments * segment_size + + (lane == segment ? block-1 : + block == 0 ? (u32)-1 : 0); + + // Find reference block + u64 j1 = index_seed & 0xffffffff; // block selector + u64 x = (j1 * j1) >> 32; + u64 y = (window_size * x) >> 32; + u64 z = (window_size - 1) - y; + u64 ref = (window_start + z) % lane_size; + u32 index = lane * lane_size + (u32)ref; + blk *reference = blocks + index; + + // Shuffle the previous & reference block + // into the current block + copy_block(&tmp, previous); + xor_block (&tmp, reference); + if (pass == 0) { copy_block(current, &tmp); } + else { xor_block (current, &tmp); } + g_rounds (&tmp); + xor_block (current, &tmp); + } + } + } + } + + // Wipe temporary block + volatile u64* p = tmp.a; + ZERO(p, 128); + + // XOR last blocks of each lane + blk *last_block = blocks + lane_size - 1; + FOR_T (u32, lane, 1, config.nb_lanes) { + blk *next_block = last_block + lane_size; + xor_block(next_block, last_block); + last_block = next_block; + } + + // Serialize last block + u8 final_block[1024]; + store64_le_buf(final_block, last_block->a, 128); + + // Wipe work area + p = (u64*)work_area; + ZERO(p, 128 * nb_blocks); + + // Hash the very last block with H' into the output hash + extended_hash(hash, hash_size, final_block, 1024); + WIPE_BUFFER(final_block); +} + +//////////////////////////////////// +/// Arithmetic modulo 2^255 - 19 /// +//////////////////////////////////// +// Originally taken from SUPERCOP's ref10 implementation. +// A bit bigger than TweetNaCl, over 4 times faster. + +// field element +typedef i32 fe[10]; + +// field constants +// +// fe_one : 1 +// sqrtm1 : sqrt(-1) +// d : -121665 / 121666 +// D2 : 2 * -121665 / 121666 +// lop_x, lop_y: low order point in Edwards coordinates +// ufactor : -sqrt(-1) * 2 +// A2 : 486662^2 (A squared) +static const fe fe_one = {1}; +static const fe sqrtm1 = { + -32595792, -7943725, 9377950, 3500415, 12389472, + -272473, -25146209, -2005654, 326686, 11406482, +}; +static const fe d = { + -10913610, 13857413, -15372611, 6949391, 114729, + -8787816, -6275908, -3247719, -18696448, -12055116, +}; +static const fe D2 = { + -21827239, -5839606, -30745221, 13898782, 229458, + 15978800, -12551817, -6495438, 29715968, 9444199, +}; +static const fe lop_x = { + 21352778, 5345713, 4660180, -8347857, 24143090, + 14568123, 30185756, -12247770, -33528939, 8345319, +}; +static const fe lop_y = { + -6952922, -1265500, 6862341, -7057498, -4037696, + -5447722, 31680899, -15325402, -19365852, 1569102, +}; +static const fe ufactor = { + -1917299, 15887451, -18755900, -7000830, -24778944, + 544946, -16816446, 4011309, -653372, 10741468, +}; +static const fe A2 = { + 12721188, 3529, 0, 0, 0, 0, 0, 0, 0, 0, +}; + +static void fe_0(fe h) { ZERO(h , 10); } +static void fe_1(fe h) { h[0] = 1; ZERO(h+1, 9); } + +static void fe_copy(fe h,const fe f ){FOR(i,0,10) h[i] = f[i]; } +static void fe_neg (fe h,const fe f ){FOR(i,0,10) h[i] = -f[i]; } +static void fe_add (fe h,const fe f,const fe g){FOR(i,0,10) h[i] = f[i] + g[i];} +static void fe_sub (fe h,const fe f,const fe g){FOR(i,0,10) h[i] = f[i] - g[i];} + +static void fe_cswap(fe f, fe g, int b) +{ + i32 mask = -b; // -1 = 0xffffffff + FOR (i, 0, 10) { + i32 x = (f[i] ^ g[i]) & mask; + f[i] = f[i] ^ x; + g[i] = g[i] ^ x; + } +} + +static void fe_ccopy(fe f, const fe g, int b) +{ + i32 mask = -b; // -1 = 0xffffffff + FOR (i, 0, 10) { + i32 x = (f[i] ^ g[i]) & mask; + f[i] = f[i] ^ x; + } +} + + +// Signed carry propagation +// ------------------------ +// +// Let t be a number. It can be uniquely decomposed thus: +// +// t = h*2^26 + l +// such that -2^25 <= l < 2^25 +// +// Let c = (t + 2^25) / 2^26 (rounded down) +// c = (h*2^26 + l + 2^25) / 2^26 (rounded down) +// c = h + (l + 2^25) / 2^26 (rounded down) +// c = h (exactly) +// Because 0 <= l + 2^25 < 2^26 +// +// Let u = t - c*2^26 +// u = h*2^26 + l - h*2^26 +// u = l +// Therefore, -2^25 <= u < 2^25 +// +// Additionally, if |t| < x, then |h| < x/2^26 (rounded down) +// +// Notations: +// - In C, 1<<25 means 2^25. +// - In C, x>>25 means floor(x / (2^25)). +// - All of the above applies with 25 & 24 as well as 26 & 25. +// +// +// Note on negative right shifts +// ----------------------------- +// +// In C, x >> n, where x is a negative integer, is implementation +// defined. In practice, all platforms do arithmetic shift, which is +// equivalent to division by 2^26, rounded down. Some compilers, like +// GCC, even guarantee it. +// +// If we ever stumble upon a platform that does not propagate the sign +// bit (we won't), visible failures will show at the slightest test, and +// the signed shifts can be replaced by the following: +// +// typedef struct { i64 x:39; } s25; +// typedef struct { i64 x:38; } s26; +// i64 shift25(i64 x) { s25 s; s.x = ((u64)x)>>25; return s.x; } +// i64 shift26(i64 x) { s26 s; s.x = ((u64)x)>>26; return s.x; } +// +// Current compilers cannot optimise this, causing a 30% drop in +// performance. Fairly expensive for something that never happens. +// +// +// Precondition +// ------------ +// +// |t0| < 2^63 +// |t1|..|t9| < 2^62 +// +// Algorithm +// --------- +// c = t0 + 2^25 / 2^26 -- |c| <= 2^36 +// t0 -= c * 2^26 -- |t0| <= 2^25 +// t1 += c -- |t1| <= 2^63 +// +// c = t4 + 2^25 / 2^26 -- |c| <= 2^36 +// t4 -= c * 2^26 -- |t4| <= 2^25 +// t5 += c -- |t5| <= 2^63 +// +// c = t1 + 2^24 / 2^25 -- |c| <= 2^38 +// t1 -= c * 2^25 -- |t1| <= 2^24 +// t2 += c -- |t2| <= 2^63 +// +// c = t5 + 2^24 / 2^25 -- |c| <= 2^38 +// t5 -= c * 2^25 -- |t5| <= 2^24 +// t6 += c -- |t6| <= 2^63 +// +// c = t2 + 2^25 / 2^26 -- |c| <= 2^37 +// t2 -= c * 2^26 -- |t2| <= 2^25 < 1.1 * 2^25 (final t2) +// t3 += c -- |t3| <= 2^63 +// +// c = t6 + 2^25 / 2^26 -- |c| <= 2^37 +// t6 -= c * 2^26 -- |t6| <= 2^25 < 1.1 * 2^25 (final t6) +// t7 += c -- |t7| <= 2^63 +// +// c = t3 + 2^24 / 2^25 -- |c| <= 2^38 +// t3 -= c * 2^25 -- |t3| <= 2^24 < 1.1 * 2^24 (final t3) +// t4 += c -- |t4| <= 2^25 + 2^38 < 2^39 +// +// c = t7 + 2^24 / 2^25 -- |c| <= 2^38 +// t7 -= c * 2^25 -- |t7| <= 2^24 < 1.1 * 2^24 (final t7) +// t8 += c -- |t8| <= 2^63 +// +// c = t4 + 2^25 / 2^26 -- |c| <= 2^13 +// t4 -= c * 2^26 -- |t4| <= 2^25 < 1.1 * 2^25 (final t4) +// t5 += c -- |t5| <= 2^24 + 2^13 < 1.1 * 2^24 (final t5) +// +// c = t8 + 2^25 / 2^26 -- |c| <= 2^37 +// t8 -= c * 2^26 -- |t8| <= 2^25 < 1.1 * 2^25 (final t8) +// t9 += c -- |t9| <= 2^63 +// +// c = t9 + 2^24 / 2^25 -- |c| <= 2^38 +// t9 -= c * 2^25 -- |t9| <= 2^24 < 1.1 * 2^24 (final t9) +// t0 += c * 19 -- |t0| <= 2^25 + 2^38*19 < 2^44 +// +// c = t0 + 2^25 / 2^26 -- |c| <= 2^18 +// t0 -= c * 2^26 -- |t0| <= 2^25 < 1.1 * 2^25 (final t0) +// t1 += c -- |t1| <= 2^24 + 2^18 < 1.1 * 2^24 (final t1) +// +// Postcondition +// ------------- +// |t0|, |t2|, |t4|, |t6|, |t8| < 1.1 * 2^25 +// |t1|, |t3|, |t5|, |t7|, |t9| < 1.1 * 2^24 +#define FE_CARRY \ + i64 c; \ + c = (t0 + ((i64)1<<25)) >> 26; t0 -= c * ((i64)1 << 26); t1 += c; \ + c = (t4 + ((i64)1<<25)) >> 26; t4 -= c * ((i64)1 << 26); t5 += c; \ + c = (t1 + ((i64)1<<24)) >> 25; t1 -= c * ((i64)1 << 25); t2 += c; \ + c = (t5 + ((i64)1<<24)) >> 25; t5 -= c * ((i64)1 << 25); t6 += c; \ + c = (t2 + ((i64)1<<25)) >> 26; t2 -= c * ((i64)1 << 26); t3 += c; \ + c = (t6 + ((i64)1<<25)) >> 26; t6 -= c * ((i64)1 << 26); t7 += c; \ + c = (t3 + ((i64)1<<24)) >> 25; t3 -= c * ((i64)1 << 25); t4 += c; \ + c = (t7 + ((i64)1<<24)) >> 25; t7 -= c * ((i64)1 << 25); t8 += c; \ + c = (t4 + ((i64)1<<25)) >> 26; t4 -= c * ((i64)1 << 26); t5 += c; \ + c = (t8 + ((i64)1<<25)) >> 26; t8 -= c * ((i64)1 << 26); t9 += c; \ + c = (t9 + ((i64)1<<24)) >> 25; t9 -= c * ((i64)1 << 25); t0 += c * 19; \ + c = (t0 + ((i64)1<<25)) >> 26; t0 -= c * ((i64)1 << 26); t1 += c; \ + h[0]=(i32)t0; h[1]=(i32)t1; h[2]=(i32)t2; h[3]=(i32)t3; h[4]=(i32)t4; \ + h[5]=(i32)t5; h[6]=(i32)t6; h[7]=(i32)t7; h[8]=(i32)t8; h[9]=(i32)t9 + +// Decodes a field element from a byte buffer. +// mask specifies how many bits we ignore. +// Traditionally we ignore 1. It's useful for EdDSA, +// which uses that bit to denote the sign of x. +// Elligator however uses positive representatives, +// which means ignoring 2 bits instead. +static void fe_frombytes_mask(fe h, const u8 s[32], unsigned nb_mask) +{ + u32 mask = 0xffffff >> nb_mask; + i64 t0 = load32_le(s); // t0 < 2^32 + i64 t1 = load24_le(s + 4) << 6; // t1 < 2^30 + i64 t2 = load24_le(s + 7) << 5; // t2 < 2^29 + i64 t3 = load24_le(s + 10) << 3; // t3 < 2^27 + i64 t4 = load24_le(s + 13) << 2; // t4 < 2^26 + i64 t5 = load32_le(s + 16); // t5 < 2^32 + i64 t6 = load24_le(s + 20) << 7; // t6 < 2^31 + i64 t7 = load24_le(s + 23) << 5; // t7 < 2^29 + i64 t8 = load24_le(s + 26) << 4; // t8 < 2^28 + i64 t9 = (load24_le(s + 29) & mask) << 2; // t9 < 2^25 + FE_CARRY; // Carry precondition OK +} + +static void fe_frombytes(fe h, const u8 s[32]) +{ + fe_frombytes_mask(h, s, 1); +} + + +// Precondition +// |h[0]|, |h[2]|, |h[4]|, |h[6]|, |h[8]| < 1.1 * 2^25 +// |h[1]|, |h[3]|, |h[5]|, |h[7]|, |h[9]| < 1.1 * 2^24 +// +// Therefore, |h| < 2^255-19 +// There are two possibilities: +// +// - If h is positive, all we need to do is reduce its individual +// limbs down to their tight positive range. +// - If h is negative, we also need to add 2^255-19 to it. +// Or just remove 19 and chop off any excess bit. +static void fe_tobytes(u8 s[32], const fe h) +{ + i32 t[10]; + COPY(t, h, 10); + i32 q = (19 * t[9] + (((i32) 1) << 24)) >> 25; + // |t9| < 1.1 * 2^24 + // -1.1 * 2^24 < t9 < 1.1 * 2^24 + // -21 * 2^24 < 19 * t9 < 21 * 2^24 + // -2^29 < 19 * t9 + 2^24 < 2^29 + // -2^29 / 2^25 < (19 * t9 + 2^24) / 2^25 < 2^29 / 2^25 + // -16 < (19 * t9 + 2^24) / 2^25 < 16 + FOR (i, 0, 5) { + q += t[2*i ]; q >>= 26; // q = 0 or -1 + q += t[2*i+1]; q >>= 25; // q = 0 or -1 + } + // q = 0 iff h >= 0 + // q = -1 iff h < 0 + // Adding q * 19 to h reduces h to its proper range. + q *= 19; // Shift carry back to the beginning + FOR (i, 0, 5) { + t[i*2 ] += q; q = t[i*2 ] >> 26; t[i*2 ] -= q * ((i32)1 << 26); + t[i*2+1] += q; q = t[i*2+1] >> 25; t[i*2+1] -= q * ((i32)1 << 25); + } + // h is now fully reduced, and q represents the excess bit. + + store32_le(s + 0, ((u32)t[0] >> 0) | ((u32)t[1] << 26)); + store32_le(s + 4, ((u32)t[1] >> 6) | ((u32)t[2] << 19)); + store32_le(s + 8, ((u32)t[2] >> 13) | ((u32)t[3] << 13)); + store32_le(s + 12, ((u32)t[3] >> 19) | ((u32)t[4] << 6)); + store32_le(s + 16, ((u32)t[5] >> 0) | ((u32)t[6] << 25)); + store32_le(s + 20, ((u32)t[6] >> 7) | ((u32)t[7] << 19)); + store32_le(s + 24, ((u32)t[7] >> 13) | ((u32)t[8] << 12)); + store32_le(s + 28, ((u32)t[8] >> 20) | ((u32)t[9] << 6)); + + WIPE_BUFFER(t); +} + +// Precondition +// ------------- +// |f0|, |f2|, |f4|, |f6|, |f8| < 1.65 * 2^26 +// |f1|, |f3|, |f5|, |f7|, |f9| < 1.65 * 2^25 +// +// |g0|, |g2|, |g4|, |g6|, |g8| < 1.65 * 2^26 +// |g1|, |g3|, |g5|, |g7|, |g9| < 1.65 * 2^25 +static void fe_mul_small(fe h, const fe f, i32 g) +{ + i64 t0 = f[0] * (i64) g; i64 t1 = f[1] * (i64) g; + i64 t2 = f[2] * (i64) g; i64 t3 = f[3] * (i64) g; + i64 t4 = f[4] * (i64) g; i64 t5 = f[5] * (i64) g; + i64 t6 = f[6] * (i64) g; i64 t7 = f[7] * (i64) g; + i64 t8 = f[8] * (i64) g; i64 t9 = f[9] * (i64) g; + // |t0|, |t2|, |t4|, |t6|, |t8| < 1.65 * 2^26 * 2^31 < 2^58 + // |t1|, |t3|, |t5|, |t7|, |t9| < 1.65 * 2^25 * 2^31 < 2^57 + + FE_CARRY; // Carry precondition OK +} + +// Precondition +// ------------- +// |f0|, |f2|, |f4|, |f6|, |f8| < 1.65 * 2^26 +// |f1|, |f3|, |f5|, |f7|, |f9| < 1.65 * 2^25 +// +// |g0|, |g2|, |g4|, |g6|, |g8| < 1.65 * 2^26 +// |g1|, |g3|, |g5|, |g7|, |g9| < 1.65 * 2^25 +static void fe_mul(fe h, const fe f, const fe g) +{ + // Everything is unrolled and put in temporary variables. + // We could roll the loop, but that would make curve25519 twice as slow. + i32 f0 = f[0]; i32 f1 = f[1]; i32 f2 = f[2]; i32 f3 = f[3]; i32 f4 = f[4]; + i32 f5 = f[5]; i32 f6 = f[6]; i32 f7 = f[7]; i32 f8 = f[8]; i32 f9 = f[9]; + i32 g0 = g[0]; i32 g1 = g[1]; i32 g2 = g[2]; i32 g3 = g[3]; i32 g4 = g[4]; + i32 g5 = g[5]; i32 g6 = g[6]; i32 g7 = g[7]; i32 g8 = g[8]; i32 g9 = g[9]; + i32 F1 = f1*2; i32 F3 = f3*2; i32 F5 = f5*2; i32 F7 = f7*2; i32 F9 = f9*2; + i32 G1 = g1*19; i32 G2 = g2*19; i32 G3 = g3*19; + i32 G4 = g4*19; i32 G5 = g5*19; i32 G6 = g6*19; + i32 G7 = g7*19; i32 G8 = g8*19; i32 G9 = g9*19; + // |F1|, |F3|, |F5|, |F7|, |F9| < 1.65 * 2^26 + // |G0|, |G2|, |G4|, |G6|, |G8| < 2^31 + // |G1|, |G3|, |G5|, |G7|, |G9| < 2^30 + + i64 t0 = f0*(i64)g0 + F1*(i64)G9 + f2*(i64)G8 + F3*(i64)G7 + f4*(i64)G6 + + F5*(i64)G5 + f6*(i64)G4 + F7*(i64)G3 + f8*(i64)G2 + F9*(i64)G1; + i64 t1 = f0*(i64)g1 + f1*(i64)g0 + f2*(i64)G9 + f3*(i64)G8 + f4*(i64)G7 + + f5*(i64)G6 + f6*(i64)G5 + f7*(i64)G4 + f8*(i64)G3 + f9*(i64)G2; + i64 t2 = f0*(i64)g2 + F1*(i64)g1 + f2*(i64)g0 + F3*(i64)G9 + f4*(i64)G8 + + F5*(i64)G7 + f6*(i64)G6 + F7*(i64)G5 + f8*(i64)G4 + F9*(i64)G3; + i64 t3 = f0*(i64)g3 + f1*(i64)g2 + f2*(i64)g1 + f3*(i64)g0 + f4*(i64)G9 + + f5*(i64)G8 + f6*(i64)G7 + f7*(i64)G6 + f8*(i64)G5 + f9*(i64)G4; + i64 t4 = f0*(i64)g4 + F1*(i64)g3 + f2*(i64)g2 + F3*(i64)g1 + f4*(i64)g0 + + F5*(i64)G9 + f6*(i64)G8 + F7*(i64)G7 + f8*(i64)G6 + F9*(i64)G5; + i64 t5 = f0*(i64)g5 + f1*(i64)g4 + f2*(i64)g3 + f3*(i64)g2 + f4*(i64)g1 + + f5*(i64)g0 + f6*(i64)G9 + f7*(i64)G8 + f8*(i64)G7 + f9*(i64)G6; + i64 t6 = f0*(i64)g6 + F1*(i64)g5 + f2*(i64)g4 + F3*(i64)g3 + f4*(i64)g2 + + F5*(i64)g1 + f6*(i64)g0 + F7*(i64)G9 + f8*(i64)G8 + F9*(i64)G7; + i64 t7 = f0*(i64)g7 + f1*(i64)g6 + f2*(i64)g5 + f3*(i64)g4 + f4*(i64)g3 + + f5*(i64)g2 + f6*(i64)g1 + f7*(i64)g0 + f8*(i64)G9 + f9*(i64)G8; + i64 t8 = f0*(i64)g8 + F1*(i64)g7 + f2*(i64)g6 + F3*(i64)g5 + f4*(i64)g4 + + F5*(i64)g3 + f6*(i64)g2 + F7*(i64)g1 + f8*(i64)g0 + F9*(i64)G9; + i64 t9 = f0*(i64)g9 + f1*(i64)g8 + f2*(i64)g7 + f3*(i64)g6 + f4*(i64)g5 + + f5*(i64)g4 + f6*(i64)g3 + f7*(i64)g2 + f8*(i64)g1 + f9*(i64)g0; + // t0 < 0.67 * 2^61 + // t1 < 0.41 * 2^61 + // t2 < 0.52 * 2^61 + // t3 < 0.32 * 2^61 + // t4 < 0.38 * 2^61 + // t5 < 0.22 * 2^61 + // t6 < 0.23 * 2^61 + // t7 < 0.13 * 2^61 + // t8 < 0.09 * 2^61 + // t9 < 0.03 * 2^61 + + FE_CARRY; // Everything below 2^62, Carry precondition OK +} + +// Precondition +// ------------- +// |f0|, |f2|, |f4|, |f6|, |f8| < 1.65 * 2^26 +// |f1|, |f3|, |f5|, |f7|, |f9| < 1.65 * 2^25 +// +// Note: we could use fe_mul() for this, but this is significantly faster +static void fe_sq(fe h, const fe f) +{ + i32 f0 = f[0]; i32 f1 = f[1]; i32 f2 = f[2]; i32 f3 = f[3]; i32 f4 = f[4]; + i32 f5 = f[5]; i32 f6 = f[6]; i32 f7 = f[7]; i32 f8 = f[8]; i32 f9 = f[9]; + i32 f0_2 = f0*2; i32 f1_2 = f1*2; i32 f2_2 = f2*2; i32 f3_2 = f3*2; + i32 f4_2 = f4*2; i32 f5_2 = f5*2; i32 f6_2 = f6*2; i32 f7_2 = f7*2; + i32 f5_38 = f5*38; i32 f6_19 = f6*19; i32 f7_38 = f7*38; + i32 f8_19 = f8*19; i32 f9_38 = f9*38; + // |f0_2| , |f2_2| , |f4_2| , |f6_2| , |f8_2| < 1.65 * 2^27 + // |f1_2| , |f3_2| , |f5_2| , |f7_2| , |f9_2| < 1.65 * 2^26 + // |f5_38|, |f6_19|, |f7_38|, |f8_19|, |f9_38| < 2^31 + + i64 t0 = f0 *(i64)f0 + f1_2*(i64)f9_38 + f2_2*(i64)f8_19 + + f3_2*(i64)f7_38 + f4_2*(i64)f6_19 + f5 *(i64)f5_38; + i64 t1 = f0_2*(i64)f1 + f2 *(i64)f9_38 + f3_2*(i64)f8_19 + + f4 *(i64)f7_38 + f5_2*(i64)f6_19; + i64 t2 = f0_2*(i64)f2 + f1_2*(i64)f1 + f3_2*(i64)f9_38 + + f4_2*(i64)f8_19 + f5_2*(i64)f7_38 + f6 *(i64)f6_19; + i64 t3 = f0_2*(i64)f3 + f1_2*(i64)f2 + f4 *(i64)f9_38 + + f5_2*(i64)f8_19 + f6 *(i64)f7_38; + i64 t4 = f0_2*(i64)f4 + f1_2*(i64)f3_2 + f2 *(i64)f2 + + f5_2*(i64)f9_38 + f6_2*(i64)f8_19 + f7 *(i64)f7_38; + i64 t5 = f0_2*(i64)f5 + f1_2*(i64)f4 + f2_2*(i64)f3 + + f6 *(i64)f9_38 + f7_2*(i64)f8_19; + i64 t6 = f0_2*(i64)f6 + f1_2*(i64)f5_2 + f2_2*(i64)f4 + + f3_2*(i64)f3 + f7_2*(i64)f9_38 + f8 *(i64)f8_19; + i64 t7 = f0_2*(i64)f7 + f1_2*(i64)f6 + f2_2*(i64)f5 + + f3_2*(i64)f4 + f8 *(i64)f9_38; + i64 t8 = f0_2*(i64)f8 + f1_2*(i64)f7_2 + f2_2*(i64)f6 + + f3_2*(i64)f5_2 + f4 *(i64)f4 + f9 *(i64)f9_38; + i64 t9 = f0_2*(i64)f9 + f1_2*(i64)f8 + f2_2*(i64)f7 + + f3_2*(i64)f6 + f4 *(i64)f5_2; + // t0 < 0.67 * 2^61 + // t1 < 0.41 * 2^61 + // t2 < 0.52 * 2^61 + // t3 < 0.32 * 2^61 + // t4 < 0.38 * 2^61 + // t5 < 0.22 * 2^61 + // t6 < 0.23 * 2^61 + // t7 < 0.13 * 2^61 + // t8 < 0.09 * 2^61 + // t9 < 0.03 * 2^61 + + FE_CARRY; +} + +// Parity check. Returns 0 if even, 1 if odd +static int fe_isodd(const fe f) +{ + u8 s[32]; + fe_tobytes(s, f); + u8 isodd = s[0] & 1; + WIPE_BUFFER(s); + return isodd; +} + +// Returns 1 if equal, 0 if not equal +static int fe_isequal(const fe f, const fe g) +{ + u8 fs[32]; + u8 gs[32]; + fe_tobytes(fs, f); + fe_tobytes(gs, g); + int isdifferent = crypto_verify32(fs, gs); + WIPE_BUFFER(fs); + WIPE_BUFFER(gs); + return 1 + isdifferent; +} + +// Inverse square root. +// Returns true if x is a square, false otherwise. +// After the call: +// isr = sqrt(1/x) if x is a non-zero square. +// isr = sqrt(sqrt(-1)/x) if x is not a square. +// isr = 0 if x is zero. +// We do not guarantee the sign of the square root. +// +// Notes: +// Let quartic = x^((p-1)/4) +// +// x^((p-1)/2) = chi(x) +// quartic^2 = chi(x) +// quartic = sqrt(chi(x)) +// quartic = 1 or -1 or sqrt(-1) or -sqrt(-1) +// +// Note that x is a square if quartic is 1 or -1 +// There are 4 cases to consider: +// +// if quartic = 1 (x is a square) +// then x^((p-1)/4) = 1 +// x^((p-5)/4) * x = 1 +// x^((p-5)/4) = 1/x +// x^((p-5)/8) = sqrt(1/x) or -sqrt(1/x) +// +// if quartic = -1 (x is a square) +// then x^((p-1)/4) = -1 +// x^((p-5)/4) * x = -1 +// x^((p-5)/4) = -1/x +// x^((p-5)/8) = sqrt(-1) / sqrt(x) +// x^((p-5)/8) * sqrt(-1) = sqrt(-1)^2 / sqrt(x) +// x^((p-5)/8) * sqrt(-1) = -1/sqrt(x) +// x^((p-5)/8) * sqrt(-1) = -sqrt(1/x) or sqrt(1/x) +// +// if quartic = sqrt(-1) (x is not a square) +// then x^((p-1)/4) = sqrt(-1) +// x^((p-5)/4) * x = sqrt(-1) +// x^((p-5)/4) = sqrt(-1)/x +// x^((p-5)/8) = sqrt(sqrt(-1)/x) or -sqrt(sqrt(-1)/x) +// +// Note that the product of two non-squares is always a square: +// For any non-squares a and b, chi(a) = -1 and chi(b) = -1. +// Since chi(x) = x^((p-1)/2), chi(a)*chi(b) = chi(a*b) = 1. +// Therefore a*b is a square. +// +// Since sqrt(-1) and x are both non-squares, their product is a +// square, and we can compute their square root. +// +// if quartic = -sqrt(-1) (x is not a square) +// then x^((p-1)/4) = -sqrt(-1) +// x^((p-5)/4) * x = -sqrt(-1) +// x^((p-5)/4) = -sqrt(-1)/x +// x^((p-5)/8) = sqrt(-sqrt(-1)/x) +// x^((p-5)/8) = sqrt( sqrt(-1)/x) * sqrt(-1) +// x^((p-5)/8) * sqrt(-1) = sqrt( sqrt(-1)/x) * sqrt(-1)^2 +// x^((p-5)/8) * sqrt(-1) = sqrt( sqrt(-1)/x) * -1 +// x^((p-5)/8) * sqrt(-1) = -sqrt(sqrt(-1)/x) or sqrt(sqrt(-1)/x) +static int invsqrt(fe isr, const fe x) +{ + fe t0, t1, t2; + + // t0 = x^((p-5)/8) + // Can be achieved with a simple double & add ladder, + // but it would be slower. + fe_sq(t0, x); + fe_sq(t1,t0); fe_sq(t1, t1); fe_mul(t1, x, t1); + fe_mul(t0, t0, t1); + fe_sq(t0, t0); fe_mul(t0, t1, t0); + fe_sq(t1, t0); FOR (i, 1, 5) { fe_sq(t1, t1); } fe_mul(t0, t1, t0); + fe_sq(t1, t0); FOR (i, 1, 10) { fe_sq(t1, t1); } fe_mul(t1, t1, t0); + fe_sq(t2, t1); FOR (i, 1, 20) { fe_sq(t2, t2); } fe_mul(t1, t2, t1); + fe_sq(t1, t1); FOR (i, 1, 10) { fe_sq(t1, t1); } fe_mul(t0, t1, t0); + fe_sq(t1, t0); FOR (i, 1, 50) { fe_sq(t1, t1); } fe_mul(t1, t1, t0); + fe_sq(t2, t1); FOR (i, 1, 100) { fe_sq(t2, t2); } fe_mul(t1, t2, t1); + fe_sq(t1, t1); FOR (i, 1, 50) { fe_sq(t1, t1); } fe_mul(t0, t1, t0); + fe_sq(t0, t0); FOR (i, 1, 2) { fe_sq(t0, t0); } fe_mul(t0, t0, x); + + // quartic = x^((p-1)/4) + i32 *quartic = t1; + fe_sq (quartic, t0); + fe_mul(quartic, quartic, x); + + i32 *check = t2; + fe_0 (check); int z0 = fe_isequal(x , check); + fe_1 (check); int p1 = fe_isequal(quartic, check); + fe_neg(check, check ); int m1 = fe_isequal(quartic, check); + fe_neg(check, sqrtm1); int ms = fe_isequal(quartic, check); + + // if quartic == -1 or sqrt(-1) + // then isr = x^((p-1)/4) * sqrt(-1) + // else isr = x^((p-1)/4) + fe_mul(isr, t0, sqrtm1); + fe_ccopy(isr, t0, 1 - (m1 | ms)); + + WIPE_BUFFER(t0); + WIPE_BUFFER(t1); + WIPE_BUFFER(t2); + return p1 | m1 | z0; +} + +// Inverse in terms of inverse square root. +// Requires two additional squarings to get rid of the sign. +// +// 1/x = x * (+invsqrt(x^2))^2 +// = x * (-invsqrt(x^2))^2 +// +// A fully optimised exponentiation by p-1 would save 6 field +// multiplications, but it would require more code. +static void fe_invert(fe out, const fe x) +{ + fe tmp; + fe_sq(tmp, x); + invsqrt(tmp, tmp); + fe_sq(tmp, tmp); + fe_mul(out, tmp, x); + WIPE_BUFFER(tmp); +} + +// trim a scalar for scalar multiplication +void crypto_eddsa_trim_scalar(u8 out[32], const u8 in[32]) +{ + COPY(out, in, 32); + out[ 0] &= 248; + out[31] &= 127; + out[31] |= 64; +} + +// get bit from scalar at position i +static int scalar_bit(const u8 s[32], int i) +{ + if (i < 0) { return 0; } // handle -1 for sliding windows + return (s[i>>3] >> (i&7)) & 1; +} + +/////////////// +/// X-25519 /// Taken from SUPERCOP's ref10 implementation. +/////////////// +static void scalarmult(u8 q[32], const u8 scalar[32], const u8 p[32], + int nb_bits) +{ + // computes the scalar product + fe x1; + fe_frombytes(x1, p); + + // computes the actual scalar product (the result is in x2 and z2) + fe x2, z2, x3, z3, t0, t1; + // Montgomery ladder + // In projective coordinates, to avoid divisions: x = X / Z + // We don't care about the y coordinate, it's only 1 bit of information + fe_1(x2); fe_0(z2); // "zero" point + fe_copy(x3, x1); fe_1(z3); // "one" point + int swap = 0; + for (int pos = nb_bits-1; pos >= 0; --pos) { + // constant time conditional swap before ladder step + int b = scalar_bit(scalar, pos); + swap ^= b; // xor trick avoids swapping at the end of the loop + fe_cswap(x2, x3, swap); + fe_cswap(z2, z3, swap); + swap = b; // anticipates one last swap after the loop + + // Montgomery ladder step: replaces (P2, P3) by (P2*2, P2+P3) + // with differential addition + fe_sub(t0, x3, z3); + fe_sub(t1, x2, z2); + fe_add(x2, x2, z2); + fe_add(z2, x3, z3); + fe_mul(z3, t0, x2); + fe_mul(z2, z2, t1); + fe_sq (t0, t1 ); + fe_sq (t1, x2 ); + fe_add(x3, z3, z2); + fe_sub(z2, z3, z2); + fe_mul(x2, t1, t0); + fe_sub(t1, t1, t0); + fe_sq (z2, z2 ); + fe_mul_small(z3, t1, 121666); + fe_sq (x3, x3 ); + fe_add(t0, t0, z3); + fe_mul(z3, x1, z2); + fe_mul(z2, t1, t0); + } + // last swap is necessary to compensate for the xor trick + // Note: after this swap, P3 == P2 + P1. + fe_cswap(x2, x3, swap); + fe_cswap(z2, z3, swap); + + // normalises the coordinates: x == X / Z + fe_invert(z2, z2); + fe_mul(x2, x2, z2); + fe_tobytes(q, x2); + + WIPE_BUFFER(x1); + WIPE_BUFFER(x2); WIPE_BUFFER(z2); WIPE_BUFFER(t0); + WIPE_BUFFER(x3); WIPE_BUFFER(z3); WIPE_BUFFER(t1); +} + +void crypto_x25519(u8 raw_shared_secret[32], + const u8 your_secret_key [32], + const u8 their_public_key [32]) +{ + // restrict the possible scalar values + u8 e[32]; + crypto_eddsa_trim_scalar(e, your_secret_key); + scalarmult(raw_shared_secret, e, their_public_key, 255); + WIPE_BUFFER(e); +} + +void crypto_x25519_public_key(u8 public_key[32], + const u8 secret_key[32]) +{ + static const u8 base_point[32] = {9}; + crypto_x25519(public_key, secret_key, base_point); +} + +/////////////////////////// +/// Arithmetic modulo L /// +/////////////////////////// +static const u32 L[8] = { + 0x5cf5d3ed, 0x5812631a, 0xa2f79cd6, 0x14def9de, + 0x00000000, 0x00000000, 0x00000000, 0x10000000, +}; + +// p = a*b + p +static void multiply(u32 p[16], const u32 a[8], const u32 b[8]) +{ + FOR (i, 0, 8) { + u64 carry = 0; + FOR (j, 0, 8) { + carry += p[i+j] + (u64)a[i] * b[j]; + p[i+j] = (u32)carry; + carry >>= 32; + } + p[i+8] = (u32)carry; + } +} + +static int is_above_l(const u32 x[8]) +{ + // We work with L directly, in a 2's complement encoding + // (-L == ~L + 1) + u64 carry = 1; + FOR (i, 0, 8) { + carry += (u64)x[i] + (~L[i] & 0xffffffff); + carry >>= 32; + } + return (int)carry; // carry is either 0 or 1 +} + +// Final reduction modulo L, by conditionally removing L. +// if x < l , then r = x +// if l <= x 2*l, then r = x-l +// otherwise the result will be wrong +static void remove_l(u32 r[8], const u32 x[8]) +{ + u64 carry = (u64)is_above_l(x); + u32 mask = ~(u32)carry + 1; // carry == 0 or 1 + FOR (i, 0, 8) { + carry += (u64)x[i] + (~L[i] & mask); + r[i] = (u32)carry; + carry >>= 32; + } +} + +// Full reduction modulo L (Barrett reduction) +static void mod_l(u8 reduced[32], const u32 x[16]) +{ + static const u32 r[9] = { + 0x0a2c131b,0xed9ce5a3,0x086329a7,0x2106215d, + 0xffffffeb,0xffffffff,0xffffffff,0xffffffff,0xf, + }; + // xr = x * r + u32 xr[25] = {0}; + FOR (i, 0, 9) { + u64 carry = 0; + FOR (j, 0, 16) { + carry += xr[i+j] + (u64)r[i] * x[j]; + xr[i+j] = (u32)carry; + carry >>= 32; + } + xr[i+16] = (u32)carry; + } + // xr = floor(xr / 2^512) * L + // Since the result is guaranteed to be below 2*L, + // it is enough to only compute the first 256 bits. + // The division is performed by saying xr[i+16]. (16 * 32 = 512) + ZERO(xr, 8); + FOR (i, 0, 8) { + u64 carry = 0; + FOR (j, 0, 8-i) { + carry += xr[i+j] + (u64)xr[i+16] * L[j]; + xr[i+j] = (u32)carry; + carry >>= 32; + } + } + // xr = x - xr + u64 carry = 1; + FOR (i, 0, 8) { + carry += (u64)x[i] + (~xr[i] & 0xffffffff); + xr[i] = (u32)carry; + carry >>= 32; + } + // Final reduction modulo L (conditional subtraction) + remove_l(xr, xr); + store32_le_buf(reduced, xr, 8); + + WIPE_BUFFER(xr); +} + +void crypto_eddsa_reduce(u8 reduced[32], const u8 expanded[64]) +{ + u32 x[16]; + load32_le_buf(x, expanded, 16); + mod_l(reduced, x); + WIPE_BUFFER(x); +} + +// r = (a * b) + c +void crypto_eddsa_mul_add(u8 r[32], + const u8 a[32], const u8 b[32], const u8 c[32]) +{ + u32 A[8]; load32_le_buf(A, a, 8); + u32 B[8]; load32_le_buf(B, b, 8); + u32 p[16]; load32_le_buf(p, c, 8); ZERO(p + 8, 8); + multiply(p, A, B); + mod_l(r, p); + WIPE_BUFFER(p); + WIPE_BUFFER(A); + WIPE_BUFFER(B); +} + +/////////////// +/// Ed25519 /// +/////////////// + +// Point (group element, ge) in a twisted Edwards curve, +// in extended projective coordinates. +// ge : x = X/Z, y = Y/Z, T = XY/Z +// ge_cached : Yp = X+Y, Ym = X-Y, T2 = T*D2 +// ge_precomp: Z = 1 +typedef struct { fe X; fe Y; fe Z; fe T; } ge; +typedef struct { fe Yp; fe Ym; fe Z; fe T2; } ge_cached; +typedef struct { fe Yp; fe Ym; fe T2; } ge_precomp; + +static void ge_zero(ge *p) +{ + fe_0(p->X); + fe_1(p->Y); + fe_1(p->Z); + fe_0(p->T); +} + +static void ge_tobytes(u8 s[32], const ge *h) +{ + fe recip, x, y; + fe_invert(recip, h->Z); + fe_mul(x, h->X, recip); + fe_mul(y, h->Y, recip); + fe_tobytes(s, y); + s[31] ^= fe_isodd(x) << 7; + + WIPE_BUFFER(recip); + WIPE_BUFFER(x); + WIPE_BUFFER(y); +} + +// h = -s, where s is a point encoded in 32 bytes +// +// Variable time! Inputs must not be secret! +// => Use only to *check* signatures. +// +// From the specifications: +// The encoding of s contains y and the sign of x +// x = sqrt((y^2 - 1) / (d*y^2 + 1)) +// In extended coordinates: +// X = x, Y = y, Z = 1, T = x*y +// +// Note that num * den is a square iff num / den is a square +// If num * den is not a square, the point was not on the curve. +// From the above: +// Let num = y^2 - 1 +// Let den = d*y^2 + 1 +// x = sqrt((y^2 - 1) / (d*y^2 + 1)) +// x = sqrt(num / den) +// x = sqrt(num^2 / (num * den)) +// x = num * sqrt(1 / (num * den)) +// +// Therefore, we can just compute: +// num = y^2 - 1 +// den = d*y^2 + 1 +// isr = invsqrt(num * den) // abort if not square +// x = num * isr +// Finally, negate x if its sign is not as specified. +static int ge_frombytes_neg_vartime(ge *h, const u8 s[32]) +{ + fe_frombytes(h->Y, s); + fe_1(h->Z); + fe_sq (h->T, h->Y); // t = y^2 + fe_mul(h->X, h->T, d ); // x = d*y^2 + fe_sub(h->T, h->T, h->Z); // t = y^2 - 1 + fe_add(h->X, h->X, h->Z); // x = d*y^2 + 1 + fe_mul(h->X, h->T, h->X); // x = (y^2 - 1) * (d*y^2 + 1) + int is_square = invsqrt(h->X, h->X); + if (!is_square) { + return -1; // Not on the curve, abort + } + fe_mul(h->X, h->T, h->X); // x = sqrt((y^2 - 1) / (d*y^2 + 1)) + if (fe_isodd(h->X) == (s[31] >> 7)) { + fe_neg(h->X, h->X); + } + fe_mul(h->T, h->X, h->Y); + return 0; +} + +static void ge_cache(ge_cached *c, const ge *p) +{ + fe_add (c->Yp, p->Y, p->X); + fe_sub (c->Ym, p->Y, p->X); + fe_copy(c->Z , p->Z ); + fe_mul (c->T2, p->T, D2 ); +} + +// Internal buffers are not wiped! Inputs must not be secret! +// => Use only to *check* signatures. +static void ge_add(ge *s, const ge *p, const ge_cached *q) +{ + fe a, b; + fe_add(a , p->Y, p->X ); + fe_sub(b , p->Y, p->X ); + fe_mul(a , a , q->Yp); + fe_mul(b , b , q->Ym); + fe_add(s->Y, a , b ); + fe_sub(s->X, a , b ); + + fe_add(s->Z, p->Z, p->Z ); + fe_mul(s->Z, s->Z, q->Z ); + fe_mul(s->T, p->T, q->T2); + fe_add(a , s->Z, s->T ); + fe_sub(b , s->Z, s->T ); + + fe_mul(s->T, s->X, s->Y); + fe_mul(s->X, s->X, b ); + fe_mul(s->Y, s->Y, a ); + fe_mul(s->Z, a , b ); +} + +// Internal buffers are not wiped! Inputs must not be secret! +// => Use only to *check* signatures. +static void ge_sub(ge *s, const ge *p, const ge_cached *q) +{ + ge_cached neg; + fe_copy(neg.Ym, q->Yp); + fe_copy(neg.Yp, q->Ym); + fe_copy(neg.Z , q->Z ); + fe_neg (neg.T2, q->T2); + ge_add(s, p, &neg); +} + +static void ge_madd(ge *s, const ge *p, const ge_precomp *q, fe a, fe b) +{ + fe_add(a , p->Y, p->X ); + fe_sub(b , p->Y, p->X ); + fe_mul(a , a , q->Yp); + fe_mul(b , b , q->Ym); + fe_add(s->Y, a , b ); + fe_sub(s->X, a , b ); + + fe_add(s->Z, p->Z, p->Z ); + fe_mul(s->T, p->T, q->T2); + fe_add(a , s->Z, s->T ); + fe_sub(b , s->Z, s->T ); + + fe_mul(s->T, s->X, s->Y); + fe_mul(s->X, s->X, b ); + fe_mul(s->Y, s->Y, a ); + fe_mul(s->Z, a , b ); +} + +// Internal buffers are not wiped! Inputs must not be secret! +// => Use only to *check* signatures. +static void ge_msub(ge *s, const ge *p, const ge_precomp *q, fe a, fe b) +{ + ge_precomp neg; + fe_copy(neg.Ym, q->Yp); + fe_copy(neg.Yp, q->Ym); + fe_neg (neg.T2, q->T2); + ge_madd(s, p, &neg, a, b); +} + +static void ge_double(ge *s, const ge *p, ge *q) +{ + fe_sq (q->X, p->X); + fe_sq (q->Y, p->Y); + fe_sq (q->Z, p->Z); // qZ = pZ^2 + fe_mul_small(q->Z, q->Z, 2); // qZ = pZ^2 * 2 + fe_add(q->T, p->X, p->Y); + fe_sq (s->T, q->T); + fe_add(q->T, q->Y, q->X); + fe_sub(q->Y, q->Y, q->X); + fe_sub(q->X, s->T, q->T); + fe_sub(q->Z, q->Z, q->Y); + + fe_mul(s->X, q->X , q->Z); + fe_mul(s->Y, q->T , q->Y); + fe_mul(s->Z, q->Y , q->Z); + fe_mul(s->T, q->X , q->T); +} + +// 5-bit signed window in cached format (Niels coordinates, Z=1) +static const ge_precomp b_window[8] = { + {{25967493,-14356035,29566456,3660896,-12694345, + 4014787,27544626,-11754271,-6079156,2047605,}, + {-12545711,934262,-2722910,3049990,-727428, + 9406986,12720692,5043384,19500929,-15469378,}, + {-8738181,4489570,9688441,-14785194,10184609, + -12363380,29287919,11864899,-24514362,-4438546,},}, + {{15636291,-9688557,24204773,-7912398,616977, + -16685262,27787600,-14772189,28944400,-1550024,}, + {16568933,4717097,-11556148,-1102322,15682896, + -11807043,16354577,-11775962,7689662,11199574,}, + {30464156,-5976125,-11779434,-15670865,23220365, + 15915852,7512774,10017326,-17749093,-9920357,},}, + {{10861363,11473154,27284546,1981175,-30064349, + 12577861,32867885,14515107,-15438304,10819380,}, + {4708026,6336745,20377586,9066809,-11272109, + 6594696,-25653668,12483688,-12668491,5581306,}, + {19563160,16186464,-29386857,4097519,10237984, + -4348115,28542350,13850243,-23678021,-15815942,},}, + {{5153746,9909285,1723747,-2777874,30523605, + 5516873,19480852,5230134,-23952439,-15175766,}, + {-30269007,-3463509,7665486,10083793,28475525, + 1649722,20654025,16520125,30598449,7715701,}, + {28881845,14381568,9657904,3680757,-20181635, + 7843316,-31400660,1370708,29794553,-1409300,},}, + {{-22518993,-6692182,14201702,-8745502,-23510406, + 8844726,18474211,-1361450,-13062696,13821877,}, + {-6455177,-7839871,3374702,-4740862,-27098617, + -10571707,31655028,-7212327,18853322,-14220951,}, + {4566830,-12963868,-28974889,-12240689,-7602672, + -2830569,-8514358,-10431137,2207753,-3209784,},}, + {{-25154831,-4185821,29681144,7868801,-6854661, + -9423865,-12437364,-663000,-31111463,-16132436,}, + {25576264,-2703214,7349804,-11814844,16472782, + 9300885,3844789,15725684,171356,6466918,}, + {23103977,13316479,9739013,-16149481,817875, + -15038942,8965339,-14088058,-30714912,16193877,},}, + {{-33521811,3180713,-2394130,14003687,-16903474, + -16270840,17238398,4729455,-18074513,9256800,}, + {-25182317,-4174131,32336398,5036987,-21236817, + 11360617,22616405,9761698,-19827198,630305,}, + {-13720693,2639453,-24237460,-7406481,9494427, + -5774029,-6554551,-15960994,-2449256,-14291300,},}, + {{-3151181,-5046075,9282714,6866145,-31907062, + -863023,-18940575,15033784,25105118,-7894876,}, + {-24326370,15950226,-31801215,-14592823,-11662737, + -5090925,1573892,-2625887,2198790,-15804619,}, + {-3099351,10324967,-2241613,7453183,-5446979, + -2735503,-13812022,-16236442,-32461234,-12290683,},}, +}; + +// Incremental sliding windows (left to right) +// Based on Roberto Maria Avanzi[2005] +typedef struct { + i16 next_index; // position of the next signed digit + i8 next_digit; // next signed digit (odd number below 2^window_width) + u8 next_check; // point at which we must check for a new window +} slide_ctx; + +static void slide_init(slide_ctx *ctx, const u8 scalar[32]) +{ + // scalar is guaranteed to be below L, either because we checked (s), + // or because we reduced it modulo L (h_ram). L is under 2^253, so + // so bits 253 to 255 are guaranteed to be zero. No need to test them. + // + // Note however that L is very close to 2^252, so bit 252 is almost + // always zero. If we were to start at bit 251, the tests wouldn't + // catch the off-by-one error (constructing one that does would be + // prohibitively expensive). + // + // We should still check bit 252, though. + int i = 252; + while (i > 0 && scalar_bit(scalar, i) == 0) { + i--; + } + ctx->next_check = (u8)(i + 1); + ctx->next_index = -1; + ctx->next_digit = -1; +} + +static int slide_step(slide_ctx *ctx, int width, int i, const u8 scalar[32]) +{ + if (i == ctx->next_check) { + if (scalar_bit(scalar, i) == scalar_bit(scalar, i - 1)) { + ctx->next_check--; + } else { + // compute digit of next window + int w = MIN(width, i + 1); + int v = -(scalar_bit(scalar, i) << (w-1)); + FOR_T (int, j, 0, w-1) { + v += scalar_bit(scalar, i-(w-1)+j) << j; + } + v += scalar_bit(scalar, i-w); + int lsb = v & (~v + 1); // smallest bit of v + int s = // log2(lsb) + (((lsb & 0xAA) != 0) << 0) | + (((lsb & 0xCC) != 0) << 1) | + (((lsb & 0xF0) != 0) << 2); + ctx->next_index = (i16)(i-(w-1)+s); + ctx->next_digit = (i8) (v >> s ); + ctx->next_check -= (u8) w; + } + } + return i == ctx->next_index ? ctx->next_digit: 0; +} + +#define P_W_WIDTH 3 // Affects the size of the stack +#define B_W_WIDTH 5 // Affects the size of the binary +#define P_W_SIZE (1<<(P_W_WIDTH-2)) + +int crypto_eddsa_check_equation(const u8 signature[64], const u8 public_key[32], + const u8 h[32]) +{ + ge minus_A; // -public_key + ge minus_R; // -first_half_of_signature + const u8 *s = signature + 32; + + // Check that A and R are on the curve + // Check that 0 <= S < L (prevents malleability) + // *Allow* non-cannonical encoding for A and R + { + u32 s32[8]; + load32_le_buf(s32, s, 8); + if (ge_frombytes_neg_vartime(&minus_A, public_key) || + ge_frombytes_neg_vartime(&minus_R, signature) || + is_above_l(s32)) { + return -1; + } + } + + // look-up table for minus_A + ge_cached lutA[P_W_SIZE]; + { + ge minus_A2, tmp; + ge_double(&minus_A2, &minus_A, &tmp); + ge_cache(&lutA[0], &minus_A); + FOR (i, 1, P_W_SIZE) { + ge_add(&tmp, &minus_A2, &lutA[i-1]); + ge_cache(&lutA[i], &tmp); + } + } + + // sum = [s]B - [h]A + // Merged double and add ladder, fused with sliding + slide_ctx h_slide; slide_init(&h_slide, h); + slide_ctx s_slide; slide_init(&s_slide, s); + int i = MAX(h_slide.next_check, s_slide.next_check); + ge *sum = &minus_A; // reuse minus_A for the sum + ge_zero(sum); + while (i >= 0) { + ge tmp; + ge_double(sum, sum, &tmp); + int h_digit = slide_step(&h_slide, P_W_WIDTH, i, h); + int s_digit = slide_step(&s_slide, B_W_WIDTH, i, s); + if (h_digit > 0) { ge_add(sum, sum, &lutA[ h_digit / 2]); } + if (h_digit < 0) { ge_sub(sum, sum, &lutA[-h_digit / 2]); } + fe t1, t2; + if (s_digit > 0) { ge_madd(sum, sum, b_window + s_digit/2, t1, t2); } + if (s_digit < 0) { ge_msub(sum, sum, b_window + -s_digit/2, t1, t2); } + i--; + } + + // Compare [8](sum-R) and the zero point + // The multiplication by 8 eliminates any low-order component + // and ensures consistency with batched verification. + ge_cached cached; + u8 check[32]; + static const u8 zero_point[32] = {1}; // Point of order 1 + ge_cache(&cached, &minus_R); + ge_add(sum, sum, &cached); + ge_double(sum, sum, &minus_R); // reuse minus_R as temporary + ge_double(sum, sum, &minus_R); // reuse minus_R as temporary + ge_double(sum, sum, &minus_R); // reuse minus_R as temporary + ge_tobytes(check, sum); + return crypto_verify32(check, zero_point); +} + +// 5-bit signed comb in cached format (Niels coordinates, Z=1) +static const ge_precomp b_comb_low[8] = { + {{-6816601,-2324159,-22559413,124364,18015490, + 8373481,19993724,1979872,-18549925,9085059,}, + {10306321,403248,14839893,9633706,8463310, + -8354981,-14305673,14668847,26301366,2818560,}, + {-22701500,-3210264,-13831292,-2927732,-16326337, + -14016360,12940910,177905,12165515,-2397893,},}, + {{-12282262,-7022066,9920413,-3064358,-32147467, + 2927790,22392436,-14852487,2719975,16402117,}, + {-7236961,-4729776,2685954,-6525055,-24242706, + -15940211,-6238521,14082855,10047669,12228189,}, + {-30495588,-12893761,-11161261,3539405,-11502464, + 16491580,-27286798,-15030530,-7272871,-15934455,},}, + {{17650926,582297,-860412,-187745,-12072900, + -10683391,-20352381,15557840,-31072141,-5019061,}, + {-6283632,-2259834,-4674247,-4598977,-4089240, + 12435688,-31278303,1060251,6256175,10480726,}, + {-13871026,2026300,-21928428,-2741605,-2406664, + -8034988,7355518,15733500,-23379862,7489131,},}, + {{6883359,695140,23196907,9644202,-33430614, + 11354760,-20134606,6388313,-8263585,-8491918,}, + {-7716174,-13605463,-13646110,14757414,-19430591, + -14967316,10359532,-11059670,-21935259,12082603,}, + {-11253345,-15943946,10046784,5414629,24840771, + 8086951,-6694742,9868723,15842692,-16224787,},}, + {{9639399,11810955,-24007778,-9320054,3912937, + -9856959,996125,-8727907,-8919186,-14097242,}, + {7248867,14468564,25228636,-8795035,14346339, + 8224790,6388427,-7181107,6468218,-8720783,}, + {15513115,15439095,7342322,-10157390,18005294, + -7265713,2186239,4884640,10826567,7135781,},}, + {{-14204238,5297536,-5862318,-6004934,28095835, + 4236101,-14203318,1958636,-16816875,3837147,}, + {-5511166,-13176782,-29588215,12339465,15325758, + -15945770,-8813185,11075932,-19608050,-3776283,}, + {11728032,9603156,-4637821,-5304487,-7827751, + 2724948,31236191,-16760175,-7268616,14799772,},}, + {{-28842672,4840636,-12047946,-9101456,-1445464, + 381905,-30977094,-16523389,1290540,12798615,}, + {27246947,-10320914,14792098,-14518944,5302070, + -8746152,-3403974,-4149637,-27061213,10749585,}, + {25572375,-6270368,-15353037,16037944,1146292, + 32198,23487090,9585613,24714571,-1418265,},}, + {{19844825,282124,-17583147,11004019,-32004269, + -2716035,6105106,-1711007,-21010044,14338445,}, + {8027505,8191102,-18504907,-12335737,25173494, + -5923905,15446145,7483684,-30440441,10009108,}, + {-14134701,-4174411,10246585,-14677495,33553567, + -14012935,23366126,15080531,-7969992,7663473,},}, +}; + +static const ge_precomp b_comb_high[8] = { + {{33055887,-4431773,-521787,6654165,951411, + -6266464,-5158124,6995613,-5397442,-6985227,}, + {4014062,6967095,-11977872,3960002,8001989, + 5130302,-2154812,-1899602,-31954493,-16173976,}, + {16271757,-9212948,23792794,731486,-25808309, + -3546396,6964344,-4767590,10976593,10050757,},}, + {{2533007,-4288439,-24467768,-12387405,-13450051, + 14542280,12876301,13893535,15067764,8594792,}, + {20073501,-11623621,3165391,-13119866,13188608, + -11540496,-10751437,-13482671,29588810,2197295,}, + {-1084082,11831693,6031797,14062724,14748428, + -8159962,-20721760,11742548,31368706,13161200,},}, + {{2050412,-6457589,15321215,5273360,25484180, + 124590,-18187548,-7097255,-6691621,-14604792,}, + {9938196,2162889,-6158074,-1711248,4278932, + -2598531,-22865792,-7168500,-24323168,11746309,}, + {-22691768,-14268164,5965485,9383325,20443693, + 5854192,28250679,-1381811,-10837134,13717818,},}, + {{-8495530,16382250,9548884,-4971523,-4491811, + -3902147,6182256,-12832479,26628081,10395408,}, + {27329048,-15853735,7715764,8717446,-9215518, + -14633480,28982250,-5668414,4227628,242148,}, + {-13279943,-7986904,-7100016,8764468,-27276630, + 3096719,29678419,-9141299,3906709,11265498,},}, + {{11918285,15686328,-17757323,-11217300,-27548967, + 4853165,-27168827,6807359,6871949,-1075745,}, + {-29002610,13984323,-27111812,-2713442,28107359, + -13266203,6155126,15104658,3538727,-7513788,}, + {14103158,11233913,-33165269,9279850,31014152, + 4335090,-1827936,4590951,13960841,12787712,},}, + {{1469134,-16738009,33411928,13942824,8092558, + -8778224,-11165065,1437842,22521552,-2792954,}, + {31352705,-4807352,-25327300,3962447,12541566, + -9399651,-27425693,7964818,-23829869,5541287,}, + {-25732021,-6864887,23848984,3039395,-9147354, + 6022816,-27421653,10590137,25309915,-1584678,},}, + {{-22951376,5048948,31139401,-190316,-19542447, + -626310,-17486305,-16511925,-18851313,-12985140,}, + {-9684890,14681754,30487568,7717771,-10829709, + 9630497,30290549,-10531496,-27798994,-13812825,}, + {5827835,16097107,-24501327,12094619,7413972, + 11447087,28057551,-1793987,-14056981,4359312,},}, + {{26323183,2342588,-21887793,-1623758,-6062284, + 2107090,-28724907,9036464,-19618351,-13055189,}, + {-29697200,14829398,-4596333,14220089,-30022969, + 2955645,12094100,-13693652,-5941445,7047569,}, + {-3201977,14413268,-12058324,-16417589,-9035655, + -7224648,9258160,1399236,30397584,-5684634,},}, +}; + +static void lookup_add(ge *p, ge_precomp *tmp_c, fe tmp_a, fe tmp_b, + const ge_precomp comb[8], const u8 scalar[32], int i) +{ + u8 teeth = (u8)((scalar_bit(scalar, i) ) + + (scalar_bit(scalar, i + 32) << 1) + + (scalar_bit(scalar, i + 64) << 2) + + (scalar_bit(scalar, i + 96) << 3)); + u8 high = teeth >> 3; + u8 index = (teeth ^ (high - 1)) & 7; + FOR (j, 0, 8) { + i32 select = 1 & (((j ^ index) - 1) >> 8); + fe_ccopy(tmp_c->Yp, comb[j].Yp, select); + fe_ccopy(tmp_c->Ym, comb[j].Ym, select); + fe_ccopy(tmp_c->T2, comb[j].T2, select); + } + fe_neg(tmp_a, tmp_c->T2); + fe_cswap(tmp_c->T2, tmp_a , high ^ 1); + fe_cswap(tmp_c->Yp, tmp_c->Ym, high ^ 1); + ge_madd(p, p, tmp_c, tmp_a, tmp_b); +} + +// p = [scalar]B, where B is the base point +static void ge_scalarmult_base(ge *p, const u8 scalar[32]) +{ + // twin 4-bits signed combs, from Mike Hamburg's + // Fast and compact elliptic-curve cryptography (2012) + // 1 / 2 modulo L + static const u8 half_mod_L[32] = { + 247,233,122,46,141,49,9,44,107,206,123,81,239,124,111,10, + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8, + }; + // (2^256 - 1) / 2 modulo L + static const u8 half_ones[32] = { + 142,74,204,70,186,24,118,107,184,231,190,57,250,173,119,99, + 255,255,255,255,255,255,255,255,255,255,255,255,255,255,255,7, + }; + + // All bits set form: 1 means 1, 0 means -1 + u8 s_scalar[32]; + crypto_eddsa_mul_add(s_scalar, scalar, half_mod_L, half_ones); + + // Double and add ladder + fe tmp_a, tmp_b; // temporaries for addition + ge_precomp tmp_c; // temporary for comb lookup + ge tmp_d; // temporary for doubling + fe_1(tmp_c.Yp); + fe_1(tmp_c.Ym); + fe_0(tmp_c.T2); + + // Save a double on the first iteration + ge_zero(p); + lookup_add(p, &tmp_c, tmp_a, tmp_b, b_comb_low , s_scalar, 31); + lookup_add(p, &tmp_c, tmp_a, tmp_b, b_comb_high, s_scalar, 31+128); + // Regular double & add for the rest + for (int i = 30; i >= 0; i--) { + ge_double(p, p, &tmp_d); + lookup_add(p, &tmp_c, tmp_a, tmp_b, b_comb_low , s_scalar, i); + lookup_add(p, &tmp_c, tmp_a, tmp_b, b_comb_high, s_scalar, i+128); + } + // Note: we could save one addition at the end if we assumed the + // scalar fit in 252 bits. Which it does in practice if it is + // selected at random. However, non-random, non-hashed scalars + // *can* overflow 252 bits in practice. Better account for that + // than leaving that kind of subtle corner case. + + WIPE_BUFFER(tmp_a); WIPE_CTX(&tmp_d); + WIPE_BUFFER(tmp_b); WIPE_CTX(&tmp_c); + WIPE_BUFFER(s_scalar); +} + +void crypto_eddsa_scalarbase(u8 point[32], const u8 scalar[32]) +{ + ge P; + ge_scalarmult_base(&P, scalar); + ge_tobytes(point, &P); + WIPE_CTX(&P); +} + +void crypto_eddsa_key_pair(u8 secret_key[64], u8 public_key[32], u8 seed[32]) +{ + // To allow overlaps, observable writes happen in this order: + // 1. seed + // 2. secret_key + // 3. public_key + u8 a[64]; + COPY(a, seed, 32); + crypto_wipe(seed, 32); + COPY(secret_key, a, 32); + crypto_blake2b(a, 64, a, 32); + crypto_eddsa_trim_scalar(a, a); + crypto_eddsa_scalarbase(secret_key + 32, a); + COPY(public_key, secret_key + 32, 32); + WIPE_BUFFER(a); +} + +static void hash_reduce(u8 h[32], + const u8 *a, size_t a_size, + const u8 *b, size_t b_size, + const u8 *c, size_t c_size) +{ + u8 hash[64]; + crypto_blake2b_ctx ctx; + crypto_blake2b_init (&ctx, 64); + crypto_blake2b_update(&ctx, a, a_size); + crypto_blake2b_update(&ctx, b, b_size); + crypto_blake2b_update(&ctx, c, c_size); + crypto_blake2b_final (&ctx, hash); + crypto_eddsa_reduce(h, hash); +} + +// Digital signature of a message with from a secret key. +// +// The secret key comprises two parts: +// - The seed that generates the key (secret_key[ 0..31]) +// - The public key (secret_key[32..63]) +// +// The seed and the public key are bundled together to make sure users +// don't use mismatched seeds and public keys, which would instantly +// leak the secret scalar and allow forgeries (allowing this to happen +// has resulted in critical vulnerabilities in the wild). +// +// The seed is hashed to derive the secret scalar and a secret prefix. +// The sole purpose of the prefix is to generate a secret random nonce. +// The properties of that nonce must be as follows: +// - Unique: we need a different one for each message. +// - Secret: third parties must not be able to predict it. +// - Random: any detectable bias would break all security. +// +// There are two ways to achieve these properties. The obvious one is +// to simply generate a random number. Here that would be a parameter +// (Monocypher doesn't have an RNG). It works, but then users may reuse +// the nonce by accident, which _also_ leaks the secret scalar and +// allows forgeries. This has happened in the wild too. +// +// This is no good, so instead we generate that nonce deterministically +// by reducing modulo L a hash of the secret prefix and the message. +// The secret prefix makes the nonce unpredictable, the message makes it +// unique, and the hash/reduce removes all bias. +// +// The cost of that safety is hashing the message twice. If that cost +// is unacceptable, there are two alternatives: +// +// - Signing a hash of the message instead of the message itself. This +// is fine as long as the hash is collision resistant. It is not +// compatible with existing "pure" signatures, but at least it's safe. +// +// - Using a random nonce. Please exercise **EXTREME CAUTION** if you +// ever do that. It is absolutely **critical** that the nonce is +// really an unbiased random number between 0 and L-1, never reused, +// and wiped immediately. +// +// To lower the likelihood of complete catastrophe if the RNG is +// either flawed or misused, you can hash the RNG output together with +// the secret prefix and the beginning of the message, and use the +// reduction of that hash instead of the RNG output itself. It's not +// foolproof (you'd need to hash the whole message) but it helps. +// +// Signing a message involves the following operations: +// +// scalar, prefix = HASH(secret_key) +// r = HASH(prefix || message) % L +// R = [r]B +// h = HASH(R || public_key || message) % L +// S = ((h * a) + r) % L +// signature = R || S +void crypto_eddsa_sign(u8 signature [64], const u8 secret_key[64], + const u8 *message, size_t message_size) +{ + u8 a[64]; // secret scalar and prefix + u8 r[32]; // secret deterministic "random" nonce + u8 h[32]; // publically verifiable hash of the message (not wiped) + u8 R[32]; // first half of the signature (allows overlapping inputs) + + crypto_blake2b(a, 64, secret_key, 32); + crypto_eddsa_trim_scalar(a, a); + hash_reduce(r, a + 32, 32, message, message_size, 0, 0); + crypto_eddsa_scalarbase(R, r); + hash_reduce(h, R, 32, secret_key + 32, 32, message, message_size); + COPY(signature, R, 32); + crypto_eddsa_mul_add(signature + 32, h, a, r); + + WIPE_BUFFER(a); + WIPE_BUFFER(r); +} + +// To check the signature R, S of the message M with the public key A, +// there are 3 steps: +// +// compute h = HASH(R || A || message) % L +// check that A is on the curve. +// check that R == [s]B - [h]A +// +// The last two steps are done in crypto_eddsa_check_equation() +int crypto_eddsa_check(const u8 signature[64], const u8 public_key[32], + const u8 *message, size_t message_size) +{ + u8 h[32]; + hash_reduce(h, signature, 32, public_key, 32, message, message_size); + return crypto_eddsa_check_equation(signature, public_key, h); +} + +///////////////////////// +/// EdDSA <--> X25519 /// +///////////////////////// +void crypto_eddsa_to_x25519(u8 x25519[32], const u8 eddsa[32]) +{ + // (u, v) = ((1+y)/(1-y), sqrt(-486664)*u/x) + // Only converting y to u, the sign of x is ignored. + fe t1, t2; + fe_frombytes(t2, eddsa); + fe_add(t1, fe_one, t2); + fe_sub(t2, fe_one, t2); + fe_invert(t2, t2); + fe_mul(t1, t1, t2); + fe_tobytes(x25519, t1); + WIPE_BUFFER(t1); + WIPE_BUFFER(t2); +} + +void crypto_x25519_to_eddsa(u8 eddsa[32], const u8 x25519[32]) +{ + // (x, y) = (sqrt(-486664)*u/v, (u-1)/(u+1)) + // Only converting u to y, x is assumed positive. + fe t1, t2; + fe_frombytes(t2, x25519); + fe_sub(t1, t2, fe_one); + fe_add(t2, t2, fe_one); + fe_invert(t2, t2); + fe_mul(t1, t1, t2); + fe_tobytes(eddsa, t1); + WIPE_BUFFER(t1); + WIPE_BUFFER(t2); +} + +///////////////////////////////////////////// +/// Dirty ephemeral public key generation /// +///////////////////////////////////////////// + +// Those functions generates a public key, *without* clearing the +// cofactor. Sending that key over the network leaks 3 bits of the +// private key. Use only to generate ephemeral keys that will be hidden +// with crypto_curve_to_hidden(). +// +// The public key is otherwise compatible with crypto_x25519(), which +// properly clears the cofactor. +// +// Note that the distribution of the resulting public keys is almost +// uniform. Flipping the sign of the v coordinate (not provided by this +// function), covers the entire key space almost perfectly, where +// "almost" means a 2^-128 bias (undetectable). This uniformity is +// needed to ensure the proper randomness of the resulting +// representatives (once we apply crypto_curve_to_hidden()). +// +// Recall that Curve25519 has order C = 2^255 + e, with e < 2^128 (not +// to be confused with the prime order of the main subgroup, L, which is +// 8 times less than that). +// +// Generating all points would require us to multiply a point of order C +// (the base point plus any point of order 8) by all scalars from 0 to +// C-1. Clamping limits us to scalars between 2^254 and 2^255 - 1. But +// by negating the resulting point at random, we also cover scalars from +// -2^255 + 1 to -2^254 (which modulo C is congruent to e+1 to 2^254 + e). +// +// In practice: +// - Scalars from 0 to e + 1 are never generated +// - Scalars from 2^255 to 2^255 + e are never generated +// - Scalars from 2^254 + 1 to 2^254 + e are generated twice +// +// Since e < 2^128, detecting this bias requires observing over 2^100 +// representatives from a given source (this will never happen), *and* +// recovering enough of the private key to determine that they do, or do +// not, belong to the biased set (this practically requires solving +// discrete logarithm, which is conjecturally intractable). +// +// In practice, this means the bias is impossible to detect. + +// s + (x*L) % 8*L +// Guaranteed to fit in 256 bits iff s fits in 255 bits. +// L < 2^253 +// x%8 < 2^3 +// L * (x%8) < 2^255 +// s < 2^255 +// s + L * (x%8) < 2^256 +static void add_xl(u8 s[32], u8 x) +{ + u64 mod8 = x & 7; + u64 carry = 0; + FOR (i , 0, 8) { + carry = carry + load32_le(s + 4*i) + L[i] * mod8; + store32_le(s + 4*i, (u32)carry); + carry >>= 32; + } +} + +// "Small" dirty ephemeral key. +// Use if you need to shrink the size of the binary, and can afford to +// slow down by a factor of two (compared to the fast version) +// +// This version works by decoupling the cofactor from the main factor. +// +// - The trimmed scalar determines the main factor +// - The clamped bits of the scalar determine the cofactor. +// +// Cofactor and main factor are combined into a single scalar, which is +// then multiplied by a point of order 8*L (unlike the base point, which +// has prime order). That "dirty" base point is the addition of the +// regular base point (9), and a point of order 8. +void crypto_x25519_dirty_small(u8 public_key[32], const u8 secret_key[32]) +{ + // Base point of order 8*L + // Raw scalar multiplication with it does not clear the cofactor, + // and the resulting public key will reveal 3 bits of the scalar. + // + // The low order component of this base point has been chosen + // to yield the same results as crypto_x25519_dirty_fast(). + static const u8 dirty_base_point[32] = { + 0xd8, 0x86, 0x1a, 0xa2, 0x78, 0x7a, 0xd9, 0x26, + 0x8b, 0x74, 0x74, 0xb6, 0x82, 0xe3, 0xbe, 0xc3, + 0xce, 0x36, 0x9a, 0x1e, 0x5e, 0x31, 0x47, 0xa2, + 0x6d, 0x37, 0x7c, 0xfd, 0x20, 0xb5, 0xdf, 0x75, + }; + // separate the main factor & the cofactor of the scalar + u8 scalar[32]; + crypto_eddsa_trim_scalar(scalar, secret_key); + + // Separate the main factor and the cofactor + // + // The scalar is trimmed, so its cofactor is cleared. The three + // least significant bits however still have a main factor. We must + // remove it for X25519 compatibility. + // + // cofactor = lsb * L (modulo 8*L) + // combined = scalar + cofactor (modulo 8*L) + add_xl(scalar, secret_key[0]); + scalarmult(public_key, scalar, dirty_base_point, 256); + WIPE_BUFFER(scalar); +} + +// Select low order point +// We're computing the [cofactor]lop scalar multiplication, where: +// +// cofactor = tweak & 7. +// lop = (lop_x, lop_y) +// lop_x = sqrt((sqrt(d + 1) + 1) / d) +// lop_y = -lop_x * sqrtm1 +// +// The low order point has order 8. There are 4 such points. We've +// chosen the one whose both coordinates are positive (below p/2). +// The 8 low order points are as follows: +// +// [0]lop = ( 0 , 1 ) +// [1]lop = ( lop_x , lop_y) +// [2]lop = ( sqrt(-1), -0 ) +// [3]lop = ( lop_x , -lop_y) +// [4]lop = (-0 , -1 ) +// [5]lop = (-lop_x , -lop_y) +// [6]lop = (-sqrt(-1), 0 ) +// [7]lop = (-lop_x , lop_y) +// +// The x coordinate is either 0, sqrt(-1), lop_x, or their opposite. +// The y coordinate is either 0, -1 , lop_y, or their opposite. +// The pattern for both is the same, except for a rotation of 2 (modulo 8) +// +// This helper function captures the pattern, and we can use it thus: +// +// select_lop(x, lop_x, sqrtm1, cofactor); +// select_lop(y, lop_y, fe_one, cofactor + 2); +// +// This is faster than an actual scalar multiplication, +// and requires less code than naive constant time look up. +static void select_lop(fe out, const fe x, const fe k, u8 cofactor) +{ + fe tmp; + fe_0(out); + fe_ccopy(out, k , (cofactor >> 1) & 1); // bit 1 + fe_ccopy(out, x , (cofactor >> 0) & 1); // bit 0 + fe_neg (tmp, out); + fe_ccopy(out, tmp, (cofactor >> 2) & 1); // bit 2 + WIPE_BUFFER(tmp); +} + +// "Fast" dirty ephemeral key +// We use this one by default. +// +// This version works by performing a regular scalar multiplication, +// then add a low order point. The scalar multiplication is done in +// Edwards space for more speed (*2 compared to the "small" version). +// The cost is a bigger binary for programs that don't also sign messages. +void crypto_x25519_dirty_fast(u8 public_key[32], const u8 secret_key[32]) +{ + // Compute clean scalar multiplication + u8 scalar[32]; + ge pk; + crypto_eddsa_trim_scalar(scalar, secret_key); + ge_scalarmult_base(&pk, scalar); + + // Compute low order point + fe t1, t2; + select_lop(t1, lop_x, sqrtm1, secret_key[0]); + select_lop(t2, lop_y, fe_one, secret_key[0] + 2); + ge_precomp low_order_point; + fe_add(low_order_point.Yp, t2, t1); + fe_sub(low_order_point.Ym, t2, t1); + fe_mul(low_order_point.T2, t2, t1); + fe_mul(low_order_point.T2, low_order_point.T2, D2); + + // Add low order point to the public key + ge_madd(&pk, &pk, &low_order_point, t1, t2); + + // Convert to Montgomery u coordinate (we ignore the sign) + fe_add(t1, pk.Z, pk.Y); + fe_sub(t2, pk.Z, pk.Y); + fe_invert(t2, t2); + fe_mul(t1, t1, t2); + + fe_tobytes(public_key, t1); + + WIPE_BUFFER(t1); WIPE_CTX(&pk); + WIPE_BUFFER(t2); WIPE_CTX(&low_order_point); + WIPE_BUFFER(scalar); +} + +/////////////////// +/// Elligator 2 /// +/////////////////// +static const fe A = {486662}; + +// Elligator direct map +// +// Computes the point corresponding to a representative, encoded in 32 +// bytes (little Endian). Since positive representatives fits in 254 +// bits, The two most significant bits are ignored. +// +// From the paper: +// w = -A / (fe(1) + non_square * r^2) +// e = chi(w^3 + A*w^2 + w) +// u = e*w - (fe(1)-e)*(A//2) +// v = -e * sqrt(u^3 + A*u^2 + u) +// +// We ignore v because we don't need it for X25519 (the Montgomery +// ladder only uses u). +// +// Note that e is either 0, 1 or -1 +// if e = 0 u = 0 and v = 0 +// if e = 1 u = w +// if e = -1 u = -w - A = w * non_square * r^2 +// +// Let r1 = non_square * r^2 +// Let r2 = 1 + r1 +// Note that r2 cannot be zero, -1/non_square is not a square. +// We can (tediously) verify that: +// w^3 + A*w^2 + w = (A^2*r1 - r2^2) * A / r2^3 +// Therefore: +// chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3)) +// chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3)) * 1 +// chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3)) * chi(r2^6) +// chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * (A / r2^3) * r2^6) +// chi(w^3 + A*w^2 + w) = chi((A^2*r1 - r2^2) * A * r2^3) +// Corollary: +// e = 1 if (A^2*r1 - r2^2) * A * r2^3) is a non-zero square +// e = -1 if (A^2*r1 - r2^2) * A * r2^3) is not a square +// Note that w^3 + A*w^2 + w (and therefore e) can never be zero: +// w^3 + A*w^2 + w = w * (w^2 + A*w + 1) +// w^3 + A*w^2 + w = w * (w^2 + A*w + A^2/4 - A^2/4 + 1) +// w^3 + A*w^2 + w = w * (w + A/2)^2 - A^2/4 + 1) +// which is zero only if: +// w = 0 (impossible) +// (w + A/2)^2 = A^2/4 - 1 (impossible, because A^2/4-1 is not a square) +// +// Let isr = invsqrt((A^2*r1 - r2^2) * A * r2^3) +// isr = sqrt(1 / ((A^2*r1 - r2^2) * A * r2^3)) if e = 1 +// isr = sqrt(sqrt(-1) / ((A^2*r1 - r2^2) * A * r2^3)) if e = -1 +// +// if e = 1 +// let u1 = -A * (A^2*r1 - r2^2) * A * r2^2 * isr^2 +// u1 = w +// u1 = u +// +// if e = -1 +// let ufactor = -non_square * sqrt(-1) * r^2 +// let vfactor = sqrt(ufactor) +// let u2 = -A * (A^2*r1 - r2^2) * A * r2^2 * isr^2 * ufactor +// u2 = w * -1 * -non_square * r^2 +// u2 = w * non_square * r^2 +// u2 = u +void crypto_elligator_map(u8 curve[32], const u8 hidden[32]) +{ + fe r, u, t1, t2, t3; + fe_frombytes_mask(r, hidden, 2); // r is encoded in 254 bits. + fe_sq(r, r); + fe_add(t1, r, r); + fe_add(u, t1, fe_one); + fe_sq (t2, u); + fe_mul(t3, A2, t1); + fe_sub(t3, t3, t2); + fe_mul(t3, t3, A); + fe_mul(t1, t2, u); + fe_mul(t1, t3, t1); + int is_square = invsqrt(t1, t1); + fe_mul(u, r, ufactor); + fe_ccopy(u, fe_one, is_square); + fe_sq (t1, t1); + fe_mul(u, u, A); + fe_mul(u, u, t3); + fe_mul(u, u, t2); + fe_mul(u, u, t1); + fe_neg(u, u); + fe_tobytes(curve, u); + + WIPE_BUFFER(t1); WIPE_BUFFER(r); + WIPE_BUFFER(t2); WIPE_BUFFER(u); + WIPE_BUFFER(t3); +} + +// Elligator inverse map +// +// Computes the representative of a point, if possible. If not, it does +// nothing and returns -1. Note that the success of the operation +// depends only on the point (more precisely its u coordinate). The +// tweak parameter is used only upon success +// +// The tweak should be a random byte. Beyond that, its contents are an +// implementation detail. Currently, the tweak comprises: +// - Bit 1 : sign of the v coordinate (0 if positive, 1 if negative) +// - Bit 2-5: not used +// - Bits 6-7: random padding +// +// From the paper: +// Let sq = -non_square * u * (u+A) +// if sq is not a square, or u = -A, there is no mapping +// Assuming there is a mapping: +// if v is positive: r = sqrt(-u / (non_square * (u+A))) +// if v is negative: r = sqrt(-(u+A) / (non_square * u )) +// +// We compute isr = invsqrt(-non_square * u * (u+A)) +// if it wasn't a square, abort. +// else, isr = sqrt(-1 / (non_square * u * (u+A)) +// +// If v is positive, we return isr * u: +// isr * u = sqrt(-1 / (non_square * u * (u+A)) * u +// isr * u = sqrt(-u / (non_square * (u+A)) +// +// If v is negative, we return isr * (u+A): +// isr * (u+A) = sqrt(-1 / (non_square * u * (u+A)) * (u+A) +// isr * (u+A) = sqrt(-(u+A) / (non_square * u) +int crypto_elligator_rev(u8 hidden[32], const u8 public_key[32], u8 tweak) +{ + fe t1, t2, t3; + fe_frombytes(t1, public_key); // t1 = u + + fe_add(t2, t1, A); // t2 = u + A + fe_mul(t3, t1, t2); + fe_mul_small(t3, t3, -2); + int is_square = invsqrt(t3, t3); // t3 = sqrt(-1 / non_square * u * (u+A)) + if (is_square) { + // The only variable time bit. This ultimately reveals how many + // tries it took us to find a representable key. + // This does not affect security as long as we try keys at random. + + fe_ccopy (t1, t2, tweak & 1); // multiply by u if v is positive, + fe_mul (t3, t1, t3); // multiply by u+A otherwise + fe_mul_small(t1, t3, 2); + fe_neg (t2, t3); + fe_ccopy (t3, t2, fe_isodd(t1)); + fe_tobytes(hidden, t3); + + // Pad with two random bits + hidden[31] |= tweak & 0xc0; + } + + WIPE_BUFFER(t1); + WIPE_BUFFER(t2); + WIPE_BUFFER(t3); + return is_square - 1; +} + +void crypto_elligator_key_pair(u8 hidden[32], u8 secret_key[32], u8 seed[32]) +{ + u8 pk [32]; // public key + u8 buf[64]; // seed + representative + COPY(buf + 32, seed, 32); + do { + crypto_chacha20_djb(buf, 0, 64, buf+32, zero, 0); + crypto_x25519_dirty_fast(pk, buf); // or the "small" version + } while(crypto_elligator_rev(buf+32, pk, buf[32])); + // Note that the return value of crypto_elligator_rev() is + // independent from its tweak parameter. + // Therefore, buf[32] is not actually reused. Either we loop one + // more time and buf[32] is used for the new seed, or we succeeded, + // and buf[32] becomes the tweak parameter. + + crypto_wipe(seed, 32); + COPY(hidden , buf + 32, 32); + COPY(secret_key, buf , 32); + WIPE_BUFFER(buf); + WIPE_BUFFER(pk); +} + +/////////////////////// +/// Scalar division /// +/////////////////////// + +// Montgomery reduction. +// Divides x by (2^256), and reduces the result modulo L +// +// Precondition: +// x < L * 2^256 +// Constants: +// r = 2^256 (makes division by r trivial) +// k = (r * (1/r) - 1) // L (1/r is computed modulo L ) +// Algorithm: +// s = (x * k) % r +// t = x + s*L (t is always a multiple of r) +// u = (t/r) % L (u is always below 2*L, conditional subtraction is enough) +static void redc(u32 u[8], u32 x[16]) +{ + static const u32 k[8] = { + 0x12547e1b, 0xd2b51da3, 0xfdba84ff, 0xb1a206f2, + 0xffa36bea, 0x14e75438, 0x6fe91836, 0x9db6c6f2, + }; + + // s = x * k (modulo 2^256) + // This is cheaper than the full multiplication. + u32 s[8] = {0}; + FOR (i, 0, 8) { + u64 carry = 0; + FOR (j, 0, 8-i) { + carry += s[i+j] + (u64)x[i] * k[j]; + s[i+j] = (u32)carry; + carry >>= 32; + } + } + u32 t[16] = {0}; + multiply(t, s, L); + + // t = t + x + u64 carry = 0; + FOR (i, 0, 16) { + carry += (u64)t[i] + x[i]; + t[i] = (u32)carry; + carry >>= 32; + } + + // u = (t / 2^256) % L + // Note that t / 2^256 is always below 2*L, + // So a constant time conditional subtraction is enough + remove_l(u, t+8); + + WIPE_BUFFER(s); + WIPE_BUFFER(t); +} + +void crypto_x25519_inverse(u8 blind_salt [32], const u8 private_key[32], + const u8 curve_point[32]) +{ + static const u8 Lm2[32] = { // L - 2 + 0xeb, 0xd3, 0xf5, 0x5c, 0x1a, 0x63, 0x12, 0x58, + 0xd6, 0x9c, 0xf7, 0xa2, 0xde, 0xf9, 0xde, 0x14, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, + }; + // 1 in Montgomery form + u32 m_inv [8] = { + 0x8d98951d, 0xd6ec3174, 0x737dcf70, 0xc6ef5bf4, + 0xfffffffe, 0xffffffff, 0xffffffff, 0x0fffffff, + }; + + u8 scalar[32]; + crypto_eddsa_trim_scalar(scalar, private_key); + + // Convert the scalar in Montgomery form + // m_scl = scalar * 2^256 (modulo L) + u32 m_scl[8]; + { + u32 tmp[16]; + ZERO(tmp, 8); + load32_le_buf(tmp+8, scalar, 8); + mod_l(scalar, tmp); + load32_le_buf(m_scl, scalar, 8); + WIPE_BUFFER(tmp); // Wipe ASAP to save stack space + } + + // Compute the inverse + u32 product[16]; + for (int i = 252; i >= 0; i--) { + ZERO(product, 16); + multiply(product, m_inv, m_inv); + redc(m_inv, product); + if (scalar_bit(Lm2, i)) { + ZERO(product, 16); + multiply(product, m_inv, m_scl); + redc(m_inv, product); + } + } + // Convert the inverse *out* of Montgomery form + // scalar = m_inv / 2^256 (modulo L) + COPY(product, m_inv, 8); + ZERO(product + 8, 8); + redc(m_inv, product); + store32_le_buf(scalar, m_inv, 8); // the *inverse* of the scalar + + // Clear the cofactor of scalar: + // cleared = scalar * (3*L + 1) (modulo 8*L) + // cleared = scalar + scalar * 3 * L (modulo 8*L) + // Note that (scalar * 3) is reduced modulo 8, so we only need the + // first byte. + add_xl(scalar, scalar[0] * 3); + + // Recall that 8*L < 2^256. However it is also very close to + // 2^255. If we spanned the ladder over 255 bits, random tests + // wouldn't catch the off-by-one error. + scalarmult(blind_salt, scalar, curve_point, 256); + + WIPE_BUFFER(scalar); WIPE_BUFFER(m_scl); + WIPE_BUFFER(product); WIPE_BUFFER(m_inv); +} + +//////////////////////////////// +/// Authenticated encryption /// +//////////////////////////////// +static void lock_auth(u8 mac[16], const u8 auth_key[32], + const u8 *ad , size_t ad_size, + const u8 *cipher_text, size_t text_size) +{ + u8 sizes[16]; // Not secret, not wiped + store64_le(sizes + 0, ad_size); + store64_le(sizes + 8, text_size); + crypto_poly1305_ctx poly_ctx; // auto wiped... + crypto_poly1305_init (&poly_ctx, auth_key); + crypto_poly1305_update(&poly_ctx, ad , ad_size); + crypto_poly1305_update(&poly_ctx, zero , gap(ad_size, 16)); + crypto_poly1305_update(&poly_ctx, cipher_text, text_size); + crypto_poly1305_update(&poly_ctx, zero , gap(text_size, 16)); + crypto_poly1305_update(&poly_ctx, sizes , 16); + crypto_poly1305_final (&poly_ctx, mac); // ...here +} + +void crypto_aead_init_x(crypto_aead_ctx *ctx, + u8 const key[32], const u8 nonce[24]) +{ + crypto_chacha20_h(ctx->key, key, nonce); + COPY(ctx->nonce, nonce + 16, 8); + ctx->counter = 0; +} + +void crypto_aead_init_djb(crypto_aead_ctx *ctx, + const u8 key[32], const u8 nonce[8]) +{ + COPY(ctx->key , key , 32); + COPY(ctx->nonce, nonce, 8); + ctx->counter = 0; +} + +void crypto_aead_init_ietf(crypto_aead_ctx *ctx, + const u8 key[32], const u8 nonce[12]) +{ + COPY(ctx->key , key , 32); + COPY(ctx->nonce, nonce + 4, 8); + ctx->counter = (u64)load32_le(nonce) << 32; +} + +void crypto_aead_write(crypto_aead_ctx *ctx, u8 *cipher_text, u8 mac[16], + const u8 *ad, size_t ad_size, + const u8 *plain_text, size_t text_size) +{ + u8 auth_key[64]; // the last 32 bytes are used for rekeying. + crypto_chacha20_djb(auth_key, 0, 64, ctx->key, ctx->nonce, ctx->counter); + crypto_chacha20_djb(cipher_text, plain_text, text_size, + ctx->key, ctx->nonce, ctx->counter + 1); + lock_auth(mac, auth_key, ad, ad_size, cipher_text, text_size); + COPY(ctx->key, auth_key + 32, 32); + WIPE_BUFFER(auth_key); +} + +int crypto_aead_read(crypto_aead_ctx *ctx, u8 *plain_text, const u8 mac[16], + const u8 *ad, size_t ad_size, + const u8 *cipher_text, size_t text_size) +{ + u8 auth_key[64]; // the last 32 bytes are used for rekeying. + u8 real_mac[16]; + crypto_chacha20_djb(auth_key, 0, 64, ctx->key, ctx->nonce, ctx->counter); + lock_auth(real_mac, auth_key, ad, ad_size, cipher_text, text_size); + int mismatch = crypto_verify16(mac, real_mac); + if (!mismatch) { + crypto_chacha20_djb(plain_text, cipher_text, text_size, + ctx->key, ctx->nonce, ctx->counter + 1); + COPY(ctx->key, auth_key + 32, 32); + } + WIPE_BUFFER(auth_key); + WIPE_BUFFER(real_mac); + return mismatch; +} + +void crypto_aead_lock(u8 *cipher_text, u8 mac[16], const u8 key[32], + const u8 nonce[24], const u8 *ad, size_t ad_size, + const u8 *plain_text, size_t text_size) +{ + crypto_aead_ctx ctx; + crypto_aead_init_x(&ctx, key, nonce); + crypto_aead_write(&ctx, cipher_text, mac, ad, ad_size, + plain_text, text_size); + crypto_wipe(&ctx, sizeof(ctx)); +} + +int crypto_aead_unlock(u8 *plain_text, const u8 mac[16], const u8 key[32], + const u8 nonce[24], const u8 *ad, size_t ad_size, + const u8 *cipher_text, size_t text_size) +{ + crypto_aead_ctx ctx; + crypto_aead_init_x(&ctx, key, nonce); + int mismatch = crypto_aead_read(&ctx, plain_text, mac, ad, ad_size, + cipher_text, text_size); + crypto_wipe(&ctx, sizeof(ctx)); + return mismatch; +} + +#ifdef MONOCYPHER_CPP_NAMESPACE +} +#endif diff --git a/libs/monocypher.h b/libs/monocypher.h new file mode 100644 index 0000000..765a07f --- /dev/null +++ b/libs/monocypher.h @@ -0,0 +1,321 @@ +// Monocypher version 4.0.2 +// +// This file is dual-licensed. Choose whichever licence you want from +// the two licences listed below. +// +// The first licence is a regular 2-clause BSD licence. The second licence +// is the CC-0 from Creative Commons. It is intended to release Monocypher +// to the public domain. The BSD licence serves as a fallback option. +// +// SPDX-License-Identifier: BSD-2-Clause OR CC0-1.0 +// +// ------------------------------------------------------------------------ +// +// Copyright (c) 2017-2019, Loup Vaillant +// All rights reserved. +// +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ------------------------------------------------------------------------ +// +// Written in 2017-2019 by Loup Vaillant +// +// To the extent possible under law, the author(s) have dedicated all copyright +// and related neighboring rights to this software to the public domain +// worldwide. This software is distributed without any warranty. +// +// You should have received a copy of the CC0 Public Domain Dedication along +// with this software. If not, see +// + +#ifndef MONOCYPHER_H +#define MONOCYPHER_H + +#include +#include + +#ifdef MONOCYPHER_CPP_NAMESPACE +namespace MONOCYPHER_CPP_NAMESPACE { +#elif defined(__cplusplus) +extern "C" { +#endif + +// Constant time comparisons +// ------------------------- + +// Return 0 if a and b are equal, -1 otherwise +int crypto_verify16(const uint8_t a[16], const uint8_t b[16]); +int crypto_verify32(const uint8_t a[32], const uint8_t b[32]); +int crypto_verify64(const uint8_t a[64], const uint8_t b[64]); + + +// Erase sensitive data +// -------------------- +void crypto_wipe(void *secret, size_t size); + + +// Authenticated encryption +// ------------------------ +void crypto_aead_lock(uint8_t *cipher_text, + uint8_t mac [16], + const uint8_t key [32], + const uint8_t nonce[24], + const uint8_t *ad, size_t ad_size, + const uint8_t *plain_text, size_t text_size); +int crypto_aead_unlock(uint8_t *plain_text, + const uint8_t mac [16], + const uint8_t key [32], + const uint8_t nonce[24], + const uint8_t *ad, size_t ad_size, + const uint8_t *cipher_text, size_t text_size); + +// Authenticated stream +// -------------------- +typedef struct { + uint64_t counter; + uint8_t key[32]; + uint8_t nonce[8]; +} crypto_aead_ctx; + +void crypto_aead_init_x(crypto_aead_ctx *ctx, + const uint8_t key[32], const uint8_t nonce[24]); +void crypto_aead_init_djb(crypto_aead_ctx *ctx, + const uint8_t key[32], const uint8_t nonce[8]); +void crypto_aead_init_ietf(crypto_aead_ctx *ctx, + const uint8_t key[32], const uint8_t nonce[12]); + +void crypto_aead_write(crypto_aead_ctx *ctx, + uint8_t *cipher_text, + uint8_t mac[16], + const uint8_t *ad , size_t ad_size, + const uint8_t *plain_text, size_t text_size); +int crypto_aead_read(crypto_aead_ctx *ctx, + uint8_t *plain_text, + const uint8_t mac[16], + const uint8_t *ad , size_t ad_size, + const uint8_t *cipher_text, size_t text_size); + + +// General purpose hash (BLAKE2b) +// ------------------------------ + +// Direct interface +void crypto_blake2b(uint8_t *hash, size_t hash_size, + const uint8_t *message, size_t message_size); + +void crypto_blake2b_keyed(uint8_t *hash, size_t hash_size, + const uint8_t *key, size_t key_size, + const uint8_t *message, size_t message_size); + +// Incremental interface +typedef struct { + // Do not rely on the size or contents of this type, + // for they may change without notice. + uint64_t hash[8]; + uint64_t input_offset[2]; + uint64_t input[16]; + size_t input_idx; + size_t hash_size; +} crypto_blake2b_ctx; + +void crypto_blake2b_init(crypto_blake2b_ctx *ctx, size_t hash_size); +void crypto_blake2b_keyed_init(crypto_blake2b_ctx *ctx, size_t hash_size, + const uint8_t *key, size_t key_size); +void crypto_blake2b_update(crypto_blake2b_ctx *ctx, + const uint8_t *message, size_t message_size); +void crypto_blake2b_final(crypto_blake2b_ctx *ctx, uint8_t *hash); + + +// Password key derivation (Argon2) +// -------------------------------- +#define CRYPTO_ARGON2_D 0 +#define CRYPTO_ARGON2_I 1 +#define CRYPTO_ARGON2_ID 2 + +typedef struct { + uint32_t algorithm; // Argon2d, Argon2i, Argon2id + uint32_t nb_blocks; // memory hardness, >= 8 * nb_lanes + uint32_t nb_passes; // CPU hardness, >= 1 (>= 3 recommended for Argon2i) + uint32_t nb_lanes; // parallelism level (single threaded anyway) +} crypto_argon2_config; + +typedef struct { + const uint8_t *pass; + const uint8_t *salt; + uint32_t pass_size; + uint32_t salt_size; // 16 bytes recommended +} crypto_argon2_inputs; + +typedef struct { + const uint8_t *key; // may be NULL if no key + const uint8_t *ad; // may be NULL if no additional data + uint32_t key_size; // 0 if no key (32 bytes recommended otherwise) + uint32_t ad_size; // 0 if no additional data +} crypto_argon2_extras; + +extern const crypto_argon2_extras crypto_argon2_no_extras; + +void crypto_argon2(uint8_t *hash, uint32_t hash_size, void *work_area, + crypto_argon2_config config, + crypto_argon2_inputs inputs, + crypto_argon2_extras extras); + + +// Key exchange (X-25519) +// ---------------------- + +// Shared secrets are not quite random. +// Hash them to derive an actual shared key. +void crypto_x25519_public_key(uint8_t public_key[32], + const uint8_t secret_key[32]); +void crypto_x25519(uint8_t raw_shared_secret[32], + const uint8_t your_secret_key [32], + const uint8_t their_public_key [32]); + +// Conversion to EdDSA +void crypto_x25519_to_eddsa(uint8_t eddsa[32], const uint8_t x25519[32]); + +// scalar "division" +// Used for OPRF. Be aware that exponential blinding is less secure +// than Diffie-Hellman key exchange. +void crypto_x25519_inverse(uint8_t blind_salt [32], + const uint8_t private_key[32], + const uint8_t curve_point[32]); + +// "Dirty" versions of x25519_public_key(). +// Use with crypto_elligator_rev(). +// Leaks 3 bits of the private key. +void crypto_x25519_dirty_small(uint8_t pk[32], const uint8_t sk[32]); +void crypto_x25519_dirty_fast (uint8_t pk[32], const uint8_t sk[32]); + + +// Signatures +// ---------- + +// EdDSA with curve25519 + BLAKE2b +void crypto_eddsa_key_pair(uint8_t secret_key[64], + uint8_t public_key[32], + uint8_t seed[32]); +void crypto_eddsa_sign(uint8_t signature [64], + const uint8_t secret_key[64], + const uint8_t *message, size_t message_size); +int crypto_eddsa_check(const uint8_t signature [64], + const uint8_t public_key[32], + const uint8_t *message, size_t message_size); + +// Conversion to X25519 +void crypto_eddsa_to_x25519(uint8_t x25519[32], const uint8_t eddsa[32]); + +// EdDSA building blocks +void crypto_eddsa_trim_scalar(uint8_t out[32], const uint8_t in[32]); +void crypto_eddsa_reduce(uint8_t reduced[32], const uint8_t expanded[64]); +void crypto_eddsa_mul_add(uint8_t r[32], + const uint8_t a[32], + const uint8_t b[32], + const uint8_t c[32]); +void crypto_eddsa_scalarbase(uint8_t point[32], const uint8_t scalar[32]); +int crypto_eddsa_check_equation(const uint8_t signature[64], + const uint8_t public_key[32], + const uint8_t h_ram[32]); + + +// Chacha20 +// -------- + +// Specialised hash. +// Used to hash X25519 shared secrets. +void crypto_chacha20_h(uint8_t out[32], + const uint8_t key[32], + const uint8_t in [16]); + +// Unauthenticated stream cipher. +// Don't forget to add authentication. +uint64_t crypto_chacha20_djb(uint8_t *cipher_text, + const uint8_t *plain_text, + size_t text_size, + const uint8_t key[32], + const uint8_t nonce[8], + uint64_t ctr); +uint32_t crypto_chacha20_ietf(uint8_t *cipher_text, + const uint8_t *plain_text, + size_t text_size, + const uint8_t key[32], + const uint8_t nonce[12], + uint32_t ctr); +uint64_t crypto_chacha20_x(uint8_t *cipher_text, + const uint8_t *plain_text, + size_t text_size, + const uint8_t key[32], + const uint8_t nonce[24], + uint64_t ctr); + + +// Poly 1305 +// --------- + +// This is a *one time* authenticator. +// Disclosing the mac reveals the key. +// See crypto_lock() on how to use it properly. + +// Direct interface +void crypto_poly1305(uint8_t mac[16], + const uint8_t *message, size_t message_size, + const uint8_t key[32]); + +// Incremental interface +typedef struct { + // Do not rely on the size or contents of this type, + // for they may change without notice. + uint8_t c[16]; // chunk of the message + size_t c_idx; // How many bytes are there in the chunk. + uint32_t r [4]; // constant multiplier (from the secret key) + uint32_t pad[4]; // random number added at the end (from the secret key) + uint32_t h [5]; // accumulated hash +} crypto_poly1305_ctx; + +void crypto_poly1305_init (crypto_poly1305_ctx *ctx, const uint8_t key[32]); +void crypto_poly1305_update(crypto_poly1305_ctx *ctx, + const uint8_t *message, size_t message_size); +void crypto_poly1305_final (crypto_poly1305_ctx *ctx, uint8_t mac[16]); + + +// Elligator 2 +// ----------- + +// Elligator mappings proper +void crypto_elligator_map(uint8_t curve [32], const uint8_t hidden[32]); +int crypto_elligator_rev(uint8_t hidden[32], const uint8_t curve [32], + uint8_t tweak); + +// Easy to use key pair generation +void crypto_elligator_key_pair(uint8_t hidden[32], uint8_t secret_key[32], + uint8_t seed[32]); + +#ifdef __cplusplus +} +#endif + +#endif // MONOCYPHER_H diff --git a/libs/optional/monocypher-ed25519.c b/libs/optional/monocypher-ed25519.c new file mode 100644 index 0000000..1dbcfbb --- /dev/null +++ b/libs/optional/monocypher-ed25519.c @@ -0,0 +1,500 @@ +// Monocypher version 4.0.2 +// +// This file is dual-licensed. Choose whichever licence you want from +// the two licences listed below. +// +// The first licence is a regular 2-clause BSD licence. The second licence +// is the CC-0 from Creative Commons. It is intended to release Monocypher +// to the public domain. The BSD licence serves as a fallback option. +// +// SPDX-License-Identifier: BSD-2-Clause OR CC0-1.0 +// +// ------------------------------------------------------------------------ +// +// Copyright (c) 2017-2019, Loup Vaillant +// All rights reserved. +// +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ------------------------------------------------------------------------ +// +// Written in 2017-2019 by Loup Vaillant +// +// To the extent possible under law, the author(s) have dedicated all copyright +// and related neighboring rights to this software to the public domain +// worldwide. This software is distributed without any warranty. +// +// You should have received a copy of the CC0 Public Domain Dedication along +// with this software. If not, see +// + +#include "monocypher-ed25519.h" + +#ifdef MONOCYPHER_CPP_NAMESPACE +namespace MONOCYPHER_CPP_NAMESPACE { +#endif + +///////////////// +/// Utilities /// +///////////////// +#define FOR(i, min, max) for (size_t i = min; i < max; i++) +#define COPY(dst, src, size) FOR(_i_, 0, size) (dst)[_i_] = (src)[_i_] +#define ZERO(buf, size) FOR(_i_, 0, size) (buf)[_i_] = 0 +#define WIPE_CTX(ctx) crypto_wipe(ctx , sizeof(*(ctx))) +#define WIPE_BUFFER(buffer) crypto_wipe(buffer, sizeof(buffer)) +#define MIN(a, b) ((a) <= (b) ? (a) : (b)) +typedef uint8_t u8; +typedef uint64_t u64; + +// Returns the smallest positive integer y such that +// (x + y) % pow_2 == 0 +// Basically, it's how many bytes we need to add to "align" x. +// Only works when pow_2 is a power of 2. +// Note: we use ~x+1 instead of -x to avoid compiler warnings +static size_t align(size_t x, size_t pow_2) +{ + return (~x + 1) & (pow_2 - 1); +} + +static u64 load64_be(const u8 s[8]) +{ + return((u64)s[0] << 56) + | ((u64)s[1] << 48) + | ((u64)s[2] << 40) + | ((u64)s[3] << 32) + | ((u64)s[4] << 24) + | ((u64)s[5] << 16) + | ((u64)s[6] << 8) + | (u64)s[7]; +} + +static void store64_be(u8 out[8], u64 in) +{ + out[0] = (in >> 56) & 0xff; + out[1] = (in >> 48) & 0xff; + out[2] = (in >> 40) & 0xff; + out[3] = (in >> 32) & 0xff; + out[4] = (in >> 24) & 0xff; + out[5] = (in >> 16) & 0xff; + out[6] = (in >> 8) & 0xff; + out[7] = in & 0xff; +} + +static void load64_be_buf (u64 *dst, const u8 *src, size_t size) { + FOR(i, 0, size) { dst[i] = load64_be(src + i*8); } +} + +/////////////// +/// SHA 512 /// +/////////////// +static u64 rot(u64 x, int c ) { return (x >> c) | (x << (64 - c)); } +static u64 ch (u64 x, u64 y, u64 z) { return (x & y) ^ (~x & z); } +static u64 maj(u64 x, u64 y, u64 z) { return (x & y) ^ ( x & z) ^ (y & z); } +static u64 big_sigma0(u64 x) { return rot(x, 28) ^ rot(x, 34) ^ rot(x, 39); } +static u64 big_sigma1(u64 x) { return rot(x, 14) ^ rot(x, 18) ^ rot(x, 41); } +static u64 lit_sigma0(u64 x) { return rot(x, 1) ^ rot(x, 8) ^ (x >> 7); } +static u64 lit_sigma1(u64 x) { return rot(x, 19) ^ rot(x, 61) ^ (x >> 6); } + +static const u64 K[80] = { + 0x428a2f98d728ae22,0x7137449123ef65cd,0xb5c0fbcfec4d3b2f,0xe9b5dba58189dbbc, + 0x3956c25bf348b538,0x59f111f1b605d019,0x923f82a4af194f9b,0xab1c5ed5da6d8118, + 0xd807aa98a3030242,0x12835b0145706fbe,0x243185be4ee4b28c,0x550c7dc3d5ffb4e2, + 0x72be5d74f27b896f,0x80deb1fe3b1696b1,0x9bdc06a725c71235,0xc19bf174cf692694, + 0xe49b69c19ef14ad2,0xefbe4786384f25e3,0x0fc19dc68b8cd5b5,0x240ca1cc77ac9c65, + 0x2de92c6f592b0275,0x4a7484aa6ea6e483,0x5cb0a9dcbd41fbd4,0x76f988da831153b5, + 0x983e5152ee66dfab,0xa831c66d2db43210,0xb00327c898fb213f,0xbf597fc7beef0ee4, + 0xc6e00bf33da88fc2,0xd5a79147930aa725,0x06ca6351e003826f,0x142929670a0e6e70, + 0x27b70a8546d22ffc,0x2e1b21385c26c926,0x4d2c6dfc5ac42aed,0x53380d139d95b3df, + 0x650a73548baf63de,0x766a0abb3c77b2a8,0x81c2c92e47edaee6,0x92722c851482353b, + 0xa2bfe8a14cf10364,0xa81a664bbc423001,0xc24b8b70d0f89791,0xc76c51a30654be30, + 0xd192e819d6ef5218,0xd69906245565a910,0xf40e35855771202a,0x106aa07032bbd1b8, + 0x19a4c116b8d2d0c8,0x1e376c085141ab53,0x2748774cdf8eeb99,0x34b0bcb5e19b48a8, + 0x391c0cb3c5c95a63,0x4ed8aa4ae3418acb,0x5b9cca4f7763e373,0x682e6ff3d6b2b8a3, + 0x748f82ee5defb2fc,0x78a5636f43172f60,0x84c87814a1f0ab72,0x8cc702081a6439ec, + 0x90befffa23631e28,0xa4506cebde82bde9,0xbef9a3f7b2c67915,0xc67178f2e372532b, + 0xca273eceea26619c,0xd186b8c721c0c207,0xeada7dd6cde0eb1e,0xf57d4f7fee6ed178, + 0x06f067aa72176fba,0x0a637dc5a2c898a6,0x113f9804bef90dae,0x1b710b35131c471b, + 0x28db77f523047d84,0x32caab7b40c72493,0x3c9ebe0a15c9bebc,0x431d67c49c100d4c, + 0x4cc5d4becb3e42b6,0x597f299cfc657e2a,0x5fcb6fab3ad6faec,0x6c44198c4a475817 +}; + +static void sha512_compress(crypto_sha512_ctx *ctx) +{ + u64 a = ctx->hash[0]; u64 b = ctx->hash[1]; + u64 c = ctx->hash[2]; u64 d = ctx->hash[3]; + u64 e = ctx->hash[4]; u64 f = ctx->hash[5]; + u64 g = ctx->hash[6]; u64 h = ctx->hash[7]; + + FOR (j, 0, 16) { + u64 in = K[j] + ctx->input[j]; + u64 t1 = big_sigma1(e) + ch (e, f, g) + h + in; + u64 t2 = big_sigma0(a) + maj(a, b, c); + h = g; g = f; f = e; e = d + t1; + d = c; c = b; b = a; a = t1 + t2; + } + size_t i16 = 0; + FOR(i, 1, 5) { + i16 += 16; + FOR (j, 0, 16) { + ctx->input[j] += lit_sigma1(ctx->input[(j- 2) & 15]); + ctx->input[j] += lit_sigma0(ctx->input[(j-15) & 15]); + ctx->input[j] += ctx->input[(j- 7) & 15]; + u64 in = K[i16 + j] + ctx->input[j]; + u64 t1 = big_sigma1(e) + ch (e, f, g) + h + in; + u64 t2 = big_sigma0(a) + maj(a, b, c); + h = g; g = f; f = e; e = d + t1; + d = c; c = b; b = a; a = t1 + t2; + } + } + + ctx->hash[0] += a; ctx->hash[1] += b; + ctx->hash[2] += c; ctx->hash[3] += d; + ctx->hash[4] += e; ctx->hash[5] += f; + ctx->hash[6] += g; ctx->hash[7] += h; +} + +// Write 1 input byte +static void sha512_set_input(crypto_sha512_ctx *ctx, u8 input) +{ + size_t word = ctx->input_idx >> 3; + size_t byte = ctx->input_idx & 7; + ctx->input[word] |= (u64)input << (8 * (7 - byte)); +} + +// Increment a 128-bit "word". +static void sha512_incr(u64 x[2], u64 y) +{ + x[1] += y; + if (x[1] < y) { + x[0]++; + } +} + +void crypto_sha512_init(crypto_sha512_ctx *ctx) +{ + ctx->hash[0] = 0x6a09e667f3bcc908; + ctx->hash[1] = 0xbb67ae8584caa73b; + ctx->hash[2] = 0x3c6ef372fe94f82b; + ctx->hash[3] = 0xa54ff53a5f1d36f1; + ctx->hash[4] = 0x510e527fade682d1; + ctx->hash[5] = 0x9b05688c2b3e6c1f; + ctx->hash[6] = 0x1f83d9abfb41bd6b; + ctx->hash[7] = 0x5be0cd19137e2179; + ctx->input_size[0] = 0; + ctx->input_size[1] = 0; + ctx->input_idx = 0; + ZERO(ctx->input, 16); +} + +void crypto_sha512_update(crypto_sha512_ctx *ctx, + const u8 *message, size_t message_size) +{ + // Avoid undefined NULL pointer increments with empty messages + if (message_size == 0) { + return; + } + + // Align ourselves with word boundaries + if ((ctx->input_idx & 7) != 0) { + size_t nb_bytes = MIN(align(ctx->input_idx, 8), message_size); + FOR (i, 0, nb_bytes) { + sha512_set_input(ctx, message[i]); + ctx->input_idx++; + } + message += nb_bytes; + message_size -= nb_bytes; + } + + // Align ourselves with block boundaries + if ((ctx->input_idx & 127) != 0) { + size_t nb_words = MIN(align(ctx->input_idx, 128), message_size) >> 3; + load64_be_buf(ctx->input + (ctx->input_idx >> 3), message, nb_words); + ctx->input_idx += nb_words << 3; + message += nb_words << 3; + message_size -= nb_words << 3; + } + + // Compress block if needed + if (ctx->input_idx == 128) { + sha512_incr(ctx->input_size, 1024); // size is in bits + sha512_compress(ctx); + ctx->input_idx = 0; + ZERO(ctx->input, 16); + } + + // Process the message block by block + FOR (i, 0, message_size >> 7) { // number of blocks + load64_be_buf(ctx->input, message, 16); + sha512_incr(ctx->input_size, 1024); // size is in bits + sha512_compress(ctx); + ctx->input_idx = 0; + ZERO(ctx->input, 16); + message += 128; + } + message_size &= 127; + + if (message_size != 0) { + // Remaining words + size_t nb_words = message_size >> 3; + load64_be_buf(ctx->input, message, nb_words); + ctx->input_idx += nb_words << 3; + message += nb_words << 3; + message_size -= nb_words << 3; + + // Remaining bytes + FOR (i, 0, message_size) { + sha512_set_input(ctx, message[i]); + ctx->input_idx++; + } + } +} + +void crypto_sha512_final(crypto_sha512_ctx *ctx, u8 hash[64]) +{ + // Add padding bit + if (ctx->input_idx == 0) { + ZERO(ctx->input, 16); + } + sha512_set_input(ctx, 128); + + // Update size + sha512_incr(ctx->input_size, ctx->input_idx * 8); + + // Compress penultimate block (if any) + if (ctx->input_idx > 111) { + sha512_compress(ctx); + ZERO(ctx->input, 14); + } + // Compress last block + ctx->input[14] = ctx->input_size[0]; + ctx->input[15] = ctx->input_size[1]; + sha512_compress(ctx); + + // Copy hash to output (big endian) + FOR (i, 0, 8) { + store64_be(hash + i*8, ctx->hash[i]); + } + + WIPE_CTX(ctx); +} + +void crypto_sha512(u8 hash[64], const u8 *message, size_t message_size) +{ + crypto_sha512_ctx ctx; + crypto_sha512_init (&ctx); + crypto_sha512_update(&ctx, message, message_size); + crypto_sha512_final (&ctx, hash); +} + +//////////////////// +/// HMAC SHA 512 /// +//////////////////// +void crypto_sha512_hmac_init(crypto_sha512_hmac_ctx *ctx, + const u8 *key, size_t key_size) +{ + // hash key if it is too long + if (key_size > 128) { + crypto_sha512(ctx->key, key, key_size); + key = ctx->key; + key_size = 64; + } + // Compute inner key: padded key XOR 0x36 + FOR (i, 0, key_size) { ctx->key[i] = key[i] ^ 0x36; } + FOR (i, key_size, 128) { ctx->key[i] = 0x36; } + // Start computing inner hash + crypto_sha512_init (&ctx->ctx); + crypto_sha512_update(&ctx->ctx, ctx->key, 128); +} + +void crypto_sha512_hmac_update(crypto_sha512_hmac_ctx *ctx, + const u8 *message, size_t message_size) +{ + crypto_sha512_update(&ctx->ctx, message, message_size); +} + +void crypto_sha512_hmac_final(crypto_sha512_hmac_ctx *ctx, u8 hmac[64]) +{ + // Finish computing inner hash + crypto_sha512_final(&ctx->ctx, hmac); + // Compute outer key: padded key XOR 0x5c + FOR (i, 0, 128) { + ctx->key[i] ^= 0x36 ^ 0x5c; + } + // Compute outer hash + crypto_sha512_init (&ctx->ctx); + crypto_sha512_update(&ctx->ctx, ctx->key , 128); + crypto_sha512_update(&ctx->ctx, hmac, 64); + crypto_sha512_final (&ctx->ctx, hmac); // outer hash + WIPE_CTX(ctx); +} + +void crypto_sha512_hmac(u8 hmac[64], const u8 *key, size_t key_size, + const u8 *message, size_t message_size) +{ + crypto_sha512_hmac_ctx ctx; + crypto_sha512_hmac_init (&ctx, key, key_size); + crypto_sha512_hmac_update(&ctx, message, message_size); + crypto_sha512_hmac_final (&ctx, hmac); +} + +//////////////////// +/// HKDF SHA 512 /// +//////////////////// +void crypto_sha512_hkdf_expand(u8 *okm, size_t okm_size, + const u8 *prk, size_t prk_size, + const u8 *info, size_t info_size) +{ + int not_first = 0; + u8 ctr = 1; + u8 blk[64]; + + while (okm_size > 0) { + size_t out_size = MIN(okm_size, sizeof(blk)); + + crypto_sha512_hmac_ctx ctx; + crypto_sha512_hmac_init(&ctx, prk , prk_size); + if (not_first) { + // For some reason HKDF uses some kind of CBC mode. + // For some reason CTR mode alone wasn't enough. + // Like what, they didn't trust HMAC in 2010? Really?? + crypto_sha512_hmac_update(&ctx, blk , sizeof(blk)); + } + crypto_sha512_hmac_update(&ctx, info, info_size); + crypto_sha512_hmac_update(&ctx, &ctr, 1); + crypto_sha512_hmac_final(&ctx, blk); + + COPY(okm, blk, out_size); + + not_first = 1; + okm += out_size; + okm_size -= out_size; + ctr++; + } +} + +void crypto_sha512_hkdf(u8 *okm , size_t okm_size, + const u8 *ikm , size_t ikm_size, + const u8 *salt, size_t salt_size, + const u8 *info, size_t info_size) +{ + // Extract + u8 prk[64]; + crypto_sha512_hmac(prk, salt, salt_size, ikm, ikm_size); + + // Expand + crypto_sha512_hkdf_expand(okm, okm_size, prk, sizeof(prk), info, info_size); +} + +/////////////// +/// Ed25519 /// +/////////////// +void crypto_ed25519_key_pair(u8 secret_key[64], u8 public_key[32], u8 seed[32]) +{ + u8 a[64]; + COPY(a, seed, 32); // a[ 0..31] = seed + crypto_wipe(seed, 32); + COPY(secret_key, a, 32); // secret key = seed + crypto_sha512(a, a, 32); // a[ 0..31] = scalar + crypto_eddsa_trim_scalar(a, a); // a[ 0..31] = trimmed scalar + crypto_eddsa_scalarbase(public_key, a); // public key = [trimmed scalar]B + COPY(secret_key + 32, public_key, 32); // secret key includes public half + WIPE_BUFFER(a); +} + +static void hash_reduce(u8 h[32], + const u8 *a, size_t a_size, + const u8 *b, size_t b_size, + const u8 *c, size_t c_size, + const u8 *d, size_t d_size) +{ + u8 hash[64]; + crypto_sha512_ctx ctx; + crypto_sha512_init (&ctx); + crypto_sha512_update(&ctx, a, a_size); + crypto_sha512_update(&ctx, b, b_size); + crypto_sha512_update(&ctx, c, c_size); + crypto_sha512_update(&ctx, d, d_size); + crypto_sha512_final (&ctx, hash); + crypto_eddsa_reduce(h, hash); +} + +static void ed25519_dom_sign(u8 signature [64], const u8 secret_key[32], + const u8 *dom, size_t dom_size, + const u8 *message, size_t message_size) +{ + u8 a[64]; // secret scalar and prefix + u8 r[32]; // secret deterministic "random" nonce + u8 h[32]; // publically verifiable hash of the message (not wiped) + u8 R[32]; // first half of the signature (allows overlapping inputs) + const u8 *pk = secret_key + 32; + + crypto_sha512(a, secret_key, 32); + crypto_eddsa_trim_scalar(a, a); + hash_reduce(r, dom, dom_size, a + 32, 32, message, message_size, 0, 0); + crypto_eddsa_scalarbase(R, r); + hash_reduce(h, dom, dom_size, R, 32, pk, 32, message, message_size); + COPY(signature, R, 32); + crypto_eddsa_mul_add(signature + 32, h, a, r); + + WIPE_BUFFER(a); + WIPE_BUFFER(r); +} + +void crypto_ed25519_sign(u8 signature [64], const u8 secret_key[64], + const u8 *message, size_t message_size) +{ + ed25519_dom_sign(signature, secret_key, 0, 0, message, message_size); +} + +int crypto_ed25519_check(const u8 signature[64], const u8 public_key[32], + const u8 *msg, size_t msg_size) +{ + u8 h_ram[32]; + hash_reduce(h_ram, signature, 32, public_key, 32, msg, msg_size, 0, 0); + return crypto_eddsa_check_equation(signature, public_key, h_ram); +} + +static const u8 domain[34] = "SigEd25519 no Ed25519 collisions\1"; + +void crypto_ed25519_ph_sign(uint8_t signature[64], const uint8_t secret_key[64], + const uint8_t message_hash[64]) +{ + ed25519_dom_sign(signature, secret_key, domain, sizeof(domain), + message_hash, 64); +} + +int crypto_ed25519_ph_check(const uint8_t sig[64], const uint8_t pk[32], + const uint8_t msg_hash[64]) +{ + u8 h_ram[32]; + hash_reduce(h_ram, domain, sizeof(domain), sig, 32, pk, 32, msg_hash, 64); + return crypto_eddsa_check_equation(sig, pk, h_ram); +} + + +#ifdef MONOCYPHER_CPP_NAMESPACE +} +#endif diff --git a/libs/optional/monocypher-ed25519.h b/libs/optional/monocypher-ed25519.h new file mode 100644 index 0000000..1e6d705 --- /dev/null +++ b/libs/optional/monocypher-ed25519.h @@ -0,0 +1,140 @@ +// Monocypher version 4.0.2 +// +// This file is dual-licensed. Choose whichever licence you want from +// the two licences listed below. +// +// The first licence is a regular 2-clause BSD licence. The second licence +// is the CC-0 from Creative Commons. It is intended to release Monocypher +// to the public domain. The BSD licence serves as a fallback option. +// +// SPDX-License-Identifier: BSD-2-Clause OR CC0-1.0 +// +// ------------------------------------------------------------------------ +// +// Copyright (c) 2017-2019, Loup Vaillant +// All rights reserved. +// +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// ------------------------------------------------------------------------ +// +// Written in 2017-2019 by Loup Vaillant +// +// To the extent possible under law, the author(s) have dedicated all copyright +// and related neighboring rights to this software to the public domain +// worldwide. This software is distributed without any warranty. +// +// You should have received a copy of the CC0 Public Domain Dedication along +// with this software. If not, see +// + +#ifndef ED25519_H +#define ED25519_H + +#include "monocypher.h" + +#ifdef MONOCYPHER_CPP_NAMESPACE +namespace MONOCYPHER_CPP_NAMESPACE { +#elif defined(__cplusplus) +extern "C" { +#endif + +//////////////////////// +/// Type definitions /// +//////////////////////// + +// Do not rely on the size or content on any of those types, +// they may change without notice. +typedef struct { + uint64_t hash[8]; + uint64_t input[16]; + uint64_t input_size[2]; + size_t input_idx; +} crypto_sha512_ctx; + +typedef struct { + uint8_t key[128]; + crypto_sha512_ctx ctx; +} crypto_sha512_hmac_ctx; + + +// SHA 512 +// ------- +void crypto_sha512_init (crypto_sha512_ctx *ctx); +void crypto_sha512_update(crypto_sha512_ctx *ctx, + const uint8_t *message, size_t message_size); +void crypto_sha512_final (crypto_sha512_ctx *ctx, uint8_t hash[64]); +void crypto_sha512(uint8_t hash[64], + const uint8_t *message, size_t message_size); + +// SHA 512 HMAC +// ------------ +void crypto_sha512_hmac_init(crypto_sha512_hmac_ctx *ctx, + const uint8_t *key, size_t key_size); +void crypto_sha512_hmac_update(crypto_sha512_hmac_ctx *ctx, + const uint8_t *message, size_t message_size); +void crypto_sha512_hmac_final(crypto_sha512_hmac_ctx *ctx, uint8_t hmac[64]); +void crypto_sha512_hmac(uint8_t hmac[64], + const uint8_t *key , size_t key_size, + const uint8_t *message, size_t message_size); + +// SHA 512 HKDF +// ------------ +void crypto_sha512_hkdf_expand(uint8_t *okm, size_t okm_size, + const uint8_t *prk, size_t prk_size, + const uint8_t *info, size_t info_size); +void crypto_sha512_hkdf(uint8_t *okm , size_t okm_size, + const uint8_t *ikm , size_t ikm_size, + const uint8_t *salt, size_t salt_size, + const uint8_t *info, size_t info_size); + +// Ed25519 +// ------- +// Signatures (EdDSA with curve25519 + SHA-512) +// -------------------------------------------- +void crypto_ed25519_key_pair(uint8_t secret_key[64], + uint8_t public_key[32], + uint8_t seed[32]); +void crypto_ed25519_sign(uint8_t signature [64], + const uint8_t secret_key[64], + const uint8_t *message, size_t message_size); +int crypto_ed25519_check(const uint8_t signature [64], + const uint8_t public_key[32], + const uint8_t *message, size_t message_size); + +// Pre-hash variants +void crypto_ed25519_ph_sign(uint8_t signature [64], + const uint8_t secret_key [64], + const uint8_t message_hash[64]); +int crypto_ed25519_ph_check(const uint8_t signature [64], + const uint8_t public_key [32], + const uint8_t message_hash[64]); + +#ifdef __cplusplus +} +#endif + +#endif // ED25519_H diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..4000ef2 --- /dev/null +++ b/main.cpp @@ -0,0 +1,8 @@ +#include "cli.hpp" +#include "config.hpp" + +int main() { + AppConfig config; + runCLI(config); + return 0; +} diff --git a/sound.cpp b/sound.cpp new file mode 100644 index 0000000..c94f749 --- /dev/null +++ b/sound.cpp @@ -0,0 +1,180 @@ + +#include "sound.hpp" +#include "bfsk.hpp" +#include "config.hpp" +#include "x25519_handshake.hpp" + +#include +#include +#include +#include +#include + +#include + +extern "C" { +#include "monocypher.h" +} + +constexpr int CAPTURE_SECONDS = 3; + +static std::atomic gSoundActive{false}; +static std::thread gSoundThread; + +static std::vector gInputBuffer; +static size_t gWritePos = 0; + +static std::vector gOutputBuffer; +static size_t gReadPos = 0; + +static int audioCallback(const void *input, + void *output, + unsigned long frameCount, + const PaStreamCallbackTimeInfo* timeInfo, + PaStreamCallbackFlags statusFlags, + void *userData) +{ + (void)timeInfo; (void)statusFlags; (void)userData; + const float *in = static_cast(input); + float *out = static_cast(output); + + for (unsigned long i = 0; i < frameCount; i++) { + if (gWritePos < gInputBuffer.size()) { + gInputBuffer[gWritePos++] = in ? in[i] : 0.0f; + } + + float sample = 0.0f; + if (gReadPos < gOutputBuffer.size()) { + sample = gOutputBuffer[gReadPos++]; + } + out[i*2 + 0] = sample; + out[i*2 + 1] = sample; + } + + return paContinue; +} + +static void soundThreadFunc(AppConfig config) { + PaError err = Pa_Initialize(); + if (err != paNoError) { + std::cerr << CLR_RED "[sound] Pa_Initialize error: " << Pa_GetErrorText(err) << CLR_RESET "\n"; + return; + } + + PaStreamParameters inParams, outParams; + + inParams.device = Pa_GetDefaultInputDevice(); + if (inParams.device == paNoDevice) { + std::cerr << CLR_RED "[sound] Нет устройства ввода.\n" CLR_RESET; + Pa_Terminate(); + return; + } + inParams.channelCount = 1; + inParams.sampleFormat = paFloat32; + inParams.suggestedLatency = Pa_GetDeviceInfo(inParams.device)->defaultLowInputLatency; + inParams.hostApiSpecificStreamInfo = nullptr; + + outParams.device = Pa_GetDefaultOutputDevice(); + if (outParams.device == paNoDevice) { + std::cerr << CLR_RED "[sound] Нет устройства вывода.\n" CLR_RESET; + Pa_Terminate(); + return; + } + outParams.channelCount = 2; + outParams.sampleFormat = paFloat32; + outParams.suggestedLatency = Pa_GetDeviceInfo(outParams.device)->defaultLowOutputLatency; + outParams.hostApiSpecificStreamInfo = nullptr; + + PaStream *stream = nullptr; + + err = Pa_OpenStream(&stream, + &inParams, + &outParams, + SAMPLE_RATE, + 256, + paNoFlag, + audioCallback, + nullptr); + if (err != paNoError) { + std::cerr << CLR_RED "[sound] Pa_OpenStream error: " << Pa_GetErrorText(err) << CLR_RESET "\n"; + Pa_Terminate(); + return; + } + + err = Pa_StartStream(stream); + if (err != paNoError) { + std::cerr << CLR_RED "[sound] Pa_StartStream error: " << Pa_GetErrorText(err) << CLR_RESET "\n"; + Pa_CloseStream(stream); + Pa_Terminate(); + return; + } + + std::cout << CLR_BLUE "[sound] Старт записи/воспроизведения (3 сек)...\n" CLR_RESET; + Pa_Sleep(CAPTURE_SECONDS * 1000); + + Pa_StopStream(stream); + Pa_CloseStream(stream); + Pa_Terminate(); + + std::cout << CLR_BLUE "[sound] Остановка аудиопотока...\n" CLR_RESET; + + auto received = bfskDemodulate(gInputBuffer); + if (!received.empty()) { + if (received.size() >= 33 && received[0] == 'E') { + uint8_t otherPub[32]; + std::memcpy(otherPub, received.data() + 1, 32); + + x25519ComputeShared(config, otherPub); + std::cout << CLR_GREEN "[x25519] Общий сеансовый ключ вычислен!\n" CLR_RESET; + } else { + std::cout << CLR_YELLOW "[sound] Получены " << received.size() + << " байт, но не формат 'E' + 32 байта.\n" CLR_RESET; + } + } else { + std::cout << CLR_YELLOW "[sound] Ничего не демодулировано.\n" CLR_RESET; + } + + gSoundActive = false; +} + +void soundFind(AppConfig &config) { + if (config.soundExchangeActive) { + std::cout << CLR_YELLOW "[sound] Уже идёт процесс.\n" CLR_RESET; + return; + } + config.soundExchangeActive = true; + gSoundActive = true; + + x25519GenerateEphemeral(config); + + std::vector packet; + packet.push_back('E'); + packet.insert(packet.end(), config.ephemeralPub, config.ephemeralPub + 32); + + gOutputBuffer = bfskModulate(packet); + gReadPos = 0; + + gInputBuffer.clear(); + gInputBuffer.resize(SAMPLE_RATE * CAPTURE_SECONDS, 0.0f); + gWritePos = 0; + + gSoundThread = std::thread(soundThreadFunc, config); + + std::cout << CLR_GREEN "[sound] Отправляем свой публичный ключ X25519 и слушаем!\n" CLR_RESET; +} + +void soundLose(AppConfig &config) { + if (!config.soundExchangeActive) { + std::cout << CLR_YELLOW "[sound] Процесс не активен.\n" CLR_RESET; + return; + } + config.soundExchangeActive = false; + + if (gSoundActive) { + gSoundActive = false; + } + if (gSoundThread.joinable()) { + gSoundThread.join(); + } + std::cout << CLR_GREEN "[sound] Процесс остановлен.\n" CLR_RESET; +} diff --git a/sound.hpp b/sound.hpp new file mode 100644 index 0000000..814d9f2 --- /dev/null +++ b/sound.hpp @@ -0,0 +1,6 @@ +#pragma once +#include "config.hpp" + +void soundFind(AppConfig &config); + +void soundLose(AppConfig &config); diff --git a/webserver.cpp b/webserver.cpp new file mode 100644 index 0000000..972593e --- /dev/null +++ b/webserver.cpp @@ -0,0 +1,66 @@ +#include "webserver.hpp" +#include "config.hpp" + +#include +#include +#include + +#include "libs/httplib.h" + +static std::atomic g_serverRunning{false}; +static std::thread g_serverThread; + +static void serverThreadFunc() { + httplib::Server svr; + svr.Get("/", [](const httplib::Request&, httplib::Response &res){ + res.set_content("Hello from Cerberus BFSK!", "text/plain"); + }); + + if (!svr.listen("0.0.0.0", 8080)) { + std::cerr << CLR_RED "[web] Ошибка listen(8080). Возможно, порт занят.\n" CLR_RESET; + } + g_serverRunning = false; +} + +void webServerStart(AppConfig &config) { + if (config.webServerRunning) { + std::cout << CLR_YELLOW "[web] Сервер уже запущен.\n" CLR_RESET; + return; + } + g_serverRunning = true; + g_serverThread = std::thread(serverThreadFunc); + + config.webServerRunning = true; + std::cout << CLR_GREEN "[web] Сервер запущен на порту 8080.\n" CLR_RESET; +} + +void webServerConnect(AppConfig &config, const std::string &type, const std::string &ip) { + if (!config.webServerRunning) { + std::cout << CLR_YELLOW "[web] Сначала запустите сервер (web start)\n" CLR_RESET; + return; + } + httplib::Client cli(ip.c_str(), 8080); + if (auto res = cli.Get("/")) { + if (res->status == 200) { + std::cout << CLR_CYAN "[web] Ответ от " << ip << ": " << res->body << CLR_RESET "\n"; + } else { + std::cout << CLR_YELLOW "[web] Подключились, статус: " << res->status << CLR_RESET "\n"; + } + } else { + std::cout << CLR_RED "[web] Не удалось подключиться к " << ip << ":8080.\n" CLR_RESET; + } +} + +void webServerStop(AppConfig &config) { + if (!config.webServerRunning) { + std::cout << CLR_YELLOW "[web] Сервер не запущен.\n" CLR_RESET; + return; + } + g_serverRunning = false; + if (g_serverThread.joinable()) { + g_serverThread.detach(); + } + + config.webServerRunning = false; + std::cout << CLR_GREEN "[web] Сервер остановлен (демо).\n" CLR_RESET; +} diff --git a/webserver.hpp b/webserver.hpp new file mode 100644 index 0000000..5cedbc8 --- /dev/null +++ b/webserver.hpp @@ -0,0 +1,7 @@ +#pragma once +#include "config.hpp" +#include + +void webServerStart(AppConfig &config); +void webServerConnect(AppConfig &config, const std::string &type, const std::string &ip); +void webServerStop(AppConfig &config); diff --git a/x25519_handshake.cpp b/x25519_handshake.cpp new file mode 100644 index 0000000..bf87348 --- /dev/null +++ b/x25519_handshake.cpp @@ -0,0 +1,34 @@ +#include "x25519_handshake.hpp" +#include "config.hpp" + +extern "C" { +#include "libs/monocypher.h" +} + +#include +#include +#include + +void x25519GenerateEphemeral(AppConfig &config) { + FILE* f = fopen("/dev/urandom", "rb"); + if (!f) { + std::cerr << "[x25519] Не удалось открыть /dev/urandom\n"; + return; + } + fread(config.ephemeralSec, 1, 32, f); + fclose(f); + + crypto_x25519_public_key(config.ephemeralPub, config.ephemeralSec); + + memset(config.sharedSecret, 0, 32); + config.haveSharedSecret = false; +} + +void x25519ComputeShared(AppConfig &config, const uint8_t otherPub[32]) { + uint8_t shared[32]; + crypto_x25519(shared, config.ephemeralSec, otherPub); + memcpy(config.sharedSecret, shared, 32); + + config.haveSharedSecret = true; + std::cout << "[x25519] Получен общий сеансовый ключ (32 байта).\n"; +} diff --git a/x25519_handshake.hpp b/x25519_handshake.hpp new file mode 100644 index 0000000..35329d1 --- /dev/null +++ b/x25519_handshake.hpp @@ -0,0 +1,9 @@ +#pragma once +#include +#include + +#include "config.hpp" + +void x25519GenerateEphemeral(AppConfig &config); + +void x25519ComputeShared(AppConfig &config, const uint8_t otherPub[32]);