Files

524 lines
19 KiB
Python

# Run in dev mode using:
# fastapi dev app.py
#
# This will automatically start both the web server AND the UDP relay (server.py).
# After starting, navigate to http://localhost:8000 to see the web interface.
#
# Note: This requires the user to have the fastapi CLI tool installed.
# The user should be in the same directory as `app.py`, `server.py`, and `index.html`.
import asyncio
from contextlib import asynccontextmanager
import socket
import subprocess
import os
import signal
from fastapi import FastAPI, Request
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
clients = set()
server_process = None # Will hold the server.py subprocess
# Broadcast function to notify all SSE clients
async def notify_clients(message: str):
for queue in clients:
await queue.put(message)
async def sock_recvfrom(nonblocking_sock, *pos, loop, **kw):
while True:
try:
return nonblocking_sock.recvfrom(*pos, **kw)
except BlockingIOError:
future = asyncio.Future(loop=loop)
loop.add_reader(nonblocking_sock.fileno(), future.set_result, None)
try:
await future
finally:
loop.remove_reader(nonblocking_sock.fileno())
# Background task: UDP listener for SSE clients
async def udp_listener():
loop = asyncio.get_running_loop()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("0.0.0.0", 5001))
sock.setblocking(False)
while True:
data, addr = await sock_recvfrom(sock, 1024, loop=loop)
message = data.decode()
await notify_clients(message)
# Background task: Tracking data listener
async def tracking_listener():
loop = asyncio.get_running_loop()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("0.0.0.0", 5002)) # Different port for tracking data
sock.setblocking(False)
while True:
data, addr = await sock_recvfrom(sock, 1024 * 16, loop=loop) # Larger buffer for tracking data
# Only record if tracking is active
if tracking_state["is_recording"]:
try:
message = data.decode()
# Extract source IP from tagged data (added by server.py)
if message.startswith("SOURCE_IP:"):
parts = message.split("|", 1)
source_ip = parts[0].replace("SOURCE_IP:", "")
actual_data = parts[1] if len(parts) > 1 else ""
else:
# Fallback if data isn't tagged (shouldn't happen normally)
source_ip = addr[0]
actual_data = message
# Record data from both players
if source_ip in [tracking_state["player1_ip"], tracking_state["player2_ip"]]:
parsed = parse_tracking_data(actual_data)
if parsed:
import time
# Calculate time elapsed since experiment start
elapsed_time = time.time() - tracking_state["experiment_start_time"]
# Determine player_ID and role
if source_ip == tracking_state["player1_ip"]:
player_id = "player1"
else:
player_id = "player2"
# Determine role: active_player_ip is the mimicker
if source_ip == tracking_state["active_player_ip"]:
role = "mimicker"
else:
role = "guesser"
# Add sample with metadata
sample = {
"timestamp": time.time(),
"elapsed_time": elapsed_time,
"player_id": player_id,
"role": role,
"current_word": current_word_state["word"],
"word_time_remaining": 0.0,
"data": parsed
}
# Calculate word time remaining if word is active
if current_word_state["startTime"]:
word_elapsed = time.time() - current_word_state["startTime"]
sample["word_time_remaining"] = max(0, current_word_state["timeSeconds"] - word_elapsed)
tracking_state["samples"].append(sample)
except Exception as e:
print(f"Error processing tracking data: {e}")
@asynccontextmanager
async def lifespan(app: FastAPI):
global server_process
# Start server.py subprocess
print("Starting UDP relay server (server.py)...")
try:
server_process = subprocess.Popen(
["python", "server.py"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=os.path.dirname(os.path.abspath(__file__))
)
print(f"UDP relay server started with PID {server_process.pid}")
except Exception as e:
print(f"Warning: Failed to start server.py: {e}")
print("You may need to start server.py manually.")
# Start background tasks
asyncio.create_task(udp_listener())
asyncio.create_task(tracking_listener())
yield
# Cleanup: Stop server.py subprocess
if server_process:
print("Stopping UDP relay server...")
server_process.terminate()
try:
server_process.wait(timeout=5)
print("UDP relay server stopped")
except subprocess.TimeoutExpired:
print("Forcefully killing UDP relay server...")
server_process.kill()
server_process.wait()
app = FastAPI(lifespan=lifespan)
# Mount static files for serving CSS, JavaScript, and HTML assets
app.mount("/static", StaticFiles(directory="static"), name="static")
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# sock.bind(("0.0.0.0", 5002))
SERVER_IP = "127.0.0.1"
SERVER_PORT = 5000
@app.get("/")
async def read_index():
return FileResponse('index.html')
@app.get("/display")
async def read_display():
return FileResponse('static/player-display.html')
@app.post("/facialexpressions")
def read_item(weights: list[float]):
msg = ';'.join(str(w) for w in weights)
print(len(weights), msg)
sock.sendto(msg.encode('utf-8'), (SERVER_IP, SERVER_PORT))
return { "status": "ok" }
class Word(BaseModel):
target: str
lastWordStatus: int
timeSeconds: float
word: str
class WordList(BaseModel):
words: list[str]
class VRConfig(BaseModel):
player1_ip: str
player2_ip: str
server_ip: str
mode: str
class TrackingStart(BaseModel):
group_id: str
condition: str
player1_ip: str
player2_ip: str
active_player_ip: str # The player who mimics (mimicker)
# Global state for current word display
current_word_state = {
"word": "",
"timeSeconds": 0.0,
"lastWordStatus": -1,
"startTime": None
}
# Global state for tracking data recording
tracking_state = {
"is_recording": False,
"group_id": "",
"condition": "",
"player1_ip": "",
"player2_ip": "",
"active_player_ip": "", # The mimicker
"experiment_start_time": None,
"samples": [] # List of tracking data samples
}
def parse_tracking_data(data_str):
"""
Parse tracking data from VR headset to extract camera and controller positions/rotations.
Returns dict with center_eye, left_hand, right_hand data or None if parsing fails.
"""
try:
parts = data_str.split(';')
# The data structure from TrackWeights.cs:
# 0-62: Face expressions (63 values)
# 63-66: Root orientation (4 values)
# 67-69: Root position (3 values)
# 70: Root scale (1 value)
# 71: Bone rotations length (1 value)
bone_rot_length = int(float(parts[71]))
# 72 to 72+(bone_rot_length*4)-1: Bone rotations (4 values each)
idx = 72 + (bone_rot_length * 4)
# idx: IsDataValid
idx += 1
# idx: IsDataHighConfidence
idx += 1
# idx: Bone translations length
bone_trans_length = int(float(parts[idx]))
idx += 1
# idx to idx+(bone_trans_length*3)-1: Bone translations (3 values each)
idx += bone_trans_length * 3
# idx: SkeletonChangedCount
idx += 1
# idx to idx+3: Left eye rotation (4 values)
idx += 4
# idx to idx+3: Right eye rotation (4 values)
idx += 4
# Now we're at the data we want!
# Center eye camera: position (3) + rotation (4) = 7 values
center_eye_pos_x = float(parts[idx])
center_eye_pos_y = float(parts[idx + 1])
center_eye_pos_z = float(parts[idx + 2])
center_eye_rot_w = float(parts[idx + 3])
center_eye_rot_x = float(parts[idx + 4])
center_eye_rot_y = float(parts[idx + 5])
center_eye_rot_z = float(parts[idx + 6])
idx += 7
# Left hand controller: position (3) + rotation (4) = 7 values
left_hand_pos_x = float(parts[idx])
left_hand_pos_y = float(parts[idx + 1])
left_hand_pos_z = float(parts[idx + 2])
left_hand_rot_w = float(parts[idx + 3])
left_hand_rot_x = float(parts[idx + 4])
left_hand_rot_y = float(parts[idx + 5])
left_hand_rot_z = float(parts[idx + 6])
idx += 7
# Right hand controller: position (3) + rotation (4) = 7 values
right_hand_pos_x = float(parts[idx])
right_hand_pos_y = float(parts[idx + 1])
right_hand_pos_z = float(parts[idx + 2])
right_hand_rot_w = float(parts[idx + 3])
right_hand_rot_x = float(parts[idx + 4])
right_hand_rot_y = float(parts[idx + 5])
right_hand_rot_z = float(parts[idx + 6])
return {
"center_eye": {
"pos": {"x": center_eye_pos_x, "y": center_eye_pos_y, "z": center_eye_pos_z},
"rot": {"w": center_eye_rot_w, "x": center_eye_rot_x, "y": center_eye_rot_y, "z": center_eye_rot_z}
},
"left_hand": {
"pos": {"x": left_hand_pos_x, "y": left_hand_pos_y, "z": left_hand_pos_z},
"rot": {"w": left_hand_rot_w, "x": left_hand_rot_x, "y": left_hand_rot_y, "z": left_hand_rot_z}
},
"right_hand": {
"pos": {"x": right_hand_pos_x, "y": right_hand_pos_y, "z": right_hand_pos_z},
"rot": {"w": right_hand_rot_w, "x": right_hand_rot_x, "y": right_hand_rot_y, "z": right_hand_rot_z}
}
}
except (IndexError, ValueError) as e:
print(f"Error parsing tracking data: {e}")
return None
@app.post("/word")
def read_word(word: Word):
import time
# Only update global state for player display if word is not empty
# (avoid overwriting with empty words sent to "other" player)
if word.word and word.word.strip():
current_word_state["word"] = word.word
current_word_state["timeSeconds"] = word.timeSeconds
current_word_state["lastWordStatus"] = word.lastWordStatus
current_word_state["startTime"] = time.time() if word.timeSeconds > 0 else None
msg = f"CHARADE:{word.lastWordStatus};{word.timeSeconds};{word.word}"
print(msg)
sock.sendto(msg.encode('utf-8'), (word.target, 5000))
return { "status": "ok" }
@app.get("/current-word")
def get_current_word():
import time
if current_word_state["startTime"] is None:
return {
"word": current_word_state["word"],
"timeRemaining": 0.0,
"lastWordStatus": current_word_state["lastWordStatus"],
"isActive": bool(current_word_state["word"])
}
elapsed = time.time() - current_word_state["startTime"]
time_remaining = max(0, current_word_state["timeSeconds"] - elapsed)
return {
"word": current_word_state["word"],
"timeRemaining": time_remaining,
"lastWordStatus": current_word_state["lastWordStatus"],
"isActive": time_remaining > 0 and bool(current_word_state["word"])
}
@app.post("/shuffle")
def shuffle_words(word_list: WordList):
import random
shuffled = word_list.words.copy()
random.shuffle(shuffled)
return { "status": "ok", "shuffled_words": shuffled }
@app.post("/send-config")
def send_vr_config(config: VRConfig):
"""
Send IP and MODE configuration to VR headsets.
This integrates the functionality from control.py into the web interface.
"""
try:
# Create UDP socket for sending commands
cmd_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Send IP configuration to both players
ip_msg = f"IP:{config.server_ip}".encode('utf-8')
cmd_sock.sendto(ip_msg, (config.player1_ip, 5000))
cmd_sock.sendto(ip_msg, (config.player2_ip, 5000))
# Send MODE configuration to both players
mode_msg = f"MODE:{config.mode}".encode('utf-8')
cmd_sock.sendto(mode_msg, (config.player1_ip, 5000))
cmd_sock.sendto(mode_msg, (config.player2_ip, 5000))
cmd_sock.close()
print(f"Sent IP config: {ip_msg.decode()}")
print(f"Sent MODE config: {mode_msg.decode()}")
print(f"To players: {config.player1_ip}, {config.player2_ip}")
return { "status": "ok" }
except Exception as e:
print(f"Error sending VR config: {e}")
return { "status": "error", "message": str(e) }
@app.post("/tracking/start")
def start_tracking(config: TrackingStart):
"""
Start recording tracking data from both players.
"""
import time
tracking_state["is_recording"] = True
tracking_state["group_id"] = config.group_id
tracking_state["condition"] = config.condition
tracking_state["player1_ip"] = config.player1_ip
tracking_state["player2_ip"] = config.player2_ip
tracking_state["active_player_ip"] = config.active_player_ip
tracking_state["experiment_start_time"] = time.time()
tracking_state["samples"] = []
print(f"Started tracking: group={config.group_id}, condition={config.condition}")
print(f" Player1: {config.player1_ip}, Player2: {config.player2_ip}")
print(f" Mimicker: {config.active_player_ip}")
return { "status": "ok", "message": "Tracking started" }
@app.post("/tracking/stop")
def stop_tracking():
"""
Stop recording tracking data.
"""
tracking_state["is_recording"] = False
sample_count = len(tracking_state["samples"])
print(f"Stopped tracking: {sample_count} samples recorded")
return { "status": "ok", "message": f"Tracking stopped. {sample_count} samples recorded." }
@app.get("/tracking/status")
def get_tracking_status():
"""
Get current tracking status.
"""
return {
"is_recording": tracking_state["is_recording"],
"sample_count": len(tracking_state["samples"]),
"group_id": tracking_state["group_id"],
"condition": tracking_state["condition"]
}
@app.get("/tracking/download")
def download_tracking_csv():
"""
Download tracking data as CSV.
"""
import io
from datetime import datetime
if len(tracking_state["samples"]) == 0:
return { "status": "error", "message": "No tracking data available" }
# Create CSV content
output = io.StringIO()
# Write header
header = [
"timestamp", "elapsed_time", "player_id", "role", "group_id", "condition", "current_word", "word_time_remaining",
"center_eye_pos_x", "center_eye_pos_y", "center_eye_pos_z",
"center_eye_rot_w", "center_eye_rot_x", "center_eye_rot_y", "center_eye_rot_z",
"left_hand_pos_x", "left_hand_pos_y", "left_hand_pos_z",
"left_hand_rot_w", "left_hand_rot_x", "left_hand_rot_y", "left_hand_rot_z",
"right_hand_pos_x", "right_hand_pos_y", "right_hand_pos_z",
"right_hand_rot_w", "right_hand_rot_x", "right_hand_rot_y", "right_hand_rot_z"
]
output.write(";".join(header) + "\n")
# Write data rows
for sample in tracking_state["samples"]:
data = sample["data"]
row = [
str(sample["timestamp"]),
f"{sample['elapsed_time']:.4f}",
sample["player_id"],
sample["role"],
tracking_state["group_id"],
tracking_state["condition"],
sample["current_word"],
f"{sample['word_time_remaining']:.4f}",
f"{data['center_eye']['pos']['x']:.4f}",
f"{data['center_eye']['pos']['y']:.4f}",
f"{data['center_eye']['pos']['z']:.4f}",
f"{data['center_eye']['rot']['w']:.4f}",
f"{data['center_eye']['rot']['x']:.4f}",
f"{data['center_eye']['rot']['y']:.4f}",
f"{data['center_eye']['rot']['z']:.4f}",
f"{data['left_hand']['pos']['x']:.4f}",
f"{data['left_hand']['pos']['y']:.4f}",
f"{data['left_hand']['pos']['z']:.4f}",
f"{data['left_hand']['rot']['w']:.4f}",
f"{data['left_hand']['rot']['x']:.4f}",
f"{data['left_hand']['rot']['y']:.4f}",
f"{data['left_hand']['rot']['z']:.4f}",
f"{data['right_hand']['pos']['x']:.4f}",
f"{data['right_hand']['pos']['y']:.4f}",
f"{data['right_hand']['pos']['z']:.4f}",
f"{data['right_hand']['rot']['w']:.4f}",
f"{data['right_hand']['rot']['x']:.4f}",
f"{data['right_hand']['rot']['y']:.4f}",
f"{data['right_hand']['rot']['z']:.4f}"
]
output.write(";".join(row) + "\n")
csv_content = output.getvalue()
output.close()
# Create filename with timestamp
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = f"{tracking_state['group_id']}_tracking_{timestamp}.csv"
from fastapi.responses import Response
return Response(
content=csv_content,
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
# SSE endpoint
@app.get("/news")
async def sse_endpoint(request: Request):
queue = asyncio.Queue()
clients.add(queue)
async def event_generator():
try:
while True:
if await request.is_disconnected():
break
message = await queue.get()
yield f"event: update\ndata: {message}\n\n"
finally:
clients.remove(queue)
return StreamingResponse(event_generator(), media_type="text/event-stream")