524 lines
19 KiB
Python
524 lines
19 KiB
Python
# Run in dev mode using:
|
|
# fastapi dev app.py
|
|
#
|
|
# This will automatically start both the web server AND the UDP relay (server.py).
|
|
# After starting, navigate to http://localhost:8000 to see the web interface.
|
|
#
|
|
# Note: This requires the user to have the fastapi CLI tool installed.
|
|
# The user should be in the same directory as `app.py`, `server.py`, and `index.html`.
|
|
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
import socket
|
|
import subprocess
|
|
import os
|
|
import signal
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import FileResponse, StreamingResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
clients = set()
|
|
server_process = None # Will hold the server.py subprocess
|
|
|
|
# Broadcast function to notify all SSE clients
|
|
async def notify_clients(message: str):
|
|
for queue in clients:
|
|
await queue.put(message)
|
|
|
|
async def sock_recvfrom(nonblocking_sock, *pos, loop, **kw):
|
|
while True:
|
|
try:
|
|
return nonblocking_sock.recvfrom(*pos, **kw)
|
|
except BlockingIOError:
|
|
future = asyncio.Future(loop=loop)
|
|
loop.add_reader(nonblocking_sock.fileno(), future.set_result, None)
|
|
try:
|
|
await future
|
|
finally:
|
|
loop.remove_reader(nonblocking_sock.fileno())
|
|
|
|
# Background task: UDP listener for SSE clients
|
|
async def udp_listener():
|
|
loop = asyncio.get_running_loop()
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
sock.bind(("0.0.0.0", 5001))
|
|
sock.setblocking(False)
|
|
|
|
while True:
|
|
data, addr = await sock_recvfrom(sock, 1024, loop=loop)
|
|
message = data.decode()
|
|
await notify_clients(message)
|
|
|
|
# Background task: Tracking data listener
|
|
async def tracking_listener():
|
|
loop = asyncio.get_running_loop()
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
sock.bind(("0.0.0.0", 5002)) # Different port for tracking data
|
|
sock.setblocking(False)
|
|
|
|
while True:
|
|
data, addr = await sock_recvfrom(sock, 1024 * 16, loop=loop) # Larger buffer for tracking data
|
|
|
|
# Only record if tracking is active
|
|
if tracking_state["is_recording"]:
|
|
try:
|
|
message = data.decode()
|
|
|
|
# Extract source IP from tagged data (added by server.py)
|
|
if message.startswith("SOURCE_IP:"):
|
|
parts = message.split("|", 1)
|
|
source_ip = parts[0].replace("SOURCE_IP:", "")
|
|
actual_data = parts[1] if len(parts) > 1 else ""
|
|
else:
|
|
# Fallback if data isn't tagged (shouldn't happen normally)
|
|
source_ip = addr[0]
|
|
actual_data = message
|
|
|
|
# Record data from both players
|
|
if source_ip in [tracking_state["player1_ip"], tracking_state["player2_ip"]]:
|
|
parsed = parse_tracking_data(actual_data)
|
|
|
|
if parsed:
|
|
import time
|
|
# Calculate time elapsed since experiment start
|
|
elapsed_time = time.time() - tracking_state["experiment_start_time"]
|
|
|
|
# Determine player_ID and role
|
|
if source_ip == tracking_state["player1_ip"]:
|
|
player_id = "player1"
|
|
else:
|
|
player_id = "player2"
|
|
|
|
# Determine role: active_player_ip is the mimicker
|
|
if source_ip == tracking_state["active_player_ip"]:
|
|
role = "mimicker"
|
|
else:
|
|
role = "guesser"
|
|
|
|
# Add sample with metadata
|
|
sample = {
|
|
"timestamp": time.time(),
|
|
"elapsed_time": elapsed_time,
|
|
"player_id": player_id,
|
|
"role": role,
|
|
"current_word": current_word_state["word"],
|
|
"word_time_remaining": 0.0,
|
|
"data": parsed
|
|
}
|
|
|
|
# Calculate word time remaining if word is active
|
|
if current_word_state["startTime"]:
|
|
word_elapsed = time.time() - current_word_state["startTime"]
|
|
sample["word_time_remaining"] = max(0, current_word_state["timeSeconds"] - word_elapsed)
|
|
|
|
tracking_state["samples"].append(sample)
|
|
except Exception as e:
|
|
print(f"Error processing tracking data: {e}")
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global server_process
|
|
|
|
# Start server.py subprocess
|
|
print("Starting UDP relay server (server.py)...")
|
|
try:
|
|
server_process = subprocess.Popen(
|
|
["python", "server.py"],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
cwd=os.path.dirname(os.path.abspath(__file__))
|
|
)
|
|
print(f"UDP relay server started with PID {server_process.pid}")
|
|
except Exception as e:
|
|
print(f"Warning: Failed to start server.py: {e}")
|
|
print("You may need to start server.py manually.")
|
|
|
|
# Start background tasks
|
|
asyncio.create_task(udp_listener())
|
|
asyncio.create_task(tracking_listener())
|
|
|
|
yield
|
|
|
|
# Cleanup: Stop server.py subprocess
|
|
if server_process:
|
|
print("Stopping UDP relay server...")
|
|
server_process.terminate()
|
|
try:
|
|
server_process.wait(timeout=5)
|
|
print("UDP relay server stopped")
|
|
except subprocess.TimeoutExpired:
|
|
print("Forcefully killing UDP relay server...")
|
|
server_process.kill()
|
|
server_process.wait()
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
# Mount static files for serving CSS, JavaScript, and HTML assets
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
# sock.bind(("0.0.0.0", 5002))
|
|
|
|
SERVER_IP = "127.0.0.1"
|
|
SERVER_PORT = 5000
|
|
|
|
@app.get("/")
|
|
async def read_index():
|
|
return FileResponse('index.html')
|
|
|
|
@app.get("/display")
|
|
async def read_display():
|
|
return FileResponse('static/player-display.html')
|
|
|
|
@app.post("/facialexpressions")
|
|
def read_item(weights: list[float]):
|
|
msg = ';'.join(str(w) for w in weights)
|
|
print(len(weights), msg)
|
|
sock.sendto(msg.encode('utf-8'), (SERVER_IP, SERVER_PORT))
|
|
return { "status": "ok" }
|
|
|
|
class Word(BaseModel):
|
|
target: str
|
|
lastWordStatus: int
|
|
timeSeconds: float
|
|
word: str
|
|
|
|
class WordList(BaseModel):
|
|
words: list[str]
|
|
|
|
class VRConfig(BaseModel):
|
|
player1_ip: str
|
|
player2_ip: str
|
|
server_ip: str
|
|
mode: str
|
|
|
|
class TrackingStart(BaseModel):
|
|
group_id: str
|
|
condition: str
|
|
player1_ip: str
|
|
player2_ip: str
|
|
active_player_ip: str # The player who mimics (mimicker)
|
|
|
|
# Global state for current word display
|
|
current_word_state = {
|
|
"word": "",
|
|
"timeSeconds": 0.0,
|
|
"lastWordStatus": -1,
|
|
"startTime": None
|
|
}
|
|
|
|
# Global state for tracking data recording
|
|
tracking_state = {
|
|
"is_recording": False,
|
|
"group_id": "",
|
|
"condition": "",
|
|
"player1_ip": "",
|
|
"player2_ip": "",
|
|
"active_player_ip": "", # The mimicker
|
|
"experiment_start_time": None,
|
|
"samples": [] # List of tracking data samples
|
|
}
|
|
|
|
def parse_tracking_data(data_str):
|
|
"""
|
|
Parse tracking data from VR headset to extract camera and controller positions/rotations.
|
|
Returns dict with center_eye, left_hand, right_hand data or None if parsing fails.
|
|
"""
|
|
try:
|
|
parts = data_str.split(';')
|
|
|
|
# The data structure from TrackWeights.cs:
|
|
# 0-62: Face expressions (63 values)
|
|
# 63-66: Root orientation (4 values)
|
|
# 67-69: Root position (3 values)
|
|
# 70: Root scale (1 value)
|
|
# 71: Bone rotations length (1 value)
|
|
bone_rot_length = int(float(parts[71]))
|
|
# 72 to 72+(bone_rot_length*4)-1: Bone rotations (4 values each)
|
|
idx = 72 + (bone_rot_length * 4)
|
|
# idx: IsDataValid
|
|
idx += 1
|
|
# idx: IsDataHighConfidence
|
|
idx += 1
|
|
# idx: Bone translations length
|
|
bone_trans_length = int(float(parts[idx]))
|
|
idx += 1
|
|
# idx to idx+(bone_trans_length*3)-1: Bone translations (3 values each)
|
|
idx += bone_trans_length * 3
|
|
# idx: SkeletonChangedCount
|
|
idx += 1
|
|
# idx to idx+3: Left eye rotation (4 values)
|
|
idx += 4
|
|
# idx to idx+3: Right eye rotation (4 values)
|
|
idx += 4
|
|
|
|
# Now we're at the data we want!
|
|
# Center eye camera: position (3) + rotation (4) = 7 values
|
|
center_eye_pos_x = float(parts[idx])
|
|
center_eye_pos_y = float(parts[idx + 1])
|
|
center_eye_pos_z = float(parts[idx + 2])
|
|
center_eye_rot_w = float(parts[idx + 3])
|
|
center_eye_rot_x = float(parts[idx + 4])
|
|
center_eye_rot_y = float(parts[idx + 5])
|
|
center_eye_rot_z = float(parts[idx + 6])
|
|
idx += 7
|
|
|
|
# Left hand controller: position (3) + rotation (4) = 7 values
|
|
left_hand_pos_x = float(parts[idx])
|
|
left_hand_pos_y = float(parts[idx + 1])
|
|
left_hand_pos_z = float(parts[idx + 2])
|
|
left_hand_rot_w = float(parts[idx + 3])
|
|
left_hand_rot_x = float(parts[idx + 4])
|
|
left_hand_rot_y = float(parts[idx + 5])
|
|
left_hand_rot_z = float(parts[idx + 6])
|
|
idx += 7
|
|
|
|
# Right hand controller: position (3) + rotation (4) = 7 values
|
|
right_hand_pos_x = float(parts[idx])
|
|
right_hand_pos_y = float(parts[idx + 1])
|
|
right_hand_pos_z = float(parts[idx + 2])
|
|
right_hand_rot_w = float(parts[idx + 3])
|
|
right_hand_rot_x = float(parts[idx + 4])
|
|
right_hand_rot_y = float(parts[idx + 5])
|
|
right_hand_rot_z = float(parts[idx + 6])
|
|
|
|
return {
|
|
"center_eye": {
|
|
"pos": {"x": center_eye_pos_x, "y": center_eye_pos_y, "z": center_eye_pos_z},
|
|
"rot": {"w": center_eye_rot_w, "x": center_eye_rot_x, "y": center_eye_rot_y, "z": center_eye_rot_z}
|
|
},
|
|
"left_hand": {
|
|
"pos": {"x": left_hand_pos_x, "y": left_hand_pos_y, "z": left_hand_pos_z},
|
|
"rot": {"w": left_hand_rot_w, "x": left_hand_rot_x, "y": left_hand_rot_y, "z": left_hand_rot_z}
|
|
},
|
|
"right_hand": {
|
|
"pos": {"x": right_hand_pos_x, "y": right_hand_pos_y, "z": right_hand_pos_z},
|
|
"rot": {"w": right_hand_rot_w, "x": right_hand_rot_x, "y": right_hand_rot_y, "z": right_hand_rot_z}
|
|
}
|
|
}
|
|
except (IndexError, ValueError) as e:
|
|
print(f"Error parsing tracking data: {e}")
|
|
return None
|
|
|
|
@app.post("/word")
|
|
def read_word(word: Word):
|
|
import time
|
|
|
|
# Only update global state for player display if word is not empty
|
|
# (avoid overwriting with empty words sent to "other" player)
|
|
if word.word and word.word.strip():
|
|
current_word_state["word"] = word.word
|
|
current_word_state["timeSeconds"] = word.timeSeconds
|
|
current_word_state["lastWordStatus"] = word.lastWordStatus
|
|
current_word_state["startTime"] = time.time() if word.timeSeconds > 0 else None
|
|
|
|
msg = f"CHARADE:{word.lastWordStatus};{word.timeSeconds};{word.word}"
|
|
print(msg)
|
|
sock.sendto(msg.encode('utf-8'), (word.target, 5000))
|
|
return { "status": "ok" }
|
|
|
|
@app.get("/current-word")
|
|
def get_current_word():
|
|
import time
|
|
|
|
if current_word_state["startTime"] is None:
|
|
return {
|
|
"word": current_word_state["word"],
|
|
"timeRemaining": 0.0,
|
|
"lastWordStatus": current_word_state["lastWordStatus"],
|
|
"isActive": bool(current_word_state["word"])
|
|
}
|
|
|
|
elapsed = time.time() - current_word_state["startTime"]
|
|
time_remaining = max(0, current_word_state["timeSeconds"] - elapsed)
|
|
|
|
return {
|
|
"word": current_word_state["word"],
|
|
"timeRemaining": time_remaining,
|
|
"lastWordStatus": current_word_state["lastWordStatus"],
|
|
"isActive": time_remaining > 0 and bool(current_word_state["word"])
|
|
}
|
|
|
|
@app.post("/shuffle")
|
|
def shuffle_words(word_list: WordList):
|
|
import random
|
|
shuffled = word_list.words.copy()
|
|
random.shuffle(shuffled)
|
|
return { "status": "ok", "shuffled_words": shuffled }
|
|
|
|
@app.post("/send-config")
|
|
def send_vr_config(config: VRConfig):
|
|
"""
|
|
Send IP and MODE configuration to VR headsets.
|
|
This integrates the functionality from control.py into the web interface.
|
|
"""
|
|
try:
|
|
# Create UDP socket for sending commands
|
|
cmd_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
|
|
# Send IP configuration to both players
|
|
ip_msg = f"IP:{config.server_ip}".encode('utf-8')
|
|
cmd_sock.sendto(ip_msg, (config.player1_ip, 5000))
|
|
cmd_sock.sendto(ip_msg, (config.player2_ip, 5000))
|
|
|
|
# Send MODE configuration to both players
|
|
mode_msg = f"MODE:{config.mode}".encode('utf-8')
|
|
cmd_sock.sendto(mode_msg, (config.player1_ip, 5000))
|
|
cmd_sock.sendto(mode_msg, (config.player2_ip, 5000))
|
|
|
|
cmd_sock.close()
|
|
|
|
print(f"Sent IP config: {ip_msg.decode()}")
|
|
print(f"Sent MODE config: {mode_msg.decode()}")
|
|
print(f"To players: {config.player1_ip}, {config.player2_ip}")
|
|
|
|
return { "status": "ok" }
|
|
except Exception as e:
|
|
print(f"Error sending VR config: {e}")
|
|
return { "status": "error", "message": str(e) }
|
|
|
|
@app.post("/tracking/start")
|
|
def start_tracking(config: TrackingStart):
|
|
"""
|
|
Start recording tracking data from both players.
|
|
"""
|
|
import time
|
|
|
|
tracking_state["is_recording"] = True
|
|
tracking_state["group_id"] = config.group_id
|
|
tracking_state["condition"] = config.condition
|
|
tracking_state["player1_ip"] = config.player1_ip
|
|
tracking_state["player2_ip"] = config.player2_ip
|
|
tracking_state["active_player_ip"] = config.active_player_ip
|
|
tracking_state["experiment_start_time"] = time.time()
|
|
tracking_state["samples"] = []
|
|
|
|
print(f"Started tracking: group={config.group_id}, condition={config.condition}")
|
|
print(f" Player1: {config.player1_ip}, Player2: {config.player2_ip}")
|
|
print(f" Mimicker: {config.active_player_ip}")
|
|
return { "status": "ok", "message": "Tracking started" }
|
|
|
|
@app.post("/tracking/stop")
|
|
def stop_tracking():
|
|
"""
|
|
Stop recording tracking data.
|
|
"""
|
|
tracking_state["is_recording"] = False
|
|
sample_count = len(tracking_state["samples"])
|
|
print(f"Stopped tracking: {sample_count} samples recorded")
|
|
return { "status": "ok", "message": f"Tracking stopped. {sample_count} samples recorded." }
|
|
|
|
@app.get("/tracking/status")
|
|
def get_tracking_status():
|
|
"""
|
|
Get current tracking status.
|
|
"""
|
|
return {
|
|
"is_recording": tracking_state["is_recording"],
|
|
"sample_count": len(tracking_state["samples"]),
|
|
"group_id": tracking_state["group_id"],
|
|
"condition": tracking_state["condition"]
|
|
}
|
|
|
|
@app.get("/tracking/download")
|
|
def download_tracking_csv():
|
|
"""
|
|
Download tracking data as CSV.
|
|
"""
|
|
import io
|
|
from datetime import datetime
|
|
|
|
if len(tracking_state["samples"]) == 0:
|
|
return { "status": "error", "message": "No tracking data available" }
|
|
|
|
# Create CSV content
|
|
output = io.StringIO()
|
|
|
|
# Write header
|
|
header = [
|
|
"timestamp", "elapsed_time", "player_id", "role", "group_id", "condition", "current_word", "word_time_remaining",
|
|
"center_eye_pos_x", "center_eye_pos_y", "center_eye_pos_z",
|
|
"center_eye_rot_w", "center_eye_rot_x", "center_eye_rot_y", "center_eye_rot_z",
|
|
"left_hand_pos_x", "left_hand_pos_y", "left_hand_pos_z",
|
|
"left_hand_rot_w", "left_hand_rot_x", "left_hand_rot_y", "left_hand_rot_z",
|
|
"right_hand_pos_x", "right_hand_pos_y", "right_hand_pos_z",
|
|
"right_hand_rot_w", "right_hand_rot_x", "right_hand_rot_y", "right_hand_rot_z"
|
|
]
|
|
output.write(";".join(header) + "\n")
|
|
|
|
# Write data rows
|
|
for sample in tracking_state["samples"]:
|
|
data = sample["data"]
|
|
row = [
|
|
str(sample["timestamp"]),
|
|
f"{sample['elapsed_time']:.4f}",
|
|
sample["player_id"],
|
|
sample["role"],
|
|
tracking_state["group_id"],
|
|
tracking_state["condition"],
|
|
sample["current_word"],
|
|
f"{sample['word_time_remaining']:.4f}",
|
|
f"{data['center_eye']['pos']['x']:.4f}",
|
|
f"{data['center_eye']['pos']['y']:.4f}",
|
|
f"{data['center_eye']['pos']['z']:.4f}",
|
|
f"{data['center_eye']['rot']['w']:.4f}",
|
|
f"{data['center_eye']['rot']['x']:.4f}",
|
|
f"{data['center_eye']['rot']['y']:.4f}",
|
|
f"{data['center_eye']['rot']['z']:.4f}",
|
|
f"{data['left_hand']['pos']['x']:.4f}",
|
|
f"{data['left_hand']['pos']['y']:.4f}",
|
|
f"{data['left_hand']['pos']['z']:.4f}",
|
|
f"{data['left_hand']['rot']['w']:.4f}",
|
|
f"{data['left_hand']['rot']['x']:.4f}",
|
|
f"{data['left_hand']['rot']['y']:.4f}",
|
|
f"{data['left_hand']['rot']['z']:.4f}",
|
|
f"{data['right_hand']['pos']['x']:.4f}",
|
|
f"{data['right_hand']['pos']['y']:.4f}",
|
|
f"{data['right_hand']['pos']['z']:.4f}",
|
|
f"{data['right_hand']['rot']['w']:.4f}",
|
|
f"{data['right_hand']['rot']['x']:.4f}",
|
|
f"{data['right_hand']['rot']['y']:.4f}",
|
|
f"{data['right_hand']['rot']['z']:.4f}"
|
|
]
|
|
output.write(";".join(row) + "\n")
|
|
|
|
csv_content = output.getvalue()
|
|
output.close()
|
|
|
|
# Create filename with timestamp
|
|
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
filename = f"{tracking_state['group_id']}_tracking_{timestamp}.csv"
|
|
|
|
from fastapi.responses import Response
|
|
return Response(
|
|
content=csv_content,
|
|
media_type="text/csv",
|
|
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
|
)
|
|
|
|
|
|
# SSE endpoint
|
|
@app.get("/news")
|
|
async def sse_endpoint(request: Request):
|
|
queue = asyncio.Queue()
|
|
clients.add(queue)
|
|
|
|
async def event_generator():
|
|
try:
|
|
while True:
|
|
if await request.is_disconnected():
|
|
break
|
|
message = await queue.get()
|
|
yield f"event: update\ndata: {message}\n\n"
|
|
finally:
|
|
clients.remove(queue)
|
|
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|