1
use std::{borrow::Cow, sync::Arc};
2

            
3
use axum::{extract, extract::Extension, http::HeaderValue, response::Html, routing::get, Router};
4
use bonsaidb::{
5
    core::{async_trait::async_trait, connection::AsyncConnection},
6
    server::{CustomServer, HttpService, Peer},
7
};
8
use cfg_if::cfg_if;
9
use futures::{stream::FuturesUnordered, StreamExt};
10
use hyper::{header, server::conn::Http, Body, Request, Response, StatusCode};
11
use minority_game_shared::whole_percent;
12
use serde::{Deserialize, Serialize};
13
use tera::Tera;
14
use tower_http::{services::ServeDir, set_header::SetResponseHeaderLayer};
15

            
16
use crate::{
17
    schema::{PlayerByScore, PlayerStats},
18
    sort_players, CustomServerExt, Game,
19
};
20

            
21
cfg_if! {
22
    if #[cfg(debug_assertions)] {
23
        const STATIC_PATH: &str = "./client/static";
24
        const PKG_PATH: &str = "./client/pkg";
25
    } else {
26
        const PKG_PATH: &str = "./pkg";
27
        const STATIC_PATH: &str = "./static";
28
    }
29
}
30

            
31
#[derive(Debug, Clone)]
32
pub struct WebServer {
33
    server: CustomServer<Game>,
34
    templates: Arc<Tera>,
35
}
36

            
37
impl WebServer {
38
    pub(super) async fn new(server: CustomServer<Game>) -> Self {
39
        let mut templates = Tera::default();
40
        templates
41
            .add_raw_template("stats", &stats_template().await)
42
            .unwrap();
43
        let templates = Arc::new(templates);
44

            
45
        Self { server, templates }
46
    }
47
}
48

            
49
#[async_trait]
50
impl HttpService for WebServer {
51
    async fn handle_connection<
52
        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
53
    >(
54
        &self,
55
        connection: S,
56
        peer: &Peer,
57
    ) -> Result<(), S> {
58
        if let Err(err) = Http::new()
59
            .serve_connection(connection, self.router(peer))
60
            .with_upgrades()
61
            .await
62
        {
63
            log::error!("[http] error serving {}: {:?}", peer.address, err);
64
        }
65

            
66
        Ok(())
67
    }
68
}
69

            
70
impl WebServer {
71
    fn webapp(&self, peer: &Peer) -> Router {
72
        Router::new()
73
            .nest(
74
                "/pkg",
75
                axum::routing::get_service(ServeDir::new(PKG_PATH)).handle_error(
76
                    |err: std::io::Error| async move {
77
                        (
78
                            StatusCode::INTERNAL_SERVER_ERROR,
79
                            format!("unhandled internal error: {}", err),
80
                        )
81
                    },
82
                ),
83
            )
84
            .nest(
85
                "/static",
86
                axum::routing::get_service(ServeDir::new(STATIC_PATH)).handle_error(
87
                    |err: std::io::Error| async move {
88
                        (
89
                            StatusCode::INTERNAL_SERVER_ERROR,
90
                            format!("unhandled internal error: {}", err),
91
                        )
92
                    },
93
                ),
94
            )
95
            .route("/ws", get(upgrade_websocket))
96
            .route("/game", axum::routing::get(spa_index))
97
            .route("/stats", axum::routing::get(stats))
98
            .route("/", axum::routing::get(index))
99
            // Attach the server and the remote address as extractable data for the /ws route
100
            .layer(Extension(self.server.clone()))
101
            .layer(Extension(peer.clone()))
102
            .layer(Extension(self.templates.clone()))
103
            .layer(SetResponseHeaderLayer::if_not_present(
104
                header::STRICT_TRANSPORT_SECURITY,
105
                HeaderValue::from_static("max-age=31536000; preload"),
106
            ))
107
    }
108

            
109
    #[cfg(debug_assertions)]
110
    fn router(&self, peer: &Peer) -> Router {
111
        self.webapp(peer)
112
    }
113

            
114
    #[cfg(not(debug_assertions))]
115
    fn router(&self, peer: &Peer) -> Router {
116
        if peer.secure {
117
            self.webapp(peer)
118
        } else {
119
            Router::new()
120
                .nest("/", axum::routing::get(redirect_to_https))
121
                .layer(Extension(self.server.clone()))
122
        }
123
    }
124
}
125

            
126
#[cfg(not(debug_assertions))]
127
async fn redirect_to_https(
128
    server: extract::Extension<CustomServer<Game>>,
129
    req: hyper::Request<Body>,
130
) -> hyper::Response<Body> {
131
    let path = req.uri().path();
132
    let mut response = hyper::Response::new(Body::empty());
133
    *response.status_mut() = hyper::StatusCode::PERMANENT_REDIRECT;
134
    response.headers_mut().insert(
135
        "Location",
136
        HeaderValue::from_str(&format!("https://{}{}", server.primary_domain(), path)).unwrap(),
137
    );
138
    response
139
}
140

            
141
async fn upgrade_websocket(
142
    server: extract::Extension<CustomServer<Game>>,
143
    peer: extract::Extension<Peer>,
144
    req: Request<Body>,
145
) -> Response<Body> {
146
    server.upgrade_websocket(peer.address, req).await
147
}
148

            
149
#[allow(clippy::unused_async)]
150
async fn index() -> Html<Cow<'static, str>> {
151
    let file_contents = {
152
        cfg_if! {
153
            if #[cfg(debug_assertions)] {
154
                Cow::Owned(tokio::fs::read_to_string("server/src/index.html")
155
                    .await
156
                    .unwrap())
157
            } else {
158
                Cow::Borrowed(include_str!("../../server/src/index.html"))
159
            }
160
        }
161
    };
162

            
163
    Html::from(file_contents)
164
}
165

            
166
#[allow(clippy::unused_async)]
167
async fn spa_index() -> Html<Cow<'static, str>> {
168
    let file_contents = {
169
        cfg_if! {
170
            if #[cfg(debug_assertions)] {
171
                Cow::Owned(tokio::fs::read_to_string("client/bootstrap.html")
172
                    .await
173
                    .unwrap())
174
            } else {
175
                Cow::Borrowed(include_str!("../../client/bootstrap.html"))
176
            }
177
        }
178
    };
179

            
180
    Html::from(file_contents)
181
}
182

            
183
async fn stats_template() -> Cow<'static, str> {
184
    cfg_if! {
185
        if #[cfg(debug_assertions)] {
186
            Cow::Owned(tokio::fs::read_to_string("server/src/stats.tera.html")
187
                .await
188
                .unwrap())
189
        } else {
190
            Cow::Borrowed(include_str!("../../server/src/stats.tera.html"))
191
        }
192
    }
193
}
194

            
195
async fn stats(
196
    server: extract::Extension<CustomServer<Game>>,
197
    templates: extract::Extension<Arc<Tera>>,
198
) -> Html<String> {
199
    let mut current_players = server
200
        .connected_clients()
201
        .await
202
        .iter()
203
        .map(|client| client.client_data())
204
        .collect::<FuturesUnordered<_>>()
205
        .filter_map(|player| async move { player.clone() })
206
        .collect::<Vec<_>>()
207
        .await;
208

            
209
    sort_players(&mut current_players);
210

            
211
    let db = server.game_database().await.unwrap();
212
    let top_players = db
213
        .view::<PlayerByScore>()
214
        .descending()
215
        .limit(10)
216
        .query()
217
        .await
218
        .unwrap();
219

            
220
    let html = templates
221
        .render(
222
            "stats",
223
            &tera::Context::from_serialize(Stats {
224
                current_players: current_players
225
                    .iter()
226
                    .enumerate()
227
                    .map(|(index, player)| {
228
                        RankedPlayer::from_player_stats(
229
                            &player.contents.stats,
230
                            player.header.id,
231
                            index,
232
                        )
233
                    })
234
                    .collect(),
235
                top_players: top_players
236
                    .into_iter()
237
                    .enumerate()
238
                    .map(|(index, map)| {
239
                        RankedPlayer::from_player_stats(
240
                            &map.value,
241
                            map.source.id.deserialize().unwrap(),
242
                            index,
243
                        )
244
                    })
245
                    .collect(),
246
            })
247
            .unwrap(),
248
        )
249
        .unwrap();
250

            
251
    Html::from(html)
252
}
253

            
254
#[derive(Serialize, Deserialize, Debug)]
255
struct RankedPlayer {
256
    id: u64,
257
    rank: u32,
258
    happiness: u32,
259
    times_went_out: u32,
260
    times_stayed_in: u32,
261
}
262

            
263
impl RankedPlayer {
264
    pub fn from_player_stats(player: &PlayerStats, id: u64, index: usize) -> Self {
265
        Self {
266
            id,
267
            rank: index as u32 + 1,
268
            happiness: whole_percent(player.happiness),
269
            times_stayed_in: player.times_stayed_in,
270
            times_went_out: player.times_went_out,
271
        }
272
    }
273
}
274

            
275
#[derive(Serialize, Deserialize, Debug)]
276
struct Stats {
277
    current_players: Vec<RankedPlayer>,
278
    top_players: Vec<RankedPlayer>,
279
}